Skip to content

feat: support type annotations to supply input and output types to @remote_function decorator #717

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 25, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 60 additions & 25 deletions bigframes/functions/remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@
import sys
import tempfile
import textwrap
from typing import List, NamedTuple, Optional, Sequence, TYPE_CHECKING, Union
import typing
from typing import Any, List, Mapping, NamedTuple, Optional, Sequence, Union
import warnings

import ibis
import pandas
import requests

if TYPE_CHECKING:
if typing.TYPE_CHECKING:
from bigframes.session import Session

import bigframes_vendored.ibis.backends.bigquery.datatypes as third_party_ibis_bqtypes
Expand Down Expand Up @@ -735,8 +736,8 @@ def get_routine_reference(
# which has moved as @js to the ibis package
# https://github.com/ibis-project/ibis/blob/master/ibis/backends/bigquery/udf/__init__.py
def remote_function(
input_types: Union[type, Sequence[type]],
output_type: type,
input_types: Union[None, type, Sequence[type]] = None,
output_type: Optional[type] = None,
session: Optional[Session] = None,
bigquery_client: Optional[bigquery.Client] = None,
bigquery_connection_client: Optional[
Expand Down Expand Up @@ -800,11 +801,11 @@ def remote_function(
`$ gcloud projects add-iam-policy-binding PROJECT_ID --member="serviceAccount:CONNECTION_SERVICE_ACCOUNT_ID" --role="roles/run.invoker"`.

Args:
input_types (type or sequence(type)):
input_types (None, type, or sequence(type)):
For scalar user defined function it should be the input type or
sequence of input types. For row processing user defined function,
type `Series` should be specified.
output_type (type):
output_type (Optional[type]):
Data type of the output in the user defined function.
session (bigframes.Session, Optional):
BigQuery DataFrames session to use for getting default project,
Expand Down Expand Up @@ -907,26 +908,9 @@ def remote_function(
service(s) that are on a VPC network. See for more details
https://cloud.google.com/functions/docs/networking/connecting-vpc.
"""
is_row_processor = False

import bigframes.series

if input_types == bigframes.series.Series:
warnings.warn(
"input_types=Series scenario is in preview.",
stacklevel=1,
category=bigframes.exceptions.PreviewWarning,
)

# we will model the row as a json serialized string containing the data
# and the metadata representing the row
input_types = [str]
is_row_processor = True
elif isinstance(input_types, type):
input_types = [input_types]

# Some defaults may be used from the session if not provided otherwise
import bigframes.pandas as bpd
import bigframes.series

session = session or bpd.get_global_session()

Expand Down Expand Up @@ -1019,10 +1003,61 @@ def remote_function(
bq_connection_manager = None if session is None else session.bqconnectionmanager

def wrapper(f):
nonlocal input_types, output_type

if not callable(f):
raise TypeError("f must be callable, got {}".format(f))

signature = inspect.signature(f)
if sys.version_info >= (3, 10):
# Add `eval_str = True` so that deferred annotations are turned into their
# corresponding type objects. Need Python 3.10 for eval_str parameter.
# https://docs.python.org/3/library/inspect.html#inspect.signature
signature_kwargs: Mapping[str, Any] = {"eval_str": True}
else:
signature_kwargs = {}

signature = inspect.signature(
f,
**signature_kwargs,
)

# Try to get input types via type annotations.
if input_types is None:
input_types = []
for parameter in signature.parameters.values():
if (param_type := parameter.annotation) is inspect.Signature.empty:
raise ValueError(
"'input_types' was not set and parameter "
f"'{parameter.name}' is missing a type annotation. "
"Types are required to use @remote_function."
)
input_types.append(param_type)

if output_type is None:
if (output_type := signature.return_annotation) is inspect.Signature.empty:
raise ValueError(
"'output_type' was not set and function is missing a "
"return type annotation. Types are required to use "
"@remote_function."
)

# The function will actually be receiving a pandas Series, but allow both
# BigQuery DataFrames and pandas object types for compatibility.
is_row_processor = False
if input_types == bigframes.series.Series or input_types == pandas.Series:
warnings.warn(
"input_types=Series scenario is in preview.",
stacklevel=1,
category=bigframes.exceptions.PreviewWarning,
)

# we will model the row as a json serialized string containing the data
# and the metadata representing the row
input_types = [str]
is_row_processor = True
elif isinstance(input_types, type):
input_types = [input_types]

# TODO(b/340898611): fix type error
ibis_signature = ibis_signature_from_python_signature(
signature, input_types, output_type # type: ignore
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def create_bigquery_session(
session_id: str = "abcxyz",
table_schema: Sequence[google.cloud.bigquery.SchemaField] = TEST_SCHEMA,
anonymous_dataset: Optional[google.cloud.bigquery.DatasetReference] = None,
location: str = "test-region",
) -> bigframes.Session:
credentials = mock.create_autospec(
google.auth.credentials.Credentials, instance=True
Expand All @@ -53,11 +54,12 @@ def create_bigquery_session(
if bqclient is None:
bqclient = mock.create_autospec(google.cloud.bigquery.Client, instance=True)
bqclient.project = "test-project"
bqclient.location = location

# Mock the location.
table = mock.create_autospec(google.cloud.bigquery.Table, instance=True)
table._properties = {}
type(table).location = mock.PropertyMock(return_value="test-region")
type(table).location = mock.PropertyMock(return_value=location)
type(table).schema = mock.PropertyMock(return_value=table_schema)
type(table).reference = mock.PropertyMock(
return_value=anonymous_dataset.table("test_table")
Expand Down Expand Up @@ -93,9 +95,7 @@ def query_mock(query, *args, **kwargs):
type(clients_provider).bqclient = mock.PropertyMock(return_value=bqclient)
clients_provider._credentials = credentials

bqoptions = bigframes.BigQueryOptions(
credentials=credentials, location="test-region"
)
bqoptions = bigframes.BigQueryOptions(credentials=credentials, location=location)
session = bigframes.Session(context=bqoptions, clients_provider=clients_provider)
return session

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def all_session_methods():
[(method_name,) for method_name in all_session_methods()],
)
def test_method_matches_session(method_name: str):
if sys.version_info <= (3, 10):
if sys.version_info < (3, 10):
pytest.skip(
"Need Python 3.10 to reconcile deferred annotations."
) # pragma: no cover
Expand Down
39 changes: 39 additions & 0 deletions tests/unit/test_remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@

import bigframes_vendored.ibis.backends.bigquery.datatypes as third_party_ibis_bqtypes
from ibis.expr import datatypes as ibis_types
import pytest

import bigframes.dtypes
import bigframes.functions.remote_function
from tests.unit import resources


def test_supported_types_correspond():
Expand All @@ -29,3 +32,39 @@ def test_supported_types_correspond():
}

assert ibis_types_from_python == ibis_types_from_bigquery


def test_missing_input_types():
session = resources.create_bigquery_session()
remote_function_decorator = bigframes.functions.remote_function.remote_function(
session=session
)

def function_without_parameter_annotations(myparam) -> str:
return str(myparam)

assert function_without_parameter_annotations(42) == "42"

with pytest.raises(
ValueError,
match="'input_types' was not set .* 'myparam' is missing a type annotation",
):
remote_function_decorator(function_without_parameter_annotations)


def test_missing_output_type():
session = resources.create_bigquery_session()
remote_function_decorator = bigframes.functions.remote_function.remote_function(
session=session
)

def function_without_return_annotation(myparam: int):
return str(myparam)

assert function_without_return_annotation(42) == "42"

with pytest.raises(
ValueError,
match="'output_type' was not set .* missing a return type annotation",
):
remote_function_decorator(function_without_return_annotation)
16 changes: 10 additions & 6 deletions third_party/bigframes_vendored/pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3892,8 +3892,8 @@ def map(self, func, na_action: Optional[str] = None) -> DataFrame:
to potentially reuse a previously deployed ``remote_function`` from
the same user defined function.

>>> @bpd.remote_function(int, float, reuse=False)
... def minutes_to_hours(x):
>>> @bpd.remote_function(reuse=False)
... def minutes_to_hours(x: int) -> float:
... return x/60

>>> df_minutes = bpd.DataFrame(
Expand Down Expand Up @@ -4214,6 +4214,7 @@ def apply(self, func, *, axis=0, args=(), **kwargs):
**Examples:**

>>> import bigframes.pandas as bpd
>>> import pandas as pd
>>> bpd.options.display.progress_bar = None

>>> df = bpd.DataFrame({'col1': [1, 2], 'col2': [3, 4]})
Expand All @@ -4235,16 +4236,19 @@ def apply(self, func, *, axis=0, args=(), **kwargs):
[2 rows x 2 columns]

You could apply a user defined function to every row of the DataFrame by
creating a remote function out of it, and using it with `axis=1`.
creating a remote function out of it, and using it with `axis=1`. Within
the function, each row is passed as a ``pandas.Series``. It is recommended
to select only the necessary columns before calling `apply()`. Note: This
feature is currently in **preview**.

>>> @bpd.remote_function(bpd.Series, int, reuse=False)
... def foo(row):
>>> @bpd.remote_function(reuse=False)
... def foo(row: pd.Series) -> int:
... result = 1
... result += row["col1"]
... result += row["col2"]*row["col2"]
... return result

>>> df.apply(foo, axis=1)
>>> df[["col1", "col2"]].apply(foo, axis=1)
0 11
1 19
dtype: Int64
Expand Down
16 changes: 7 additions & 9 deletions third_party/bigframes_vendored/pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,8 +1183,8 @@ def apply(
to potentially reuse a previously deployed `remote_function` from
the same user defined function.

>>> @bpd.remote_function(int, float, reuse=False)
... def minutes_to_hours(x):
>>> @bpd.remote_function(reuse=False)
... def minutes_to_hours(x: int) -> float:
... return x/60

>>> minutes = bpd.Series([0, 30, 60, 90, 120])
Expand All @@ -1210,12 +1210,10 @@ def apply(
`packages` param.

>>> @bpd.remote_function(
... str,
... str,
... reuse=False,
... packages=["cryptography"],
... )
... def get_hash(input):
... def get_hash(input: str) -> str:
... from cryptography.fernet import Fernet
...
... # handle missing value
Expand Down Expand Up @@ -3452,8 +3450,8 @@ def mask(self, cond, other):
condition is evaluated based on a complicated business logic which cannot
be expressed in form of a Series.

>>> @bpd.remote_function(str, bool, reuse=False)
... def should_mask(name):
>>> @bpd.remote_function(reuse=False)
... def should_mask(name: str) -> bool:
... hash = 0
... for char_ in name:
... hash += ord(char_)
Expand Down Expand Up @@ -3971,8 +3969,8 @@ def map(

It also accepts a remote function:

>>> @bpd.remote_function(str, str)
... def my_mapper(val):
>>> @bpd.remote_function
... def my_mapper(val: str) -> str:
... vowels = ["a", "e", "i", "o", "u"]
... if val:
... return "".join([
Expand Down