Skip to content

feat: df.apply(axis=1) to support remote function with mutiple params #851

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions bigframes/core/compile/scalar_op_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,19 +191,27 @@ def normalized_impl(args: typing.Sequence[ibis_types.Value], op: ops.RowOp):

return decorator

def register_nary_op(self, op_ref: typing.Union[ops.NaryOp, type[ops.NaryOp]]):
def register_nary_op(
self, op_ref: typing.Union[ops.NaryOp, type[ops.NaryOp]], pass_op: bool = False
):
"""
Decorator to register a nary op implementation.

Args:
op_ref (NaryOp or NaryOp type):
Class or instance of operator that is implemented by the decorated function.
pass_op (bool):
Set to true if implementation takes the operator object as the last argument.
This is needed for parameterized ops where parameters are part of op object.
"""
key = typing.cast(str, op_ref.name)

def decorator(impl: typing.Callable[..., ibis_types.Value]):
def normalized_impl(args: typing.Sequence[ibis_types.Value], op: ops.RowOp):
return impl(*args)
if pass_op:
return impl(*args, op=op)
else:
return impl(*args)

self._register(key, normalized_impl)
return impl
Expand Down Expand Up @@ -1444,6 +1452,7 @@ def clip_op(
)


# N-ary Operations
@scalar_op_compiler.register_nary_op(ops.case_when_op)
def case_when_op(*cases_and_outputs: ibis_types.Value) -> ibis_types.Value:
# ibis can handle most type coercions, but we need to force bool -> int
Expand All @@ -1463,6 +1472,19 @@ def case_when_op(*cases_and_outputs: ibis_types.Value) -> ibis_types.Value:
return case_val.end()


