Skip to content

Commit dcb6205

Browse files
lingyinwcopybara-github
authored andcommitted
feat: add index_update_method to MatchingEngineIndex create()
PiperOrigin-RevId: 580589542
1 parent 21686ae commit dcb6205

File tree

2 files changed

+72
-17
lines changed

2 files changed

+72
-17
lines changed

google/cloud/aiplatform/matching_engine/matching_engine_index.py

+30-15
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def _create(
108108
credentials: Optional[auth_credentials.Credentials] = None,
109109
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
110110
sync: bool = True,
111+
index_update_method: Optional[str] = None,
111112
) -> "MatchingEngineIndex":
112113
"""Creates a MatchingEngineIndex resource.
113114
@@ -153,27 +154,33 @@ def _create(
153154
credentials set in aiplatform.init.
154155
request_metadata (Sequence[Tuple[str, str]]):
155156
Optional. Strings which should be sent along with the request as metadata.
156-
encryption_spec (str):
157-
Optional. Customer-managed encryption key
158-
spec for data storage. If set, both of the
159-
online and offline data storage will be secured
160-
by this key.
161157
sync (bool):
162158
Optional. Whether to execute this creation synchronously. If False, this method
163159
will be executed in concurrent Future and any downstream object will
164160
be immediately returned and synced when the Future has completed.
161+
index_update_method (str):
162+
Optional. The update method to use with this index. Choose
163+
stream_update or batch_update. If not set, batch update will be
164+
used by default.
165165
166166
Returns:
167167
MatchingEngineIndex - Index resource object
168168
169169
"""
170+
index_update_method_enum = None
171+
if index_update_method in _INDEX_UPDATE_METHOD_TO_ENUM_VALUE:
172+
index_update_method_enum = _INDEX_UPDATE_METHOD_TO_ENUM_VALUE[
173+
index_update_method
174+
]
175+
170176
gapic_index = gca_matching_engine_index.Index(
171177
display_name=display_name,
172178
description=description,
173179
metadata={
174180
"config": config.as_dict(),
175181
"contentsDeltaUri": contents_delta_uri,
176182
},
183+
index_update_method=index_update_method_enum,
177184
)
178185

179186
if labels:
@@ -386,6 +393,7 @@ def create_tree_ah_index(
386393
credentials: Optional[auth_credentials.Credentials] = None,
387394
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
388395
sync: bool = True,
396+
index_update_method: Optional[str] = None,
389397
) -> "MatchingEngineIndex":
390398
"""Creates a MatchingEngineIndex resource that uses the tree-AH algorithm.
391399
@@ -456,15 +464,14 @@ def create_tree_ah_index(
456464
credentials set in aiplatform.init.
457465
request_metadata (Sequence[Tuple[str, str]]):
458466
Optional. Strings which should be sent along with the request as metadata.
459-
encryption_spec (str):
460-
Optional. Customer-managed encryption key
461-
spec for data storage. If set, both of the
462-
online and offline data storage will be secured
463-
by this key.
464467
sync (bool):
465468
Optional. Whether to execute this creation synchronously. If False, this method
466469
will be executed in concurrent Future and any downstream object will
467470
be immediately returned and synced when the Future has completed.
471+
index_update_method (str):
472+
Optional. The update method to use with this index. Choose
473+
STREAM_UPDATE or BATCH_UPDATE. If not set, batch update will be
474+
used by default.
468475
469476
Returns:
470477
MatchingEngineIndex - Index resource object
@@ -494,6 +501,7 @@ def create_tree_ah_index(
494501
credentials=credentials,
495502
request_metadata=request_metadata,
496503
sync=sync,
504+
index_update_method=index_update_method,
497505
)
498506

499507
@classmethod
@@ -512,6 +520,7 @@ def create_brute_force_index(
512520
credentials: Optional[auth_credentials.Credentials] = None,
513521
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
514522
sync: bool = True,
523+
index_update_method: Optional[str] = None,
515524
) -> "MatchingEngineIndex":
516525
"""Creates a MatchingEngineIndex resource that uses the brute force algorithm.
517526
@@ -571,15 +580,14 @@ def create_brute_force_index(
571580
credentials set in aiplatform.init.
572581
request_metadata (Sequence[Tuple[str, str]]):
573582
Optional. Strings which should be sent along with the request as metadata.
574-
encryption_spec (str):
575-
Optional. Customer-managed encryption key
576-
spec for data storage. If set, both of the
577-
online and offline data storage will be secured
578-
by this key.
579583
sync (bool):
580584
Optional. Whether to execute this creation synchronously. If False, this method
581585
will be executed in concurrent Future and any downstream object will
582586
be immediately returned and synced when the Future has completed.
587+
index_update_method (str):
588+
Optional. The update method to use with this index. Choose
589+
stream_update or batch_update. If not set, batch update will be
590+
used by default.
583591
584592
Returns:
585593
MatchingEngineIndex - Index resource object
@@ -605,4 +613,11 @@ def create_brute_force_index(
605613
credentials=credentials,
606614
request_metadata=request_metadata,
607615
sync=sync,
616+
index_update_method=index_update_method,
608617
)
618+
619+
620+
_INDEX_UPDATE_METHOD_TO_ENUM_VALUE = {
621+
"STREAM_UPDATE": gca_matching_engine_index.Index.IndexUpdateMethod.STREAM_UPDATE,
622+
"BATCH_UPDATE": gca_matching_engine_index.Index.IndexUpdateMethod.BATCH_UPDATE,
623+
}

tests/unit/aiplatform/test_matching_engine_index.py

+42-2
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,18 @@
9292
),
9393
]
9494

