13
13
# limitations under the License.
14
14
#
15
15
16
- from typing import Literal , Optional , Union
16
+ from typing import Dict , Literal , Optional , Union
17
17
18
18
from google .cloud .aiplatform_v1 .types import tuning_job as gca_tuning_job_types
19
-
20
19
from vertexai import generative_models
21
20
from vertexai .tuning import _tuning
22
21
@@ -30,27 +29,28 @@ def train(
30
29
epochs : Optional [int ] = None ,
31
30
learning_rate_multiplier : Optional [float ] = None ,
32
31
adapter_size : Optional [Literal [1 , 4 , 8 , 16 ]] = None ,
32
+ labels : Optional [Dict [str , str ]] = None ,
33
33
) -> "SupervisedTuningJob" :
34
- """Tunes a model using supervised training.
34
+ """Tunes a model using supervised training.
35
35
36
- Args:
37
- source_model (str):
38
- Model name for tuning, e.g., "gemini-1.0-pro-002".
39
- train_dataset: Cloud Storage path to file containing training dataset for tuning .
40
- The dataset should be in JSONL format.
41
- validation_dataset: Cloud Storage path to file containing validation dataset for tuning .
42
- The dataset should be in JSONL format.
43
- tuned_model_display_name: The display name of the
44
- [TunedModel][google.cloud.aiplatform.v1.Model]. The name can
45
- be up to 128 characters long and can consist of any UTF-8 characters .
46
- epochs: Number of training epoches for this tuning job .
47
- learning_rate_multiplier: Learning rate multiplier for tuning.
48
- adapter_size: Adapter size for tuning.
36
+ Args:
37
+ source_model (str): Model name for tuning, e.g., "gemini-1.0-pro-002".
38
+ train_dataset: Cloud Storage path to file containing training dataset for
39
+ tuning. The dataset should be in JSONL format .
40
+ validation_dataset: Cloud Storage path to file containing validation
41
+ dataset for tuning. The dataset should be in JSONL format .
42
+ tuned_model_display_name: The display name of the
43
+ [TunedModel][google.cloud.aiplatform.v1.Model]. The name can be up to
44
+ 128 characters long and can consist of any UTF-8 characters.
45
+ epochs: Number of training epoches for this tuning job .
46
+ learning_rate_multiplier: Learning rate multiplier for tuning.
47
+ adapter_size: Adapter size for tuning.
48
+ labels: User-defined metadata to be associated with trained models
49
49
50
- Returns:
51
- A `TuningJob` object.
52
- """
53
- supervised_tuning_spec = gca_tuning_job_types .SupervisedTuningSpec (
50
+ Returns:
51
+ A `TuningJob` object.
52
+ """
53
+ supervised_tuning_spec = gca_tuning_job_types .SupervisedTuningSpec (
54
54
training_dataset_uri = train_dataset ,
55
55
validation_dataset_uri = validation_dataset ,
56
56
hyper_parameters = gca_tuning_job_types .SupervisedHyperParameters (
@@ -60,14 +60,15 @@ def train(
60
60
),
61
61
)
62
62
63
- if isinstance (source_model , generative_models .GenerativeModel ):
64
- source_model = source_model ._prediction_resource_name .rpartition ('/' )[- 1 ]
63
+ if isinstance (source_model , generative_models .GenerativeModel ):
64
+ source_model = source_model ._prediction_resource_name .rpartition ('/' )[- 1 ]
65
65
66
- return SupervisedTuningJob ._create ( # pylint: disable=protected-access
67
- base_model = source_model ,
68
- tuning_spec = supervised_tuning_spec ,
69
- tuned_model_display_name = tuned_model_display_name ,
70
- )
66
+ return SupervisedTuningJob ._create ( # pylint: disable=protected-access
67
+ base_model = source_model ,
68
+ tuning_spec = supervised_tuning_spec ,
69
+ tuned_model_display_name = tuned_model_display_name ,
70
+ labels = labels ,
71
+ )
71
72
72
73
73
74
class SupervisedTuningJob (_tuning .TuningJob ):
0 commit comments