Skip to content

Commit 81d1262

Browse files
SalemJordenSalem Boylandtswast
authored
feat: add ARIMAPlus.coef_ property exposing ML.ARIMA_COEFFICIENTS functionality (#585)
* create_single_timeseries_forecasting_model_test.py code sample * fix: forecast method to forecast time series * pair programming PR draft creation * feature: insoect coefficients * update tests for new feature * add arima_model.coef_ to fetch coefficients * updated tests for coefficients feature * feature update for arima_coefficients * updates to output cols * docstring updates --------- Co-authored-by: Salem Boyland <[email protected]> Co-authored-by: Tim Sweña (Swast) <[email protected]>
1 parent 8e4616b commit 81d1262

File tree

5 files changed

+70
-12
lines changed

5 files changed

+70
-12
lines changed

bigframes/ml/core.py

+5
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,11 @@ def arima_evaluate(self, show_all_candidate_models: bool = False):
205205

206206
return self._session.read_gbq(sql)
207207

208+
def arima_coefficients(self) -> bpd.DataFrame:
209+
sql = self._model_manipulation_sql_generator.ml_arima_coefficients()
210+
211+
return self._session.read_gbq(sql)
212+
208213
def centroids(self) -> bpd.DataFrame:
209214
assert self._model.model_type == "KMEANS"
210215

bigframes/ml/forecasting.py

+21
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,27 @@ def predict(
269269
options={"horizon": horizon, "confidence_level": confidence_level}
270270
)
271271

272+
@property
273+
def coef_(
274+
self,
275+
) -> bpd.DataFrame:
276+
"""Inspect the coefficients of the model.
277+
278+
..note::
279+
280+
Output matches that of the ML.ARIMA_COEFFICIENTS function.
281+
See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-arima-coefficients
282+
for the outputs relevant to this model type.
283+
284+
Returns:
285+
bigframes.dataframe.DataFrame:
286+
A DataFrame with the coefficients for the model.
287+
"""
288+
289+
if not self._bqml_model:
290+
raise RuntimeError("A model must be fitted before inspect coefficients")
291+
return self._bqml_model.arima_coefficients()
292+
272293
def detect_anomalies(
273294
self,
274295
X: Union[bpd.DataFrame, bpd.Series],

bigframes/ml/sql.py

+4
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,10 @@ def ml_evaluate(self, source_df: Optional[bpd.DataFrame] = None) -> str:
318318
return f"""SELECT * FROM ML.EVALUATE(MODEL `{self._model_name}`,
319319
({source_sql}))"""
320320

321+
def ml_arima_coefficients(self) -> str:
322+
"""Encode ML.ARIMA_COEFFICIENTS for BQML"""
323+
return f"""SELECT * FROM ML.ARIMA_COEFFICIENTS(MODEL `{self._model_name}`)"""
324+
321325
# ML evaluation TVFs
322326
def ml_llm_evaluate(
323327
self, source_df: bpd.DataFrame, task_type: Optional[str] = None

tests/system/large/ml/test_forecasting.py

+30-12
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import pandas as pd
16+
import pytest
1617

1718
from bigframes.ml import forecasting
1819

@@ -31,15 +32,22 @@
3132
]
3233

3334

34-
def test_arima_plus_model_fit_score(
35-
time_series_df_default_index, dataset_id, new_time_series_df
36-
):
35+
@pytest.fixture(scope="module")
36+
def arima_model(time_series_df_default_index):
3737
model = forecasting.ARIMAPlus()
3838
X_train = time_series_df_default_index[["parsed_date"]]
3939
y_train = time_series_df_default_index[["total_visits"]]
4040
model.fit(X_train, y_train)
41+
return model
42+
43+
44+
def test_arima_plus_model_fit_score(
45+
dataset_id,
46+
new_time_series_df,
47+
arima_model,
48+
):
4149

42-
result = model.score(
50+
result = arima_model.score(
4351
new_time_series_df[["parsed_date"]], new_time_series_df[["total_visits"]]
4452
).to_pandas()
4553
expected = pd.DataFrame(
@@ -56,29 +64,39 @@ def test_arima_plus_model_fit_score(
5664
pd.testing.assert_frame_equal(result, expected, check_exact=False, rtol=0.1)
5765

5866
# save, load to ensure configuration was kept
59-
reloaded_model = model.to_gbq(f"{dataset_id}.temp_arima_plus_model", replace=True)
67+
reloaded_model = arima_model.to_gbq(
68+
f"{dataset_id}.temp_arima_plus_model", replace=True
69+
)
6070
assert (
6171
f"{dataset_id}.temp_arima_plus_model" in reloaded_model._bqml_model.model_name
6272
)
6373

6474

65-
def test_arima_plus_model_fit_summary(time_series_df_default_index, dataset_id):
66-
model = forecasting.ARIMAPlus()
67-
X_train = time_series_df_default_index[["parsed_date"]]
68-
y_train = time_series_df_default_index[["total_visits"]]
69-
model.fit(X_train, y_train)
75+
def test_arima_plus_model_fit_summary(dataset_id, arima_model):
7076

71-
result = model.summary()
77+
result = arima_model.summary()
7278
assert result.shape == (1, 12)
7379
assert all(column in result.columns for column in ARIMA_EVALUATE_OUTPUT_COL)
7480

7581
# save, load to ensure configuration was kept
76-
reloaded_model = model.to_gbq(f"{dataset_id}.temp_arima_plus_model", replace=True)
82+
reloaded_model = arima_model.to_gbq(
83+
f"{dataset_id}.temp_arima_plus_model", replace=True
84+
)
7785
assert (
7886
f"{dataset_id}.temp_arima_plus_model" in reloaded_model._bqml_model.model_name
7987
)
8088

8189

90+
def test_arima_coefficients(arima_model):
91+
got = arima_model.coef_
92+
expected_columns = {
93+
"ar_coefficients",
94+
"ma_coefficients",
95+
"intercept_or_drift",
96+
}
97+
assert set(got.columns) == expected_columns
98+
99+
82100
def test_arima_plus_model_fit_params(time_series_df_default_index, dataset_id):
83101
model = forecasting.ARIMAPlus(
84102
horizon=100,

tests/unit/ml/test_sql.py

+10
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,16 @@ def mock_df():
4747
return mock_df
4848

4949

50+
def test_ml_arima_coefficients(
51+
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
52+
):
53+
sql = model_manipulation_sql_generator.ml_arima_coefficients()
54+
assert (
55+
sql
56+
== """SELECT * FROM ML.ARIMA_COEFFICIENTS(MODEL `my_project_id.my_dataset_id.my_model_id`)"""
57+
)
58+
59+
5060
def test_options_correct(base_sql_generator: ml_sql.BaseSqlGenerator):
5161
sql = base_sql_generator.options(
5262
model_type="lin_reg", input_label_cols=["col_a"], l1_reg=0.6

0 commit comments

Comments
 (0)