@@ -50,7 +50,7 @@ def __init__(self, session: bigframes.Session, model: bigquery.Model):
50
50
self .model_name
51
51
)
52
52
53
- def _predict_sql (
53
+ def _apply_ml_tvf (
54
54
self ,
55
55
input_data : bpd .DataFrame ,
56
56
apply_sql_tvf : Callable [[str ], str ],
@@ -114,13 +114,13 @@ def model(self) -> bigquery.Model:
114
114
return self ._model
115
115
116
116
def predict (self , input_data : bpd .DataFrame ) -> bpd .DataFrame :
117
- return self ._predict_sql (
117
+ return self ._apply_ml_tvf (
118
118
input_data ,
119
119
self ._model_manipulation_sql_generator .ml_predict ,
120
120
)
121
121
122
122
def transform (self , input_data : bpd .DataFrame ) -> bpd .DataFrame :
123
- return self ._predict_sql (
123
+ return self ._apply_ml_tvf (
124
124
input_data ,
125
125
self ._model_manipulation_sql_generator .ml_transform ,
126
126
)
@@ -130,7 +130,7 @@ def generate_text(
130
130
input_data : bpd .DataFrame ,
131
131
options : Mapping [str , int | float ],
132
132
) -> bpd .DataFrame :
133
- return self ._predict_sql (
133
+ return self ._apply_ml_tvf (
134
134
input_data ,
135
135
lambda source_sql : self ._model_manipulation_sql_generator .ml_generate_text (
136
136
source_sql = source_sql ,
@@ -143,7 +143,7 @@ def generate_embedding(
143
143
input_data : bpd .DataFrame ,
144
144
options : Mapping [str , int | float ],
145
145
) -> bpd .DataFrame :
146
- return self ._predict_sql (
146
+ return self ._apply_ml_tvf (
147
147
input_data ,
148
148
lambda source_sql : self ._model_manipulation_sql_generator .ml_generate_embedding (
149
149
source_sql = source_sql ,
@@ -156,7 +156,7 @@ def detect_anomalies(
156
156
) -> bpd .DataFrame :
157
157
assert self ._model .model_type in ("PCA" , "KMEANS" , "ARIMA_PLUS" )
158
158
159
- return self ._predict_sql (
159
+ return self ._apply_ml_tvf (
160
160
input_data ,
161
161
lambda source_sql : self ._model_manipulation_sql_generator .ml_detect_anomalies (
162
162
source_sql = source_sql ,
0 commit comments