Skip to content

Commit 5f0ea37

Browse files
authored
feat: support 32k text-generation and multilingual embedding models (#161)
* feat: support 32k text-generation and embedding multilingual models
1 parent e1817c9 commit 5f0ea37

File tree

3 files changed

+118
-12
lines changed

3 files changed

+118
-12
lines changed

bigframes/ml/llm.py

+35-12
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from __future__ import annotations
1818

19-
from typing import cast, Optional, Union
19+
from typing import cast, Literal, Optional, Union
2020

2121
import bigframes
2222
from bigframes import clients, constants
@@ -25,29 +25,37 @@
2525
import bigframes.pandas as bpd
2626

2727
_REMOTE_TEXT_GENERATOR_MODEL_CODE = "CLOUD_AI_LARGE_LANGUAGE_MODEL_V1"
28+
_REMOTE_TEXT_GENERATOR_32K_MODEL_CODE = "text-bison-32k"
2829
_TEXT_GENERATE_RESULT_COLUMN = "ml_generate_text_llm_result"
2930

3031
_REMOTE_EMBEDDING_GENERATOR_MODEL_CODE = "CLOUD_AI_TEXT_EMBEDDING_MODEL_V1"
32+
_REMOTE_EMBEDDING_GENERATOR_MUlTILINGUAL_MODEL_CODE = "textembedding-gecko-multilingual"
3133
_EMBED_TEXT_RESULT_COLUMN = "text_embedding"
3234

3335

3436
class PaLM2TextGenerator(base.Predictor):
3537
"""PaLM2 text generator LLM model.
3638
3739
Args:
40+
model_name (str, Default to "text-bison"):
41+
The model for natural language tasks. “text-bison” returns model fine-tuned to follow natural language instructions
42+
and is suitable for a variety of language tasks. "text-bison-32k" supports up to 32k tokens per request.
43+
Default to "text-bison".
3844
session (bigframes.Session or None):
3945
BQ session to create the model. If None, use the global default session.
4046
connection_name (str or None):
41-
connection to connect with remote service. str of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
47+
Connection to connect with remote service. str of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
4248
if None, use default connection in session context. BigQuery DataFrame will try to create the connection and attach
4349
permission if the connection isn't fully setup.
4450
"""
4551

4652
def __init__(
4753
self,
54+
model_name: Literal["text-bison", "text-bison-32k"] = "text-bison",
4855
session: Optional[bigframes.Session] = None,
4956
connection_name: Optional[str] = None,
5057
):
58+
self.model_name = model_name
5159
self.session = session or bpd.get_global_session()
5260
self._bq_connection_manager = clients.BqConnectionManager(
5361
self.session.bqconnectionclient, self.session.resourcemanagerclient
@@ -80,11 +88,14 @@ def _create_bqml_model(self):
8088
connection_id=connection_name_parts[2],
8189
iam_role="aiplatform.user",
8290
)
83-
84-
options = {
85-
"remote_service_type": _REMOTE_TEXT_GENERATOR_MODEL_CODE,
86-
}
87-
91+
if self.model_name == "text-bison":
92+
options = {
93+
"remote_service_type": _REMOTE_TEXT_GENERATOR_MODEL_CODE,
94+
}
95+
else:
96+
options = {
97+
"endpoint": _REMOTE_TEXT_GENERATOR_32K_MODEL_CODE,
98+
}
8899
return self._bqml_model_factory.create_remote_model(
89100
session=self.session, connection_name=self.connection_name, options=options
90101
)
@@ -118,7 +129,7 @@ def predict(
118129
119130
top_k (int, default 40):
120131
Top-k changes how the model selects tokens for output. A top-k of 1 means the selected token is the most probable among all tokens
121-
in the models vocabulary (also called greedy decoding), while a top-k of 3 means that the next token is selected from among the 3 most probable tokens (using temperature).
132+
in the model's vocabulary (also called greedy decoding), while a top-k of 3 means that the next token is selected from among the 3 most probable tokens (using temperature).
122133
For each token selection step, the top K tokens with the highest probabilities are sampled. Then tokens are further filtered based on topP with the final token selected using temperature sampling.
123134
Specify a lower value for less random responses and a higher value for more random responses.
124135
Default 40. Possible values [1, 40].
@@ -175,6 +186,10 @@ class PaLM2TextEmbeddingGenerator(base.Predictor):
175186
"""PaLM2 text embedding generator LLM model.
176187
177188
Args:
189+
model_name (str, Default to "textembedding-gecko"):
190+
The model for text embedding. “textembedding-gecko” returns model embeddings for text inputs.
191+
"textembedding-gecko-multilingual" returns model embeddings for text inputs which support over 100 languages
192+
Default to "textembedding-gecko".
178193
session (bigframes.Session or None):
179194
BQ session to create the model. If None, use the global default session.
180195
connection_name (str or None):
@@ -184,9 +199,13 @@ class PaLM2TextEmbeddingGenerator(base.Predictor):
184199

185200
def __init__(
186201
self,
202+
model_name: Literal[
203+
"textembedding-gecko", "textembedding-gecko-multilingual"
204+
] = "textembedding-gecko",
187205
session: Optional[bigframes.Session] = None,
188206
connection_name: Optional[str] = None,
189207
):
208+
self.model_name = model_name
190209
self.session = session or bpd.get_global_session()
191210
self._bq_connection_manager = clients.BqConnectionManager(
192211
self.session.bqconnectionclient, self.session.resourcemanagerclient
@@ -219,10 +238,14 @@ def _create_bqml_model(self):
219238
connection_id=connection_name_parts[2],
220239
iam_role="aiplatform.user",
221240
)
222-
223-
options = {
224-
"remote_service_type": _REMOTE_EMBEDDING_GENERATOR_MODEL_CODE,
225-
}
241+
if self.model_name == "textembedding-gecko":
242+
options = {
243+
"remote_service_type": _REMOTE_EMBEDDING_GENERATOR_MODEL_CODE,
244+
}
245+
else:
246+
options = {
247+
"endpoint": _REMOTE_EMBEDDING_GENERATOR_MUlTILINGUAL_MODEL_CODE,
248+
}
226249

227250
return self._bqml_model_factory.create_remote_model(
228251
session=self.session, connection_name=self.connection_name, options=options

tests/system/small/ml/conftest.py

+18
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,13 @@ def palm2_text_generator_model(session, bq_connection) -> llm.PaLM2TextGenerator
213213
return llm.PaLM2TextGenerator(session=session, connection_name=bq_connection)
214214

215215

216+
@pytest.fixture(scope="session")
217+
def palm2_text_generator_32k_model(session, bq_connection) -> llm.PaLM2TextGenerator:
218+
return llm.PaLM2TextGenerator(
219+
model_name="text-bison-32k", session=session, connection_name=bq_connection
220+
)
221+
222+
216223
@pytest.fixture(scope="function")
217224
def ephemera_palm2_text_generator_model(
218225
session, bq_connection
@@ -229,6 +236,17 @@ def palm2_embedding_generator_model(
229236
)
230237

231238

239+
@pytest.fixture(scope="session")
240+
def palm2_embedding_generator_multilingual_model(
241+
session, bq_connection
242+
) -> llm.PaLM2TextEmbeddingGenerator:
243+
return llm.PaLM2TextEmbeddingGenerator(
244+
model_name="textembedding-gecko-multilingual",
245+
session=session,
246+
connection_name=bq_connection,
247+
)
248+
249+
232250
@pytest.fixture(scope="session")
233251
def time_series_bqml_arima_plus_model(
234252
session, time_series_arima_plus_model_name

tests/system/small/ml/test_llm.py

+65
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ def test_create_text_generator_model(palm2_text_generator_model):
2626
assert palm2_text_generator_model._bqml_model is not None
2727

2828

29+
def test_create_text_generator_32k_model(palm2_text_generator_32k_model):
30+
# Model creation doesn't return error
31+
assert palm2_text_generator_32k_model is not None
32+
assert palm2_text_generator_32k_model._bqml_model is not None
33+
34+
2935
@pytest.mark.flaky(retries=2, delay=120)
3036
def test_create_text_generator_model_default_session(bq_connection, llm_text_pandas_df):
3137
import bigframes.pandas as bpd
@@ -48,6 +54,30 @@ def test_create_text_generator_model_default_session(bq_connection, llm_text_pan
4854
assert all(series.str.len() > 20)
4955

5056

57+
@pytest.mark.flaky(retries=2, delay=120)
58+
def test_create_text_generator_32k_model_default_session(
59+
bq_connection, llm_text_pandas_df
60+
):
61+
import bigframes.pandas as bpd
62+
63+
bpd.close_session()
64+
bpd.options.bigquery.bq_connection = bq_connection
65+
bpd.options.bigquery.location = "us"
66+
67+
model = llm.PaLM2TextGenerator(model_name="text-bison-32k")
68+
assert model is not None
69+
assert model._bqml_model is not None
70+
assert model.connection_name.casefold() == "bigframes-dev.us.bigframes-rf-conn"
71+
72+
llm_text_df = bpd.read_pandas(llm_text_pandas_df)
73+
74+
df = model.predict(llm_text_df).to_pandas()
75+
TestCase().assertSequenceEqual(df.shape, (3, 1))
76+
assert "ml_generate_text_llm_result" in df.columns
77+
series = df["ml_generate_text_llm_result"]
78+
assert all(series.str.len() > 20)
79+
80+
5181
@pytest.mark.flaky(retries=2, delay=120)
5282
def test_create_text_generator_model_default_connection(llm_text_pandas_df):
5383
from bigframes import _config
@@ -127,6 +157,14 @@ def test_create_embedding_generator_model(palm2_embedding_generator_model):
127157
assert palm2_embedding_generator_model._bqml_model is not None
128158

129159

160+
def test_create_embedding_generator_multilingual_model(
161+
palm2_embedding_generator_multilingual_model,
162+
):
163+
# Model creation doesn't return error
164+
assert palm2_embedding_generator_multilingual_model is not None
165+
assert palm2_embedding_generator_multilingual_model._bqml_model is not None
166+
167+
130168
def test_create_text_embedding_generator_model_defaults(bq_connection):
131169
import bigframes.pandas as bpd
132170

@@ -139,6 +177,20 @@ def test_create_text_embedding_generator_model_defaults(bq_connection):
139177
assert model._bqml_model is not None
140178

141179

180+
def test_create_text_embedding_generator_multilingual_model_defaults(bq_connection):
181+
import bigframes.pandas as bpd
182+
183+
bpd.close_session()
184+
bpd.options.bigquery.bq_connection = bq_connection
185+
bpd.options.bigquery.location = "us"
186+
187+
model = llm.PaLM2TextEmbeddingGenerator(
188+
model_name="textembedding-gecko-multilingual"
189+
)
190+
assert model is not None
191+
assert model._bqml_model is not None
192+
193+
142194
@pytest.mark.flaky(retries=2, delay=120)
143195
def test_embedding_generator_predict_success(
144196
palm2_embedding_generator_model, llm_text_df
@@ -152,6 +204,19 @@ def test_embedding_generator_predict_success(
152204
assert value.size == 768
153205

154206

207+
@pytest.mark.flaky(retries=2, delay=120)
208+
def test_embedding_generator_multilingual_predict_success(
209+
palm2_embedding_generator_multilingual_model, llm_text_df
210+
):
211+
df = palm2_embedding_generator_multilingual_model.predict(llm_text_df).to_pandas()
212+
TestCase().assertSequenceEqual(df.shape, (3, 1))
213+
assert "text_embedding" in df.columns
214+
series = df["text_embedding"]
215+
value = series[0]
216+
assert isinstance(value, np.ndarray)
217+
assert value.size == 768
218+
219+
155220
@pytest.mark.flaky(retries=2, delay=120)
156221
def test_embedding_generator_predict_series_success(
157222
palm2_embedding_generator_model, llm_text_df

0 commit comments

Comments
 (0)