@@ -125,6 +125,9 @@ def _rewrite_op_expr(
125
125
# but for timedeltas: int(timedelta) // float => int(timedelta)
126
126
return _rewrite_floordiv_op (inputs [0 ], inputs [1 ])
127
127
128
+ if isinstance (expr .op , ops .ToTimedeltaOp ):
129
+ return _rewrite_to_timedelta_op (expr .op , inputs [0 ])
130
+
128
131
return _TypedExpr .create_op_expr (expr .op , * inputs )
129
132
130
133
@@ -154,9 +157,9 @@ def _rewrite_mul_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
154
157
result = _TypedExpr .create_op_expr (ops .mul_op , left , right )
155
158
156
159
if left .dtype is dtypes .TIMEDELTA_DTYPE and dtypes .is_numeric (right .dtype ):
157
- return _TypedExpr .create_op_expr (ops .ToTimedeltaOp ( "us" ) , result )
160
+ return _TypedExpr .create_op_expr (ops .timedelta_floor_op , result )
158
161
if dtypes .is_numeric (left .dtype ) and right .dtype is dtypes .TIMEDELTA_DTYPE :
159
- return _TypedExpr .create_op_expr (ops .ToTimedeltaOp ( "us" ) , result )
162
+ return _TypedExpr .create_op_expr (ops .timedelta_floor_op , result )
160
163
161
164
return result
162
165
@@ -165,7 +168,7 @@ def _rewrite_div_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
165
168
result = _TypedExpr .create_op_expr (ops .div_op , left , right )
166
169
167
170
if left .dtype is dtypes .TIMEDELTA_DTYPE and dtypes .is_numeric (right .dtype ):
168
- return _TypedExpr .create_op_expr (ops .ToTimedeltaOp ( "us" ) , result )
171
+ return _TypedExpr .create_op_expr (ops .timedelta_floor_op , result )
169
172
170
173
return result
171
174
@@ -174,11 +177,19 @@ def _rewrite_floordiv_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
174
177
result = _TypedExpr .create_op_expr (ops .floordiv_op , left , right )
175
178
176
179
if left .dtype is dtypes .TIMEDELTA_DTYPE and dtypes .is_numeric (right .dtype ):
177
- return _TypedExpr .create_op_expr (ops .ToTimedeltaOp ( "us" ) , result )
180
+ return _TypedExpr .create_op_expr (ops .timedelta_floor_op , result )
178
181
179
182
return result
180
183
181
184
185
+ def _rewrite_to_timedelta_op (op : ops .ToTimedeltaOp , arg : _TypedExpr ):
186
+ if arg .dtype is dtypes .TIMEDELTA_DTYPE :
187
+ # Do nothing for values that are already timedeltas
188
+ return arg
189
+
190
+ return _TypedExpr .create_op_expr (op , arg )
191
+
192
+
182
193
@functools .cache
183
194
def _rewrite_aggregation (
184
195
aggregation : ex .Aggregation , schema : schema .ArraySchema
0 commit comments