Skip to content

Commit 651fd7d

Browse files
authored
feat: bigframes.options and bigframes.option_context now uses thread-local variables to prevent context managers in separate threads from affecting each other (#652)
* feat: `bigframes.options` and `bigframes.option_context` now uses thread-local variables to prevent context managers in separate threads from affecting each other In our tests, this allows us to actually test things like `bf.option_context("display.repr_mode", "deferred"):` without always having some other test change the display mode and break the test. Fixes internal issue 308657813 * catch close errors on thread-local session too * use presence of _local.bigquery_options to indicate thread locality feat: always do a query dry run when `option.repr_mode == "deferred"` (#652)
1 parent 2715d2b commit 651fd7d

File tree

10 files changed

+275
-119
lines changed

10 files changed

+275
-119
lines changed

bigframes/_config/__init__.py

+62-12
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
DataFrames from this package.
1818
"""
1919

20+
import copy
21+
import threading
22+
2023
import bigframes_vendored.pandas._config.config as pandas_config
2124

2225
import bigframes._config.bigquery_options as bigquery_options
@@ -29,44 +32,91 @@ class Options:
2932
"""Global options affecting BigQuery DataFrames behavior."""
3033

3134
def __init__(self):
35+
self._local = threading.local()
36+
37+
# Initialize these in the property getters to make sure we do have a
38+
# separate instance per thread.
39+
self._local.bigquery_options = None
40+
self._local.display_options = None
41+
self._local.sampling_options = None
42+
self._local.compute_options = None
43+
44+
# BigQuery options are special because they can only be set once per
45+
# session, so we need an indicator as to whether we are using the
46+
# thread-local session or the global session.
3247
self._bigquery_options = bigquery_options.BigQueryOptions()
33-
self._display_options = display_options.DisplayOptions()
34-
self._sampling_options = sampling_options.SamplingOptions()
35-
self._compute_options = compute_options.ComputeOptions()
48+
49+
def _init_bigquery_thread_local(self):
50+
"""Initialize thread-local options, based on current global options."""
51+
52+
# Already thread-local, so don't reset any options that have been set
53+
# already. No locks needed since this only modifies thread-local
54+
# variables.
55+
if self._local.bigquery_options is not None:
56+
return
57+
58+
self._local.bigquery_options = copy.deepcopy(self._bigquery_options)
59+
self._local.bigquery_options._session_started = False
3660

3761
@property
3862
def bigquery(self) -> bigquery_options.BigQueryOptions:
3963
"""Options to use with the BigQuery engine."""
64+
if self._local.bigquery_options is not None:
65+
# The only way we can get here is if someone called
66+
# _init_bigquery_thread_local.
67+
return self._local.bigquery_options
68+
4069
return self._bigquery_options
4170

4271
@property
4372
def display(self) -> display_options.DisplayOptions:
4473
"""Options controlling object representation."""
45-
return self._display_options
74+
if self._local.display_options is None:
75+
self._local.display_options = display_options.DisplayOptions()
76+
77+
return self._local.display_options
4678

4779
@property
4880
def sampling(self) -> sampling_options.SamplingOptions:
4981
"""Options controlling downsampling when downloading data
50-
to memory. The data will be downloaded into memory explicitly
82+
to memory.
83+
84+
The data can be downloaded into memory explicitly
5185
(e.g., to_pandas, to_numpy, values) or implicitly (e.g.,
5286
matplotlib plotting). This option can be overriden by
53-
parameters in specific functions."""
54-
return self._sampling_options
87+
parameters in specific functions.
88+
"""
89+
if self._local.sampling_options is None:
90+
self._local.sampling_options = sampling_options.SamplingOptions()
91+
92+
return self._local.sampling_options
5593

5694
@property
5795
def compute(self) -> compute_options.ComputeOptions:
58-
"""Options controlling object computation."""
59-
return self._compute_options
96+
"""Thread-local options controlling object computation."""
97+
if self._local.compute_options is None:
98+
self._local.compute_options = compute_options.ComputeOptions()
99+
100+
return self._local.compute_options
101+
102+
@property
103+
def is_bigquery_thread_local(self) -> bool:
104+
"""Indicator that we're using a thread-local session.
105+
106+
A thread-local session can be started by using
107+
`with bigframes.option_context("bigquery.some_option", "some-value"):`.
108+
"""
109+
return self._local.bigquery_options is not None
60110

61111

62112
options = Options()
63113
"""Global options for default session."""
64114

115+
option_context = pandas_config.option_context
116+
65117

66118
__all__ = (
67119
"Options",
68120
"options",
121+
"option_context",
69122
)
70-
71-
72-
option_context = pandas_config.option_context

bigframes/core/global_session.py

+46-14
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,24 @@
2626

2727
_global_session: Optional[bigframes.session.Session] = None
2828
_global_session_lock = threading.Lock()
29+
_global_session_state = threading.local()
30+
_global_session_state.thread_local_session = None
31+
32+
33+
def _try_close_session(session):
34+
"""Try to close the session and warn if couldn't."""
35+
try:
36+
session.close()
37+
except google.auth.exceptions.RefreshError as e:
38+
session_id = session.session_id
39+
location = session._location
40+
project_id = session._project
41+
warnings.warn(
42+
f"Session cleanup failed for session with id: {session_id}, "
43+
f"location: {location}, project: {project_id}",
44+
category=bigframes.exceptions.CleanupFailedWarning,
45+
)
46+
traceback.print_tb(e.__traceback__)
2947

3048

3149
def close_session() -> None:
@@ -37,24 +55,30 @@ def close_session() -> None:
3755
Returns:
3856
None
3957
"""
40-
global _global_session
58+
global _global_session, _global_session_lock, _global_session_state
59+
60+
if bigframes._config.options.is_bigquery_thread_local:
61+
if _global_session_state.thread_local_session is not None:
62+
_try_close_session(_global_session_state.thread_local_session)
63+
_global_session_state.thread_local_session = None
64+
65+
# Currently using thread-local options, so no global lock needed.
66+
# Don't reset options.bigquery, as that's the responsibility
67+
# of the context manager that started it in the first place. The user
68+
# might have explicitly closed the session in the context manager and
69+
# the thread-locality property needs to be retained.
70+
bigframes._config.options.bigquery._session_started = False
71+
72+
# Don't close the non-thread-local session.
73+
return
4174

4275
with _global_session_lock:
4376
if _global_session is not None:
44-
try:
45-
_global_session.close()
46-
except google.auth.exceptions.RefreshError as e:
47-
session_id = _global_session.session_id
48-
location = _global_session._location
49-
project_id = _global_session._project
50-
warnings.warn(
51-
f"Session cleanup failed for session with id: {session_id}, "
52-
f"location: {location}, project: {project_id}",
53-
category=bigframes.exceptions.CleanupFailedWarning,
54-
)
55-
traceback.print_tb(e.__traceback__)
77+
_try_close_session(_global_session)
5678
_global_session = None
5779

80+
# This should be global, not thread-local because of the if clause
81+
# above.
5882
bigframes._config.options.bigquery._session_started = False
5983

6084

@@ -63,7 +87,15 @@ def get_global_session():
6387
6488
Creates the global session if it does not exist.
6589
"""
66-
global _global_session, _global_session_lock
90+
global _global_session, _global_session_lock, _global_session_state
91+
92+
if bigframes._config.options.is_bigquery_thread_local:
93+
if _global_session_state.thread_local_session is None:
94+
_global_session_state.thread_local_session = bigframes.session.connect(
95+
bigframes._config.options.bigquery
96+
)
97+
98+
return _global_session_state.thread_local_session
6799

68100
with _global_session_lock:
69101
if _global_session is None:

bigframes/core/indexes/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def __repr__(self) -> str:
239239
opts = bigframes.options.display
240240
max_results = opts.max_rows
241241
if opts.repr_mode == "deferred":
242-
return formatter.repr_query_job(self.query_job)
242+
return formatter.repr_query_job(self._block._compute_dry_run())
243243

244244
pandas_df, _, query_job = self._block.retrieve_repr_request_results(max_results)
245245
self._query_job = query_job

bigframes/dataframe.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,7 @@ def __repr__(self) -> str:
595595
opts = bigframes.options.display
596596
max_results = opts.max_rows
597597
if opts.repr_mode == "deferred":
598-
return formatter.repr_query_job(self.query_job)
598+
return formatter.repr_query_job(self._compute_dry_run())
599599

600600
self._cached()
601601
# TODO(swast): pass max_columns and get the true column count back. Maybe
@@ -632,9 +632,9 @@ def _repr_html_(self) -> str:
632632
many notebooks are not configured for large tables.
633633
"""
634634
opts = bigframes.options.display
635-
max_results = bigframes.options.display.max_rows
635+
max_results = opts.max_rows
636636
if opts.repr_mode == "deferred":
637-
return formatter.repr_query_job_html(self.query_job)
637+
return formatter.repr_query_job(self._compute_dry_run())
638638

639639
self._cached()
640640
# TODO(swast): pass max_columns and get the true column count back. Maybe

bigframes/series.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def __repr__(self) -> str:
282282
opts = bigframes.options.display
283283
max_results = opts.max_rows
284284
if opts.repr_mode == "deferred":
285-
return formatter.repr_query_job(self.query_job)
285+
return formatter.repr_query_job(self._compute_dry_run())
286286

287287
self._cached()
288288
pandas_df, _, query_job = self._block.retrieve_repr_request_results(max_results)

tests/system/small/ml/test_llm.py

+66-54
Original file line numberDiff line numberDiff line change
@@ -55,25 +55,28 @@ def test_create_text_generator_model_default_session(
5555
):
5656
import bigframes.pandas as bpd
5757

58-
bpd.close_session()
59-
bpd.options.bigquery.bq_connection = bq_connection
60-
bpd.options.bigquery.location = "us"
61-
62-
model = llm.PaLM2TextGenerator()
63-
assert model is not None
64-
assert model._bqml_model is not None
65-
assert (
66-
model.connection_name.casefold()
67-
== f"{bigquery_client.project}.us.bigframes-rf-conn"
68-
)
69-
70-
llm_text_df = bpd.read_pandas(llm_text_pandas_df)
71-
72-
df = model.predict(llm_text_df).to_pandas()
73-
assert df.shape == (3, 4)
74-
assert "ml_generate_text_llm_result" in df.columns
75-
series = df["ml_generate_text_llm_result"]
76-
assert all(series.str.len() > 20)
58+
# Note: This starts a thread-local session.
59+
with bpd.option_context(
60+
"bigquery.bq_connection",
61+
bq_connection,
62+
"bigquery.location",
63+
"US",
64+
):
65+
model = llm.PaLM2TextGenerator()
66+
assert model is not None
67+
assert model._bqml_model is not None
68+
assert (
69+
model.connection_name.casefold()
70+
== f"{bigquery_client.project}.us.bigframes-rf-conn"
71+
)
72+
73+
llm_text_df = bpd.read_pandas(llm_text_pandas_df)
74+
75+
df = model.predict(llm_text_df).to_pandas()
76+
assert df.shape == (3, 4)
77+
assert "ml_generate_text_llm_result" in df.columns
78+
series = df["ml_generate_text_llm_result"]
79+
assert all(series.str.len() > 20)
7780

7881

7982
@pytest.mark.flaky(retries=2)
@@ -82,25 +85,28 @@ def test_create_text_generator_32k_model_default_session(
8285
):
8386
import bigframes.pandas as bpd
8487

85-
bpd.close_session()
86-
bpd.options.bigquery.bq_connection = bq_connection
87-
bpd.options.bigquery.location = "us"
88-
89-
model = llm.PaLM2TextGenerator(model_name="text-bison-32k")
90-
assert model is not None
91-
assert model._bqml_model is not None
92-
assert (
93-
model.connection_name.casefold()
94-
== f"{bigquery_client.project}.us.bigframes-rf-conn"
95-
)
96-
97-
llm_text_df = bpd.read_pandas(llm_text_pandas_df)
98-
99-
df = model.predict(llm_text_df).to_pandas()
100-
assert df.shape == (3, 4)
101-
assert "ml_generate_text_llm_result" in df.columns
102-
series = df["ml_generate_text_llm_result"]
103-
assert all(series.str.len() > 20)
88+
# Note: This starts a thread-local session.
89+
with bpd.option_context(
90+
"bigquery.bq_connection",
91+
bq_connection,
92+
"bigquery.location",
93+
"US",
94+
):
95+
model = llm.PaLM2TextGenerator(model_name="text-bison-32k")
96+
assert model is not None
97+
assert model._bqml_model is not None
98+
assert (
99+
model.connection_name.casefold()
100+
== f"{bigquery_client.project}.us.bigframes-rf-conn"
101+
)
102+
103+
llm_text_df = bpd.read_pandas(llm_text_pandas_df)
104+
105+
df = model.predict(llm_text_df).to_pandas()
106+
assert df.shape == (3, 4)
107+
assert "ml_generate_text_llm_result" in df.columns
108+
series = df["ml_generate_text_llm_result"]
109+
assert all(series.str.len() > 20)
104110

105111

106112
@pytest.mark.flaky(retries=2)
@@ -232,27 +238,33 @@ def test_create_embedding_generator_multilingual_model(
232238
def test_create_text_embedding_generator_model_defaults(bq_connection):
233239
import bigframes.pandas as bpd
234240

235-
bpd.close_session()
236-
bpd.options.bigquery.bq_connection = bq_connection
237-
bpd.options.bigquery.location = "us"
238-
239-
model = llm.PaLM2TextEmbeddingGenerator()
240-
assert model is not None
241-
assert model._bqml_model is not None
241+
# Note: This starts a thread-local session.
242+
with bpd.option_context(
243+
"bigquery.bq_connection",
244+
bq_connection,
245+
"bigquery.location",
246+
"US",
247+
):
248+
model = llm.PaLM2TextEmbeddingGenerator()
249+
assert model is not None
250+
assert model._bqml_model is not None
242251

243252

244253
def test_create_text_embedding_generator_multilingual_model_defaults(bq_connection):
245254
import bigframes.pandas as bpd
246255

247-
bpd.close_session()
248-
bpd.options.bigquery.bq_connection = bq_connection
249-
bpd.options.bigquery.location = "us"
250-
251-
model = llm.PaLM2TextEmbeddingGenerator(
252-
model_name="textembedding-gecko-multilingual"
253-
)
254-
assert model is not None
255-
assert model._bqml_model is not None
256+
# Note: This starts a thread-local session.
257+
with bpd.option_context(
258+
"bigquery.bq_connection",
259+
bq_connection,
260+
"bigquery.location",
261+
"US",
262+
):
263+
model = llm.PaLM2TextEmbeddingGenerator(
264+
model_name="textembedding-gecko-multilingual"
265+
)
266+
assert model is not None
267+
assert model._bqml_model is not None
256268

257269

258270
@pytest.mark.flaky(retries=2)

0 commit comments

Comments
 (0)