|
17 | 17 | from __future__ import annotations
|
18 | 18 |
|
19 | 19 | from typing import cast, Literal, Optional, Union
|
| 20 | +import warnings |
20 | 21 |
|
21 | 22 | import bigframes
|
22 | 23 | from bigframes import clients, constants
|
23 | 24 | from bigframes.core import blocks
|
24 | 25 | from bigframes.ml import base, core, globals, utils
|
25 | 26 | import bigframes.pandas as bpd
|
26 | 27 |
|
27 |
| -_REMOTE_TEXT_GENERATOR_MODEL_ENDPOINT = "text-bison" |
28 |
| -_REMOTE_TEXT_GENERATOR_32K_MODEL_ENDPOINT = "text-bison-32k" |
29 |
| -_TEXT_GENERATE_RESULT_COLUMN = "ml_generate_text_llm_result" |
| 28 | +_TEXT_GENERATOR_BISON_ENDPOINT = "text-bison" |
| 29 | +_TEXT_GENERATOR_BISON_32K_ENDPOINT = "text-bison-32k" |
| 30 | +_TEXT_GENERATOR_ENDPOINTS = ( |
| 31 | + _TEXT_GENERATOR_BISON_ENDPOINT, |
| 32 | + _TEXT_GENERATOR_BISON_32K_ENDPOINT, |
| 33 | +) |
30 | 34 |
|
31 |
| -_REMOTE_EMBEDDING_GENERATOR_MODEL_ENDPOINT = "textembedding-gecko" |
32 |
| -_REMOTE_EMBEDDING_GENERATOR_MUlTILINGUAL_MODEL_ENDPOINT = ( |
33 |
| - "textembedding-gecko-multilingual" |
| 35 | +_EMBEDDING_GENERATOR_GECKO_ENDPOINT = "textembedding-gecko" |
| 36 | +_EMBEDDING_GENERATOR_GECKO_MULTILINGUAL_ENDPOINT = "textembedding-gecko-multilingual" |
| 37 | +_EMBEDDING_GENERATOR_ENDPOINTS = ( |
| 38 | + _EMBEDDING_GENERATOR_GECKO_ENDPOINT, |
| 39 | + _EMBEDDING_GENERATOR_GECKO_MULTILINGUAL_ENDPOINT, |
34 | 40 | )
|
35 |
| -_EMBED_TEXT_RESULT_COLUMN = "text_embedding" |
| 41 | + |
| 42 | +_ML_GENERATE_TEXT_STATUS = "ml_generate_text_status" |
| 43 | +_ML_EMBED_TEXT_STATUS = "ml_embed_text_status" |
36 | 44 |
|
37 | 45 |
|
38 | 46 | class PaLM2TextGenerator(base.Predictor):
|
@@ -90,18 +98,16 @@ def _create_bqml_model(self):
|
90 | 98 | connection_id=connection_name_parts[2],
|
91 | 99 | iam_role="aiplatform.user",
|
92 | 100 | )
|
93 |
| - if self.model_name == _REMOTE_TEXT_GENERATOR_MODEL_ENDPOINT: |
94 |
| - options = { |
95 |
| - "endpoint": _REMOTE_TEXT_GENERATOR_MODEL_ENDPOINT, |
96 |
| - } |
97 |
| - elif self.model_name == _REMOTE_TEXT_GENERATOR_32K_MODEL_ENDPOINT: |
98 |
| - options = { |
99 |
| - "endpoint": _REMOTE_TEXT_GENERATOR_32K_MODEL_ENDPOINT, |
100 |
| - } |
101 |
| - else: |
| 101 | + |
| 102 | + if self.model_name not in _TEXT_GENERATOR_ENDPOINTS: |
102 | 103 | raise ValueError(
|
103 |
| - f"Model name {self.model_name} is not supported. We only support {_REMOTE_TEXT_GENERATOR_MODEL_ENDPOINT} and {_REMOTE_TEXT_GENERATOR_32K_MODEL_ENDPOINT}." |
| 104 | + f"Model name {self.model_name} is not supported. We only support {', '.join(_TEXT_GENERATOR_ENDPOINTS)}." |
104 | 105 | )
|
| 106 | + |
| 107 | + options = { |
| 108 | + "endpoint": self.model_name, |
| 109 | + } |
| 110 | + |
105 | 111 | return self._bqml_model_factory.create_remote_model(
|
106 | 112 | session=self.session, connection_name=self.connection_name, options=options
|
107 | 113 | )
|
@@ -182,7 +188,16 @@ def predict(
|
182 | 188 | "top_p": top_p,
|
183 | 189 | "flatten_json_output": True,
|
184 | 190 | }
|
185 |
| - return self._bqml_model.generate_text(X, options) |
| 191 | + |
| 192 | + df = self._bqml_model.generate_text(X, options) |
| 193 | + |
| 194 | + if (df[_ML_GENERATE_TEXT_STATUS] != "").any(): |
| 195 | + warnings.warn( |
| 196 | + f"Some predictions failed. Check column {_ML_GENERATE_TEXT_STATUS} for detailed status. You may want to filter the failed rows and retry.", |
| 197 | + RuntimeWarning, |
| 198 | + ) |
| 199 | + |
| 200 | + return df |
186 | 201 |
|
187 | 202 |
|
188 | 203 | class PaLM2TextEmbeddingGenerator(base.Predictor):
|
@@ -241,19 +256,15 @@ def _create_bqml_model(self):
|
241 | 256 | connection_id=connection_name_parts[2],
|
242 | 257 | iam_role="aiplatform.user",
|
243 | 258 | )
|
244 |
| - if self.model_name == "textembedding-gecko": |
245 |
| - options = { |
246 |
| - "endpoint": _REMOTE_EMBEDDING_GENERATOR_MODEL_ENDPOINT, |
247 |
| - } |
248 |
| - elif self.model_name == _REMOTE_EMBEDDING_GENERATOR_MUlTILINGUAL_MODEL_ENDPOINT: |
249 |
| - options = { |
250 |
| - "endpoint": _REMOTE_EMBEDDING_GENERATOR_MUlTILINGUAL_MODEL_ENDPOINT, |
251 |
| - } |
252 |
| - else: |
| 259 | + |
| 260 | + if self.model_name not in _EMBEDDING_GENERATOR_ENDPOINTS: |
253 | 261 | raise ValueError(
|
254 |
| - f"Model name {self.model_name} is not supported. We only support {_REMOTE_EMBEDDING_GENERATOR_MODEL_ENDPOINT} and {_REMOTE_EMBEDDING_GENERATOR_MUlTILINGUAL_MODEL_ENDPOINT}." |
| 262 | + f"Model name {self.model_name} is not supported. We only support {', '.join(_EMBEDDING_GENERATOR_ENDPOINTS)}." |
255 | 263 | )
|
256 | 264 |
|
| 265 | + options = { |
| 266 | + "endpoint": self.model_name, |
| 267 | + } |
257 | 268 | return self._bqml_model_factory.create_remote_model(
|
258 | 269 | session=self.session, connection_name=self.connection_name, options=options
|
259 | 270 | )
|
@@ -284,4 +295,13 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
|
284 | 295 | options = {
|
285 | 296 | "flatten_json_output": True,
|
286 | 297 | }
|
287 |
| - return self._bqml_model.generate_text_embedding(X, options) |
| 298 | + |
| 299 | + df = self._bqml_model.generate_text_embedding(X, options) |
| 300 | + |
| 301 | + if (df[_ML_EMBED_TEXT_STATUS] != "").any(): |
| 302 | + warnings.warn( |
| 303 | + f"Some predictions failed. Check column {_ML_EMBED_TEXT_STATUS} for detailed status. You may want to filter the failed rows and retry.", |
| 304 | + RuntimeWarning, |
| 305 | + ) |
| 306 | + |
| 307 | + return df |
0 commit comments