Skip to content

Commit e8e66cf

Browse files
feat: Add support for numpy expm1, log1p, floor, ceil, arctan2 ops (#505)
Co-authored-by: Tim Sweña (Swast) <[email protected]>
1 parent 036649e commit e8e66cf

File tree

3 files changed

+86
-0
lines changed

3 files changed

+86
-0
lines changed

bigframes/core/compile/scalar_op_compiler.py

+54
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,13 @@ def arctan_op_impl(x: ibis_types.Value):
257257
return typing.cast(ibis_types.NumericValue, x).atan()
258258

259259

260+
@scalar_op_compiler.register_binary_op(ops.arctan2_op)
261+
def arctan2_op_impl(x: ibis_types.Value, y: ibis_types.Value):
262+
return typing.cast(ibis_types.NumericValue, x).atan2(
263+
typing.cast(ibis_types.NumericValue, y)
264+
)
265+
266+
260267
# Hyperbolic trig functions
261268
# BQ has these functions, but Ibis doesn't
262269
@scalar_op_compiler.register_unary_op(ops.sinh_op)
@@ -319,6 +326,30 @@ def arctanh_op_impl(x: ibis_types.Value):
319326

320327

321328
# Numeric Ops
329+
@scalar_op_compiler.register_unary_op(ops.floor_op)
330+
def floor_op_impl(x: ibis_types.Value):
331+
x_numeric = typing.cast(ibis_types.NumericValue, x)
332+
if x_numeric.type().is_integer():
333+
return x_numeric.cast(ibis_dtypes.Float64())
334+
if x_numeric.type().is_floating():
335+
# Default ibis impl tries to cast to integer, which doesn't match pandas and can overflow
336+
return float_floor(x_numeric)
337+
else: # numeric
338+
return x_numeric.floor()
339+
340+
341+
@scalar_op_compiler.register_unary_op(ops.ceil_op)
342+
def ceil_op_impl(x: ibis_types.Value):
343+
x_numeric = typing.cast(ibis_types.NumericValue, x)
344+
if x_numeric.type().is_integer():
345+
return x_numeric.cast(ibis_dtypes.Float64())
346+
if x_numeric.type().is_floating():
347+
# Default ibis impl tries to cast to integer, which doesn't match pandas and can overflow
348+
return float_ceil(x_numeric)
349+
else: # numeric
350+
return x_numeric.ceil()
351+
352+
322353
@scalar_op_compiler.register_unary_op(ops.abs_op)
323354
def abs_op_impl(x: ibis_types.Value):
324355
return typing.cast(ibis_types.NumericValue, x).abs()
@@ -347,13 +378,23 @@ def ln_op_impl(x: ibis_types.Value):
347378
return (~domain).ifelse(out_of_domain, numeric_value.ln())
348379

349380

381+
@scalar_op_compiler.register_unary_op(ops.log1p_op)
382+
def log1p_op_impl(x: ibis_types.Value):
383+
return ln_op_impl(_ibis_num(1) + x)
384+
385+
350386
@scalar_op_compiler.register_unary_op(ops.exp_op)
351387
def exp_op_impl(x: ibis_types.Value):
352388
numeric_value = typing.cast(ibis_types.NumericValue, x)
353389
domain = numeric_value < _FLOAT64_EXP_BOUND
354390
return (~domain).ifelse(_INF, numeric_value.exp())
355391

356392

393+
@scalar_op_compiler.register_unary_op(ops.expm1_op)
394+
def expm1_op_impl(x: ibis_types.Value):
395+
return exp_op_impl(x) - _ibis_num(1)
396+
397+
357398
@scalar_op_compiler.register_unary_op(ops.invert_op)
358399
def invert_op_impl(x: ibis_types.Value):
359400
return typing.cast(ibis_types.NumericValue, x).negate()
@@ -1318,3 +1359,16 @@ def _ibis_num(number: float):
13181359
@ibis.udf.scalar.builtin
13191360
def timestamp(a: str) -> ibis_dtypes.timestamp:
13201361
"""Convert string to timestamp."""
1362+
1363+
1364+
# Need these because ibis otherwise tries to do casts to int that can fail
1365+
@ibis.udf.scalar.builtin(name="floor")
1366+
def float_floor(a: float) -> float:
1367+
"""Convert string to timestamp."""
1368+
return 0 # pragma: NO COVER
1369+
1370+
1371+
@ibis.udf.scalar.builtin(name="ceil")
1372+
def float_ceil(a: float) -> float:
1373+
"""Convert string to timestamp."""
1374+
return 0 # pragma: NO COVER

bigframes/operations/__init__.py

+10
Original file line numberDiff line numberDiff line change
@@ -246,11 +246,16 @@ def create_ternary_op(
246246
arcsinh_op = create_unary_op(name="arcsinh", type_rule=op_typing.REAL_NUMERIC)
247247
arccosh_op = create_unary_op(name="arccosh", type_rule=op_typing.REAL_NUMERIC)
248248
arctanh_op = create_unary_op(name="arctanh", type_rule=op_typing.REAL_NUMERIC)
249+
arctan2_op = create_binary_op(name="arctan2", type_rule=op_typing.REAL_NUMERIC)
249250
## Numeric Ops
251+
floor_op = create_unary_op(name="floor", type_rule=op_typing.REAL_NUMERIC)
252+
ceil_op = create_unary_op(name="ceil", type_rule=op_typing.REAL_NUMERIC)
250253
abs_op = create_unary_op(name="abs", type_rule=op_typing.INPUT_TYPE)
251254
exp_op = create_unary_op(name="exp", type_rule=op_typing.REAL_NUMERIC)
255+
expm1_op = create_unary_op(name="expm1", type_rule=op_typing.REAL_NUMERIC)
252256
ln_op = create_unary_op(name="log", type_rule=op_typing.REAL_NUMERIC)
253257
log10_op = create_unary_op(name="log10", type_rule=op_typing.REAL_NUMERIC)
258+
log1p_op = create_unary_op(name="log1p", type_rule=op_typing.REAL_NUMERIC)
254259
sqrt_op = create_unary_op(name="sqrt", type_rule=op_typing.REAL_NUMERIC)
255260

256261

@@ -540,6 +545,10 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
540545
np.log10: log10_op,
541546
np.sqrt: sqrt_op,
542547
np.abs: abs_op,
548+
np.floor: floor_op,
549+
np.ceil: ceil_op,
550+
np.log1p: log1p_op,
551+
np.expm1: expm1_op,
543552
}
544553

545554

@@ -549,4 +558,5 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
549558
np.multiply: mul_op,
550559
np.divide: div_op,
551560
np.power: pow_op,
561+
np.arctan2: arctan2_op,
552562
}

