Skip to content

Commit f5a20eb

Browse files
Ark-kuncopybara-github
authored andcommitted
fix: LLM - De-hardcoded the max_output_tokens default value for the CodeGenerationModel
The previously default value (128) was inconsistent with the service-side default values of the models: `code-bison` has 1024 and `code-gecko` has 64. More so, the default value was out of range of the `code-gecko` model. This CL fixes these issues. The SDK now relies on the service-side default values when the user does not pass a parameter value explicitly. What can change: When using the `code-bison` model, the default value of `max_output_tokens` effectively increases from 128 to 1024 (the current service-side default value). PiperOrigin-RevId: 559266968
1 parent d11b8e6 commit f5a20eb

File tree

2 files changed

+5
-12
lines changed

2 files changed

+5
-12
lines changed

tests/unit/aiplatform/test_language_models.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -2187,9 +2187,6 @@ def test_code_generation(self):
21872187
# Validating the parameters
21882188
predict_temperature = 0.1
21892189
predict_max_output_tokens = 100
2190-
default_max_output_tokens = (
2191-
language_models.CodeGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS
2192-
)
21932190
stop_sequences = ["\n"]
21942191

21952192
with mock.patch.object(
@@ -2213,7 +2210,7 @@ def test_code_generation(self):
22132210
)
22142211
prediction_parameters = mock_predict.call_args[1]["parameters"]
22152212
assert "temperature" not in prediction_parameters
2216-
assert prediction_parameters["maxOutputTokens"] == default_max_output_tokens
2213+
assert "maxOutputTokens" not in prediction_parameters
22172214

22182215
def test_code_completion(self):
22192216
"""Tests code completion with the code generation model."""
@@ -2255,9 +2252,6 @@ def test_code_completion(self):
22552252
# Validating the parameters
22562253
predict_temperature = 0.1
22572254
predict_max_output_tokens = 100
2258-
default_max_output_tokens = (
2259-
language_models.CodeGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS
2260-
)
22612255

22622256
with mock.patch.object(
22632257
target=prediction_service_client.PredictionServiceClient,
@@ -2278,7 +2272,7 @@ def test_code_completion(self):
22782272
)
22792273
prediction_parameters = mock_predict.call_args[1]["parameters"]
22802274
assert "temperature" not in prediction_parameters
2281-
assert prediction_parameters["maxOutputTokens"] == default_max_output_tokens
2275+
assert "maxOutputTokens" not in prediction_parameters
22822276

22832277
def test_code_generation_model_predict_streaming(self):
22842278
"""Tests the TextGenerationModel.predict_streaming method."""

vertexai/language_models/_language_models.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1686,14 +1686,13 @@ class CodeGenerationModel(_LanguageModel):
16861686
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/code_generation_1.0.0.yaml"
16871687

16881688
_LAUNCH_STAGE = _model_garden_models._SDK_GA_LAUNCH_STAGE
1689-
_DEFAULT_MAX_OUTPUT_TOKENS = 128
16901689

16911690
def _create_prediction_request(
16921691
self,
16931692
prefix: str,
16941693
suffix: Optional[str] = None,
16951694
*,
1696-
max_output_tokens: Optional[int] = _DEFAULT_MAX_OUTPUT_TOKENS,
1695+
max_output_tokens: Optional[int] = None,
16971696
temperature: Optional[float] = None,
16981697
stop_sequences: Optional[List[str]] = None,
16991698
) -> _PredictionRequest:
@@ -1732,7 +1731,7 @@ def predict(
17321731
prefix: str,
17331732
suffix: Optional[str] = None,
17341733
*,
1735-
max_output_tokens: Optional[int] = _DEFAULT_MAX_OUTPUT_TOKENS,
1734+
max_output_tokens: Optional[int] = None,
17361735
temperature: Optional[float] = None,
17371736
stop_sequences: Optional[List[str]] = None,
17381737
) -> "TextGenerationResponse":
@@ -1771,7 +1770,7 @@ def predict_streaming(
17711770
prefix: str,
17721771
suffix: Optional[str] = None,
17731772
*,
1774-
max_output_tokens: Optional[int] = _DEFAULT_MAX_OUTPUT_TOKENS,
1773+
max_output_tokens: Optional[int] = None,
17751774
temperature: Optional[float] = None,
17761775
) -> Iterator[TextGenerationResponse]:
17771776
"""Predicts the code based on previous code.

0 commit comments

Comments
 (0)