26
26
import bigframes .core .expression as ex
27
27
import bigframes .core .ordering as ordering
28
28
import bigframes .core .window_spec as windows
29
+ import bigframes .dtypes
29
30
import bigframes .dtypes as dtypes
30
31
import bigframes .operations as ops
31
32
import bigframes .operations .aggregations as agg_ops
@@ -409,6 +410,8 @@ def rank(
409
410
method : str = "average" ,
410
411
na_option : str = "keep" ,
411
412
ascending : bool = True ,
413
+ grouping_cols : tuple [str , ...] = (),
414
+ columns : tuple [str , ...] = (),
412
415
):
413
416
if method not in ["average" , "min" , "max" , "first" , "dense" ]:
414
417
raise ValueError (
@@ -417,8 +420,8 @@ def rank(
417
420
if na_option not in ["keep" , "top" , "bottom" ]:
418
421
raise ValueError ("na_option must be one of 'keep', 'top', or 'bottom'" )
419
422
420
- columns = block .value_columns
421
- labels = block .column_labels
423
+ columns = columns or tuple ( col for col in block .value_columns )
424
+ labels = [ block .col_id_to_label [ id ] for id in columns ]
422
425
# Step 1: Calculate row numbers for each row
423
426
# Identify null values to be treated according to na_option param
424
427
rownum_col_ids = []
@@ -442,9 +445,13 @@ def rank(
442
445
block , rownum_id = block .apply_window_op (
443
446
col if na_option == "keep" else nullity_col_id ,
444
447
agg_ops .dense_rank_op if method == "dense" else agg_ops .count_op ,
445
- window_spec = windows .unbound (ordering = window_ordering )
448
+ window_spec = windows .unbound (
449
+ grouping_keys = grouping_cols , ordering = window_ordering
450
+ )
446
451
if method == "dense"
447
- else windows .rows (following = 0 , ordering = window_ordering ),
452
+ else windows .rows (
453
+ following = 0 , ordering = window_ordering , grouping_keys = grouping_cols
454
+ ),
448
455
skip_reproject_unsafe = (col != columns [- 1 ]),
449
456
)
450
457
rownum_col_ids .append (rownum_id )
@@ -462,12 +469,32 @@ def rank(
462
469
block , result_id = block .apply_window_op (
463
470
rownum_col_ids [i ],
464
471
agg_op ,
465
- window_spec = windows .unbound (grouping_keys = (columns [i ],)),
472
+ window_spec = windows .unbound (grouping_keys = (columns [i ], * grouping_cols )),
466
473
skip_reproject_unsafe = (i < (len (columns ) - 1 )),
467
474
)
468
475
post_agg_rownum_col_ids .append (result_id )
469
476
rownum_col_ids = post_agg_rownum_col_ids
470
477
478
+ # Pandas masks all values where any grouping column is null
479
+ # Note: we use pd.NA instead of float('nan')
480
+ if grouping_cols :
481
+ predicate = functools .reduce (
482
+ ops .and_op .as_expr ,
483
+ [ops .notnull_op .as_expr (column_id ) for column_id in grouping_cols ],
484
+ )
485
+ block = block .project_exprs (
486
+ [
487
+ ops .where_op .as_expr (
488
+ ex .deref (col ),
489
+ predicate ,
490
+ ex .const (None ),
491
+ )
492
+ for col in rownum_col_ids
493
+ ],
494
+ labels = labels ,
495
+ )
496
+ rownum_col_ids = list (block .value_columns [- len (rownum_col_ids ) :])
497
+
471
498
# Step 3: post processing: mask null values and cast to float
472
499
if method in ["min" , "max" , "first" , "dense" ]:
473
500
# Pandas rank always produces Float64, so must cast for aggregation types that produce ints
0 commit comments