Skip to content

Commit 6a78c89

Browse files
authored
fix: correct index labels in multiple aggregations for DataFrameGroupBy (#723)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [X] Make sure to open an issue as a bug - [X] Ensure the tests and linter pass - [X] Code coverage does not decrease (if any source code was changed) - [X] Appropriate docs were updated (if necessary) Fixes internal issue 341157901 🦕
1 parent 0e25a3b commit 6a78c89

File tree

3 files changed

+63
-3
lines changed

3 files changed

+63
-3
lines changed

bigframes/core/groupby/__init__.py

+24-3
Original file line numberDiff line numberDiff line change
@@ -339,9 +339,30 @@ def _agg_list(self, func: typing.Sequence) -> df.DataFrame:
339339
for col_id in self._aggregated_columns()
340340
for f in func
341341
]
342-
column_labels = [
343-
(col_id, f) for col_id in self._aggregated_columns() for f in func
344-
]
342+
343+
if self._block.column_labels.nlevels > 1:
344+
# Restructure MultiIndex for proper format: (idx1, idx2, func)
345+
# rather than ((idx1, idx2), func).
346+
aggregated_columns = pd.MultiIndex.from_tuples(
347+
[
348+
self._block.col_id_to_label[col_id]
349+
for col_id in self._aggregated_columns()
350+
],
351+
names=[*self._block.column_labels.names],
352+
).to_frame(index=False)
353+
354+
column_labels = [
355+
tuple(col_id) + (f,)
356+
for col_id in aggregated_columns.to_numpy()
357+
for f in func
358+
]
359+
else:
360+
column_labels = [
361+
(self._block.col_id_to_label[col_id], f)
362+
for col_id in self._aggregated_columns()
363+
for f in func
364+
]
365+
345366
agg_block, _ = self._block.aggregate(
346367
by_column_ids=self._by_col_ids,
347368
aggregations=aggregations,

tests/system/small/test_groupby.py

+17
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,23 @@ def test_dataframe_groupby_agg_list(scalars_df_index, scalars_pandas_df_index):
144144
pd.testing.assert_frame_equal(pd_result, bf_result_computed, check_dtype=False)
145145

146146

147+
def test_dataframe_groupby_agg_list_w_column_multi_index(
148+
scalars_df_index, scalars_pandas_df_index
149+
):
150+
columns = ["int64_too", "string_col", "bool_col"]
151+
multi_columns = pd.MultiIndex.from_tuples(zip(["a", "b", "a"], columns))
152+
bf_df = scalars_df_index[columns].copy()
153+
bf_df.columns = multi_columns
154+
pd_df = scalars_pandas_df_index[columns].copy()
155+
pd_df.columns = multi_columns
156+
157+
bf_result = bf_df.groupby(level=0).agg(["count", "min"])
158+
pd_result = pd_df.groupby(level=0).agg(["count", "min"])
159+
160+
bf_result_computed = bf_result.to_pandas()
161+
pd.testing.assert_frame_equal(pd_result, bf_result_computed, check_dtype=False)
162+
163+
147164
@pytest.mark.parametrize(
148165
("as_index"),
149166
[

third_party/bigframes_vendored/pandas/core/groupby/__init__.py

+22
Original file line numberDiff line numberDiff line change
@@ -1092,6 +1092,17 @@ def agg(self, func, **kwargs):
10921092
<BLANKLINE>
10931093
[2 rows x 2 columns]
10941094
1095+
Multiple aggregations
1096+
1097+
>>> df.groupby('A').agg(['min', 'max'])
1098+
B C
1099+
min max min max
1100+
A
1101+
1 1 2 0.227877 0.362838
1102+
2 3 4 -0.56286 1.267767
1103+
<BLANKLINE>
1104+
[2 rows x 4 columns]
1105+
10951106
Args:
10961107
func (function, str, list, dict or None):
10971108
Function to use for aggregating the data.
@@ -1140,6 +1151,17 @@ def aggregate(self, func, **kwargs):
11401151
<BLANKLINE>
11411152
[2 rows x 2 columns]
11421153
1154+
Multiple aggregations
1155+
1156+
>>> df.groupby('A').agg(['min', 'max'])
1157+
B C
1158+
min max min max
1159+
A
1160+
1 1 2 0.227877 0.362838
1161+
2 3 4 -0.56286 1.267767
1162+
<BLANKLINE>
1163+
[2 rows x 4 columns]
1164+
11431165
Args:
11441166
func (function, str, list, dict or None):
11451167
Function to use for aggregating the data.

0 commit comments

Comments
 (0)