Skip to content

Commit 65c6f47

Browse files
authored
feat!: rename ml model params (#491)
Includes following changes: * renaming min_rel_progress -> tol, to be consistent with sklearn * not allowing setting early_stop anymore, always to True * renaming n_parallell_trees -> n_estimators, to be consistent with sklearn * renaming class_weights -> class_weight, to be consistent with sklearn * renaming learn_rate -> learning_rate, to be consistent with sklearn * PCA n_components supports float value and None now, default to None
1 parent ae586e0 commit 65c6f47

File tree

13 files changed

+205
-195
lines changed

13 files changed

+205
-195
lines changed

bigframes/ml/cluster.py

-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
"init_col": "kmeansInitializationColumn",
3434
"distance_type": "distanceType",
3535
"max_iter": "maxIterations",
36-
"early_stop": "earlyStop",
3736
"tol": "minRelativeProgress",
3837
}
3938

bigframes/ml/decomposition.py

+23-6
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class PCA(
3737

3838
def __init__(
3939
self,
40-
n_components: int = 3,
40+
n_components: Optional[Union[int, float]] = None,
4141
*,
4242
svd_solver: Literal["full", "randomized", "auto"] = "auto",
4343
):
@@ -56,13 +56,31 @@ def _from_bq(cls, session: bigframes.Session, model: bigquery.Model) -> PCA:
5656
last_fitting = model.training_runs[-1]["trainingOptions"]
5757
if "numPrincipalComponents" in last_fitting:
5858
kwargs["n_components"] = int(last_fitting["numPrincipalComponents"])
59+
if "pcaExplainedVarianceRatio" in last_fitting:
60+
kwargs["n_components"] = float(last_fitting["pcaExplainedVarianceRatio"])
5961
if "pcaSolver" in last_fitting:
6062
kwargs["svd_solver"] = str(last_fitting["pcaSolver"])
6163

6264
new_pca = cls(**kwargs)
6365
new_pca._bqml_model = core.BqmlModel(session, model)
6466
return new_pca
6567

68+
@property
69+
def _bqml_options(self) -> dict:
70+
"""The model options as they will be set for BQML"""
71+
options: dict = {
72+
"model_type": "PCA",
73+
"pca_solver": self.svd_solver,
74+
}
75+
76+
assert self.n_components is not None
77+
if 0 < self.n_components < 1:
78+
options["pca_explained_variance_ratio"] = float(self.n_components)
79+
elif self.n_components >= 1:
80+
options["num_principal_components"] = int(self.n_components)
81+
82+
return options
83+
6684
def _fit(
6785
self,
6886
X: Union[bpd.DataFrame, bpd.Series],
@@ -71,14 +89,13 @@ def _fit(
7189
) -> PCA:
7290
(X,) = utils.convert_to_dataframe(X)
7391

92+
# To mimic sklearn's behavior
93+
if self.n_components is None:
94+
self.n_components = min(X.shape)
7495
self._bqml_model = self._bqml_model_factory.create_model(
7596
X_train=X,
7697
transforms=transforms,
77-
options={
78-
"model_type": "PCA",
79-
"num_principal_components": self.n_components,
80-
"pca_solver": self.svd_solver,
81-
},
98+
options=self._bqml_options,
8299
)
83100
return self
84101

bigframes/ml/ensemble.py

+30-39
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
_BQML_PARAMS_MAPPING = {
3232
"booster": "boosterType",
3333
"tree_method": "treeMethod",
34-
"early_stop": "earlyStop",
3534
"colsample_bytree": "colsampleBylevel",
3635
"colsample_bylevel": "colsampleBytree",
3736
"colsample_bynode": "colsampleBynode",
@@ -40,8 +39,8 @@
4039
"reg_alpha": "l1Regularization",
4140
"reg_lambda": "l2Regularization",
4241
"learning_rate": "learnRate",
43-
"min_rel_progress": "minRelativeProgress",
44-
"num_parallel_tree": "numParallelTree",
42+
"tol": "minRelativeProgress",
43+
"n_estimators": "numParallelTree",
4544
"min_tree_child_weight": "minTreeChildWeight",
4645
"max_depth": "maxTreeDepth",
4746
"max_iterations": "maxIterations",
@@ -57,7 +56,7 @@ class XGBRegressor(
5756

5857
def __init__(
5958
self,
60-
num_parallel_tree: int = 1,
59+
n_estimators: int = 1,
6160
*,
6261
booster: Literal["gbtree", "dart"] = "gbtree",
6362
dart_normalized_type: Literal["tree", "forest"] = "tree",
@@ -71,14 +70,13 @@ def __init__(
7170
subsample: float = 1.0,
7271
reg_alpha: float = 0.0,
7372
reg_lambda: float = 1.0,
74-
early_stop: float = True,
7573
learning_rate: float = 0.3,
7674
max_iterations: int = 20,
77-
min_rel_progress: float = 0.01,
75+
tol: float = 0.01,
7876
enable_global_explain: bool = False,
7977
xgboost_version: Literal["0.9", "1.1"] = "0.9",
8078
):
81-
self.num_parallel_tree = num_parallel_tree
79+
self.n_estimators = n_estimators
8280
self.booster = booster
8381
self.dart_normalized_type = dart_normalized_type
8482
self.tree_method = tree_method
@@ -91,10 +89,9 @@ def __init__(
9189
self.subsample = subsample
9290
self.reg_alpha = reg_alpha
9391
self.reg_lambda = reg_lambda
94-
self.early_stop = early_stop
9592
self.learning_rate = learning_rate
9693
self.max_iterations = max_iterations
97-
self.min_rel_progress = min_rel_progress
94+
self.tol = tol
9895
self.enable_global_explain = enable_global_explain
9996
self.xgboost_version = xgboost_version
10097
self._bqml_model: Optional[core.BqmlModel] = None
@@ -127,7 +124,8 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
127124
return {
128125
"model_type": "BOOSTED_TREE_REGRESSOR",
129126
"data_split_method": "NO_SPLIT",
130-
"num_parallel_tree": self.num_parallel_tree,
127+
"early_stop": True,
128+
"num_parallel_tree": self.n_estimators,
131129
"booster_type": self.booster,
132130
"tree_method": self.tree_method,
133131
"min_tree_child_weight": self.min_tree_child_weight,
@@ -139,10 +137,9 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
139137
"subsample": self.subsample,
140138
"l1_reg": self.reg_alpha,
141139
"l2_reg": self.reg_lambda,
142-
"early_stop": self.early_stop,
143140
"learn_rate": self.learning_rate,
144141
"max_iterations": self.max_iterations,
145-
"min_rel_progress": self.min_rel_progress,
142+
"min_rel_progress": self.tol,
146143
"enable_global_explain": self.enable_global_explain,
147144
"xgboost_version": self.xgboost_version,
148145
}
@@ -215,7 +212,7 @@ class XGBClassifier(
215212

216213
def __init__(
217214
self,
218-
num_parallel_tree: int = 1,
215+
n_estimators: int = 1,
219216
*,
220217
booster: Literal["gbtree", "dart"] = "gbtree",
221218
dart_normalized_type: Literal["tree", "forest"] = "tree",
@@ -229,14 +226,13 @@ def __init__(
229226
subsample: float = 1.0,
230227
reg_alpha: float = 0.0,
231228
reg_lambda: float = 1.0,
232-
early_stop: bool = True,
233229
learning_rate: float = 0.3,
234230
max_iterations: int = 20,
235-
min_rel_progress: float = 0.01,
231+
tol: float = 0.01,
236232
enable_global_explain: bool = False,
237233
xgboost_version: Literal["0.9", "1.1"] = "0.9",
238234
):
239-
self.num_parallel_tree = num_parallel_tree
235+
self.n_estimators = n_estimators
240236
self.booster = booster
241237
self.dart_normalized_type = dart_normalized_type
242238
self.tree_method = tree_method
@@ -249,10 +245,9 @@ def __init__(
249245
self.subsample = subsample
250246
self.reg_alpha = reg_alpha
251247
self.reg_lambda = reg_lambda
252-
self.early_stop = early_stop
253248
self.learning_rate = learning_rate
254249
self.max_iterations = max_iterations
255-
self.min_rel_progress = min_rel_progress
250+
self.tol = tol
256251
self.enable_global_explain = enable_global_explain
257252
self.xgboost_version = xgboost_version
258253
self._bqml_model: Optional[core.BqmlModel] = None
@@ -285,7 +280,8 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
285280
return {
286281
"model_type": "BOOSTED_TREE_CLASSIFIER",
287282
"data_split_method": "NO_SPLIT",
288-
"num_parallel_tree": self.num_parallel_tree,
283+
"early_stop": True,
284+
"num_parallel_tree": self.n_estimators,
289285
"booster_type": self.booster,
290286
"tree_method": self.tree_method,
291287
"min_tree_child_weight": self.min_tree_child_weight,
@@ -297,10 +293,9 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
297293
"subsample": self.subsample,
298294
"l1_reg": self.reg_alpha,
299295
"l2_reg": self.reg_lambda,
300-
"early_stop": self.early_stop,
301296
"learn_rate": self.learning_rate,
302297
"max_iterations": self.max_iterations,
303-
"min_rel_progress": self.min_rel_progress,
298+
"min_rel_progress": self.tol,
304299
"enable_global_explain": self.enable_global_explain,
305300
"xgboost_version": self.xgboost_version,
306301
}
@@ -371,7 +366,7 @@ class RandomForestRegressor(
371366

372367
def __init__(
373368
self,
374-
num_parallel_tree: int = 100,
369+
n_estimators: int = 100,
375370
*,
376371
tree_method: Literal["auto", "exact", "approx", "hist"] = "auto",
377372
min_tree_child_weight: int = 1,
@@ -383,12 +378,11 @@ def __init__(
383378
subsample=0.8,
384379
reg_alpha=0.0,
385380
reg_lambda=1.0,
386-
early_stop=True,
387-
min_rel_progress=0.01,
381+
tol=0.01,
388382
enable_global_explain=False,
389383
xgboost_version: Literal["0.9", "1.1"] = "0.9",
390384
):
391-
self.num_parallel_tree = num_parallel_tree
385+
self.n_estimators = n_estimators
392386
self.tree_method = tree_method
393387
self.min_tree_child_weight = min_tree_child_weight
394388
self.colsample_bytree = colsample_bytree
@@ -399,8 +393,7 @@ def __init__(
399393
self.subsample = subsample
400394
self.reg_alpha = reg_alpha
401395
self.reg_lambda = reg_lambda
402-
self.early_stop = early_stop
403-
self.min_rel_progress = min_rel_progress
396+
self.tol = tol
404397
self.enable_global_explain = enable_global_explain
405398
self.xgboost_version = xgboost_version
406399
self._bqml_model: Optional[core.BqmlModel] = None
@@ -432,7 +425,8 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
432425
"""The model options as they will be set for BQML"""
433426
return {
434427
"model_type": "RANDOM_FOREST_REGRESSOR",
435-
"num_parallel_tree": self.num_parallel_tree,
428+
"early_stop": True,
429+
"num_parallel_tree": self.n_estimators,
436430
"tree_method": self.tree_method,
437431
"min_tree_child_weight": self.min_tree_child_weight,
438432
"colsample_bytree": self.colsample_bytree,
@@ -443,8 +437,7 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
443437
"subsample": self.subsample,
444438
"l1_reg": self.reg_alpha,
445439
"l2_reg": self.reg_lambda,
446-
"early_stop": self.early_stop,
447-
"min_rel_progress": self.min_rel_progress,
440+
"min_rel_progress": self.tol,
448441
"data_split_method": "NO_SPLIT",
449442
"enable_global_explain": self.enable_global_explain,
450443
"xgboost_version": self.xgboost_version,
@@ -536,7 +529,7 @@ class RandomForestClassifier(
536529

537530
def __init__(
538531
self,
539-
num_parallel_tree: int = 100,
532+
n_estimators: int = 100,
540533
*,
541534
tree_method: Literal["auto", "exact", "approx", "hist"] = "auto",
542535
min_tree_child_weight: int = 1,
@@ -548,12 +541,11 @@ def __init__(
548541
subsample: float = 0.8,
549542
reg_alpha: float = 0.0,
550543
reg_lambda: float = 1.0,
551-
early_stop=True,
552-
min_rel_progress: float = 0.01,
544+
tol: float = 0.01,
553545
enable_global_explain=False,
554546
xgboost_version: Literal["0.9", "1.1"] = "0.9",
555547
):
556-
self.num_parallel_tree = num_parallel_tree
548+
self.n_estimators = n_estimators
557549
self.tree_method = tree_method
558550
self.min_tree_child_weight = min_tree_child_weight
559551
self.colsample_bytree = colsample_bytree
@@ -564,8 +556,7 @@ def __init__(
564556
self.subsample = subsample
565557
self.reg_alpha = reg_alpha
566558
self.reg_lambda = reg_lambda
567-
self.early_stop = early_stop
568-
self.min_rel_progress = min_rel_progress
559+
self.tol = tol
569560
self.enable_global_explain = enable_global_explain
570561
self.xgboost_version = xgboost_version
571562
self._bqml_model: Optional[core.BqmlModel] = None
@@ -597,7 +588,8 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
597588
"""The model options as they will be set for BQML"""
598589
return {
599590
"model_type": "RANDOM_FOREST_CLASSIFIER",
600-
"num_parallel_tree": self.num_parallel_tree,
591+
"early_stop": True,
592+
"num_parallel_tree": self.n_estimators,
601593
"tree_method": self.tree_method,
602594
"min_tree_child_weight": self.min_tree_child_weight,
603595
"colsample_bytree": self.colsample_bytree,
@@ -608,8 +600,7 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
608600
"subsample": self.subsample,
609601
"l1_reg": self.reg_alpha,
610602
"l2_reg": self.reg_lambda,
611-
"early_stop": self.early_stop,
612-
"min_rel_progress": self.min_rel_progress,
603+
"min_rel_progress": self.tol,
613604
"data_split_method": "NO_SPLIT",
614605
"enable_global_explain": self.enable_global_explain,
615606
"xgboost_version": self.xgboost_version,

0 commit comments

Comments
 (0)