Skip to content

Commit c19b6c3

Browse files
yinghsienwucopybara-github
authored andcommitted
feat: Support experiment autologging when using persistent cluster as executor
PiperOrigin-RevId: 574306937
1 parent a9d7632 commit c19b6c3

File tree

7 files changed

+244
-27
lines changed

7 files changed

+244
-27
lines changed

tests/unit/vertexai/conftest.py

+27-3
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
4545
PersistentResource,
4646
ResourcePool,
47+
ResourceRuntimeSpec,
48+
ServiceAccountSpec,
4749
)
4850

4951

@@ -54,6 +56,7 @@
5456
_TEST_DISPLAY_NAME = f"{_TEST_PARENT}/customJobs/12345"
5557
_TEST_BUCKET_NAME = "gs://test_bucket"
5658
_TEST_BASE_OUTPUT_DIR = f"{_TEST_BUCKET_NAME}/test_base_output_dir"
59+
_TEST_SERVICE_ACCOUNT = f"{_TEST_PROJECT_NUMBER}[email protected]"
5760

5861
_TEST_INPUTS = [
5962
"--arg_0=string_val_0",
@@ -86,7 +89,9 @@
8689
labels={"trained_by_vertex_ai": "true"},
8790
)
8891

89-
_TEST_REQUEST_RUNNING_DEFAULT = PersistentResource()
92+
_TEST_REQUEST_RUNNING_DEFAULT = PersistentResource(
93+
resource_runtime_spec=ResourceRuntimeSpec(service_account_spec=ServiceAccountSpec())
94+
)
9095
resource_pool = ResourcePool()
9196
resource_pool.machine_spec.machine_type = "n1-standard-4"
9297
resource_pool.replica_count = 1
@@ -95,8 +100,15 @@
95100
_TEST_REQUEST_RUNNING_DEFAULT.resource_pools = [resource_pool]
96101

97102

98-
_TEST_PERSISTENT_RESOURCE_RUNNING = PersistentResource()
99-
_TEST_PERSISTENT_RESOURCE_RUNNING.state = "RUNNING"
103+
_TEST_PERSISTENT_RESOURCE_RUNNING = PersistentResource(state="RUNNING")
104+
_TEST_PERSISTENT_RESOURCE_SERVICE_ACCOUNT_RUNNING = PersistentResource(
105+
state="RUNNING",
106+
resource_runtime_spec=ResourceRuntimeSpec(
107+
service_account_spec=ServiceAccountSpec(
108+
enable_custom_service_account=True, service_account=_TEST_SERVICE_ACCOUNT
109+
)
110+
),
111+
)
100112

101113

102114
@pytest.fixture(scope="module")
@@ -284,6 +296,18 @@ def persistent_resource_running_mock():
284296
yield persistent_resource_running_mock
285297

286298

299+
@pytest.fixture
300+
def persistent_resource_service_account_running_mock():
301+
with mock.patch.object(
302+
PersistentResourceServiceClient,
303+
"get_persistent_resource",
304+
) as persistent_resource_service_account_running_mock:
305+
persistent_resource_service_account_running_mock.return_value = (
306+
_TEST_PERSISTENT_RESOURCE_SERVICE_ACCOUNT_RUNNING
307+
)
308+
yield persistent_resource_service_account_running_mock
309+
310+
287311
@pytest.fixture
288312
def persistent_resource_exception_mock():
289313
with mock.patch.object(

tests/unit/vertexai/test_persistent_resource_util.py

+8
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
)
3434
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
3535
ResourcePool,
36+
ResourceRuntimeSpec,
37+
ServiceAccountSpec,
3638
)
3739
from vertexai.preview._workflow.executor import (
3840
persistent_resource_util,
@@ -75,8 +77,14 @@
7577
)
7678
_TEST_REQUEST_RUNNING_DEFAULT = PersistentResource(
7779
resource_pools=[resource_pool_0],
80+
resource_runtime_spec=ResourceRuntimeSpec(
81+
service_account_spec=ServiceAccountSpec(enable_custom_service_account=False),
82+
),
7883
)
7984
_TEST_REQUEST_RUNNING_CUSTOM = PersistentResource(
85+
resource_runtime_spec=ResourceRuntimeSpec(
86+
service_account_spec=ServiceAccountSpec(enable_custom_service_account=False),
87+
),
8088
resource_pools=[resource_pool_0, resource_pool_1],
8189
)
8290

