Skip to content

Commit eaa1db0

Browse files
authored
refactor: read transformer output columns from model entity (#817)
* refactor: read transformer output columns from model entity * fix tests and docs * remove dup code * fix comment
1 parent cdfd979 commit eaa1db0

File tree

5 files changed

+124
-170
lines changed

5 files changed

+124
-170
lines changed

bigframes/ml/base.py

+18
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,24 @@ def __init__(self):
188188
def _keys(self):
189189
pass
190190

191+
def _extract_output_names(self):
192+
"""Extract transform output column names. Save the results to self._output_names."""
193+
assert self._bqml_model is not None
194+
195+
output_names = []
196+
for transform_col in self._bqml_model._model._properties["transformColumns"]:
197+
transform_col_dict = cast(dict, transform_col)
198+
# pass the columns that are not transformed
199+
if "transformSql" not in transform_col_dict:
200+
continue
201+
transform_sql: str = transform_col_dict["transformSql"]
202+
if not transform_sql.startswith("ML."):
203+
continue
204+
205+
output_names.append(transform_col_dict["name"])
206+
207+
self._output_names = output_names
208+
191209
def __eq__(self, other) -> bool:
192210
return type(self) is type(other) and self._keys() == other._keys()
193211

bigframes/ml/compose.py

+6-14
Original file line numberDiff line numberDiff line change
@@ -201,25 +201,20 @@ def _merge(
201201

202202
def _compile_to_sql(
203203
self,
204-
columns: List[str],
205204
X: bpd.DataFrame,
206-
) -> List[Tuple[str, str]]:
205+
) -> List[str]:
207206
"""Compile this transformer to a list of SQL expressions that can be included in
208207
a BQML TRANSFORM clause
209208
210209
Args:
211-
columns (List[str]):
212-
a list of column names to transform
213-
X (bpd.DataFrame):
214-
The Dataframe with training data.
210+
X: DataFrame to transform.
215211
216-
Returns:
217-
a list of tuples of (sql_expression, output_name)"""
212+
Returns: a list of sql_expr."""
218213
result = []
219214
for _, transformer, target_columns in self.transformers:
220215
if isinstance(target_columns, str):
221216
target_columns = [target_columns]
222-
result += transformer._compile_to_sql(target_columns, X=X)
217+
result += transformer._compile_to_sql(X, target_columns)
223218
return result
224219

225220
def fit(
@@ -229,17 +224,14 @@ def fit(
229224
) -> ColumnTransformer:
230225
(X,) = utils.convert_to_dataframe(X)
231226

232-
compiled_transforms = self._compile_to_sql(X.columns.tolist(), X)
233-
transform_sqls = [transform_sql for transform_sql, _ in compiled_transforms]
234-
227+
transform_sqls = self._compile_to_sql(X)
235228
self._bqml_model = self._bqml_model_factory.create_model(
236229
X,
237230
options={"model_type": "transform_only"},
238231
transforms=transform_sqls,
239232
)
240233

241-
# The schema of TRANSFORM output is not available in the model API, so save it during fitting
242-
self._output_names = [name for _, name in compiled_transforms]
234+
self._extract_output_names()
243235
return self
244236

245237
def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:

bigframes/ml/impute.py

+13-19
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from __future__ import annotations
1919

2020
import typing
21-
from typing import Iterable, List, Literal, Optional, Tuple, Union
21+
from typing import Iterable, List, Literal, Optional, Union
2222

2323
import bigframes_vendored.sklearn.impute._base
2424

@@ -49,25 +49,22 @@ def _keys(self):
4949

5050
def _compile_to_sql(
5151
self,
52-
columns: Iterable[str],
53-
X=None,
54-
) -> List[Tuple[str, str]]:
52+
X: bpd.DataFrame,
53+
columns: Optional[Iterable[str]] = None,
54+
) -> List[str]:
5555
"""Compile this transformer to a list of SQL expressions that can be included in
5656
a BQML TRANSFORM clause
5757
5858
Args:
59-
columns:
60-
A list of column names to transform.
61-
X:
62-
The Dataframe with training data.
59+
X: DataFrame to transform.
60+
columns: transform columns. If None, transform all columns in X.
6361
64-
Returns: a list of tuples of (sql_expression, output_name)"""
62+
Returns: a list of tuples sql_expr."""
63+
if columns is None:
64+
columns = X.columns
6565
return [
66-
(
67-
self._base_sql_generator.ml_imputer(
68-
column, self.strategy, f"imputer_{column}"
69-
),
70-
f"imputer_{column}",
66+
self._base_sql_generator.ml_imputer(
67+
column, self.strategy, f"imputer_{column}"
7168
)
7269
for column in columns
7370
]
@@ -92,17 +89,14 @@ def fit(
9289
) -> SimpleImputer:
9390
(X,) = utils.convert_to_dataframe(X)
9491

95-
compiled_transforms = self._compile_to_sql(X.columns.tolist(), X)
96-
transform_sqls = [transform_sql for transform_sql, _ in compiled_transforms]
97-
92+
transform_sqls = self._compile_to_sql(X)
9893
self._bqml_model = self._bqml_model_factory.create_model(
9994
X,
10095
options={"model_type": "transform_only"},
10196
transforms=transform_sqls,
10297
)
10398

104-
# The schema of TRANSFORM output is not available in the model API, so save it during fitting
105-
self._output_names = [name for _, name in compiled_transforms]
99+
self._extract_output_names()
106100
return self
107101

108102
def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:

bigframes/ml/pipeline.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,7 @@ def fit(
106106
) -> Pipeline:
107107
(X,) = utils.convert_to_dataframe(X)
108108

109-
compiled_transforms = self._transform._compile_to_sql(X.columns.tolist(), X=X)
110-
transform_sqls = [transform_sql for transform_sql, _ in compiled_transforms]
111-
109+
transform_sqls = self._transform._compile_to_sql(X)
112110
if y is not None:
113111
# If labels columns are present, they should pass through un-transformed
114112
(y,) = utils.convert_to_dataframe(y)

0 commit comments

Comments
 (0)