79
79
80
80
# vertexai constants
81
81
_TEST_PROJECT = "test-project"
82
- _TEST_PROJECT_NUMBER = 123
82
+ _TEST_PROJECT_NUMBER = 12345678
83
83
_TEST_LOCATION = "us-central1"
84
84
_TEST_PARENT = f"projects/{ _TEST_PROJECT } /locations/{ _TEST_LOCATION } "
85
85
_TEST_BUCKET_NAME = "gs://test-bucket"
88
88
_TEST_REMOTE_JOB_BASE_PATH = os .path .join (_TEST_BUCKET_NAME , _TEST_REMOTE_JOB_NAME )
89
89
_TEST_EXPERIMENT = "test-experiment"
90
90
_TEST_EXPERIMENT_RUN = "test-experiment-run"
91
+ _TEST_SERVICE_ACCOUNT = f"{ _TEST_PROJECT_NUMBER } [email protected] "
91
92
92
93
# dataset constants
93
94
dataset = load_iris ()
269
270
],
270
271
)
271
272
273
+ _TEST_PERSISTENT_RESOURCE_CONFIG_SERVICE_ACCOUNT = configs .PersistentResourceConfig (
274
+ name = _TEST_PERSISTENT_RESOURCE_ID ,
275
+ resource_pools = [
276
+ remote_specs .ResourcePool (
277
+ replica_count = _TEST_REPLICA_COUNT ,
278
+ ),
279
+ remote_specs .ResourcePool (
280
+ machine_type = "n1-standard-8" ,
281
+ replica_count = 2 ,
282
+ ),
283
+ ],
284
+ service_account = _TEST_SERVICE_ACCOUNT ,
285
+ )
286
+
272
287
_TEST_PERSISTENT_RESOURCE_CONFIG_DISABLE = configs .PersistentResourceConfig (
273
288
name = _TEST_PERSISTENT_RESOURCE_ID ,
274
289
resource_pools = [
@@ -1583,7 +1598,7 @@ def test_remote_training_keras_distributed_no_cuda_no_worker_pool_specs(
1583
1598
@pytest .mark .xfail (
1584
1599
sys .version_info .minor >= 8 ,
1585
1600
raises = ValueError ,
1586
- reason = "Flaky in python 3.8, 3.10, 3.11 " ,
1601
+ reason = "Flaky in python >= 3.8" ,
1587
1602
)
1588
1603
@pytest .mark .usefixtures (
1589
1604
"list_default_tensorboard_mock" ,
@@ -1667,7 +1682,7 @@ def test_remote_training_sklearn_with_experiment(
1667
1682
@pytest .mark .xfail (
1668
1683
sys .version_info .minor >= 8 ,
1669
1684
raises = ValueError ,
1670
- reason = "Flaky in python 3.8, 3.10, 3.11 " ,
1685
+ reason = "Flaky in python >= 3.8" ,
1671
1686
)
1672
1687
@pytest .mark .usefixtures (
1673
1688
"list_default_tensorboard_mock" ,
@@ -1856,6 +1871,27 @@ def test_remote_training_sklearn_with_persistent_cluster(
1856
1871
model .score (_X_TEST , _Y_TEST )
1857
1872
1858
1873
@pytest .mark .usefixtures (
1874
+ "mock_timestamped_unique_name" ,
1875
+ "mock_get_custom_job" ,
1876
+ "mock_autolog_disabled" ,
1877
+ "persistent_resource_running_mock" ,
1878
+ )
1879
+ def test_initialize_existing_persistent_resource_service_account_mismatch (self ):
1880
+ vertexai .init (
1881
+ project = _TEST_PROJECT ,
1882
+ location = _TEST_LOCATION ,
1883
+ staging_bucket = _TEST_BUCKET_NAME ,
1884
+ )
1885
+ with pytest .raises (ValueError ) as e :
1886
+ vertexai .preview .init (
1887
+ cluster = _TEST_PERSISTENT_RESOURCE_CONFIG_SERVICE_ACCOUNT
1888
+ )
1889
+ e .match (
1890
+ regexp = r"Expect the existing cluster was created with the service account "
1891
+ )
1892
+
1893
+ @pytest .mark .usefixtures (
1894
+ "mock_get_project_number" ,
1859
1895
"list_default_tensorboard_mock" ,
1860
1896
"mock_get_experiment_run" ,
1861
1897
"mock_get_metadata_store" ,
@@ -1865,7 +1901,7 @@ def test_remote_training_sklearn_with_persistent_cluster(
1865
1901
"mock_autolog_enabled" ,
1866
1902
"persistent_resource_running_mock" ,
1867
1903
)
1868
- def test_remote_training_sklearn_with_persistent_cluster_and_experiment_error (
1904
+ def test_remote_training_sklearn_with_persistent_cluster_no_service_account_and_experiment_error (
1869
1905
self ,
1870
1906
):
1871
1907
vertexai .init (
@@ -1884,9 +1920,101 @@ def test_remote_training_sklearn_with_persistent_cluster_and_experiment_error(
1884
1920
with pytest .raises (ValueError ) as e :
1885
1921
model .fit .vertex .remote_config .service_account = "GCE"
1886
1922
model .fit (_X_TRAIN , _Y_TRAIN )
1887
- e .match (
1888
- regexp = r"Persistent cluster currently does not support custom service account."
1923
+ e .match (regexp = r"The service account for autologging" )
1924
+
1925
+ # TODO(b/300116902) Remove this once we find better solution.
1926
+ @pytest .mark .xfail (
1927
+ sys .version_info .minor >= 8 ,
1928
+ raises = ValueError ,
1929
+ reason = "Flaky in python >=3.8" ,
1930
+ )
1931
+ @pytest .mark .usefixtures (
1932
+ "mock_get_project_number" ,
1933
+ "list_default_tensorboard_mock" ,
1934
+ "mock_get_experiment_run" ,
1935
+ "mock_get_metadata_store" ,
1936
+ "get_artifact_not_found_mock" ,
1937
+ "update_context_mock" ,
1938
+ "aiplatform_autolog_mock" ,
1939
+ "mock_autolog_enabled" ,
1940
+ "persistent_resource_service_account_running_mock" ,
1941
+ "mock_timestamped_unique_name" ,
1942
+ "mock_get_custom_job" ,
1943
+ )
1944
+ def test_remote_training_sklearn_with_persistent_cluster_and_experiment_autologging (
1945
+ self ,
1946
+ mock_any_serializer_sklearn ,
1947
+ mock_create_custom_job ,
1948
+ ):
1949
+ vertexai .init (
1950
+ project = _TEST_PROJECT ,
1951
+ location = _TEST_LOCATION ,
1952
+ staging_bucket = _TEST_BUCKET_NAME ,
1953
+ experiment = _TEST_EXPERIMENT ,
1889
1954
)
1955
+ vertexai .preview .init (
1956
+ remote = True ,
1957
+ autolog = True ,
1958
+ cluster = _TEST_PERSISTENT_RESOURCE_CONFIG_SERVICE_ACCOUNT ,
1959
+ )
1960
+
1961
+ vertexai .preview .start_run (_TEST_EXPERIMENT_RUN , resume = True )
1962
+
1963
+ LogisticRegression = vertexai .preview .remote (_logistic .LogisticRegression )
1964
+ model = LogisticRegression ()
1965
+
1966
+ model .fit .vertex .remote_config .service_account = _TEST_SERVICE_ACCOUNT
1967
+
1968
+ model .fit (_X_TRAIN , _Y_TRAIN )
1969
+
1970
+ # check that model is serialized correctly
1971
+ mock_any_serializer_sklearn .return_value .serialize .assert_any_call (
1972
+ to_serialize = model ,
1973
+ gcs_path = os .path .join (_TEST_REMOTE_JOB_BASE_PATH , "input/input_estimator" ),
1974
+ )
1975
+
1976
+ # check that args are serialized correctly
1977
+ mock_any_serializer_sklearn .return_value .serialize .assert_any_call (
1978
+ to_serialize = _X_TRAIN ,
1979
+ gcs_path = os .path .join (_TEST_REMOTE_JOB_BASE_PATH , "input/X" ),
1980
+ )
1981
+ mock_any_serializer_sklearn .return_value .serialize .assert_any_call (
1982
+ to_serialize = _Y_TRAIN ,
1983
+ gcs_path = os .path .join (_TEST_REMOTE_JOB_BASE_PATH , "input/y" ),
1984
+ )
1985
+
1986
+ # ckeck that CustomJob is created correctly
1987
+ expected_custom_job = _get_custom_job_proto (
1988
+ service_account = _TEST_SERVICE_ACCOUNT ,
1989
+ experiment = _TEST_EXPERIMENT ,
1990
+ experiment_run = _TEST_EXPERIMENT_RUN ,
1991
+ autolog_enabled = True ,
1992
+ persistent_resource_id = _TEST_PERSISTENT_RESOURCE_ID ,
1993
+ )
1994
+ mock_create_custom_job .assert_called_once_with (
1995
+ parent = _TEST_PARENT ,
1996
+ custom_job = expected_custom_job ,
1997
+ timeout = None ,
1998
+ )
1999
+
2000
+ # check that trained model is deserialized correctly
2001
+ mock_any_serializer_sklearn .return_value .deserialize .assert_has_calls (
2002
+ [
2003
+ mock .call (
2004
+ os .path .join (_TEST_REMOTE_JOB_BASE_PATH , "output/output_estimator" )
2005
+ ),
2006
+ mock .call (
2007
+ os .path .join (_TEST_REMOTE_JOB_BASE_PATH , "output/output_data" )
2008
+ ),
2009
+ ]
2010
+ )
2011
+
2012
+ # change to `vertexai.preview.init(remote=False)` to use local prediction
2013
+ vertexai .preview .init (remote = False )
2014
+
2015
+ # check that local model is updated in place
2016
+ # `model.score` raises NotFittedError if the model is not updated
2017
+ model .score (_X_TEST , _Y_TEST )
1890
2018
1891
2019
@pytest .mark .usefixtures (
1892
2020
"mock_timestamped_unique_name" ,
0 commit comments