Skip to content

Commit 24c6256

Browse files
authored
feat: add ml.preprocessing.KBinsDiscretizer (#81)
* feat: add ml.preprocessing.KBinsDiscretizer * fix: address all the comments * fix: address additional comments * fix: fix the failed test * Empty commit * Trigger Kokoro
1 parent fff3d45 commit 24c6256

File tree

9 files changed

+426
-17
lines changed

9 files changed

+426
-17
lines changed

bigframes/ml/compose.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
preprocessing.StandardScaler,
3232
preprocessing.MaxAbsScaler,
3333
preprocessing.MinMaxScaler,
34+
preprocessing.KBinsDiscretizer,
3435
preprocessing.LabelEncoder,
3536
]
3637

@@ -91,18 +92,24 @@ def transformers_(
9192

9293
return result
9394

94-
def _compile_to_sql(self, columns: List[str]) -> List[Tuple[str, str]]:
95+
def _compile_to_sql(
96+
self,
97+
columns: List[str],
98+
X: bpd.DataFrame,
99+
) -> List[Tuple[str, str]]:
95100
"""Compile this transformer to a list of SQL expressions that can be included in
96101
a BQML TRANSFORM clause
97102
98103
Args:
99104
columns (List[str]):
100105
a list of column names to transform
106+
X (bpd.DataFrame):
107+
The Dataframe with training data.
101108
102109
Returns:
103110
a list of tuples of (sql_expression, output_name)"""
104111
return [
105-
transformer._compile_to_sql([column])[0]
112+
transformer._compile_to_sql([column], X=X)[0]
106113
for column in columns
107114
for _, transformer, target_column in self.transformers_
108115
if column == target_column
@@ -115,7 +122,7 @@ def fit(
115122
) -> ColumnTransformer:
116123
(X,) = utils.convert_to_dataframe(X)
117124

118-
compiled_transforms = self._compile_to_sql(X.columns.tolist())
125+
compiled_transforms = self._compile_to_sql(X.columns.tolist(), X)
119126
transform_sqls = [transform_sql for transform_sql, _ in compiled_transforms]
120127

121128
self._bqml_model = self._bqml_model_factory.create_model(

bigframes/ml/pipeline.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(self, steps: List[Tuple[str, base.BaseEstimator]]):
5252
preprocessing.OneHotEncoder,
5353
preprocessing.MaxAbsScaler,
5454
preprocessing.MinMaxScaler,
55+
preprocessing.KBinsDiscretizer,
5556
preprocessing.LabelEncoder,
5657
),
5758
):
@@ -93,7 +94,7 @@ def fit(
9394
) -> Pipeline:
9495
(X,) = utils.convert_to_dataframe(X)
9596

96-
compiled_transforms = self._transform._compile_to_sql(X.columns.tolist())
97+
compiled_transforms = self._transform._compile_to_sql(X.columns.tolist(), X=X)
9798
transform_sqls = [transform_sql for transform_sql, _ in compiled_transforms]
9899

99100
if y is not None:
@@ -151,6 +152,7 @@ def _extract_as_column_transformer(
151152
preprocessing.StandardScaler,
152153
preprocessing.MaxAbsScaler,
153154
preprocessing.MinMaxScaler,
155+
preprocessing.KBinsDiscretizer,
154156
preprocessing.LabelEncoder,
155157
],
156158
Union[str, List[str]],
@@ -190,6 +192,13 @@ def _extract_as_column_transformer(
190192
*preprocessing.MinMaxScaler._parse_from_sql(transform_sql),
191193
)
192194
)
195+
elif transform_sql.startswith("ML.BUCKETIZE"):
196+
transformers.append(
197+
(
198+
"k_bins_discretizer",
199+
*preprocessing.KBinsDiscretizer._parse_from_sql(transform_sql),
200+
)
201+
)
193202
elif transform_sql.startswith("ML.LABEL_ENCODER"):
194203
transformers.append(
195204
(
@@ -213,6 +222,7 @@ def _merge_column_transformer(
213222
preprocessing.OneHotEncoder,
214223
preprocessing.MaxAbsScaler,
215224
preprocessing.MinMaxScaler,
225+
preprocessing.KBinsDiscretizer,
216226
preprocessing.LabelEncoder,
217227
]:
218228
"""Try to merge the column transformer to a simple transformer."""

bigframes/ml/preprocessing.py

+142-10
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from bigframes.ml import base, core, globals, utils
2424
import bigframes.pandas as bpd
2525
import third_party.bigframes_vendored.sklearn.preprocessing._data
26+
import third_party.bigframes_vendored.sklearn.preprocessing._discretization
2627
import third_party.bigframes_vendored.sklearn.preprocessing._encoder
2728
import third_party.bigframes_vendored.sklearn.preprocessing._label
2829

@@ -44,12 +45,15 @@ def __init__(self):
4445
def __eq__(self, other: Any) -> bool:
4546
return type(other) is StandardScaler and self._bqml_model == other._bqml_model
4647

47-
def _compile_to_sql(self, columns: List[str]) -> List[Tuple[str, str]]:
48+
def _compile_to_sql(self, columns: List[str], X=None) -> List[Tuple[str, str]]:
4849
"""Compile this transformer to a list of SQL expressions that can be included in
4950
a BQML TRANSFORM clause
5051
5152
Args:
52-
columns: a list of column names to transform
53+
columns:
54+
a list of column names to transform.
55+
X (default None):
56+
Ignored.
5357
5458
Returns: a list of tuples of (sql_expression, output_name)"""
5559
return [
@@ -124,12 +128,15 @@ def __init__(self):
124128
def __eq__(self, other: Any) -> bool:
125129
return type(other) is MaxAbsScaler and self._bqml_model == other._bqml_model
126130

127-
def _compile_to_sql(self, columns: List[str]) -> List[Tuple[str, str]]:
131+
def _compile_to_sql(self, columns: List[str], X=None) -> List[Tuple[str, str]]:
128132
"""Compile this transformer to a list of SQL expressions that can be included in
129133
a BQML TRANSFORM clause
130134
131135
Args:
132-
columns: a list of column names to transform
136+
columns:
137+
a list of column names to transform.
138+
X (default None):
139+
Ignored.
133140
134141
Returns: a list of tuples of (sql_expression, output_name)"""
135142
return [
@@ -204,12 +211,15 @@ def __init__(self):
204211
def __eq__(self, other: Any) -> bool:
205212
return type(other) is MinMaxScaler and self._bqml_model == other._bqml_model
206213

207-
def _compile_to_sql(self, columns: List[str]) -> List[Tuple[str, str]]:
214+
def _compile_to_sql(self, columns: List[str], X=None) -> List[Tuple[str, str]]:
208215
"""Compile this transformer to a list of SQL expressions that can be included in
209216
a BQML TRANSFORM clause
210217
211218
Args:
212-
columns: a list of column names to transform
219+
columns:
220+
a list of column names to transform.
221+
X (default None):
222+
Ignored.
213223
214224
Returns: a list of tuples of (sql_expression, output_name)"""
215225
return [
@@ -267,6 +277,124 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
267277
)
268278

269279

280+
class KBinsDiscretizer(
281+
base.Transformer,
282+
third_party.bigframes_vendored.sklearn.preprocessing._discretization.KBinsDiscretizer,
283+
):
284+
__doc__ = (
285+
third_party.bigframes_vendored.sklearn.preprocessing._discretization.KBinsDiscretizer.__doc__
286+
)
287+
288+
def __init__(
289+
self,
290+
n_bins: int = 5,
291+
strategy: Literal["uniform", "quantile"] = "quantile",
292+
):
293+
if strategy != "uniform":
294+
raise NotImplementedError(
295+
f"Only strategy = 'uniform' is supported now, input is {strategy}."
296+
)
297+
if n_bins < 2:
298+
raise ValueError(
299+
f"n_bins has to be larger than or equal to 2, input is {n_bins}."
300+
)
301+
self.n_bins = n_bins
302+
self.strategy = strategy
303+
self._bqml_model: Optional[core.BqmlModel] = None
304+
self._bqml_model_factory = globals.bqml_model_factory()
305+
self._base_sql_generator = globals.base_sql_generator()
306+
307+
# TODO(garrettwu): implement __hash__
308+
def __eq__(self, other: Any) -> bool:
309+
return (
310+
type(other) is KBinsDiscretizer
311+
and self.n_bins == other.n_bins
312+
and self._bqml_model == other._bqml_model
313+
)
314+
315+
def _compile_to_sql(
316+
self,
317+
columns: List[str],
318+
X: bpd.DataFrame,
319+
) -> List[Tuple[str, str]]:
320+
"""Compile this transformer to a list of SQL expressions that can be included in
321+
a BQML TRANSFORM clause
322+
323+
Args:
324+
columns:
325+
a list of column names to transform
326+
X:
327+
The Dataframe with training data.
328+
329+
Returns: a list of tuples of (sql_expression, output_name)"""
330+
array_split_points = {}
331+
if self.strategy == "uniform":
332+
for column in columns:
333+
min_value = X[column].min()
334+
max_value = X[column].max()
335+
bin_size = (max_value - min_value) / self.n_bins
336+
array_split_points[column] = [
337+
min_value + i * bin_size for i in range(self.n_bins - 1)
338+
]
339+
340+
return [
341+
(
342+
self._base_sql_generator.ml_bucketize(
343+
column, array_split_points[column], f"kbinsdiscretizer_{column}"
344+
),
345+
f"kbinsdiscretizer_{column}",
346+
)
347+
for column in columns
348+
]
349+
350+
@classmethod
351+
def _parse_from_sql(cls, sql: str) -> tuple[KBinsDiscretizer, str]:
352+
"""Parse SQL to tuple(KBinsDiscretizer, column_label).
353+
354+
Args:
355+
sql: SQL string of format "ML.BUCKETIZE({col_label}, array_split_points, FALSE) OVER()"
356+
357+
Returns:
358+
tuple(KBinsDiscretizer, column_label)"""
359+
s = sql[sql.find("(") + 1 : sql.find(")")]
360+
array_split_points = s[s.find("[") + 1 : s.find("]")]
361+
col_label = s[: s.find(",")]
362+
n_bins = array_split_points.count(",") + 2
363+
return cls(n_bins, "uniform"), col_label
364+
365+
def fit(
366+
self,
367+
X: Union[bpd.DataFrame, bpd.Series],
368+
y=None, # ignored
369+
) -> KBinsDiscretizer:
370+
(X,) = utils.convert_to_dataframe(X)
371+
372+
compiled_transforms = self._compile_to_sql(X.columns.tolist(), X)
373+
transform_sqls = [transform_sql for transform_sql, _ in compiled_transforms]
374+
375+
self._bqml_model = self._bqml_model_factory.create_model(
376+
X,
377+
options={"model_type": "transform_only"},
378+
transforms=transform_sqls,
379+
)
380+
381+
# The schema of TRANSFORM output is not available in the model API, so save it during fitting
382+
self._output_names = [name for _, name in compiled_transforms]
383+
return self
384+
385+
def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
386+
if not self._bqml_model:
387+
raise RuntimeError("Must be fitted before transform")
388+
389+
(X,) = utils.convert_to_dataframe(X)
390+
391+
df = self._bqml_model.transform(X)
392+
return typing.cast(
393+
bpd.DataFrame,
394+
df[self._output_names],
395+
)
396+
397+
270398
class OneHotEncoder(
271399
base.Transformer,
272400
third_party.bigframes_vendored.sklearn.preprocessing._encoder.OneHotEncoder,
@@ -308,13 +436,15 @@ def __eq__(self, other: Any) -> bool:
308436
and self.max_categories == other.max_categories
309437
)
310438

311-
def _compile_to_sql(self, columns: List[str]) -> List[Tuple[str, str]]:
439+
def _compile_to_sql(self, columns: List[str], X=None) -> List[Tuple[str, str]]:
312440
"""Compile this transformer to a list of SQL expressions that can be included in
313441
a BQML TRANSFORM clause
314442
315443
Args:
316444
columns:
317-
a list of column names to transform
445+
a list of column names to transform.
446+
X (default None):
447+
Ignored.
318448
319449
Returns: a list of tuples of (sql_expression, output_name)"""
320450

@@ -432,13 +562,15 @@ def __eq__(self, other: Any) -> bool:
432562
and self.max_categories == other.max_categories
433563
)
434564

435-
def _compile_to_sql(self, columns: List[str]) -> List[Tuple[str, str]]:
565+
def _compile_to_sql(self, columns: List[str], X=None) -> List[Tuple[str, str]]:
436566
"""Compile this transformer to a list of SQL expressions that can be included in
437567
a BQML TRANSFORM clause
438568
439569
Args:
440570
columns:
441-
a list of column names to transform
571+
a list of column names to transform.
572+
X (default None):
573+
Ignored.
442574
443575
Returns: a list of tuples of (sql_expression, output_name)"""
444576

bigframes/ml/sql.py

+9
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,15 @@ def ml_min_max_scaler(self, numeric_expr_sql: str, name: str) -> str:
8585
"""Encode ML.MIN_MAX_SCALER for BQML"""
8686
return f"""ML.MIN_MAX_SCALER({numeric_expr_sql}) OVER() AS {name}"""
8787

88+
def ml_bucketize(
89+
self,
90+
numeric_expr_sql: str,
91+
array_split_points: Iterable[Union[int, float]],
92+
name: str,
93+
) -> str:
94+
"""Encode ML.MIN_MAX_SCALER for BQML"""
95+
return f"""ML.BUCKETIZE({numeric_expr_sql}, {array_split_points}, FALSE) AS {name}"""
96+
8897
def ml_one_hot_encoder(
8998
self,
9099
numeric_expr_sql: str,

0 commit comments

Comments
 (0)