Skip to content

Commit 0145656

Browse files
perf: Simplify sum aggregate SQL text (#1395)
1 parent 7990262 commit 0145656

File tree

3 files changed

+8
-10
lines changed

3 files changed

+8
-10
lines changed

bigframes/core/compile/aggregate_compiler.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,7 @@ def _(
164164
) -> ibis_types.NumericValue:
165165
# Will be null if all inputs are null. Pandas defaults to zero sum though.
166166
bq_sum = _apply_window_if_present(column.sum(), window)
167-
return (
168-
ibis_api.case().when(bq_sum.isnull(), ibis_types.literal(0)).else_(bq_sum).end() # type: ignore
169-
)
167+
return bq_sum.fillna(ibis_types.literal(0))
170168

171169

172170
@compile_unary_agg.register

bigframes/core/compile/compiled.py

-5
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,6 @@ def aggregate(
205205
self,
206206
aggregations: typing.Sequence[tuple[ex.Aggregation, str]],
207207
by_column_ids: typing.Sequence[ex.DerefOp] = (),
208-
dropna: bool = True,
209208
order_by: typing.Sequence[OrderingExpression] = (),
210209
) -> UnorderedIR:
211210
"""
@@ -230,10 +229,6 @@ def aggregate(
230229
for aggregate, col_out in aggregations
231230
}
232231
if by_column_ids:
233-
if dropna:
234-
table = table.filter(
235-
[table[ref.id.sql].notnull() for ref in by_column_ids]
236-
)
237232
result = table.group_by((ref.id.sql for ref in by_column_ids)).aggregate(
238233
**stats
239234
)

bigframes/core/compile/compiler.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import google.cloud.bigquery
2626
import pandas as pd
2727

28-
from bigframes import dtypes
28+
from bigframes import dtypes, operations
2929
from bigframes.core import utils
3030
import bigframes.core.compile.compiled as compiled
3131
import bigframes.core.compile.concat as concat_impl
@@ -278,8 +278,13 @@ def compile_rowcount(self, node: nodes.RowCountNode):
278278
def compile_aggregate(self, node: nodes.AggregateNode):
279279
aggs = tuple((agg, id.sql) for agg, id in node.aggregations)
280280
result = self.compile_node(node.child).aggregate(
281-
aggs, node.by_column_ids, node.dropna, order_by=node.order_by
281+
aggs, node.by_column_ids, order_by=node.order_by
282282
)
283+
# TODO: Remove dropna field and use filter node instead
284+
if node.dropna:
285+
for key in node.by_column_ids:
286+
if node.child.field_by_id[key.id].nullable:
287+
result = result.filter(operations.notnull_op.as_expr(key))
283288
return result
284289

285290
@_compile_node.register

0 commit comments

Comments
 (0)