@@ -128,14 +128,12 @@ def model(self) -> bigquery.Model:
128
128
return self ._model
129
129
130
130
def predict (self , input_data : bpd .DataFrame ) -> bpd .DataFrame :
131
- # TODO: validate input data schema
132
131
return self ._apply_sql (
133
132
input_data ,
134
133
self ._model_manipulation_sql_generator .ml_predict ,
135
134
)
136
135
137
136
def transform (self , input_data : bpd .DataFrame ) -> bpd .DataFrame :
138
- # TODO: validate input data schema
139
137
return self ._apply_sql (
140
138
input_data ,
141
139
self ._model_manipulation_sql_generator .ml_transform ,
@@ -146,7 +144,6 @@ def generate_text(
146
144
input_data : bpd .DataFrame ,
147
145
options : Mapping [str , int | float ],
148
146
) -> bpd .DataFrame :
149
- # TODO: validate input data schema
150
147
return self ._apply_sql (
151
148
input_data ,
152
149
lambda source_df : self ._model_manipulation_sql_generator .ml_generate_text (
@@ -160,7 +157,6 @@ def generate_text_embedding(
160
157
input_data : bpd .DataFrame ,
161
158
options : Mapping [str , int | float ],
162
159
) -> bpd .DataFrame :
163
- # TODO: validate input data schema
164
160
return self ._apply_sql (
165
161
input_data ,
166
162
lambda source_df : self ._model_manipulation_sql_generator .ml_generate_text_embedding (
@@ -169,12 +165,24 @@ def generate_text_embedding(
169
165
),
170
166
)
171
167
168
+ def detect_anomalies (
169
+ self , input_data : bpd .DataFrame , options : Mapping [str , int | float ]
170
+ ) -> bpd .DataFrame :
171
+ assert self ._model .model_type in ("PCA" , "KMEANS" , "ARIMA_PLUS" )
172
+
173
+ return self ._apply_sql (
174
+ input_data ,
175
+ lambda source_df : self ._model_manipulation_sql_generator .ml_detect_anomalies (
176
+ source_df = source_df ,
177
+ struct_options = options ,
178
+ ),
179
+ )
180
+
172
181
def forecast (self , options : Mapping [str , int | float ]) -> bpd .DataFrame :
173
182
sql = self ._model_manipulation_sql_generator .ml_forecast (struct_options = options )
174
183
return self ._session .read_gbq (sql , index_col = "forecast_timestamp" ).reset_index ()
175
184
176
185
def evaluate (self , input_data : Optional [bpd .DataFrame ] = None ):
177
- # TODO: validate input data schema
178
186
sql = self ._model_manipulation_sql_generator .ml_evaluate (input_data )
179
187
180
188
return self ._session .read_gbq (sql )
0 commit comments