tests/system/small/test_numpy.py

+22
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ def test_series_ufuncs(floats_pd, floats_bf, opname):
5656
("log10",),
5757
("sqrt",),
5858
("abs",),
59+
("floor",),
60+
("ceil",),
61+
("expm1",),
62+
("log1p",),
5963
],
6064
)
6165
def test_df_ufuncs(scalars_dfs, opname):
@@ -77,6 +81,7 @@ def test_df_ufuncs(scalars_dfs, opname):
7781
("multiply",),
7882
("divide",),
7983
("power",),
84+
("arctan2",),
8085
],
8186
)
8287
def test_series_binary_ufuncs(floats_product_pd, floats_product_bf, opname):
@@ -112,6 +117,23 @@ def test_df_binary_ufuncs(scalars_dfs, opname):
112117
pd.testing.assert_frame_equal(bf_result, pd_result)
113118

114119

120+
@pytest.mark.parametrize(
121+
("x", "y"),
122+
[
123+
("int64_col", "int64_col"),
124+
("float64_col", "int64_col"),
125+
],
126+
)
127+
def test_series_atan2(scalars_dfs, x, y):
128+
# Test atan2 separately as pandas errors when passing entire df as input, so pass only series
129+
scalars_df, scalars_pandas_df = scalars_dfs
130+
131+
bf_result = np.arctan2(scalars_df[x], scalars_df[y]).to_pandas()
132+
pd_result = np.arctan2(scalars_pandas_df[x], scalars_pandas_df[y])
133+
134+
pd.testing.assert_series_equal(bf_result, pd_result)
135+
136+
115137
def test_series_binary_ufuncs_reverse(scalars_dfs):
116138
scalars_df, scalars_pandas_df = scalars_dfs
117139

0 commit comments

Comments
 (0)