18
18
from __future__ import annotations
19
19
20
20
import typing
21
- from typing import Any , cast , List , Literal , Optional , Tuple , Union
21
+ from typing import cast , Iterable , List , Literal , Optional , Tuple , Union
22
22
23
23
import bigframes_vendored .sklearn .preprocessing ._data
24
24
import bigframes_vendored .sklearn .preprocessing ._discretization
@@ -43,11 +43,10 @@ def __init__(self):
43
43
self ._bqml_model_factory = globals .bqml_model_factory ()
44
44
self ._base_sql_generator = globals .base_sql_generator ()
45
45
46
- # TODO(garrettwu): implement __hash__
47
- def __eq__ (self , other : Any ) -> bool :
48
- return type (other ) is StandardScaler and self ._bqml_model == other ._bqml_model
46
+ def _keys (self ):
47
+ return (self ._bqml_model ,)
49
48
50
- def _compile_to_sql (self , columns : List [str ], X = None ) -> List [Tuple [str , str ]]:
49
+ def _compile_to_sql (self , columns : Iterable [str ], X = None ) -> List [Tuple [str , str ]]:
51
50
"""Compile this transformer to a list of SQL expressions that can be included in
52
51
a BQML TRANSFORM clause
53
52
@@ -125,11 +124,10 @@ def __init__(self):
125
124
self ._bqml_model_factory = globals .bqml_model_factory ()
126
125
self ._base_sql_generator = globals .base_sql_generator ()
127
126
128
- # TODO(garrettwu): implement __hash__
129
- def __eq__ (self , other : Any ) -> bool :
130
- return type (other ) is MaxAbsScaler and self ._bqml_model == other ._bqml_model
127
+ def _keys (self ):
128
+ return (self ._bqml_model ,)
131
129
132
- def _compile_to_sql (self , columns : List [str ], X = None ) -> List [Tuple [str , str ]]:
130
+ def _compile_to_sql (self , columns : Iterable [str ], X = None ) -> List [Tuple [str , str ]]:
133
131
"""Compile this transformer to a list of SQL expressions that can be included in
134
132
a BQML TRANSFORM clause
135
133
@@ -207,11 +205,10 @@ def __init__(self):
207
205
self ._bqml_model_factory = globals .bqml_model_factory ()
208
206
self ._base_sql_generator = globals .base_sql_generator ()
209
207
210
- # TODO(garrettwu): implement __hash__
211
- def __eq__ (self , other : Any ) -> bool :
212
- return type (other ) is MinMaxScaler and self ._bqml_model == other ._bqml_model
208
+ def _keys (self ):
209
+ return (self ._bqml_model ,)
213
210
214
- def _compile_to_sql (self , columns : List [str ], X = None ) -> List [Tuple [str , str ]]:
211
+ def _compile_to_sql (self , columns : Iterable [str ], X = None ) -> List [Tuple [str , str ]]:
215
212
"""Compile this transformer to a list of SQL expressions that can be included in
216
213
a BQML TRANSFORM clause
217
214
@@ -301,18 +298,12 @@ def __init__(
301
298
self ._bqml_model_factory = globals .bqml_model_factory ()
302
299
self ._base_sql_generator = globals .base_sql_generator ()
303
300
304
- # TODO(garrettwu): implement __hash__
305
- def __eq__ (self , other : Any ) -> bool :
306
- return (
307
- type (other ) is KBinsDiscretizer
308
- and self .n_bins == other .n_bins
309
- and self .strategy == other .strategy
310
- and self ._bqml_model == other ._bqml_model
311
- )
301
+ def _keys (self ):
302
+ return (self ._bqml_model , self .n_bins , self .strategy )
312
303
313
304
def _compile_to_sql (
314
305
self ,
315
- columns : List [str ],
306
+ columns : Iterable [str ],
316
307
X : bpd .DataFrame ,
317
308
) -> List [Tuple [str , str ]]:
318
309
"""Compile this transformer to a list of SQL expressions that can be included in
@@ -446,17 +437,10 @@ def __init__(
446
437
self ._bqml_model_factory = globals .bqml_model_factory ()
447
438
self ._base_sql_generator = globals .base_sql_generator ()
448
439
449
- # TODO(garrettwu): implement __hash__
450
- def __eq__ (self , other : Any ) -> bool :
451
- return (
452
- type (other ) is OneHotEncoder
453
- and self ._bqml_model == other ._bqml_model
454
- and self .drop == other .drop
455
- and self .min_frequency == other .min_frequency
456
- and self .max_categories == other .max_categories
457
- )
440
+ def _keys (self ):
441
+ return (self ._bqml_model , self .drop , self .min_frequency , self .max_categories )
458
442
459
- def _compile_to_sql (self , columns : List [str ], X = None ) -> List [Tuple [str , str ]]:
443
+ def _compile_to_sql (self , columns : Iterable [str ], X = None ) -> List [Tuple [str , str ]]:
460
444
"""Compile this transformer to a list of SQL expressions that can be included in
461
445
a BQML TRANSFORM clause
462
446
@@ -572,16 +556,10 @@ def __init__(
572
556
self ._bqml_model_factory = globals .bqml_model_factory ()
573
557
self ._base_sql_generator = globals .base_sql_generator ()
574
558
575
- # TODO(garrettwu): implement __hash__
576
- def __eq__ (self , other : Any ) -> bool :
577
- return (
578
- type (other ) is LabelEncoder
579
- and self ._bqml_model == other ._bqml_model
580
- and self .min_frequency == other .min_frequency
581
- and self .max_categories == other .max_categories
582
- )
559
+ def _keys (self ):
560
+ return (self ._bqml_model , self .min_frequency , self .max_categories )
583
561
584
- def _compile_to_sql (self , columns : List [str ], X = None ) -> List [Tuple [str , str ]]:
562
+ def _compile_to_sql (self , columns : Iterable [str ], X = None ) -> List [Tuple [str , str ]]:
585
563
"""Compile this transformer to a list of SQL expressions that can be included in
586
564
a BQML TRANSFORM clause
587
565
@@ -672,18 +650,17 @@ class PolynomialFeatures(
672
650
)
673
651
674
652
def __init__ (self , degree : int = 2 ):
653
+ if degree not in range (1 , 5 ):
654
+ raise ValueError (f"degree has to be [1, 4], input is { degree } ." )
675
655
self .degree = degree
676
656
self ._bqml_model : Optional [core .BqmlModel ] = None
677
657
self ._bqml_model_factory = globals .bqml_model_factory ()
678
658
self ._base_sql_generator = globals .base_sql_generator ()
679
659
680
- # TODO(garrettwu): implement __hash__
681
- def __eq__ (self , other : Any ) -> bool :
682
- return (
683
- type (other ) is PolynomialFeatures and self ._bqml_model == other ._bqml_model
684
- )
660
+ def _keys (self ):
661
+ return (self ._bqml_model , self .degree )
685
662
686
- def _compile_to_sql (self , columns : List [str ], X = None ) -> List [Tuple [str , str ]]:
663
+ def _compile_to_sql (self , columns : Iterable [str ], X = None ) -> List [Tuple [str , str ]]:
687
664
"""Compile this transformer to a list of SQL expressions that can be included in
688
665
a BQML TRANSFORM clause
689
666
@@ -705,17 +682,18 @@ def _compile_to_sql(self, columns: List[str], X=None) -> List[Tuple[str, str]]:
705
682
]
706
683
707
684
@classmethod
708
- def _parse_from_sql (cls , sql : str ) -> tuple [PolynomialFeatures , str ]:
709
- """Parse SQL to tuple(PolynomialFeatures, column_label ).
685
+ def _parse_from_sql (cls , sql : str ) -> tuple [PolynomialFeatures , tuple [ str , ...] ]:
686
+ """Parse SQL to tuple(PolynomialFeatures, column_labels ).
710
687
711
688
Args:
712
689
sql: SQL string of format "ML.POLYNOMIAL_EXPAND(STRUCT(col_label0, col_label1, ...), degree)"
713
690
714
691
Returns:
715
692
tuple(MaxAbsScaler, column_label)"""
716
- col_label = sql [sql .find ("STRUCT(" ) + 7 : sql .find (")" )]
693
+ col_labels = sql [sql .find ("STRUCT(" ) + 7 : sql .find (")" )].split ("," )
694
+ col_labels = [label .strip () for label in col_labels ]
717
695
degree = int (sql [sql .rfind ("," ) + 1 : sql .rfind (")" )])
718
- return cls (degree ), col_label
696
+ return cls (degree ), tuple ( col_labels )
719
697
720
698
def fit (
721
699
self ,
@@ -762,8 +740,6 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
762
740
df [self ._output_names ],
763
741
)
764
742
765
- # TODO(garrettwu): to_gbq()
766
-
767
743
768
744
PreprocessingType = Union [
769
745
OneHotEncoder ,
@@ -772,4 +748,5 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
772
748
MinMaxScaler ,
773
749
KBinsDiscretizer ,
774
750
LabelEncoder ,
751
+ PolynomialFeatures ,
775
752
]
0 commit comments