Skip to content

Commit fe7a902

Browse files
feat: Update proto definitions for bigquery/v2 to support new proto fields for BQML. (#817)
PiperOrigin-RevId: 387137741 Source-Link: googleapis/googleapis@8962c92 Source-Link: googleapis/googleapis-gen@102f1b4
1 parent 3c1be14 commit fe7a902

File tree

2 files changed

+107
-9
lines changed

2 files changed

+107
-9
lines changed

google/cloud/bigquery_v2/types/model.py

+95-9
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ class Model(proto.Message):
9696
Output only. Label columns that were used to train this
9797
model. The output of the model will have a `predicted_`
9898
prefix to these columns.
99+
best_trial_id (int):
100+
The best trial_id across all training runs.
99101
"""
100102

101103
class ModelType(proto.Enum):
@@ -113,6 +115,7 @@ class ModelType(proto.Enum):
113115
ARIMA = 11
114116
AUTOML_REGRESSOR = 12
115117
AUTOML_CLASSIFIER = 13
118+
ARIMA_PLUS = 19
116119

117120
class LossType(proto.Enum):
118121
r"""Loss metric to evaluate model training performance."""
@@ -151,6 +154,7 @@ class DataFrequency(proto.Enum):
151154
WEEKLY = 5
152155
DAILY = 6
153156
HOURLY = 7
157+
PER_MINUTE = 8
154158

155159
class HolidayRegion(proto.Enum):
156160
r"""Type of supported holiday regions for time series forecasting
@@ -285,7 +289,7 @@ class RegressionMetrics(proto.Message):
285289
median_absolute_error (google.protobuf.wrappers_pb2.DoubleValue):
286290
Median absolute error.
287291
r_squared (google.protobuf.wrappers_pb2.DoubleValue):
288-
R^2 score.
292+
R^2 score. This corresponds to r2_score in ML.EVALUATE.
289293
"""
290294

291295
mean_absolute_error = proto.Field(
@@ -528,7 +532,7 @@ class ClusteringMetrics(proto.Message):
528532
Mean of squared distances between each sample
529533
to its cluster centroid.
530534
clusters (Sequence[google.cloud.bigquery_v2.types.Model.ClusteringMetrics.Cluster]):
531-
[Beta] Information for all clusters.
535+
Information for all clusters.
532536
"""
533537

534538
class Cluster(proto.Message):
@@ -697,10 +701,29 @@ class ArimaSingleModelForecastingMetrics(proto.Message):
697701
Is arima model fitted with drift or not. It
698702
is always false when d is not 1.
699703
time_series_id (str):
700-
The id to indicate different time series.
704+
The time_series_id value for this time series. It will be
705+
one of the unique values from the time_series_id_column
706+
specified during ARIMA model training. Only present when
707+
time_series_id_column training option was used.
708+
time_series_ids (Sequence[str]):
709+
The tuple of time_series_ids identifying this time series.
710+
It will be one of the unique tuples of values present in the
711+
time_series_id_columns specified during ARIMA model
712+
training. Only present when time_series_id_columns training
713+
option was used and the order of values here are same as the
714+
order of time_series_id_columns.
701715
seasonal_periods (Sequence[google.cloud.bigquery_v2.types.Model.SeasonalPeriod.SeasonalPeriodType]):
702716
Seasonal periods. Repeated because multiple
703717
periods are supported for one time series.
718+
has_holiday_effect (google.protobuf.wrappers_pb2.BoolValue):
719+
If true, holiday_effect is a part of time series
720+
decomposition result.
721+
has_spikes_and_dips (google.protobuf.wrappers_pb2.BoolValue):
722+
If true, spikes_and_dips is a part of time series
723+
decomposition result.
724+
has_step_changes (google.protobuf.wrappers_pb2.BoolValue):
725+
If true, step_changes is a part of time series decomposition
726+
result.
704727
"""
705728

706729
non_seasonal_order = proto.Field(
@@ -711,9 +734,19 @@ class ArimaSingleModelForecastingMetrics(proto.Message):
711734
)
712735
has_drift = proto.Field(proto.BOOL, number=3,)
713736
time_series_id = proto.Field(proto.STRING, number=4,)
737+
time_series_ids = proto.RepeatedField(proto.STRING, number=9,)
714738
seasonal_periods = proto.RepeatedField(
715739
proto.ENUM, number=5, enum="Model.SeasonalPeriod.SeasonalPeriodType",
716740
)
741+
has_holiday_effect = proto.Field(
742+
proto.MESSAGE, number=6, message=wrappers_pb2.BoolValue,
743+
)
744+
has_spikes_and_dips = proto.Field(
745+
proto.MESSAGE, number=7, message=wrappers_pb2.BoolValue,
746+
)
747+
has_step_changes = proto.Field(
748+
proto.MESSAGE, number=8, message=wrappers_pb2.BoolValue,
749+
)
717750

718751
non_seasonal_order = proto.RepeatedField(
719752
proto.MESSAGE, number=1, message="Model.ArimaOrder",
@@ -901,7 +934,7 @@ class TrainingRun(proto.Message):
901934
"""
902935

903936
class TrainingOptions(proto.Message):
904-
r"""
937+
r"""Options used in model training.
905938
Attributes:
906939
max_iterations (int):
907940
The maximum number of iterations in training.
@@ -972,8 +1005,9 @@ class TrainingOptions(proto.Message):
9721005
num_clusters (int):
9731006
Number of clusters for clustering models.
9741007
model_uri (str):
975-
[Beta] Google Cloud Storage URI from which the model was
976-
imported. Only applicable for imported models.
1008+
Google Cloud Storage URI from which the model
1009+
was imported. Only applicable for imported
1010+
models.
9771011
optimization_strategy (google.cloud.bigquery_v2.types.Model.OptimizationStrategy):
9781012
Optimization strategy for training linear
9791013
regression models.
@@ -1030,8 +1064,11 @@ class TrainingOptions(proto.Message):
10301064
If a valid value is specified, then holiday
10311065
effects modeling is enabled.
10321066
time_series_id_column (str):
1033-
The id column that will be used to indicate
1034-
different time series to forecast in parallel.
1067+
The time series id column that was used
1068+
during ARIMA model training.
1069+
time_series_id_columns (Sequence[str]):
1070+
The time series id columns that were used
1071+
during ARIMA model training.
10351072
horizon (int):
10361073
The number of periods ahead that need to be
10371074
forecasted.
@@ -1042,6 +1079,15 @@ class TrainingOptions(proto.Message):
10421079
output feature name is A.b.
10431080
auto_arima_max_order (int):
10441081
The max value of non-seasonal p and q.
1082+
decompose_time_series (google.protobuf.wrappers_pb2.BoolValue):
1083+
If true, perform decompose time series and
1084+
save the results.
1085+
clean_spikes_and_dips (google.protobuf.wrappers_pb2.BoolValue):
1086+
If true, clean spikes and dips in the input
1087+
time series.
1088+
adjust_step_changes (google.protobuf.wrappers_pb2.BoolValue):
1089+
If true, detect step changes and make data
1090+
adjustment in the input time series.
10451091
"""
10461092

10471093
max_iterations = proto.Field(proto.INT64, number=1,)
@@ -1120,9 +1166,19 @@ class TrainingOptions(proto.Message):
11201166
proto.ENUM, number=42, enum="Model.HolidayRegion",
11211167
)
11221168
time_series_id_column = proto.Field(proto.STRING, number=43,)
1169+
time_series_id_columns = proto.RepeatedField(proto.STRING, number=51,)
11231170
horizon = proto.Field(proto.INT64, number=44,)
11241171
preserve_input_structs = proto.Field(proto.BOOL, number=45,)
11251172
auto_arima_max_order = proto.Field(proto.INT64, number=46,)
1173+
decompose_time_series = proto.Field(
1174+
proto.MESSAGE, number=50, message=wrappers_pb2.BoolValue,
1175+
)
1176+
clean_spikes_and_dips = proto.Field(
1177+
proto.MESSAGE, number=52, message=wrappers_pb2.BoolValue,
1178+
)
1179+
adjust_step_changes = proto.Field(
1180+
proto.MESSAGE, number=53, message=wrappers_pb2.BoolValue,
1181+
)
11261182

11271183
class IterationResult(proto.Message):
11281184
r"""Information about a single iteration of the training run.
@@ -1218,10 +1274,29 @@ class ArimaModelInfo(proto.Message):
12181274
Whether Arima model fitted with drift or not.
12191275
It is always false when d is not 1.
12201276
time_series_id (str):
1221-
The id to indicate different time series.
1277+
The time_series_id value for this time series. It will be
1278+
one of the unique values from the time_series_id_column
1279+
specified during ARIMA model training. Only present when
1280+
time_series_id_column training option was used.
1281+
time_series_ids (Sequence[str]):
1282+
The tuple of time_series_ids identifying this time series.
1283+
It will be one of the unique tuples of values present in the
1284+
time_series_id_columns specified during ARIMA model
1285+
training. Only present when time_series_id_columns training
1286+
option was used and the order of values here are same as the
1287+
order of time_series_id_columns.
12221288
seasonal_periods (Sequence[google.cloud.bigquery_v2.types.Model.SeasonalPeriod.SeasonalPeriodType]):
12231289
Seasonal periods. Repeated because multiple
12241290
periods are supported for one time series.
1291+
has_holiday_effect (google.protobuf.wrappers_pb2.BoolValue):
1292+
If true, holiday_effect is a part of time series
1293+
decomposition result.
1294+
has_spikes_and_dips (google.protobuf.wrappers_pb2.BoolValue):
1295+
If true, spikes_and_dips is a part of time series
1296+
decomposition result.
1297+
has_step_changes (google.protobuf.wrappers_pb2.BoolValue):
1298+
If true, step_changes is a part of time series decomposition
1299+
result.
12251300
"""
12261301

12271302
non_seasonal_order = proto.Field(
@@ -1237,11 +1312,21 @@ class ArimaModelInfo(proto.Message):
12371312
)
12381313
has_drift = proto.Field(proto.BOOL, number=4,)
12391314
time_series_id = proto.Field(proto.STRING, number=5,)
1315+
time_series_ids = proto.RepeatedField(proto.STRING, number=10,)
12401316
seasonal_periods = proto.RepeatedField(
12411317
proto.ENUM,
12421318
number=6,
12431319
enum="Model.SeasonalPeriod.SeasonalPeriodType",
12441320
)
1321+
has_holiday_effect = proto.Field(
1322+
proto.MESSAGE, number=7, message=wrappers_pb2.BoolValue,
1323+
)
1324+
has_spikes_and_dips = proto.Field(
1325+
proto.MESSAGE, number=8, message=wrappers_pb2.BoolValue,
1326+
)
1327+
has_step_changes = proto.Field(
1328+
proto.MESSAGE, number=9, message=wrappers_pb2.BoolValue,
1329+
)
12451330

12461331
arima_model_info = proto.RepeatedField(
12471332
proto.MESSAGE,
@@ -1319,6 +1404,7 @@ class ArimaModelInfo(proto.Message):
13191404
label_columns = proto.RepeatedField(
13201405
proto.MESSAGE, number=11, message=standard_sql.StandardSqlField,
13211406
)
1407+
best_trial_id = proto.Field(proto.INT64, number=19,)
13221408

13231409

13241410
class GetModelRequest(proto.Message):

google/cloud/bigquery_v2/types/table_reference.py

+12
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,23 @@ class TableReference(proto.Message):
3636
maximum length is 1,024 characters. Certain operations allow
3737
suffixing of the table ID with a partition decorator, such
3838
as ``sample_table$20190123``.
39+
project_id_alternative (Sequence[str]):
40+
The alternative field that will be used when ESF is not able
41+
to translate the received data to the project_id field.
42+
dataset_id_alternative (Sequence[str]):
43+
The alternative field that will be used when ESF is not able
44+
to translate the received data to the project_id field.
45+
table_id_alternative (Sequence[str]):
46+
The alternative field that will be used when ESF is not able
47+
to translate the received data to the project_id field.
3948
"""
4049

4150
project_id = proto.Field(proto.STRING, number=1,)
4251
dataset_id = proto.Field(proto.STRING, number=2,)
4352
table_id = proto.Field(proto.STRING, number=3,)
53+
project_id_alternative = proto.RepeatedField(proto.STRING, number=4,)
54+
dataset_id_alternative = proto.RepeatedField(proto.STRING, number=5,)
55+
table_id_alternative = proto.RepeatedField(proto.STRING, number=6,)
4456

4557

4658
__all__ = tuple(sorted(__protobuf__.manifest))

0 commit comments

Comments
 (0)