13
13
# limitations under the License.
14
14
15
15
import pandas as pd
16
+ import pytest
16
17
17
18
from bigframes .ml import forecasting
18
19
31
32
]
32
33
33
34
34
- def test_arima_plus_model_fit_score (
35
- time_series_df_default_index , dataset_id , new_time_series_df
36
- ):
35
+ @pytest .fixture (scope = "module" )
36
+ def arima_model (time_series_df_default_index ):
37
37
model = forecasting .ARIMAPlus ()
38
38
X_train = time_series_df_default_index [["parsed_date" ]]
39
39
y_train = time_series_df_default_index [["total_visits" ]]
40
40
model .fit (X_train , y_train )
41
+ return model
42
+
43
+
44
+ def test_arima_plus_model_fit_score (
45
+ dataset_id ,
46
+ new_time_series_df ,
47
+ arima_model ,
48
+ ):
41
49
42
- result = model .score (
50
+ result = arima_model .score (
43
51
new_time_series_df [["parsed_date" ]], new_time_series_df [["total_visits" ]]
44
52
).to_pandas ()
45
53
expected = pd .DataFrame (
@@ -56,29 +64,39 @@ def test_arima_plus_model_fit_score(
56
64
pd .testing .assert_frame_equal (result , expected , check_exact = False , rtol = 0.1 )
57
65
58
66
# save, load to ensure configuration was kept
59
- reloaded_model = model .to_gbq (f"{ dataset_id } .temp_arima_plus_model" , replace = True )
67
+ reloaded_model = arima_model .to_gbq (
68
+ f"{ dataset_id } .temp_arima_plus_model" , replace = True
69
+ )
60
70
assert (
61
71
f"{ dataset_id } .temp_arima_plus_model" in reloaded_model ._bqml_model .model_name
62
72
)
63
73
64
74
65
- def test_arima_plus_model_fit_summary (time_series_df_default_index , dataset_id ):
66
- model = forecasting .ARIMAPlus ()
67
- X_train = time_series_df_default_index [["parsed_date" ]]
68
- y_train = time_series_df_default_index [["total_visits" ]]
69
- model .fit (X_train , y_train )
75
+ def test_arima_plus_model_fit_summary (dataset_id , arima_model ):
70
76
71
- result = model .summary ()
77
+ result = arima_model .summary ()
72
78
assert result .shape == (1 , 12 )
73
79
assert all (column in result .columns for column in ARIMA_EVALUATE_OUTPUT_COL )
74
80
75
81
# save, load to ensure configuration was kept
76
- reloaded_model = model .to_gbq (f"{ dataset_id } .temp_arima_plus_model" , replace = True )
82
+ reloaded_model = arima_model .to_gbq (
83
+ f"{ dataset_id } .temp_arima_plus_model" , replace = True
84
+ )
77
85
assert (
78
86
f"{ dataset_id } .temp_arima_plus_model" in reloaded_model ._bqml_model .model_name
79
87
)
80
88
81
89
90
+ def test_arima_coefficients (arima_model ):
91
+ got = arima_model .coef_
92
+ expected_columns = {
93
+ "ar_coefficients" ,
94
+ "ma_coefficients" ,
95
+ "intercept_or_drift" ,
96
+ }
97
+ assert set (got .columns ) == expected_columns
98
+
99
+
82
100
def test_arima_plus_model_fit_params (time_series_df_default_index , dataset_id ):
83
101
model = forecasting .ARIMAPlus (
84
102
horizon = 100 ,
0 commit comments