Skip to content

Commit a6465cc

Browse files
chore: Convert enable probabilistic inference from additional experime… (googleapis#1643)
* chore: Convert enable probabilistic inference from additional experiments. Converts `enable_probabilistic_inference` flag in additional experiments to a boolean field in the API `enableProbabilisticInference`, only adds if True. The flag is removed from the additional experiments to reduce duplication. * chore: Fix linting issues. Fixes extra newline. Co-authored-by: sasha-gitg <[email protected]>
1 parent 2cf9fe6 commit a6465cc

File tree

2 files changed

+122
-0
lines changed

2 files changed

+122
-0
lines changed

google/cloud/aiplatform/training_jobs.py

+17
Original file line numberDiff line numberDiff line change
@@ -2417,6 +2417,9 @@ def _run(
24172417
max_count=window_max_count,
24182418
)
24192419

2420+
# TODO(b/244643824): Replace additional experiments with a new job arg.
2421+
enable_probabilistic_inference = self._convert_enable_probabilistic_inference()
2422+
24202423
training_task_inputs_dict = {
24212424
# required inputs
24222425
"targetColumn": target_column,
@@ -2459,6 +2462,11 @@ def _run(
24592462
if window_config:
24602463
training_task_inputs_dict["windowConfig"] = window_config
24612464

2465+
if enable_probabilistic_inference:
2466+
training_task_inputs_dict[
2467+
"enableProbabilisticInference"
2468+
] = enable_probabilistic_inference
2469+
24622470
final_export_eval_bq_uri = export_evaluated_data_items_bigquery_destination_uri
24632471
if final_export_eval_bq_uri and not final_export_eval_bq_uri.startswith(
24642472
"bq://"
@@ -2541,6 +2549,15 @@ def _add_additional_experiments(self, additional_experiments: List[str]):
25412549
"""
25422550
self._additional_experiments.extend(additional_experiments)
25432551

2552+
def _convert_enable_probabilistic_inference(self) -> bool:
2553+
"""Convert enable probabilistic from additional experiments."""
2554+
key = "enable_probabilistic_inference"
2555+
if self._additional_experiments:
2556+
if key in self._additional_experiments:
2557+
self._additional_experiments.remove(key)
2558+
return True
2559+
return False
2560+
25442561
@staticmethod
25452562
def _create_window_config(
25462563
column: Optional[str] = None,

tests/unit/aiplatform/test_automl_forecasting_training_jobs.py

+105
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@
8989
_TEST_WINDOW_STRIDE_LENGTH = 1
9090
_TEST_WINDOW_MAX_COUNT = None
9191
_TEST_TRAINING_HOLIDAY_REGIONS = ["GLOBAL"]
92+
_TEST_ADDITIONAL_EXPERIMENTS_PROBABILISTIC_INFERENCE = [
93+
"exp1",
94+
"exp2",
95+
"enable_probabilistic_inference",
96+
]
9297
_TEST_TRAINING_TASK_INPUTS_DICT = {
9398
# required inputs
9499
"targetColumn": _TEST_TRAINING_TARGET_COLUMN,
@@ -134,6 +139,17 @@
134139
struct_pb2.Value(),
135140
)
136141

142+
_TEST_TRAINING_TASK_INPUTS_WITH_ADDITIONAL_EXPERIMENTS_PROBABILISTIC_INFERENCE = (
143+
json_format.ParseDict(
144+
{
145+
**_TEST_TRAINING_TASK_INPUTS_DICT,
146+
"additionalExperiments": _TEST_ADDITIONAL_EXPERIMENTS,
147+
"enableProbabilisticInference": True,
148+
},
149+
struct_pb2.Value(),
150+
)
151+
)
152+
137153
_TEST_TRAINING_TASK_INPUTS = json_format.ParseDict(
138154
_TEST_TRAINING_TASK_INPUTS_DICT,
139155
struct_pb2.Value(),
@@ -1243,3 +1259,92 @@ def test_splits_default(
12431259
training_pipeline=true_training_pipeline,
12441260
timeout=None,
12451261
)
1262+
1263+
@mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1)
1264+
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
1265+
@pytest.mark.usefixtures("mock_pipeline_service_get")
1266+
@pytest.mark.parametrize("sync", [True, False])
1267+
@pytest.mark.parametrize(
1268+
"training_job",
1269+
[
1270+
training_jobs.AutoMLForecastingTrainingJob,
1271+
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
1272+
],
1273+
)
1274+
def test_run_call_pipeline_if_set_additional_experiments_probabilistic_inference(
1275+
self,
1276+
mock_pipeline_service_create,
1277+
mock_dataset_time_series,
1278+
mock_model_service_get,
1279+
sync,
1280+
training_job,
1281+
):
1282+
aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME)
1283+
1284+
job = training_job(
1285+
display_name=_TEST_DISPLAY_NAME,
1286+
optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME,
1287+
column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS,
1288+
)
1289+
1290+
job._add_additional_experiments(
1291+
_TEST_ADDITIONAL_EXPERIMENTS_PROBABILISTIC_INFERENCE
1292+
)
1293+
1294+
model_from_job = job.run(
1295+
dataset=mock_dataset_time_series,
1296+
target_column=_TEST_TRAINING_TARGET_COLUMN,
1297+
time_column=_TEST_TRAINING_TIME_COLUMN,
1298+
time_series_identifier_column=_TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN,
1299+
unavailable_at_forecast_columns=_TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS,
1300+
available_at_forecast_columns=_TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS,
1301+
forecast_horizon=_TEST_TRAINING_FORECAST_HORIZON,
1302+
data_granularity_unit=_TEST_TRAINING_DATA_GRANULARITY_UNIT,
1303+
data_granularity_count=_TEST_TRAINING_DATA_GRANULARITY_COUNT,
1304+
weight_column=_TEST_TRAINING_WEIGHT_COLUMN,
1305+
time_series_attribute_columns=_TEST_TRAINING_TIME_SERIES_ATTRIBUTE_COLUMNS,
1306+
context_window=_TEST_TRAINING_CONTEXT_WINDOW,
1307+
budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS,
1308+
export_evaluated_data_items=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS,
1309+
export_evaluated_data_items_bigquery_destination_uri=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI,
1310+
export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION,
1311+
quantiles=_TEST_TRAINING_QUANTILES,
1312+
validation_options=_TEST_TRAINING_VALIDATION_OPTIONS,
1313+
hierarchy_group_columns=_TEST_HIERARCHY_GROUP_COLUMNS,
1314+
hierarchy_group_total_weight=_TEST_HIERARCHY_GROUP_TOTAL_WEIGHT,
1315+
hierarchy_temporal_total_weight=_TEST_HIERARCHY_TEMPORAL_TOTAL_WEIGHT,
1316+
hierarchy_group_temporal_total_weight=_TEST_HIERARCHY_GROUP_TEMPORAL_TOTAL_WEIGHT,
1317+
window_column=_TEST_WINDOW_COLUMN,
1318+
window_stride_length=_TEST_WINDOW_STRIDE_LENGTH,
1319+
window_max_count=_TEST_WINDOW_MAX_COUNT,
1320+
sync=sync,
1321+
create_request_timeout=None,
1322+
holiday_regions=_TEST_TRAINING_HOLIDAY_REGIONS,
1323+
)
1324+
1325+
if not sync:
1326+
model_from_job.wait()
1327+
1328+
# Test that if defaults to the job display name
1329+
true_managed_model = gca_model.Model(
1330+
display_name=_TEST_DISPLAY_NAME,
1331+
version_aliases=["default"],
1332+
)
1333+
1334+
true_input_data_config = gca_training_pipeline.InputDataConfig(
1335+
dataset_id=mock_dataset_time_series.name,
1336+
)
1337+
1338+
true_training_pipeline = gca_training_pipeline.TrainingPipeline(
1339+
display_name=_TEST_DISPLAY_NAME,
1340+
training_task_definition=training_job._training_task_definition,
1341+
training_task_inputs=_TEST_TRAINING_TASK_INPUTS_WITH_ADDITIONAL_EXPERIMENTS_PROBABILISTIC_INFERENCE,
1342+
model_to_upload=true_managed_model,
1343+
input_data_config=true_input_data_config,
1344+
)
1345+
1346+
mock_pipeline_service_create.assert_called_once_with(
1347+
parent=initializer.global_config.common_location_path(),
1348+
training_pipeline=true_training_pipeline,
1349+
timeout=None,
1350+
)

0 commit comments

Comments
 (0)