Skip to content

Commit 9959fc8

Browse files
fix: Fix issue with invalid sql generated by ml distance functions (#865)
1 parent 042db4b commit 9959fc8

File tree

8 files changed

+216
-154
lines changed

8 files changed

+216
-154
lines changed

bigframes/core/compile/scalar_op_compiler.py

+29
Original file line numberDiff line numberDiff line change
@@ -1380,6 +1380,30 @@ def minimum_impl(
13801380
return ibis.case().when(upper.isnull() | (value > upper), upper).else_(value).end()
13811381

13821382

1383+
@scalar_op_compiler.register_binary_op(ops.cosine_distance_op)
1384+
def cosine_distance_impl(
1385+
vector1: ibis_types.Value,
1386+
vector2: ibis_types.Value,
1387+
):
1388+
return vector_distance(vector1, vector2, "COSINE")
1389+
1390+
1391+
@scalar_op_compiler.register_binary_op(ops.euclidean_distance_op)
1392+
def euclidean_distance_impl(
1393+
vector1: ibis_types.Value,
1394+
vector2: ibis_types.Value,
1395+
):
1396+
return vector_distance(vector1, vector2, "EUCLIDEAN")
1397+
1398+
1399+
@scalar_op_compiler.register_binary_op(ops.manhattan_distance_op)
1400+
def manhattan_distance_impl(
1401+
vector1: ibis_types.Value,
1402+
vector2: ibis_types.Value,
1403+
):
1404+
return vector_distance(vector1, vector2, "MANHATTAN")
1405+
1406+
13831407
@scalar_op_compiler.register_binary_op(ops.BinaryRemoteFunctionOp, pass_op=True)
13841408
def binary_remote_function_op_impl(
13851409
x: ibis_types.Value, y: ibis_types.Value, op: ops.BinaryRemoteFunctionOp
@@ -1501,3 +1525,8 @@ def json_set(
15011525
json_obj: ibis_dtypes.JSON, json_path: ibis_dtypes.str, json_value
15021526
) -> ibis_dtypes.JSON:
15031527
"""Produces a new SQL JSON value with the specified JSON data inserted or replaced."""
1528+
1529+
1530+
@ibis.udf.scalar.builtin(name="ML.DISTANCE")
1531+
def vector_distance(vector1, vector2, type: str) -> ibis_dtypes.Float64:
1532+
"""Computes the distance between two vectors using specified type ("EUCLIDEAN", "MANHATTAN", or "COSINE")"""

bigframes/ml/core.py

+55-76
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from __future__ import annotations
1818

1919
import datetime
20-
from typing import Callable, cast, Iterable, Literal, Mapping, Optional, Union
20+
from typing import Callable, cast, Iterable, Mapping, Optional, Union
2121
import uuid
2222

2323
from google.cloud import bigquery
@@ -35,11 +35,27 @@ def __init__(self, session: bigframes.Session):
3535
self._session = session
3636
self._base_sql_generator = ml_sql.BaseSqlGenerator()
3737

38-
def _apply_sql(
38+
39+
class BqmlModel(BaseBqml):
40+
"""Represents an existing BQML model in BigQuery.
41+
42+
Wraps the BQML API and SQL interface to expose the functionality needed for
43+
BigQuery DataFrames ML.
44+
"""
45+
46+
def __init__(self, session: bigframes.Session, model: bigquery.Model):
47+
self._session = session
48+
self._model = model
49+
self._model_manipulation_sql_generator = ml_sql.ModelManipulationSqlGenerator(
50+
self.model_name
51+
)
52+
53+
def _apply_ml_tvf(
3954
self,
4055
input_data: bpd.DataFrame,
41-
func: Callable[[bpd.DataFrame], str],
56+
apply_sql_tvf: Callable[[str], str],
4257
) -> bpd.DataFrame:
58+
# Used for predict, transform, distance
4359
"""Helper to wrap a dataframe in a SQL query, keeping the index intact.
4460
4561
Args:
@@ -50,67 +66,28 @@ def _apply_sql(
5066
the dataframe to be wrapped
5167
5268
func (function):
53-
a function that will accept a SQL string and produce a new SQL
54-
string from which to construct the output dataframe. It must
55-
include the index columns of the input SQL.
69+
Takes an input sql table value and applies a prediction tvf. The
70+
resulting table value must include all input columns, with new
71+
columns appended to the end.
5672
"""
57-
_, index_col_ids, index_labels = input_data._to_sql_query(include_index=True)
58-
59-
sql = func(input_data)
60-
df = self._session.read_gbq(sql, index_col=index_col_ids)
61-
df.index.names = index_labels
62-
63-
return df
64-
65-
def distance(
66-
self,
67-
x: bpd.DataFrame,
68-
y: bpd.DataFrame,
69-
type: Literal["EUCLIDEAN", "MANHATTAN", "COSINE"],
70-
name: str,
71-
) -> bpd.DataFrame:
72-
"""Calculate ML.DISTANCE from DataFrame inputs.
73-
74-
Args:
75-
x:
76-
input DataFrame
77-
y:
78-
input DataFrame
79-
type:
80-
Distance types, accept values are "EUCLIDEAN", "MANHATTAN", "COSINE".
81-
name:
82-
name of the output result column
83-
"""
84-
assert len(x.columns) == 1 and len(y.columns) == 1
85-
86-
input_data = x.join(y, how="outer").cache()
87-
x_column_id, y_column_id = x._block.value_columns[0], y._block.value_columns[0]
88-
89-
return self._apply_sql(
90-
input_data,
91-
lambda source_df: self._base_sql_generator.ml_distance(
92-
x_column_id,
93-
y_column_id,
94-
type=type,
95-
source_df=source_df,
96-
name=name,
97-
),
73+
# TODO: Preserve ordering information?
74+
input_sql, index_col_ids, index_labels = input_data._to_sql_query(
75+
include_index=True
9876
)
9977

100-
101-
class BqmlModel(BaseBqml):
102-
"""Represents an existing BQML model in BigQuery.
103-
104-
Wraps the BQML API and SQL interface to expose the functionality needed for
105-
BigQuery DataFrames ML.
106-
"""
107-
108-
def __init__(self, session: bigframes.Session, model: bigquery.Model):
109-
self._session = session
110-
self._model = model
111-
self._model_manipulation_sql_generator = ml_sql.ModelManipulationSqlGenerator(
112-
self.model_name
78+
result_sql = apply_sql_tvf(input_sql)
79+
df = self._session.read_gbq(result_sql, index_col=index_col_ids)
80+
df.index.names = index_labels
81+
# Restore column labels
82+
df.rename(
83+
columns={
84+
label: original_label
85+
for label, original_label in zip(
86+
df.columns.values, input_data.columns.values
87+
)
88+
}
11389
)
90+
return df
11491

11592
def _keys(self):
11693
return (self._session, self._model)
@@ -137,13 +114,13 @@ def model(self) -> bigquery.Model:
137114
return self._model
138115

139116
def predict(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
140-
return self._apply_sql(
117+
return self._apply_ml_tvf(
141118
input_data,
142119
self._model_manipulation_sql_generator.ml_predict,
143120
)
144121

145122
def transform(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
146-
return self._apply_sql(
123+
return self._apply_ml_tvf(
147124
input_data,
148125
self._model_manipulation_sql_generator.ml_transform,
149126
)
@@ -153,10 +130,10 @@ def generate_text(
153130
input_data: bpd.DataFrame,
154131
options: Mapping[str, int | float],
155132
) -> bpd.DataFrame:
156-
return self._apply_sql(
133+
return self._apply_ml_tvf(
157134
input_data,
158-
lambda source_df: self._model_manipulation_sql_generator.ml_generate_text(
159-
source_df=source_df,
135+
lambda source_sql: self._model_manipulation_sql_generator.ml_generate_text(
136+
source_sql=source_sql,
160137
struct_options=options,
161138
),
162139
)
@@ -166,10 +143,10 @@ def generate_embedding(
166143
input_data: bpd.DataFrame,
167144
options: Mapping[str, int | float],
168145
) -> bpd.DataFrame:
169-
return self._apply_sql(
146+
return self._apply_ml_tvf(
170147
input_data,
171-
lambda source_df: self._model_manipulation_sql_generator.ml_generate_embedding(
172-
source_df=source_df,
148+
lambda source_sql: self._model_manipulation_sql_generator.ml_generate_embedding(
149+
source_sql=source_sql,
173150
struct_options=options,
174151
),
175152
)
@@ -179,10 +156,10 @@ def detect_anomalies(
179156
) -> bpd.DataFrame:
180157
assert self._model.model_type in ("PCA", "KMEANS", "ARIMA_PLUS")
181158

182-
return self._apply_sql(
159+
return self._apply_ml_tvf(
183160
input_data,
184-
lambda source_df: self._model_manipulation_sql_generator.ml_detect_anomalies(
185-
source_df=source_df,
161+
lambda source_sql: self._model_manipulation_sql_generator.ml_detect_anomalies(
162+
source_sql=source_sql,
186163
struct_options=options,
187164
),
188165
)
@@ -192,7 +169,9 @@ def forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
192169
return self._session.read_gbq(sql, index_col="forecast_timestamp").reset_index()
193170

194171
def evaluate(self, input_data: Optional[bpd.DataFrame] = None):
195-
sql = self._model_manipulation_sql_generator.ml_evaluate(input_data)
172+
sql = self._model_manipulation_sql_generator.ml_evaluate(
173+
input_data.sql if (input_data is not None) else None
174+
)
196175

197176
return self._session.read_gbq(sql)
198177

@@ -202,7 +181,7 @@ def llm_evaluate(
202181
task_type: Optional[str] = None,
203182
):
204183
sql = self._model_manipulation_sql_generator.ml_llm_evaluate(
205-
input_data, task_type
184+
input_data.sql, task_type
206185
)
207186

208187
return self._session.read_gbq(sql)
@@ -336,7 +315,7 @@ def create_model(
336315
model_ref = self._create_model_ref(session._anonymous_dataset)
337316

338317
sql = self._model_creation_sql_generator.create_model(
339-
source_df=input_data,
318+
source_sql=input_data.sql,
340319
model_ref=model_ref,
341320
transforms=transforms,
342321
options=options,
@@ -374,7 +353,7 @@ def create_llm_remote_model(
374353
model_ref = self._create_model_ref(session._anonymous_dataset)
375354

376355
sql = self._model_creation_sql_generator.create_llm_remote_model(
377-
source_df=input_data,
356+
source_sql=input_data.sql,
378357
model_ref=model_ref,
379358
options=options,
380359
connection_name=connection_name,
@@ -407,7 +386,7 @@ def create_time_series_model(
407386
model_ref = self._create_model_ref(session._anonymous_dataset)
408387

409388
sql = self._model_creation_sql_generator.create_model(
410-
source_df=input_data,
389+
source_sql=input_data.sql,
411390
model_ref=model_ref,
412391
transforms=transforms,
413392
options=options,

bigframes/ml/metrics/pairwise.py

+30-17
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,24 @@
1717

1818
import bigframes_vendored.sklearn.metrics.pairwise as vendored_metrics_pairwise
1919

20-
from bigframes.ml import core, utils
20+
from bigframes.ml import utils
21+
import bigframes.operations as ops
2122
import bigframes.pandas as bpd
2223

2324

2425
def paired_cosine_distances(
2526
X: Union[bpd.DataFrame, bpd.Series], Y: Union[bpd.DataFrame, bpd.Series]
2627
) -> bpd.DataFrame:
27-
X, Y = utils.convert_to_dataframe(X, Y)
28-
if len(X.columns) != 1 or len(Y.columns) != 1:
29-
raise ValueError("Inputs X and Y can only contain 1 column.")
28+
X, Y = utils.convert_to_series(X, Y)
29+
joined_block, _ = X._block.join(Y._block, how="outer")
3030

31-
base_bqml = core.BaseBqml(session=X._session)
32-
return base_bqml.distance(X, Y, type="COSINE", name="cosine_distance")
31+
result_block, _ = joined_block.project_expr(
32+
ops.cosine_distance_op.as_expr(
33+
joined_block.value_columns[0], joined_block.value_columns[1]
34+
),
35+
label="cosine_distance",
36+
)
37+
return bpd.DataFrame(result_block)
3338

3439

3540
paired_cosine_distances.__doc__ = inspect.getdoc(
@@ -40,12 +45,16 @@ def paired_cosine_distances(
4045
def paired_manhattan_distance(
4146
X: Union[bpd.DataFrame, bpd.Series], Y: Union[bpd.DataFrame, bpd.Series]
4247
) -> bpd.DataFrame:
43-
X, Y = utils.convert_to_dataframe(X, Y)
44-
if len(X.columns) != 1 or len(Y.columns) != 1:
45-
raise ValueError("Inputs X and Y can only contain 1 column.")
48+
X, Y = utils.convert_to_series(X, Y)
49+
joined_block, _ = X._block.join(Y._block, how="outer")
4650

47-
base_bqml = core.BaseBqml(session=X._session)
48-
return base_bqml.distance(X, Y, type="MANHATTAN", name="manhattan_distance")
51+
result_block, _ = joined_block.project_expr(
52+
ops.manhattan_distance_op.as_expr(
53+
joined_block.value_columns[0], joined_block.value_columns[1]
54+
),
55+
label="manhattan_distance",
56+
)
57+
return bpd.DataFrame(result_block)
4958

5059

5160
paired_manhattan_distance.__doc__ = inspect.getdoc(
@@ -56,12 +65,16 @@ def paired_manhattan_distance(
5665
def paired_euclidean_distances(
5766
X: Union[bpd.DataFrame, bpd.Series], Y: Union[bpd.DataFrame, bpd.Series]
5867
) -> bpd.DataFrame:
59-
X, Y = utils.convert_to_dataframe(X, Y)
60-
if len(X.columns) != 1 or len(Y.columns) != 1:
61-
raise ValueError("Inputs X and Y can only contain 1 column.")
62-
63-
base_bqml = core.BaseBqml(session=X._session)
64-
return base_bqml.distance(X, Y, type="EUCLIDEAN", name="euclidean_distance")
68+
X, Y = utils.convert_to_series(X, Y)
69+
joined_block, _ = X._block.join(Y._block, how="outer")
70+
71+
result_block, _ = joined_block.project_expr(
72+
ops.euclidean_distance_op.as_expr(
73+
joined_block.value_columns[0], joined_block.value_columns[1]
74+
),
75+
label="euclidean_distance",
76+
)
77+
return bpd.DataFrame(result_block)
6578

6679

6780
paired_euclidean_distances.__doc__ = inspect.getdoc(

0 commit comments

Comments
 (0)