Skip to content

Commit 650a190

Browse files
authored
fix: calling to_timdelta() over timedeltas no longer changes their values (#1411)
* fix: fix a bug where to_timdelta() calls over timedeltas changes their values * add tests for floats too
1 parent 2993b28 commit 650a190

File tree

5 files changed

+60
-8
lines changed

5 files changed

+60
-8
lines changed

bigframes/core/compile/scalar_op_compiler.py

+5
Original file line numberDiff line numberDiff line change
@@ -1186,6 +1186,11 @@ def to_timedelta_op_impl(x: ibis_types.Value, op: ops.ToTimedeltaOp):
11861186
).floor()
11871187

11881188

1189+
@scalar_op_compiler.register_unary_op(ops.timedelta_floor_op)
1190+
def timedelta_floor_op_impl(x: ibis_types.NumericValue):
1191+
return x.floor()
1192+
1193+
11891194
@scalar_op_compiler.register_unary_op(ops.RemoteFunctionOp, pass_op=True)
11901195
def remote_function_op_impl(x: ibis_types.Value, op: ops.RemoteFunctionOp):
11911196
ibis_node = getattr(op.func, "ibis_node", None)

bigframes/core/rewrite/timedeltas.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,9 @@ def _rewrite_op_expr(
125125
# but for timedeltas: int(timedelta) // float => int(timedelta)
126126
return _rewrite_floordiv_op(inputs[0], inputs[1])
127127

128+
if isinstance(expr.op, ops.ToTimedeltaOp):
129+
return _rewrite_to_timedelta_op(expr.op, inputs[0])
130+
128131
return _TypedExpr.create_op_expr(expr.op, *inputs)
129132

130133

@@ -154,9 +157,9 @@ def _rewrite_mul_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
154157
result = _TypedExpr.create_op_expr(ops.mul_op, left, right)
155158

156159
if left.dtype is dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype):
157-
return _TypedExpr.create_op_expr(ops.ToTimedeltaOp("us"), result)
160+
return _TypedExpr.create_op_expr(ops.timedelta_floor_op, result)
158161
if dtypes.is_numeric(left.dtype) and right.dtype is dtypes.TIMEDELTA_DTYPE:
159-
return _TypedExpr.create_op_expr(ops.ToTimedeltaOp("us"), result)
162+
return _TypedExpr.create_op_expr(ops.timedelta_floor_op, result)
160163

161164
return result
162165

@@ -165,7 +168,7 @@ def _rewrite_div_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
165168
result = _TypedExpr.create_op_expr(ops.div_op, left, right)
166169

167170
if left.dtype is dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype):
168-
return _TypedExpr.create_op_expr(ops.ToTimedeltaOp("us"), result)
171+
return _TypedExpr.create_op_expr(ops.timedelta_floor_op, result)
169172

170173
return result
171174

@@ -174,11 +177,19 @@ def _rewrite_floordiv_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
174177
result = _TypedExpr.create_op_expr(ops.floordiv_op, left, right)
175178

176179
if left.dtype is dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype):
177-
return _TypedExpr.create_op_expr(ops.ToTimedeltaOp("us"), result)
180+
return _TypedExpr.create_op_expr(ops.timedelta_floor_op, result)
178181

179182
return result
180183

181184

185+
def _rewrite_to_timedelta_op(op: ops.ToTimedeltaOp, arg: _TypedExpr):
186+
if arg.dtype is dtypes.TIMEDELTA_DTYPE:
187+
# Do nothing for values that are already timedeltas
188+
return arg
189+
190+
return _TypedExpr.create_op_expr(op, arg)
191+
192+
182193
@functools.cache
183194
def _rewrite_aggregation(
184195
aggregation: ex.Aggregation, schema: schema.ArraySchema

bigframes/operations/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@
184184
from bigframes.operations.struct_ops import StructFieldOp, StructOp
185185
from bigframes.operations.time_ops import hour_op, minute_op, normalize_op, second_op
186186
from bigframes.operations.timedelta_ops import (
187+
timedelta_floor_op,
187188
timestamp_add_op,
188189
timestamp_sub_op,
189190
ToTimedeltaOp,
@@ -259,6 +260,7 @@
259260
"second_op",
260261
"normalize_op",
261262
# Timedelta ops
263+
"timedelta_floor_op",
262264
"timestamp_add_op",
263265
"timestamp_sub_op",
264266
"ToTimedeltaOp",

bigframes/operations/timedelta_ops.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,26 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
3636

3737

3838
@dataclasses.dataclass(frozen=True)
39-
class TimestampAdd(base_ops.BinaryOp):
39+
class TimedeltaFloorOp(base_ops.UnaryOp):
40+
"""Floors the numeric value to the nearest integer and use it to represent a timedelta.
41+
42+
This operator is only meant to be used during expression tree rewrites. Do not use it anywhere else!
43+
"""
44+
45+
name: typing.ClassVar[str] = "timedelta_floor"
46+
47+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
48+
input_type = input_types[0]
49+
if dtypes.is_numeric(input_type) or input_type is dtypes.TIMEDELTA_DTYPE:
50+
return dtypes.TIMEDELTA_DTYPE
51+
raise TypeError(f"unsupported type: {input_type}")
52+
53+
54+
timedelta_floor_op = TimedeltaFloorOp()
55+
56+
57+
@dataclasses.dataclass(frozen=True)
58+
class TimestampAddOp(base_ops.BinaryOp):
4059
name: typing.ClassVar[str] = "timestamp_add"
4160

4261
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
@@ -57,10 +76,10 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
5776
)
5877

5978

60-
timestamp_add_op = TimestampAdd()
79+
timestamp_add_op = TimestampAddOp()
6180

6281

63-
class TimestampSub(base_ops.BinaryOp):
82+
class TimestampSubOp(base_ops.BinaryOp):
6483
name: typing.ClassVar[str] = "timestamp_sub"
6584

6685
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
@@ -76,4 +95,4 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
7695
)
7796

7897

79-
timestamp_sub_op = TimestampSub()
98+
timestamp_sub_op = TimestampSubOp()

tests/system/small/test_pandas.py

+15
Original file line numberDiff line numberDiff line change
@@ -829,3 +829,18 @@ def test_to_timedelta_with_bf_series_invalid_unit(session, unit):
829829
@pytest.mark.parametrize("input", [1, 1.2, "1s"])
830830
def test_to_timedelta_non_bf_series(input):
831831
assert bpd.to_timedelta(input) == pd.to_timedelta(input)
832+
833+
834+
def test_to_timedelta_on_timedelta_series__should_be_no_op(scalars_dfs):
835+
bf_df, pd_df = scalars_dfs
836+
bf_series = bpd.to_timedelta(bf_df["int64_too"], unit="us")
837+
pd_series = pd.to_timedelta(pd_df["int64_too"], unit="us")
838+
839+
actual_result = (
840+
bpd.to_timedelta(bf_series, unit="s").to_pandas().astype("timedelta64[ns]")
841+
)
842+
843+
expected_result = pd.to_timedelta(pd_series, unit="s")
844+
pd.testing.assert_series_equal(
845+
actual_result, expected_result, check_index_type=False
846+
)

0 commit comments

Comments
 (0)