Skip to content

Commit d850da6

Browse files
authored
feat: allow functions decorated with @bpd.remote_function to execute locally (#704)
* feat: allow functions decorated with `@bpd.remote_function` to execute locally * fix read_gbq_function * fix for rare case where re-deploy exact same function object
1 parent 4a12e3c commit d850da6

File tree

4 files changed

+94
-43
lines changed

4 files changed

+94
-43
lines changed

bigframes/core/compile/scalar_op_compiler.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -856,11 +856,12 @@ def to_timestamp_op_impl(x: ibis_types.Value, op: ops.ToTimestampOp):
856856

857857
@scalar_op_compiler.register_unary_op(ops.RemoteFunctionOp, pass_op=True)
858858
def remote_function_op_impl(x: ibis_types.Value, op: ops.RemoteFunctionOp):
859-
if not hasattr(op.func, "bigframes_remote_function"):
859+
ibis_node = getattr(op.func, "ibis_node", None)
860+
if ibis_node is None:
860861
raise TypeError(
861862
f"only a bigframes remote function is supported as a callable. {constants.FEEDBACK_LINK}"
862863
)
863-
x_transformed = op.func(x)
864+
x_transformed = ibis_node(x)
864865
if not op.apply_on_null:
865866
x_transformed = ibis.case().when(x.isnull(), x).else_(x_transformed).end()
866867
return x_transformed
@@ -1342,11 +1343,12 @@ def minimum_impl(
13421343
def binary_remote_function_op_impl(
13431344
x: ibis_types.Value, y: ibis_types.Value, op: ops.BinaryRemoteFunctionOp
13441345
):
1345-
if not hasattr(op.func, "bigframes_remote_function"):
1346+
ibis_node = getattr(op.func, "ibis_node", None)
1347+
if ibis_node is None:
13461348
raise TypeError(
13471349
f"only a bigframes remote function is supported as a callable. {constants.FEEDBACK_LINK}"
13481350
)
1349-
x_transformed = op.func(x, y)
1351+
x_transformed = ibis_node(x, y)
13501352
return x_transformed
13511353

13521354

bigframes/functions/remote_function.py

+36-16
Original file line numberDiff line numberDiff line change
@@ -1013,11 +1013,11 @@ def remote_function(
10131013

10141014
bq_connection_manager = None if session is None else session.bqconnectionmanager
10151015

1016-
def wrapper(f):
1016+
def wrapper(func):
10171017
nonlocal input_types, output_type
10181018

1019-
if not callable(f):
1020-
raise TypeError("f must be callable, got {}".format(f))
1019+
if not callable(func):
1020+
raise TypeError("f must be callable, got {}".format(func))
10211021

10221022
if sys.version_info >= (3, 10):
10231023
# Add `eval_str = True` so that deferred annotations are turned into their
@@ -1028,7 +1028,7 @@ def wrapper(f):
10281028
signature_kwargs = {}
10291029

10301030
signature = inspect.signature(
1031-
f,
1031+
func,
10321032
**signature_kwargs,
10331033
)
10341034

@@ -1089,8 +1089,23 @@ def wrapper(f):
10891089
session=session, # type: ignore
10901090
)
10911091

1092+
# In the unlikely case where the user is trying to re-deploy the same
1093+
# function, cleanup the attributes we add below, first. This prevents
1094+
# the pickle from having dependencies that might not otherwise be
1095+
# present such as ibis or pandas.
1096+
def try_delattr(attr):
1097+
try:
1098+
delattr(func, attr)
1099+
except AttributeError:
1100+
pass
1101+
1102+
try_delattr("bigframes_cloud_function")
1103+
try_delattr("bigframes_remote_function")
1104+
try_delattr("output_dtype")
1105+
try_delattr("ibis_node")
1106+
10921107
rf_name, cf_name = remote_function_client.provision_bq_remote_function(
1093-
f,
1108+
func,
10941109
ibis_signature.input_types,
10951110
ibis_signature.output_type,
10961111
reuse,
@@ -1105,19 +1120,20 @@ def wrapper(f):
11051120

11061121
# TODO: Move ibis logic to compiler step
11071122
node = ibis.udf.scalar.builtin(
1108-
f,
1123+
func,
11091124
name=rf_name,
11101125
schema=f"{dataset_ref.project}.{dataset_ref.dataset_id}",
11111126
signature=(ibis_signature.input_types, ibis_signature.output_type),
11121127
)
1113-
node.bigframes_cloud_function = (
1128+
func.bigframes_cloud_function = (
11141129
remote_function_client.get_cloud_function_fully_qualified_name(cf_name)
11151130
)
1116-
node.bigframes_remote_function = str(dataset_ref.routine(rf_name)) # type: ignore
1117-
node.output_dtype = bigframes.dtypes.ibis_dtype_to_bigframes_dtype(
1131+
func.bigframes_remote_function = str(dataset_ref.routine(rf_name)) # type: ignore
1132+
func.output_dtype = bigframes.dtypes.ibis_dtype_to_bigframes_dtype(
11181133
ibis_signature.output_type
11191134
)
1120-
return node
1135+
func.ibis_node = node
1136+
return func
11211137

11221138
return wrapper
11231139

@@ -1168,19 +1184,23 @@ def read_gbq_function(
11681184

11691185
# The name "args" conflicts with the Ibis operator, so we use
11701186
# non-standard names for the arguments here.
1171-
def node(*ignored_args, **ignored_kwargs):
1187+
def func(*ignored_args, **ignored_kwargs):
11721188
f"""Remote function {str(routine_ref)}."""
1189+
# TODO(swast): Construct an ibis client from bigquery_client and
1190+
# execute node via a query.
11731191

11741192
# TODO: Move ibis logic to compiler step
1175-
node.__name__ = routine_ref.routine_id
1193+
func.__name__ = routine_ref.routine_id
1194+
11761195
node = ibis.udf.scalar.builtin(
1177-
node,
1196+
func,
11781197
name=routine_ref.routine_id,
11791198
schema=f"{routine_ref.project}.{routine_ref.dataset_id}",
11801199
signature=(ibis_signature.input_types, ibis_signature.output_type),
11811200
)
1182-
node.bigframes_remote_function = str(routine_ref) # type: ignore
1183-
node.output_dtype = bigframes.dtypes.ibis_dtype_to_bigframes_dtype( # type: ignore
1201+
func.bigframes_remote_function = str(routine_ref) # type: ignore
1202+
func.output_dtype = bigframes.dtypes.ibis_dtype_to_bigframes_dtype( # type: ignore
11841203
ibis_signature.output_type
11851204
)
1186-
return node
1205+
func.ibis_node = node # type: ignore
1206+
return func

tests/system/large/test_remote_function.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,9 @@ def test_remote_function_stringify_with_ibis(
194194
def stringify(x):
195195
return f"I got {x}"
196196

197+
# Function should work locally.
198+
assert stringify(42) == "I got 42"
199+
197200
_, dataset_name, table_name = scalars_table_id.split(".")
198201
if not ibis_client.dataset:
199202
ibis_client.dataset = dataset_name
@@ -205,7 +208,7 @@ def stringify(x):
205208
pandas_df_orig = bigquery_client.query(sql).to_dataframe()
206209

207210
col = table[col_name]
208-
col_2x = stringify(col).name("int64_str_col")
211+
col_2x = stringify.ibis_node(col).name("int64_str_col")
209212
table = table.mutate([col_2x])
210213
sql = table.compile()
211214
pandas_df_new = bigquery_client.query(sql).to_dataframe()

tests/system/small/test_remote_function.py

+48-22
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def bq_cf_connection_location_project(bigquery_client) -> str:
6767

6868
@pytest.fixture(scope="module")
6969
def bq_cf_connection_location_project_mismatched() -> str:
70-
"""Pre-created BQ connection in the migframes-metrics project in US location,
70+
"""Pre-created BQ connection in the bigframes-metrics project in US location,
7171
in format PROJECT_ID.LOCATION.CONNECTION_NAME, used to invoke cloud function.
7272
7373
$ bq show --connection --location=us --project_id=PROJECT_ID bigframes-rf-conn
@@ -108,11 +108,15 @@ def test_remote_function_direct_no_session_param(
108108
reuse=True,
109109
)
110110
def square(x):
111-
# This executes on a remote function, where coverage isn't tracked.
112-
return x * x # pragma: NO COVER
111+
return x * x
113112

114-
assert square.bigframes_remote_function
115-
assert square.bigframes_cloud_function
113+
# Function should still work normally.
114+
assert square(2) == 4
115+
116+
# Function should have extra metadata attached for remote execution.
117+
assert hasattr(square, "bigframes_remote_function")
118+
assert hasattr(square, "bigframes_cloud_function")
119+
assert hasattr(square, "ibis_node")
116120

117121
scalars_df, scalars_pandas_df = scalars_dfs
118122

@@ -161,8 +165,10 @@ def test_remote_function_direct_no_session_param_location_specified(
161165
reuse=True,
162166
)
163167
def square(x):
164-
# This executes on a remote function, where coverage isn't tracked.
165-
return x * x # pragma: NO COVER
168+
return x * x
169+
170+
# Function should still work normally.
171+
assert square(2) == 4
166172

167173
scalars_df, scalars_pandas_df = scalars_dfs
168174

@@ -197,7 +203,10 @@ def test_remote_function_direct_no_session_param_location_mismatched(
197203
dataset_id_permanent,
198204
bq_cf_connection_location_mismatched,
199205
):
200-
with pytest.raises(ValueError):
206+
with pytest.raises(
207+
ValueError,
208+
match=re.escape("The location does not match BigQuery connection location:"),
209+
):
201210

202211
@rf.remote_function(
203212
[int],
@@ -212,7 +221,8 @@ def test_remote_function_direct_no_session_param_location_mismatched(
212221
reuse=True,
213222
)
214223
def square(x):
215-
# This executes on a remote function, where coverage isn't tracked.
224+
# Not expected to reach this code, as the location of the
225+
# connection doesn't match the location of the dataset.
216226
return x * x # pragma: NO COVER
217227

218228

@@ -239,8 +249,10 @@ def test_remote_function_direct_no_session_param_location_project_specified(
239249
reuse=True,
240250
)
241251
def square(x):
242-
# This executes on a remote function, where coverage isn't tracked.
243-
return x * x # pragma: NO COVER
252+
return x * x
253+
254+
# Function should still work normally.
255+
assert square(2) == 4
244256

245257
scalars_df, scalars_pandas_df = scalars_dfs
246258

@@ -275,7 +287,12 @@ def test_remote_function_direct_no_session_param_project_mismatched(
275287
dataset_id_permanent,
276288
bq_cf_connection_location_project_mismatched,
277289
):
278-
with pytest.raises(ValueError):
290+
with pytest.raises(
291+
ValueError,
292+
match=re.escape(
293+
"The project_id does not match BigQuery connection gcp_project_id:"
294+
),
295+
):
279296

280297
@rf.remote_function(
281298
[int],
@@ -290,7 +307,8 @@ def test_remote_function_direct_no_session_param_project_mismatched(
290307
reuse=True,
291308
)
292309
def square(x):
293-
# This executes on a remote function, where coverage isn't tracked.
310+
# Not expected to reach this code, as the project of the
311+
# connection doesn't match the project of the dataset.
294312
return x * x # pragma: NO COVER
295313

296314

@@ -302,8 +320,10 @@ def test_remote_function_direct_session_param(session_with_bq_connection, scalar
302320
session=session_with_bq_connection,
303321
)
304322
def square(x):
305-
# This executes on a remote function, where coverage isn't tracked.
306-
return x * x # pragma: NO COVER
323+
return x * x
324+
325+
# Function should still work normally.
326+
assert square(2) == 4
307327

308328
scalars_df, scalars_pandas_df = scalars_dfs
309329

@@ -340,8 +360,10 @@ def test_remote_function_via_session_default(session_with_bq_connection, scalars
340360
# cloud function would be common and quickly reused.
341361
@session_with_bq_connection.remote_function([int], int)
342362
def square(x):
343-
# This executes on a remote function, where coverage isn't tracked.
344-
return x * x # pragma: NO COVER
363+
return x * x
364+
365+
# Function should still work normally.
366+
assert square(2) == 4
345367

346368
scalars_df, scalars_pandas_df = scalars_dfs
347369

@@ -380,8 +402,10 @@ def test_remote_function_via_session_with_overrides(
380402
reuse=True,
381403
)
382404
def square(x):
383-
# This executes on a remote function, where coverage isn't tracked.
384-
return x * x # pragma: NO COVER
405+
return x * x
406+
407+
# Function should still work normally.
408+
assert square(2) == 4
385409

386410
scalars_df, scalars_pandas_df = scalars_dfs
387411

@@ -508,7 +532,7 @@ def test_skip_bq_connection_check(dataset_id_permanent):
508532

509533
@session.remote_function([int], int, dataset=dataset_id_permanent)
510534
def add_one(x):
511-
# This executes on a remote function, where coverage isn't tracked.
535+
# Not expected to reach this code, as the connection doesn't exist.
512536
return x + 1 # pragma: NO COVER
513537

514538

@@ -546,8 +570,10 @@ def test_read_gbq_function_like_original(
546570
reuse=True,
547571
)
548572
def square1(x):
549-
# This executes on a remote function, where coverage isn't tracked.
550-
return x * x # pragma: NO COVER
573+
return x * x
574+
575+
# Function should still work normally.
576+
assert square1(2) == 4
551577

552578
square2 = rf.read_gbq_function(
553579
function_name=square1.bigframes_remote_function,

0 commit comments

Comments
 (0)