Skip to content

Commit c598c0a

Browse files
perf: Speed up DataFrame corr, cov (#1309)
1 parent 6785aee commit c598c0a

File tree

4 files changed

+56
-5
lines changed

4 files changed

+56
-5
lines changed

bigframes/core/utils.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def is_dict_like(obj: typing.Any) -> typing_extensions.TypeGuard[typing.Mapping]
5252

5353

5454
def combine_indices(index1: pd.Index, index2: pd.Index) -> pd.MultiIndex:
55-
"""Combines indices into multi-index while preserving dtypes, names."""
55+
"""Combines indices into multi-index while preserving dtypes, names merging by rows 1:1"""
5656
multi_index = pd.MultiIndex.from_frame(
5757
pd.concat([index1.to_frame(index=False), index2.to_frame(index=False)], axis=1)
5858
)
@@ -61,6 +61,20 @@ def combine_indices(index1: pd.Index, index2: pd.Index) -> pd.MultiIndex:
6161
return multi_index
6262

6363

64+
def cross_indices(index1: pd.Index, index2: pd.Index) -> pd.MultiIndex:
65+
"""Combines indices into multi-index while preserving dtypes, names using cross product"""
66+
multi_index = pd.MultiIndex.from_frame(
67+
pd.merge(
68+
left=index1.to_frame(index=False),
69+
right=index2.to_frame(index=False),
70+
how="cross",
71+
)
72+
)
73+
# to_frame will produce numbered default names, we don't want these
74+
multi_index.names = [*index1.names, *index2.names]
75+
return multi_index
76+
77+
6478
def index_as_tuples(index: pd.Index) -> typing.Sequence[typing.Tuple]:
6579
if isinstance(index, pd.MultiIndex):
6680
return [label for label in index]

bigframes/dataframe.py

+37
Original file line numberDiff line numberDiff line change
@@ -1270,6 +1270,35 @@ def combine(
12701270
def combine_first(self, other: DataFrame):
12711271
return self._apply_dataframe_binop(other, ops.fillna_op)
12721272

1273+
def _fast_stat_matrix(self, op: agg_ops.BinaryAggregateOp) -> DataFrame:
1274+
"""Faster corr, cov calculations, but creates more sql text, so cannot scale to many columns"""
1275+
assert len(self.columns) * len(self.columns) < bigframes.constants.MAX_COLUMNS
1276+
orig_columns = self.columns
1277+
frame = self.copy()
1278+
# Replace column names with 0 to n - 1 to keep order
1279+
# and avoid the influence of duplicated column name
1280+
frame.columns = pandas.Index(range(len(orig_columns)))
1281+
frame = frame.astype(bigframes.dtypes.FLOAT_DTYPE)
1282+
block = frame._block
1283+
1284+
aggregations = [
1285+
ex.BinaryAggregation(op, ex.deref(left_col), ex.deref(right_col))
1286+
for left_col in block.value_columns
1287+
for right_col in block.value_columns
1288+
]
1289+
# unique columns stops
1290+
uniq_orig_columns = utils.combine_indices(
1291+
orig_columns, pandas.Index(range(len(orig_columns)))
1292+
)
1293+
labels = utils.cross_indices(uniq_orig_columns, uniq_orig_columns)
1294+
1295+
block, _ = block.aggregate(aggregations=aggregations, column_labels=labels)
1296+
1297+
block = block.stack(levels=orig_columns.nlevels + 1)
1298+
# The aggregate operation crated a index level with just 0, need to drop it
1299+
# Also, drop the last level of each index, which was created to guarantee uniqueness
1300+
return DataFrame(block).droplevel(0).droplevel(-1, axis=0).droplevel(-1, axis=1)
1301+
12731302
def corr(self, method="pearson", min_periods=None, numeric_only=False) -> DataFrame:
12741303
if method != "pearson":
12751304
raise NotImplementedError(
@@ -1285,6 +1314,10 @@ def corr(self, method="pearson", min_periods=None, numeric_only=False) -> DataFr
12851314
else:
12861315
frame = self._drop_non_numeric()
12871316

1317+
if len(frame.columns) <= 30:
1318+
return frame._fast_stat_matrix(agg_ops.CorrOp())
1319+
1320+
frame = frame.copy()
12881321
orig_columns = frame.columns
12891322
# Replace column names with 0 to n - 1 to keep order
12901323
# and avoid the influence of duplicated column name
@@ -1393,6 +1426,10 @@ def cov(self, *, numeric_only: bool = False) -> DataFrame:
13931426
else:
13941427
frame = self._drop_non_numeric()
13951428

1429+
if len(frame.columns) <= 30:
1430+
return frame._fast_stat_matrix(agg_ops.CovOp())
1431+
1432+
frame = frame.copy()
13961433
orig_columns = frame.columns
13971434
# Replace column names with 0 to n - 1 to keep order
13981435
# and avoid the influence of duplicated column name

tests/system/large/test_dataframe.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# See: https://github.com/python/cpython/issues/112282
1010
reason="setrecursionlimit has no effect on the Python C stack since Python 3.12.",
1111
)
12-
def test_corr_w_numeric_only(scalars_df_numeric_150_columns_maybe_ordered):
12+
def test_corr_150_columns(scalars_df_numeric_150_columns_maybe_ordered):
1313
scalars_df, scalars_pandas_df = scalars_df_numeric_150_columns_maybe_ordered
1414
bf_result = scalars_df.corr(numeric_only=True).to_pandas()
1515
pd_result = scalars_pandas_df.corr(numeric_only=True)
@@ -28,7 +28,7 @@ def test_corr_w_numeric_only(scalars_df_numeric_150_columns_maybe_ordered):
2828
# See: https://github.com/python/cpython/issues/112282
2929
reason="setrecursionlimit has no effect on the Python C stack since Python 3.12.",
3030
)
31-
def test_cov_w_numeric_only(scalars_df_numeric_150_columns_maybe_ordered):
31+
def test_cov_150_columns(scalars_df_numeric_150_columns_maybe_ordered):
3232
scalars_df, scalars_pandas_df = scalars_df_numeric_150_columns_maybe_ordered
3333
bf_result = scalars_df.cov(numeric_only=True).to_pandas()
3434
pd_result = scalars_pandas_df.cov(numeric_only=True)

tests/system/small/test_dataframe.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2214,7 +2214,7 @@ def test_combine_first(
22142214
),
22152215
],
22162216
)
2217-
def test_corr_w_numeric_only(scalars_dfs_maybe_ordered, columns, numeric_only):
2217+
def test_df_corr_w_numeric_only(scalars_dfs_maybe_ordered, columns, numeric_only):
22182218
scalars_df, scalars_pandas_df = scalars_dfs_maybe_ordered
22192219

22202220
bf_result = scalars_df[columns].corr(numeric_only=numeric_only).to_pandas()
@@ -2228,7 +2228,7 @@ def test_corr_w_numeric_only(scalars_dfs_maybe_ordered, columns, numeric_only):
22282228
)
22292229

22302230

2231-
def test_corr_w_invalid_parameters(scalars_dfs):
2231+
def test_df_corr_w_invalid_parameters(scalars_dfs):
22322232
columns = ["int64_too", "int64_col", "float64_col"]
22332233
scalars_df, _ = scalars_dfs
22342234

0 commit comments

Comments
 (0)