Skip to content

Commit 81125f9

Browse files
authored
feat: send warnings on LLM prediction partial failures (#216)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://togithub.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕
1 parent a729831 commit 81125f9

File tree

1 file changed

+49
-29
lines changed

1 file changed

+49
-29
lines changed

bigframes/ml/llm.py

+49-29
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,30 @@
1717
from __future__ import annotations
1818

1919
from typing import cast, Literal, Optional, Union
20+
import warnings
2021

2122
import bigframes
2223
from bigframes import clients, constants
2324
from bigframes.core import blocks
2425
from bigframes.ml import base, core, globals, utils
2526
import bigframes.pandas as bpd
2627

27-
_REMOTE_TEXT_GENERATOR_MODEL_ENDPOINT = "text-bison"
28-
_REMOTE_TEXT_GENERATOR_32K_MODEL_ENDPOINT = "text-bison-32k"
29-
_TEXT_GENERATE_RESULT_COLUMN = "ml_generate_text_llm_result"
28+
_TEXT_GENERATOR_BISON_ENDPOINT = "text-bison"
29+
_TEXT_GENERATOR_BISON_32K_ENDPOINT = "text-bison-32k"
30+
_TEXT_GENERATOR_ENDPOINTS = (
31+
_TEXT_GENERATOR_BISON_ENDPOINT,
32+
_TEXT_GENERATOR_BISON_32K_ENDPOINT,
33+
)
3034

31-
_REMOTE_EMBEDDING_GENERATOR_MODEL_ENDPOINT = "textembedding-gecko"
32-
_REMOTE_EMBEDDING_GENERATOR_MUlTILINGUAL_MODEL_ENDPOINT = (
33-
"textembedding-gecko-multilingual"
35+
_EMBEDDING_GENERATOR_GECKO_ENDPOINT = "textembedding-gecko"
36+
_EMBEDDING_GENERATOR_GECKO_MULTILINGUAL_ENDPOINT = "textembedding-gecko-multilingual"
37+
_EMBEDDING_GENERATOR_ENDPOINTS = (
38+
_EMBEDDING_GENERATOR_GECKO_ENDPOINT,
39+
_EMBEDDING_GENERATOR_GECKO_MULTILINGUAL_ENDPOINT,
3440
)
35-
_EMBED_TEXT_RESULT_COLUMN = "text_embedding"
41+
42+
_ML_GENERATE_TEXT_STATUS = "ml_generate_text_status"
43+
_ML_EMBED_TEXT_STATUS = "ml_embed_text_status"
3644

3745

3846
class PaLM2TextGenerator(base.Predictor):
@@ -90,18 +98,16 @@ def _create_bqml_model(self):
9098
connection_id=connection_name_parts[2],
9199
iam_role="aiplatform.user",
92100
)
93-
if self.model_name == _REMOTE_TEXT_GENERATOR_MODEL_ENDPOINT:
94-
options = {
95-
"endpoint": _REMOTE_TEXT_GENERATOR_MODEL_ENDPOINT,
96-
}
97-
elif self.model_name == _REMOTE_TEXT_GENERATOR_32K_MODEL_ENDPOINT:
98-
options = {
99-
"endpoint": _REMOTE_TEXT_GENERATOR_32K_MODEL_ENDPOINT,
100-
}
101-
else:
101+
102+
if self.model_name not in _TEXT_GENERATOR_ENDPOINTS:
102103
raise ValueError(
103-
f"Model name {self.model_name} is not supported. We only support {_REMOTE_TEXT_GENERATOR_MODEL_ENDPOINT} and {_REMOTE_TEXT_GENERATOR_32K_MODEL_ENDPOINT}."
104+
f"Model name {self.model_name} is not supported. We only support {', '.join(_TEXT_GENERATOR_ENDPOINTS)}."
104105
)
106+
107+
options = {
108+
"endpoint": self.model_name,
109+
}
110+
105111
return self._bqml_model_factory.create_remote_model(
106112
session=self.session, connection_name=self.connection_name, options=options
107113
)
@@ -182,7 +188,16 @@ def predict(
182188
"top_p": top_p,
183189
"flatten_json_output": True,
184190
}
185-
return self._bqml_model.generate_text(X, options)
191+
192+
df = self._bqml_model.generate_text(X, options)
193+
194+
if (df[_ML_GENERATE_TEXT_STATUS] != "").any():
195+
warnings.warn(
196+
f"Some predictions failed. Check column {_ML_GENERATE_TEXT_STATUS} for detailed status. You may want to filter the failed rows and retry.",
197+
RuntimeWarning,
198+
)
199+
200+
return df
186201

187202

188203
class PaLM2TextEmbeddingGenerator(base.Predictor):
@@ -241,19 +256,15 @@ def _create_bqml_model(self):
241256
connection_id=connection_name_parts[2],
242257
iam_role="aiplatform.user",
243258
)
244-
if self.model_name == "textembedding-gecko":
245-
options = {
246-
"endpoint": _REMOTE_EMBEDDING_GENERATOR_MODEL_ENDPOINT,
247-
}
248-
elif self.model_name == _REMOTE_EMBEDDING_GENERATOR_MUlTILINGUAL_MODEL_ENDPOINT:
249-
options = {
250-
"endpoint": _REMOTE_EMBEDDING_GENERATOR_MUlTILINGUAL_MODEL_ENDPOINT,
251-
}
252-
else:
259+
260+
if self.model_name not in _EMBEDDING_GENERATOR_ENDPOINTS:
253261
raise ValueError(
254-
f"Model name {self.model_name} is not supported. We only support {_REMOTE_EMBEDDING_GENERATOR_MODEL_ENDPOINT} and {_REMOTE_EMBEDDING_GENERATOR_MUlTILINGUAL_MODEL_ENDPOINT}."
262+
f"Model name {self.model_name} is not supported. We only support {', '.join(_EMBEDDING_GENERATOR_ENDPOINTS)}."
255263
)
256264

265+
options = {
266+
"endpoint": self.model_name,
267+
}
257268
return self._bqml_model_factory.create_remote_model(
258269
session=self.session, connection_name=self.connection_name, options=options
259270
)
@@ -284,4 +295,13 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
284295
options = {
285296
"flatten_json_output": True,
286297
}
287-
return self._bqml_model.generate_text_embedding(X, options)
298+
299+
df = self._bqml_model.generate_text_embedding(X, options)
300+
301+
if (df[_ML_EMBED_TEXT_STATUS] != "").any():
302+
warnings.warn(
303+
f"Some predictions failed. Check column {_ML_EMBED_TEXT_STATUS} for detailed status. You may want to filter the failed rows and retry.",
304+
RuntimeWarning,
305+
)
306+
307+
return df

0 commit comments

Comments
 (0)