Skip to content

Commit 4b8fc15

Browse files
authored
fix: include internally required packages in remote_function hash (#799)
* fix: include internally required packages in `remote_function` id computation * refactor to keep the tests supported
1 parent 2e692e9 commit 4b8fc15

File tree

3 files changed

+140
-20
lines changed

3 files changed

+140
-20
lines changed

bigframes/functions/remote_function.py

+36-19
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,24 @@ def _get_hash(def_, package_requirements=None):
102102
return hashlib.md5(def_repr).hexdigest()
103103

104104

105+
def _get_updated_package_requirements(package_requirements, is_row_processor):
106+
requirements = [f"cloudpickle=={cloudpickle.__version__}"]
107+
if is_row_processor:
108+
# bigframes remote function will send an entire row of data as json,
109+
# which would be converted to a pandas series and processed
110+
# Ensure numpy versions match to avoid unpickling problems. See
111+
# internal issue b/347934471.
112+
requirements.append(f"numpy=={numpy.__version__}")
113+
requirements.append(f"pandas=={pandas.__version__}")
114+
requirements.append(f"pyarrow=={pyarrow.__version__}")
115+
116+
if package_requirements:
117+
requirements.extend(package_requirements)
118+
119+
requirements = sorted(requirements)
120+
return requirements
121+
122+
105123
def routine_ref_to_string_for_query(routine_ref: bigquery.RoutineReference) -> str:
106124
return f"`{routine_ref.project}.{routine_ref.dataset_id}`.{routine_ref.routine_id}"
107125

@@ -112,13 +130,22 @@ class IbisSignature(NamedTuple):
112130
output_type: IbisDataType
113131

114132

115-
def get_cloud_function_name(def_, uniq_suffix=None, package_requirements=None):
133+
def get_cloud_function_name(
134+
def_, uniq_suffix=None, package_requirements=None, is_row_processor=False
135+
):
116136
"Get a name for the cloud function for the given user defined function."
137+
138+
# Augment user package requirements with any internal package
139+
# requirements
140+
package_requirements = _get_updated_package_requirements(
141+
package_requirements, is_row_processor
142+
)
143+
117144
cf_name = _get_hash(def_, package_requirements)
118145
cf_name = f"bigframes-{cf_name}" # for identification
119146
if uniq_suffix:
120147
cf_name = f"{cf_name}-{uniq_suffix}"
121-
return cf_name
148+
return cf_name, package_requirements
122149

123150

124151
def get_remote_function_name(def_, uniq_suffix=None, package_requirements=None):
@@ -277,21 +304,10 @@ def generate_cloud_function_code(
277304
"""
278305

279306
# requirements.txt
280-
requirements = ["cloudpickle >= 2.1.0"]
281-
if is_row_processor:
282-
# bigframes remote function will send an entire row of data as json,
283-
# which would be converted to a pandas series and processed
284-
# Ensure numpy versions match to avoid unpickling problems. See
285-
# internal issue b/347934471.
286-
requirements.append(f"numpy=={numpy.__version__}")
287-
requirements.append(f"pandas=={pandas.__version__}")
288-
requirements.append(f"pyarrow=={pyarrow.__version__}")
289307
if package_requirements:
290-
requirements.extend(package_requirements)
291-
requirements = sorted(requirements)
292-
requirements_txt = os.path.join(directory, "requirements.txt")
293-
with open(requirements_txt, "w") as f:
294-
f.write("\n".join(requirements))
308+
requirements_txt = os.path.join(directory, "requirements.txt")
309+
with open(requirements_txt, "w") as f:
310+
f.write("\n".join(package_requirements))
295311

296312
# main.py
297313
entry_point = bigframes.functions.remote_function_template.generate_cloud_function_main_code(
@@ -469,9 +485,10 @@ def provision_bq_remote_function(
469485
)
470486

471487
# Derive the name of the cloud function underlying the intended BQ
472-
# remote function
473-
cloud_function_name = get_cloud_function_name(
474-
def_, uniq_suffix, package_requirements
488+
# remote function, also collect updated package requirements as
489+
# determined in the name resolution
490+
cloud_function_name, package_requirements = get_cloud_function_name(
491+
def_, uniq_suffix, package_requirements, is_row_processor
475492
)
476493
cf_endpoint = self.get_cloud_function_endpoint(cloud_function_name)
477494

tests/system/large/test_remote_function.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ def add_one(x):
590590
add_one_uniq, add_one_uniq_dir = make_uniq_udf(add_one)
591591

592592
# Expected cloud function name for the unique udf
593-
add_one_uniq_cf_name = get_cloud_function_name(add_one_uniq)
593+
add_one_uniq_cf_name, _ = get_cloud_function_name(add_one_uniq)
594594

595595
# There should be no cloud function yet for the unique udf
596596
cloud_functions = list(

tests/system/small/test_remote_function.py

+103
Original file line numberDiff line numberDiff line change
@@ -742,6 +742,109 @@ def test_read_gbq_function_enforces_explicit_types(
742742
)
743743

744744

745+
@pytest.mark.flaky(retries=2, delay=120)
746+
def test_df_apply_axis_1(session, scalars_dfs):
747+
columns = [
748+
"bool_col",
749+
"int64_col",
750+
"int64_too",
751+
"float64_col",
752+
"string_col",
753+
"bytes_col",
754+
]
755+
scalars_df, scalars_pandas_df = scalars_dfs
756+
757+
def add_ints(row):
758+
return row["int64_col"] + row["int64_too"]
759+
760+
with pytest.warns(
761+
bigframes.exceptions.PreviewWarning,
762+
match="input_types=Series is in preview.",
763+
):
764+
add_ints_remote = session.remote_function(
765+
bigframes.series.Series,
766+
int,
767+
)(add_ints)
768+
769+
with pytest.warns(
770+
bigframes.exceptions.PreviewWarning, match="axis=1 scenario is in preview."
771+
):
772+
bf_result = scalars_df[columns].apply(add_ints_remote, axis=1).to_pandas()
773+
774+
pd_result = scalars_pandas_df[columns].apply(add_ints, axis=1)
775+
776+
# bf_result.dtype is 'Int64' while pd_result.dtype is 'object', ignore this
777+
# mismatch by using check_dtype=False.
778+
#
779+
# bf_result.to_numpy() produces an array of numpy.float64's
780+
# (in system_prerelease tests), while pd_result.to_numpy() produces an
781+
# array of ints, ignore this mismatch by using check_exact=False.
782+
pd.testing.assert_series_equal(
783+
pd_result, bf_result, check_dtype=False, check_exact=False
784+
)
785+
786+
787+
@pytest.mark.flaky(retries=2, delay=120)
788+
def test_df_apply_axis_1_ordering(session, scalars_dfs):
789+
columns = ["bool_col", "int64_col", "int64_too", "float64_col", "string_col"]
790+
ordering_columns = ["bool_col", "int64_col"]
791+
scalars_df, scalars_pandas_df = scalars_dfs
792+
793+
def add_ints(row):
794+
return row["int64_col"] + row["int64_too"]
795+
796+
add_ints_remote = session.remote_function(bigframes.series.Series, int)(add_ints)
797+
798+
bf_result = (
799+
scalars_df[columns]
800+
.sort_values(ordering_columns)
801+
.apply(add_ints_remote, axis=1)
802+
.to_pandas()
803+
)
804+
pd_result = (
805+
scalars_pandas_df[columns].sort_values(ordering_columns).apply(add_ints, axis=1)
806+
)
807+
808+
# bf_result.dtype is 'Int64' while pd_result.dtype is 'object', ignore this
809+
# mismatch by using check_dtype=False.
810+
#
811+
# bf_result.to_numpy() produces an array of numpy.float64's
812+
# (in system_prerelease tests), while pd_result.to_numpy() produces an
813+
# array of ints, ignore this mismatch by using check_exact=False.
814+
pd.testing.assert_series_equal(
815+
pd_result, bf_result, check_dtype=False, check_exact=False
816+
)
817+
818+
819+
@pytest.mark.flaky(retries=2, delay=120)
820+
def test_df_apply_axis_1_multiindex(session):
821+
pd_df = pd.DataFrame(
822+
{"x": [1, 2, 3], "y": [1.5, 3.75, 5], "z": ["pq", "rs", "tu"]},
823+
index=pd.MultiIndex.from_tuples([("a", 100), ("a", 200), ("b", 300)]),
824+
)
825+
bf_df = session.read_pandas(pd_df)
826+
827+
def add_numbers(row):
828+
return row["x"] + row["y"]
829+
830+
add_numbers_remote = session.remote_function(bigframes.series.Series, float)(
831+
add_numbers
832+
)
833+
834+
bf_result = bf_df.apply(add_numbers_remote, axis=1).to_pandas()
835+
pd_result = pd_df.apply(add_numbers, axis=1)
836+
837+
# bf_result.dtype is 'Float64' while pd_result.dtype is 'float64', ignore this
838+
# mismatch by using check_dtype=False.
839+
#
840+
# bf_result.index[0].dtype is 'string[pyarrow]' while
841+
# pd_result.index[0].dtype is 'object', ignore this mismatch by using
842+
# check_index_type=False.
843+
pd.testing.assert_series_equal(
844+
pd_result, bf_result, check_dtype=False, check_index_type=False
845+
)
846+
847+
745848
def test_df_apply_axis_1_unsupported_callable(scalars_dfs):
746849
scalars_df, scalars_pandas_df = scalars_dfs
747850
columns = ["bool_col", "int64_col", "int64_too", "float64_col", "string_col"]

0 commit comments

Comments
 (0)