Skip to content

Commit 87e6018

Browse files
Genesis929tswastgcf-owl-bot[bot]
authored
feat: add .agg support for size (#792)
* feat: add .agg support for size * undo test change. * logic fix. * type update * Apply suggestions from code review * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --------- Co-authored-by: Tim Sweña (Swast) <[email protected]> Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
1 parent 580e1b9 commit 87e6018

File tree

7 files changed

+126
-36
lines changed

7 files changed

+126
-36
lines changed

bigframes/core/blocks.py

+46-9
Original file line numberDiff line numberDiff line change
@@ -995,7 +995,7 @@ def filter(self, predicate: scalars.Expression):
995995

996996
def aggregate_all_and_stack(
997997
self,
998-
operation: agg_ops.UnaryAggregateOp,
998+
operation: typing.Union[agg_ops.UnaryAggregateOp, agg_ops.NullaryAggregateOp],
999999
*,
10001000
axis: int | str = 0,
10011001
value_col_id: str = "values",
@@ -1004,7 +1004,12 @@ def aggregate_all_and_stack(
10041004
axis_n = utils.get_axis_number(axis)
10051005
if axis_n == 0:
10061006
aggregations = [
1007-
(ex.UnaryAggregation(operation, ex.free_var(col_id)), col_id)
1007+
(
1008+
ex.UnaryAggregation(operation, ex.free_var(col_id))
1009+
if isinstance(operation, agg_ops.UnaryAggregateOp)
1010+
else ex.NullaryAggregation(operation),
1011+
col_id,
1012+
)
10081013
for col_id in self.value_columns
10091014
]
10101015
index_id = guid.generate_guid()
@@ -1033,6 +1038,11 @@ def aggregate_all_and_stack(
10331038
(ex.UnaryAggregation(agg_ops.AnyValueOp(), ex.free_var(col_id)), col_id)
10341039
for col_id in [*self.index_columns]
10351040
]
1041+
# TODO: may need add NullaryAggregation in main_aggregation
1042+
# when agg add support for axis=1, needed for agg("size", axis=1)
1043+
assert isinstance(
1044+
operation, agg_ops.UnaryAggregateOp
1045+
), f"Expected a unary operation, but got {operation}. Please report this error and how you got here to the BigQuery DataFrames team (bit.ly/bigframes-feedback)."
10361046
main_aggregation = (
10371047
ex.UnaryAggregation(operation, ex.free_var(value_col_id)),
10381048
value_col_id,
@@ -1125,7 +1135,11 @@ def remap_f(x):
11251135
def aggregate(
11261136
self,
11271137
by_column_ids: typing.Sequence[str] = (),
1128-
aggregations: typing.Sequence[typing.Tuple[str, agg_ops.UnaryAggregateOp]] = (),
1138+
aggregations: typing.Sequence[
1139+
typing.Tuple[
1140+
str, typing.Union[agg_ops.UnaryAggregateOp, agg_ops.NullaryAggregateOp]
1141+
]
1142+
] = (),
11291143
*,
11301144
dropna: bool = True,
11311145
) -> typing.Tuple[Block, typing.Sequence[str]]:
@@ -1139,7 +1153,9 @@ def aggregate(
11391153
"""
11401154
agg_specs = [
11411155
(
1142-
ex.UnaryAggregation(operation, ex.free_var(input_id)),
1156+
ex.UnaryAggregation(operation, ex.free_var(input_id))
1157+
if isinstance(operation, agg_ops.UnaryAggregateOp)
1158+
else ex.NullaryAggregation(operation),
11431159
guid.generate_guid(),
11441160
)
11451161
for input_id, operation in aggregations
@@ -1175,18 +1191,32 @@ def aggregate(
11751191
output_col_ids,
11761192
)
11771193

1178-
def get_stat(self, column_id: str, stat: agg_ops.UnaryAggregateOp):
1194+
def get_stat(
1195+
self,
1196+
column_id: str,
1197+
stat: typing.Union[agg_ops.UnaryAggregateOp, agg_ops.NullaryAggregateOp],
1198+
):
11791199
"""Gets aggregates immediately, and caches it"""
11801200
if stat.name in self._stats_cache[column_id]:
11811201
return self._stats_cache[column_id][stat.name]
11821202

11831203
# TODO: Convert nonstandard stats into standard stats where possible (popvar, etc.)
11841204
# if getting a standard stat, just go get the rest of them
1185-
standard_stats = self._standard_stats(column_id)
1205+
standard_stats = typing.cast(
1206+
typing.Sequence[
1207+
typing.Union[agg_ops.UnaryAggregateOp, agg_ops.NullaryAggregateOp]
1208+
],
1209+
self._standard_stats(column_id),
1210+
)
11861211
stats_to_fetch = standard_stats if stat in standard_stats else [stat]
11871212

11881213
aggregations = [
1189-
(ex.UnaryAggregation(stat, ex.free_var(column_id)), stat.name)
1214+
(
1215+
ex.UnaryAggregation(stat, ex.free_var(column_id))
1216+
if isinstance(stat, agg_ops.UnaryAggregateOp)
1217+
else ex.NullaryAggregation(stat),
1218+
stat.name,
1219+
)
11901220
for stat in stats_to_fetch
11911221
]
11921222
expr = self.expr.aggregate(aggregations)
@@ -1231,13 +1261,20 @@ def get_binary_stat(
12311261
def summarize(
12321262
self,
12331263
column_ids: typing.Sequence[str],
1234-
stats: typing.Sequence[agg_ops.UnaryAggregateOp],
1264+
stats: typing.Sequence[
1265+
typing.Union[agg_ops.UnaryAggregateOp, agg_ops.NullaryAggregateOp]
1266+
],
12351267
):
12361268
"""Get a list of stats as a deferred block object."""
12371269
label_col_id = guid.generate_guid()
12381270
labels = [stat.name for stat in stats]
12391271
aggregations = [
1240-
(ex.UnaryAggregation(stat, ex.free_var(col_id)), f"{col_id}-{stat.name}")
1272+
(
1273+
ex.UnaryAggregation(stat, ex.free_var(col_id))
1274+
if isinstance(stat, agg_ops.UnaryAggregateOp)
1275+
else ex.NullaryAggregation(stat),
1276+
f"{col_id}-{stat.name}",
1277+
)
12411278
for stat in stats
12421279
for col_id in column_ids
12431280
]

bigframes/core/groupby/__init__.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -286,10 +286,10 @@ def expanding(self, min_periods: int = 1) -> windows.Window:
286286
block, window_spec, self._selected_cols, drop_null_groups=self._dropna
287287
)
288288

289-
def agg(self, func=None, **kwargs) -> df.DataFrame:
289+
def agg(self, func=None, **kwargs) -> typing.Union[df.DataFrame, series.Series]:
290290
if func:
291291
if isinstance(func, str):
292-
return self._agg_string(func)
292+
return self.size() if func == "size" else self._agg_string(func)
293293
elif utils.is_dict_like(func):
294294
return self._agg_dict(func)
295295
elif utils.is_list_like(func):
@@ -315,7 +315,11 @@ def _agg_string(self, func: str) -> df.DataFrame:
315315
return dataframe if self._as_index else self._convert_index(dataframe)
316316

317317
def _agg_dict(self, func: typing.Mapping) -> df.DataFrame:
318-
aggregations: typing.List[typing.Tuple[str, agg_ops.UnaryAggregateOp]] = []
318+
aggregations: typing.List[
319+
typing.Tuple[
320+
str, typing.Union[agg_ops.UnaryAggregateOp, agg_ops.NullaryAggregateOp]
321+
]
322+
] = []
319323
column_labels = []
320324

321325
want_aggfunc_level = any(utils.is_list_like(aggs) for aggs in func.values())

bigframes/operations/aggregations.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,9 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
487487

488488

489489
# TODO: Alternative names and lookup from numpy function objects
490-
_AGGREGATIONS_LOOKUP: dict[str, UnaryAggregateOp] = {
490+
_AGGREGATIONS_LOOKUP: typing.Dict[
491+
str, typing.Union[UnaryAggregateOp, NullaryAggregateOp]
492+
] = {
491493
op.name: op
492494
for op in [
493495
sum_op,
@@ -506,10 +508,14 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
506508
ApproxQuartilesOp(2),
507509
ApproxQuartilesOp(3),
508510
]
511+
+ [
512+
# Add size_op separately to avoid Mypy type inference errors.
513+
size_op,
514+
]
509515
}
510516

511517

512-
def lookup_agg_func(key: str) -> UnaryAggregateOp:
518+
def lookup_agg_func(key: str) -> typing.Union[UnaryAggregateOp, NullaryAggregateOp]:
513519
if callable(key):
514520
raise NotImplementedError(
515521
"Aggregating with callable object not supported, pass method name as string instead (eg. 'sum' instead of np.sum)."

bigframes/series.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -968,7 +968,6 @@ def agg(self, func: str | typing.Sequence[str]) -> scalars.Scalar | Series:
968968
)
969969
)
970970
else:
971-
972971
return self._apply_aggregation(
973972
agg_ops.lookup_agg_func(typing.cast(str, func))
974973
)
@@ -1246,7 +1245,9 @@ def _align3(self, other1: Series | scalars.Scalar, other2: Series | scalars.Scal
12461245
values, index = self._align_n([other1, other2], how)
12471246
return (values[0], values[1], values[2], index)
12481247

1249-
def _apply_aggregation(self, op: agg_ops.UnaryAggregateOp) -> Any:
1248+
def _apply_aggregation(
1249+
self, op: agg_ops.UnaryAggregateOp | agg_ops.NullaryAggregateOp
1250+
) -> Any:
12501251
return self._block.get_stat(self._value_column, op)
12511252

12521253
def _apply_window_op(

tests/system/small/test_dataframe.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -2485,12 +2485,19 @@ def test_dataframe_agg_single_string(scalars_dfs):
24852485
)
24862486

24872487

2488-
def test_dataframe_agg_int_single_string(scalars_dfs):
2488+
@pytest.mark.parametrize(
2489+
("agg",),
2490+
(
2491+
("sum",),
2492+
("size",),
2493+
),
2494+
)
2495+
def test_dataframe_agg_int_single_string(scalars_dfs, agg):
24892496
numeric_cols = ["int64_col", "int64_too", "bool_col"]
24902497
scalars_df, scalars_pandas_df = scalars_dfs
24912498

2492-
bf_result = scalars_df[numeric_cols].agg("sum").to_pandas()
2493-
pd_result = scalars_pandas_df[numeric_cols].agg("sum")
2499+
bf_result = scalars_df[numeric_cols].agg(agg).to_pandas()
2500+
pd_result = scalars_pandas_df[numeric_cols].agg(agg)
24942501

24952502
assert bf_result.dtype == "Int64"
24962503
pd.testing.assert_series_equal(
@@ -2537,6 +2544,7 @@ def test_dataframe_agg_int_multi_string(scalars_dfs):
25372544
"sum",
25382545
"nunique",
25392546
"count",
2547+
"size",
25402548
]
25412549
scalars_df, scalars_pandas_df = scalars_dfs
25422550
bf_result = scalars_df[numeric_cols].agg(aggregations).to_pandas()

tests/system/small/test_groupby.py

+30-13
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,23 @@ def test_dataframe_groupby_agg_string(
140140
)
141141

142142

143+
def test_dataframe_groupby_agg_size_string(scalars_df_index, scalars_pandas_df_index):
144+
col_names = ["int64_too", "float64_col", "int64_col", "bool_col", "string_col"]
145+
bf_result = scalars_df_index[col_names].groupby("string_col").agg("size")
146+
pd_result = scalars_pandas_df_index[col_names].groupby("string_col").agg("size")
147+
148+
pd.testing.assert_series_equal(pd_result, bf_result.to_pandas(), check_dtype=False)
149+
150+
143151
def test_dataframe_groupby_agg_list(scalars_df_index, scalars_pandas_df_index):
144152
col_names = ["int64_too", "float64_col", "int64_col", "bool_col", "string_col"]
145-
bf_result = scalars_df_index[col_names].groupby("string_col").agg(["count", "min"])
153+
bf_result = (
154+
scalars_df_index[col_names].groupby("string_col").agg(["count", "min", "size"])
155+
)
146156
pd_result = (
147-
scalars_pandas_df_index[col_names].groupby("string_col").agg(["count", "min"])
157+
scalars_pandas_df_index[col_names]
158+
.groupby("string_col")
159+
.agg(["count", "min", "size"])
148160
)
149161
bf_result_computed = bf_result.to_pandas()
150162

@@ -161,8 +173,8 @@ def test_dataframe_groupby_agg_list_w_column_multi_index(
161173
pd_df = scalars_pandas_df_index[columns].copy()
162174
pd_df.columns = multi_columns
163175

164-
bf_result = bf_df.groupby(level=0).agg(["count", "min"])
165-
pd_result = pd_df.groupby(level=0).agg(["count", "min"])
176+
bf_result = bf_df.groupby(level=0).agg(["count", "min", "size"])
177+
pd_result = pd_df.groupby(level=0).agg(["count", "min", "size"])
166178

167179
bf_result_computed = bf_result.to_pandas()
168180
pd.testing.assert_frame_equal(pd_result, bf_result_computed, check_dtype=False)
@@ -182,12 +194,12 @@ def test_dataframe_groupby_agg_dict_with_list(
182194
bf_result = (
183195
scalars_df_index[col_names]
184196
.groupby("string_col", as_index=as_index)
185-
.agg({"int64_too": ["mean", "max"], "string_col": "count"})
197+
.agg({"int64_too": ["mean", "max"], "string_col": "count", "bool_col": "size"})
186198
)
187199
pd_result = (
188200
scalars_pandas_df_index[col_names]
189201
.groupby("string_col", as_index=as_index)
190-
.agg({"int64_too": ["mean", "max"], "string_col": "count"})
202+
.agg({"int64_too": ["mean", "max"], "string_col": "count", "bool_col": "size"})
191203
)
192204
bf_result_computed = bf_result.to_pandas()
193205

@@ -413,16 +425,21 @@ def test_dataframe_groupby_nonnumeric_with_mean():
413425
# ==============
414426

415427

416-
def test_series_groupby_agg_string(scalars_df_index, scalars_pandas_df_index):
428+
@pytest.mark.parametrize(
429+
("agg"),
430+
[
431+
("count"),
432+
("size"),
433+
],
434+
)
435+
def test_series_groupby_agg_string(scalars_df_index, scalars_pandas_df_index, agg):
417436
bf_result = (
418-
scalars_df_index["int64_col"]
419-
.groupby(scalars_df_index["string_col"])
420-
.agg("count")
437+
scalars_df_index["int64_col"].groupby(scalars_df_index["string_col"]).agg(agg)
421438
)
422439
pd_result = (
423440
scalars_pandas_df_index["int64_col"]
424441
.groupby(scalars_pandas_df_index["string_col"])
425-
.agg("count")
442+
.agg(agg)
426443
)
427444
bf_result_computed = bf_result.to_pandas()
428445

@@ -435,12 +452,12 @@ def test_series_groupby_agg_list(scalars_df_index, scalars_pandas_df_index):
435452
bf_result = (
436453
scalars_df_index["int64_col"]
437454
.groupby(scalars_df_index["string_col"])
438-
.agg(["sum", "mean"])
455+
.agg(["sum", "mean", "size"])
439456
)
440457
pd_result = (
441458
scalars_pandas_df_index["int64_col"]
442459
.groupby(scalars_pandas_df_index["string_col"])
443-
.agg(["sum", "mean"])
460+
.agg(["sum", "mean", "size"])
444461
)
445462
bf_result_computed = bf_result.to_pandas()
446463

tests/system/small/test_series.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -506,15 +506,32 @@ def test_series_dropna(scalars_dfs, ignore_index):
506506
pd.testing.assert_series_equal(pd_result, bf_result, check_index_type=False)
507507

508508

509-
def test_series_agg_single_string(scalars_dfs):
509+
@pytest.mark.parametrize(
510+
("agg",),
511+
(
512+
("sum",),
513+
("size",),
514+
),
515+
)
516+
def test_series_agg_single_string(scalars_dfs, agg):
510517
scalars_df, scalars_pandas_df = scalars_dfs
511-
bf_result = scalars_df["int64_col"].agg("sum")
512-
pd_result = scalars_pandas_df["int64_col"].agg("sum")
518+
bf_result = scalars_df["int64_col"].agg(agg)
519+
pd_result = scalars_pandas_df["int64_col"].agg(agg)
513520
assert math.isclose(pd_result, bf_result)
514521

515522

516523
def test_series_agg_multi_string(scalars_dfs):
517-
aggregations = ["sum", "mean", "std", "var", "min", "max", "nunique", "count"]
524+
aggregations = [
525+
"sum",
526+
"mean",
527+
"std",
528+
"var",
529+
"min",
530+
"max",
531+
"nunique",
532+
"count",
533+
"size",
534+
]
518535
scalars_df, scalars_pandas_df = scalars_dfs
519536
bf_result = scalars_df["int64_col"].agg(aggregations).to_pandas()
520537
pd_result = scalars_pandas_df["int64_col"].agg(aggregations)

0 commit comments

Comments
 (0)