Skip to content

Commit 1a302d2

Browse files
authored
feat: adds function/method enhancements, demo samples (#122)
* feat: adds function/method enhancements
1 parent dd8677c commit 1a302d2

File tree

17 files changed

+331
-39
lines changed

17 files changed

+331
-39
lines changed

.github/CODEOWNERS

+3
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,6 @@
99

1010
# The python-samples-owners team is the default owner for samples
1111
/samples/**/*.py @dizcology @googleapis/python-samples-owners
12+
13+
# The enhanced client library tests are owned by @telpirion
14+
/tests/unit/enhanced_library/*.py @telpirion
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from google.cloud.aiplatform.helpers import value_converter
2+
3+
__all__ = (value_converter,)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import absolute_import
15+
from google.cloud.aiplatform.helpers import value_converter
16+
17+
from proto.marshal import Marshal
18+
from proto.marshal.rules.struct import ValueRule
19+
from google.protobuf.struct_pb2 import Value
20+
21+
22+
class ConversionValueRule(ValueRule):
23+
def to_python(self, value, *, absent: bool = None):
24+
return super().to_python(value, absent=absent)
25+
26+
def to_proto(self, value):
27+
28+
# Need to check whether value is an instance
29+
# of an enhanced type
30+
if callable(getattr(value, "to_value", None)):
31+
return value.to_value()
32+
else:
33+
return super().to_proto(value)
34+
35+
36+
def _add_methods_to_classes_in_package(pkg):
37+
classes = dict(
38+
[(name, cls) for name, cls in pkg.__dict__.items() if isinstance(cls, type)]
39+
)
40+
41+
for class_name, cls in classes.items():
42+
# Add to_value() method to class with docstring
43+
setattr(cls, "to_value", value_converter.to_value)
44+
cls.to_value.__doc__ = value_converter.to_value.__doc__
45+
46+
# Add from_value() method to class with docstring
47+
setattr(cls, "from_value", _add_from_value_to_class(cls))
48+
cls.from_value.__doc__ = value_converter.from_value.__doc__
49+
50+
# Add from_map() method to class with docstring
51+
setattr(cls, "from_map", _add_from_map_to_class(cls))
52+
cls.from_map.__doc__ = value_converter.from_map.__doc__
53+
54+
55+
def _add_from_value_to_class(cls):
56+
def _from_value(value):
57+
return value_converter.from_value(cls, value)
58+
59+
return _from_value
60+
61+
62+
def _add_from_map_to_class(cls):
63+
def _from_map(map_):
64+
return value_converter.from_map(cls, map_)
65+
66+
return _from_map
67+
68+
69+
marshal = Marshal(name="google.cloud.aiplatform.v1beta1")
70+
marshal.register(Value, ConversionValueRule(marshal=marshal))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import absolute_import
15+
from google.protobuf.struct_pb2 import Value
16+
from google.protobuf import json_format
17+
from proto.marshal.collections.maps import MapComposite
18+
from proto.marshal import Marshal
19+
from proto import Message
20+
from proto.message import MessageMeta
21+
22+
23+
def to_value(self: Message) -> Value:
24+
"""Converts a message type to a :class:`~google.protobuf.struct_pb2.Value` object.
25+
26+
Args:
27+
message: the message to convert
28+
29+
Returns:
30+
the message as a :class:`~google.protobuf.struct_pb2.Value` object
31+
"""
32+
tmp_dict = json_format.MessageToDict(self._pb)
33+
return json_format.ParseDict(tmp_dict, Value())
34+
35+
36+
def from_value(cls: MessageMeta, value: Value) -> Message:
37+
"""Creates instance of class from a :class:`~google.protobuf.struct_pb2.Value` object.
38+
39+
Args:
40+
value: a :class:`~google.protobuf.struct_pb2.Value` object
41+
42+
Returns:
43+
Instance of class
44+
"""
45+
value_dict = json_format.MessageToDict(value)
46+
return json_format.ParseDict(value_dict, cls()._pb)
47+
48+
49+
def from_map(cls: MessageMeta, map_: MapComposite) -> Message:
50+
"""Creates instance of class from a :class:`~proto.marshal.collections.maps.MapComposite` object.
51+
52+
Args:
53+
map_: a :class:`~proto.marshal.collections.maps.MapComposite` object
54+
55+
Returns:
56+
Instance of class
57+
"""
58+
marshal = Marshal(name="marshal")
59+
pb = marshal.to_proto(Value, map_)
60+
return from_value(cls, pb)

google/cloud/aiplatform/v1beta1/schema/predict/instance/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17+
from google.cloud.aiplatform.helpers import _decorators
18+
import google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types as pkg
1719

1820
from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.image_classification import (
1921
ImageClassificationPredictionInstance,
@@ -54,3 +56,4 @@
5456
"VideoClassificationPredictionInstance",
5557
"VideoObjectTrackingPredictionInstance",
5658
)
59+
_decorators._add_methods_to_classes_in_package(pkg)

google/cloud/aiplatform/v1beta1/schema/predict/params/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17+
from google.cloud.aiplatform.helpers import _decorators
18+
import google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1.types as pkg
1719

1820
from google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1.types.image_classification import (
1921
ImageClassificationPredictionParams,
@@ -42,3 +44,4 @@
4244
"VideoClassificationPredictionParams",
4345
"VideoObjectTrackingPredictionParams",
4446
)
47+
_decorators._add_methods_to_classes_in_package(pkg)

google/cloud/aiplatform/v1beta1/schema/predict/prediction/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17+
from google.cloud.aiplatform.helpers import _decorators
18+
import google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types as pkg
1719

1820
from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.classification import (
1921
ClassificationPredictionResult,
@@ -62,3 +64,4 @@
6264
"VideoClassificationPredictionResult",
6365
"VideoObjectTrackingPredictionResult",
6466
)
67+
_decorators._add_methods_to_classes_in_package(pkg)

google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_sentiment.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717

1818
import proto # type: ignore
1919

20-
21-
from google.cloud.aiplatform.v1beta1.schema.predict.instance import text_sentiment_pb2 as gcaspi_text_sentiment # type: ignore
20+
# DO NOT OVERWRITE FOLLOWING LINE: it was manually edited.
21+
from google.cloud.aiplatform.v1beta1.schema.predict.instance import (
22+
TextSentimentPredictionInstance,
23+
)
2224

2325

2426
__protobuf__ = proto.module(
@@ -57,9 +59,7 @@ class Prediction(proto.Message):
5759
sentiment = proto.Field(proto.INT32, number=1)
5860

5961
instance = proto.Field(
60-
proto.MESSAGE,
61-
number=1,
62-
message=gcaspi_text_sentiment.TextSentimentPredictionInstance,
62+
proto.MESSAGE, number=1, message=TextSentimentPredictionInstance,
6363
)
6464

6565
prediction = proto.Field(proto.MESSAGE, number=2, message=Prediction,)

google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17+
from google.cloud.aiplatform.helpers import _decorators
18+
import google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types as pkg
1719

1820
from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_forecasting import (
1921
AutoMlForecasting,
@@ -130,3 +132,4 @@
130132
"AutoMlVideoObjectTrackingInputs",
131133
"ExportEvaluatedDataItemsConfig",
132134
)
135+
_decorators._add_methods_to_classes_in_package(pkg)

google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_forecasting.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,14 @@ class AutoMlForecastingInputs(proto.Message):
7878
function over the validation set.
7979
8080
The supported optimization objectives:
81-
"minimize-rmse" (default) - Minimize root-
81+
"minimize-rmse" (default) - Minimize root-
8282
mean-squared error (RMSE). "minimize-mae" -
8383
Minimize mean-absolute error (MAE). "minimize-
8484
rmsle" - Minimize root-mean-squared log error
8585
(RMSLE). "minimize-rmspe" - Minimize root-
8686
mean-squared percentage error (RMSPE).
8787
"minimize-wape-mae" - Minimize the combination
88-
of weighted absolute percentage error (WAPE)
88+
of weighted absolute percentage error (WAPE)
8989
and mean-absolute-error (MAE).
9090
train_budget_milli_node_hours (int):
9191
Required. The train budget of creating this
@@ -418,11 +418,11 @@ class Period(proto.Message):
418418
unit (str):
419419
The time granularity unit of this time
420420
period. The supported unit are:
421-
"hour"
422-
"day"
423-
"week"
424-
"month"
425-
"year".
421+
"hour"
422+
"day"
423+
"week"
424+
"month"
425+
"year".
426426
quantity (int):
427427
The number of units per period, e.g. 3 weeks
428428
or 2 months.

google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_tables.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class AutoMlTablesInputs(proto.Message):
6161
produce. "classification" - Predict one out of
6262
multiple target values is
6363
picked for each row.
64-
"regression" - Predict a value based on its
64+
"regression" - Predict a value based on its
6565
relation to other values. This
6666
type is available only to columns that contain
6767
semantically numeric values, i.e. integers or
@@ -87,22 +87,22 @@ class AutoMlTablesInputs(proto.Message):
8787
the prediction type. If the field is not set, a
8888
default objective function is used.
8989
classification (binary):
90-
"maximize-au-roc" (default) - Maximize the
90+
"maximize-au-roc" (default) - Maximize the
9191
area under the receiver
9292
operating characteristic (ROC) curve.
9393
"minimize-log-loss" - Minimize log loss.
94-
"maximize-au-prc" - Maximize the area under
94+
"maximize-au-prc" - Maximize the area under
9595
the precision-recall curve. "maximize-
9696
precision-at-recall" - Maximize precision for a
9797
specified
9898
recall value. "maximize-recall-at-precision" -
9999
Maximize recall for a specified
100100
precision value.
101101
classification (multi-class):
102-
"minimize-log-loss" (default) - Minimize log
102+
"minimize-log-loss" (default) - Minimize log
103103
loss.
104104
regression:
105-
"minimize-rmse" (default) - Minimize root-
105+
"minimize-rmse" (default) - Minimize root-
106106
mean-squared error (RMSE). "minimize-mae" -
107107
Minimize mean-absolute error (MAE). "minimize-
108108
rmsle" - Minimize root-mean-squared log error

samples/snippets/create_training_pipeline_image_classification_sample.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
# [START aiplatform_create_training_pipeline_image_classification_sample]
1616
from google.cloud import aiplatform
17-
from google.protobuf import json_format
18-
from google.protobuf.struct_pb2 import Value
17+
from google.cloud.aiplatform.v1beta1.schema.trainingjob import definition
18+
ModelType = definition.AutoMlImageClassificationInputs().ModelType
1919

2020

2121
def create_training_pipeline_image_classification_sample(
@@ -31,13 +31,14 @@ def create_training_pipeline_image_classification_sample(
3131
# Initialize client that will be used to create and send requests.
3232
# This client only needs to be created once, and can be reused for multiple requests.
3333
client = aiplatform.gapic.PipelineServiceClient(client_options=client_options)
34-
training_task_inputs_dict = {
35-
"multiLabel": True,
36-
"modelType": "CLOUD",
37-
"budgetMilliNodeHours": 8000,
38-
"disableEarlyStopping": False,
39-
}
40-
training_task_inputs = json_format.ParseDict(training_task_inputs_dict, Value())
34+
35+
icn_training_inputs = definition.AutoMlImageClassificationInputs(
36+
multi_label=True,
37+
model_type=ModelType.CLOUD,
38+
budget_milli_node_hours=8000,
39+
disable_early_stopping=False
40+
)
41+
training_task_inputs = icn_training_inputs.to_value()
4142

4243
training_pipeline = {
4344
"display_name": display_name,

samples/snippets/predict_image_classification_sample.py

+17-12
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
import base64
1717

1818
from google.cloud import aiplatform
19-
from google.protobuf import json_format
20-
from google.protobuf.struct_pb2 import Value
19+
from google.cloud.aiplatform.v1beta1.schema.predict import instance
20+
from google.cloud.aiplatform.v1beta1.schema.predict import params
21+
from google.cloud.aiplatform.v1beta1.schema.predict import prediction
2122

2223

2324
def predict_image_classification_sample(
@@ -37,25 +38,29 @@ def predict_image_classification_sample(
3738

3839
# The format of each instance should conform to the deployed model's prediction input schema.
3940
encoded_content = base64.b64encode(file_content).decode("utf-8")
40-
instance_dict = {"content": encoded_content}
4141

42-
instance = json_format.ParseDict(instance_dict, Value())
43-
instances = [instance]
44-
# See gs://google-cloud-aiplatform/schema/predict/params/image_classification_1.0.0.yaml for the format of the parameters.
45-
parameters_dict = {"confidence_threshold": 0.5, "max_predictions": 5}
46-
parameters = json_format.ParseDict(parameters_dict, Value())
42+
instance_obj = instance.ImageClassificationPredictionInstance(
43+
content=encoded_content)
44+
45+
instance_val = instance_obj.to_value()
46+
instances = [instance_val]
47+
48+
params_obj = params.ImageClassificationPredictionParams(
49+
confidence_threshold=0.5, max_predictions=5)
50+
4751
endpoint = client.endpoint_path(
4852
project=project, location=location, endpoint=endpoint_id
4953
)
5054
response = client.predict(
51-
endpoint=endpoint, instances=instances, parameters=parameters
55+
endpoint=endpoint, instances=instances, parameters=params_obj
5256
)
5357
print("response")
54-
print(" deployed_model_id:", response.deployed_model_id)
58+
print("\tdeployed_model_id:", response.deployed_model_id)
5559
# See gs://google-cloud-aiplatform/schema/predict/prediction/classification.yaml for the format of the predictions.
5660
predictions = response.predictions
57-
for prediction in predictions:
58-
print(" prediction:", dict(prediction))
61+
for prediction_ in predictions:
62+
prediction_obj = prediction.ClassificationPredictionResult.from_map(prediction_)
63+
print(prediction_obj)
5964

6065

6166
# [END aiplatform_predict_image_classification_sample]

samples/snippets/predict_image_classification_sample_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,4 @@ def test_ucaip_generated_predict_image_classification_sample(capsys):
3131
)
3232

3333
out, _ = capsys.readouterr()
34-
assert 'string_value: "daisy"' in out
34+
assert 'deployed_model_id:' in out

0 commit comments

Comments
 (0)