@scalar_op_compiler.register_nary_op(ops.NaryRemoteFunctionOp, pass_op=True)
def nary_remote_function_op_impl(
*operands: ibis_types.Value, op: ops.NaryRemoteFunctionOp
):
ibis_node = getattr(op.func, "ibis_node", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't need to do this in the current pr - but we need to move away from storing ibis values in the op definition. We will want to generate this at compile-time only to allow non-ibis compilation.

Copy link
Contributor Author

@shobsi shobsi Jul 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack. b/356686746

if ibis_node is None:
raise TypeError(
f"only a bigframes remote function is supported as a callable. {constants.FEEDBACK_LINK}"
)
result = ibis_node(*operands)
return result


# Helpers
def is_null(value) -> bool:
# float NaN/inf should be treated as distinct from 'true' null values
Expand Down
148 changes: 91 additions & 57 deletions bigframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3433,9 +3433,9 @@ def map(self, func, na_action: Optional[str] = None) -> DataFrame:
raise ValueError(f"na_action={na_action} not supported")

# TODO(shobs): Support **kwargs
# Reproject as workaround to applying filter too late. This forces the filter
# to be applied before passing data to remote function, protecting from bad
# inputs causing errors.
# Reproject as workaround to applying filter too late. This forces the
# filter to be applied before passing data to remote function,
# protecting from bad inputs causing errors.
reprojected_df = DataFrame(self._block._force_reproject())
return reprojected_df._apply_unary_op(
ops.RemoteFunctionOp(func=func, apply_on_null=(na_action is None))
Expand All @@ -3448,65 +3448,99 @@ def apply(self, func, *, axis=0, args: typing.Tuple = (), **kwargs):
category=bigframes.exceptions.PreviewWarning,
)

# Early check whether the dataframe dtypes are currently supported
# in the remote function
# NOTE: Keep in sync with the value converters used in the gcf code
# generated in remote_function_template.py
remote_function_supported_dtypes = (
bigframes.dtypes.INT_DTYPE,
bigframes.dtypes.FLOAT_DTYPE,
bigframes.dtypes.BOOL_DTYPE,
bigframes.dtypes.BYTES_DTYPE,
bigframes.dtypes.STRING_DTYPE,
)
supported_dtypes_types = tuple(
type(dtype)
for dtype in remote_function_supported_dtypes
if not isinstance(dtype, pandas.ArrowDtype)
)
# Check ArrowDtype separately since multiple BigQuery types map to
# ArrowDtype, including BYTES and TIMESTAMP.
supported_arrow_types = tuple(
dtype.pyarrow_dtype
for dtype in remote_function_supported_dtypes
if isinstance(dtype, pandas.ArrowDtype)
)
supported_dtypes_hints = tuple(
str(dtype) for dtype in remote_function_supported_dtypes
)

for dtype in self.dtypes:
if (
# Not one of the pandas/numpy types.
not isinstance(dtype, supported_dtypes_types)
# And not one of the arrow types.
and not (
isinstance(dtype, pandas.ArrowDtype)
and any(
dtype.pyarrow_dtype.equals(arrow_type)
for arrow_type in supported_arrow_types
)
)
):
raise NotImplementedError(
f"DataFrame has a column of dtype '{dtype}' which is not supported with axis=1."
f" Supported dtypes are {supported_dtypes_hints}."
)

# Check if the function is a remote function
if not hasattr(func, "bigframes_remote_function"):
raise ValueError("For axis=1 a remote function must be used.")

# Serialize the rows as json values
block = self._get_block()
rows_as_json_series = bigframes.series.Series(
block._get_rows_as_json_values()
)
is_row_processor = getattr(func, "is_row_processor")
if is_row_processor:
# Early check whether the dataframe dtypes are currently supported
# in the remote function
# NOTE: Keep in sync with the value converters used in the gcf code
# generated in remote_function_template.py
remote_function_supported_dtypes = (
bigframes.dtypes.INT_DTYPE,
bigframes.dtypes.FLOAT_DTYPE,
bigframes.dtypes.BOOL_DTYPE,
bigframes.dtypes.BYTES_DTYPE,
bigframes.dtypes.STRING_DTYPE,
)
supported_dtypes_types = tuple(
type(dtype)
for dtype in remote_function_supported_dtypes
if not isinstance(dtype, pandas.ArrowDtype)
)
# Check ArrowDtype separately since multiple BigQuery types map to
# ArrowDtype, including BYTES and TIMESTAMP.
supported_arrow_types = tuple(
dtype.pyarrow_dtype
for dtype in remote_function_supported_dtypes
if isinstance(dtype, pandas.ArrowDtype)
)
supported_dtypes_hints = tuple(
str(dtype) for dtype in remote_function_supported_dtypes
)

# Apply the function
result_series = rows_as_json_series._apply_unary_op(
ops.RemoteFunctionOp(func=func, apply_on_null=True)
)
for dtype in self.dtypes:
if (
# Not one of the pandas/numpy types.
not isinstance(dtype, supported_dtypes_types)
# And not one of the arrow types.
and not (
isinstance(dtype, pandas.ArrowDtype)
and any(
dtype.pyarrow_dtype.equals(arrow_type)
for arrow_type in supported_arrow_types
)
)
):
raise NotImplementedError(
f"DataFrame has a column of dtype '{dtype}' which is not supported with axis=1."
f" Supported dtypes are {supported_dtypes_hints}."
)

# Serialize the rows as json values
block = self._get_block()
rows_as_json_series = bigframes.series.Series(
block._get_rows_as_json_values()
)

# Apply the function
result_series = rows_as_json_series._apply_unary_op(
ops.RemoteFunctionOp(func=func, apply_on_null=True)
)
else:
# This is a special case where we are providing not-pandas-like
# extension. If the remote function can take one or more params
# then we assume that here the user intention is to use the
# column values of the dataframe as arguments to the function.
# For this to work the following condition must be true:
# 1. The number or input params in the function must be same
# as the number of columns in the dataframe
# 2. The dtypes of the columns in the dataframe must be
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to accept compatible dtypes? eg the column is int, but the function takes decimal?

Copy link
Contributor Author

@shobsi shobsi Jul 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

function taking decimal is not something we support right now. There is a longer term desire to expand the datatype support.

RF_SUPPORTED_IO_PYTHON_TYPES = {bool, bytes, float, int, str}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having said that, are there other places where we reconcile dtypes across dataframes or in operations?

# compatible with the data types of the input params
# 3. The order of the columns in the dataframe must correspond
# to the order of the input params in the function
udf_input_dtypes = getattr(func, "input_dtypes")
if len(udf_input_dtypes) != len(self.columns):
raise ValueError(
f"Remote function takes {len(udf_input_dtypes)} arguments but DataFrame has {len(self.columns)} columns."
)
if udf_input_dtypes != tuple(self.dtypes.to_list()):
raise ValueError(
f"Remote function takes arguments of types {udf_input_dtypes} but DataFrame dtypes are {tuple(self.dtypes)}."
)

series_list = [self[col] for col in self.columns]
# Reproject as workaround to applying filter too late. This forces the
# filter to be applied before passing data to remote function,
# protecting from bad inputs causing errors.
reprojected_series = bigframes.series.Series(
series_list[0]._block._force_reproject()
)
Comment on lines +3538 to +3540
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wyh do we need a force_reproject?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am just copy-pasting the pattern introduced in this change: ff3bb89#diff-8718ceb6a8f6b68d7b06a15e84043fb866c500d5bfb1f33ad8c945f06815a140

Is the reasoning (got a bit detached unintentionally, sitting at the beginning of the function) still valid?

# Reproject as workaround to applying filter too late. This forces the filter
# to be applied before passing data to remote function, protecting from bad
# inputs causing errors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it does make a difference, quickly tested in #874 and series.mask doctest is failing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved the comment back close to reproject, PTAL

result_series = reprojected_series._apply_nary_op(
ops.NaryRemoteFunctionOp(func=func), series_list[1:]
)
result_series.name = None

# Return Series with materialized result so that any error in the remote
Expand Down
4 changes: 4 additions & 0 deletions bigframes/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,7 @@ class QueryComplexityError(RuntimeError):

class TimeTravelDisabledWarning(Warning):
"""A query was reattempted without time travel."""


class UnknownDataTypeWarning(Warning):
"""Data type is unknown."""
40 changes: 37 additions & 3 deletions bigframes/functions/remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
from bigframes import clients
import bigframes.constants as constants
import bigframes.core.compile.ibis_types
import bigframes.dtypes
import bigframes.functions.remote_function_template

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -895,8 +896,8 @@ def remote_function(
reuse (bool, Optional):
Reuse the remote function if already exists.
`True` by default, which will result in reusing an existing remote
function and corresponding cloud function (if any) that was
previously created for the same udf.
function and corresponding cloud function that was previously
created (if any) for the same udf.
Please note that for an unnamed (i.e. created without an explicit
`name` argument) remote function, the BigQuery DataFrames
session id is attached in the cloud artifacts names. So for the
Expand Down Expand Up @@ -1174,7 +1175,9 @@ def try_delattr(attr):

try_delattr("bigframes_cloud_function")
try_delattr("bigframes_remote_function")
try_delattr("input_dtypes")
try_delattr("output_dtype")
try_delattr("is_row_processor")
try_delattr("ibis_node")

(
Expand Down Expand Up @@ -1216,12 +1219,20 @@ def try_delattr(attr):
rf_name
)
)

func.input_dtypes = tuple(
[
bigframes.core.compile.ibis_types.ibis_dtype_to_bigframes_dtype(
input_type
)
for input_type in ibis_signature.input_types
]
)
func.output_dtype = (
bigframes.core.compile.ibis_types.ibis_dtype_to_bigframes_dtype(
ibis_signature.output_type
)
)
func.is_row_processor = is_row_processor
func.ibis_node = node

# If a new remote function was created, update the cloud artifacts
Expand Down Expand Up @@ -1305,6 +1316,29 @@ def func(*ignored_args, **ignored_kwargs):
signature=(ibis_signature.input_types, ibis_signature.output_type),
)
func.bigframes_remote_function = str(routine_ref) # type: ignore

