Skip to content

Commit b02c401

Browse files
authored
docs: add code snippets for imported tensorflow model (#679)
1 parent 0b8b827 commit b02c401

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (t
4+
# you may not use this file except in compliance wi
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in
10+
# distributed under the License is distributed on a
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, eit
12+
# See the License for the specific language governi
13+
# limitations under the License.
14+
15+
16+
def test_imported_tensorflow_model() -> None:
17+
# Determine project id, in this case prefer the one set in the environment
18+
# variable GOOGLE_CLOUD_PROJECT (if any)
19+
import os
20+
21+
PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT", "bigframes-dev")
22+
23+
# [START bigquery_dataframes_imported_tensorflow_tutorial_import_tensorflow_models]
24+
import bigframes
25+
from bigframes.ml.imported import TensorFlowModel
26+
27+
bigframes.options.bigquery.project = PROJECT_ID
28+
# You can change the location to one of the valid locations: https://cloud.google.com/bigquery/docs/locations#supported_locations
29+
bigframes.options.bigquery.location = "US"
30+
31+
imported_tensorflow_model = TensorFlowModel(
32+
model_path="gs://cloud-training-demos/txtclass/export/exporter/1549825580/*"
33+
)
34+
# [END bigquery_dataframes_imported_tensorflow_tutorial_import_tensorflow_models]
35+
assert imported_tensorflow_model is not None
36+
37+
# [START bigquery_dataframes_imported_tensorflow_tutorial_make_predictions]
38+
import bigframes.pandas as bpd
39+
40+
df = bpd.read_gbq("bigquery-public-data.hacker_news.full")
41+
predictions = imported_tensorflow_model.predict(df)
42+
predictions.head(5)
43+
# [END bigquery_dataframes_imported_tensorflow_tutorial_make_predictions]

0 commit comments

Comments
 (0)