Skip to content

Commit 3a633d5

Browse files
feat: Add groupby.rank() (#1433)
1 parent ddfd02a commit 3a633d5

File tree

5 files changed

+265
-6
lines changed

5 files changed

+265
-6
lines changed

bigframes/core/block_transforms.py

+32-5
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import bigframes.core.expression as ex
2727
import bigframes.core.ordering as ordering
2828
import bigframes.core.window_spec as windows
29+
import bigframes.dtypes
2930
import bigframes.dtypes as dtypes
3031
import bigframes.operations as ops
3132
import bigframes.operations.aggregations as agg_ops
@@ -409,6 +410,8 @@ def rank(
409410
method: str = "average",
410411
na_option: str = "keep",
411412
ascending: bool = True,
413+
grouping_cols: tuple[str, ...] = (),
414+
columns: tuple[str, ...] = (),
412415
):
413416
if method not in ["average", "min", "max", "first", "dense"]:
414417
raise ValueError(
@@ -417,8 +420,8 @@ def rank(
417420
if na_option not in ["keep", "top", "bottom"]:
418421
raise ValueError("na_option must be one of 'keep', 'top', or 'bottom'")
419422

420-
columns = block.value_columns
421-
labels = block.column_labels
423+
columns = columns or tuple(col for col in block.value_columns)
424+
labels = [block.col_id_to_label[id] for id in columns]
422425
# Step 1: Calculate row numbers for each row
423426
# Identify null values to be treated according to na_option param
424427
rownum_col_ids = []
@@ -442,9 +445,13 @@ def rank(
442445
block, rownum_id = block.apply_window_op(
443446
col if na_option == "keep" else nullity_col_id,
444447
agg_ops.dense_rank_op if method == "dense" else agg_ops.count_op,
445-
window_spec=windows.unbound(ordering=window_ordering)
448+
window_spec=windows.unbound(
449+
grouping_keys=grouping_cols, ordering=window_ordering
450+
)
446451
if method == "dense"
447-
else windows.rows(following=0, ordering=window_ordering),
452+
else windows.rows(
453+
following=0, ordering=window_ordering, grouping_keys=grouping_cols
454+
),
448455
skip_reproject_unsafe=(col != columns[-1]),
449456
)
450457
rownum_col_ids.append(rownum_id)
@@ -462,12 +469,32 @@ def rank(
462469
block, result_id = block.apply_window_op(
463470
rownum_col_ids[i],
464471
agg_op,
465-
window_spec=windows.unbound(grouping_keys=(columns[i],)),
472+
window_spec=windows.unbound(grouping_keys=(columns[i], *grouping_cols)),
466473
skip_reproject_unsafe=(i < (len(columns) - 1)),
467474
)
468475
post_agg_rownum_col_ids.append(result_id)
469476
rownum_col_ids = post_agg_rownum_col_ids
470477

478+
# Pandas masks all values where any grouping column is null
479+
# Note: we use pd.NA instead of float('nan')
480+
if grouping_cols:
481+
predicate = functools.reduce(
482+
ops.and_op.as_expr,
483+
[ops.notnull_op.as_expr(column_id) for column_id in grouping_cols],
484+
)
485+
block = block.project_exprs(
486+
[
487+
ops.where_op.as_expr(
488+
ex.deref(col),
489+
predicate,
490+
ex.const(None),
491+
)
492+
for col in rownum_col_ids
493+
],
494+
labels=labels,
495+
)
496+
rownum_col_ids = list(block.value_columns[-len(rownum_col_ids) :])
497+
471498
# Step 3: post processing: mask null values and cast to float
472499
if method in ["min", "max", "first", "dense"]:
473500
# Pandas rank always produces Float64, so must cast for aggregation types that produce ints

bigframes/core/compile/ibis_types.py

+1
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,7 @@ def literal_to_ibis_scalar(
397397
)
398398
# "correct" way would be to use ibis.array, but this produces invalid BQ SQL syntax
399399
return tuple(literal)
400+
400401
if not pd.api.types.is_list_like(literal) and pd.isna(literal):
401402
if ibis_dtype:
402403
return bigframes_vendored.ibis.null().cast(ibis_dtype)

bigframes/core/groupby/__init__.py

+28
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,20 @@ def median(self, numeric_only: bool = False, *, exact: bool = True) -> df.DataFr
174174
return self.quantile(0.5)
175175
return self._aggregate_all(agg_ops.median_op, numeric_only=True)
176176

177+
def rank(
178+
self, method="average", ascending: bool = True, na_option: str = "keep"
179+
) -> df.DataFrame:
180+
return df.DataFrame(
181+
block_ops.rank(
182+
self._block,
183+
method,
184+
na_option,
185+
ascending,
186+
grouping_cols=tuple(self._by_col_ids),
187+
columns=tuple(self._selected_cols),
188+
)
189+
)
190+
177191
def quantile(
178192
self, q: Union[float, Sequence[float]] = 0.5, *, numeric_only: bool = False
179193
) -> df.DataFrame:
@@ -574,6 +588,20 @@ def sum(self, *args) -> series.Series:
574588
def mean(self, *args) -> series.Series:
575589
return self._aggregate(agg_ops.mean_op)
576590

591+
def rank(
592+
self, method="average", ascending: bool = True, na_option: str = "keep"
593+
) -> series.Series:
594+
return series.Series(
595+
block_ops.rank(
596+
self._block,
597+
method,
598+
na_option,
599+
ascending,
600+
grouping_cols=tuple(self._by_col_ids),
601+
columns=(self._value_column,),
602+
)
603+
)
604+
577605
def median(
578606
self,
579607
*args,

tests/system/small/test_groupby.py

+133-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import pytest
1717

1818
import bigframes.pandas as bpd
19-
from tests.system.utils import assert_pandas_df_equal
19+
from tests.system.utils import assert_pandas_df_equal, skip_legacy_pandas
2020

2121
# =================
2222
# DataFrame.groupby
@@ -94,6 +94,72 @@ def test_dataframe_groupby_quantile(scalars_df_index, scalars_pandas_df_index, q
9494
)
9595

9696

97+
@skip_legacy_pandas
98+
@pytest.mark.parametrize(
99+
("na_option", "method", "ascending"),
100+
[
101+
(
102+
"keep",
103+
"average",
104+
True,
105+
),
106+
(
107+
"top",
108+
"min",
109+
False,
110+
),
111+
(
112+
"bottom",
113+
"max",
114+
False,
115+
),
116+
(
117+
"top",
118+
"first",
119+
False,
120+
),
121+
(
122+
"bottom",
123+
"dense",
124+
False,
125+
),
126+
],
127+
)
128+
def test_dataframe_groupby_rank(
129+
scalars_df_index,
130+
scalars_pandas_df_index,
131+
na_option,
132+
method,
133+
ascending,
134+
):
135+
col_names = ["int64_too", "float64_col", "int64_col", "string_col"]
136+
bf_result = (
137+
scalars_df_index[col_names]
138+
.groupby("string_col")
139+
.rank(
140+
na_option=na_option,
141+
method=method,
142+
ascending=ascending,
143+
)
144+
).to_pandas()
145+
pd_result = (
146+
(
147+
scalars_pandas_df_index[col_names]
148+
.groupby("string_col")
149+
.rank(
150+
na_option=na_option,
151+
method=method,
152+
ascending=ascending,
153+
)
154+
)
155+
.astype("float64")
156+
.astype("Float64")
157+
)
158+
pd.testing.assert_frame_equal(
159+
pd_result, bf_result, check_dtype=False, check_index_type=False
160+
)
161+
162+
97163
@pytest.mark.parametrize(
98164
("operator"),
99165
[
@@ -534,6 +600,72 @@ def test_series_groupby_agg_list(scalars_df_index, scalars_pandas_df_index):
534600
)
535601

536602

603+
@skip_legacy_pandas
604+
@pytest.mark.parametrize(
605+
("na_option", "method", "ascending"),
606+
[
607+
(
608+
"keep",
609+
"average",
610+
True,
611+
),
612+
(
613+
"top",
614+
"min",
615+
False,
616+
),
617+
(
618+
"bottom",
619+
"max",
620+
False,
621+
),
622+
(
623+
"top",
624+
"first",
625+
False,
626+
),
627+
(
628+
"bottom",
629+
"dense",
630+
False,
631+
),
632+
],
633+
)
634+
def test_series_groupby_rank(
635+
scalars_df_index,
636+
scalars_pandas_df_index,
637+
na_option,
638+
method,
639+
ascending,
640+
):
641+
col_names = ["int64_col", "string_col"]
642+
bf_result = (
643+
scalars_df_index[col_names]
644+
.groupby("string_col")["int64_col"]
645+
.rank(
646+
na_option=na_option,
647+
method=method,
648+
ascending=ascending,
649+
)
650+
).to_pandas()
651+
pd_result = (
652+
(
653+
scalars_pandas_df_index[col_names]
654+
.groupby("string_col")["int64_col"]
655+
.rank(
656+
na_option=na_option,
657+
method=method,
658+
ascending=ascending,
659+
)
660+
)
661+
.astype("float64")
662+
.astype("Float64")
663+
)
664+
pd.testing.assert_series_equal(
665+
pd_result, bf_result, check_dtype=False, check_index_type=False
666+
)
667+
668+
537669
@pytest.mark.parametrize("dropna", [True, False])
538670
def test_series_groupby_head(scalars_df_index, scalars_pandas_df_index, dropna):
539671
bf_result = (

third_party/bigframes_vendored/pandas/core/groupby/__init__.py

+71
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,77 @@ def var(
363363
"""
364364
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
365365

366+
def rank(
367+
self,
368+
method: str = "average",
369+
ascending: bool = True,
370+
na_option: str = "keep",
371+
):
372+
"""
373+
Provide the rank of values within each group.
374+
375+
**Examples:**
376+
377+
>>> import bigframes.pandas as bpd
378+
>>> import numpy as np
379+
>>> bpd.options.display.progress_bar = None
380+
381+
>>> df = bpd.DataFrame(
382+
... {
383+
... "group": ["a", "a", "a", "a", "a", "b", "b", "b", "b", "b"],
384+
... "value": [2, 4, 2, 3, 5, 1, 2, 4, 1, 5],
385+
... }
386+
... )
387+
>>> df
388+
group value
389+
0 a 2
390+
1 a 4
391+
2 a 2
392+
3 a 3
393+
4 a 5
394+
5 b 1
395+
6 b 2
396+
7 b 4
397+
8 b 1
398+
9 b 5
399+
<BLANKLINE>
400+
[10 rows x 2 columns]
401+
>>> for method in ['average', 'min', 'max', 'dense', 'first']:
402+
... df[f'{method}_rank'] = df.groupby('group')['value'].rank(method)
403+
>>> df
404+
group value average_rank min_rank max_rank dense_rank first_rank
405+
0 a 2 1.5 1.0 2.0 1.0 1.0
406+
1 a 4 4.0 4.0 4.0 3.0 4.0
407+
2 a 2 1.5 1.0 2.0 1.0 2.0
408+
3 a 3 3.0 3.0 3.0 2.0 3.0
409+
4 a 5 5.0 5.0 5.0 4.0 5.0
410+
5 b 1 1.5 1.0 2.0 1.0 1.0
411+
6 b 2 3.0 3.0 3.0 2.0 3.0
412+
7 b 4 4.0 4.0 4.0 3.0 4.0
413+
8 b 1 1.5 1.0 2.0 1.0 2.0
414+
9 b 5 5.0 5.0 5.0 4.0 5.0
415+
<BLANKLINE>
416+
[10 rows x 7 columns]
417+
418+
Args:
419+
method ({'average', 'min', 'max', 'first', 'dense'}, default 'average'):
420+
* average: average rank of group.
421+
* min: lowest rank in group.
422+
* max: highest rank in group.
423+
* first: ranks assigned in order they appear in the array.
424+
* dense: like 'min', but rank always increases by 1 between groups.
425+
ascending (bool, default True):
426+
False for ranks by high (1) to low (N).
427+
na_option ({'keep', 'top', 'bottom'}, default 'keep'):
428+
* keep: leave NA values where they are.
429+
* top: smallest rank if ascending.
430+
* bottom: smallest rank if descending.
431+
432+
Returns:
433+
DataFrame with ranking of values within each group
434+
"""
435+
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
436+
366437
def skew(
367438
self,
368439
*,

0 commit comments

Comments
 (0)