Skip to content

Commit fb5d83b

Browse files
authored
feat: add ml PCA model params (#474)
1 parent 4727563 commit fb5d83b

File tree

3 files changed

+86
-4
lines changed

3 files changed

+86
-4
lines changed

bigframes/ml/decomposition.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from __future__ import annotations
1919

20-
from typing import List, Optional, Union
20+
from typing import List, Literal, Optional, Union
2121

2222
import bigframes_vendored.sklearn.decomposition._pca
2323
from google.cloud import bigquery
@@ -35,21 +35,29 @@ class PCA(
3535
):
3636
__doc__ = bigframes_vendored.sklearn.decomposition._pca.PCA.__doc__
3737

38-
def __init__(self, n_components: int = 3):
38+
def __init__(
39+
self,
40+
n_components: int = 3,
41+
*,
42+
svd_solver: Literal["full", "randomized", "auto"] = "auto",
43+
):
3944
self.n_components = n_components
45+
self.svd_solver = svd_solver
4046
self._bqml_model: Optional[core.BqmlModel] = None
4147
self._bqml_model_factory = globals.bqml_model_factory()
4248

4349
@classmethod
4450
def _from_bq(cls, session: bigframes.Session, model: bigquery.Model) -> PCA:
4551
assert model.model_type == "PCA"
4652

47-
kwargs = {}
53+
kwargs: dict = {}
4854

4955
# See https://cloud.google.com/bigquery/docs/reference/rest/v2/models#trainingrun
5056
last_fitting = model.training_runs[-1]["trainingOptions"]
5157
if "numPrincipalComponents" in last_fitting:
5258
kwargs["n_components"] = int(last_fitting["numPrincipalComponents"])
59+
if "pcaSolver" in last_fitting:
60+
kwargs["svd_solver"] = str(last_fitting["pcaSolver"])
5361

5462
new_pca = cls(**kwargs)
5563
new_pca._bqml_model = core.BqmlModel(session, model)
@@ -69,6 +77,7 @@ def _fit(
6977
options={
7078
"model_type": "PCA",
7179
"num_principal_components": self.n_components,
80+
"pca_solver": self.svd_solver,
7281
},
7382
)
7483
return self

tests/system/large/ml/test_decomposition.py

+71
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,74 @@ def test_decomposition_configure_fit_score_predict(
8484
in reloaded_model._bqml_model.model_name
8585
)
8686
assert reloaded_model.n_components == 3
87+
88+
89+
def test_decomposition_configure_fit_score_predict_params(
90+
session, penguins_df_default_index, dataset_id
91+
):
92+
model = decomposition.PCA(n_components=5, svd_solver="randomized")
93+
model.fit(penguins_df_default_index)
94+
95+
new_penguins = session.read_pandas(
96+
pd.DataFrame(
97+
{
98+
"tag_number": [1633, 1672, 1690],
99+
"species": [
100+
"Adelie Penguin (Pygoscelis adeliae)",
101+
"Gentoo penguin (Pygoscelis papua)",
102+
"Adelie Penguin (Pygoscelis adeliae)",
103+
],
104+
"island": ["Dream", "Biscoe", "Torgersen"],
105+
"culmen_length_mm": [37.8, 46.5, 41.1],
106+
"culmen_depth_mm": [18.1, 14.8, 18.6],
107+
"flipper_length_mm": [193.0, 217.0, 189.0],
108+
"body_mass_g": [3750.0, 5200.0, 3325.0],
109+
"sex": ["MALE", "FEMALE", "MALE"],
110+
}
111+
).set_index("tag_number")
112+
)
113+
114+
# Check score to ensure the model was fitted
115+
score_result = model.score(new_penguins).to_pandas()
116+
score_expected = pd.DataFrame(
117+
{
118+
"total_explained_variance_ratio": [0.932897],
119+
},
120+
dtype="Float64",
121+
)
122+
score_expected = score_expected.reindex(index=score_expected.index.astype("Int64"))
123+
124+
pd.testing.assert_frame_equal(
125+
score_result, score_expected, check_exact=False, rtol=0.1
126+
)
127+
128+
result = model.predict(new_penguins).to_pandas()
129+
expected = pd.DataFrame(
130+
{
131+
"principal_component_1": [-1.459, 2.258, -1.685],
132+
"principal_component_2": [-1.120, -1.351, -0.874],
133+
"principal_component_3": [-0.646, 0.443, -0.704],
134+
"principal_component_4": [-0.539, 0.234, -0.571],
135+
"principal_component_5": [-0.876, 0.122, 0.609],
136+
},
137+
dtype="Float64",
138+
index=pd.Index([1633, 1672, 1690], name="tag_number", dtype="Int64"),
139+
)
140+
141+
tests.system.utils.assert_pandas_df_equal_pca(
142+
result,
143+
expected,
144+
check_exact=False,
145+
rtol=0.1,
146+
)
147+
148+
# save, load, check n_components to ensure configuration was kept
149+
reloaded_model = model.to_gbq(
150+
f"{dataset_id}.temp_configured_pca_model", replace=True
151+
)
152+
assert (
153+
f"{dataset_id}.temp_configured_pca_model"
154+
in reloaded_model._bqml_model.model_name
155+
)
156+
assert reloaded_model.n_components == 5
157+
assert reloaded_model.svd_solver == "RANDOMIZED"

third_party/bigframes_vendored/sklearn/decomposition/_pca.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,11 @@ class PCA(BaseEstimator, metaclass=ABCMeta):
3232
truncated SVD.
3333
3434
Args:
35-
n_components (Optional[int], default 3):
35+
n_components (Optional[int], default 3):
3636
Number of components to keep. if n_components is not set all components
3737
are kept.
38+
svd_solver ("full", "randomized" or "auto", default "auto"):
39+
The solver to use to calculate the principal components. Details: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-pca#pca_solver.
3840
3941
"""
4042

0 commit comments

Comments
 (0)