# set input bigframes data types
has_unknown_dtypes = False
function_input_dtypes = []
for ibis_type in ibis_signature.input_types:
input_dtype = cast(bigframes.dtypes.Dtype, bigframes.dtypes.DEFAULT_DTYPE)
if ibis_type is None:
has_unknown_dtypes = True
else:
input_dtype = (
bigframes.core.compile.ibis_types.ibis_dtype_to_bigframes_dtype(
ibis_type
)
)
function_input_dtypes.append(input_dtype)
if has_unknown_dtypes:
warnings.warn(
"The function has one or more missing input data types."
f" BigQuery DataFrames will assume default data type {bigframes.dtypes.DEFAULT_DTYPE} for them.",
category=bigframes.exceptions.UnknownDataTypeWarning,
)
func.input_dtypes = tuple(function_input_dtypes) # type: ignore

func.output_dtype = bigframes.core.compile.ibis_types.ibis_dtype_to_bigframes_dtype( # type: ignore
ibis_signature.output_type
)
Expand Down
13 changes: 13 additions & 0 deletions bigframes/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,19 @@ def output_type(self, *input_types):
raise AttributeError("output_dtype not defined")


@dataclasses.dataclass(frozen=True)
class NaryRemoteFunctionOp(NaryOp):
name: typing.ClassVar[str] = "nary_remote_function"
func: typing.Callable

