|
89 | 89 | _TEST_WINDOW_STRIDE_LENGTH = 1
|
90 | 90 | _TEST_WINDOW_MAX_COUNT = None
|
91 | 91 | _TEST_TRAINING_HOLIDAY_REGIONS = ["GLOBAL"]
|
| 92 | +_TEST_ADDITIONAL_EXPERIMENTS_PROBABILISTIC_INFERENCE = [ |
| 93 | + "exp1", |
| 94 | + "exp2", |
| 95 | + "enable_probabilistic_inference", |
| 96 | +] |
92 | 97 | _TEST_TRAINING_TASK_INPUTS_DICT = {
|
93 | 98 | # required inputs
|
94 | 99 | "targetColumn": _TEST_TRAINING_TARGET_COLUMN,
|
|
134 | 139 | struct_pb2.Value(),
|
135 | 140 | )
|
136 | 141 |
|
| 142 | +_TEST_TRAINING_TASK_INPUTS_WITH_ADDITIONAL_EXPERIMENTS_PROBABILISTIC_INFERENCE = ( |
| 143 | + json_format.ParseDict( |
| 144 | + { |
| 145 | + **_TEST_TRAINING_TASK_INPUTS_DICT, |
| 146 | + "additionalExperiments": _TEST_ADDITIONAL_EXPERIMENTS, |
| 147 | + "enableProbabilisticInference": True, |
| 148 | + }, |
| 149 | + struct_pb2.Value(), |
| 150 | + ) |
| 151 | +) |
| 152 | + |
137 | 153 | _TEST_TRAINING_TASK_INPUTS = json_format.ParseDict(
|
138 | 154 | _TEST_TRAINING_TASK_INPUTS_DICT,
|
139 | 155 | struct_pb2.Value(),
|
@@ -1243,3 +1259,92 @@ def test_splits_default(
|
1243 | 1259 | training_pipeline=true_training_pipeline,
|
1244 | 1260 | timeout=None,
|
1245 | 1261 | )
|
| 1262 | + |
| 1263 | + @mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1) |
| 1264 | + @mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1) |
| 1265 | + @pytest.mark.usefixtures("mock_pipeline_service_get") |
| 1266 | + @pytest.mark.parametrize("sync", [True, False]) |
| 1267 | + @pytest.mark.parametrize( |
| 1268 | + "training_job", |
| 1269 | + [ |
| 1270 | + training_jobs.AutoMLForecastingTrainingJob, |
| 1271 | + training_jobs.SequenceToSequencePlusForecastingTrainingJob, |
| 1272 | + ], |
| 1273 | + ) |
| 1274 | + def test_run_call_pipeline_if_set_additional_experiments_probabilistic_inference( |
| 1275 | + self, |
| 1276 | + mock_pipeline_service_create, |
| 1277 | + mock_dataset_time_series, |
| 1278 | + mock_model_service_get, |
| 1279 | + sync, |
| 1280 | + training_job, |
| 1281 | + ): |
| 1282 | + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) |
| 1283 | + |
| 1284 | + job = training_job( |
| 1285 | + display_name=_TEST_DISPLAY_NAME, |
| 1286 | + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, |
| 1287 | + column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS, |
| 1288 | + ) |
| 1289 | + |
| 1290 | + job._add_additional_experiments( |
| 1291 | + _TEST_ADDITIONAL_EXPERIMENTS_PROBABILISTIC_INFERENCE |
| 1292 | + ) |
| 1293 | + |
| 1294 | + model_from_job = job.run( |
| 1295 | + dataset=mock_dataset_time_series, |
| 1296 | + target_column=_TEST_TRAINING_TARGET_COLUMN, |
| 1297 | + time_column=_TEST_TRAINING_TIME_COLUMN, |
| 1298 | + time_series_identifier_column=_TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN, |
| 1299 | + unavailable_at_forecast_columns=_TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS, |
| 1300 | + available_at_forecast_columns=_TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS, |
| 1301 | + forecast_horizon=_TEST_TRAINING_FORECAST_HORIZON, |
| 1302 | + data_granularity_unit=_TEST_TRAINING_DATA_GRANULARITY_UNIT, |
| 1303 | + data_granularity_count=_TEST_TRAINING_DATA_GRANULARITY_COUNT, |
| 1304 | + weight_column=_TEST_TRAINING_WEIGHT_COLUMN, |
| 1305 | + time_series_attribute_columns=_TEST_TRAINING_TIME_SERIES_ATTRIBUTE_COLUMNS, |
| 1306 | + context_window=_TEST_TRAINING_CONTEXT_WINDOW, |
| 1307 | + budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, |
| 1308 | + export_evaluated_data_items=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS, |
| 1309 | + export_evaluated_data_items_bigquery_destination_uri=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI, |
| 1310 | + export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION, |
| 1311 | + quantiles=_TEST_TRAINING_QUANTILES, |
| 1312 | + validation_options=_TEST_TRAINING_VALIDATION_OPTIONS, |
| 1313 | + hierarchy_group_columns=_TEST_HIERARCHY_GROUP_COLUMNS, |
| 1314 | + hierarchy_group_total_weight=_TEST_HIERARCHY_GROUP_TOTAL_WEIGHT, |
| 1315 | + hierarchy_temporal_total_weight=_TEST_HIERARCHY_TEMPORAL_TOTAL_WEIGHT, |
| 1316 | + hierarchy_group_temporal_total_weight=_TEST_HIERARCHY_GROUP_TEMPORAL_TOTAL_WEIGHT, |
| 1317 | + window_column=_TEST_WINDOW_COLUMN, |
| 1318 | + window_stride_length=_TEST_WINDOW_STRIDE_LENGTH, |
| 1319 | + window_max_count=_TEST_WINDOW_MAX_COUNT, |
| 1320 | + sync=sync, |
| 1321 | + create_request_timeout=None, |
| 1322 | + holiday_regions=_TEST_TRAINING_HOLIDAY_REGIONS, |
| 1323 | + ) |
| 1324 | + |
| 1325 | + if not sync: |
| 1326 | + model_from_job.wait() |
| 1327 | + |
| 1328 | + # Test that if defaults to the job display name |
| 1329 | + true_managed_model = gca_model.Model( |
| 1330 | + display_name=_TEST_DISPLAY_NAME, |
| 1331 | + version_aliases=["default"], |
| 1332 | + ) |
| 1333 | + |
| 1334 | + true_input_data_config = gca_training_pipeline.InputDataConfig( |
| 1335 | + dataset_id=mock_dataset_time_series.name, |
| 1336 | + ) |
| 1337 | + |
| 1338 | + true_training_pipeline = gca_training_pipeline.TrainingPipeline( |
| 1339 | + display_name=_TEST_DISPLAY_NAME, |
| 1340 | + training_task_definition=training_job._training_task_definition, |
| 1341 | + training_task_inputs=_TEST_TRAINING_TASK_INPUTS_WITH_ADDITIONAL_EXPERIMENTS_PROBABILISTIC_INFERENCE, |
| 1342 | + model_to_upload=true_managed_model, |
| 1343 | + input_data_config=true_input_data_config, |
| 1344 | + ) |
| 1345 | + |
| 1346 | + mock_pipeline_service_create.assert_called_once_with( |
| 1347 | + parent=initializer.global_config.common_location_path(), |
| 1348 | + training_pipeline=true_training_pipeline, |
| 1349 | + timeout=None, |
| 1350 | + ) |
0 commit comments