Skip to content

Commit 3ffc1d2

Browse files
authored
feat: support the score method for PaLM2TextGenerator (#634)
* feat: support the score method for PaLM2TextGenerator * address comments * address additional comments * address minor comments
1 parent 3acc494 commit 3ffc1d2

File tree

7 files changed

+146
-12
lines changed

7 files changed

+146
-12
lines changed

bigframes/ml/core.py

+11
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,17 @@ def evaluate(self, input_data: Optional[bpd.DataFrame] = None):
187187

188188
return self._session.read_gbq(sql)
189189

190+
def llm_evaluate(
191+
self,
192+
input_data: bpd.DataFrame,
193+
task_type: Optional[str] = None,
194+
):
195+
sql = self._model_manipulation_sql_generator.ml_llm_evaluate(
196+
input_data, task_type
197+
)
198+
199+
return self._session.read_gbq(sql)
200+
190201
def arima_evaluate(self, show_all_candidate_models: bool = False):
191202
sql = self._model_manipulation_sql_generator.ml_arima_evaluate(
192203
show_all_candidate_models

bigframes/ml/llm.py

+58-1
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def predict(
220220
221221
Args:
222222
X (bigframes.dataframe.DataFrame or bigframes.series.Series):
223-
Input DataFrame or Series, which needs to contain a column with name "prompt". Only the column will be used as input.
223+
Input DataFrame or Series, which contains only one column of prompts.
224224
Prompts can include preamble, questions, suggestions, instructions, or examples.
225225
226226
temperature (float, default 0.0):
@@ -310,6 +310,63 @@ def predict(
310310

311311
return df
312312

313+
def score(
314+
self,
315+
X: Union[bpd.DataFrame, bpd.Series],
316+
y: Union[bpd.DataFrame, bpd.Series],
317+
task_type: Literal[
318+
"text_generation", "classification", "summarization", "question_answering"
319+
] = "text_generation",
320+
) -> bpd.DataFrame:
321+
"""Calculate evaluation metrics of the model.
322+
323+
.. note::
324+
325+
This product or feature is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the
326+
Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is"
327+
and might have limited support. For more information, see the launch stage descriptions
328+
(https://cloud.google.com/products#product-launch-stages).
329+
330+
.. note::
331+
332+
Output matches that of the BigQuery ML.EVALUTE function.
333+
See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-evaluate#remote-model-llm
334+
for the outputs relevant to this model type.
335+
336+
Args:
337+
X (bigframes.dataframe.DataFrame or bigframes.series.Series):
338+
A BigQuery DataFrame as evaluation data, which contains only one column of input_text
339+
that contains the prompt text to use when evaluating the model.
340+
y (bigframes.dataframe.DataFrame or bigframes.series.Series):
341+
A BigQuery DataFrame as evaluation labels, which contains only one column of output_text
342+
that you would expect to be returned by the model.
343+
task_type (str):
344+
The type of the task for LLM model. Default to "text_generation".
345+
Possible values: "text_generation", "classification", "summarization", and "question_answering".
346+
347+
Returns:
348+
bigframes.dataframe.DataFrame: The DataFrame as evaluation result.
349+
"""
350+
if not self._bqml_model:
351+
raise RuntimeError("A model must be fitted before score")
352+
353+
X, y = utils.convert_to_dataframe(X, y)
354+
355+
if len(X.columns) != 1 or len(y.columns) != 1:
356+
raise ValueError(
357+
f"Only support one column as input for X and y. {constants.FEEDBACK_LINK}"
358+
)
359+
360+
# BQML identified the column by name
361+
X_col_label = cast(blocks.Label, X.columns[0])
362+
y_col_label = cast(blocks.Label, y.columns[0])
363+
X = X.rename(columns={X_col_label: "input_text"})
364+
y = y.rename(columns={y_col_label: "output_text"})
365+
366+
input_data = X.join(y, how="outer")
367+
368+
return self._bqml_model.llm_evaluate(input_data, task_type)
369+
313370
def to_gbq(self, model_name: str, replace: bool = False) -> PaLM2TextGenerator:
314371
"""Save the model to BigQuery.
315372

bigframes/ml/sql.py

+10
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,16 @@ def ml_evaluate(self, source_df: Optional[bpd.DataFrame] = None) -> str:
318318
return f"""SELECT * FROM ML.EVALUATE(MODEL `{self._model_name}`,
319319
({source_sql}))"""
320320

321+
# ML evaluation TVFs
322+
def ml_llm_evaluate(
323+
self, source_df: bpd.DataFrame, task_type: Optional[str] = None
324+
) -> str:
325+
"""Encode ML.EVALUATE for BQML"""
326+
# Note: don't need index as evaluate returns a new table
327+
source_sql, _, _ = source_df._to_sql_query(include_index=False)
328+
return f"""SELECT * FROM ML.EVALUATE(MODEL `{self._model_name}`,
329+
({source_sql}), STRUCT("{task_type}" AS task_type))"""
330+
321331
# ML evaluation TVFs
322332
def ml_arima_evaluate(self, show_all_candidate_models: bool = False) -> str:
323333
"""Encode ML.ARMIA_EVALUATE for BQML"""

tests/system/load/test_llm.py

+49-7
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,12 @@
2222
def llm_fine_tune_df_default_index(
2323
session: bigframes.Session,
2424
) -> bigframes.dataframe.DataFrame:
25-
sql = """
26-
SELECT
27-
CONCAT("Please do sentiment analysis on the following text and only output a number from 0 to 5 where 0 means sadness, 1 means joy, 2 means love, 3 means anger, 4 means fear, and 5 means surprise. Text: ", text) as prompt,
28-
CAST(label AS STRING) as label
29-
FROM `llm_tuning.emotion_classification_train`
30-
"""
31-
return session.read_gbq(sql)
25+
training_table_name = "llm_tuning.emotion_classification_train"
26+
df = session.read_gbq(training_table_name)
27+
prefix = "Please do sentiment analysis on the following text and only output a number from 0 to 5 where 0 means sadness, 1 means joy, 2 means love, 3 means anger, 4 means fear, and 5 means surprise. Text: "
28+
df["prompt"] = prefix + df["text"]
29+
df["label"] = df["label"].astype("string")
30+
return df
3231

3332

3433
@pytest.fixture(scope="session")
@@ -69,3 +68,46 @@ def test_llm_palm_configure_fit(llm_fine_tune_df_default_index, llm_remote_text_
6968
assert all(series.str.len() == 1)
7069

7170
# TODO(ashleyxu b/335492787): After bqml rolled out version control: save, load, check parameters to ensure configuration was kept
71+
72+
73+
def test_llm_palm_score(llm_fine_tune_df_default_index):
74+
model = bigframes.ml.llm.PaLM2TextGenerator(model_name="text-bison")
75+
76+
# Check score to ensure the model was fitted
77+
score_result = model.score(
78+
X=llm_fine_tune_df_default_index[["prompt"]],
79+
y=llm_fine_tune_df_default_index[["label"]],
80+
).to_pandas()
81+
score_result_col = score_result.columns.to_list()
82+
expected_col = [
83+
"bleu4_score",
84+
"rouge-l_precision",
85+
"rouge-l_recall",
86+
"rouge-l_f1_score",
87+
"evaluation_status",
88+
]
89+
assert all(col in score_result_col for col in expected_col)
90+
91+
92+
def test_llm_palm_score_params(llm_fine_tune_df_default_index):
93+
model = bigframes.ml.llm.PaLM2TextGenerator(
94+
model_name="text-bison", max_iterations=1
95+
)
96+
97+
# Check score to ensure the model was fitted
98+
score_result = model.score(
99+
X=llm_fine_tune_df_default_index["prompt"],
100+
y=llm_fine_tune_df_default_index["label"],
101+
task_type="classification",
102+
).to_pandas()
103+
score_result_col = score_result.columns.to_list()
104+
expected_col = [
105+
"trial_id",
106+
"precision",
107+
"recall",
108+
"accuracy",
109+
"f1_score",
110+
"log_loss",
111+
"roc_auc",
112+
]
113+
assert all(col in score_result_col for col in expected_col)

tests/unit/ml/test_sql.py

+14
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,20 @@ def test_ml_predict_correct(
319319
)
320320

321321

322+
def test_ml_llm_evaluate_correct(
323+
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
324+
mock_df: bpd.DataFrame,
325+
):
326+
sql = model_manipulation_sql_generator.ml_llm_evaluate(
327+
source_df=mock_df, task_type="CLASSIFICATION"
328+
)
329+
assert (
330+
sql
331+
== """SELECT * FROM ML.EVALUATE(MODEL `my_project_id.my_dataset_id.my_model_id`,
332+
(input_X_sql), STRUCT("CLASSIFICATION" AS task_type))"""
333+
)
334+
335+
322336
def test_ml_evaluate_correct(
323337
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
324338
mock_df: bpd.DataFrame,

third_party/bigframes_vendored/sklearn/ensemble/_forest.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ class RandomForestRegressor(ForestRegressor):
9595
Number of parallel trees constructed during each iteration. Default to 100. Minimum value is 2.
9696
tree_method (Optional[str]):
9797
Specify which tree method to use. Default to "auto". If this parameter is set to
98-
default, XGBoost will choose the most conservative option available. Possible values: ""exact", "approx",
98+
default, XGBoost will choose the most conservative option available. Possible values: "exact", "approx",
9999
"hist".
100100
min_child_weight (Optional[float]):
101101
Minimum sum of instance weight(hessian) needed in a child. Default to 1.
@@ -160,7 +160,7 @@ class RandomForestClassifier(ForestClassifier):
160160
Number of parallel trees constructed during each iteration. Default to 100. Minimum value is 2.
161161
tree_method (Optional[str]):
162162
Specify which tree method to use. Default to "auto". If this parameter is set to
163-
default, XGBoost will choose the most conservative option available. Possible values: ""exact", "approx",
163+
default, XGBoost will choose the most conservative option available. Possible values: "exact", "approx",
164164
"hist".
165165
min_child_weight (Optional[float]):
166166
Minimum sum of instance weight(hessian) needed in a child. Default to 1.

third_party/bigframes_vendored/xgboost/sklearn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class XGBRegressor(XGBModel, XGBRegressorBase):
6363
Type of normalization algorithm for DART booster. Possible values: "TREE", "FOREST". Default to "TREE".
6464
tree_method (Optional[str]):
6565
Specify which tree method to use. Default to "auto". If this parameter is set to
66-
default, XGBoost will choose the most conservative option available. Possible values: ""exact", "approx",
66+
default, XGBoost will choose the most conservative option available. Possible values: "exact", "approx",
6767
"hist".
6868
min_child_weight (Optional[float]):
6969
Minimum sum of instance weight(hessian) needed in a child. Default to 1.
@@ -110,7 +110,7 @@ class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase):
110110
Type of normalization algorithm for DART booster. Possible values: "TREE", "FOREST". Default to "TREE".
111111
tree_method (Optional[str]):
112112
Specify which tree method to use. Default to "auto". If this parameter is set to
113-
default, XGBoost will choose the most conservative option available. Possible values: ""exact", "approx",
113+
default, XGBoost will choose the most conservative option available. Possible values: "exact", "approx",
114114
"hist".
115115
min_child_weight (Optional[float]):
116116
Minimum sum of instance weight(hessian) needed in a child. Default to 1.

0 commit comments

Comments
 (0)