Skip to content

Commit c2cf612

Browse files
authored
feat: allow the prediction endpoint to be overridden (#461)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/python-aiplatform/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕
1 parent 8cfd611 commit c2cf612

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

google/cloud/aiplatform/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
}
3434

3535
API_BASE_PATH = "aiplatform.googleapis.com"
36+
PREDICTION_API_BASE_PATH = API_BASE_PATH
3637

3738
# Batch Prediction
3839
BATCH_PREDICTION_INPUT_STORAGE_FORMATS = (

google/cloud/aiplatform/initializer.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def encryption_spec_key_name(self) -> Optional[str]:
194194
return self._encryption_spec_key_name
195195

196196
def get_client_options(
197-
self, location_override: Optional[str] = None
197+
self, location_override: Optional[str] = None, prediction_client: bool = False
198198
) -> client_options.ClientOptions:
199199
"""Creates GAPIC client_options using location and type.
200200
@@ -203,6 +203,8 @@ def get_client_options(
203203
Set this parameter to get client options for a location different from
204204
location set by initializer. Must be a GCP region supported by AI
205205
Platform (Unified).
206+
prediction_client (str): Optional flag to use a prediction endpoint.
207+
206208
207209
Returns:
208210
clients_options (google.api_core.client_options.ClientOptions):
@@ -220,8 +222,14 @@ def get_client_options(
220222

221223
utils.validate_region(region)
222224

225+
service_base_path = (
226+
constants.PREDICTION_API_BASE_PATH
227+
if prediction_client
228+
else constants.API_BASE_PATH
229+
)
230+
223231
return client_options.ClientOptions(
224-
api_endpoint=f"{region}-{constants.API_BASE_PATH}"
232+
api_endpoint=f"{region}-{service_base_path}"
225233
)
226234

227235
def common_location_path(
@@ -278,7 +286,8 @@ def create_client(
278286
kwargs = {
279287
"credentials": credentials or self.credentials,
280288
"client_options": self.get_client_options(
281-
location_override=location_override
289+
location_override=location_override,
290+
prediction_client=prediction_client,
282291
),
283292
"client_info": client_info,
284293
}

0 commit comments

Comments
 (0)