Skip to content

Commit 9c21323

Browse files
authored
feat: Allow manually set clustering_columns in dataframe.to_gbq (#302)
* feat: Allow manually set clustering_columns in dataframe.to_gbq * Update if_exists check. * Update test.
1 parent a01b271 commit 9c21323

File tree

4 files changed

+150
-12
lines changed

4 files changed

+150
-12
lines changed

bigframes/dataframe.py

+68-11
Original file line numberDiff line numberDiff line change
@@ -2499,25 +2499,17 @@ def to_gbq(
24992499
if_exists: Optional[Literal["fail", "replace", "append"]] = None,
25002500
index: bool = True,
25012501
ordering_id: Optional[str] = None,
2502+
clustering_columns: Union[pandas.Index, Iterable[typing.Hashable]] = (),
25022503
) -> str:
25032504
dispositions = {
25042505
"fail": bigquery.WriteDisposition.WRITE_EMPTY,
25052506
"replace": bigquery.WriteDisposition.WRITE_TRUNCATE,
25062507
"append": bigquery.WriteDisposition.WRITE_APPEND,
25072508
}
25082509

2509-
if destination_table is None:
2510-
# TODO(swast): If there have been no modifications to the DataFrame
2511-
# since the last time it was written (cached), then return that.
2512-
# For `read_gbq` nodes, return the underlying table clone.
2513-
destination_table = bigframes.session._io.bigquery.create_temp_table(
2514-
self._session.bqclient,
2515-
self._session._anonymous_dataset,
2516-
# TODO(swast): allow custom expiration times, probably via session configuration.
2517-
datetime.datetime.now(datetime.timezone.utc)
2518-
+ constants.DEFAULT_EXPIRATION,
2519-
)
2510+
temp_table_ref = None
25202511

2512+
if destination_table is None:
25212513
if if_exists is not None and if_exists != "replace":
25222514
raise ValueError(
25232515
f"Got invalid value {repr(if_exists)} for if_exists. "
@@ -2526,6 +2518,11 @@ def to_gbq(
25262518
)
25272519
if_exists = "replace"
25282520

2521+
temp_table_ref = bigframes.session._io.bigquery.random_table(
2522+
self._session._anonymous_dataset
2523+
)
2524+
destination_table = f"{temp_table_ref.project}.{temp_table_ref.dataset_id}.{temp_table_ref.table_id}"
2525+
25292526
table_parts = destination_table.split(".")
25302527
default_project = self._block.expr.session.bqclient.project
25312528

@@ -2553,15 +2550,29 @@ def to_gbq(
25532550
except google.api_core.exceptions.NotFound:
25542551
self._session.bqclient.create_dataset(destination_dataset, exists_ok=True)
25552552

2553+
clustering_fields = self._map_clustering_columns(
2554+
clustering_columns, index=index
2555+
)
2556+
25562557
job_config = bigquery.QueryJobConfig(
25572558
write_disposition=dispositions[if_exists],
25582559
destination=bigquery.table.TableReference.from_string(
25592560
destination_table,
25602561
default_project=default_project,
25612562
),
2563+
clustering_fields=clustering_fields if clustering_fields else None,
25622564
)
25632565

25642566
self._run_io_query(index=index, ordering_id=ordering_id, job_config=job_config)
2567+
2568+
if temp_table_ref:
2569+
bigframes.session._io.bigquery.set_table_expiration(
2570+
self._session.bqclient,
2571+
temp_table_ref,
2572+
datetime.datetime.now(datetime.timezone.utc)
2573+
+ constants.DEFAULT_EXPIRATION,
2574+
)
2575+
25652576
return destination_table
25662577

25672578
def to_numpy(
@@ -2756,6 +2767,52 @@ def _apply_unary_op(self, operation: ops.UnaryOp) -> DataFrame:
27562767
block = self._block.multi_apply_unary_op(self._block.value_columns, operation)
27572768
return DataFrame(block)
27582769

2770+
def _map_clustering_columns(
2771+
self,
2772+
clustering_columns: Union[pandas.Index, Iterable[typing.Hashable]],
2773+
index: bool,
2774+
) -> List[str]:
2775+
"""Maps the provided clustering columns to the existing columns in the DataFrame."""
2776+
2777+
def map_columns_on_occurrence(columns):
2778+
mapped_columns = []
2779+
for col in clustering_columns:
2780+
if col in columns:
2781+
count = columns.count(col)
2782+
mapped_columns.extend([col] * count)
2783+
return mapped_columns
2784+
2785+
if not clustering_columns:
2786+
return []
2787+
2788+
if len(list(clustering_columns)) != len(set(clustering_columns)):
2789+
raise ValueError("Duplicates are not supported in clustering_columns")
2790+
2791+
all_possible_columns = (
2792+
(set(self.columns) | set(self.index.names)) if index else set(self.columns)
2793+
)
2794+
missing_columns = set(clustering_columns) - all_possible_columns
2795+
if missing_columns:
2796+
raise ValueError(
2797+
f"Clustering columns not found in DataFrame: {missing_columns}"
2798+
)
2799+
2800+
clustering_columns_for_df = map_columns_on_occurrence(
2801+
list(self._block.column_labels)
2802+
)
2803+
clustering_columns_for_index = (
2804+
map_columns_on_occurrence(list(self.index.names)) if index else []
2805+
)
2806+
2807+
(
2808+
clustering_columns_for_df,
2809+
clustering_columns_for_index,
2810+
) = utils.get_standardized_ids(
2811+
clustering_columns_for_df, clustering_columns_for_index
2812+
)
2813+
2814+
return clustering_columns_for_index + clustering_columns_for_df
2815+
27592816
def _create_io_query(self, index: bool, ordering_id: Optional[str]) -> str:
27602817
"""Create query text representing this dataframe for I/O."""
27612818
array_value = self._block.expr

bigframes/session/_io/bigquery.py

+11
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,17 @@ def create_temp_table(
150150
return f"{table_ref.project}.{table_ref.dataset_id}.{table_ref.table_id}"
151151

152152

153+
def set_table_expiration(
154+
bqclient: bigquery.Client,
155+
table_ref: bigquery.TableReference,
156+
expiration: datetime.datetime,
157+
) -> None:
158+
"""Set an expiration time for an existing BigQuery table."""
159+
table = bqclient.get_table(table_ref)
160+
table.expires = expiration
161+
bqclient.update_table(table, ["expires"])
162+
163+
153164
# BigQuery REST API returns types in Legacy SQL format
154165
# https://cloud.google.com/bigquery/docs/data-types but we use Standard SQL
155166
# names

tests/system/small/test_dataframe_io.py

+49
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,55 @@ def test_to_gbq_w_None_column_names(
317317
)
318318

319319

320+
@pytest.mark.parametrize(
321+
"clustering_columns",
322+
[
323+
pytest.param(["int64_col", "geography_col"]),
324+
pytest.param(
325+
["float64_col"],
326+
marks=pytest.mark.xfail(raises=google.api_core.exceptions.BadRequest),
327+
),
328+
pytest.param(
329+
["int64_col", "int64_col"],
330+
marks=pytest.mark.xfail(raises=ValueError),
331+
),
332+
],
333+
)
334+
def test_to_gbq_w_clustering(
335+
scalars_df_default_index,
336+
dataset_id,
337+
bigquery_client,
338+
clustering_columns,
339+
):
340+
"""Test the `to_gbq` API for creating clustered tables."""
341+
destination_table = (
342+
f"{dataset_id}.test_to_gbq_clustering_{'_'.join(clustering_columns)}"
343+
)
344+
345+
scalars_df_default_index.to_gbq(
346+
destination_table, clustering_columns=clustering_columns
347+
)
348+
table = bigquery_client.get_table(destination_table)
349+
350+
assert list(table.clustering_fields) == clustering_columns
351+
assert table.expires is None
352+
353+
354+
def test_to_gbq_w_clustering_no_destination(
355+
scalars_df_default_index,
356+
bigquery_client,
357+
):
358+
"""Test the `to_gbq` API for creating clustered tables without destination."""
359+
clustering_columns = ["int64_col", "geography_col"]
360+
destination_table = scalars_df_default_index.to_gbq(
361+
clustering_columns=clustering_columns
362+
)
363+
table = bigquery_client.get_table(destination_table)
364+
365+
assert list(table.clustering_fields) == clustering_columns
366+
assert table.expires is not None
367+
368+
320369
def test_to_gbq_w_invalid_destination_table(scalars_df_index):
321370
with pytest.raises(ValueError):
322371
scalars_df_index.to_gbq("table_id")

third_party/bigframes_vendored/pandas/core/frame.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
"""
1212
from __future__ import annotations
1313

14-
from typing import Literal, Mapping, Optional, Sequence, Union
14+
from typing import Hashable, Iterable, Literal, Mapping, Optional, Sequence, Union
1515

1616
import numpy as np
17+
import pandas as pd
1718

1819
from bigframes import constants
1920
from third_party.bigframes_vendored.pandas.core.generic import NDFrame
@@ -307,6 +308,7 @@ def to_gbq(
307308
if_exists: Optional[Literal["fail", "replace", "append"]] = None,
308309
index: bool = True,
309310
ordering_id: Optional[str] = None,
311+
clustering_columns: Union[pd.Index, Iterable[Hashable]] = (),
310312
) -> str:
311313
"""Write a DataFrame to a BigQuery table.
312314
@@ -336,6 +338,16 @@ def to_gbq(
336338
<BLANKLINE>
337339
[2 rows x 2 columns]
338340
341+
Write a DataFrame to a BigQuery table with clustering columns:
342+
>>> df = bpd.DataFrame({'col1': [1, 2], 'col2': [3, 4], 'col3': [5, 6]})
343+
>>> clustering_cols = ['col1', 'col3']
344+
>>> df.to_gbq(
345+
... "bigframes-dev.birds.test-clusters",
346+
... if_exists="replace",
347+
... clustering_columns=clustering_cols,
348+
... )
349+
'bigframes-dev.birds.test-clusters'
350+
339351
Args:
340352
destination_table (Optional[str]):
341353
Name of table to be written, in the form ``dataset.tablename``
@@ -364,6 +376,15 @@ def to_gbq(
364376
If set, write the ordering of the DataFrame as a column in the
365377
result table with this name.
366378
379+
clustering_columns (Union[pd.Index, Iterable[Hashable]], default ()):
380+
Specifies the columns for clustering in the BigQuery table. The order
381+
of columns in this list is significant for clustering hierarchy. Index
382+
columns may be included in clustering if the `index` parameter is set
383+
to True, and their names are specified in this. These index columns,
384+
if included, precede DataFrame columns in the clustering order. The
385+
clustering order within the Index/DataFrame columns follows the order
386+
specified in `clustering_columns`.
387+
367388
Returns:
368389
str:
369390
The fully-qualified ID for the written table, in the form

0 commit comments

Comments
 (0)