Skip to content

Commit 087ff39

Browse files
author
Ivan Cheung
committed
Consolidated samples
1 parent 2fc3f90 commit 087ff39

6 files changed

+29
-110
lines changed

.sample_configs/param_handlers/create_training_pipeline_tabular_forecasting_sample.py

+26-36
Original file line numberDiff line numberDiff line change
@@ -13,52 +13,44 @@
1313
# limitations under the License.
1414
#
1515

16+
1617
def make_parent(parent: str) -> str:
1718
parent = parent
1819

1920
return parent
2021

22+
2123
def make_training_pipeline(
22-
display_name: str,
23-
dataset_id: str,
24-
model_display_name: str,
25-
target_column: str,
26-
time_series_identifier_column: str,
27-
time_column: str,
28-
static_columns: str,
29-
time_variant_past_only_columns: str,
30-
time_variant_past_and_future_columns: str,
31-
forecast_window_end: int,
32-
) -> google.cloud.aiplatform_v1alpha1.types.training_pipeline.TrainingPipeline:
24+
display_name: str,
25+
dataset_id: str,
26+
model_display_name: str,
27+
target_column: str,
28+
time_series_identifier_column: str,
29+
time_column: str,
30+
static_columns: str,
31+
time_variant_past_only_columns: str,
32+
time_variant_past_and_future_columns: str,
33+
forecast_window_end: int,
34+
) -> google.cloud.aiplatform_v1alpha1.types.training_pipeline.TrainingPipeline:
3335
# set the columns used for training and their data types
3436
transformations = [
3537
{"auto": {"column_name": "date"}},
3638
{"auto": {"column_name": "state_name"}},
3739
{"auto": {"column_name": "county_fips_code"}},
3840
{"auto": {"column_name": "confirmed_cases"}},
39-
{"auto": {"column_name": "deaths"}}
41+
{"auto": {"column_name": "deaths"}},
4042
]
4143

4244
period = {"unit": "day", "quantity": 1}
4345

46+
# the inputs should be formatted according to the training_task_definition yaml file
4447
training_task_inputs_dict = {
4548
# required inputs
4649
"targetColumn": target_column,
4750
"timeSeriesIdentifierColumn": time_series_identifier_column,
4851
"timeColumn": time_column,
4952
"transformations": transformations,
5053
"period": period,
51-
52-
# Objective function the model is to be optimized towards.
53-
# The training process creates a Model that optimizes the value of the objective
54-
# function over the validation set. The supported optimization objectives:
55-
# "minimize-rmse" (default) - Minimize root-mean-squared error (RMSE).
56-
# "minimize-mae" - Minimize mean-absolute error (MAE).
57-
# "minimize-rmsle" - Minimize root-mean-squared log error (RMSLE).
58-
# "minimize-rmspe" - Minimize root-mean-squared percentage error (RMSPE).
59-
# "minimize-wape-mae" - Minimize the combination of weighted absolute percentage error (WAPE)
60-
# and mean-absolute-error (MAE).
61-
# "minimize-quantile-loss" - Minimize the quantile loss at the defined quantiles.
6254
"optimizationObjective": "minimize-rmse",
6355
"trainBudgetMilliNodeHours": 8000,
6456
"staticColumns": static_columns,
@@ -70,20 +62,18 @@ def make_training_pipeline(
7062
training_task_inputs = to_protobuf_value(training_task_inputs_dict)
7163

7264
training_pipeline = {
73-
'display_name': display_name,
74-
'training_task_definition': "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_forecasting_1.0.0.yaml",
75-
'training_task_inputs': training_task_inputs,
76-
'input_data_config': {
77-
'dataset_id': dataset_id,
78-
'fraction_split': {
79-
'training_fraction': 0.8,
80-
'validation_fraction': 0.1,
81-
'test_fraction': 0.1,
82-
}
65+
"display_name": display_name,
66+
"training_task_definition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_forecasting_1.0.0.yaml",
67+
"training_task_inputs": training_task_inputs,
68+
"input_data_config": {
69+
"dataset_id": dataset_id,
70+
"fraction_split": {
71+
"training_fraction": 0.8,
72+
"validation_fraction": 0.1,
73+
"test_fraction": 0.1,
74+
},
8375
},
84-
'model_to_upload': {
85-
'display_name': model_display_name
86-
}
76+
"model_to_upload": {"display_name": model_display_name},
8777
}
8878

8979
return training_pipeline

.sample_configs/param_handlers/list_model_evaluations_tabular_forecasting_sample.py

-5
This file was deleted.

.sample_configs/process_configs.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ create_batch_prediction_job_custom_image_explain_sample: {}
1818
create_batch_prediction_job_custom_tabular_explain_sample: {}
1919
create_batch_prediction_job_sample: {}
2020
create_batch_prediction_job_tabular_explain_sample: {}
21+
create_batch_prediction_job_tabular_forecasting_sample: {}
2122
create_batch_prediction_job_text_classification_sample: {}
2223
create_batch_prediction_job_text_entity_extraction_sample: {}
2324
create_batch_prediction_job_text_sentiment_analysis_sample: {}

.sample_configs/variants.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ create_batch_prediction_job:
2222
- custom_image_explain
2323
- custom_tabular_explain
2424
- tabular_explain
25+
- tabular_forecasting
2526
- text_classification
2627
- text_entity_extraction
2728
- text_sentiment_analysis
@@ -132,7 +133,6 @@ get_model_evaluation:
132133
- image_classification
133134
- image_object_detection
134135
- tabular_classification
135-
- tabular_forecasting
136136
- tabular_regression
137137
- text_classification
138138
- text_entity_extraction
@@ -177,7 +177,7 @@ list_hyperparameter_tuning_jobs:
177177
list_model_evaluation_slices:
178178
- ''
179179
list_model_evaluations:
180-
- tabular_forecasting
180+
- ''
181181
list_models:
182182
- ''
183183
list_specialist_pools:

samples/snippets/get_model_evaluation_tabular_forecasting_sample.py

-37
This file was deleted.

samples/snippets/get_model_evaluation_tabular_forecasting_sample_test.py

-30
This file was deleted.

0 commit comments

Comments
 (0)