Skip to content

Commit a07397a

Browse files
committed
Improve and refactor pyarrow schema detection
Add more pyarrow types, convert to pyarrow only the columns the schema could not be detected for, etc.
1 parent dd43f6b commit a07397a

File tree

2 files changed

+218
-51
lines changed

2 files changed

+218
-51
lines changed

bigquery/google/cloud/bigquery/_pandas_helpers.py

Lines changed: 88 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -110,13 +110,35 @@ def pyarrow_timestamp():
110110
"TIME": pyarrow_time,
111111
"TIMESTAMP": pyarrow_timestamp,
112112
}
113-
ARROW_SCALARS_TO_BQ = {
114-
arrow_type(): bq_type # TODO: explain wht calling arrow_type()
115-
for bq_type, arrow_type in BQ_TO_ARROW_SCALARS.items()
113+
ARROW_SCALAR_IDS_TO_BQ = {
114+
# https://arrow.apache.org/docs/python/api/datatypes.html#type-classes
115+
pyarrow.bool_().id: "BOOL",
116+
pyarrow.int8().id: "INT64",
117+
pyarrow.int16().id: "INT64",
118+
pyarrow.int32().id: "INT64",
119+
pyarrow.int64().id: "INT64",
120+
pyarrow.uint8().id: "INT64",
121+
pyarrow.uint16().id: "INT64",
122+
pyarrow.uint32().id: "INT64",
123+
pyarrow.uint64().id: "INT64",
124+
pyarrow.float16().id: "FLOAT64",
125+
pyarrow.float32().id: "FLOAT64",
126+
pyarrow.float64().id: "FLOAT64",
127+
pyarrow.time32("ms").id: "TIME",
128+
pyarrow.time64("ns").id: "TIME",
129+
pyarrow.timestamp("ns").id: "TIMESTAMP",
130+
pyarrow.date32().id: "DATE",
131+
pyarrow.date64().id: "DATETIME", # because millisecond resolution
132+
pyarrow.binary().id: "BYTES",
133+
pyarrow.string().id: "STRING", # also alias for pyarrow.utf8()
134+
pyarrow.decimal128(38, scale=9).id: "NUMERIC",
135+
# The exact decimal's scale and precision are not important, as only
136+
# the type ID matters, and it's the same for all decimal128 instances.
116137
}
138+
117139
else: # pragma: NO COVER
118140
BQ_TO_ARROW_SCALARS = {} # pragma: NO COVER
119-
ARROW_SCALARS_TO_BQ = {} # pragma: NO_COVER
141+
ARROW_SCALAR_IDS_TO_BQ = {} # pragma: NO_COVER
120142

121143

122144
def bq_to_arrow_struct_data_type(field):
@@ -269,6 +291,8 @@ def dataframe_to_bq_schema(dataframe, bq_schema):
269291
bq_schema_unused = set()
270292

271293
bq_schema_out = []
294+
unknown_type_fields = []
295+
272296
for column, dtype in list_columns_and_indexes(dataframe):
273297
# Use provided type from schema, if present.
274298
bq_field = bq_schema_index.get(column)
@@ -280,12 +304,12 @@ def dataframe_to_bq_schema(dataframe, bq_schema):
280304
# Otherwise, try to automatically determine the type based on the
281305
# pandas dtype.
282306
bq_type = _PANDAS_DTYPE_TO_BQ.get(dtype.name)
283-
if not bq_type:
284-
warnings.warn(u"Unable to determine type of column '{}'.".format(column))
285-
286307
bq_field = schema.SchemaField(column, bq_type)
287308
bq_schema_out.append(bq_field)
288309

310+
if bq_field.field_type is None:
311+
unknown_type_fields.append(bq_field)
312+
289313
# Catch any schema mismatch. The developer explicitly asked to serialize a
290314
# column, but it was not found.
291315
if bq_schema_unused:
@@ -297,42 +321,70 @@ def dataframe_to_bq_schema(dataframe, bq_schema):
297321

298322
# If schema detection was not successful for all columns, also try with
299323
# pyarrow, if available.
300-
if any(field.field_type is None for field in bq_schema_out):
324+
if unknown_type_fields:
301325
if not pyarrow:
326+
msg = u"Could not determine the type of columns: {}".format(
327+
", ".join(field.name for field in unknown_type_fields)
328+
)
329+
warnings.warn(msg)
302330
return None # We cannot detect the schema in full.
303331

304-
arrow_table = dataframe_to_arrow(dataframe, bq_schema_out)
305-
arrow_schema_index = {field.name: field.type for field in arrow_table}
332+
# The currate_schema() helper itself will also issue unknown type
333+
# warnings if detection still fails for any of the fields.
334+
bq_schema_out = currate_schema(dataframe, bq_schema_out)
306335

307-
currated_schema = []
308-
for schema_field in bq_schema_out:
309-
if schema_field.field_type is not None:
310-
currated_schema.append(schema_field)
311-
continue
336+
return tuple(bq_schema_out) if bq_schema_out else None
312337

313-
detected_type = ARROW_SCALARS_TO_BQ.get(
314-
arrow_schema_index.get(schema_field.name)
315-
)
316-
if detected_type is None:
317-
warnings.warn(
318-
u"Pyarrow could not determine the type of column '{}'.".format(
319-
schema_field.name
320-
)
321-
)
322-
return None
323-
324-
new_field = schema.SchemaField(
325-
name=schema_field.name,
326-
field_type=detected_type,
327-
mode=schema_field.mode,
328-
description=schema_field.description,
329-
fields=schema_field.fields,
330-
)
331-
currated_schema.append(new_field)
332338

333-
bq_schema_out = currated_schema
339+
def currate_schema(dataframe, current_bq_schema):
340+
"""Try to deduce the unknown field types and return an improved schema.
341+
342+
This function requires ``pyarrow`` to run. If all the missing types still
343+
cannot be detected, ``None`` is returned. If all types are already known,
344+
a shallow copy of the given schema is returned.
345+
346+
Args:
347+
dataframe (pandas.DataFrame):
348+
DataFrame for which some of the field types are still unknown.
349+
current_bq_schema (Sequence[google.cloud.bigquery.schema.SchemaField]):
350+
A BigQuery schema for ``dataframe``. The types of some or all of
351+
the fields may be ``None``.
352+
Returns:
353+
Optional[Sequence[google.cloud.bigquery.schema.SchemaField]]
354+
"""
355+
currated_schema = []
356+
unknown_type_fields = []
357+
358+
for field in current_bq_schema:
359+
if field.field_type is not None:
360+
currated_schema.append(field)
361+
continue
362+
363+
arrow_table = pyarrow.array(dataframe[field.name])
364+
detected_type = ARROW_SCALAR_IDS_TO_BQ.get(arrow_table.type.id)
365+
366+
if detected_type is None:
367+
unknown_type_fields.append(field)
368+
continue
369+
370+
new_field = schema.SchemaField(
371+
name=field.name,
372+
field_type=detected_type,
373+
mode=field.mode,
374+
description=field.description,
375+
fields=field.fields,
376+
)
377+
currated_schema.append(new_field)
378+
379+
if unknown_type_fields:
380+
warnings.warn(
381+
u"Pyarrow could not determine the type of columns: {}.".format(
382+
", ".join(field.name for field in unknown_type_fields)
383+
)
384+
)
385+
return None
334386

335-
return tuple(bq_schema_out)
387+
return currated_schema
336388

337389

338390
def dataframe_to_arrow(dataframe, bq_schema):

bigquery/tests/unit/test__pandas_helpers.py

Lines changed: 130 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import datetime
1717
import decimal
1818
import functools
19+
import operator
1920
import warnings
2021

2122
import mock
@@ -911,47 +912,66 @@ def test_dataframe_to_parquet_compression_method(module_under_test):
911912
def test_dataframe_to_bq_schema_fallback_needed_wo_pyarrow(module_under_test):
912913
dataframe = pandas.DataFrame(
913914
data=[
914-
{"id": 10, "status": "FOO", "execution_date": datetime.date(2019, 5, 10)},
915-
{"id": 20, "status": "BAR", "created_at": datetime.date(2018, 9, 12)},
915+
{"id": 10, "status": u"FOO", "execution_date": datetime.date(2019, 5, 10)},
916+
{"id": 20, "status": u"BAR", "created_at": datetime.date(2018, 9, 12)},
916917
]
917918
)
918919

919920
no_pyarrow_patch = mock.patch(module_under_test.__name__ + ".pyarrow", None)
920921

921-
with no_pyarrow_patch:
922+
with no_pyarrow_patch, warnings.catch_warnings(record=True) as warned:
922923
detected_schema = module_under_test.dataframe_to_bq_schema(
923924
dataframe, bq_schema=[]
924925
)
925926

926927
assert detected_schema is None
927928

929+
# a warning should also be issued
930+
expected_warnings = [
931+
warning for warning in warned if "could not determine" in str(warning).lower()
932+
]
933+
assert len(expected_warnings) == 1
934+
msg = str(expected_warnings[0])
935+
assert "execution_date" in msg and "created_at" in msg
936+
928937

929938
@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
930939
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
931940
def test_dataframe_to_bq_schema_fallback_needed_w_pyarrow(module_under_test):
932941
dataframe = pandas.DataFrame(
933942
data=[
934-
{"id": 10, "status": "FOO", "created_at": datetime.date(2019, 5, 10)},
935-
{"id": 20, "status": "BAR", "created_at": datetime.date(2018, 9, 12)},
943+
{"id": 10, "status": u"FOO", "created_at": datetime.date(2019, 5, 10)},
944+
{"id": 20, "status": u"BAR", "created_at": datetime.date(2018, 9, 12)},
936945
]
937946
)
938947

939-
detected_schema = module_under_test.dataframe_to_bq_schema(dataframe, bq_schema=[])
948+
with warnings.catch_warnings(record=True) as warned:
949+
detected_schema = module_under_test.dataframe_to_bq_schema(
950+
dataframe, bq_schema=[]
951+
)
952+
940953
expected_schema = (
941954
schema.SchemaField("id", "INTEGER", mode="NULLABLE"),
942955
schema.SchemaField("status", "STRING", mode="NULLABLE"),
943956
schema.SchemaField("created_at", "DATE", mode="NULLABLE"),
944957
)
945-
assert detected_schema == expected_schema
958+
by_name = operator.attrgetter("name")
959+
assert sorted(detected_schema, key=by_name) == sorted(expected_schema, key=by_name)
960+
961+
# there should be no relevant warnings
962+
unwanted_warnings = [
963+
warning for warning in warned if "could not determine" in str(warning).lower()
964+
]
965+
assert not unwanted_warnings
946966

947967

948968
@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
949969
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
950970
def test_dataframe_to_bq_schema_pyarrow_fallback_fails(module_under_test):
951971
dataframe = pandas.DataFrame(
952972
data=[
953-
{"id": 10, "status": "FOO", "all_items": [10.1, 10.2]},
954-
{"id": 20, "status": "BAR", "all_items": [20.1, 20.2]},
973+
{"struct_field": {"one": 2}, "status": u"FOO"},
974+
{"struct_field": {"two": u"222"}, "status": u"BAR"},
955975
]
956976
)
957977

@@ -962,12 +982,107 @@ def test_dataframe_to_bq_schema_pyarrow_fallback_fails(module_under_test):
962982

963983
assert detected_schema is None
964984

965-
expected_warnings = []
966-
for warning in warned:
967-
if "Pyarrow could not" in str(warning):
968-
expected_warnings.append(warning)
985+
# a warning should also be issued
986+
expected_warnings = [
987+
warning for warning in warned if "could not determine" in str(warning).lower()
988+
]
989+
assert len(expected_warnings) == 1
990+
assert "struct_field" in str(expected_warnings[0])
991+
992+
993+
@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
994+
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
995+
def test_currate_schema_type_detection_succeeds(module_under_test):
996+
dataframe = pandas.DataFrame(
997+
data=[
998+
{
999+
"bool_field": False,
1000+
"int_field": 123,
1001+
"float_field": 3.141592,
1002+
"time_field": datetime.time(17, 59, 47),
1003+
"timestamp_field": datetime.datetime(2005, 5, 31, 14, 25, 55),
1004+
"date_field": datetime.date(2005, 5, 31),
1005+
"bytes_field": b"some bytes",
1006+
"string_field": u"some characters",
1007+
"numeric_field": decimal.Decimal("123.456"),
1008+
}
1009+
]
1010+
)
1011+
1012+
# NOTE: In Pandas dataframe, the dtype of Python's datetime instances is
1013+
# set to "datetime64[ns]", and pyarrow converts that to pyarrow.TimestampArray.
1014+
# We thus cannot expect to get a DATETIME date when converting back to the
1015+
# BigQuery type.
1016+
1017+
current_schema = (
1018+
schema.SchemaField("bool_field", field_type=None, mode="NULLABLE"),
1019+
schema.SchemaField("int_field", field_type=None, mode="NULLABLE"),
1020+
schema.SchemaField("float_field", field_type=None, mode="NULLABLE"),
1021+
schema.SchemaField("time_field", field_type=None, mode="NULLABLE"),
1022+
schema.SchemaField("timestamp_field", field_type=None, mode="NULLABLE"),
1023+
schema.SchemaField("date_field", field_type=None, mode="NULLABLE"),
1024+
schema.SchemaField("bytes_field", field_type=None, mode="NULLABLE"),
1025+
schema.SchemaField("string_field", field_type=None, mode="NULLABLE"),
1026+
schema.SchemaField("numeric_field", field_type=None, mode="NULLABLE"),
1027+
)
1028+
1029+
with warnings.catch_warnings(record=True) as warned:
1030+
currated_schema = module_under_test.currate_schema(dataframe, current_schema)
9691031

1032+
# there should be no relevant warnings
1033+
unwanted_warnings = [
1034+
warning for warning in warned if "Pyarrow could not" in str(warning)
1035+
]
1036+
assert not unwanted_warnings
1037+
1038+
# the currated schema must match the expected
1039+
expected_schema = (
1040+
schema.SchemaField("bool_field", field_type="BOOL", mode="NULLABLE"),
1041+
schema.SchemaField("int_field", field_type="INT64", mode="NULLABLE"),
1042+
schema.SchemaField("float_field", field_type="FLOAT64", mode="NULLABLE"),
1043+
schema.SchemaField("time_field", field_type="TIME", mode="NULLABLE"),
1044+
schema.SchemaField("timestamp_field", field_type="TIMESTAMP", mode="NULLABLE"),
1045+
schema.SchemaField("date_field", field_type="DATE", mode="NULLABLE"),
1046+
schema.SchemaField("bytes_field", field_type="BYTES", mode="NULLABLE"),
1047+
schema.SchemaField("string_field", field_type="STRING", mode="NULLABLE"),
1048+
schema.SchemaField("numeric_field", field_type="NUMERIC", mode="NULLABLE"),
1049+
)
1050+
by_name = operator.attrgetter("name")
1051+
assert sorted(currated_schema, key=by_name) == sorted(expected_schema, key=by_name)
1052+
1053+
1054+
@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
1055+
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
1056+
def test_currate_schema_type_detection_fails(module_under_test):
1057+
dataframe = pandas.DataFrame(
1058+
data=[
1059+
{
1060+
"status": u"FOO",
1061+
"struct_field": {"one": 1},
1062+
"struct_field_2": {"foo": u"123"},
1063+
},
1064+
{
1065+
"status": u"BAR",
1066+
"struct_field": {"two": u"111"},
1067+
"struct_field_2": {"bar": 27},
1068+
},
1069+
]
1070+
)
1071+
current_schema = [
1072+
schema.SchemaField("status", field_type="STRING", mode="NULLABLE"),
1073+
schema.SchemaField("struct_field", field_type=None, mode="NULLABLE"),
1074+
schema.SchemaField("struct_field_2", field_type=None, mode="NULLABLE"),
1075+
]
1076+
1077+
with warnings.catch_warnings(record=True) as warned:
1078+
currated_schema = module_under_test.currate_schema(dataframe, current_schema)
1079+
1080+
assert currated_schema is None
1081+
1082+
expected_warnings = [
1083+
warning for warning in warned if "could not determine" in str(warning)
1084+
]
9701085
assert len(expected_warnings) == 1
9711086
warning_msg = str(expected_warnings[0])
972-
assert "all_items" in warning_msg
973-
assert "could not determine the type" in warning_msg
1087+
assert "pyarrow" in warning_msg.lower()
1088+
assert "struct_field" in warning_msg and "struct_field_2" in warning_msg

0 commit comments

Comments
 (0)