Skip to content

feat: add ColumnTransformer save/load #541

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 128 additions & 5 deletions bigframes/ml/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,21 @@

from __future__ import annotations

import re
import types
import typing
from typing import List, Optional, Tuple, Union
from typing import cast, List, Optional, Tuple, Union

import bigframes_vendored.sklearn.compose._column_transformer
from google.cloud import bigquery

import bigframes
from bigframes import constants
from bigframes.core import log_adapter
from bigframes.ml import base, core, globals, preprocessing, utils
import bigframes.pandas as bpd

CompilablePreprocessorType = Union[
_PREPROCESSING_TYPES = Union[
preprocessing.OneHotEncoder,
preprocessing.StandardScaler,
preprocessing.MaxAbsScaler,
Expand All @@ -36,6 +41,17 @@
preprocessing.LabelEncoder,
]

_BQML_TRANSFROM_TYPE_MAPPING = types.MappingProxyType(
{
"ML.STANDARD_SCALER": preprocessing.StandardScaler,
"ML.ONE_HOT_ENCODER": preprocessing.OneHotEncoder,
"ML.MAX_ABS_SCALER": preprocessing.MaxAbsScaler,
"ML.MIN_MAX_SCALER": preprocessing.MinMaxScaler,
"ML.BUCKETIZE": preprocessing.KBinsDiscretizer,
"ML.LABEL_ENCODER": preprocessing.LabelEncoder,
}
)


@log_adapter.class_logger
class ColumnTransformer(
Expand All @@ -51,7 +67,7 @@ def __init__(
transformers: List[
Tuple[
str,
CompilablePreprocessorType,
_PREPROCESSING_TYPES,
Union[str, List[str]],
]
],
Expand All @@ -66,12 +82,12 @@ def __init__(
@property
def transformers_(
self,
) -> List[Tuple[str, CompilablePreprocessorType, str,]]:
) -> List[Tuple[str, _PREPROCESSING_TYPES, str,]]:
"""The collection of transformers as tuples of (name, transformer, column)."""
result: List[
Tuple[
str,
CompilablePreprocessorType,
_PREPROCESSING_TYPES,
str,
]
] = []
Expand All @@ -89,6 +105,96 @@ def transformers_(

return result

@classmethod
def _from_bq(
cls, session: bigframes.Session, model: bigquery.Model
) -> ColumnTransformer:
col_transformer = cls._extract_from_bq_model(model)
col_transformer._bqml_model = core.BqmlModel(session, model)

return col_transformer

@classmethod
def _extract_from_bq_model(
cls,
bq_model: bigquery.Model,
) -> ColumnTransformer:
"""Extract transformers as ColumnTransformer obj from a BQ Model. Keep the _bqml_model field as None."""
assert "transformColumns" in bq_model._properties

transformers: List[
Tuple[
str,
_PREPROCESSING_TYPES,
Union[str, List[str]],
]
] = []

def camel_to_snake(name):
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()

for transform_col in bq_model._properties["transformColumns"]:
# pass the columns that are not transformed
if "transformSql" not in transform_col:
continue
transform_sql: str = cast(dict, transform_col)["transformSql"]
if not transform_sql.startswith("ML."):
continue

found_transformer = False
for prefix in _BQML_TRANSFROM_TYPE_MAPPING:
if transform_sql.startswith(prefix):
transformer_cls = _BQML_TRANSFROM_TYPE_MAPPING[prefix]
transformers.append(
(
camel_to_snake(transformer_cls.__name__),
*transformer_cls._parse_from_sql(transform_sql), # type: ignore
)
)

found_transformer = True
break
if not found_transformer:
raise NotImplementedError(
f"Unsupported transformer type. {constants.FEEDBACK_LINK}"
)

return cls(transformers=transformers)

def _merge(
self, bq_model: bigquery.Model
) -> Union[
ColumnTransformer,
preprocessing.StandardScaler,
preprocessing.OneHotEncoder,
preprocessing.MaxAbsScaler,
preprocessing.MinMaxScaler,
preprocessing.KBinsDiscretizer,
preprocessing.LabelEncoder,
]:
"""Try to merge the column transformer to a simple transformer. Depends on all the columns in bq_model are transformed with the same transformer."""
transformers = self.transformers_

assert len(transformers) > 0
_, transformer_0, column_0 = transformers[0]
columns = [column_0]
for _, transformer, column in transformers[1:]:
# all transformers are the same
if transformer != transformer_0:
return self
columns.append(column)
# all feature columns are transformed
if sorted(
[
cast(str, feature_column.name)
for feature_column in bq_model.feature_columns
]
) == sorted(columns):
return transformer_0

return self

def _compile_to_sql(
self,
columns: List[str],
Expand Down Expand Up @@ -143,3 +249,20 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
bpd.DataFrame,
df[self._output_names],
)

def to_gbq(self, model_name: str, replace: bool = False) -> ColumnTransformer:
"""Save the transformer as a BigQuery model.

Args:
model_name (str):
the name of the model.
replace (bool, default False):
whether to replace if the model already exists. Default to False.

Returns:
ColumnTransformer: saved model."""
if not self._bqml_model:
raise RuntimeError("A transformer must be fitted before it can be saved")

new_model = self._bqml_model.copy(model_name, replace)
return new_model.session.read_gbq_model(model_name)
30 changes: 25 additions & 5 deletions bigframes/ml/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import bigframes.constants as constants
from bigframes.ml import (
cluster,
compose,
decomposition,
ensemble,
forecasting,
Expand Down Expand Up @@ -79,6 +80,7 @@ def from_bq(
llm.PaLM2TextGenerator,
llm.PaLM2TextEmbeddingGenerator,
pipeline.Pipeline,
compose.ColumnTransformer,
]:
"""Load a BQML model to BigQuery DataFrames ML.

Expand All @@ -89,22 +91,32 @@ def from_bq(
Returns:
A BigQuery DataFrames ML model object.
"""
# TODO(garrettwu): the entire condition only to TRANSFORM_ONLY when b/331679273 is fixed.
if (
bq_model.model_type == "TRANSFORM_ONLY"
or bq_model.model_type == "MODEL_TYPE_UNSPECIFIED"
and "transformColumns" in bq_model._properties
and not _is_bq_model_remote(bq_model)
):
return _transformer_from_bq(session, bq_model)

if _is_bq_model_pipeline(bq_model):
return pipeline.Pipeline._from_bq(session, bq_model)

return _model_from_bq(session, bq_model)


def _transformer_from_bq(session: bigframes.Session, bq_model: bigquery.Model):
# TODO(garrettwu): add other transformers
return compose.ColumnTransformer._from_bq(session, bq_model)


def _model_from_bq(session: bigframes.Session, bq_model: bigquery.Model):
if bq_model.model_type in _BQML_MODEL_TYPE_MAPPING:
return _BQML_MODEL_TYPE_MAPPING[bq_model.model_type]._from_bq( # type: ignore
session=session, model=bq_model
)
if (
bq_model.model_type == "MODEL_TYPE_UNSPECIFIED"
and "remoteModelInfo" in bq_model._properties
and "endpoint" in bq_model._properties["remoteModelInfo"]
):
if _is_bq_model_remote(bq_model):
# Parse the remote model endpoint
bqml_endpoint = bq_model._properties["remoteModelInfo"]["endpoint"]
model_endpoint = bqml_endpoint.split("/")[-1]
Expand All @@ -121,3 +133,11 @@ def _model_from_bq(session: bigframes.Session, bq_model: bigquery.Model):

def _is_bq_model_pipeline(bq_model: bigquery.Model) -> bool:
return "transformColumns" in bq_model._properties


def _is_bq_model_remote(bq_model: bigquery.Model) -> bool:
return (
bq_model.model_type == "MODEL_TYPE_UNSPECIFIED"
and "remoteModelInfo" in bq_model._properties
and "endpoint" in bq_model._properties["remoteModelInfo"]
)
113 changes: 3 additions & 110 deletions bigframes/ml/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from __future__ import annotations

from typing import cast, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

import bigframes_vendored.sklearn.pipeline
from google.cloud import bigquery
Expand Down Expand Up @@ -83,8 +83,8 @@ def __init__(self, steps: List[Tuple[str, base.BaseEstimator]]):

@classmethod
def _from_bq(cls, session: bigframes.Session, bq_model: bigquery.Model) -> Pipeline:
col_transformer = _extract_as_column_transformer(bq_model)
transform = _merge_column_transformer(bq_model, col_transformer)
col_transformer = compose.ColumnTransformer._extract_from_bq_model(bq_model)
transform = col_transformer._merge(bq_model)

estimator = loader._model_from_bq(session, bq_model)
return cls([("transform", transform), ("estimator", estimator)])
Expand Down Expand Up @@ -138,110 +138,3 @@ def to_gbq(self, model_name: str, replace: bool = False) -> Pipeline:
new_model = self._estimator._bqml_model.copy(model_name, replace)

return new_model.session.read_gbq_model(model_name)


def _extract_as_column_transformer(
bq_model: bigquery.Model,
) -> compose.ColumnTransformer:
"""Extract transformers as ColumnTransformer obj from a BQ Model."""
assert "transformColumns" in bq_model._properties

transformers: List[
Tuple[
str,
Union[
preprocessing.OneHotEncoder,
preprocessing.StandardScaler,
preprocessing.MaxAbsScaler,
preprocessing.MinMaxScaler,
preprocessing.KBinsDiscretizer,
preprocessing.LabelEncoder,
],
Union[str, List[str]],
]
] = []
for transform_col in bq_model._properties["transformColumns"]:
# pass the columns that are not transformed
if "transformSql" not in transform_col:
continue

transform_sql: str = cast(dict, transform_col)["transformSql"]
if transform_sql.startswith("ML.STANDARD_SCALER"):
transformers.append(
(
"standard_scaler",
*preprocessing.StandardScaler._parse_from_sql(transform_sql),
)
)
elif transform_sql.startswith("ML.ONE_HOT_ENCODER"):
transformers.append(
(
"ont_hot_encoder",
*preprocessing.OneHotEncoder._parse_from_sql(transform_sql),
)
)
elif transform_sql.startswith("ML.MAX_ABS_SCALER"):
transformers.append(
(
"max_abs_scaler",
*preprocessing.MaxAbsScaler._parse_from_sql(transform_sql),
)
)
elif transform_sql.startswith("ML.MIN_MAX_SCALER"):
transformers.append(
(
"min_max_scaler",
*preprocessing.MinMaxScaler._parse_from_sql(transform_sql),
)
)
elif transform_sql.startswith("ML.BUCKETIZE"):
transformers.append(
(
"k_bins_discretizer",
*preprocessing.KBinsDiscretizer._parse_from_sql(transform_sql),
)
)
elif transform_sql.startswith("ML.LABEL_ENCODER"):
transformers.append(
(
"label_encoder",
*preprocessing.LabelEncoder._parse_from_sql(transform_sql),
)
)
else:
raise NotImplementedError(
f"Unsupported transformer type. {constants.FEEDBACK_LINK}"
)

return compose.ColumnTransformer(transformers=transformers)


def _merge_column_transformer(
bq_model: bigquery.Model, column_transformer: compose.ColumnTransformer
) -> Union[
compose.ColumnTransformer,
preprocessing.StandardScaler,
preprocessing.OneHotEncoder,
preprocessing.MaxAbsScaler,
preprocessing.MinMaxScaler,
preprocessing.KBinsDiscretizer,
preprocessing.LabelEncoder,
]:
"""Try to merge the column transformer to a simple transformer."""
transformers = column_transformer.transformers_

assert len(transformers) > 0
_, transformer_0, column_0 = transformers[0]
columns = [column_0]
for _, transformer, column in transformers[1:]:
# all transformers are the same
if transformer != transformer_0:
return column_transformer
columns.append(column)
# all feature columns are transformed
if sorted(
[cast(str, feature_column.name) for feature_column in bq_model.feature_columns]
) == sorted(columns):
return transformer_0

return column_transformer
2 changes: 1 addition & 1 deletion bigframes/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,7 @@ def read_gbq_model(self, model_name: str):
to load from the default project.

Returns:
A bigframes.ml Model wrapping the model.
A bigframes.ml Model, Transformer or Pipeline wrapping the model.
"""
import bigframes.ml.loader

Expand Down
Loading