Skip to content

Commit dafbc1b

Browse files
authored
feat: add to_gbq() method for LLM models (#299)
1 parent 2e1a403 commit dafbc1b

File tree

3 files changed

+145
-2
lines changed

3 files changed

+145
-2
lines changed

bigframes/ml/llm.py

+74
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from typing import cast, Literal, Optional, Union
2020
import warnings
2121

22+
from google.cloud import bigquery
23+
2224
import bigframes
2325
from bigframes import clients, constants
2426
from bigframes.core import blocks, log_adapter
@@ -113,6 +115,26 @@ def _create_bqml_model(self):
113115
session=self.session, connection_name=self.connection_name, options=options
114116
)
115117

118+
@classmethod
119+
def _from_bq(
120+
cls, session: bigframes.Session, model: bigquery.Model
121+
) -> PaLM2TextGenerator:
122+
assert model.model_type == "MODEL_TYPE_UNSPECIFIED"
123+
assert "remoteModelInfo" in model._properties
124+
assert "endpoint" in model._properties["remoteModelInfo"]
125+
assert "connection" in model._properties["remoteModelInfo"]
126+
127+
# Parse the remote model endpoint
128+
bqml_endpoint = model._properties["remoteModelInfo"]["endpoint"]
129+
model_connection = model._properties["remoteModelInfo"]["connection"]
130+
model_endpoint = bqml_endpoint.split("/")[-1]
131+
132+
text_generator_model = cls(
133+
session=session, model_name=model_endpoint, connection_name=model_connection
134+
)
135+
text_generator_model._bqml_model = core.BqmlModel(session, model)
136+
return text_generator_model
137+
116138
def predict(
117139
self,
118140
X: Union[bpd.DataFrame, bpd.Series],
@@ -200,6 +222,21 @@ def predict(
200222

201223
return df
202224

225+
def to_gbq(self, model_name: str, replace: bool = False) -> PaLM2TextGenerator:
226+
"""Save the model to BigQuery.
227+
228+
Args:
229+
model_name (str):
230+
the name of the model.
231+
replace (bool, default False):
232+
whether to replace if the model already exists. Default to False.
233+
234+
Returns:
235+
PaLM2TextGenerator: saved model."""
236+
237+
new_model = self._bqml_model.copy(model_name, replace)
238+
return new_model.session.read_gbq_model(model_name)
239+
203240

204241
@log_adapter.class_logger
205242
class PaLM2TextEmbeddingGenerator(base.Predictor):
@@ -271,6 +308,26 @@ def _create_bqml_model(self):
271308
session=self.session, connection_name=self.connection_name, options=options
272309
)
273310

311+
@classmethod
312+
def _from_bq(
313+
cls, session: bigframes.Session, model: bigquery.Model
314+
) -> PaLM2TextEmbeddingGenerator:
315+
assert model.model_type == "MODEL_TYPE_UNSPECIFIED"
316+
assert "remoteModelInfo" in model._properties
317+
assert "endpoint" in model._properties["remoteModelInfo"]
318+
assert "connection" in model._properties["remoteModelInfo"]
319+
320+
# Parse the remote model endpoint
321+
bqml_endpoint = model._properties["remoteModelInfo"]["endpoint"]
322+
model_connection = model._properties["remoteModelInfo"]["connection"]
323+
model_endpoint = bqml_endpoint.split("/")[-1]
324+
325+
embedding_generator_model = cls(
326+
session=session, model_name=model_endpoint, connection_name=model_connection
327+
)
328+
embedding_generator_model._bqml_model = core.BqmlModel(session, model)
329+
return embedding_generator_model
330+
274331
def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
275332
"""Predict the result from input DataFrame.
276333
@@ -307,3 +364,20 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
307364
)
308365

309366
return df
367+
368+
def to_gbq(
369+
self, model_name: str, replace: bool = False
370+
) -> PaLM2TextEmbeddingGenerator:
371+
"""Save the model to BigQuery.
372+
373+
Args:
374+
model_name (str):
375+
the name of the model.
376+
replace (bool, default False):
377+
whether to replace if the model already exists. Default to False.
378+
379+
Returns:
380+
PaLM2TextEmbeddingGenerator: saved model."""
381+
382+
new_model = self._bqml_model.copy(model_name, replace)
383+
return new_model.session.read_gbq_model(model_name)

bigframes/ml/loader.py

+23
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
forecasting,
2929
imported,
3030
linear_model,
31+
llm,
3132
pipeline,
3233
)
3334

@@ -47,6 +48,15 @@
4748
}
4849
)
4950

51+
_BQML_ENDPOINT_TYPE_MAPPING = MappingProxyType(
52+
{
53+
llm._TEXT_GENERATOR_BISON_ENDPOINT: llm.PaLM2TextGenerator,
54+
llm._TEXT_GENERATOR_BISON_32K_ENDPOINT: llm.PaLM2TextGenerator,
55+
llm._EMBEDDING_GENERATOR_GECKO_ENDPOINT: llm.PaLM2TextEmbeddingGenerator,
56+
llm._EMBEDDING_GENERATOR_GECKO_MULTILINGUAL_ENDPOINT: llm.PaLM2TextEmbeddingGenerator,
57+
}
58+
)
59+
5060

5161
def from_bq(
5262
session: bigframes.Session, bq_model: bigquery.Model
@@ -62,6 +72,8 @@ def from_bq(
6272
ensemble.RandomForestClassifier,
6373
imported.TensorFlowModel,
6474
imported.ONNXModel,
75+
llm.PaLM2TextGenerator,
76+
llm.PaLM2TextEmbeddingGenerator,
6577
pipeline.Pipeline,
6678
]:
6779
"""Load a BQML model to BigQuery DataFrames ML.
@@ -84,6 +96,17 @@ def _model_from_bq(session: bigframes.Session, bq_model: bigquery.Model):
8496
return _BQML_MODEL_TYPE_MAPPING[bq_model.model_type]._from_bq( # type: ignore
8597
session=session, model=bq_model
8698
)
99+
if (
100+
bq_model.model_type == "MODEL_TYPE_UNSPECIFIED"
101+
and "remoteModelInfo" in bq_model._properties
102+
and "endpoint" in bq_model._properties["remoteModelInfo"]
103+
):
104+
# Parse the remote model endpoint
105+
bqml_endpoint = bq_model._properties["remoteModelInfo"]["endpoint"]
106+
endpoint_model = bqml_endpoint.split("/")[-1]
107+
return _BQML_ENDPOINT_TYPE_MAPPING[endpoint_model]._from_bq( # type: ignore
108+
session=session, model=bq_model
109+
)
87110

88111
raise NotImplementedError(
89112
f"Model type {bq_model.model_type} is not yet supported by BigQuery DataFrames. {constants.FEEDBACK_LINK}"

tests/system/small/ml/test_llm.py

+48-2
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,37 @@
1717
from bigframes.ml import llm
1818

1919

20-
def test_create_text_generator_model(palm2_text_generator_model):
20+
def test_create_text_generator_model(
21+
palm2_text_generator_model, dataset_id, bq_connection
22+
):
2123
# Model creation doesn't return error
2224
assert palm2_text_generator_model is not None
2325
assert palm2_text_generator_model._bqml_model is not None
2426

27+
# save, load to ensure configuration was kept
28+
reloaded_model = palm2_text_generator_model.to_gbq(
29+
f"{dataset_id}.temp_text_model", replace=True
30+
)
31+
assert f"{dataset_id}.temp_text_model" == reloaded_model._bqml_model.model_name
32+
assert reloaded_model.model_name == "text-bison"
33+
assert reloaded_model.connection_name == bq_connection
34+
35+
36+
def test_create_text_generator_32k_model(
37+
palm2_text_generator_32k_model, dataset_id, bq_connection
38+
):
39+
# Model creation doesn't return error
40+
assert palm2_text_generator_32k_model is not None
41+
assert palm2_text_generator_32k_model._bqml_model is not None
42+
43+
# save, load to ensure configuration was kept
44+
reloaded_model = palm2_text_generator_32k_model.to_gbq(
45+
f"{dataset_id}.temp_text_model", replace=True
46+
)
47+
assert f"{dataset_id}.temp_text_model" == reloaded_model._bqml_model.model_name
48+
assert reloaded_model.model_name == "text-bison-32k"
49+
assert reloaded_model.connection_name == bq_connection
50+
2551

2652
@pytest.mark.flaky(retries=2, delay=120)
2753
def test_create_text_generator_model_default_session(
@@ -152,19 +178,39 @@ def test_text_generator_predict_with_params_success(
152178
assert all(series.str.len() > 20)
153179

154180

155-
def test_create_embedding_generator_model(palm2_embedding_generator_model):
181+
def test_create_embedding_generator_model(
182+
palm2_embedding_generator_model, dataset_id, bq_connection
183+
):
156184
# Model creation doesn't return error
157185
assert palm2_embedding_generator_model is not None
158186
assert palm2_embedding_generator_model._bqml_model is not None
159187

188+
# save, load to ensure configuration was kept
189+
reloaded_model = palm2_embedding_generator_model.to_gbq(
190+
f"{dataset_id}.temp_embedding_model", replace=True
191+
)
192+
assert f"{dataset_id}.temp_embedding_model" == reloaded_model._bqml_model.model_name
193+
assert reloaded_model.model_name == "textembedding-gecko"
194+
assert reloaded_model.connection_name == bq_connection
195+
160196

161197
def test_create_embedding_generator_multilingual_model(
162198
palm2_embedding_generator_multilingual_model,
199+
dataset_id,
200+
bq_connection,
163201
):
164202
# Model creation doesn't return error
165203
assert palm2_embedding_generator_multilingual_model is not None
166204
assert palm2_embedding_generator_multilingual_model._bqml_model is not None
167205

206+
# save, load to ensure configuration was kept
207+
reloaded_model = palm2_embedding_generator_multilingual_model.to_gbq(
208+
f"{dataset_id}.temp_embedding_model", replace=True
209+
)
210+
assert f"{dataset_id}.temp_embedding_model" == reloaded_model._bqml_model.model_name
211+
assert reloaded_model.model_name == "textembedding-gecko-multilingual"
212+
assert reloaded_model.connection_name == bq_connection
213+
168214

169215
def test_create_text_embedding_generator_model_defaults(bq_connection):
170216
import bigframes.pandas as bpd

0 commit comments

Comments
 (0)