Skip to content

Commit 97532c9

Browse files
Shuowei Lishuoweil
Shuowei Li
andauthored
feat: support time_series_id_col in ARIMAPlus (#1282)
* manually port from shuowei-arima-plus branch, now I cannot pass format test * I have resolve all conflicts after manual porting * use inherritance for arima plus model, and add sql models * resolve unexpected indent for docstring --------- Co-authored-by: Shuowei Li <[email protected]>
1 parent 2ba59e5 commit 97532c9

File tree

10 files changed

+1486
-545
lines changed

10 files changed

+1486
-545
lines changed

bigframes/ml/base.py

+28
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,34 @@ def fit(
165165
return self._fit(X, y)
166166

167167

168+
class SupervisedTrainableWithIdColPredictor(SupervisedTrainablePredictor):
169+
"""Inherits from SupervisedTrainablePredictor,
170+
but adds an optional id_col parameter to fit()."""
171+
172+
def __init__(self):
173+
super().__init__()
174+
self.id_col = None
175+
176+
def _fit(
177+
self,
178+
X: utils.ArrayType,
179+
y: utils.ArrayType,
180+
transforms=None,
181+
id_col: Optional[utils.ArrayType] = None,
182+
):
183+
return self
184+
185+
def fit(
186+
self,
187+
X: utils.ArrayType,
188+
y: utils.ArrayType,
189+
transforms=None,
190+
id_col: Optional[utils.ArrayType] = None,
191+
):
192+
self.id_col = id_col
193+
return self._fit(X, y, transforms=transforms, id_col=self.id_col)
194+
195+
168196
class TrainableWithEvaluationPredictor(TrainablePredictor):
169197
"""A BigQuery DataFrames ML Model base class that can be used to fit and predict outputs.
170198

bigframes/ml/core.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -181,15 +181,23 @@ def detect_anomalies(
181181

182182
def forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
183183
sql = self._model_manipulation_sql_generator.ml_forecast(struct_options=options)
184-
return self._session.read_gbq(sql, index_col="forecast_timestamp").reset_index()
184+
timestamp_col_name = "forecast_timestamp"
185+
index_cols = [timestamp_col_name]
186+
first_col_name = self._session.read_gbq(sql).columns.values[0]
187+
if timestamp_col_name != first_col_name:
188+
index_cols.append(first_col_name)
189+
return self._session.read_gbq(sql, index_col=index_cols).reset_index()
185190

186191
def explain_forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
187192
sql = self._model_manipulation_sql_generator.ml_explain_forecast(
188193
struct_options=options
189194
)
190-
return self._session.read_gbq(
191-
sql, index_col="time_series_timestamp"
192-
).reset_index()
195+
timestamp_col_name = "time_series_timestamp"
196+
index_cols = [timestamp_col_name]
197+
first_col_name = self._session.read_gbq(sql).columns.values[0]
198+
if timestamp_col_name != first_col_name:
199+
index_cols.append(first_col_name)
200+
return self._session.read_gbq(sql, index_col=index_cols).reset_index()
193201

194202
def evaluate(self, input_data: Optional[bpd.DataFrame] = None):
195203
sql = self._model_manipulation_sql_generator.ml_evaluate(
@@ -390,6 +398,7 @@ def create_time_series_model(
390398
self,
391399
X_train: bpd.DataFrame,
392400
y_train: bpd.DataFrame,
401+
id_col: Optional[bpd.DataFrame] = None,
393402
transforms: Optional[Iterable[str]] = None,
394403
options: Mapping[str, Union[str, int, float, Iterable[str]]] = {},
395404
) -> BqmlModel:
@@ -399,13 +408,21 @@ def create_time_series_model(
399408
assert (
400409
y_train.columns.size == 1
401410
), "Time stamp data input must only contain 1 column."
411+
assert id_col is None or (
412+
id_col is not None and id_col.columns.size == 1
413+
), "Time series id input is either None or must only contain 1 column."
402414

403415
options = dict(options)
404416
# Cache dataframes to make sure base table is not a snapshot
405417
# cached dataframe creates a full copy, never uses snapshot
406-
input_data = X_train.join(y_train, how="outer").cache()
418+
input_data = X_train.join(y_train, how="outer")
419+
if id_col is not None:
420+
input_data = input_data.join(id_col, how="outer")
421+
input_data = input_data.cache()
407422
options.update({"TIME_SERIES_TIMESTAMP_COL": X_train.columns.tolist()[0]})
408423
options.update({"TIME_SERIES_DATA_COL": y_train.columns.tolist()[0]})
424+
if id_col is not None:
425+
options.update({"TIME_SERIES_ID_COL": id_col.columns.tolist()[0]})
409426

410427
session = X_train._session
411428
model_ref = self._create_model_ref(session._anonymous_dataset)

bigframes/ml/forecasting.py

+44-14
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545

4646

4747
@log_adapter.class_logger
48-
class ARIMAPlus(base.SupervisedTrainablePredictor):
48+
class ARIMAPlus(base.SupervisedTrainableWithIdColPredictor):
4949
"""Time Series ARIMA Plus model.
5050
5151
Args:
@@ -183,37 +183,53 @@ def _fit(
183183
X: utils.ArrayType,
184184
y: utils.ArrayType,
185185
transforms: Optional[List[str]] = None,
186-
):
186+
id_col: Optional[utils.ArrayType] = None,
187+
) -> ARIMAPlus:
187188
"""Fit the model to training data.
188189
189190
Args:
190-
X (bigframes.dataframe.DataFrame or bigframes.series.Series):
191-
A dataframe of training timestamp.
192-
193-
y (bigframes.dataframe.DataFrame or bigframes.series.Series):
191+
X (bigframes.dataframe.DataFrame or bigframes.series.Series,
192+
or pandas.core.frame.DataFrame or pandas.core.series.Series):
193+
A dataframe or series of trainging timestamp.
194+
y (bigframes.dataframe.DataFrame, or bigframes.series.Series,
195+
or pandas.core.frame.DataFrame, or pandas.core.series.Series):
194196
Target values for training.
195197
transforms (Optional[List[str]], default None):
196198
Do not use. Internal param to be deprecated.
197199
Use bigframes.ml.pipeline instead.
200+
id_col (Optional[bigframes.dataframe.DataFrame]
201+
or Optional[bigframes.series.Series]
202+
or Optional[pandas.core.frame.DataFrame]
203+
or Optional[pandas.core.frame.Series]
204+
or None, default None):
205+
An optional dataframe or series of training id col.
198206
199207
Returns:
200208
ARIMAPlus: Fitted estimator.
201209
"""
202210
X, y = utils.batch_convert_to_dataframe(X, y)
203211

204212
if X.columns.size != 1:
205-
raise ValueError(
206-
"Time series timestamp input X must only contain 1 column."
207-
)
213+
raise ValueError("Time series timestamp input X contain at least 1 column.")
208214
if y.columns.size != 1:
209215
raise ValueError("Time series data input y must only contain 1 column.")
210216

217+
if id_col is not None:
218+
(id_col,) = utils.batch_convert_to_dataframe(id_col)
219+
220+
if id_col.columns.size != 1:
221+
raise ValueError(
222+
"Time series id input id_col must only contain 1 column."
223+
)
224+
211225
self._bqml_model = self._bqml_model_factory.create_time_series_model(
212226
X,
213227
y,
228+
id_col=id_col,
214229
transforms=transforms,
215230
options=self._bqml_options,
216231
)
232+
return self
217233

218234
def predict(
219235
self, X=None, *, horizon: int = 3, confidence_level: float = 0.95
@@ -237,7 +253,7 @@ def predict(
237253
238254
Returns:
239255
bigframes.dataframe.DataFrame: The predicted DataFrames. Which
240-
contains 2 columns: "forecast_timestamp" and "forecast_value".
256+
contains 2 columns: "forecast_timestamp", "id" as optional, and "forecast_value".
241257
"""
242258
if horizon < 1 or horizon > 1000:
243259
raise ValueError(f"horizon must be [1, 1000], but is {horizon}.")
@@ -345,6 +361,7 @@ def score(
345361
self,
346362
X: utils.ArrayType,
347363
y: utils.ArrayType,
364+
id_col: Optional[utils.ArrayType] = None,
348365
) -> bpd.DataFrame:
349366
"""Calculate evaluation metrics of the model.
350367
@@ -355,13 +372,22 @@ def score(
355372
for the outputs relevant to this model type.
356373
357374
Args:
358-
X (bigframes.dataframe.DataFrame or bigframes.series.Series or pandas.core.frame.DataFrame or pandas.core.series.Series):
359-
A BigQuery DataFrame only contains 1 column as
375+
X (bigframes.dataframe.DataFrame or bigframes.series.Series
376+
or pandas.core.frame.DataFrame or pandas.core.series.Series):
377+
A dataframe or series only contains 1 column as
360378
evaluation timestamp. The timestamp must be within the horizon
361379
of the model, which by default is 1000 data points.
362-
y (bigframes.dataframe.DataFrame or bigframes.series.Series or pandas.core.frame.DataFrame or pandas.core.series.Series):
363-
A BigQuery DataFrame only contains 1 column as
380+
y (bigframes.dataframe.DataFrame or bigframes.series.Series
381+
or pandas.core.frame.DataFrame or pandas.core.series.Series):
382+
A dataframe or series only contains 1 column as
364383
evaluation numeric values.
384+
id_col (Optional[bigframes.dataframe.DataFrame]
385+
or Optional[bigframes.series.Series]
386+
or Optional[pandas.core.frame.DataFrame]
387+
or Optional[pandas.core.series.Series]
388+
or None, default None):
389+
An optional dataframe or series contains at least 1 column as
390+
evaluation id column.
365391
366392
Returns:
367393
bigframes.dataframe.DataFrame: A DataFrame as evaluation result.
@@ -371,6 +397,10 @@ def score(
371397
X, y = utils.batch_convert_to_dataframe(X, y, session=self._bqml_model.session)
372398

373399
input_data = X.join(y, how="outer")
400+
if id_col is not None:
401+
(id_col,) = utils.batch_convert_to_dataframe(id_col)
402+
input_data = input_data.join(id_col, how="outer")
403+
374404
return self._bqml_model.evaluate(input_data)
375405

376406
def summary(

0 commit comments

Comments
 (0)