Skip to content

Commit a74d6c3

Browse files
author
Ivan Cheung
committed
Added forecasting test
1 parent a6ac96d commit a74d6c3

File tree

1 file changed

+87
-0
lines changed

1 file changed

+87
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
from uuid import uuid4
17+
18+
from google.cloud import aiplatform
19+
import pytest
20+
21+
import cancel_training_pipeline_sample
22+
import create_training_pipeline_tabular_forecasting_sample
23+
import delete_training_pipeline_sample
24+
import helpers
25+
26+
PROJECT_ID = os.getenv("BUILD_SPECIFIC_GCLOUD_PROJECT")
27+
DATASET_ID = "7172228697192136704" # iris 1000
28+
DISPLAY_NAME = f"temp_create_training_pipeline_test_{uuid4()}"
29+
TARGET_COLUMN = "deaths"
30+
PREDICTION_TYPE = "forecasting"
31+
32+
33+
@pytest.fixture
34+
def shared_state():
35+
state = {}
36+
yield state
37+
38+
39+
@pytest.fixture(scope="function", autouse=True)
40+
def teardown(shared_state):
41+
yield
42+
43+
training_pipeline_id = shared_state["training_pipeline_name"].split("/")[-1]
44+
45+
# Stop the training pipeline
46+
cancel_training_pipeline_sample.cancel_training_pipeline_sample(
47+
project=PROJECT_ID, training_pipeline_id=training_pipeline_id
48+
)
49+
50+
client_options = {"api_endpoint": "us-central1-aiplatform.googleapis.com"}
51+
pipeline_client = aiplatform.gapic.PipelineServiceClient(
52+
client_options=client_options
53+
)
54+
55+
# Waiting for training pipeline to be in CANCELLED state
56+
helpers.wait_for_job_state(
57+
get_job_method=pipeline_client.get_training_pipeline,
58+
name=shared_state["training_pipeline_name"],
59+
)
60+
61+
# Delete the training pipeline
62+
delete_training_pipeline_sample.delete_training_pipeline_sample(
63+
project=PROJECT_ID, training_pipeline_id=training_pipeline_id
64+
)
65+
66+
67+
def test_ucaip_generated_create_training_pipeline_sample(capsys, shared_state):
68+
69+
create_training_pipeline_tabular_forecasting_sample.create_training_pipeline_tabular_forecasting_sample(
70+
project=PROJECT_ID,
71+
display_name=DISPLAY_NAME,
72+
dataset_id=DATASET_ID,
73+
model_display_name=f"Temp Model for {DISPLAY_NAME}",
74+
target_column=TARGET_COLUMN,
75+
time_series_identifier_column="county",
76+
time_column="date",
77+
static_columns=["state_name"],
78+
time_variant_past_only_columns=["deaths"],
79+
time_variant_past_and_future_columns=[],
80+
forecast_window_end=10
81+
)
82+
83+
out, _ = capsys.readouterr()
84+
assert "response:" in out
85+
86+
# Save resource name of the newly created training pipeline
87+
shared_state["training_pipeline_name"] = helpers.get_name(out)

0 commit comments

Comments
 (0)