Skip to content

Commit f6bdc4a

Browse files
authored
feat: Support axis=1 in df.apply for scalar outputs (#629)
* feat: Support `axis=1` in `df.apply` for scalar outputs * avoid mixing other changes in the input_types param * use guid instead of hard coded column name * check_exact=False to avoid failing system_prerelease * handle index in remote function, add large system tests * make the test case more robust * handle non-string column names, add unsupported dtype tests * fix import * use `_cached` in df.apply to catch any rf execution errors early * add test for row aggregates * add row dtype information, also test * preserve the order of input in the output * absorb to_numpy() disparity in prerelease tests * add tests for column multiindex and non remote function * add preview note for row processing * add warning for input_types="row" and axis=1 * introduce early check on the supported dtypes * asjust test after early dtype handling * address review comments * user NameError for column name parsing issue, address test coverage failure * address nan return handling in the gcf code * handle (nan, inf, -inf) * replace "row" by bpd.Series for input types * make the bq parity assert more readable * fix the series name before assert * fix docstring for args * move more low level string logic in sql module * raise explicit error when a column name cannot be supported * keep literal_eval check on the serialization side to match deserialization
1 parent f2ed29c commit f6bdc4a

File tree

9 files changed

+792
-52
lines changed

9 files changed

+792
-52
lines changed

bigframes/core/blocks.py

+94-7
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@
2121

2222
from __future__ import annotations
2323

24+
import ast
2425
import dataclasses
2526
import functools
2627
import itertools
2728
import os
2829
import random
30+
import textwrap
2931
import typing
3032
from typing import Iterable, List, Literal, Mapping, Optional, Sequence, Tuple, Union
3133
import warnings
@@ -44,8 +46,8 @@
4446
import bigframes.core.join_def as join_defs
4547
import bigframes.core.ordering as ordering
4648
import bigframes.core.schema as bf_schema
49+
import bigframes.core.sql as sql
4750
import bigframes.core.tree_properties as tree_properties
48-
import bigframes.core.utils
4951
import bigframes.core.utils as utils
5052
import bigframes.core.window_spec as window_specs
5153
import bigframes.dtypes
@@ -1437,9 +1439,7 @@ def promote_offsets(self, label: Label = None) -> typing.Tuple[Block, str]:
14371439
)
14381440

14391441
def add_prefix(self, prefix: str, axis: str | int | None = None) -> Block:
1440-
axis_number = bigframes.core.utils.get_axis_number(
1441-
"rows" if (axis is None) else axis
1442-
)
1442+
axis_number = utils.get_axis_number("rows" if (axis is None) else axis)
14431443
if axis_number == 0:
14441444
expr = self._expr
14451445
for index_col in self._index_columns:
@@ -1460,9 +1460,7 @@ def add_prefix(self, prefix: str, axis: str | int | None = None) -> Block:
14601460
return self.rename(columns=lambda label: f"{prefix}{label}")
14611461

