17
17
from __future__ import annotations
18
18
19
19
import datetime
20
- from typing import Callable , cast , Iterable , Literal , Mapping , Optional , Union
20
+ from typing import Callable , cast , Iterable , Mapping , Optional , Union
21
21
import uuid
22
22
23
23
from google .cloud import bigquery
@@ -35,11 +35,27 @@ def __init__(self, session: bigframes.Session):
35
35
self ._session = session
36
36
self ._base_sql_generator = ml_sql .BaseSqlGenerator ()
37
37
38
- def _apply_sql (
38
+
39
+ class BqmlModel (BaseBqml ):
40
+ """Represents an existing BQML model in BigQuery.
41
+
42
+ Wraps the BQML API and SQL interface to expose the functionality needed for
43
+ BigQuery DataFrames ML.
44
+ """
45
+
46
+ def __init__ (self , session : bigframes .Session , model : bigquery .Model ):
47
+ self ._session = session
48
+ self ._model = model
49
+ self ._model_manipulation_sql_generator = ml_sql .ModelManipulationSqlGenerator (
50
+ self .model_name
51
+ )
52
+
53
+ def _apply_ml_tvf (
39
54
self ,
40
55
input_data : bpd .DataFrame ,
41
- func : Callable [[bpd . DataFrame ], str ],
56
+ apply_sql_tvf : Callable [[str ], str ],
42
57
) -> bpd .DataFrame :
58
+ # Used for predict, transform, distance
43
59
"""Helper to wrap a dataframe in a SQL query, keeping the index intact.
44
60
45
61
Args:
@@ -50,67 +66,28 @@ def _apply_sql(
50
66
the dataframe to be wrapped
51
67
52
68
func (function):
53
- a function that will accept a SQL string and produce a new SQL
54
- string from which to construct the output dataframe. It must
55
- include the index columns of the input SQL .
69
+ Takes an input sql table value and applies a prediction tvf. The
70
+ resulting table value must include all input columns, with new
71
+ columns appended to the end .
56
72
"""
57
- _ , index_col_ids , index_labels = input_data ._to_sql_query (include_index = True )
58
-
59
- sql = func (input_data )
60
- df = self ._session .read_gbq (sql , index_col = index_col_ids )
61
- df .index .names = index_labels
62
-
63
- return df
64
-
65
- def distance (
66
- self ,
67
- x : bpd .DataFrame ,
68
- y : bpd .DataFrame ,
69
- type : Literal ["EUCLIDEAN" , "MANHATTAN" , "COSINE" ],
70
- name : str ,
71
- ) -> bpd .DataFrame :
72
- """Calculate ML.DISTANCE from DataFrame inputs.
73
-
74
- Args:
75
- x:
76
- input DataFrame
77
- y:
78
- input DataFrame
79
- type:
80
- Distance types, accept values are "EUCLIDEAN", "MANHATTAN", "COSINE".
81
- name:
82
- name of the output result column
83
- """
84
- assert len (x .columns ) == 1 and len (y .columns ) == 1
85
-
86
- input_data = x .join (y , how = "outer" ).cache ()
87
- x_column_id , y_column_id = x ._block .value_columns [0 ], y ._block .value_columns [0 ]
88
-
89
- return self ._apply_sql (
90
- input_data ,
91
- lambda source_df : self ._base_sql_generator .ml_distance (
92
- x_column_id ,
93
- y_column_id ,
94
- type = type ,
95
- source_df = source_df ,
96
- name = name ,
97
- ),
73
+ # TODO: Preserve ordering information?
74
+ input_sql , index_col_ids , index_labels = input_data ._to_sql_query (
75
+ include_index = True
98
76
)
99
77
100
-
101
- class BqmlModel (BaseBqml ):
102
- """Represents an existing BQML model in BigQuery.
103
-
104
- Wraps the BQML API and SQL interface to expose the functionality needed for
105
- BigQuery DataFrames ML.
106
- """
107
-
108
- def __init__ (self , session : bigframes .Session , model : bigquery .Model ):
109
- self ._session = session
110
- self ._model = model
111
- self ._model_manipulation_sql_generator = ml_sql .ModelManipulationSqlGenerator (
112
- self .model_name
78
+ result_sql = apply_sql_tvf (input_sql )
79
+ df = self ._session .read_gbq (result_sql , index_col = index_col_ids )
80
+ df .index .names = index_labels
81
+ # Restore column labels
82
+ df .rename (
83
+ columns = {
84
+ label : original_label
85
+ for label , original_label in zip (
86
+ df .columns .values , input_data .columns .values
87
+ )
88
+ }
113
89
)
90
+ return df
114
91
115
92
def _keys (self ):
116
93
return (self ._session , self ._model )
@@ -137,13 +114,13 @@ def model(self) -> bigquery.Model:
137
114
return self ._model
138
115
139
116
def predict (self , input_data : bpd .DataFrame ) -> bpd .DataFrame :
140
- return self ._apply_sql (
117
+ return self ._apply_ml_tvf (
141
118
input_data ,
142
119
self ._model_manipulation_sql_generator .ml_predict ,
143
120
)
144
121
145
122
def transform (self , input_data : bpd .DataFrame ) -> bpd .DataFrame :
146
- return self ._apply_sql (
123
+ return self ._apply_ml_tvf (
147
124
input_data ,
148
125
self ._model_manipulation_sql_generator .ml_transform ,
149
126
)
@@ -153,10 +130,10 @@ def generate_text(
153
130
input_data : bpd .DataFrame ,
154
131
options : Mapping [str , int | float ],
155
132
) -> bpd .DataFrame :
156
- return self ._apply_sql (
133
+ return self ._apply_ml_tvf (
157
134
input_data ,
158
- lambda source_df : self ._model_manipulation_sql_generator .ml_generate_text (
159
- source_df = source_df ,
135
+ lambda source_sql : self ._model_manipulation_sql_generator .ml_generate_text (
136
+ source_sql = source_sql ,
160
137
struct_options = options ,
161
138
),
162
139
)
@@ -166,10 +143,10 @@ def generate_embedding(
166
143
input_data : bpd .DataFrame ,
167
144
options : Mapping [str , int | float ],
168
145
) -> bpd .DataFrame :
169
- return self ._apply_sql (
146
+ return self ._apply_ml_tvf (
170
147
input_data ,
171
- lambda source_df : self ._model_manipulation_sql_generator .ml_generate_embedding (
172
- source_df = source_df ,
148
+ lambda source_sql : self ._model_manipulation_sql_generator .ml_generate_embedding (
149
+ source_sql = source_sql ,
173
150
struct_options = options ,
174
151
),
175
152
)
@@ -179,10 +156,10 @@ def detect_anomalies(
179
156
) -> bpd .DataFrame :
180
157
assert self ._model .model_type in ("PCA" , "KMEANS" , "ARIMA_PLUS" )
181
158
182
- return self ._apply_sql (
159
+ return self ._apply_ml_tvf (
183
160
input_data ,
184
- lambda source_df : self ._model_manipulation_sql_generator .ml_detect_anomalies (
185
- source_df = source_df ,
161
+ lambda source_sql : self ._model_manipulation_sql_generator .ml_detect_anomalies (
162
+ source_sql = source_sql ,
186
163
struct_options = options ,
187
164
),
188
165
)
@@ -192,7 +169,9 @@ def forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
192
169
return self ._session .read_gbq (sql , index_col = "forecast_timestamp" ).reset_index ()
193
170
194
171
def evaluate (self , input_data : Optional [bpd .DataFrame ] = None ):
195
- sql = self ._model_manipulation_sql_generator .ml_evaluate (input_data )
172
+ sql = self ._model_manipulation_sql_generator .ml_evaluate (
173
+ input_data .sql if (input_data is not None ) else None
174
+ )
196
175
197
176
return self ._session .read_gbq (sql )
198
177
@@ -202,7 +181,7 @@ def llm_evaluate(
202
181
task_type : Optional [str ] = None ,
203
182
):
204
183
sql = self ._model_manipulation_sql_generator .ml_llm_evaluate (
205
- input_data , task_type
184
+ input_data . sql , task_type
206
185
)
207
186
208
187
return self ._session .read_gbq (sql )
@@ -336,7 +315,7 @@ def create_model(
336
315
model_ref = self ._create_model_ref (session ._anonymous_dataset )
337
316
338
317
sql = self ._model_creation_sql_generator .create_model (
339
- source_df = input_data ,
318
+ source_sql = input_data . sql ,
340
319
model_ref = model_ref ,
341
320
transforms = transforms ,
342
321
options = options ,
@@ -374,7 +353,7 @@ def create_llm_remote_model(
374
353
model_ref = self ._create_model_ref (session ._anonymous_dataset )
375
354
376
355
sql = self ._model_creation_sql_generator .create_llm_remote_model (
377
- source_df = input_data ,
356
+ source_sql = input_data . sql ,
378
357
model_ref = model_ref ,
379
358
options = options ,
380
359
connection_name = connection_name ,
@@ -407,7 +386,7 @@ def create_time_series_model(
407
386
model_ref = self ._create_model_ref (session ._anonymous_dataset )
408
387
409
388
sql = self ._model_creation_sql_generator .create_model (
410
- source_df = input_data ,
389
+ source_sql = input_data . sql ,
411
390
model_ref = model_ref ,
412
391
transforms = transforms ,
413
392
options = options ,
0 commit comments