Skip to content

Commit 34d01b2

Browse files
authored
chore: support comparison, ordering, and filtering for timedeltas (#1387)
* [WIP] support timedelta ordering and filtering * chore: support comparison, ordering, and filtering for timedeltas * fix format * some cleanups * use operator package for testing * fix test error
1 parent 44f4137 commit 34d01b2

File tree

3 files changed

+157
-32
lines changed

3 files changed

+157
-32
lines changed

bigframes/core/rewrite/timedeltas.py

+33-22
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
import dataclasses
1618
import functools
1719
import typing
@@ -27,6 +29,14 @@ class _TypedExpr:
2729
expr: ex.Expression
2830
dtype: dtypes.Dtype
2931

32+
@classmethod
33+
def create_op_expr(
34+
cls, op: typing.Union[ops.ScalarOp, ops.RowOp], *inputs: _TypedExpr
35+
) -> _TypedExpr:
36+
expr = op.as_expr(*tuple(x.expr for x in inputs)) # type: ignore
37+
dtype = op.output_type(*tuple(x.dtype for x in inputs))
38+
return cls(expr, dtype)
39+
3040

3141
def rewrite_timedelta_expressions(root: nodes.BigFrameNode) -> nodes.BigFrameNode:
3242
"""
@@ -38,12 +48,27 @@ def rewrite_timedelta_expressions(root: nodes.BigFrameNode) -> nodes.BigFrameNod
3848
(_rewrite_expressions(expr, root.schema).expr, column_id)
3949
for expr, column_id in root.assignments
4050
)
41-
root = nodes.ProjectionNode(root.child, updated_assignments)
51+
return nodes.ProjectionNode(root.child, updated_assignments)
52+
53+
if isinstance(root, nodes.FilterNode):
54+
return nodes.FilterNode(
55+
root.child, _rewrite_expressions(root.predicate, root.schema).expr
56+
)
57+
58+
if isinstance(root, nodes.OrderByNode):
59+
by = tuple(_rewrite_ordering_expr(x, root.schema) for x in root.by)
60+
return nodes.OrderByNode(root.child, by)
4261

43-
# TODO(b/394354614): FilterByNode and OrderNode also contain expressions. Need to update them too.
4462
return root
4563

4664

65+
def _rewrite_ordering_expr(
66+
expr: nodes.OrderingExpression, schema: schema.ArraySchema
67+
) -> nodes.OrderingExpression:
68+
by = _rewrite_expressions(expr.scalar_expression, schema).expr
69+
return nodes.OrderingExpression(by, expr.direction, expr.na_last)
70+
71+
4772
@functools.cache
4873
def _rewrite_expressions(expr: ex.Expression, schema: schema.ArraySchema) -> _TypedExpr:
4974
if isinstance(expr, ex.DerefOp):
@@ -78,37 +103,23 @@ def _rewrite_op_expr(
78103
if isinstance(expr.op, ops.AddOp):
79104
return _rewrite_add_op(inputs[0], inputs[1])
80105

81-
input_types = tuple(map(lambda x: x.dtype, inputs))
82-
return _TypedExpr(expr, expr.op.output_type(*input_types))
106+
return _TypedExpr.create_op_expr(expr.op, *inputs)
83107

84108

85109
def _rewrite_sub_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
86-
result_op: ops.BinaryOp = ops.sub_op
87110
if dtypes.is_datetime_like(left.dtype) and dtypes.is_datetime_like(right.dtype):
88-
result_op = ops.timestamp_diff_op
111+
return _TypedExpr.create_op_expr(ops.timestamp_diff_op, left, right)
89112

90-
return _TypedExpr(
91-
result_op.as_expr(left.expr, right.expr),
92-
result_op.output_type(left.dtype, right.dtype),
93-
)
113+
return _TypedExpr.create_op_expr(ops.sub_op, left, right)
94114

95115

96116
def _rewrite_add_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
97117
if dtypes.is_datetime_like(left.dtype) and right.dtype is dtypes.TIMEDELTA_DTYPE:
98-
return _TypedExpr(
99-
ops.timestamp_add_op.as_expr(left.expr, right.expr),
100-
ops.timestamp_add_op.output_type(left.dtype, right.dtype),
101-
)
118+
return _TypedExpr.create_op_expr(ops.timestamp_add_op, left, right)
102119

103120
if left.dtype is dtypes.TIMEDELTA_DTYPE and dtypes.is_datetime_like(right.dtype):
104121
# Re-arrange operands such that timestamp is always on the left and timedelta is
105122
# always on the right.
106-
return _TypedExpr(
107-
ops.timestamp_add_op.as_expr(right.expr, left.expr),
108-
ops.timestamp_add_op.output_type(right.dtype, left.dtype),
109-
)
123+
return _TypedExpr.create_op_expr(ops.timestamp_add_op, right, left)
110124

111-
return _TypedExpr(
112-
ops.add_op.as_expr(left.expr, right.expr),
113-
ops.add_op.output_type(left.dtype, right.dtype),
114-
)
125+
return _TypedExpr.create_op_expr(ops.add_op, left, right)

bigframes/dtypes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def is_comparable(type_: ExpressionType) -> bool:
358358

359359
def is_orderable(type_: ExpressionType) -> bool:
360360
# On BQ side, ARRAY, STRUCT, GEOGRAPHY, JSON are not orderable
361-
return type_ in _ORDERABLE_SIMPLE_TYPES
361+
return type_ in _ORDERABLE_SIMPLE_TYPES or type_ is TIMEDELTA_DTYPE
362362

363363

364364
_CLUSTERABLE_SIMPLE_TYPES = set(

tests/system/small/operations/test_timedeltas.py

+123-9
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
import datetime
17+
import operator
1718

1819
import numpy as np
1920
import pandas as pd
@@ -28,12 +29,23 @@ def temporal_dfs(session):
2829
"datetime_col": [
2930
pd.Timestamp("2025-02-01 01:00:01"),
3031
pd.Timestamp("2019-01-02 02:00:00"),
32+
pd.Timestamp("1997-01-01 19:00:00"),
3133
],
3234
"timestamp_col": [
3335
pd.Timestamp("2023-01-01 01:00:01", tz="UTC"),
3436
pd.Timestamp("2024-01-02 02:00:00", tz="UTC"),
37+
pd.Timestamp("2005-03-05 02:00:00", tz="UTC"),
38+
],
39+
"timedelta_col_1": [
40+
pd.Timedelta(3, "s"),
41+
pd.Timedelta(-4, "d"),
42+
pd.Timedelta(5, "h"),
43+
],
44+
"timedelta_col_2": [
45+
pd.Timedelta(2, "s"),
46+
pd.Timedelta(-4, "d"),
47+
pd.Timedelta(6, "h"),
3548
],
36-
"timedelta_col": [pd.Timedelta(3, "s"), pd.Timedelta(-4, "d")],
3749
}
3850
)
3951

@@ -53,10 +65,10 @@ def test_timestamp_add__ts_series_plus_td_series(temporal_dfs, column, pd_dtype)
5365
bf_df, pd_df = temporal_dfs
5466

5567
actual_result = (
56-
(bf_df[column] + bf_df["timedelta_col"]).to_pandas().astype(pd_dtype)
68+
(bf_df[column] + bf_df["timedelta_col_1"]).to_pandas().astype(pd_dtype)
5769
)
5870

59-
expected_result = pd_df[column] + pd_df["timedelta_col"]
71+
expected_result = pd_df[column] + pd_df["timedelta_col_1"]
6072
pandas.testing.assert_series_equal(
6173
actual_result, expected_result, check_index_type=False
6274
)
@@ -94,10 +106,10 @@ def test_timestamp_add__td_series_plus_ts_series(temporal_dfs, column, pd_dtype)
94106
bf_df, pd_df = temporal_dfs
95107

96108
actual_result = (
97-
(bf_df["timedelta_col"] + bf_df[column]).to_pandas().astype(pd_dtype)
109+
(bf_df["timedelta_col_1"] + bf_df[column]).to_pandas().astype(pd_dtype)
98110
)
99111

100-
expected_result = pd_df["timedelta_col"] + pd_df[column]
112+
expected_result = pd_df["timedelta_col_1"] + pd_df[column]
101113
pandas.testing.assert_series_equal(
102114
actual_result, expected_result, check_index_type=False
103115
)
@@ -120,10 +132,10 @@ def test_timestamp_add__ts_literal_plus_td_series(temporal_dfs):
120132
timestamp = pd.Timestamp("2025-01-01", tz="UTC")
121133

122134
actual_result = (
123-
(timestamp + bf_df["timedelta_col"]).to_pandas().astype("datetime64[ns, UTC]")
135+
(timestamp + bf_df["timedelta_col_1"]).to_pandas().astype("datetime64[ns, UTC]")
124136
)
125137

126-
expected_result = timestamp + pd_df["timedelta_col"]
138+
expected_result = timestamp + pd_df["timedelta_col_1"]
127139
pandas.testing.assert_series_equal(
128140
actual_result, expected_result, check_index_type=False
129141
)
@@ -140,10 +152,10 @@ def test_timestamp_add_with_numpy_op(temporal_dfs, column, pd_dtype):
140152
bf_df, pd_df = temporal_dfs
141153

142154
actual_result = (
143-
np.add(bf_df[column], bf_df["timedelta_col"]).to_pandas().astype(pd_dtype)
155+
np.add(bf_df[column], bf_df["timedelta_col_1"]).to_pandas().astype(pd_dtype)
144156
)
145157

146-
expected_result = np.add(pd_df[column], pd_df["timedelta_col"])
158+
expected_result = np.add(pd_df[column], pd_df["timedelta_col_1"])
147159
pandas.testing.assert_series_equal(
148160
actual_result, expected_result, check_index_type=False
149161
)
@@ -164,3 +176,105 @@ def test_timestamp_add_dataframes(temporal_dfs):
164176
pandas.testing.assert_frame_equal(
165177
actual_result, expected_result, check_index_type=False
166178
)
179+
180+
181+
@pytest.mark.parametrize(
182+
"compare_func",
183+
[
184+
pytest.param(operator.gt, id="gt"),
185+
pytest.param(operator.ge, id="ge"),
186+
pytest.param(operator.eq, id="eq"),
187+
pytest.param(operator.ne, id="ne"),
188+
pytest.param(operator.lt, id="lt"),
189+
pytest.param(operator.le, id="le"),
190+
],
191+
)
192+
def test_timedelta_series_comparison(temporal_dfs, compare_func):
193+
bf_df, pd_df = temporal_dfs
194+
195+
actual_result = compare_func(
196+
bf_df["timedelta_col_1"], bf_df["timedelta_col_2"]
197+
).to_pandas()
198+
199+
expected_result = compare_func(
200+
pd_df["timedelta_col_1"], pd_df["timedelta_col_2"]
201+
).astype("boolean")
202+
pandas.testing.assert_series_equal(
203+
actual_result, expected_result, check_index_type=False
204+
)
205+
206+
207+
@pytest.mark.parametrize(
208+
"compare_func",
209+
[
210+
pytest.param(operator.gt, id="gt"),
211+
pytest.param(operator.ge, id="ge"),
212+
pytest.param(operator.eq, id="eq"),
213+
pytest.param(operator.ne, id="ne"),
214+
pytest.param(operator.lt, id="lt"),
215+
pytest.param(operator.le, id="le"),
216+
],
217+
)
218+
def test_timedelta_series_and_literal_comparison(temporal_dfs, compare_func):
219+
bf_df, pd_df = temporal_dfs
220+
literal = pd.Timedelta(3, "s")
221+
222+
actual_result = compare_func(literal, bf_df["timedelta_col_2"]).to_pandas()
223+
224+
expected_result = compare_func(literal, pd_df["timedelta_col_2"]).astype("boolean")
225+
pandas.testing.assert_series_equal(
226+
actual_result, expected_result, check_index_type=False
227+
)
228+
229+
230+
def test_timedelta_filtering(session):
231+
pd_series = pd.Series(
232+
[
233+
pd.Timestamp("2025-01-01 01:00:00"),
234+
pd.Timestamp("2025-01-01 02:00:00"),
235+
pd.Timestamp("2025-01-01 03:00:00"),
236+
]
237+
)
238+
bf_series = session.read_pandas(pd_series)
239+
timestamp = pd.Timestamp("2025-01-01, 00:00:01")
240+
241+
actual_result = (
242+
bf_series[((bf_series - timestamp) > pd.Timedelta(1, "h"))]
243+
.to_pandas()
244+
.astype("<M8[ns]")
245+
)
246+
247+
expected_result = pd_series[(pd_series - timestamp) > pd.Timedelta(1, "h")]
248+
pandas.testing.assert_series_equal(
249+
actual_result, expected_result, check_index_type=False
250+
)
251+
252+
253+
def test_timedelta_ordering(session):
254+
pd_df = pd.DataFrame(
255+
{
256+
"col_1": [
257+
pd.Timestamp("2025-01-01 01:00:00"),
258+
pd.Timestamp("2025-01-01 02:00:00"),
259+
pd.Timestamp("2025-01-01 03:00:00"),
260+
],
261+
"col_2": [
262+
pd.Timestamp("2025-01-01 01:00:02"),
263+
pd.Timestamp("2025-01-01 02:00:01"),
264+
pd.Timestamp("2025-01-01 02:59:59"),
265+
],
266+
}
267+
)
268+
bf_df = session.read_pandas(pd_df)
269+
270+
actual_result = (
271+
(bf_df["col_2"] - bf_df["col_1"])
272+
.sort_values()
273+
.to_pandas()
274+
.astype("timedelta64[ns]")
275+
)
276+
277+
expected_result = (pd_df["col_2"] - pd_df["col_1"]).sort_values()
278+
pandas.testing.assert_series_equal(
279+
actual_result, expected_result, check_index_type=False
280+
)

0 commit comments

Comments
 (0)