Skip to content

Commit 2401a1d

Browse files
jaycee-licopybara-github
authored andcommitted
fix: list method for MLMD schema classes
PiperOrigin-RevId: 506383196
1 parent 076308f commit 2401a1d

File tree

5 files changed

+241
-27
lines changed

5 files changed

+241
-27
lines changed

google/cloud/aiplatform/metadata/experiment_run_resource.py

+3-23
Original file line numberDiff line numberDiff line change
@@ -1418,34 +1418,14 @@ def get_experiment_models(self) -> List[google_artifact_schema.ExperimentModel]:
14181418
Returns:
14191419
List of ExperimentModel instances associated this run.
14201420
"""
1421-
# TODO(b/264194064) Replace this by ExperimentModel.list
1422-
artifact_list = artifact.Artifact.list(
1423-
filter=metadata_utils._make_filter_string(
1424-
in_context=[self.resource_name],
1425-
schema_title=google_artifact_schema.ExperimentModel.schema_title,
1426-
),
1421+
experiment_model_list = google_artifact_schema.ExperimentModel.list(
1422+
filter=metadata_utils._make_filter_string(in_context=[self.resource_name]),
14271423
project=self.project,
14281424
location=self.location,
14291425
credentials=self.credentials,
14301426
)
14311427

1432-
res = []
1433-
for model_artifact in artifact_list:
1434-
experiment_model = google_artifact_schema.ExperimentModel(
1435-
framework_name="",
1436-
framework_version="",
1437-
model_file="",
1438-
uri="",
1439-
)
1440-
experiment_model._gca_resource = model_artifact._gca_resource
1441-
experiment_model.project = model_artifact.project
1442-
experiment_model.location = model_artifact.location
1443-
experiment_model.credentials = model_artifact.credentials
1444-
experiment_model.api_client = model_artifact.api_client
1445-
1446-
res.append(experiment_model)
1447-
1448-
return res
1428+
return experiment_model_list
14491429

14501430
@_v1_not_supported
14511431
def associate_execution(self, execution: execution.Execution):

google/cloud/aiplatform/metadata/schema/base_artifact.py

+59-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
# Copyright 2022 Google LLC
3+
# Copyright 2023 Google LLC
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.
@@ -17,7 +17,7 @@
1717

1818
import abc
1919

20-
from typing import Any, Optional, Dict
20+
from typing import Any, Optional, Dict, List
2121

2222
from google.auth import credentials as auth_credentials
2323
from google.cloud.aiplatform.compat.types import artifact as gca_artifact
@@ -202,6 +202,63 @@ def create(
202202
self._init_with_resource_name(artifact_name=new_artifact_instance.resource_name)
203203
return self
204204

205+
@classmethod
206+
def list(
207+
cls,
208+
filter: Optional[str] = None, # pylint: disable=redefined-builtin
209+
metadata_store_id: str = "default",
210+
project: Optional[str] = None,
211+
location: Optional[str] = None,
212+
credentials: Optional[auth_credentials.Credentials] = None,
213+
order_by: Optional[str] = None,
214+
) -> List["BaseArtifactSchema"]:
215+
"""List all the Artifact resources with a particular schema.
216+
217+
Args:
218+
filter (str):
219+
Optional. A query to filter available resources for
220+
matching results.
221+
metadata_store_id (str):
222+
The <metadata_store_id> portion of the resource name with
223+
the format:
224+
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>
225+
If not provided, the MetadataStore's ID will be set to "default".
226+
project (str):
227+
Project used to create this resource. Overrides project set in
228+
aiplatform.init.
229+
location (str):
230+
Location used to create this resource. Overrides location set in
231+
aiplatform.init.
232+
credentials (auth_credentials.Credentials):
233+
Custom credentials used to create this resource. Overrides
234+
credentials set in aiplatform.init.
235+
order_by (str):
236+
Optional. How the list of messages is ordered.
237+
Specify the values to order by and an ordering operation. The
238+
default sorting order is ascending. To specify descending order
239+
for a field, users append a " desc" suffix; for example: "foo
240+
desc, bar". Subfields are specified with a ``.`` character, such
241+
as foo.bar. see https://google.aip.dev/132#ordering for more
242+
details.
243+
244+
Returns:
245+
A list of artifact resources with a particular schema.
246+
247+
"""
248+
schema_filter = f'schema_title="{cls.schema_title}"'
249+
if filter:
250+
filter = f"{filter} AND {schema_filter}"
251+
else:
252+
filter = schema_filter
253+
254+
return super().list(
255+
filter=filter,
256+
metadata_store_id=metadata_store_id,
257+
project=project,
258+
location=location,
259+
credentials=credentials,
260+
)
261+
205262
def sync_resource(self):
206263
"""Syncs local resource with the resource in metadata store.
207264

