Skip to content

Commit f7c5567

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: add labels parameter to the supervised tuning train method
PiperOrigin-RevId: 636381156
1 parent 0936f35 commit f7c5567

File tree

1 file changed

+28
-27
lines changed

1 file changed

+28
-27
lines changed

vertexai/tuning/_supervised_tuning.py

+28-27
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,9 @@
1313
# limitations under the License.
1414
#
1515

16-
from typing import Literal, Optional, Union
16+
from typing import Dict, Literal, Optional, Union
1717

1818
from google.cloud.aiplatform_v1.types import tuning_job as gca_tuning_job_types
19-
2019
from vertexai import generative_models
2120
from vertexai.tuning import _tuning
2221

@@ -30,27 +29,28 @@ def train(
3029
epochs: Optional[int] = None,
3130
learning_rate_multiplier: Optional[float] = None,
3231
adapter_size: Optional[Literal[1, 4, 8, 16]] = None,
32+
labels: Optional[Dict[str, str]] = None,
3333
) -> "SupervisedTuningJob":
34-
"""Tunes a model using supervised training.
34+
"""Tunes a model using supervised training.
3535
36-
Args:
37-
source_model (str):
38-
Model name for tuning, e.g., "gemini-1.0-pro-002".
39-
train_dataset: Cloud Storage path to file containing training dataset for tuning.
40-
The dataset should be in JSONL format.
41-
validation_dataset: Cloud Storage path to file containing validation dataset for tuning.
42-
The dataset should be in JSONL format.
43-
tuned_model_display_name: The display name of the
44-
[TunedModel][google.cloud.aiplatform.v1.Model]. The name can
45-
be up to 128 characters long and can consist of any UTF-8 characters.
46-
epochs: Number of training epoches for this tuning job.
47-
learning_rate_multiplier: Learning rate multiplier for tuning.
48-
adapter_size: Adapter size for tuning.
36+
Args:
37+
source_model (str): Model name for tuning, e.g., "gemini-1.0-pro-002".
38+
train_dataset: Cloud Storage path to file containing training dataset for
39+
tuning. The dataset should be in JSONL format.
40+
validation_dataset: Cloud Storage path to file containing validation
41+
dataset for tuning. The dataset should be in JSONL format.
42+
tuned_model_display_name: The display name of the
43+
[TunedModel][google.cloud.aiplatform.v1.Model]. The name can be up to
44+
128 characters long and can consist of any UTF-8 characters.
45+
epochs: Number of training epoches for this tuning job.
46+
learning_rate_multiplier: Learning rate multiplier for tuning.
47+
adapter_size: Adapter size for tuning.
48+
labels: User-defined metadata to be associated with trained models
4949
50-
Returns:
51-
A `TuningJob` object.
52-
"""
53-
supervised_tuning_spec = gca_tuning_job_types.SupervisedTuningSpec(
50+
Returns:
51+
A `TuningJob` object.
52+
"""
53+
supervised_tuning_spec = gca_tuning_job_types.SupervisedTuningSpec(
5454
training_dataset_uri=train_dataset,
5555
validation_dataset_uri=validation_dataset,
5656
hyper_parameters=gca_tuning_job_types.SupervisedHyperParameters(
@@ -60,14 +60,15 @@ def train(
6060
),
6161
)
6262

63-
if isinstance(source_model, generative_models.GenerativeModel):
64-
source_model = source_model._prediction_resource_name.rpartition('/')[-1]
63+
if isinstance(source_model, generative_models.GenerativeModel):
64+
source_model = source_model._prediction_resource_name.rpartition('/')[-1]
6565

66-
return SupervisedTuningJob._create( # pylint: disable=protected-access
67-
base_model=source_model,
68-
tuning_spec=supervised_tuning_spec,
69-
tuned_model_display_name=tuned_model_display_name,
70-
)
66+
return SupervisedTuningJob._create( # pylint: disable=protected-access
67+
base_model=source_model,
68+
tuning_spec=supervised_tuning_spec,
69+
tuned_model_display_name=tuned_model_display_name,
70+
labels=labels,
71+
)
7172

7273

7374
class SupervisedTuningJob(_tuning.TuningJob):

0 commit comments

Comments
 (0)