95+
# Index update method
96+
_TEST_INDEX_BATCH_UPDATE_METHOD = "BATCH_UPDATE"
97+
_TEST_INDEX_STREAM_UPDATE_METHOD = "STREAM_UPDATE"
98+
_TEST_INDEX_EMPTY_UPDATE_METHOD = None
99+
_TEST_INDEX_INVALID_UPDATE_METHOD = "INVALID_UPDATE_METHOD"
100+
_TEST_INDEX_UPDATE_METHOD_EXPECTED_RESULT_MAP = {
101+
_TEST_INDEX_BATCH_UPDATE_METHOD: gca_index.Index.IndexUpdateMethod.BATCH_UPDATE,
102+
_TEST_INDEX_STREAM_UPDATE_METHOD: gca_index.Index.IndexUpdateMethod.STREAM_UPDATE,
103+
_TEST_INDEX_EMPTY_UPDATE_METHOD: None,
104+
_TEST_INDEX_INVALID_UPDATE_METHOD: None,
105+
}
106+
95107

96108
def uuid_mock():
97109
return uuid.UUID(int=1)
@@ -273,7 +285,16 @@ def test_delete_index(self, delete_index_mock, sync):
273285

274286
@pytest.mark.usefixtures("get_index_mock")
275287
@pytest.mark.parametrize("sync", [True, False])
276-
def test_create_tree_ah_index(self, create_index_mock, sync):
288+
@pytest.mark.parametrize(
289+
"index_update_method",
290+
[
291+
_TEST_INDEX_STREAM_UPDATE_METHOD,
292+
_TEST_INDEX_BATCH_UPDATE_METHOD,
293+
_TEST_INDEX_EMPTY_UPDATE_METHOD,
294+
_TEST_INDEX_INVALID_UPDATE_METHOD,
295+
],
296+
)
297+
def test_create_tree_ah_index(self, create_index_mock, sync, index_update_method):
277298
aiplatform.init(project=_TEST_PROJECT)
278299

279300
my_index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
@@ -287,6 +308,7 @@ def test_create_tree_ah_index(self, create_index_mock, sync):
287308
description=_TEST_INDEX_DESCRIPTION,
288309
labels=_TEST_LABELS,
289310
sync=sync,
311+
index_update_method=index_update_method,
290312
)
291313

292314
if not sync:
@@ -312,6 +334,9 @@ def test_create_tree_ah_index(self, create_index_mock, sync):
312334
},
313335
description=_TEST_INDEX_DESCRIPTION,
314336
labels=_TEST_LABELS,
337+
index_update_method=_TEST_INDEX_UPDATE_METHOD_EXPECTED_RESULT_MAP[
338+
index_update_method
339+
],
315340
)
316341

317342
create_index_mock.assert_called_once_with(
@@ -322,7 +347,18 @@ def test_create_tree_ah_index(self, create_index_mock, sync):
322347

323348
@pytest.mark.usefixtures("get_index_mock")
324349
@pytest.mark.parametrize("sync", [True, False])
325-
def test_create_brute_force_index(self, create_index_mock, sync):
350+
@pytest.mark.parametrize(
351+
"index_update_method",
352+
[
353+
_TEST_INDEX_STREAM_UPDATE_METHOD,
354+
_TEST_INDEX_BATCH_UPDATE_METHOD,
355+
_TEST_INDEX_EMPTY_UPDATE_METHOD,
356+
_TEST_INDEX_INVALID_UPDATE_METHOD,
357+
],
358+
)
359+
def test_create_brute_force_index(
360+
self, create_index_mock, sync, index_update_method
361+
):
326362
aiplatform.init(project=_TEST_PROJECT)
327363

328364
my_index = aiplatform.MatchingEngineIndex.create_brute_force_index(
@@ -333,6 +369,7 @@ def test_create_brute_force_index(self, create_index_mock, sync):
333369
description=_TEST_INDEX_DESCRIPTION,
334370
labels=_TEST_LABELS,
335371
sync=sync,
372+
index_update_method=index_update_method,
336373
)
337374

338375
if not sync:
@@ -353,6 +390,9 @@ def test_create_brute_force_index(self, create_index_mock, sync):
353390
},
354391
description=_TEST_INDEX_DESCRIPTION,
355392
labels=_TEST_LABELS,
393+
index_update_method=_TEST_INDEX_UPDATE_METHOD_EXPECTED_RESULT_MAP[
394+
index_update_method
395+
],
356396
)
357397

358398
create_index_mock.assert_called_once_with(

0 commit comments

Comments
 (0)