Skip to content

Commit 4a84714

Browse files
feat: add df.unstack (#63)
1 parent d8910d4 commit 4a84714

File tree

7 files changed

+241
-122
lines changed

7 files changed

+241
-122
lines changed

bigframes/core/__init__.py

+79-34
Original file line numberDiff line numberDiff line change
@@ -963,10 +963,11 @@ def unpivot(
963963
],
964964
*,
965965
passthrough_columns: typing.Sequence[str] = (),
966-
index_col_id: str = "index",
966+
index_col_ids: typing.Sequence[str] = ["index"],
967967
dtype: typing.Union[
968968
bigframes.dtypes.Dtype, typing.Sequence[bigframes.dtypes.Dtype]
969969
] = pandas.Float64Dtype(),
970+
how="left",
970971
) -> ArrayValue:
971972
"""
972973
Unpivot ArrayValue columns.
@@ -981,8 +982,11 @@ def unpivot(
981982
Returns:
982983
ArrayValue: The unpivoted ArrayValue
983984
"""
984-
table = self._to_ibis_expr(ordering_mode="offset_col")
985+
if how not in ("left", "right"):
986+
raise ValueError("'how' must be 'left' or 'right'")
987+
table = self._to_ibis_expr(ordering_mode="unordered", expose_hidden_cols=True)
985988
row_n = len(row_labels)
989+
hidden_col_ids = self._hidden_ordering_column_names.keys()
986990
if not all(
987991
len(source_columns) == row_n for _, source_columns in unpivot_columns
988992
):
@@ -992,33 +996,44 @@ def unpivot(
992996
unpivot_table = table.cross_join(
993997
ibis.memtable({unpivot_offset_id: range(row_n)})
994998
)
995-
unpivot_offsets_value = (
996-
(
997-
(unpivot_table[ORDER_ID_COLUMN] * row_n)
998-
+ unpivot_table[unpivot_offset_id]
999-
)
1000-
.cast(ibis_dtypes.int64)
1001-
.name(ORDER_ID_COLUMN),
1002-
)
1003-
1004999
# Use ibis memtable to infer type of rowlabels (if possible)
10051000
# TODO: Allow caller to specify dtype
1006-
labels_ibis_type = ibis.memtable({"col": row_labels})["col"].type()
1007-
labels_dtype = bigframes.dtypes.ibis_dtype_to_bigframes_dtype(labels_ibis_type)
1008-
cases = [
1009-
(
1010-
i,
1011-
bigframes.dtypes.literal_to_ibis_scalar(
1012-
row_labels[i], force_dtype=labels_dtype # type:ignore
1013-
),
1014-
)
1015-
for i in range(len(row_labels))
1001+
if isinstance(row_labels[0], tuple):
1002+
labels_table = ibis.memtable(row_labels)
1003+
labels_ibis_types = [
1004+
labels_table[col].type() for col in labels_table.columns
1005+
]
1006+
else:
1007+
labels_ibis_types = [ibis.memtable({"col": row_labels})["col"].type()]
1008+
labels_dtypes = [
1009+
bigframes.dtypes.ibis_dtype_to_bigframes_dtype(ibis_type)
1010+
for ibis_type in labels_ibis_types
10161011
]
1017-
labels_value = (
1018-
typing.cast(ibis_types.IntegerColumn, unpivot_table[unpivot_offset_id])
1019-
.cases(cases, default=None) # type:ignore
1020-
.name(index_col_id)
1021-
)
1012+
1013+
label_columns = []
1014+
for label_part, (col_id, label_dtype) in enumerate(
1015+
zip(index_col_ids, labels_dtypes)
1016+
):
1017+
# interpret as tuples even if it wasn't originally so can apply same logic for multi-column labels
1018+
labels_as_tuples = [
1019+
label if isinstance(label, tuple) else (label,) for label in row_labels
1020+
]
1021+
cases = [
1022+
(
1023+
i,
1024+
bigframes.dtypes.literal_to_ibis_scalar(
1025+
label_tuple[label_part], # type:ignore
1026+
force_dtype=label_dtype, # type:ignore
1027+
),
1028+
)
1029+
for i, label_tuple in enumerate(labels_as_tuples)
1030+
]
1031+
labels_value = (
1032+
typing.cast(ibis_types.IntegerColumn, unpivot_table[unpivot_offset_id])
1033+
.cases(cases, default=None) # type:ignore
1034+
.name(col_id)
1035+
)
1036+
label_columns.append(labels_value)
10221037

10231038
unpivot_values = []
10241039
for j in range(len(unpivot_columns)):
@@ -1042,23 +1057,53 @@ def unpivot(
10421057
unpivot_values.append(unpivot_value.name(result_col))
10431058

10441059
unpivot_table = unpivot_table.select(
1045-
passthrough_columns, labels_value, *unpivot_values, unpivot_offsets_value
1060+
passthrough_columns,
1061+
*label_columns,
1062+
*unpivot_values,
1063+
*hidden_col_ids,
1064+
unpivot_offset_id,
10461065
)
10471066

1067+
# Extend the original ordering using unpivot_offset_id
1068+
old_ordering = self._ordering
1069+
if how == "left":
1070+
new_ordering = ExpressionOrdering(
1071+
ordering_value_columns=[
1072+
*old_ordering.ordering_value_columns,
1073+
OrderingColumnReference(unpivot_offset_id),
1074+
],
1075+
total_ordering_columns=frozenset(
1076+
[*old_ordering.total_ordering_columns, unpivot_offset_id]
1077+
),
1078+
)
1079+
else: # how=="right"
1080+
new_ordering = ExpressionOrdering(
1081+
ordering_value_columns=[
1082+
OrderingColumnReference(unpivot_offset_id),
1083+
*old_ordering.ordering_value_columns,
1084+
],
1085+
total_ordering_columns=frozenset(
1086+
[*old_ordering.total_ordering_columns, unpivot_offset_id]
1087+
),
1088+
)
10481089
value_columns = [
10491090
unpivot_table[value_col_id] for value_col_id, _ in unpivot_columns
10501091
]
10511092
passthrough_values = [unpivot_table[col] for col in passthrough_columns]
1093+
hidden_ordering_columns = [
1094+
unpivot_table[unpivot_offset_id],
1095+
*[unpivot_table[hidden_col] for hidden_col in hidden_col_ids],
1096+
]
10521097
return ArrayValue(
10531098
session=self._session,
10541099
table=unpivot_table,
1055-
columns=[unpivot_table[index_col_id], *value_columns, *passthrough_values],
1056-
hidden_ordering_columns=[unpivot_table[ORDER_ID_COLUMN]],
1057-
ordering=ExpressionOrdering(
1058-
ordering_value_columns=[OrderingColumnReference(ORDER_ID_COLUMN)],
1059-
integer_encoding=IntegerEncoding(is_encoded=True, is_sequential=True),
1060-
total_ordering_columns=frozenset([ORDER_ID_COLUMN]),
1061-
),
1100+
columns=[
1101+
*[unpivot_table[col_id] for col_id in index_col_ids],
1102+
*value_columns,
1103+
*passthrough_values,
1104+
],
1105+
hidden_ordering_columns=hidden_ordering_columns,
1106+
ordering=new_ordering,
10621107
)
10631108

10641109
def assign(self, source_id: str, destination_id: str) -> ArrayValue:

bigframes/core/blocks.py

+55-88
Original file line numberDiff line numberDiff line change
@@ -838,7 +838,7 @@ def aggregate_all_and_stack(
838838
]
839839
result_expr = self.expr.aggregate(aggregations, dropna=dropna).unpivot(
840840
row_labels=self.column_labels.to_list(),
841-
index_col_id="index",
841+
index_col_ids=["index"],
842842
unpivot_columns=[(value_col_id, self.value_columns)],
843843
dtype=dtype,
844844
)
@@ -849,7 +849,7 @@ def aggregate_all_and_stack(
849849
expr_with_offsets, offset_col = self.expr.promote_offsets()
850850
stacked_expr = expr_with_offsets.unpivot(
851851
row_labels=self.column_labels.to_list(),
852-
index_col_id=guid.generate_guid(),
852+
index_col_ids=[guid.generate_guid()],
853853
unpivot_columns=[(value_col_id, self.value_columns)],
854854
passthrough_columns=[*self.index_columns, offset_col],
855855
dtype=dtype,
@@ -1041,7 +1041,7 @@ def summarize(
10411041
expr = self.expr.aggregate(aggregations).unpivot(
10421042
labels,
10431043
unpivot_columns=columns,
1044-
index_col_id=label_col_id,
1044+
index_col_ids=[label_col_id],
10451045
)
10461046
labels = self._get_labels_for_columns(column_ids)
10471047
return Block(expr, column_labels=labels, index_columns=[label_col_id])
@@ -1225,116 +1225,83 @@ def pivot(
12251225

12261226
return result_block.with_column_labels(column_index)
12271227

1228-
def stack(self):
1228+
def stack(self, how="left", dropna=True, sort=True, levels: int = 1):
12291229
"""Unpivot last column axis level into row axis"""
1230-
if isinstance(self.column_labels, pd.MultiIndex):
1231-
return self._stack_multi()
1232-
else:
1233-
return self._stack_mono()
1234-
1235-
def _stack_mono(self):
1236-
if isinstance(self.column_labels, pd.MultiIndex):
1237-
raise ValueError("Expected single level index")
1238-
12391230
# These are the values that will be turned into rows
1240-
stack_values = self.column_labels.drop_duplicates().sort_values()
12411231

1242-
# Get matching columns
1243-
unpivot_columns: List[Tuple[str, List[str]]] = []
1244-
dtypes: List[bigframes.dtypes.Dtype] = []
1245-
col_id = guid.generate_guid("unpivot_")
1246-
dtype = None
1247-
input_columns: Sequence[Optional[str]] = []
1248-
for uvalue in stack_values:
1249-
matching_ids = self.label_to_col_id.get(uvalue, [])
1250-
input_id = matching_ids[0] if len(matching_ids) > 0 else None
1251-
if input_id:
1252-
if dtype and dtype != self._column_type(input_id):
1253-
raise NotImplementedError(
1254-
"Cannot stack columns with non-matching dtypes."
1255-
)
1256-
else:
1257-
dtype = self._column_type(input_id)
1258-
input_columns.append(input_id)
1259-
unpivot_columns.append((col_id, input_columns))
1260-
if dtype:
1261-
dtypes.append(dtype or pd.Float64Dtype())
1232+
col_labels, row_labels = utils.split_index(self.column_labels, levels=levels)
1233+
if dropna:
1234+
row_labels = row_labels.drop_duplicates()
1235+
if sort:
1236+
row_labels = row_labels.sort_values()
12621237

1263-
added_index_column = col_id = guid.generate_guid()
1264-
unpivot_expr = self._expr.unpivot(
1265-
row_labels=stack_values,
1266-
passthrough_columns=self.index_columns,
1267-
unpivot_columns=unpivot_columns,
1268-
index_col_id=added_index_column,
1269-
dtype=dtypes,
1270-
)
1271-
block = Block(
1272-
unpivot_expr,
1273-
index_columns=[*self.index_columns, added_index_column],
1274-
column_labels=[None],
1275-
index_labels=[*self._index_labels, self.column_labels.names[-1]],
1276-
)
1277-
return block
1278-
1279-
def _stack_multi(self):
1280-
if not isinstance(self.column_labels, pd.MultiIndex):
1281-
raise ValueError("Expected multi-index")
1282-
1283-
# These are the values that will be turned into rows
1284-
stack_values = (
1285-
self.column_labels.get_level_values(-1).drop_duplicates().sort_values()
1286-
)
1238+
row_label_tuples = utils.index_as_tuples(row_labels)
12871239

1288-
result_col_labels = (
1289-
self.column_labels.droplevel(-1)
1290-
.drop_duplicates()
1291-
.sort_values()
1292-
.dropna(how="all")
1293-
)
1240+
if col_labels is not None:
1241+
result_index = col_labels.drop_duplicates().sort_values().dropna(how="all")
1242+
result_col_labels = utils.index_as_tuples(result_index)
1243+
else:
1244+
result_index = pd.Index([None])
1245+
result_col_labels = list([()])
12941246

12951247
# Get matching columns
12961248
unpivot_columns: List[Tuple[str, List[str]]] = []
12971249
dtypes = []
12981250
for val in result_col_labels:
12991251
col_id = guid.generate_guid("unpivot_")
1300-
dtype = None
1301-
input_columns: Sequence[Optional[str]] = []
1302-
for uvalue in stack_values:
1303-
# Need to unpack if still a multi-index after dropping 1 level
1304-
label_to_match = (
1305-
(val, uvalue) if result_col_labels.nlevels == 1 else (*val, uvalue)
1306-
)
1307-
matching_ids = self.label_to_col_id.get(label_to_match, [])
1308-
input_id = matching_ids[0] if len(matching_ids) > 0 else None
1309-
if input_id:
1310-
if dtype and dtype != self._column_type(input_id):
1311-
raise NotImplementedError(
1312-
"Cannot stack columns with non-matching dtypes."
1313-
)
1314-
else:
1315-
dtype = self._column_type(input_id)
1316-
input_columns.append(input_id)
1317-
# Input column i is the first one that
1252+
input_columns, dtype = self._create_stack_column(val, row_label_tuples)
13181253
unpivot_columns.append((col_id, input_columns))
13191254
if dtype:
13201255
dtypes.append(dtype or pd.Float64Dtype())
13211256

1322-
added_index_column = col_id = guid.generate_guid()
1257+
added_index_columns = [guid.generate_guid() for _ in range(row_labels.nlevels)]
13231258
unpivot_expr = self._expr.unpivot(
1324-
row_labels=stack_values,
1259+
row_labels=row_label_tuples,
13251260
passthrough_columns=self.index_columns,
13261261
unpivot_columns=unpivot_columns,
1327-
index_col_id=added_index_column,
1262+
index_col_ids=added_index_columns,
13281263
dtype=dtypes,
1264+
how=how,
13291265
)
1266+
new_index_level_names = self.column_labels.names[-levels:]
1267+
if how == "left":
1268+
index_columns = [*self.index_columns, *added_index_columns]
1269+
index_labels = [*self._index_labels, *new_index_level_names]
1270+
else:
1271+
index_columns = [*added_index_columns, *self.index_columns]
1272+
index_labels = [*new_index_level_names, *self._index_labels]
1273+
13301274
block = Block(
13311275
unpivot_expr,
1332-
index_columns=[*self.index_columns, added_index_column],
1333-
column_labels=result_col_labels,
1334-
index_labels=[*self._index_labels, self.column_labels.names[-1]],
1276+
index_columns=index_columns,
1277+
column_labels=result_index,
1278+
index_labels=index_labels,
13351279
)
13361280
return block
13371281

1282+
def _create_stack_column(
1283+
self, col_label: typing.Tuple, stack_labels: typing.Sequence[typing.Tuple]
1284+
):
1285+
dtype = None
1286+
input_columns: list[Optional[str]] = []
1287+
for uvalue in stack_labels:
1288+
label_to_match = (*col_label, *uvalue)
1289+
label_to_match = (
1290+
label_to_match[0] if len(label_to_match) == 1 else label_to_match
1291+
)
1292+
matching_ids = self.label_to_col_id.get(label_to_match, [])
1293+
input_id = matching_ids[0] if len(matching_ids) > 0 else None
1294+
if input_id:
1295+
if dtype and dtype != self._column_type(input_id):
1296+
raise NotImplementedError(
1297+
"Cannot stack columns with non-matching dtypes."
1298+
)
1299+
else:
1300+
dtype = self._column_type(input_id)
1301+
input_columns.append(input_id)
1302+
# Input column i is the first one that
1303+
return input_columns, dtype or pd.Float64Dtype()
1304+
13381305
def _column_type(self, col_id: str) -> bigframes.dtypes.Dtype:
13391306
col_offset = self.value_columns.index(col_id)
13401307
dtype = self.dtypes[col_offset]

bigframes/core/utils.py

+20
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,26 @@ def combine_indices(index1: pd.Index, index2: pd.Index) -> pd.MultiIndex:
4949
return multi_index
5050

5151

52+
def index_as_tuples(index: pd.Index) -> typing.Sequence[typing.Tuple]:
53+
if isinstance(index, pd.MultiIndex):
54+
return [label for label in index]
55+
else:
56+
return [(label,) for label in index]
57+
58+
59+
def split_index(
60+
index: pd.Index, levels: int = 1
61+
) -> typing.Tuple[typing.Optional[pd.Index], pd.Index]:
62+
nlevels = index.nlevels
63+
remaining = nlevels - levels
64+
if remaining > 0:
65+
return index.droplevel(list(range(remaining, nlevels))), index.droplevel(
66+
list(range(0, remaining))
67+
)
68+
else:
69+
return (None, index)
70+
71+
5272
def get_standardized_ids(
5373
col_labels: Iterable[Hashable], idx_labels: Iterable[Hashable] = ()
5474
) -> tuple[list[str], list[str]]:

0 commit comments

Comments
 (0)