@@ -55,25 +55,28 @@ def test_create_text_generator_model_default_session(
55
55
):
56
56
import bigframes .pandas as bpd
57
57
58
- bpd .close_session ()
59
- bpd .options .bigquery .bq_connection = bq_connection
60
- bpd .options .bigquery .location = "us"
61
-
62
- model = llm .PaLM2TextGenerator ()
63
- assert model is not None
64
- assert model ._bqml_model is not None
65
- assert (
66
- model .connection_name .casefold ()
67
- == f"{ bigquery_client .project } .us.bigframes-rf-conn"
68
- )
69
-
70
- llm_text_df = bpd .read_pandas (llm_text_pandas_df )
71
-
72
- df = model .predict (llm_text_df ).to_pandas ()
73
- assert df .shape == (3 , 4 )
74
- assert "ml_generate_text_llm_result" in df .columns
75
- series = df ["ml_generate_text_llm_result" ]
76
- assert all (series .str .len () > 20 )
58
+ # Note: This starts a thread-local session.
59
+ with bpd .option_context (
60
+ "bigquery.bq_connection" ,
61
+ bq_connection ,
62
+ "bigquery.location" ,
63
+ "US" ,
64
+ ):
65
+ model = llm .PaLM2TextGenerator ()
66
+ assert model is not None
67
+ assert model ._bqml_model is not None
68
+ assert (
69
+ model .connection_name .casefold ()
70
+ == f"{ bigquery_client .project } .us.bigframes-rf-conn"
71
+ )
72
+
73
+ llm_text_df = bpd .read_pandas (llm_text_pandas_df )
74
+
75
+ df = model .predict (llm_text_df ).to_pandas ()
76
+ assert df .shape == (3 , 4 )
77
+ assert "ml_generate_text_llm_result" in df .columns
78
+ series = df ["ml_generate_text_llm_result" ]
79
+ assert all (series .str .len () > 20 )
77
80
78
81
79
82
@pytest .mark .flaky (retries = 2 )
@@ -82,25 +85,28 @@ def test_create_text_generator_32k_model_default_session(
82
85
):
83
86
import bigframes .pandas as bpd
84
87
85
- bpd .close_session ()
86
- bpd .options .bigquery .bq_connection = bq_connection
87
- bpd .options .bigquery .location = "us"
88
-
89
- model = llm .PaLM2TextGenerator (model_name = "text-bison-32k" )
90
- assert model is not None
91
- assert model ._bqml_model is not None
92
- assert (
93
- model .connection_name .casefold ()
94
- == f"{ bigquery_client .project } .us.bigframes-rf-conn"
95
- )
96
-
97
- llm_text_df = bpd .read_pandas (llm_text_pandas_df )
98
-
99
- df = model .predict (llm_text_df ).to_pandas ()
100
- assert df .shape == (3 , 4 )
101
- assert "ml_generate_text_llm_result" in df .columns
102
- series = df ["ml_generate_text_llm_result" ]
103
- assert all (series .str .len () > 20 )
88
+ # Note: This starts a thread-local session.
89
+ with bpd .option_context (
90
+ "bigquery.bq_connection" ,
91
+ bq_connection ,
92
+ "bigquery.location" ,
93
+ "US" ,
94
+ ):
95
+ model = llm .PaLM2TextGenerator (model_name = "text-bison-32k" )
96
+ assert model is not None
97
+ assert model ._bqml_model is not None
98
+ assert (
99
+ model .connection_name .casefold ()
100
+ == f"{ bigquery_client .project } .us.bigframes-rf-conn"
101
+ )
102
+
103
+ llm_text_df = bpd .read_pandas (llm_text_pandas_df )
104
+
105
+ df = model .predict (llm_text_df ).to_pandas ()
106
+ assert df .shape == (3 , 4 )
107
+ assert "ml_generate_text_llm_result" in df .columns
108
+ series = df ["ml_generate_text_llm_result" ]
109
+ assert all (series .str .len () > 20 )
104
110
105
111
106
112
@pytest .mark .flaky (retries = 2 )
@@ -232,27 +238,33 @@ def test_create_embedding_generator_multilingual_model(
232
238
def test_create_text_embedding_generator_model_defaults (bq_connection ):
233
239
import bigframes .pandas as bpd
234
240
235
- bpd .close_session ()
236
- bpd .options .bigquery .bq_connection = bq_connection
237
- bpd .options .bigquery .location = "us"
238
-
239
- model = llm .PaLM2TextEmbeddingGenerator ()
240
- assert model is not None
241
- assert model ._bqml_model is not None
241
+ # Note: This starts a thread-local session.
242
+ with bpd .option_context (
243
+ "bigquery.bq_connection" ,
244
+ bq_connection ,
245
+ "bigquery.location" ,
246
+ "US" ,
247
+ ):
248
+ model = llm .PaLM2TextEmbeddingGenerator ()
249
+ assert model is not None
250
+ assert model ._bqml_model is not None
242
251
243
252
244
253
def test_create_text_embedding_generator_multilingual_model_defaults (bq_connection ):
245
254
import bigframes .pandas as bpd
246
255
247
- bpd .close_session ()
248
- bpd .options .bigquery .bq_connection = bq_connection
249
- bpd .options .bigquery .location = "us"
250
-
251
- model = llm .PaLM2TextEmbeddingGenerator (
252
- model_name = "textembedding-gecko-multilingual"
253
- )
254
- assert model is not None
255
- assert model ._bqml_model is not None
256
+ # Note: This starts a thread-local session.
257
+ with bpd .option_context (
258
+ "bigquery.bq_connection" ,
259
+ bq_connection ,
260
+ "bigquery.location" ,
261
+ "US" ,
262
+ ):
263
+ model = llm .PaLM2TextEmbeddingGenerator (
264
+ model_name = "textembedding-gecko-multilingual"
265
+ )
266
+ assert model is not None
267
+ assert model ._bqml_model is not None
256
268
257
269
258
270
@pytest .mark .flaky (retries = 2 )
0 commit comments