def output_type(self, *input_types):
# This property should be set to a valid Dtype by the @remote_function decorator or read_gbq_function method
if hasattr(self.func, "output_dtype"):
return self.func.output_dtype
else:
raise AttributeError("output_dtype not defined")


add_op = AddOp()
sub_op = SubOp()
mul_op = create_binary_op(name="mul", type_signature=op_typing.BINARY_NUMERIC)
Expand Down
11 changes: 7 additions & 4 deletions bigframes/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1442,9 +1442,6 @@ def apply(
) -> Series:
# TODO(shobs, b/274645634): Support convert_dtype, args, **kwargs
# is actually a ternary op
# Reproject as workaround to applying filter too late. This forces the filter
# to be applied before passing data to remote function, protecting from bad
# inputs causing errors.

if by_row not in ["compat", False]:
raise ValueError("Param by_row must be one of 'compat' or False")
Expand Down Expand Up @@ -1474,7 +1471,10 @@ def apply(
ex.message += f"\n{_remote_function_recommendation_message}"
raise

# We are working with remote function at this point
# We are working with remote function at this point.
# Reproject as workaround to applying filter too late. This forces the
# filter to be applied before passing data to remote function,
# protecting from bad inputs causing errors.
reprojected_series = Series(self._block._force_reproject())
result_series = reprojected_series._apply_unary_op(
ops.RemoteFunctionOp(func=func, apply_on_null=True)
Expand Down Expand Up @@ -1507,6 +1507,9 @@ def combine(
ex.message += f"\n{_remote_function_recommendation_message}"
raise

# Reproject as workaround to applying filter too late. This forces the
# filter to be applied before passing data to remote function,
# protecting from bad inputs causing errors.
reprojected_series = Series(self._block._force_reproject())
result_series = reprojected_series._apply_binary_op(
other, ops.BinaryRemoteFunctionOp(func=func)
Expand Down
4 changes: 2 additions & 2 deletions bigframes/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1661,8 +1661,8 @@ def remote_function(
reuse (bool, Optional):
Reuse the remote function if already exists.
`True` by default, which will result in reusing an existing remote
function and corresponding cloud function (if any) that was
previously created for the same udf.
function and corresponding cloud function that was previously
created (if any) for the same udf.
Please note that for an unnamed (i.e. created without an explicit
`name` argument) remote function, the BigQuery DataFrames
session id is attached in the cloud artifacts names. So for the
Expand Down
Loading