Skip to content

Commit 352cb85

Browse files
authored
feat: add ml ARIMAPlus model params (#488)
1 parent 60d4a7b commit 352cb85

File tree

2 files changed

+192
-6
lines changed

2 files changed

+192
-6
lines changed

bigframes/ml/forecasting.py

+151-6
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from __future__ import annotations
1818

19-
from typing import Dict, List, Optional, Union
19+
from typing import List, Optional, Union
2020

2121
from google.cloud import bigquery
2222

@@ -25,29 +25,174 @@
2525
from bigframes.ml import base, core, globals, utils
2626
import bigframes.pandas as bpd
2727

28+
_BQML_PARAMS_MAPPING = {
29+
"horizon": "horizon",
30+
"auto_arima": "autoArima",
31+
"auto_arima_max_order": "autoArimaMaxOrder",
32+
"auto_arima_min_order": "autoArimaMinOrder",
33+
"order": "nonSeasonalOrder",
34+
"data_frequency": "dataFrequency",
35+
"holiday_region": "holidayRegion",
36+
"clean_spikes_and_dips": "cleanSpikesAndDips",
37+
"adjust_step_changes": "adjustStepChanges",
38+
"time_series_length_fraction": "timeSeriesLengthFraction",
39+
"min_time_series_length": "minTimeSeriesLength",
40+
"max_time_series_length": "maxTimeSeriesLength",
41+
"decompose_time_series": "decomposeTimeSeries",
42+
"trend_smoothing_window_size": "trendSmoothingWindowSize",
43+
}
44+
2845

2946
@log_adapter.class_logger
3047
class ARIMAPlus(base.SupervisedTrainablePredictor):
31-
"""Time Series ARIMA Plus model."""
48+
"""Time Series ARIMA Plus model.
49+
50+
Args:
51+
horizon (int, default 1,000):
52+
The number of time points to forecast. Default to 1,000, max value 10,000.
53+
54+
auto_arima (bool, default True):
55+
Determines whether the training process uses auto.ARIMA or not. If True, training automatically finds the best non-seasonal order (that is, the p, d, q tuple) and decides whether or not to include a linear drift term when d is 1.
56+
57+
auto_arima_max_order (int or None, default None):
58+
The maximum value for the sum of non-seasonal p and q.
59+
60+
auto_arima_min_order (int or None, default None):
61+
The minimum value for the sum of non-seasonal p and q.
62+
63+
data_frequency (str, default "auto_frequency"):
64+
The data frequency of the input time series.
65+
Possible values are "auto_frequency", "per_minute", "hourly", "daily", "weekly", "monthly", "quarterly", "yearly"
66+
67+
include_drift (bool, defalut False):
68+
Determines whether the model should include a linear drift term or not. The drift term is applicable when non-seasonal d is 1.
69+
70+
holiday_region (str or None, default None):
71+
The geographical region based on which the holiday effect is applied in modeling. By default, holiday effect modeling isn't used.
72+
Possible values see https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-time-series#holiday_region.
73+
74+
clean_spikes_and_dips (bool, default True):
75+
Determines whether or not to perform automatic spikes and dips detection and cleanup in the model training pipeline. The spikes and dips are replaced with local linear interpolated values when they're detected.
76+
77+
adjust_step_changes (bool, default True):
78+
Determines whether or not to perform automatic step change detection and adjustment in the model training pipeline.
79+
80+
time_series_length_fraction (float or None, default None):
81+
The fraction of the interpolated length of the time series that's used to model the time series trend component. All of the time points of the time series are used to model the non-trend component.
82+
83+
min_time_series_length (int or None, default None):
84+
The minimum number of time points that are used in modeling the trend component of the time series.
85+
86+
max_time_series_length (int or None, default None):
87+
The maximum number of time points in a time series that can be used in modeling the trend component of the time series.
88+
89+
trend_smoothing_window_size (int or None, default None):
90+
The smoothing window size for the trend component.
91+
92+
decompose_time_series (bool, default True):
93+
Determines whether the separate components of both the history and forecast parts of the time series (such as holiday effect and seasonal components) are saved in the model.
94+
"""
95+
96+
def __init__(
97+
self,
98+
*,
99+
horizon: int = 1000,
100+
auto_arima: bool = True,
101+
auto_arima_max_order: Optional[int] = None,
102+
auto_arima_min_order: Optional[int] = None,
103+
data_frequency: str = "auto_frequency",
104+
include_drift: bool = False,
105+
holiday_region: Optional[str] = None,
106+
clean_spikes_and_dips: bool = True,
107+
adjust_step_changes: bool = True,
108+
time_series_length_fraction: Optional[float] = None,
109+
min_time_series_length: Optional[int] = None,
110+
max_time_series_length: Optional[int] = None,
111+
trend_smoothing_window_size: Optional[int] = None,
112+
decompose_time_series: bool = True,
113+
):
114+
self.horizon = horizon
115+
self.auto_arima = auto_arima
116+
self.auto_arima_max_order = auto_arima_max_order
117+
self.auto_arima_min_order = auto_arima_min_order
118+
self.data_frequency = data_frequency
119+
self.include_drift = include_drift
120+
self.holiday_region = holiday_region
121+
self.clean_spikes_and_dips = clean_spikes_and_dips
122+
self.adjust_step_changes = adjust_step_changes
123+
self.time_series_length_fraction = time_series_length_fraction
124+
self.min_time_series_length = min_time_series_length
125+
self.max_time_series_length = max_time_series_length
126+
self.trend_smoothing_window_size = trend_smoothing_window_size
127+
self.decompose_time_series = decompose_time_series
128+
# TODO(garrettwu) add order and seasonalities params, which need struct/array
32129

33-
def __init__(self):
34130
self._bqml_model: Optional[core.BqmlModel] = None
35131
self._bqml_model_factory = globals.bqml_model_factory()
36132

37133
@classmethod
38134
def _from_bq(cls, session: bigframes.Session, model: bigquery.Model) -> ARIMAPlus:
39135
assert model.model_type == "ARIMA_PLUS"
40136

41-
kwargs: Dict[str, str | int | bool | float | List[str]] = {}
137+
kwargs: dict = {}
138+
last_fitting = model.training_runs[-1]["trainingOptions"]
139+
140+
dummy_arima = cls()
141+
for bf_param, bf_value in dummy_arima.__dict__.items():
142+
bqml_param = _BQML_PARAMS_MAPPING.get(bf_param)
143+
if bqml_param in last_fitting:
144+
# Convert types
145+
if bf_param in ["time_series_length_fraction"]:
146+
kwargs[bf_param] = float(last_fitting[bqml_param])
147+
elif bf_param in [
148+
"auto_arima_max_order",
149+
"auto_arima_min_order",
150+
"min_time_series_length",
151+
"max_time_series_length",
152+
"trend_smoothing_window_size",
153+
]:
154+
kwargs[bf_param] = int(last_fitting[bqml_param])
155+
elif bf_param in ["holiday_region"]:
156+
kwargs[bf_param] = str(last_fitting[bqml_param])
157+
else:
158+
kwargs[bf_param] = type(bf_value)(last_fitting[bqml_param])
42159

43160
new_arima_plus = cls(**kwargs)
44161
new_arima_plus._bqml_model = core.BqmlModel(session, model)
45162
return new_arima_plus
46163

47164
@property
48-
def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
165+
def _bqml_options(self) -> dict:
49166
"""The model options as they will be set for BQML."""
50-
return {"model_type": "ARIMA_PLUS"}
167+
options = {
168+
"model_type": "ARIMA_PLUS",
169+
"horizon": self.horizon,
170+
"auto_arima": self.auto_arima,
171+
"data_frequency": self.data_frequency,
172+
"clean_spikes_and_dips": self.clean_spikes_and_dips,
173+
"adjust_step_changes": self.adjust_step_changes,
174+
"decompose_time_series": self.decompose_time_series,
175+
}
176+
177+
if self.auto_arima_max_order is not None:
178+
options["auto_arima_max_order"] = self.auto_arima_max_order
179+
if self.auto_arima_min_order is not None:
180+
options["auto_arima_min_order"] = self.auto_arima_min_order
181+
if self.holiday_region is not None:
182+
options["holiday_region"] = self.holiday_region
183+
if self.time_series_length_fraction is not None:
184+
options["time_series_length_fraction"] = self.time_series_length_fraction
185+
if self.min_time_series_length is not None:
186+
options["min_time_series_length"] = self.min_time_series_length
187+
if self.max_time_series_length is not None:
188+
options["max_time_series_length"] = self.max_time_series_length
189+
if self.trend_smoothing_window_size is not None:
190+
options["trend_smoothing_window_size"] = self.trend_smoothing_window_size
191+
192+
if self.include_drift:
193+
options["include_drift"] = True
194+
195+
return options
51196

52197
def _fit(
53198
self,

tests/system/large/ml/test_forecasting.py

+41
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,44 @@ def test_arima_plus_model_fit_summary(time_series_df_default_index, dataset_id):
7777
assert (
7878
f"{dataset_id}.temp_arima_plus_model" in reloaded_model._bqml_model.model_name
7979
)
80+
81+
82+
def test_arima_plus_model_fit_params(time_series_df_default_index, dataset_id):
83+
model = forecasting.ARIMAPlus(
84+
horizon=100,
85+
auto_arima=True,
86+
auto_arima_max_order=4,
87+
auto_arima_min_order=1,
88+
data_frequency="daily",
89+
holiday_region="US",
90+
clean_spikes_and_dips=False,
91+
adjust_step_changes=False,
92+
time_series_length_fraction=0.5,
93+
min_time_series_length=10,
94+
trend_smoothing_window_size=5,
95+
decompose_time_series=False,
96+
)
97+
98+
X_train = time_series_df_default_index[["parsed_date"]]
99+
y_train = time_series_df_default_index[["total_visits"]]
100+
model.fit(X_train, y_train)
101+
102+
# save, load to ensure configuration was kept
103+
reloaded_model = model.to_gbq(f"{dataset_id}.temp_arima_plus_model", replace=True)
104+
assert (
105+
f"{dataset_id}.temp_arima_plus_model" in reloaded_model._bqml_model.model_name
106+
)
107+
108+
assert reloaded_model.horizon == 100
109+
assert reloaded_model.auto_arima is True
110+
assert reloaded_model.auto_arima_max_order == 4
111+
# TODO(garrettwu): now BQML doesn't populate auto_arima_min_order
112+
# assert reloaded_model.auto_arima_min_order == 1
113+
assert reloaded_model.data_frequency == "DAILY"
114+
assert reloaded_model.holiday_region == "US"
115+
assert reloaded_model.clean_spikes_and_dips is False
116+
assert reloaded_model.adjust_step_changes is False
117+
assert reloaded_model.time_series_length_fraction == 0.5
118+
assert reloaded_model.min_time_series_length == 10
119+
assert reloaded_model.trend_smoothing_window_size == 5
120+
assert reloaded_model.decompose_time_series is False

0 commit comments

Comments
 (0)