google/cloud/aiplatform/metadata/schema/base_context.py

+58-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
# Copyright 2022 Google LLC
3+
# Copyright 2023 Google LLC
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.
@@ -160,6 +160,63 @@ def create(
160160
self._init_with_resource_name(context_name=new_context.resource_name)
161161
return self
162162

163+
@classmethod
164+
def list(
165+
cls,
166+
filter: Optional[str] = None, # pylint: disable=redefined-builtin
167+
metadata_store_id: str = "default",
168+
project: Optional[str] = None,
169+
location: Optional[str] = None,
170+
credentials: Optional[auth_credentials.Credentials] = None,
171+
order_by: Optional[str] = None,
172+
) -> List["BaseContextSchema"]:
173+
"""List all the Context resources with a particular schema.
174+
175+
Args:
176+
filter (str):
177+
Optional. A query to filter available resources for
178+
matching results.
179+
metadata_store_id (str):
180+
The <metadata_store_id> portion of the resource name with
181+
the format:
182+
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>
183+
If not provided, the MetadataStore's ID will be set to "default".
184+
project (str):
185+
Project used to create this resource. Overrides project set in
186+
aiplatform.init.
187+
location (str):
188+
Location used to create this resource. Overrides location set in
189+
aiplatform.init.
190+
credentials (auth_credentials.Credentials):
191+
Custom credentials used to create this resource. Overrides
192+
credentials set in aiplatform.init.
193+
order_by (str):
194+
Optional. How the list of messages is ordered.
195+
Specify the values to order by and an ordering operation. The
196+
default sorting order is ascending. To specify descending order
197+
for a field, users append a " desc" suffix; for example: "foo
198+
desc, bar". Subfields are specified with a ``.`` character, such
199+
as foo.bar. see https://google.aip.dev/132#ordering for more
200+
details.
201+
202+
Returns:
203+
A list of context resources with a particular schema.
204+
205+
"""
206+
schema_filter = f'schema_title="{cls.schema_title}"'
207+
if filter:
208+
filter = f"{filter} AND {schema_filter}"
209+
else:
210+
filter = schema_filter
211+
212+
return super().list(
213+
filter=filter,
214+
metadata_store_id=metadata_store_id,
215+
project=project,
216+
location=location,
217+
credentials=credentials,
218+
)
219+
163220
def add_artifacts_and_executions(
164221
self,
165222
artifact_resource_names: Optional[Sequence[str]] = None,

google/cloud/aiplatform/metadata/schema/base_execution.py

+58-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
# Copyright 2022 Google LLC
3+
# Copyright 2023 Google LLC
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.
@@ -170,6 +170,63 @@ def create(
170170
)
171171
return self
172172

173+
@classmethod
174+
def list(
175+
cls,
176+
filter: Optional[str] = None, # pylint: disable=redefined-builtin
177+
metadata_store_id: str = "default",
178+
project: Optional[str] = None,
179+
location: Optional[str] = None,
180+
credentials: Optional[auth_credentials.Credentials] = None,
181+
order_by: Optional[str] = None,
182+
) -> List["BaseExecutionSchema"]:
183+
"""List all the Execution resources with a particular schema.
184+
185+
Args:
186+
filter (str):
187+
Optional. A query to filter available resources for
188+
matching results.
189+
metadata_store_id (str):
190+
The <metadata_store_id> portion of the resource name with
191+
the format:
192+
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>
193+
If not provided, the MetadataStore's ID will be set to "default".
194+
project (str):
195+
Project used to create this resource. Overrides project set in
196+
aiplatform.init.
197+
location (str):
198+
Location used to create this resource. Overrides location set in
199+
aiplatform.init.
200+
credentials (auth_credentials.Credentials):
201+
Custom credentials used to create this resource. Overrides
202+
credentials set in aiplatform.init.
203+
order_by (str):
204+
Optional. How the list of messages is ordered.
205+
Specify the values to order by and an ordering operation. The
206+
default sorting order is ascending. To specify descending order
207+
for a field, users append a " desc" suffix; for example: "foo
208+
desc, bar". Subfields are specified with a ``.`` character, such
209+
as foo.bar. see https://google.aip.dev/132#ordering for more
210+
details.
211+
212+
Returns:
213+
A list of execution resources with a particular schema.
214+
215+
"""
216+
schema_filter = f'schema_title="{cls.schema_title}"'
217+
if filter:
218+
filter = f"{filter} AND {schema_filter}"
219+
else:
220+
filter = schema_filter
221+
222+
return super().list(
223+
filter=filter,
224+
metadata_store_id=metadata_store_id,
225+
project=project,
226+
location=location,
227+
credentials=credentials,
228+
)
229+
173230
def start_execution(
174231
self,
175232
*,

tests/unit/aiplatform/test_metadata_schema.py

+63
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,27 @@ def create_context_mock():
206206
yield create_context_mock
207207

208208

209+
@pytest.fixture
210+
def list_artifacts_mock():
211+
with patch.object(MetadataServiceClient, "list_artifacts") as list_artifacts_mock:
212+
list_artifacts_mock.return_value = []
213+
yield list_artifacts_mock
214+
215+
216+
@pytest.fixture
217+
def list_executions_mock():
218+
with patch.object(MetadataServiceClient, "list_executions") as list_executions_mock:
219+
list_executions_mock.return_value = []
220+
yield list_executions_mock
221+
222+
223+
@pytest.fixture
224+
def list_contexts_mock():
225+
with patch.object(MetadataServiceClient, "list_contexts") as list_contexts_mock:
226+
list_contexts_mock.return_value = []
227+
yield list_contexts_mock
228+
229+
209230
@pytest.mark.usefixtures("google_auth_mock")
210231
class TestMetadataBaseArtifactSchema:
211232
def setup_method(self):
@@ -369,6 +390,20 @@ class TestArtifact(base_artifact.BaseArtifactSchema):
369390
"sdk_command/aiplatform.metadata.schema.base_artifact.BaseArtifactSchema._init_with_resource_name"
370391
]
371392

393+
def test_list_artifacts(self, list_artifacts_mock):
394+
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
395+
396+
class TestArtifact(base_artifact.BaseArtifactSchema):
397+
schema_title = _TEST_SCHEMA_TITLE
398+
399+
TestArtifact.list()
400+
list_artifacts_mock.assert_called_once_with(
401+
request={
402+
"parent": f"{_TEST_PARENT}/metadataStores/default",
403+
"filter": f'schema_title="{_TEST_SCHEMA_TITLE}"',
404+
}
405+
)
406+
372407

373408
@pytest.mark.usefixtures("google_auth_mock")
374409
class TestMetadataBaseExecutionSchema:
@@ -563,6 +598,20 @@ class TestExecution(base_execution.BaseExecutionSchema):
563598
"sdk_command/aiplatform.metadata.schema.base_execution.BaseExecutionSchema._init_with_resource_name"
564599
]
565600

601+
def test_list_executions(self, list_executions_mock):
602+
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
603+
604+
class TestExecution(base_execution.BaseExecutionSchema):
605+
schema_title = _TEST_SCHEMA_TITLE
606+
607+
TestExecution.list()
608+
list_executions_mock.assert_called_once_with(
609+
request={
610+
"parent": f"{_TEST_PARENT}/metadataStores/default",
611+
"filter": f'schema_title="{_TEST_SCHEMA_TITLE}"',
612+
}
613+
)
614+
566615

567616
@pytest.mark.usefixtures("google_auth_mock")
568617
class TestMetadataBaseContextSchema:
@@ -730,6 +779,20 @@ class TestContext(base_context.BaseContextSchema):
730779
"sdk_command/aiplatform.metadata.schema.base_context.BaseContextSchema._init_with_resource_name"
731780
]
732781

782+
def test_list_contexts(self, list_contexts_mock):
783+
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
784+
785+
class TestContext(base_context.BaseContextSchema):
786+
schema_title = _TEST_SCHEMA_TITLE
787+
788+
TestContext.list()
789+
list_contexts_mock.assert_called_once_with(
790+
request={
791+
"parent": f"{_TEST_PARENT}/metadataStores/default",
792+
"filter": f'schema_title="{_TEST_SCHEMA_TITLE}"',
793+
}
794+
)
795+
733796

734797
@pytest.mark.usefixtures("google_auth_mock")
735798
class TestMetadataGoogleArtifactSchema:

0 commit comments

Comments
 (0)