@@ -1270,6 +1270,35 @@ def combine(
1270
1270
def combine_first (self , other : DataFrame ):
1271
1271
return self ._apply_dataframe_binop (other , ops .fillna_op )
1272
1272
1273
+ def _fast_stat_matrix (self , op : agg_ops .BinaryAggregateOp ) -> DataFrame :
1274
+ """Faster corr, cov calculations, but creates more sql text, so cannot scale to many columns"""
1275
+ assert len (self .columns ) * len (self .columns ) < bigframes .constants .MAX_COLUMNS
1276
+ orig_columns = self .columns
1277
+ frame = self .copy ()
1278
+ # Replace column names with 0 to n - 1 to keep order
1279
+ # and avoid the influence of duplicated column name
1280
+ frame .columns = pandas .Index (range (len (orig_columns )))
1281
+ frame = frame .astype (bigframes .dtypes .FLOAT_DTYPE )
1282
+ block = frame ._block
1283
+
1284
+ aggregations = [
1285
+ ex .BinaryAggregation (op , ex .deref (left_col ), ex .deref (right_col ))
1286
+ for left_col in block .value_columns
1287
+ for right_col in block .value_columns
1288
+ ]
1289
+ # unique columns stops
1290
+ uniq_orig_columns = utils .combine_indices (
1291
+ orig_columns , pandas .Index (range (len (orig_columns )))
1292
+ )
1293
+ labels = utils .cross_indices (uniq_orig_columns , uniq_orig_columns )
1294
+
1295
+ block , _ = block .aggregate (aggregations = aggregations , column_labels = labels )
1296
+
1297
+ block = block .stack (levels = orig_columns .nlevels + 1 )
1298
+ # The aggregate operation crated a index level with just 0, need to drop it
1299
+ # Also, drop the last level of each index, which was created to guarantee uniqueness
1300
+ return DataFrame (block ).droplevel (0 ).droplevel (- 1 , axis = 0 ).droplevel (- 1 , axis = 1 )
1301
+
1273
1302
def corr (self , method = "pearson" , min_periods = None , numeric_only = False ) -> DataFrame :
1274
1303
if method != "pearson" :
1275
1304
raise NotImplementedError (
@@ -1285,6 +1314,10 @@ def corr(self, method="pearson", min_periods=None, numeric_only=False) -> DataFr
1285
1314
else :
1286
1315
frame = self ._drop_non_numeric ()
1287
1316
1317
+ if len (frame .columns ) <= 30 :
1318
+ return frame ._fast_stat_matrix (agg_ops .CorrOp ())
1319
+
1320
+ frame = frame .copy ()
1288
1321
orig_columns = frame .columns
1289
1322
# Replace column names with 0 to n - 1 to keep order
1290
1323
# and avoid the influence of duplicated column name
@@ -1393,6 +1426,10 @@ def cov(self, *, numeric_only: bool = False) -> DataFrame:
1393
1426
else :
1394
1427
frame = self ._drop_non_numeric ()
1395
1428
1429
+ if len (frame .columns ) <= 30 :
1430
+ return frame ._fast_stat_matrix (agg_ops .CovOp ())
1431
+
1432
+ frame = frame .copy ()
1396
1433
orig_columns = frame .columns
1397
1434
# Replace column names with 0 to n - 1 to keep order
1398
1435
# and avoid the influence of duplicated column name
0 commit comments