@@ -995,7 +995,7 @@ def filter(self, predicate: scalars.Expression):
995
995
996
996
def aggregate_all_and_stack (
997
997
self ,
998
- operation : agg_ops .UnaryAggregateOp ,
998
+ operation : typing . Union [ agg_ops .UnaryAggregateOp , agg_ops . NullaryAggregateOp ] ,
999
999
* ,
1000
1000
axis : int | str = 0 ,
1001
1001
value_col_id : str = "values" ,
@@ -1004,7 +1004,12 @@ def aggregate_all_and_stack(
1004
1004
axis_n = utils .get_axis_number (axis )
1005
1005
if axis_n == 0 :
1006
1006
aggregations = [
1007
- (ex .UnaryAggregation (operation , ex .free_var (col_id )), col_id )
1007
+ (
1008
+ ex .UnaryAggregation (operation , ex .free_var (col_id ))
1009
+ if isinstance (operation , agg_ops .UnaryAggregateOp )
1010
+ else ex .NullaryAggregation (operation ),
1011
+ col_id ,
1012
+ )
1008
1013
for col_id in self .value_columns
1009
1014
]
1010
1015
index_id = guid .generate_guid ()
@@ -1033,6 +1038,11 @@ def aggregate_all_and_stack(
1033
1038
(ex .UnaryAggregation (agg_ops .AnyValueOp (), ex .free_var (col_id )), col_id )
1034
1039
for col_id in [* self .index_columns ]
1035
1040
]
1041
+ # TODO: may need add NullaryAggregation in main_aggregation
1042
+ # when agg add support for axis=1, needed for agg("size", axis=1)
1043
+ assert isinstance (
1044
+ operation , agg_ops .UnaryAggregateOp
1045
+ ), f"Expected a unary operation, but got { operation } . Please report this error and how you got here to the BigQuery DataFrames team (bit.ly/bigframes-feedback)."
1036
1046
main_aggregation = (
1037
1047
ex .UnaryAggregation (operation , ex .free_var (value_col_id )),
1038
1048
value_col_id ,
@@ -1125,7 +1135,11 @@ def remap_f(x):
1125
1135
def aggregate (
1126
1136
self ,
1127
1137
by_column_ids : typing .Sequence [str ] = (),
1128
- aggregations : typing .Sequence [typing .Tuple [str , agg_ops .UnaryAggregateOp ]] = (),
1138
+ aggregations : typing .Sequence [
1139
+ typing .Tuple [
1140
+ str , typing .Union [agg_ops .UnaryAggregateOp , agg_ops .NullaryAggregateOp ]
1141
+ ]
1142
+ ] = (),
1129
1143
* ,
1130
1144
dropna : bool = True ,
1131
1145
) -> typing .Tuple [Block , typing .Sequence [str ]]:
@@ -1139,7 +1153,9 @@ def aggregate(
1139
1153
"""
1140
1154
agg_specs = [
1141
1155
(
1142
- ex .UnaryAggregation (operation , ex .free_var (input_id )),
1156
+ ex .UnaryAggregation (operation , ex .free_var (input_id ))
1157
+ if isinstance (operation , agg_ops .UnaryAggregateOp )
1158
+ else ex .NullaryAggregation (operation ),
1143
1159
guid .generate_guid (),
1144
1160
)
1145
1161
for input_id , operation in aggregations
@@ -1175,18 +1191,32 @@ def aggregate(
1175
1191
output_col_ids ,
1176
1192
)
1177
1193
1178
- def get_stat (self , column_id : str , stat : agg_ops .UnaryAggregateOp ):
1194
+ def get_stat (
1195
+ self ,
1196
+ column_id : str ,
1197
+ stat : typing .Union [agg_ops .UnaryAggregateOp , agg_ops .NullaryAggregateOp ],
1198
+ ):
1179
1199
"""Gets aggregates immediately, and caches it"""
1180
1200
if stat .name in self ._stats_cache [column_id ]:
1181
1201
return self ._stats_cache [column_id ][stat .name ]
1182
1202
1183
1203
# TODO: Convert nonstandard stats into standard stats where possible (popvar, etc.)
1184
1204
# if getting a standard stat, just go get the rest of them
1185
- standard_stats = self ._standard_stats (column_id )
1205
+ standard_stats = typing .cast (
1206
+ typing .Sequence [
1207
+ typing .Union [agg_ops .UnaryAggregateOp , agg_ops .NullaryAggregateOp ]
1208
+ ],
1209
+ self ._standard_stats (column_id ),
1210
+ )
1186
1211
stats_to_fetch = standard_stats if stat in standard_stats else [stat ]
1187
1212
1188
1213
aggregations = [
1189
- (ex .UnaryAggregation (stat , ex .free_var (column_id )), stat .name )
1214
+ (
1215
+ ex .UnaryAggregation (stat , ex .free_var (column_id ))
1216
+ if isinstance (stat , agg_ops .UnaryAggregateOp )
1217
+ else ex .NullaryAggregation (stat ),
1218
+ stat .name ,
1219
+ )
1190
1220
for stat in stats_to_fetch
1191
1221
]
1192
1222
expr = self .expr .aggregate (aggregations )
@@ -1231,13 +1261,20 @@ def get_binary_stat(
1231
1261
def summarize (
1232
1262
self ,
1233
1263
column_ids : typing .Sequence [str ],
1234
- stats : typing .Sequence [agg_ops .UnaryAggregateOp ],
1264
+ stats : typing .Sequence [
1265
+ typing .Union [agg_ops .UnaryAggregateOp , agg_ops .NullaryAggregateOp ]
1266
+ ],
1235
1267
):
1236
1268
"""Get a list of stats as a deferred block object."""
1237
1269
label_col_id = guid .generate_guid ()
1238
1270
labels = [stat .name for stat in stats ]
1239
1271
aggregations = [
1240
- (ex .UnaryAggregation (stat , ex .free_var (col_id )), f"{ col_id } -{ stat .name } " )
1272
+ (
1273
+ ex .UnaryAggregation (stat , ex .free_var (col_id ))
1274
+ if isinstance (stat , agg_ops .UnaryAggregateOp )
1275
+ else ex .NullaryAggregation (stat ),
1276
+ f"{ col_id } -{ stat .name } " ,
1277
+ )
1241
1278
for stat in stats
1242
1279
for col_id in column_ids
1243
1280
]
0 commit comments