Skip to content

Commit 57d98b9

Browse files
authored
feat: add PolynomailFeatures to_gbq and pipeline support (#805)
* feat: add ml.preprocessing.PolynomialFeatures class * feat: add PolynomailFeatures to_gbq and pipeline support * fix tests * fix tests
1 parent ec5b068 commit 57d98b9

File tree

10 files changed

+182
-87
lines changed

10 files changed

+182
-87
lines changed

bigframes/ml/base.py

+10
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,16 @@ class BaseTransformer(BaseEstimator):
184184
def __init__(self):
185185
self._bqml_model: Optional[core.BqmlModel] = None
186186

187+
@abc.abstractmethod
188+
def _keys(self):
189+
pass
190+
191+
def __eq__(self, other) -> bool:
192+
return type(self) is type(other) and self._keys() == other._keys()
193+
194+
def __hash__(self) -> int:
195+
return hash(self._keys())
196+
187197
_T = TypeVar("_T", bound="BaseTransformer")
188198

189199
def to_gbq(self: _T, model_name: str, replace: bool = False) -> _T:

bigframes/ml/compose.py

+39-21
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import re
2222
import types
2323
import typing
24-
from typing import cast, List, Optional, Tuple, Union
24+
from typing import cast, Iterable, List, Optional, Set, Tuple, Union
2525

2626
import bigframes_vendored.sklearn.compose._column_transformer
2727
from google.cloud import bigquery
@@ -40,6 +40,7 @@
4040
"ML.BUCKETIZE": preprocessing.KBinsDiscretizer,
4141
"ML.QUANTILE_BUCKETIZE": preprocessing.KBinsDiscretizer,
4242
"ML.LABEL_ENCODER": preprocessing.LabelEncoder,
43+
"ML.POLYNOMIAL_EXPAND": preprocessing.PolynomialFeatures,
4344
"ML.IMPUTER": impute.SimpleImputer,
4445
}
4546
)
@@ -56,21 +57,24 @@ class ColumnTransformer(
5657

5758
def __init__(
5859
self,
59-
transformers: List[
60+
transformers: Iterable[
6061
Tuple[
6162
str,
6263
Union[preprocessing.PreprocessingType, impute.SimpleImputer],
63-
Union[str, List[str]],
64+
Union[str, Iterable[str]],
6465
]
6566
],
6667
):
6768
# TODO: if any(transformers) has fitted raise warning
68-
self.transformers = transformers
69+
self.transformers = list(transformers)
6970
self._bqml_model: Optional[core.BqmlModel] = None
7071
self._bqml_model_factory = globals.bqml_model_factory()
7172
# call self.transformers_ to check chained transformers
7273
self.transformers_
7374

75+
def _keys(self):
76+
return (self.transformers, self._bqml_model)
77+
7478
@property
7579
def transformers_(
7680
self,
@@ -107,13 +111,13 @@ def _extract_from_bq_model(
107111
"""Extract transformers as ColumnTransformer obj from a BQ Model. Keep the _bqml_model field as None."""
108112
assert "transformColumns" in bq_model._properties
109113

110-
transformers: List[
114+
transformers_set: Set[
111115
Tuple[
112116
str,
113117
Union[preprocessing.PreprocessingType, impute.SimpleImputer],
114118
Union[str, List[str]],
115119
]
116-
] = []
120+
] = set()
117121

118122
def camel_to_snake(name):
119123
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
@@ -134,7 +138,7 @@ def camel_to_snake(name):
134138
for prefix in _BQML_TRANSFROM_TYPE_MAPPING:
135139
if transform_sql.startswith(prefix):
136140
transformer_cls = _BQML_TRANSFROM_TYPE_MAPPING[prefix]
137-
transformers.append(
141+
transformers_set.add(
138142
(
139143
camel_to_snake(transformer_cls.__name__),
140144
*transformer_cls._parse_from_sql(transform_sql), # type: ignore
@@ -148,7 +152,7 @@ def camel_to_snake(name):
148152
f"Unsupported transformer type. {constants.FEEDBACK_LINK}"
149153
)
150154

151-
transformer = cls(transformers=transformers)
155+
transformer = cls(transformers=list(transformers_set))
152156
transformer._output_names = output_names
153157

154158
return transformer
@@ -159,23 +163,37 @@ def _merge(
159163
ColumnTransformer, Union[preprocessing.PreprocessingType, impute.SimpleImputer]
160164
]:
161165
"""Try to merge the column transformer to a simple transformer. Depends on all the columns in bq_model are transformed with the same transformer."""
162-
transformers = self.transformers_
166+
transformers = self.transformers
163167

164168
assert len(transformers) > 0
165169
_, transformer_0, column_0 = transformers[0]
170+
feature_columns_sorted = sorted(
171+
[
172+
cast(str, feature_column.name)
173+
for feature_column in bq_model.feature_columns
174+
]
175+
)
176+
177+
if (
178+
len(transformers) == 1
179+
and isinstance(transformer_0, preprocessing.PolynomialFeatures)
180+
and sorted(column_0) == feature_columns_sorted
181+
):
182+
transformer_0._output_names = self._output_names
183+
return transformer_0
184+
185+
if not isinstance(column_0, str):
186+
return self
166187
columns = [column_0]
167188
for _, transformer, column in transformers[1:]:
189+
if not isinstance(column, str):
190+
return self
168191
# all transformers are the same
169192
if transformer != transformer_0:
170193
return self
171194
columns.append(column)
172195
# all feature columns are transformed
173-
if sorted(
174-
[
175-
cast(str, feature_column.name)
176-
for feature_column in bq_model.feature_columns
177-
]
178-
) == sorted(columns):
196+
if sorted(columns) == feature_columns_sorted:
179197
transformer_0._output_names = self._output_names
180198
return transformer_0
181199

@@ -197,12 +215,12 @@ def _compile_to_sql(
197215
198216
Returns:
199217
a list of tuples of (sql_expression, output_name)"""
200-
return [
201-
transformer._compile_to_sql([column], X=X)[0]
202-
for column in columns
203-
for _, transformer, target_column in self.transformers_
204-
if column == target_column
205-
]
218+
result = []
219+
for _, transformer, target_columns in self.transformers:
220+
if isinstance(target_columns, str):
221+
target_columns = [target_columns]
222+
result += transformer._compile_to_sql(target_columns, X=X)
223+
return result
206224

207225
def fit(
208226
self,

bigframes/ml/core.py

+9
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,15 @@ def __init__(self, session: bigframes.Session, model: bigquery.Model):
112112
self.model_name
113113
)
114114

115+
def _keys(self):
116+
return (self._session, self._model)
117+
118+
def __eq__(self, other):
119+
return isinstance(other, self.__class__) and self._keys() == other._keys()
120+
121+
def __hash__(self):
122+
return hash(self._keys())
123+
115124
@property
116125
def session(self) -> bigframes.Session:
117126
"""Get the BigQuery DataFrames session that this BQML model wrapper is tied to"""

bigframes/ml/impute.py

+4-9
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from __future__ import annotations
1919

2020
import typing
21-
from typing import Any, List, Literal, Optional, Tuple, Union
21+
from typing import Iterable, List, Literal, Optional, Tuple, Union
2222

2323
import bigframes_vendored.sklearn.impute._base
2424

@@ -44,17 +44,12 @@ def __init__(
4444
self._bqml_model_factory = globals.bqml_model_factory()
4545
self._base_sql_generator = globals.base_sql_generator()
4646

47-
# TODO(garrettwu): implement __hash__
48-
def __eq__(self, other: Any) -> bool:
49-
return (
50-
type(other) is SimpleImputer
51-
and self.strategy == other.strategy
52-
and self._bqml_model == other._bqml_model
53-
)
47+
def _keys(self):
48+
return (self._bqml_model, self.strategy)
5449

5550
def _compile_to_sql(
5651
self,
57-
columns: List[str],
52+
columns: Iterable[str],
5853
X=None,
5954
) -> List[Tuple[str, str]]:
6055
"""Compile this transformer to a list of SQL expressions that can be included in

bigframes/ml/pipeline.py

+1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def __init__(self, steps: List[Tuple[str, base.BaseEstimator]]):
6464
preprocessing.MinMaxScaler,
6565
preprocessing.KBinsDiscretizer,
6666
preprocessing.LabelEncoder,
67+
preprocessing.PolynomialFeatures,
6768
impute.SimpleImputer,
6869
),
6970
):

bigframes/ml/preprocessing.py

+30-53
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from __future__ import annotations
1919

2020
import typing
21-
from typing import Any, cast, List, Literal, Optional, Tuple, Union
21+
from typing import cast, Iterable, List, Literal, Optional, Tuple, Union
2222

2323
import bigframes_vendored.sklearn.preprocessing._data
2424
import bigframes_vendored.sklearn.preprocessing._discretization
@@ -43,11 +43,10 @@ def __init__(self):
4343
self._bqml_model_factory = globals.bqml_model_factory()
4444
self._base_sql_generator = globals.base_sql_generator()
4545

46-
# TODO(garrettwu): implement __hash__
47-
def __eq__(self, other: Any) -> bool:
48-
return type(other) is StandardScaler and self._bqml_model == other._bqml_model
46+
def _keys(self):
47+
return (self._bqml_model,)
4948

50-
def _compile_to_sql(self, columns: List[str], X=None) -> List[Tuple[str, str]]:
49+
def _compile_to_sql(self, columns: Iterable[str], X=None) -> List[Tuple[str, str]]:
5150
"""Compile this transformer to a list of SQL expressions that can be included in
5251
a BQML TRANSFORM clause
5352
@@ -125,11 +124,10 @@ def __init__(self):
125124
self._bqml_model_factory = globals.bqml_model_factory()
126125
self._base_sql_generator = globals.base_sql_generator()
127126

128-
# TODO(garrettwu): implement __hash__
129-
def __eq__(self, other: Any) -> bool:
130-
return type(other) is MaxAbsScaler and self._bqml_model == other._bqml_model
127+
def _keys(self):
128+
return (self._bqml_model,)
131129

132-
def _compile_to_sql(self, columns: List[str], X=None) -> List[Tuple[str, str]]:
130+
def _compile_to_sql(self, columns: Iterable[str], X=None) -> List[Tuple[str, str]]:
133131
"""Compile this transformer to a list of SQL expressions that can be included in
134132
a BQML TRANSFORM clause
135133
@@ -207,11 +205,10 @@ def __init__(self):
207205
self._bqml_model_factory = globals.bqml_model_factory()
208206
self._base_sql_generator = globals.base_sql_generator()
209207

210-
# TODO(garrettwu): implement __hash__
211-
def __eq__(self, other: Any) -> bool:
212-
return type(other) is MinMaxScaler and self._bqml_model == other._bqml_model
208+
def _keys(self):
209+
return (self._bqml_model,)
213210

214-
def _compile_to_sql(self, columns: List[str], X=None) -> List[Tuple[str, str]]:
211+
def _compile_to_sql(self, columns: Iterable[str], X=None) -> List[Tuple[str, str]]:
215212
"""Compile this transformer to a list of SQL expressions that can be included in
216213
a BQML TRANSFORM clause
217214
@@ -301,18 +298,12 @@ def __init__(
301298
self._bqml_model_factory = globals.bqml_model_factory()
302299
self._base_sql_generator = globals.base_sql_generator()
303300

304-
# TODO(garrettwu): implement __hash__
305-
def __eq__(self, other: Any) -> bool:
306-
return (
307-
type(other) is KBinsDiscretizer
308-
and self.n_bins == other.n_bins
309-
and self.strategy == other.strategy
310-
and self._bqml_model == other._bqml_model
311-
)
301+
def _keys(self):
302+
return (self._bqml_model, self.n_bins, self.strategy)
312303

313304
def _compile_to_sql(
314305
self,
315-
columns: List[str],
306+
columns: Iterable[str],
316307
X: bpd.DataFrame,
317308
) -> List[Tuple[str, str]]:
318309
"""Compile this transformer to a list of SQL expressions that can be included in
@@ -446,17 +437,10 @@ def __init__(
446437
self._bqml_model_factory = globals.bqml_model_factory()
447438
self._base_sql_generator = globals.base_sql_generator()
448439

449-
# TODO(garrettwu): implement __hash__
450-
def __eq__(self, other: Any) -> bool:
451-
return (
452-
type(other) is OneHotEncoder
453-
and self._bqml_model == other._bqml_model
454-
and self.drop == other.drop
455-
and self.min_frequency == other.min_frequency
456-
and self.max_categories == other.max_categories
457-
)
440+
def _keys(self):
441+
return (self._bqml_model, self.drop, self.min_frequency, self.max_categories)
458442

459-
def _compile_to_sql(self, columns: List[str], X=None) -> List[Tuple[str, str]]:
443+
def _compile_to_sql(self, columns: Iterable[str], X=None) -> List[Tuple[str, str]]:
460444
"""Compile this transformer to a list of SQL expressions that can be included in
461445
a BQML TRANSFORM clause
462446
@@ -572,16 +556,10 @@ def __init__(
572556
self._bqml_model_factory = globals.bqml_model_factory()
573557
self._base_sql_generator = globals.base_sql_generator()
574558

575-
# TODO(garrettwu): implement __hash__
576-
def __eq__(self, other: Any) -> bool:
577-
return (
578-
type(other) is LabelEncoder
579-
and self._bqml_model == other._bqml_model
580-
and self.min_frequency == other.min_frequency
581-
and self.max_categories == other.max_categories
582-
)
559+
def _keys(self):
560+
return (self._bqml_model, self.min_frequency, self.max_categories)
583561

584-
def _compile_to_sql(self, columns: List[str], X=None) -> List[Tuple[str, str]]:
562+
def _compile_to_sql(self, columns: Iterable[str], X=None) -> List[Tuple[str, str]]:
585563
"""Compile this transformer to a list of SQL expressions that can be included in
586564
a BQML TRANSFORM clause
587565
@@ -672,18 +650,17 @@ class PolynomialFeatures(
672650
)
673651

674652
def __init__(self, degree: int = 2):
653+
if degree not in range(1, 5):
654+
raise ValueError(f"degree has to be [1, 4], input is {degree}.")
675655
self.degree = degree
676656
self._bqml_model: Optional[core.BqmlModel] = None
677657
self._bqml_model_factory = globals.bqml_model_factory()
678658
self._base_sql_generator = globals.base_sql_generator()
679659

680-
# TODO(garrettwu): implement __hash__
681-
def __eq__(self, other: Any) -> bool:
682-
return (
683-
type(other) is PolynomialFeatures and self._bqml_model == other._bqml_model
684-
)
660+
def _keys(self):
661+
return (self._bqml_model, self.degree)
685662

686-
def _compile_to_sql(self, columns: List[str], X=None) -> List[Tuple[str, str]]:
663+
def _compile_to_sql(self, columns: Iterable[str], X=None) -> List[Tuple[str, str]]:
687664
"""Compile this transformer to a list of SQL expressions that can be included in
688665
a BQML TRANSFORM clause
689666
@@ -705,17 +682,18 @@ def _compile_to_sql(self, columns: List[str], X=None) -> List[Tuple[str, str]]:
705682
]
706683

707684
@classmethod
708-
def _parse_from_sql(cls, sql: str) -> tuple[PolynomialFeatures, str]:
709-
"""Parse SQL to tuple(PolynomialFeatures, column_label).
685+
def _parse_from_sql(cls, sql: str) -> tuple[PolynomialFeatures, tuple[str, ...]]:
686+
"""Parse SQL to tuple(PolynomialFeatures, column_labels).
710687
711688
Args:
712689
sql: SQL string of format "ML.POLYNOMIAL_EXPAND(STRUCT(col_label0, col_label1, ...), degree)"
713690
714691
Returns:
715692
tuple(MaxAbsScaler, column_label)"""
716-
col_label = sql[sql.find("STRUCT(") + 7 : sql.find(")")]
693+
col_labels = sql[sql.find("STRUCT(") + 7 : sql.find(")")].split(",")
694+
col_labels = [label.strip() for label in col_labels]
717695
degree = int(sql[sql.rfind(",") + 1 : sql.rfind(")")])
718-
return cls(degree), col_label
696+
return cls(degree), tuple(col_labels)
719697

720698
def fit(
721699
self,
@@ -762,8 +740,6 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
762740
df[self._output_names],
763741
)
764742

765-
# TODO(garrettwu): to_gbq()
766-
767743

768744
PreprocessingType = Union[
769745
OneHotEncoder,
@@ -772,4 +748,5 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
772748
MinMaxScaler,
773749
KBinsDiscretizer,
774750
LabelEncoder,
751+
PolynomialFeatures,
775752
]

0 commit comments

Comments
 (0)