tests/unit/vertexai/test_remote_training.py

+134-6
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979

8080
# vertexai constants
8181
_TEST_PROJECT = "test-project"
82-
_TEST_PROJECT_NUMBER = 123
82+
_TEST_PROJECT_NUMBER = 12345678
8383
_TEST_LOCATION = "us-central1"
8484
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
8585
_TEST_BUCKET_NAME = "gs://test-bucket"
@@ -88,6 +88,7 @@
8888
_TEST_REMOTE_JOB_BASE_PATH = os.path.join(_TEST_BUCKET_NAME, _TEST_REMOTE_JOB_NAME)
8989
_TEST_EXPERIMENT = "test-experiment"
9090
_TEST_EXPERIMENT_RUN = "test-experiment-run"
91+
_TEST_SERVICE_ACCOUNT = f"{_TEST_PROJECT_NUMBER}[email protected]"
9192

9293
# dataset constants
9394
dataset = load_iris()
@@ -269,6 +270,20 @@
269270
],
270271
)
271272

273+
_TEST_PERSISTENT_RESOURCE_CONFIG_SERVICE_ACCOUNT = configs.PersistentResourceConfig(
274+
name=_TEST_PERSISTENT_RESOURCE_ID,
275+
resource_pools=[
276+
remote_specs.ResourcePool(
277+
replica_count=_TEST_REPLICA_COUNT,
278+
),
279+
remote_specs.ResourcePool(
280+
machine_type="n1-standard-8",
281+
replica_count=2,
282+
),
283+
],
284+
service_account=_TEST_SERVICE_ACCOUNT,
285+
)
286+
272287
_TEST_PERSISTENT_RESOURCE_CONFIG_DISABLE = configs.PersistentResourceConfig(
273288
name=_TEST_PERSISTENT_RESOURCE_ID,
274289
resource_pools=[
@@ -1583,7 +1598,7 @@ def test_remote_training_keras_distributed_no_cuda_no_worker_pool_specs(
15831598
@pytest.mark.xfail(
15841599
sys.version_info.minor >= 8,
15851600
raises=ValueError,
1586-
reason="Flaky in python 3.8, 3.10, 3.11",
1601+
reason="Flaky in python >=3.8",
15871602
)
15881603
@pytest.mark.usefixtures(
15891604
"list_default_tensorboard_mock",
@@ -1667,7 +1682,7 @@ def test_remote_training_sklearn_with_experiment(
16671682
@pytest.mark.xfail(
16681683
sys.version_info.minor >= 8,
16691684
raises=ValueError,
1670-
reason="Flaky in python 3.8, 3.10, 3.11",
1685+
reason="Flaky in python >=3.8",
16711686
)
16721687
@pytest.mark.usefixtures(
16731688
"list_default_tensorboard_mock",
@@ -1856,6 +1871,27 @@ def test_remote_training_sklearn_with_persistent_cluster(
18561871
model.score(_X_TEST, _Y_TEST)
18571872

18581873
@pytest.mark.usefixtures(
1874+
"mock_timestamped_unique_name",
1875+
"mock_get_custom_job",
1876+
"mock_autolog_disabled",
1877+
"persistent_resource_running_mock",
1878+
)
1879+
def test_initialize_existing_persistent_resource_service_account_mismatch(self):
1880+
vertexai.init(
1881+
project=_TEST_PROJECT,
1882+
location=_TEST_LOCATION,
1883+
staging_bucket=_TEST_BUCKET_NAME,
1884+
)
1885+
with pytest.raises(ValueError) as e:
1886+
vertexai.preview.init(
1887+
cluster=_TEST_PERSISTENT_RESOURCE_CONFIG_SERVICE_ACCOUNT
1888+
)
1889+
e.match(
1890+
regexp=r"Expect the existing cluster was created with the service account "
1891+
)
1892+
1893+
@pytest.mark.usefixtures(
1894+
"mock_get_project_number",
18591895
"list_default_tensorboard_mock",
18601896
"mock_get_experiment_run",
18611897
"mock_get_metadata_store",
@@ -1865,7 +1901,7 @@ def test_remote_training_sklearn_with_persistent_cluster(
18651901
"mock_autolog_enabled",
18661902
"persistent_resource_running_mock",
18671903
)
1868-
def test_remote_training_sklearn_with_persistent_cluster_and_experiment_error(
1904+
def test_remote_training_sklearn_with_persistent_cluster_no_service_account_and_experiment_error(
18691905
self,
18701906
):
18711907
vertexai.init(
@@ -1884,9 +1920,101 @@ def test_remote_training_sklearn_with_persistent_cluster_and_experiment_error(
18841920
with pytest.raises(ValueError) as e:
18851921
model.fit.vertex.remote_config.service_account = "GCE"
18861922
model.fit(_X_TRAIN, _Y_TRAIN)
1887-
e.match(
1888-
regexp=r"Persistent cluster currently does not support custom service account."
1923+
e.match(regexp=r"The service account for autologging")
1924+
1925+
# TODO(b/300116902) Remove this once we find better solution.
1926+
@pytest.mark.xfail(
1927+
sys.version_info.minor >= 8,
1928+
raises=ValueError,
1929+
reason="Flaky in python >=3.8",
1930+
)
1931+
@pytest.mark.usefixtures(
1932+
"mock_get_project_number",
1933+
"list_default_tensorboard_mock",
1934+
"mock_get_experiment_run",
1935+
"mock_get_metadata_store",
1936+
"get_artifact_not_found_mock",
1937+
"update_context_mock",
1938+
"aiplatform_autolog_mock",
1939+
"mock_autolog_enabled",
1940+
"persistent_resource_service_account_running_mock",
1941+
"mock_timestamped_unique_name",
1942+
"mock_get_custom_job",
1943+
)
1944+
def test_remote_training_sklearn_with_persistent_cluster_and_experiment_autologging(
1945+
self,
1946+
mock_any_serializer_sklearn,
1947+
mock_create_custom_job,
1948+
):
1949+
vertexai.init(
1950+
project=_TEST_PROJECT,
1951+
location=_TEST_LOCATION,
1952+
staging_bucket=_TEST_BUCKET_NAME,
1953+
experiment=_TEST_EXPERIMENT,
18891954
)
1955+
vertexai.preview.init(
1956+
remote=True,
1957+
autolog=True,
1958+
cluster=_TEST_PERSISTENT_RESOURCE_CONFIG_SERVICE_ACCOUNT,
1959+
)
1960+
1961+
vertexai.preview.start_run(_TEST_EXPERIMENT_RUN, resume=True)
1962+
1963+
LogisticRegression = vertexai.preview.remote(_logistic.LogisticRegression)
1964+
model = LogisticRegression()
1965+
1966+
model.fit.vertex.remote_config.service_account = _TEST_SERVICE_ACCOUNT
1967+
1968+
model.fit(_X_TRAIN, _Y_TRAIN)
1969+
1970+
# check that model is serialized correctly
1971+
mock_any_serializer_sklearn.return_value.serialize.assert_any_call(
1972+
to_serialize=model,
1973+
gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/input_estimator"),
1974+
)
1975+
1976+
# check that args are serialized correctly
1977+
mock_any_serializer_sklearn.return_value.serialize.assert_any_call(
1978+
to_serialize=_X_TRAIN,
1979+
gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/X"),
1980+
)
1981+
mock_any_serializer_sklearn.return_value.serialize.assert_any_call(
1982+
to_serialize=_Y_TRAIN,
1983+
gcs_path=os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "input/y"),
1984+
)
1985+
1986+
# ckeck that CustomJob is created correctly
1987+
expected_custom_job = _get_custom_job_proto(
1988+
service_account=_TEST_SERVICE_ACCOUNT,
1989+
experiment=_TEST_EXPERIMENT,
1990+
experiment_run=_TEST_EXPERIMENT_RUN,
1991+
autolog_enabled=True,
1992+
persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
1993+
)
1994+
mock_create_custom_job.assert_called_once_with(
1995+
parent=_TEST_PARENT,
1996+
custom_job=expected_custom_job,
1997+
timeout=None,
1998+
)
1999+
2000+
# check that trained model is deserialized correctly
2001+
mock_any_serializer_sklearn.return_value.deserialize.assert_has_calls(
2002+
[
2003+
mock.call(
2004+
os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "output/output_estimator")
2005+
),
2006+
mock.call(
2007+
os.path.join(_TEST_REMOTE_JOB_BASE_PATH, "output/output_data")
2008+
),
2009+
]
2010+
)
2011+
2012+
# change to `vertexai.preview.init(remote=False)` to use local prediction
2013+
vertexai.preview.init(remote=False)
2014+
2015+
# check that local model is updated in place
2016+
# `model.score` raises NotFittedError if the model is not updated
2017+
model.score(_X_TEST, _Y_TEST)
18902018

18912019
@pytest.mark.usefixtures(
18922020
"mock_timestamped_unique_name",

vertexai/preview/_workflow/executor/persistent_resource_util.py

+41-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
)
3232
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
3333
ResourcePool,
34+
ResourceRuntimeSpec,
35+
ServiceAccountSpec,
3436
)
3537
from google.cloud.aiplatform_v1beta1.types.persistent_resource_service import (
3638
GetPersistentResourceRequest,
@@ -61,18 +63,28 @@ def _create_persistent_resource_client(location: Optional[str] = "us-central1"):
6163
)
6264

6365

64-
def check_persistent_resource(cluster_resource_name: str) -> bool:
66+
def cluster_resource_name(project: str, location: str, name: str) -> str:
67+
"""Helper method to get persistent resource name."""
68+
client = _create_persistent_resource_client(location)
69+
return client.persistent_resource_path(project, location, name)
70+
71+
72+
def check_persistent_resource(
73+
cluster_resource_name: str, service_account: Optional[str] = None
74+
) -> bool:
6575
"""Helper method to check if a persistent resource exists or not.
6676
6777
Args:
6878
cluster_resource_name: Persistent Resource name. Has the form:
6979
``projects/my-project/locations/my-region/persistentResource/cluster-name``.
80+
service_account: Service account.
7081
7182
Returns:
7283
True if a Persistent Resource exists.
7384
7485
Raises:
7586
ValueError: if existing cluster is not RUNNING.
87+
ValueError: if service account is specified but mismatch with existing cluster.
7688
"""
7789
# Parse resource name to get the location.
7890
locataion = cluster_resource_name.split("/")[3]
@@ -91,6 +103,24 @@ def check_persistent_resource(cluster_resource_name: str) -> bool:
91103
cluster_resource_name,
92104
"` isn't running, please specify a different cluster_name.",
93105
)
106+
# Check if service account of this existing persistent resource matches initialized one.
107+
existing_cluster_service_account = (
108+
response.resource_runtime_spec.service_account_spec.service_account
109+
if response.resource_runtime_spec.service_account_spec
110+
else None
111+
)
112+
113+
if (
114+
service_account is not None
115+
and existing_cluster_service_account != service_account
116+
):
117+
raise ValueError(
118+
"Expect the existing cluster was created with the service account `",
119+
service_account,
120+
"`, but got `",
121+
existing_cluster_service_account,
122+
"` , please ensure service account is consistent with the initialization.",
123+
)
94124
return True
95125

96126

@@ -185,6 +215,7 @@ def _get_persistent_resource(cluster_resource_name: str):
185215
def create_persistent_resource(
186216
cluster_resource_name: str,
187217
resource_pools: Optional[List[remote_specs.ResourcePool]] = None,
218+
service_account: Optional[str] = None,
188219
):
189220
"""Create a persistent resource."""
190221
locataion = cluster_resource_name.split("/")[3]
@@ -209,6 +240,15 @@ def create_persistent_resource(
209240

210241
persistent_resource = PersistentResource(resource_pools=pools)
211242

243+
enable_custom_service_account = True if service_account is not None else False
244+
245+
resource_runtime_spec = ResourceRuntimeSpec(
246+
service_account_spec=ServiceAccountSpec(
247+
enable_custom_service_account=enable_custom_service_account,
248+
service_account=service_account,
249+
),
250+
)
251+
persistent_resource.resource_runtime_spec = resource_runtime_spec
212252
request = persistent_resource_service.CreatePersistentResourceRequest(
213253
parent=parent,
214254
persistent_resource=persistent_resource,

0 commit comments

Comments
 (0)