@@ -742,6 +742,109 @@ def test_read_gbq_function_enforces_explicit_types(
742
742
)
743
743
744
744
745
+ @pytest .mark .flaky (retries = 2 , delay = 120 )
746
+ def test_df_apply_axis_1 (session , scalars_dfs ):
747
+ columns = [
748
+ "bool_col" ,
749
+ "int64_col" ,
750
+ "int64_too" ,
751
+ "float64_col" ,
752
+ "string_col" ,
753
+ "bytes_col" ,
754
+ ]
755
+ scalars_df , scalars_pandas_df = scalars_dfs
756
+
757
+ def add_ints (row ):
758
+ return row ["int64_col" ] + row ["int64_too" ]
759
+
760
+ with pytest .warns (
761
+ bigframes .exceptions .PreviewWarning ,
762
+ match = "input_types=Series is in preview." ,
763
+ ):
764
+ add_ints_remote = session .remote_function (
765
+ bigframes .series .Series ,
766
+ int ,
767
+ )(add_ints )
768
+
769
+ with pytest .warns (
770
+ bigframes .exceptions .PreviewWarning , match = "axis=1 scenario is in preview."
771
+ ):
772
+ bf_result = scalars_df [columns ].apply (add_ints_remote , axis = 1 ).to_pandas ()
773
+
774
+ pd_result = scalars_pandas_df [columns ].apply (add_ints , axis = 1 )
775
+
776
+ # bf_result.dtype is 'Int64' while pd_result.dtype is 'object', ignore this
777
+ # mismatch by using check_dtype=False.
778
+ #
779
+ # bf_result.to_numpy() produces an array of numpy.float64's
780
+ # (in system_prerelease tests), while pd_result.to_numpy() produces an
781
+ # array of ints, ignore this mismatch by using check_exact=False.
782
+ pd .testing .assert_series_equal (
783
+ pd_result , bf_result , check_dtype = False , check_exact = False
784
+ )
785
+
786
+
787
+ @pytest .mark .flaky (retries = 2 , delay = 120 )
788
+ def test_df_apply_axis_1_ordering (session , scalars_dfs ):
789
+ columns = ["bool_col" , "int64_col" , "int64_too" , "float64_col" , "string_col" ]
790
+ ordering_columns = ["bool_col" , "int64_col" ]
791
+ scalars_df , scalars_pandas_df = scalars_dfs
792
+
793
+ def add_ints (row ):
794
+ return row ["int64_col" ] + row ["int64_too" ]
795
+
796
+ add_ints_remote = session .remote_function (bigframes .series .Series , int )(add_ints )
797
+
798
+ bf_result = (
799
+ scalars_df [columns ]
800
+ .sort_values (ordering_columns )
801
+ .apply (add_ints_remote , axis = 1 )
802
+ .to_pandas ()
803
+ )
804
+ pd_result = (
805
+ scalars_pandas_df [columns ].sort_values (ordering_columns ).apply (add_ints , axis = 1 )
806
+ )
807
+
808
+ # bf_result.dtype is 'Int64' while pd_result.dtype is 'object', ignore this
809
+ # mismatch by using check_dtype=False.
810
+ #
811
+ # bf_result.to_numpy() produces an array of numpy.float64's
812
+ # (in system_prerelease tests), while pd_result.to_numpy() produces an
813
+ # array of ints, ignore this mismatch by using check_exact=False.
814
+ pd .testing .assert_series_equal (
815
+ pd_result , bf_result , check_dtype = False , check_exact = False
816
+ )
817
+
818
+
819
+ @pytest .mark .flaky (retries = 2 , delay = 120 )
820
+ def test_df_apply_axis_1_multiindex (session ):
821
+ pd_df = pd .DataFrame (
822
+ {"x" : [1 , 2 , 3 ], "y" : [1.5 , 3.75 , 5 ], "z" : ["pq" , "rs" , "tu" ]},
823
+ index = pd .MultiIndex .from_tuples ([("a" , 100 ), ("a" , 200 ), ("b" , 300 )]),
824
+ )
825
+ bf_df = session .read_pandas (pd_df )
826
+
827
+ def add_numbers (row ):
828
+ return row ["x" ] + row ["y" ]
829
+
830
+ add_numbers_remote = session .remote_function (bigframes .series .Series , float )(
831
+ add_numbers
832
+ )
833
+
834
+ bf_result = bf_df .apply (add_numbers_remote , axis = 1 ).to_pandas ()
835
+ pd_result = pd_df .apply (add_numbers , axis = 1 )
836
+
837
+ # bf_result.dtype is 'Float64' while pd_result.dtype is 'float64', ignore this
838
+ # mismatch by using check_dtype=False.
839
+ #
840
+ # bf_result.index[0].dtype is 'string[pyarrow]' while
841
+ # pd_result.index[0].dtype is 'object', ignore this mismatch by using
842
+ # check_index_type=False.
843
+ pd .testing .assert_series_equal (
844
+ pd_result , bf_result , check_dtype = False , check_index_type = False
845
+ )
846
+
847
+
745
848
def test_df_apply_axis_1_unsupported_callable (scalars_dfs ):
746
849
scalars_df , scalars_pandas_df = scalars_dfs
747
850
columns = ["bool_col" , "int64_col" , "int64_too" , "float64_col" , "string_col" ]
0 commit comments