Skip to content

Commit 8d82945

Browse files
authored
feat: add ml PCA.detect_anomalies method (#422)
* feat: add ml detect_anomalies * add PCA.detect_anomalies * fix mypy
1 parent ae0e3ea commit 8d82945

File tree

8 files changed

+124
-16
lines changed

8 files changed

+124
-16
lines changed

bigframes/ml/core.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -128,14 +128,12 @@ def model(self) -> bigquery.Model:
128128
return self._model
129129

130130
def predict(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
131-
# TODO: validate input data schema
132131
return self._apply_sql(
133132
input_data,
134133
self._model_manipulation_sql_generator.ml_predict,
135134
)
136135

137136
def transform(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
138-
# TODO: validate input data schema
139137
return self._apply_sql(
140138
input_data,
141139
self._model_manipulation_sql_generator.ml_transform,
@@ -146,7 +144,6 @@ def generate_text(
146144
input_data: bpd.DataFrame,
147145
options: Mapping[str, int | float],
148146
) -> bpd.DataFrame:
149-
# TODO: validate input data schema
150147
return self._apply_sql(
151148
input_data,
152149
lambda source_df: self._model_manipulation_sql_generator.ml_generate_text(
@@ -160,7 +157,6 @@ def generate_text_embedding(
160157
input_data: bpd.DataFrame,
161158
options: Mapping[str, int | float],
162159
) -> bpd.DataFrame:
163-
# TODO: validate input data schema
164160
return self._apply_sql(
165161
input_data,
166162
lambda source_df: self._model_manipulation_sql_generator.ml_generate_text_embedding(
@@ -169,12 +165,24 @@ def generate_text_embedding(
169165
),
170166
)
171167

168+
def detect_anomalies(
169+
self, input_data: bpd.DataFrame, options: Mapping[str, int | float]
170+
) -> bpd.DataFrame:
171+
assert self._model.model_type in ("PCA", "KMEANS", "ARIMA_PLUS")
172+
173+
return self._apply_sql(
174+
input_data,
175+
lambda source_df: self._model_manipulation_sql_generator.ml_detect_anomalies(
176+
source_df=source_df,
177+
struct_options=options,
178+
),
179+
)
180+
172181
def forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
173182
sql = self._model_manipulation_sql_generator.ml_forecast(struct_options=options)
174183
return self._session.read_gbq(sql, index_col="forecast_timestamp").reset_index()
175184

176185
def evaluate(self, input_data: Optional[bpd.DataFrame] = None):
177-
# TODO: validate input data schema
178186
sql = self._model_manipulation_sql_generator.ml_evaluate(input_data)
179187

180188
return self._session.read_gbq(sql)

bigframes/ml/decomposition.py

+28
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,34 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
110110

111111
return self._bqml_model.predict(X)
112112

113+
def detect_anomalies(
114+
self, X: Union[bpd.DataFrame, bpd.Series], *, contamination=0.1
115+
) -> bpd.DataFrame:
116+
"""Detect the anomaly data points of the input.
117+
118+
Args:
119+
X (bigframes.dataframe.DataFrame or bigframes.series.Series):
120+
Series or a DataFrame to detect anomalies.
121+
contamination (float, default 0.1):
122+
Identifies the proportion of anomalies in the training dataset that are used to create the model.
123+
The value must be in the range [0, 0.5].
124+
125+
Returns:
126+
bigframes.dataframe.DataFrame: detected DataFrame."""
127+
if contamination < 0.0 or contamination > 0.5:
128+
raise ValueError(
129+
f"contamination must be [0.0, 0.5], but is {contamination}."
130+
)
131+
132+
if not self._bqml_model:
133+
raise RuntimeError("A model must be fitted before detect_anomalies")
134+
135+
(X,) = utils.convert_to_dataframe(X)
136+
137+
return self._bqml_model.detect_anomalies(
138+
X, options={"contamination": contamination}
139+
)
140+
113141
def to_gbq(self, model_name: str, replace: bool = False) -> PCA:
114142
"""Save the model to BigQuery.
115143

bigframes/ml/imported.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import bigframes
2424
from bigframes.core import log_adapter
2525
from bigframes.ml import base, core, globals, utils
26-
from bigframes.ml.globals import _SUPPORTED_DTYPES
2726
import bigframes.pandas as bpd
2827

2928

@@ -236,9 +235,9 @@ def _create_bqml_model(self):
236235
else:
237236
for io in (self.input, self.output):
238237
for v in io.values():
239-
if v not in _SUPPORTED_DTYPES:
238+
if v not in globals._SUPPORTED_DTYPES:
240239
raise ValueError(
241-
f"field_type {v} is not supported. We only support {', '.join(_SUPPORTED_DTYPES)}."
240+
f"field_type {v} is not supported. We only support {', '.join(globals._SUPPORTED_DTYPES)}."
242241
)
243242

244243
return self._bqml_model_factory.create_xgboost_imported_model(

bigframes/ml/remote.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from bigframes import clients
2424
from bigframes.core import log_adapter
2525
from bigframes.ml import base, core, globals, utils
26-
from bigframes.ml.globals import _SUPPORTED_DTYPES
2726
import bigframes.pandas as bpd
2827

2928
_REMOTE_MODEL_STATUS = "remote_model_status"
@@ -102,9 +101,9 @@ def standardize_type(v: str):
102101
v = v.lower()
103102
v = v.replace("boolean", "bool")
104103

105-
if v not in _SUPPORTED_DTYPES:
104+
if v not in globals._SUPPORTED_DTYPES:
106105
raise ValueError(
107-
f"Data type {v} is not supported. We only support {', '.join(_SUPPORTED_DTYPES)}."
106+
f"Data type {v} is not supported. We only support {', '.join(globals._SUPPORTED_DTYPES)}."
108107
)
109108

110109
return v

bigframes/ml/sql.py

+8
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,14 @@ def ml_generate_text_embedding(
276276
return f"""SELECT * FROM ML.GENERATE_TEXT_EMBEDDING(MODEL `{self._model_name}`,
277277
({self._source_sql(source_df)}), {struct_options_sql})"""
278278

279+
def ml_detect_anomalies(
280+
self, source_df: bpd.DataFrame, struct_options: Mapping[str, Union[int, float]]
281+
) -> str:
282+
"""Encode ML.DETECT_ANOMALIES for BQML"""
283+
struct_options_sql = self.struct_options(**struct_options)
284+
return f"""SELECT * FROM ML.DETECT_ANOMALIES(MODEL `{self._model_name}`,
285+
{struct_options_sql}, ({self._source_sql(source_df)}))"""
286+
279287
# ML evaluation TVFs
280288
def ml_evaluate(self, source_df: Optional[bpd.DataFrame] = None) -> str:
281289
"""Encode ML.EVALUATE for BQML"""

tests/system/small/ml/test_core.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,29 @@ def test_model_predict_with_unnamed_index(
289289
)
290290

291291

292+
def test_model_detect_anomalies(
293+
penguins_bqml_pca_model: core.BqmlModel, new_penguins_df
294+
):
295+
options = {"contamination": 0.25}
296+
anomalies = penguins_bqml_pca_model.detect_anomalies(
297+
new_penguins_df, options
298+
).to_pandas()
299+
expected = pd.DataFrame(
300+
{
301+
"is_anomaly": [True, True, True],
302+
"mean_squared_error": [0.254188, 0.731243, 0.298889],
303+
},
304+
index=pd.Index([1633, 1672, 1690], name="tag_number", dtype="Int64"),
305+
)
306+
pd.testing.assert_frame_equal(
307+
anomalies[["is_anomaly", "mean_squared_error"]].sort_index(),
308+
expected,
309+
check_exact=False,
310+
check_dtype=False,
311+
rtol=0.1,
312+
)
313+
314+
292315
def test_remote_model_predict(
293316
bqml_linear_remote_model: core.BqmlModel, new_penguins_df
294317
):
@@ -367,16 +390,19 @@ def test_model_forecast(time_series_bqml_arima_plus_model: core.BqmlModel):
367390
)
368391

369392

370-
def test_model_register(ephemera_penguins_bqml_linear_model):
393+
def test_model_register(ephemera_penguins_bqml_linear_model: core.BqmlModel):
371394
model = ephemera_penguins_bqml_linear_model
372395
model.register()
373396

397+
assert model.model.model_id is not None
374398
model_name = "bigframes_" + model.model.model_id
375399
# Only registered model contains the field, and the field includes project/dataset. Here only check model_id.
376400
assert model_name in model.model.training_runs[-1]["vertexAiModelId"]
377401

378402

379-
def test_model_register_with_params(ephemera_penguins_bqml_linear_model):
403+
def test_model_register_with_params(
404+
ephemera_penguins_bqml_linear_model: core.BqmlModel,
405+
):
380406
model_name = "bigframes_system_test_model"
381407
model = ephemera_penguins_bqml_linear_model
382408
model.register(model_name)

tests/system/small/ml/test_decomposition.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@
1515
import pandas as pd
1616

1717
from bigframes.ml import decomposition
18+
import bigframes.pandas as bpd
1819
import tests.system.utils
1920

2021

21-
def test_pca_predict(penguins_pca_model, new_penguins_df):
22+
def test_pca_predict(
23+
penguins_pca_model: decomposition.PCA, new_penguins_df: bpd.DataFrame
24+
):
2225
predictions = penguins_pca_model.predict(new_penguins_df).to_pandas()
2326
expected = pd.DataFrame(
2427
{
@@ -35,6 +38,27 @@ def test_pca_predict(penguins_pca_model, new_penguins_df):
3538
)
3639

3740

41+
def test_pca_detect_anomalies(
42+
penguins_pca_model: decomposition.PCA, new_penguins_df: bpd.DataFrame
43+
):
44+
anomalies = penguins_pca_model.detect_anomalies(new_penguins_df).to_pandas()
45+
expected = pd.DataFrame(
46+
{
47+
"is_anomaly": [False, True, False],
48+
"mean_squared_error": [0.254188, 0.731243, 0.298889],
49+
},
50+
index=pd.Index([1633, 1672, 1690], name="tag_number", dtype="Int64"),
51+
)
52+
53+
pd.testing.assert_frame_equal(
54+
anomalies[["is_anomaly", "mean_squared_error"]].sort_index(),
55+
expected,
56+
check_exact=False,
57+
check_dtype=False,
58+
rtol=0.1,
59+
)
60+
61+
3862
def test_pca_score(penguins_pca_model: decomposition.PCA):
3963
result = penguins_pca_model.score().to_pandas()
4064
expected = pd.DataFrame(

tests/unit/ml/test_sql.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -341,9 +341,8 @@ def test_ml_centroids_correct(
341341
)
342342

343343

344-
def test_forecast_correct_sql(
344+
def test_ml_forecast_correct_sql(
345345
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
346-
mock_df: bpd.DataFrame,
347346
):
348347
sql = model_manipulation_sql_generator.ml_forecast(
349348
struct_options={"option_key1": 1, "option_key2": 2.2},
@@ -391,6 +390,23 @@ def test_ml_generate_text_embedding_correct(
391390
)
392391

393392

393+
def test_ml_detect_anomalies_correct_sql(
394+
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
395+
mock_df: bpd.DataFrame,
396+
):
397+
sql = model_manipulation_sql_generator.ml_detect_anomalies(
398+
source_df=mock_df,
399+
struct_options={"option_key1": 1, "option_key2": 2.2},
400+
)
401+
assert (
402+
sql
403+
== """SELECT * FROM ML.DETECT_ANOMALIES(MODEL `my_project_id.my_dataset_id.my_model_id`,
404+
STRUCT(
405+
1 AS option_key1,
406+
2.2 AS option_key2), (input_X_sql))"""
407+
)
408+
409+
394410
def test_ml_principal_components_correct(
395411
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
396412
):

0 commit comments

Comments
 (0)