Skip to content

Commit d805241

Browse files
authored
feat: add transformers save/load (#552)
* feat: add transformers save/load * fix mypy
1 parent f207c8f commit d805241

File tree

9 files changed

+173
-95
lines changed

9 files changed

+173
-95
lines changed

bigframes/ml/base.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,33 @@ def fit(
178178
return self._fit(X, y)
179179

180180

181-
class Transformer(BaseEstimator):
181+
class BaseTransformer(BaseEstimator):
182+
"""Transformer base class."""
183+
184+
def __init__(self):
185+
self._bqml_model: Optional[core.BqmlModel] = None
186+
187+
_T = TypeVar("_T", bound="BaseTransformer")
188+
189+
def to_gbq(self: _T, model_name: str, replace: bool = False) -> _T:
190+
"""Save the transformer as a BigQuery model.
191+
192+
Args:
193+
model_name (str):
194+
the name of the model.
195+
replace (bool, default False):
196+
whether to replace if the model already exists. Default to False.
197+
198+
Returns:
199+
Saved transformer."""
200+
if not self._bqml_model:
201+
raise RuntimeError("A transformer must be fitted before it can be saved")
202+
203+
new_model = self._bqml_model.copy(model_name, replace)
204+
return new_model.session.read_gbq_model(model_name)
205+
206+
207+
class Transformer(BaseTransformer):
182208
"""A BigQuery DataFrames Transformer base class that transforms data.
183209
184210
Also the transformers can be attached to a pipeline with a predictor."""
@@ -199,7 +225,7 @@ def fit_transform(
199225
return self.fit(X, y).transform(X)
200226

201227

202-
class LabelTransformer(BaseEstimator):
228+
class LabelTransformer(BaseTransformer):
203229
"""A BigQuery DataFrames Label Transformer base class that transforms data.
204230
205231
Also the transformers can be attached to a pipeline with a predictor."""

bigframes/ml/compose.py

+5-49
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,11 @@
2626
import bigframes_vendored.sklearn.compose._column_transformer
2727
from google.cloud import bigquery
2828

29-
import bigframes
3029
from bigframes import constants
3130
from bigframes.core import log_adapter
3231
from bigframes.ml import base, core, globals, preprocessing, utils
3332
import bigframes.pandas as bpd
3433

35-
_PREPROCESSING_TYPES = Union[
36-
preprocessing.OneHotEncoder,
37-
preprocessing.StandardScaler,
38-
preprocessing.MaxAbsScaler,
39-
preprocessing.MinMaxScaler,
40-
preprocessing.KBinsDiscretizer,
41-
preprocessing.LabelEncoder,
42-
]
43-
4434
_BQML_TRANSFROM_TYPE_MAPPING = types.MappingProxyType(
4535
{
4636
"ML.STANDARD_SCALER": preprocessing.StandardScaler,
@@ -67,7 +57,7 @@ def __init__(
6757
transformers: List[
6858
Tuple[
6959
str,
70-
_PREPROCESSING_TYPES,
60+
preprocessing.PreprocessingType,
7161
Union[str, List[str]],
7262
]
7363
],
@@ -82,12 +72,12 @@ def __init__(
8272
@property
8373
def transformers_(
8474
self,
85-
) -> List[Tuple[str, _PREPROCESSING_TYPES, str,]]:
75+
) -> List[Tuple[str, preprocessing.PreprocessingType, str,]]:
8676
"""The collection of transformers as tuples of (name, transformer, column)."""
8777
result: List[
8878
Tuple[
8979
str,
90-
_PREPROCESSING_TYPES,
80+
preprocessing.PreprocessingType,
9181
str,
9282
]
9383
] = []
@@ -105,15 +95,6 @@ def transformers_(
10595

10696
return result
10797

108-
@classmethod
109-
def _from_bq(
110-
cls, session: bigframes.Session, model: bigquery.Model
111-
) -> ColumnTransformer:
112-
col_transformer = cls._extract_from_bq_model(model)
113-
col_transformer._bqml_model = core.BqmlModel(session, model)
114-
115-
return col_transformer
116-
11798
@classmethod
11899
def _extract_from_bq_model(
119100
cls,
@@ -125,7 +106,7 @@ def _extract_from_bq_model(
125106
transformers: List[
126107
Tuple[
127108
str,
128-
_PREPROCESSING_TYPES,
109+
preprocessing.PreprocessingType,
129110
Union[str, List[str]],
130111
]
131112
] = []
@@ -164,15 +145,7 @@ def camel_to_snake(name):
164145

165146
def _merge(
166147
self, bq_model: bigquery.Model
167-
) -> Union[
168-
ColumnTransformer,
169-
preprocessing.StandardScaler,
170-
preprocessing.OneHotEncoder,
171-
preprocessing.MaxAbsScaler,
172-
preprocessing.MinMaxScaler,
173-
preprocessing.KBinsDiscretizer,
174-
preprocessing.LabelEncoder,
175-
]:
148+
) -> Union[ColumnTransformer, preprocessing.PreprocessingType,]:
176149
"""Try to merge the column transformer to a simple transformer. Depends on all the columns in bq_model are transformed with the same transformer."""
177150
transformers = self.transformers_
178151

@@ -249,20 +222,3 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
249222
bpd.DataFrame,
250223
df[self._output_names],
251224
)
252-
253-
def to_gbq(self, model_name: str, replace: bool = False) -> ColumnTransformer:
254-
"""Save the transformer as a BigQuery model.
255-
256-
Args:
257-
model_name (str):
258-
the name of the model.
259-
replace (bool, default False):
260-
whether to replace if the model already exists. Default to False.
261-
262-
Returns:
263-
ColumnTransformer: saved model."""
264-
if not self._bqml_model:
265-
raise RuntimeError("A transformer must be fitted before it can be saved")
266-
267-
new_model = self._bqml_model.copy(model_name, replace)
268-
return new_model.session.read_gbq_model(model_name)

bigframes/ml/loader.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,15 @@
2424
from bigframes.ml import (
2525
cluster,
2626
compose,
27+
core,
2728
decomposition,
2829
ensemble,
2930
forecasting,
3031
imported,
3132
linear_model,
3233
llm,
3334
pipeline,
35+
preprocessing,
3436
utils,
3537
)
3638

@@ -81,6 +83,7 @@ def from_bq(
8183
llm.PaLM2TextEmbeddingGenerator,
8284
pipeline.Pipeline,
8385
compose.ColumnTransformer,
86+
preprocessing.PreprocessingType,
8487
]:
8588
"""Load a BQML model to BigQuery DataFrames ML.
8689
@@ -107,8 +110,12 @@ def from_bq(
107110

108111

109112
def _transformer_from_bq(session: bigframes.Session, bq_model: bigquery.Model):
110-
# TODO(garrettwu): add other transformers
111-
return compose.ColumnTransformer._from_bq(session, bq_model)
113+
transformer = compose.ColumnTransformer._extract_from_bq_model(bq_model)._merge(
114+
bq_model
115+
)
116+
transformer._bqml_model = core.BqmlModel(session, bq_model)
117+
118+
return transformer
112119

113120

114121
def _model_from_bq(session: bigframes.Session, bq_model: bigquery.Model):

bigframes/ml/preprocessing.py

+10
Original file line numberDiff line numberDiff line change
@@ -639,3 +639,13 @@ def transform(self, y: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
639639
bpd.DataFrame,
640640
df[self._output_names],
641641
)
642+
643+
644+
PreprocessingType = Union[
645+
OneHotEncoder,
646+
StandardScaler,
647+
MaxAbsScaler,
648+
MinMaxScaler,
649+
KBinsDiscretizer,
650+
LabelEncoder,
651+
]

tests/system/large/ml/test_compose.py

+1
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,4 @@ def test_columntransformer_save_load(new_penguins_df, dataset_id):
151151
("standard_scaler", preprocessing.StandardScaler(), "flipper_length_mm"),
152152
]
153153
assert reloaded_transformer.transformers_ == expected
154+
assert reloaded_transformer._bqml_model is not None

tests/system/large/ml/test_pipeline.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def test_pipeline_logistic_regression_fit_score_predict(
222222
)
223223

224224

225-
@pytest.mark.flaky(retries=2, delay=120)
225+
@pytest.mark.flaky(retries=2)
226226
def test_pipeline_xgbregressor_fit_score_predict(session, penguins_df_default_index):
227227
"""Test a supervised model with a minimal preprocessing step"""
228228
pl = pipeline.Pipeline(
@@ -297,7 +297,7 @@ def test_pipeline_xgbregressor_fit_score_predict(session, penguins_df_default_in
297297
)
298298

299299

300-
@pytest.mark.flaky(retries=2, delay=120)
300+
@pytest.mark.flaky(retries=2)
301301
def test_pipeline_random_forest_classifier_fit_score_predict(
302302
session, penguins_df_default_index
303303
):
@@ -445,7 +445,7 @@ def test_pipeline_PCA_fit_score_predict(session, penguins_df_default_index):
445445
)
446446

447447

448-
@pytest.mark.flaky(retries=2, delay=120)
448+
@pytest.mark.flaky(retries=2)
449449
def test_pipeline_standard_scaler_kmeans_fit_score_predict(
450450
session, penguins_pandas_df_default_index
451451
):

tests/system/small/ml/test_core.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def test_remote_model_predict(
333333
)
334334

335335

336-
@pytest.mark.flaky(retries=2, delay=120)
336+
@pytest.mark.flaky(retries=2)
337337
def test_model_generate_text(
338338
bqml_palm2_text_generator_model: core.BqmlModel, llm_text_df
339339
):

tests/system/small/ml/test_llm.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_create_text_generator_32k_model(
4949
assert reloaded_model.connection_name == bq_connection
5050

5151

52-
@pytest.mark.flaky(retries=2, delay=120)
52+
@pytest.mark.flaky(retries=2)
5353
def test_create_text_generator_model_default_session(
5454
bq_connection, llm_text_pandas_df, bigquery_client
5555
):
@@ -76,7 +76,7 @@ def test_create_text_generator_model_default_session(
7676
assert all(series.str.len() > 20)
7777

7878

79-
@pytest.mark.flaky(retries=2, delay=120)
79+
@pytest.mark.flaky(retries=2)
8080
def test_create_text_generator_32k_model_default_session(
8181
bq_connection, llm_text_pandas_df, bigquery_client
8282
):
@@ -103,7 +103,7 @@ def test_create_text_generator_32k_model_default_session(
103103
assert all(series.str.len() > 20)
104104

105105

106-
@pytest.mark.flaky(retries=2, delay=120)
106+
@pytest.mark.flaky(retries=2)
107107
def test_create_text_generator_model_default_connection(
108108
llm_text_pandas_df, bigquery_client
109109
):
@@ -131,7 +131,7 @@ def test_create_text_generator_model_default_connection(
131131

132132

133133
# Marked as flaky only because BQML LLM is in preview, the service only has limited capacity, not stable enough.
134-
@pytest.mark.flaky(retries=2, delay=120)
134+
@pytest.mark.flaky(retries=2)
135135
def test_text_generator_predict_default_params_success(
136136
palm2_text_generator_model, llm_text_df
137137
):
@@ -142,7 +142,7 @@ def test_text_generator_predict_default_params_success(
142142
assert all(series.str.len() > 20)
143143

144144

145-
@pytest.mark.flaky(retries=2, delay=120)
145+
@pytest.mark.flaky(retries=2)
146146
def test_text_generator_predict_series_default_params_success(
147147
palm2_text_generator_model, llm_text_df
148148
):
@@ -153,7 +153,7 @@ def test_text_generator_predict_series_default_params_success(
153153
assert all(series.str.len() > 20)
154154

155155

156-
@pytest.mark.flaky(retries=2, delay=120)
156+
@pytest.mark.flaky(retries=2)
157157
def test_text_generator_predict_arbitrary_col_label_success(
158158
palm2_text_generator_model, llm_text_df
159159
):
@@ -165,7 +165,7 @@ def test_text_generator_predict_arbitrary_col_label_success(
165165
assert all(series.str.len() > 20)
166166

167167

168-
@pytest.mark.flaky(retries=2, delay=120)
168+
@pytest.mark.flaky(retries=2)
169169
def test_text_generator_predict_with_params_success(
170170
palm2_text_generator_model, llm_text_df
171171
):
@@ -255,7 +255,7 @@ def test_create_text_embedding_generator_multilingual_model_defaults(bq_connecti
255255
assert model._bqml_model is not None
256256

257257

258-
@pytest.mark.flaky(retries=2, delay=120)
258+
@pytest.mark.flaky(retries=2)
259259
def test_embedding_generator_predict_success(
260260
palm2_embedding_generator_model, llm_text_df
261261
):
@@ -267,7 +267,7 @@ def test_embedding_generator_predict_success(
267267
assert len(value) == 768
268268

269269

270-
@pytest.mark.flaky(retries=2, delay=120)
270+
@pytest.mark.flaky(retries=2)
271271
def test_embedding_generator_multilingual_predict_success(
272272
palm2_embedding_generator_multilingual_model, llm_text_df
273273
):
@@ -279,7 +279,7 @@ def test_embedding_generator_multilingual_predict_success(
279279
assert len(value) == 768
280280

281281

282-
@pytest.mark.flaky(retries=2, delay=120)
282+
@pytest.mark.flaky(retries=2)
283283
def test_embedding_generator_predict_series_success(
284284
palm2_embedding_generator_model, llm_text_df
285285
):
@@ -306,7 +306,7 @@ def test_create_gemini_text_generator_model(
306306
assert reloaded_model.connection_name == bq_connection
307307

308308

309-
@pytest.mark.flaky(retries=2, delay=120)
309+
@pytest.mark.flaky(retries=2)
310310
def test_gemini_text_generator_predict_default_params_success(
311311
gemini_text_generator_model, llm_text_df
312312
):
@@ -317,7 +317,7 @@ def test_gemini_text_generator_predict_default_params_success(
317317
assert all(series.str.len() > 20)
318318

319319

320-
@pytest.mark.flaky(retries=2, delay=120)
320+
@pytest.mark.flaky(retries=2)
321321
def test_gemini_text_generator_predict_with_params_success(
322322
gemini_text_generator_model, llm_text_df
323323
):

0 commit comments

Comments
 (0)