16
16
import datetime
17
17
import decimal
18
18
import functools
19
+ import operator
19
20
import warnings
20
21
21
22
import mock
@@ -911,47 +912,66 @@ def test_dataframe_to_parquet_compression_method(module_under_test):
911
912
def test_dataframe_to_bq_schema_fallback_needed_wo_pyarrow (module_under_test ):
912
913
dataframe = pandas .DataFrame (
913
914
data = [
914
- {"id" : 10 , "status" : "FOO" , "execution_date" : datetime .date (2019 , 5 , 10 )},
915
- {"id" : 20 , "status" : "BAR" , "created_at" : datetime .date (2018 , 9 , 12 )},
915
+ {"id" : 10 , "status" : u "FOO" , "execution_date" : datetime .date (2019 , 5 , 10 )},
916
+ {"id" : 20 , "status" : u "BAR" , "created_at" : datetime .date (2018 , 9 , 12 )},
916
917
]
917
918
)
918
919
919
920
no_pyarrow_patch = mock .patch (module_under_test .__name__ + ".pyarrow" , None )
920
921
921
- with no_pyarrow_patch :
922
+ with no_pyarrow_patch , warnings . catch_warnings ( record = True ) as warned :
922
923
detected_schema = module_under_test .dataframe_to_bq_schema (
923
924
dataframe , bq_schema = []
924
925
)
925
926
926
927
assert detected_schema is None
927
928
929
+ # a warning should also be issued
930
+ expected_warnings = [
931
+ warning for warning in warned if "could not determine" in str (warning ).lower ()
932
+ ]
933
+ assert len (expected_warnings ) == 1
934
+ msg = str (expected_warnings [0 ])
935
+ assert "execution_date" in msg and "created_at" in msg
936
+
928
937
929
938
@pytest .mark .skipif (pandas is None , reason = "Requires `pandas`" )
930
939
@pytest .mark .skipif (pyarrow is None , reason = "Requires `pyarrow`" )
931
940
def test_dataframe_to_bq_schema_fallback_needed_w_pyarrow (module_under_test ):
932
941
dataframe = pandas .DataFrame (
933
942
data = [
934
- {"id" : 10 , "status" : "FOO" , "created_at" : datetime .date (2019 , 5 , 10 )},
935
- {"id" : 20 , "status" : "BAR" , "created_at" : datetime .date (2018 , 9 , 12 )},
943
+ {"id" : 10 , "status" : u "FOO" , "created_at" : datetime .date (2019 , 5 , 10 )},
944
+ {"id" : 20 , "status" : u "BAR" , "created_at" : datetime .date (2018 , 9 , 12 )},
936
945
]
937
946
)
938
947
939
- detected_schema = module_under_test .dataframe_to_bq_schema (dataframe , bq_schema = [])
948
+ with warnings .catch_warnings (record = True ) as warned :
949
+ detected_schema = module_under_test .dataframe_to_bq_schema (
950
+ dataframe , bq_schema = []
951
+ )
952
+
940
953
expected_schema = (
941
954
schema .SchemaField ("id" , "INTEGER" , mode = "NULLABLE" ),
942
955
schema .SchemaField ("status" , "STRING" , mode = "NULLABLE" ),
943
956
schema .SchemaField ("created_at" , "DATE" , mode = "NULLABLE" ),
944
957
)
945
- assert detected_schema == expected_schema
958
+ by_name = operator .attrgetter ("name" )
959
+ assert sorted (detected_schema , key = by_name ) == sorted (expected_schema , key = by_name )
960
+
961
+ # there should be no relevant warnings
962
+ unwanted_warnings = [
963
+ warning for warning in warned if "could not determine" in str (warning ).lower ()
964
+ ]
965
+ assert not unwanted_warnings
946
966
947
967
948
968
@pytest .mark .skipif (pandas is None , reason = "Requires `pandas`" )
949
969
@pytest .mark .skipif (pyarrow is None , reason = "Requires `pyarrow`" )
950
970
def test_dataframe_to_bq_schema_pyarrow_fallback_fails (module_under_test ):
951
971
dataframe = pandas .DataFrame (
952
972
data = [
953
- {"id " : 10 , "status " : "FOO" , "all_items " : [ 10.1 , 10.2 ] },
954
- {"id " : 20 , "status " : "BAR" , "all_items " : [ 20.1 , 20.2 ] },
973
+ {"struct_field " : { "one " : 2 } , "status " : u"FOO" },
974
+ {"struct_field " : { "two " : u"222" } , "status " : u"BAR" },
955
975
]
956
976
)
957
977
@@ -962,12 +982,107 @@ def test_dataframe_to_bq_schema_pyarrow_fallback_fails(module_under_test):
962
982
963
983
assert detected_schema is None
964
984
965
- expected_warnings = []
966
- for warning in warned :
967
- if "Pyarrow could not" in str (warning ):
968
- expected_warnings .append (warning )
985
+ # a warning should also be issued
986
+ expected_warnings = [
987
+ warning for warning in warned if "could not determine" in str (warning ).lower ()
988
+ ]
989
+ assert len (expected_warnings ) == 1
990
+ assert "struct_field" in str (expected_warnings [0 ])
991
+
992
+
993
+ @pytest .mark .skipif (pandas is None , reason = "Requires `pandas`" )
994
+ @pytest .mark .skipif (pyarrow is None , reason = "Requires `pyarrow`" )
995
+ def test_currate_schema_type_detection_succeeds (module_under_test ):
996
+ dataframe = pandas .DataFrame (
997
+ data = [
998
+ {
999
+ "bool_field" : False ,
1000
+ "int_field" : 123 ,
1001
+ "float_field" : 3.141592 ,
1002
+ "time_field" : datetime .time (17 , 59 , 47 ),
1003
+ "timestamp_field" : datetime .datetime (2005 , 5 , 31 , 14 , 25 , 55 ),
1004
+ "date_field" : datetime .date (2005 , 5 , 31 ),
1005
+ "bytes_field" : b"some bytes" ,
1006
+ "string_field" : u"some characters" ,
1007
+ "numeric_field" : decimal .Decimal ("123.456" ),
1008
+ }
1009
+ ]
1010
+ )
1011
+
1012
+ # NOTE: In Pandas dataframe, the dtype of Python's datetime instances is
1013
+ # set to "datetime64[ns]", and pyarrow converts that to pyarrow.TimestampArray.
1014
+ # We thus cannot expect to get a DATETIME date when converting back to the
1015
+ # BigQuery type.
1016
+
1017
+ current_schema = (
1018
+ schema .SchemaField ("bool_field" , field_type = None , mode = "NULLABLE" ),
1019
+ schema .SchemaField ("int_field" , field_type = None , mode = "NULLABLE" ),
1020
+ schema .SchemaField ("float_field" , field_type = None , mode = "NULLABLE" ),
1021
+ schema .SchemaField ("time_field" , field_type = None , mode = "NULLABLE" ),
1022
+ schema .SchemaField ("timestamp_field" , field_type = None , mode = "NULLABLE" ),
1023
+ schema .SchemaField ("date_field" , field_type = None , mode = "NULLABLE" ),
1024
+ schema .SchemaField ("bytes_field" , field_type = None , mode = "NULLABLE" ),
1025
+ schema .SchemaField ("string_field" , field_type = None , mode = "NULLABLE" ),
1026
+ schema .SchemaField ("numeric_field" , field_type = None , mode = "NULLABLE" ),
1027
+ )
1028
+
1029
+ with warnings .catch_warnings (record = True ) as warned :
1030
+ currated_schema = module_under_test .currate_schema (dataframe , current_schema )
969
1031
1032
+ # there should be no relevant warnings
1033
+ unwanted_warnings = [
1034
+ warning for warning in warned if "Pyarrow could not" in str (warning )
1035
+ ]
1036
+ assert not unwanted_warnings
1037
+
1038
+ # the currated schema must match the expected
1039
+ expected_schema = (
1040
+ schema .SchemaField ("bool_field" , field_type = "BOOL" , mode = "NULLABLE" ),
1041
+ schema .SchemaField ("int_field" , field_type = "INT64" , mode = "NULLABLE" ),
1042
+ schema .SchemaField ("float_field" , field_type = "FLOAT64" , mode = "NULLABLE" ),
1043
+ schema .SchemaField ("time_field" , field_type = "TIME" , mode = "NULLABLE" ),
1044
+ schema .SchemaField ("timestamp_field" , field_type = "TIMESTAMP" , mode = "NULLABLE" ),
1045
+ schema .SchemaField ("date_field" , field_type = "DATE" , mode = "NULLABLE" ),
1046
+ schema .SchemaField ("bytes_field" , field_type = "BYTES" , mode = "NULLABLE" ),
1047
+ schema .SchemaField ("string_field" , field_type = "STRING" , mode = "NULLABLE" ),
1048
+ schema .SchemaField ("numeric_field" , field_type = "NUMERIC" , mode = "NULLABLE" ),
1049
+ )
1050
+ by_name = operator .attrgetter ("name" )
1051
+ assert sorted (currated_schema , key = by_name ) == sorted (expected_schema , key = by_name )
1052
+
1053
+
1054
+ @pytest .mark .skipif (pandas is None , reason = "Requires `pandas`" )
1055
+ @pytest .mark .skipif (pyarrow is None , reason = "Requires `pyarrow`" )
1056
+ def test_currate_schema_type_detection_fails (module_under_test ):
1057
+ dataframe = pandas .DataFrame (
1058
+ data = [
1059
+ {
1060
+ "status" : u"FOO" ,
1061
+ "struct_field" : {"one" : 1 },
1062
+ "struct_field_2" : {"foo" : u"123" },
1063
+ },
1064
+ {
1065
+ "status" : u"BAR" ,
1066
+ "struct_field" : {"two" : u"111" },
1067
+ "struct_field_2" : {"bar" : 27 },
1068
+ },
1069
+ ]
1070
+ )
1071
+ current_schema = [
1072
+ schema .SchemaField ("status" , field_type = "STRING" , mode = "NULLABLE" ),
1073
+ schema .SchemaField ("struct_field" , field_type = None , mode = "NULLABLE" ),
1074
+ schema .SchemaField ("struct_field_2" , field_type = None , mode = "NULLABLE" ),
1075
+ ]
1076
+
1077
+ with warnings .catch_warnings (record = True ) as warned :
1078
+ currated_schema = module_under_test .currate_schema (dataframe , current_schema )
1079
+
1080
+ assert currated_schema is None
1081
+
1082
+ expected_warnings = [
1083
+ warning for warning in warned if "could not determine" in str (warning )
1084
+ ]
970
1085
assert len (expected_warnings ) == 1
971
1086
warning_msg = str (expected_warnings [0 ])
972
- assert "all_items " in warning_msg
973
- assert "could not determine the type " in warning_msg
1087
+ assert "pyarrow " in warning_msg . lower ()
1088
+ assert "struct_field" in warning_msg and "struct_field_2 " in warning_msg
0 commit comments