23
23
from bigframes .ml import base , core , globals , utils
24
24
import bigframes .pandas as bpd
25
25
import third_party .bigframes_vendored .sklearn .preprocessing ._data
26
+ import third_party .bigframes_vendored .sklearn .preprocessing ._discretization
26
27
import third_party .bigframes_vendored .sklearn .preprocessing ._encoder
27
28
import third_party .bigframes_vendored .sklearn .preprocessing ._label
28
29
@@ -44,12 +45,15 @@ def __init__(self):
44
45
def __eq__ (self , other : Any ) -> bool :
45
46
return type (other ) is StandardScaler and self ._bqml_model == other ._bqml_model
46
47
47
- def _compile_to_sql (self , columns : List [str ]) -> List [Tuple [str , str ]]:
48
+ def _compile_to_sql (self , columns : List [str ], X = None ) -> List [Tuple [str , str ]]:
48
49
"""Compile this transformer to a list of SQL expressions that can be included in
49
50
a BQML TRANSFORM clause
50
51
51
52
Args:
52
- columns: a list of column names to transform
53
+ columns:
54
+ a list of column names to transform.
55
+ X (default None):
56
+ Ignored.
53
57
54
58
Returns: a list of tuples of (sql_expression, output_name)"""
55
59
return [
@@ -124,12 +128,15 @@ def __init__(self):
124
128
def __eq__ (self , other : Any ) -> bool :
125
129
return type (other ) is MaxAbsScaler and self ._bqml_model == other ._bqml_model
126
130
127
- def _compile_to_sql (self , columns : List [str ]) -> List [Tuple [str , str ]]:
131
+ def _compile_to_sql (self , columns : List [str ], X = None ) -> List [Tuple [str , str ]]:
128
132
"""Compile this transformer to a list of SQL expressions that can be included in
129
133
a BQML TRANSFORM clause
130
134
131
135
Args:
132
- columns: a list of column names to transform
136
+ columns:
137
+ a list of column names to transform.
138
+ X (default None):
139
+ Ignored.
133
140
134
141
Returns: a list of tuples of (sql_expression, output_name)"""
135
142
return [
@@ -204,12 +211,15 @@ def __init__(self):
204
211
def __eq__ (self , other : Any ) -> bool :
205
212
return type (other ) is MinMaxScaler and self ._bqml_model == other ._bqml_model
206
213
207
- def _compile_to_sql (self , columns : List [str ]) -> List [Tuple [str , str ]]:
214
+ def _compile_to_sql (self , columns : List [str ], X = None ) -> List [Tuple [str , str ]]:
208
215
"""Compile this transformer to a list of SQL expressions that can be included in
209
216
a BQML TRANSFORM clause
210
217
211
218
Args:
212
- columns: a list of column names to transform
219
+ columns:
220
+ a list of column names to transform.
221
+ X (default None):
222
+ Ignored.
213
223
214
224
Returns: a list of tuples of (sql_expression, output_name)"""
215
225
return [
@@ -267,6 +277,124 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
267
277
)
268
278
269
279
280
+ class KBinsDiscretizer (
281
+ base .Transformer ,
282
+ third_party .bigframes_vendored .sklearn .preprocessing ._discretization .KBinsDiscretizer ,
283
+ ):
284
+ __doc__ = (
285
+ third_party .bigframes_vendored .sklearn .preprocessing ._discretization .KBinsDiscretizer .__doc__
286
+ )
287
+
288
+ def __init__ (
289
+ self ,
290
+ n_bins : int = 5 ,
291
+ strategy : Literal ["uniform" , "quantile" ] = "quantile" ,
292
+ ):
293
+ if strategy != "uniform" :
294
+ raise NotImplementedError (
295
+ f"Only strategy = 'uniform' is supported now, input is { strategy } ."
296
+ )
297
+ if n_bins < 2 :
298
+ raise ValueError (
299
+ f"n_bins has to be larger than or equal to 2, input is { n_bins } ."
300
+ )
301
+ self .n_bins = n_bins
302
+ self .strategy = strategy
303
+ self ._bqml_model : Optional [core .BqmlModel ] = None
304
+ self ._bqml_model_factory = globals .bqml_model_factory ()
305
+ self ._base_sql_generator = globals .base_sql_generator ()
306
+
307
+ # TODO(garrettwu): implement __hash__
308
+ def __eq__ (self , other : Any ) -> bool :
309
+ return (
310
+ type (other ) is KBinsDiscretizer
311
+ and self .n_bins == other .n_bins
312
+ and self ._bqml_model == other ._bqml_model
313
+ )
314
+
315
+ def _compile_to_sql (
316
+ self ,
317
+ columns : List [str ],
318
+ X : bpd .DataFrame ,
319
+ ) -> List [Tuple [str , str ]]:
320
+ """Compile this transformer to a list of SQL expressions that can be included in
321
+ a BQML TRANSFORM clause
322
+
323
+ Args:
324
+ columns:
325
+ a list of column names to transform
326
+ X:
327
+ The Dataframe with training data.
328
+
329
+ Returns: a list of tuples of (sql_expression, output_name)"""
330
+ array_split_points = {}
331
+ if self .strategy == "uniform" :
332
+ for column in columns :
333
+ min_value = X [column ].min ()
334
+ max_value = X [column ].max ()
335
+ bin_size = (max_value - min_value ) / self .n_bins
336
+ array_split_points [column ] = [
337
+ min_value + i * bin_size for i in range (self .n_bins - 1 )
338
+ ]
339
+
340
+ return [
341
+ (
342
+ self ._base_sql_generator .ml_bucketize (
343
+ column , array_split_points [column ], f"kbinsdiscretizer_{ column } "
344
+ ),
345
+ f"kbinsdiscretizer_{ column } " ,
346
+ )
347
+ for column in columns
348
+ ]
349
+
350
+ @classmethod
351
+ def _parse_from_sql (cls , sql : str ) -> tuple [KBinsDiscretizer , str ]:
352
+ """Parse SQL to tuple(KBinsDiscretizer, column_label).
353
+
354
+ Args:
355
+ sql: SQL string of format "ML.BUCKETIZE({col_label}, array_split_points, FALSE) OVER()"
356
+
357
+ Returns:
358
+ tuple(KBinsDiscretizer, column_label)"""
359
+ s = sql [sql .find ("(" ) + 1 : sql .find (")" )]
360
+ array_split_points = s [s .find ("[" ) + 1 : s .find ("]" )]
361
+ col_label = s [: s .find ("," )]
362
+ n_bins = array_split_points .count ("," ) + 2
363
+ return cls (n_bins , "uniform" ), col_label
364
+
365
+ def fit (
366
+ self ,
367
+ X : Union [bpd .DataFrame , bpd .Series ],
368
+ y = None , # ignored
369
+ ) -> KBinsDiscretizer :
370
+ (X ,) = utils .convert_to_dataframe (X )
371
+
372
+ compiled_transforms = self ._compile_to_sql (X .columns .tolist (), X )
373
+ transform_sqls = [transform_sql for transform_sql , _ in compiled_transforms ]
374
+
375
+ self ._bqml_model = self ._bqml_model_factory .create_model (
376
+ X ,
377
+ options = {"model_type" : "transform_only" },
378
+ transforms = transform_sqls ,
379
+ )
380
+
381
+ # The schema of TRANSFORM output is not available in the model API, so save it during fitting
382
+ self ._output_names = [name for _ , name in compiled_transforms ]
383
+ return self
384
+
385
+ def transform (self , X : Union [bpd .DataFrame , bpd .Series ]) -> bpd .DataFrame :
386
+ if not self ._bqml_model :
387
+ raise RuntimeError ("Must be fitted before transform" )
388
+
389
+ (X ,) = utils .convert_to_dataframe (X )
390
+
391
+ df = self ._bqml_model .transform (X )
392
+ return typing .cast (
393
+ bpd .DataFrame ,
394
+ df [self ._output_names ],
395
+ )
396
+
397
+
270
398
class OneHotEncoder (
271
399
base .Transformer ,
272
400
third_party .bigframes_vendored .sklearn .preprocessing ._encoder .OneHotEncoder ,
@@ -308,13 +436,15 @@ def __eq__(self, other: Any) -> bool:
308
436
and self .max_categories == other .max_categories
309
437
)
310
438
311
- def _compile_to_sql (self , columns : List [str ]) -> List [Tuple [str , str ]]:
439
+ def _compile_to_sql (self , columns : List [str ], X = None ) -> List [Tuple [str , str ]]:
312
440
"""Compile this transformer to a list of SQL expressions that can be included in
313
441
a BQML TRANSFORM clause
314
442
315
443
Args:
316
444
columns:
317
- a list of column names to transform
445
+ a list of column names to transform.
446
+ X (default None):
447
+ Ignored.
318
448
319
449
Returns: a list of tuples of (sql_expression, output_name)"""
320
450
@@ -432,13 +562,15 @@ def __eq__(self, other: Any) -> bool:
432
562
and self .max_categories == other .max_categories
433
563
)
434
564
435
- def _compile_to_sql (self , columns : List [str ]) -> List [Tuple [str , str ]]:
565
+ def _compile_to_sql (self , columns : List [str ], X = None ) -> List [Tuple [str , str ]]:
436
566
"""Compile this transformer to a list of SQL expressions that can be included in
437
567
a BQML TRANSFORM clause
438
568
439
569
Args:
440
570
columns:
441
- a list of column names to transform
571
+ a list of column names to transform.
572
+ X (default None):
573
+ Ignored.
442
574
443
575
Returns: a list of tuples of (sql_expression, output_name)"""
444
576
0 commit comments