Skip to content

Commit 8f9ece6

Browse files
fix: infer narrowest numeric type when combining numeric columns (#602)
Co-authored-by: Tim Sweña (Swast) <[email protected]>
1 parent 9f8f181 commit 8f9ece6

File tree

14 files changed

+233
-507
lines changed

14 files changed

+233
-507
lines changed

bigframes/core/__init__.py

+75-13
Original file line numberDiff line numberDiff line change
@@ -354,10 +354,7 @@ def unpivot(
354354
*,
355355
passthrough_columns: typing.Sequence[str] = (),
356356
index_col_ids: typing.Sequence[str] = ["index"],
357-
dtype: typing.Union[
358-
bigframes.dtypes.Dtype, typing.Tuple[bigframes.dtypes.Dtype, ...]
359-
] = pandas.Float64Dtype(),
360-
how: typing.Literal["left", "right"] = "left",
357+
join_side: typing.Literal["left", "right"] = "left",
361358
) -> ArrayValue:
362359
"""
363360
Unpivot ArrayValue columns.
@@ -367,23 +364,88 @@ def unpivot(
367364
unpivot_columns: Mapping of column id to list of input column ids. Lists of input columns may use None.
368365
passthrough_columns: Columns that will not be unpivoted. Column id will be preserved.
369366
index_col_id (str): The column id to be used for the row labels.
370-
dtype (dtype or list of dtype): Dtype to use for the unpivot columns. If list, must be equal in number to unpivot_columns.
371367
372368
Returns:
373369
ArrayValue: The unpivoted ArrayValue
374370
"""
371+
# There will be N labels, used to disambiguate which of N source columns produced each output row
372+
explode_offsets_id = bigframes.core.guid.generate_guid("unpivot_offsets_")
373+
labels_array = self._create_unpivot_labels_array(row_labels, index_col_ids)
374+
labels_array = labels_array.promote_offsets(explode_offsets_id)
375+
376+
# Unpivot creates N output rows for each input row, labels disambiguate these N rows
377+
joined_array = self._cross_join_w_labels(labels_array, join_side)
378+
379+
# Build the output rows as a case statment that selects between the N input columns
380+
unpivot_exprs = []
381+
# Supports producing multiple stacked ouput columns for stacking only part of hierarchical index
382+
for col_id, input_ids in unpivot_columns:
383+
# row explode offset used to choose the input column
384+
# we use offset instead of label as labels are not necessarily unique
385+
cases = tuple(
386+
(
387+
ops.eq_op.as_expr(explode_offsets_id, ex.const(i)),
388+
ex.free_var(id_or_null)
389+
if (id_or_null is not None)
390+
else ex.const(None),
391+
)
392+
for i, id_or_null in enumerate(input_ids)
393+
)
394+
col_expr = ops.case_when_op.as_expr(*cases)
395+
unpivot_exprs.append((col_expr, col_id))
396+
397+
label_exprs = ((ex.free_var(id), id) for id in index_col_ids)
398+
# passthrough columns are unchanged, just repeated N times each
399+
passthrough_exprs = ((ex.free_var(id), id) for id in passthrough_columns)
375400
return ArrayValue(
376-
nodes.UnpivotNode(
377-
child=self.node,
378-
row_labels=tuple(row_labels),
379-
unpivot_columns=tuple(unpivot_columns),
380-
passthrough_columns=tuple(passthrough_columns),
381-
index_col_ids=tuple(index_col_ids),
382-
dtype=dtype,
383-
how=how,
401+
nodes.ProjectionNode(
402+
child=joined_array.node,
403+
assignments=(*label_exprs, *unpivot_exprs, *passthrough_exprs),
384404
)
385405
)
386406

407+
def _cross_join_w_labels(
408+
self, labels_array: ArrayValue, join_side: typing.Literal["left", "right"]
409+
) -> ArrayValue:
410+
"""
411+
Convert each row in self to N rows, one for each label in labels array.
412+
"""
413+
table_join_side = (
414+
join_def.JoinSide.LEFT if join_side == "left" else join_def.JoinSide.RIGHT
415+
)
416+
labels_join_side = table_join_side.inverse()
417+
labels_mappings = tuple(
418+
join_def.JoinColumnMapping(labels_join_side, id, id)
419+
for id in labels_array.schema.names
420+
)
421+
table_mappings = tuple(
422+
join_def.JoinColumnMapping(table_join_side, id, id)
423+
for id in self.schema.names
424+
)
425+
join = join_def.JoinDefinition(
426+
conditions=(), mappings=(*labels_mappings, *table_mappings), type="cross"
427+
)
428+
if join_side == "left":
429+
joined_array = self.join(labels_array, join_def=join)
430+
else:
431+
joined_array = labels_array.join(self, join_def=join)
432+
return joined_array
433+
434+
def _create_unpivot_labels_array(
435+
self,
436+
former_column_labels: typing.Sequence[typing.Hashable],
437+
col_ids: typing.Sequence[str],
438+
) -> ArrayValue:
439+
"""Create an ArrayValue from a list of label tuples."""
440+
rows = []
441+
for row_offset in range(len(former_column_labels)):
442+
row_label = former_column_labels[row_offset]
443+
row_label = (row_label,) if not isinstance(row_label, tuple) else row_label
444+
row = {col_ids[i]: row_label[i] for i in range(len(col_ids))}
445+
rows.append(row)
446+
447+
return ArrayValue.from_pyarrow(pa.Table.from_pylist(rows), session=self.session)
448+
387449
def join(
388450
self,
389451
other: ArrayValue,

bigframes/core/block_transforms.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -857,5 +857,5 @@ def _idx_extrema(
857857
# Stack the entire column axis to produce single-column result
858858
# Assumption: uniform dtype for stackability
859859
return block.aggregate_all_and_stack(
860-
agg_ops.AnyValueOp(), dtype=block.dtypes[0]
860+
agg_ops.AnyValueOp(),
861861
).with_column_labels([original_block.index.name])

bigframes/core/blocks.py

+2-14
Original file line numberDiff line numberDiff line change
@@ -914,9 +914,6 @@ def aggregate_all_and_stack(
914914
axis: int | str = 0,
915915
value_col_id: str = "values",
916916
dropna: bool = True,
917-
dtype: typing.Union[
918-
bigframes.dtypes.Dtype, typing.Tuple[bigframes.dtypes.Dtype, ...]
919-
] = pd.Float64Dtype(),
920917
) -> Block:
921918
axis_n = utils.get_axis_number(axis)
922919
if axis_n == 0:
@@ -931,7 +928,6 @@ def aggregate_all_and_stack(
931928
row_labels=self.column_labels.to_list(),
932929
index_col_ids=index_col_ids,
933930
unpivot_columns=tuple([(value_col_id, tuple(self.value_columns))]),
934-
dtype=dtype,
935931
)
936932
return Block(
937933
result_expr,
@@ -949,7 +945,6 @@ def aggregate_all_and_stack(
949945
index_col_ids=[guid.generate_guid()],
950946
unpivot_columns=[(value_col_id, tuple(self.value_columns))],
951947
passthrough_columns=[*self.index_columns, offset_col],
952-
dtype=dtype,
953948
)
954949
index_aggregations = [
955950
(ex.UnaryAggregation(agg_ops.AnyValueOp(), ex.free_var(col_id)), col_id)
@@ -1512,22 +1507,18 @@ def stack(self, how="left", levels: int = 1):
15121507

15131508
# Get matching columns
15141509
unpivot_columns: List[Tuple[str, List[str]]] = []
1515-
dtypes = []
15161510
for val in result_col_labels:
15171511
col_id = guid.generate_guid("unpivot_")
15181512
input_columns, dtype = self._create_stack_column(val, row_label_tuples)
15191513
unpivot_columns.append((col_id, input_columns))
1520-
if dtype:
1521-
dtypes.append(dtype or pd.Float64Dtype())
15221514

15231515
added_index_columns = [guid.generate_guid() for _ in range(row_labels.nlevels)]
15241516
unpivot_expr = self._expr.unpivot(
15251517
row_labels=row_label_tuples,
15261518
passthrough_columns=self.index_columns,
15271519
unpivot_columns=unpivot_columns,
15281520
index_col_ids=added_index_columns,
1529-
dtype=tuple(dtypes),
1530-
how=how,
1521+
join_side=how,
15311522
)
15321523
new_index_level_names = self.column_labels.names[-levels:]
15331524
if how == "left":
@@ -1559,15 +1550,12 @@ def melt(
15591550
value_labels = [self.col_id_to_label[col_id] for col_id in value_vars]
15601551
id_labels = [self.col_id_to_label[col_id] for col_id in id_vars]
15611552

1562-
dtype = self._expr.get_column_type(value_vars[0])
1563-
15641553
unpivot_expr = self._expr.unpivot(
15651554
row_labels=value_labels,
15661555
passthrough_columns=id_vars,
15671556
unpivot_columns=(unpivot_col,),
15681557
index_col_ids=var_col_ids,
1569-
dtype=dtype,
1570-
how="right",
1558+
join_side="right",
15711559
)
15721560
index_id = guid.generate_guid()
15731561
unpivot_expr = unpivot_expr.promote_offsets(index_id)

0 commit comments

Comments
 (0)