@@ -354,10 +354,7 @@ def unpivot(
354
354
* ,
355
355
passthrough_columns : typing .Sequence [str ] = (),
356
356
index_col_ids : typing .Sequence [str ] = ["index" ],
357
- dtype : typing .Union [
358
- bigframes .dtypes .Dtype , typing .Tuple [bigframes .dtypes .Dtype , ...]
359
- ] = pandas .Float64Dtype (),
360
- how : typing .Literal ["left" , "right" ] = "left" ,
357
+ join_side : typing .Literal ["left" , "right" ] = "left" ,
361
358
) -> ArrayValue :
362
359
"""
363
360
Unpivot ArrayValue columns.
@@ -367,23 +364,88 @@ def unpivot(
367
364
unpivot_columns: Mapping of column id to list of input column ids. Lists of input columns may use None.
368
365
passthrough_columns: Columns that will not be unpivoted. Column id will be preserved.
369
366
index_col_id (str): The column id to be used for the row labels.
370
- dtype (dtype or list of dtype): Dtype to use for the unpivot columns. If list, must be equal in number to unpivot_columns.
371
367
372
368
Returns:
373
369
ArrayValue: The unpivoted ArrayValue
374
370
"""
371
+ # There will be N labels, used to disambiguate which of N source columns produced each output row
372
+ explode_offsets_id = bigframes .core .guid .generate_guid ("unpivot_offsets_" )
373
+ labels_array = self ._create_unpivot_labels_array (row_labels , index_col_ids )
374
+ labels_array = labels_array .promote_offsets (explode_offsets_id )
375
+
376
+ # Unpivot creates N output rows for each input row, labels disambiguate these N rows
377
+ joined_array = self ._cross_join_w_labels (labels_array , join_side )
378
+
379
+ # Build the output rows as a case statment that selects between the N input columns
380
+ unpivot_exprs = []
381
+ # Supports producing multiple stacked ouput columns for stacking only part of hierarchical index
382
+ for col_id , input_ids in unpivot_columns :
383
+ # row explode offset used to choose the input column
384
+ # we use offset instead of label as labels are not necessarily unique
385
+ cases = tuple (
386
+ (
387
+ ops .eq_op .as_expr (explode_offsets_id , ex .const (i )),
388
+ ex .free_var (id_or_null )
389
+ if (id_or_null is not None )
390
+ else ex .const (None ),
391
+ )
392
+ for i , id_or_null in enumerate (input_ids )
393
+ )
394
+ col_expr = ops .case_when_op .as_expr (* cases )
395
+ unpivot_exprs .append ((col_expr , col_id ))
396
+
397
+ label_exprs = ((ex .free_var (id ), id ) for id in index_col_ids )
398
+ # passthrough columns are unchanged, just repeated N times each
399
+ passthrough_exprs = ((ex .free_var (id ), id ) for id in passthrough_columns )
375
400
return ArrayValue (
376
- nodes .UnpivotNode (
377
- child = self .node ,
378
- row_labels = tuple (row_labels ),
379
- unpivot_columns = tuple (unpivot_columns ),
380
- passthrough_columns = tuple (passthrough_columns ),
381
- index_col_ids = tuple (index_col_ids ),
382
- dtype = dtype ,
383
- how = how ,
401
+ nodes .ProjectionNode (
402
+ child = joined_array .node ,
403
+ assignments = (* label_exprs , * unpivot_exprs , * passthrough_exprs ),
384
404
)
385
405
)
386
406
407
+ def _cross_join_w_labels (
408
+ self , labels_array : ArrayValue , join_side : typing .Literal ["left" , "right" ]
409
+ ) -> ArrayValue :
410
+ """
411
+ Convert each row in self to N rows, one for each label in labels array.
412
+ """
413
+ table_join_side = (
414
+ join_def .JoinSide .LEFT if join_side == "left" else join_def .JoinSide .RIGHT
415
+ )
416
+ labels_join_side = table_join_side .inverse ()
417
+ labels_mappings = tuple (
418
+ join_def .JoinColumnMapping (labels_join_side , id , id )
419
+ for id in labels_array .schema .names
420
+ )
421
+ table_mappings = tuple (
422
+ join_def .JoinColumnMapping (table_join_side , id , id )
423
+ for id in self .schema .names
424
+ )
425
+ join = join_def .JoinDefinition (
426
+ conditions = (), mappings = (* labels_mappings , * table_mappings ), type = "cross"
427
+ )
428
+ if join_side == "left" :
429
+ joined_array = self .join (labels_array , join_def = join )
430
+ else :
431
+ joined_array = labels_array .join (self , join_def = join )
432
+ return joined_array
433
+
434
+ def _create_unpivot_labels_array (
435
+ self ,
436
+ former_column_labels : typing .Sequence [typing .Hashable ],
437
+ col_ids : typing .Sequence [str ],
438
+ ) -> ArrayValue :
439
+ """Create an ArrayValue from a list of label tuples."""
440
+ rows = []
441
+ for row_offset in range (len (former_column_labels )):
442
+ row_label = former_column_labels [row_offset ]
443
+ row_label = (row_label ,) if not isinstance (row_label , tuple ) else row_label
444
+ row = {col_ids [i ]: row_label [i ] for i in range (len (col_ids ))}
445
+ rows .append (row )
446
+
447
+ return ArrayValue .from_pyarrow (pa .Table .from_pylist (rows ), session = self .session )
448
+
387
449
def join (
388
450
self ,
389
451
other : ArrayValue ,
0 commit comments