14621462
def add_suffix(self, suffix: str, axis: str | int | None = None) -> Block:
1463-
axis_number = bigframes.core.utils.get_axis_number(
1464-
"rows" if (axis is None) else axis
1465-
)
1463+
axis_number = utils.get_axis_number("rows" if (axis is None) else axis)
14661464
if axis_number == 0:
14671465
expr = self._expr
14681466
for index_col in self._index_columns:
@@ -2072,6 +2070,95 @@ def _is_monotonic(
20722070
self._stats_cache[column_name].update({op_name: result})
20732071
return result
20742072

2073+
def _get_rows_as_json_values(self) -> Block:
2074+
# We want to preserve any ordering currently present before turning to
2075+
# direct SQL manipulation. We will restore the ordering when we rebuild
2076+
# expression.
2077+
# TODO(shobs): Replace direct SQL manipulation by structured expression
2078+
# manipulation
2079+
ordering_column_name = guid.generate_guid()
2080+
expr = self.session._cache_with_offsets(self.expr)
2081+
expr = expr.promote_offsets(ordering_column_name)
2082+
expr_sql = self.session._to_sql(expr)
2083+
2084+
# Names of the columns to serialize for the row.
2085+
# We will use the repr-eval pattern to serialize a value here and
2086+
# deserialize in the cloud function. Let's make sure that would work.
2087+
column_names = []
2088+
for col in list(self.index_columns) + [col for col in self.column_labels]:
2089+
serialized_column_name = repr(col)
2090+
try:
2091+
ast.literal_eval(serialized_column_name)
2092+
except Exception:
2093+
raise NameError(
2094+
f"Column name type '{type(col).__name__}' is not supported for row serialization."
2095+
" Please consider using a name for which literal_eval(repr(name)) works."
2096+
)
2097+
2098+
column_names.append(serialized_column_name)
2099+
column_names_csv = sql.csv(column_names, quoted=True)
2100+
2101+
# index columns count
2102+
index_columns_count = len(self.index_columns)
2103+
2104+
# column references to form the array of values for the row
2105+
column_references_csv = sql.csv(
2106+
[sql.cast_as_string(col) for col in self.expr.column_ids]
2107+
)
2108+
2109+
# types of the columns to serialize for the row
2110+
column_types = list(self.index.dtypes) + list(self.dtypes)
2111+
column_types_csv = sql.csv([str(typ) for typ in column_types], quoted=True)
2112+
2113+
# row dtype to use for deserializing the row as pandas series
2114+
pandas_row_dtype = bigframes.dtypes.lcd_type(*column_types)
2115+
if pandas_row_dtype is None:
2116+
pandas_row_dtype = "object"
2117+
pandas_row_dtype = sql.quote(str(pandas_row_dtype))
2118+
2119+
# create a json column representing row through SQL manipulation
2120+
row_json_column_name = guid.generate_guid()
2121+
select_columns = (
2122+
[ordering_column_name] + list(self.index_columns) + [row_json_column_name]
2123+
)
2124+
select_columns_csv = sql.csv(
2125+
[sql.column_reference(col) for col in select_columns]
2126+
)
2127+
json_sql = f"""\
2128+
With T0 AS (
2129+
{textwrap.indent(expr_sql, " ")}
2130+
),
2131+
T1 AS (
2132+
SELECT *,
2133+
JSON_OBJECT(
2134+
"names", [{column_names_csv}],
2135+
"types", [{column_types_csv}],
2136+
"values", [{column_references_csv}],
2137+
"indexlength", {index_columns_count},
2138+
"dtype", {pandas_row_dtype}
2139+
) AS {row_json_column_name} FROM T0
2140+
)
2141+
SELECT {select_columns_csv} FROM T1
2142+
"""
2143+
ibis_table = self.session.ibis_client.sql(json_sql)
2144+
order_for_ibis_table = ordering.ExpressionOrdering.from_offset_col(
2145+
ordering_column_name
2146+
)
2147+
expr = core.ArrayValue.from_ibis(
2148+
self.session,
2149+
ibis_table,
2150+
[ibis_table[col] for col in select_columns if col != ordering_column_name],
2151+
hidden_ordering_columns=[ibis_table[ordering_column_name]],
2152+
ordering=order_for_ibis_table,
2153+
)
2154+
block = Block(
2155+
expr,
2156+
index_columns=self.index_columns,
2157+
column_labels=[row_json_column_name],
2158+
index_labels=self._index_labels,
2159+
)
2160+
return block
2161+
20752162

20762163
class BlockIndexProperties:
20772164
"""Accessor for the index-related block properties."""

bigframes/core/sql.py

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Utility functions for SQL construction.
17+
"""
18+
19+
from typing import Iterable
20+
21+
22+
def quote(value: str):
23+
"""Return quoted input string."""
24+
25+
# Let's use repr which also escapes any special characters
26+
#
27+
# >>> for val in [
28+
# ... "123",
29+
# ... "str with no special chars",
30+
# ... "str with special chars.,'\"/\\"
31+
# ... ]:
32+
# ... print(f"{val} -> {repr(val)}")
33+
# ...
34+
# 123 -> '123'
35+
# str with no special chars -> 'str with no special chars'
36+
# str with special chars.,'"/\ -> 'str with special chars.,\'"/\\'
37+
38+
return repr(value)
39+
40+
41+
def column_reference(column_name: str):
42+
"""Return a string representing column reference in a SQL."""
43+
44+
return f"`{column_name}`"
45+
46+
47+
def cast_as_string(column_name: str):
48+
"""Return a string representing string casting of a column."""
49+
50+
return f"CAST({column_reference(column_name)} AS STRING)"
51+
52+
53+
def csv(values: Iterable[str], quoted=False):
54+
"""Return a string of comma separated values."""
55+
56+
if quoted:
57+
values = [quote(val) for val in values]
58+
59+
return ", ".join(values)

bigframes/dataframe.py

+55-1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
Tuple,
3535
Union,
3636
)
37+
import warnings
3738

3839
import bigframes_vendored.pandas.core.frame as vendored_pandas_frame
3940
import bigframes_vendored.pandas.pandas._typing as vendored_pandas_typing
@@ -61,6 +62,7 @@
6162
import bigframes.core.window
6263
import bigframes.core.window_spec as window_spec
6364
import bigframes.dtypes
65+
import bigframes.exceptions
6466
import bigframes.formatting_helpers as formatter
6567
import bigframes.operations as ops
6668
import bigframes.operations.aggregations as agg_ops
@@ -3308,7 +3310,59 @@ def map(self, func, na_action: Optional[str] = None) -> DataFrame:
33083310
ops.RemoteFunctionOp(func=func, apply_on_null=(na_action is None))
33093311
)
33103312

3311-
def apply(self, func, *, args: typing.Tuple = (), **kwargs):
3313+
def apply(self, func, *, axis=0, args: typing.Tuple = (), **kwargs):
3314+
if utils.get_axis_number(axis) == 1:
3315+
warnings.warn(
3316+
"axis=1 scenario is in preview.",
3317+
category=bigframes.exceptions.PreviewWarning,
3318+
)
3319+
3320+
# Early check whether the dataframe dtypes are currently supported
3321+
# in the remote function
3322+
# NOTE: Keep in sync with the value converters used in the gcf code
3323+
# generated in generate_cloud_function_main_code in remote_function.py
3324+
remote_function_supported_dtypes = (
3325+
bigframes.dtypes.INT_DTYPE,
3326+
bigframes.dtypes.FLOAT_DTYPE,
3327+
bigframes.dtypes.BOOL_DTYPE,
3328+
bigframes.dtypes.STRING_DTYPE,
3329+
)
3330+
supported_dtypes_types = tuple(
3331+
type(dtype) for dtype in remote_function_supported_dtypes
3332+
)
3333+
supported_dtypes_hints = tuple(
3334+
str(dtype) for dtype in remote_function_supported_dtypes
3335+
)
3336+
3337+
for dtype in self.dtypes:
3338+
if not isinstance(dtype, supported_dtypes_types):
3339+
raise NotImplementedError(
3340+
f"DataFrame has a column of dtype '{dtype}' which is not supported with axis=1."
3341+
f" Supported dtypes are {supported_dtypes_hints}."
3342+
)
3343+
3344+
# Check if the function is a remote function
3345+
if not hasattr(func, "bigframes_remote_function"):
3346+
raise ValueError("For axis=1 a remote function must be used.")
3347+
3348+
# Serialize the rows as json values
3349+
block = self._get_block()
3350+
rows_as_json_series = bigframes.series.Series(
3351+
block._get_rows_as_json_values()
3352+
)
3353+
3354+
# Apply the function
3355+
result_series = rows_as_json_series._apply_unary_op(
3356+
ops.RemoteFunctionOp(func=func, apply_on_null=True)
3357+
)
3358+
result_series.name = None
3359+
3360+
# Return Series with materialized result so that any error in the remote
3361+
# function is caught early
3362+
materialized_series = result_series.cache()
3363+
return materialized_series
3364+
3365+
# Per-column apply
33123366
results = {name: func(col, *args, **kwargs) for name, col in self.items()}
33133367
if all(
33143368
[

bigframes/exceptions.py

+4
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,7 @@ class CleanupFailedWarning(Warning):
3333

3434
class DefaultIndexWarning(Warning):
3535
"""Default index may cause unexpected costs."""
36+
37+
38+
class PreviewWarning(Warning):
39+
"""The feature is in preview."""

0 commit comments

Comments
 (0)