Skip to content

Commit 4c4415f

Browse files
authored
feat: support ml.SimpleImputer in bigframes (#708)
* feat: support ml.Imputer in bigframes * address comments * address more comments * address more comments
1 parent de0881b commit 4c4415f

File tree

14 files changed

+388
-7
lines changed

14 files changed

+388
-7
lines changed

bigframes/ml/compose.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
from bigframes import constants
3030
from bigframes.core import log_adapter
31-
from bigframes.ml import base, core, globals, preprocessing, utils
31+
from bigframes.ml import base, core, globals, impute, preprocessing, utils
3232
import bigframes.pandas as bpd
3333

3434
_BQML_TRANSFROM_TYPE_MAPPING = types.MappingProxyType(
@@ -40,6 +40,7 @@
4040
"ML.BUCKETIZE": preprocessing.KBinsDiscretizer,
4141
"ML.QUANTILE_BUCKETIZE": preprocessing.KBinsDiscretizer,
4242
"ML.LABEL_ENCODER": preprocessing.LabelEncoder,
43+
"ML.IMPUTER": impute.SimpleImputer,
4344
}
4445
)
4546

@@ -58,7 +59,7 @@ def __init__(
5859
transformers: List[
5960
Tuple[
6061
str,
61-
preprocessing.PreprocessingType,
62+
Union[preprocessing.PreprocessingType, impute.SimpleImputer],
6263
Union[str, List[str]],
6364
]
6465
],
@@ -73,12 +74,14 @@ def __init__(
7374
@property
7475
def transformers_(
7576
self,
76-
) -> List[Tuple[str, preprocessing.PreprocessingType, str,]]:
77+
) -> List[
78+
Tuple[str, Union[preprocessing.PreprocessingType, impute.SimpleImputer], str]
79+
]:
7780
"""The collection of transformers as tuples of (name, transformer, column)."""
7881
result: List[
7982
Tuple[
8083
str,
81-
preprocessing.PreprocessingType,
84+
Union[preprocessing.PreprocessingType, impute.SimpleImputer],
8285
str,
8386
]
8487
] = []
@@ -107,7 +110,7 @@ def _extract_from_bq_model(
107110
transformers: List[
108111
Tuple[
109112
str,
110-
preprocessing.PreprocessingType,
113+
Union[preprocessing.PreprocessingType, impute.SimpleImputer],
111114
Union[str, List[str]],
112115
]
113116
] = []
@@ -152,7 +155,9 @@ def camel_to_snake(name):
152155

153156
def _merge(
154157
self, bq_model: bigquery.Model
155-
) -> Union[ColumnTransformer, preprocessing.PreprocessingType,]:
158+
) -> Union[
159+
ColumnTransformer, Union[preprocessing.PreprocessingType, impute.SimpleImputer]
160+
]:
156161
"""Try to merge the column transformer to a simple transformer. Depends on all the columns in bq_model are transformed with the same transformer."""
157162
transformers = self.transformers_
158163

bigframes/ml/impute.py

+123
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Transformers for missing value imputation. This module is styled after
16+
scikit-learn's preprocessing module: https://scikit-learn.org/stable/modules/impute.html."""
17+
18+
from __future__ import annotations
19+
20+
import typing
21+
from typing import Any, List, Literal, Optional, Tuple, Union
22+
23+
import bigframes_vendored.sklearn.impute._base
24+
25+
from bigframes.core import log_adapter
26+
from bigframes.ml import base, core, globals, utils
27+
import bigframes.pandas as bpd
28+
29+
30+
@log_adapter.class_logger
31+
class SimpleImputer(
32+
base.Transformer,
33+
bigframes_vendored.sklearn.impute._base.SimpleImputer,
34+
):
35+
36+
__doc__ = bigframes_vendored.sklearn.impute._base.SimpleImputer.__doc__
37+
38+
def __init__(
39+
self,
40+
strategy: Literal["mean", "median", "most_frequent"] = "mean",
41+
):
42+
self.strategy = strategy
43+
self._bqml_model: Optional[core.BqmlModel] = None
44+
self._bqml_model_factory = globals.bqml_model_factory()
45+
self._base_sql_generator = globals.base_sql_generator()
46+
47+
# TODO(garrettwu): implement __hash__
48+
def __eq__(self, other: Any) -> bool:
49+
return (
50+
type(other) is SimpleImputer
51+
and self.strategy == other.strategy
52+
and self._bqml_model == other._bqml_model
53+
)
54+
55+
def _compile_to_sql(
56+
self,
57+
columns: List[str],
58+
X=None,
59+
) -> List[Tuple[str, str]]:
60+
"""Compile this transformer to a list of SQL expressions that can be included in
61+
a BQML TRANSFORM clause
62+
63+
Args:
64+
columns:
65+
A list of column names to transform.
66+
X:
67+
The Dataframe with training data.
68+
69+
Returns: a list of tuples of (sql_expression, output_name)"""
70+
return [
71+
(
72+
self._base_sql_generator.ml_imputer(
73+
column, self.strategy, f"imputer_{column}"
74+
),
75+
f"imputer_{column}",
76+
)
77+
for column in columns
78+
]
79+
80+
@classmethod
81+
def _parse_from_sql(cls, sql: str) -> tuple[SimpleImputer, str]:
82+
"""Parse SQL to tuple(SimpleImputer, column_label).
83+
84+
Args:
85+
sql: SQL string of format "ML.IMPUTER({col_label}, {strategy}) OVER()"
86+
87+
Returns:
88+
tuple(SimpleImputer, column_label)"""
89+
s = sql[sql.find("(") + 1 : sql.find(")")]
90+
col_label, strategy = s.split(", ")
91+
return cls(strategy[1:-1]), col_label # type: ignore[arg-type]
92+
93+
def fit(
94+
self,
95+
X: Union[bpd.DataFrame, bpd.Series],
96+
y=None, # ignored
97+
) -> SimpleImputer:
98+
(X,) = utils.convert_to_dataframe(X)
99+
100+
compiled_transforms = self._compile_to_sql(X.columns.tolist(), X)
101+
transform_sqls = [transform_sql for transform_sql, _ in compiled_transforms]
102+
103+
self._bqml_model = self._bqml_model_factory.create_model(
104+
X,
105+
options={"model_type": "transform_only"},
106+
transforms=transform_sqls,
107+
)
108+
109+
# The schema of TRANSFORM output is not available in the model API, so save it during fitting
110+
self._output_names = [name for _, name in compiled_transforms]
111+
return self
112+
113+
def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
114+
if not self._bqml_model:
115+
raise RuntimeError("Must be fitted before transform")
116+
117+
(X,) = utils.convert_to_dataframe(X)
118+
119+
df = self._bqml_model.transform(X)
120+
return typing.cast(
121+
bpd.DataFrame,
122+
df[self._output_names],
123+
)

bigframes/ml/loader.py

+2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
ensemble,
3030
forecasting,
3131
imported,
32+
impute,
3233
linear_model,
3334
llm,
3435
pipeline,
@@ -84,6 +85,7 @@ def from_bq(
8485
pipeline.Pipeline,
8586
compose.ColumnTransformer,
8687
preprocessing.PreprocessingType,
88+
impute.SimpleImputer,
8789
]:
8890
"""Load a BQML model to BigQuery DataFrames ML.
8991

bigframes/ml/pipeline.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,15 @@
2626
import bigframes
2727
import bigframes.constants as constants
2828
from bigframes.core import log_adapter
29-
from bigframes.ml import base, compose, forecasting, loader, preprocessing, utils
29+
from bigframes.ml import (
30+
base,
31+
compose,
32+
forecasting,
33+
impute,
34+
loader,
35+
preprocessing,
36+
utils,
37+
)
3038
import bigframes.pandas as bpd
3139

3240

@@ -56,6 +64,7 @@ def __init__(self, steps: List[Tuple[str, base.BaseEstimator]]):
5664
preprocessing.MinMaxScaler,
5765
preprocessing.KBinsDiscretizer,
5866
preprocessing.LabelEncoder,
67+
impute.SimpleImputer,
5968
),
6069
):
6170
self._transform = transform

bigframes/ml/preprocessing.py

+1
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def __eq__(self, other: Any) -> bool:
305305
return (
306306
type(other) is KBinsDiscretizer
307307
and self.n_bins == other.n_bins
308+
and self.strategy == other.strategy
308309
and self._bqml_model == other._bqml_model
309310
)
310311

bigframes/ml/sql.py

+9
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,15 @@ def ml_min_max_scaler(self, numeric_expr_sql: str, name: str) -> str:
103103
"""Encode ML.MIN_MAX_SCALER for BQML"""
104104
return f"""ML.MIN_MAX_SCALER({numeric_expr_sql}) OVER() AS {name}"""
105105

106+
def ml_imputer(
107+
self,
108+
expr_sql: str,
109+
strategy: str,
110+
name: str,
111+
) -> str:
112+
"""Encode ML.IMPUTER for BQML"""
113+
return f"""ML.IMPUTER({expr_sql}, '{strategy}') OVER() AS {name}"""
114+
106115
def ml_bucketize(
107116
self,
108117
numeric_expr_sql: str,
+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
bigframes.ml.impute
2+
==========================
3+
4+
.. automodule:: bigframes.ml.impute
5+
:members:
6+
:inherited-members:
7+
:undoc-members:

docs/reference/bigframes.ml/index.rst

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ API Reference
1919

2020
imported
2121

22+
impute
23+
2224
linear_model
2325

2426
llm

docs/templates/toc.yml

+6
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,12 @@
134134
- name: XGBoostModel
135135
uid: bigframes.ml.imported.XGBoostModel
136136
name: imported
137+
- items:
138+
- name: Overview
139+
uid: bigframes.ml.impute
140+
- name: SimpleImputer
141+
uid: bigframes.ml.impute.SimpleImputer
142+
name: impute
137143
- items:
138144
- name: Overview
139145
uid: bigframes.ml.linear_model

tests/system/conftest.py

+14
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,15 @@
2929
import google.cloud.resourcemanager_v3 as resourcemanager_v3
3030
import google.cloud.storage as storage # type: ignore
3131
import ibis.backends.base
32+
import numpy as np
3233
import pandas as pd
3334
import pytest
3435
import pytz
3536
import test_utils.prefixer
3637

3738
import bigframes
3839
import bigframes.dataframe
40+
import bigframes.pandas as bpd
3941
import tests.system.utils
4042

4143
# Use this to control the number of cloud functions being deleted in a single
@@ -624,6 +626,18 @@ def new_penguins_pandas_df():
624626
).set_index("tag_number")
625627

626628

629+
@pytest.fixture(scope="session")
630+
def missing_values_penguins_df():
631+
"""Additional data matching the missing values penguins dataset"""
632+
return bpd.DataFrame(
633+
{
634+
"culmen_length_mm": [39.5, 38.5, 37.9],
635+
"culmen_depth_mm": [np.nan, 17.2, 18.1],
636+
"flipper_length_mm": [np.nan, 181.0, 188.0],
637+
}
638+
)
639+
640+
627641
@pytest.fixture(scope="session")
628642
def new_penguins_df(session, new_penguins_pandas_df):
629643
return session.read_pandas(new_penguins_pandas_df)

0 commit comments

Comments
 (0)