Skip to content

Commit 3acc494

Browse files
authored
refactor: cache table metadata alongside snapshot time (#636)
This ensures the cached `primary_keys` is more likely to be correct, in case the user called ALTER TABLE after we originally cached the snapshot time.
1 parent 96c150a commit 3acc494

File tree

3 files changed

+75
-38
lines changed

3 files changed

+75
-38
lines changed

bigframes/session/__init__.py

+16-36
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,9 @@ def __init__(
231231
# Now that we're starting the session, don't allow the options to be
232232
# changed.
233233
context._session_started = True
234-
self._df_snapshot: Dict[bigquery.TableReference, datetime.datetime] = {}
234+
self._df_snapshot: Dict[
235+
bigquery.TableReference, Tuple[datetime.datetime, bigquery.Table]
236+
] = {}
235237

236238
@property
237239
def bqclient(self):
@@ -698,16 +700,25 @@ def _get_snapshot_sql_and_primary_key(
698700
column(s), then return those too so that ordering generation can be
699701
avoided.
700702
"""
701-
# If there are primary keys defined, the query engine assumes these
702-
# columns are unique, even if the constraint is not enforced. We make
703-
# the same assumption and use these columns as the total ordering keys.
704-
table = self.bqclient.get_table(table_ref)
703+
(
704+
snapshot_timestamp,
705+
table,
706+
) = bigframes_io.get_snapshot_datetime_and_table_metadata(
707+
self.bqclient,
708+
table_ref=table_ref,
709+
api_name=api_name,
710+
cache=self._df_snapshot,
711+
use_cache=use_cache,
712+
)
705713

706714
if table.location.casefold() != self._location.casefold():
707715
raise ValueError(
708716
f"Current session is in {self._location} but dataset '{table.project}.{table.dataset_id}' is located in {table.location}"
709717
)
710718

719+
# If there are primary keys defined, the query engine assumes these
720+
# columns are unique, even if the constraint is not enforced. We make
721+
# the same assumption and use these columns as the total ordering keys.
711722
primary_keys = None
712723
if (
713724
(table_constraints := getattr(table, "table_constraints", None)) is not None
@@ -718,37 +729,6 @@ def _get_snapshot_sql_and_primary_key(
718729
):
719730
primary_keys = columns
720731

721-
job_config = bigquery.QueryJobConfig()
722-
job_config.labels["bigframes-api"] = api_name
723-
if use_cache and table_ref in self._df_snapshot.keys():
724-
snapshot_timestamp = self._df_snapshot[table_ref]
725-
726-
# Cache hit could be unexpected. See internal issue 329545805.
727-
# Raise a warning with more information about how to avoid the
728-
# problems with the cache.
729-
warnings.warn(
730-
f"Reading cached table from {snapshot_timestamp} to avoid "
731-
"incompatibilies with previous reads of this table. To read "
732-
"the latest version, set `use_cache=False` or close the "
733-
"current session with Session.close() or "
734-
"bigframes.pandas.close_session().",
735-
# There are many layers before we get to (possibly) the user's code:
736-
# pandas.read_gbq_table
737-
# -> with_default_session
738-
# -> Session.read_gbq_table
739-
# -> _read_gbq_table
740-
# -> _get_snapshot_sql_and_primary_key
741-
stacklevel=6,
742-
)
743-
else:
744-
snapshot_timestamp = list(
745-
self.bqclient.query(
746-
"SELECT CURRENT_TIMESTAMP() AS `current_timestamp`",
747-
job_config=job_config,
748-
).result()
749-
)[0][0]
750-
self._df_snapshot[table_ref] = snapshot_timestamp
751-
752732
try:
753733
table_expression = self.ibis_client.sql(
754734
bigframes_io.create_snapshot_sql(table_ref, snapshot_timestamp)

bigframes/session/_io/bigquery.py

+54
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import types
2424
from typing import Dict, Iterable, Optional, Sequence, Tuple, Union
2525
import uuid
26+
import warnings
2627

2728
import google.api_core.exceptions
2829
import google.cloud.bigquery as bigquery
@@ -121,6 +122,59 @@ def table_ref_to_sql(table: bigquery.TableReference) -> str:
121122
return f"`{table.project}`.`{table.dataset_id}`.`{table.table_id}`"
122123

123124

125+
def get_snapshot_datetime_and_table_metadata(
126+
bqclient: bigquery.Client,
127+
table_ref: bigquery.TableReference,
128+
*,
129+
api_name: str,
130+
cache: Dict[bigquery.TableReference, Tuple[datetime.datetime, bigquery.Table]],
131+
use_cache: bool = True,
132+
) -> Tuple[datetime.datetime, bigquery.Table]:
133+
cached_table = cache.get(table_ref)
134+
if use_cache and cached_table is not None:
135+
snapshot_timestamp, _ = cached_table
136+
137+
# Cache hit could be unexpected. See internal issue 329545805.
138+
# Raise a warning with more information about how to avoid the
139+
# problems with the cache.
140+
warnings.warn(
141+
f"Reading cached table from {snapshot_timestamp} to avoid "
142+
"incompatibilies with previous reads of this table. To read "
143+
"the latest version, set `use_cache=False` or close the "
144+
"current session with Session.close() or "
145+
"bigframes.pandas.close_session().",
146+
# There are many layers before we get to (possibly) the user's code:
147+
# pandas.read_gbq_table
148+
# -> with_default_session
149+
# -> Session.read_gbq_table
150+
# -> _read_gbq_table
151+
# -> _get_snapshot_sql_and_primary_key
152+
# -> get_snapshot_datetime_and_table_metadata
153+
stacklevel=7,
154+
)
155+
return cached_table
156+
157+
# TODO(swast): It's possible that the table metadata is changed between now
158+
# and when we run the CURRENT_TIMESTAMP() query to see when we can time
159+
# travel to. Find a way to fetch the table metadata and BQ's current time
160+
# atomically.
161+
table = bqclient.get_table(table_ref)
162+
163+
# TODO(b/336521938): Refactor to make sure we set the "bigframes-api"
164+
# whereever we execute a query.
165+
job_config = bigquery.QueryJobConfig()
166+
job_config.labels["bigframes-api"] = api_name
167+
snapshot_timestamp = list(
168+
bqclient.query(
169+
"SELECT CURRENT_TIMESTAMP() AS `current_timestamp`",
170+
job_config=job_config,
171+
).result()
172+
)[0][0]
173+
cached_table = (snapshot_timestamp, table)
174+
cache[table_ref] = cached_table
175+
return cached_table
176+
177+
124178
def create_snapshot_sql(
125179
table_ref: bigquery.TableReference, current_timestamp: datetime.datetime
126180
) -> str:

tests/unit/session/test_session.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,11 @@ def test_read_gbq_cached_table():
4242
google.cloud.bigquery.DatasetReference("my-project", "my_dataset"),
4343
"my_table",
4444
)
45-
session._df_snapshot[table_ref] = datetime.datetime(
46-
1999, 1, 2, 3, 4, 5, 678901, tzinfo=datetime.timezone.utc
45+
table = google.cloud.bigquery.Table(table_ref)
46+
table._properties["location"] = session._location
47+
session._df_snapshot[table_ref] = (
48+
datetime.datetime(1999, 1, 2, 3, 4, 5, 678901, tzinfo=datetime.timezone.utc),
49+
table,
4750
)
4851

4952
with pytest.warns(UserWarning, match=re.escape("use_cache=False")):

0 commit comments

Comments
 (0)