Skip to content

Commit 3cae066

Browse files
authored
feat: default to DATETIME type when loading timezone-naive datetimes from Pandas (#1061)
* Make systest expect DATETIME for naive datetimes * Fix SchemaField repr() when field type not set * Adjust DATETIME detection logic in dataframes * Fix assertions in one of the samples tests
1 parent 070729f commit 3cae066

File tree

7 files changed

+201
-39
lines changed

7 files changed

+201
-39
lines changed

google/cloud/bigquery/_pandas_helpers.py

+46-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
"""Shared helper functions for connecting BigQuery and pandas."""
1616

1717
import concurrent.futures
18+
from datetime import datetime
1819
import functools
20+
from itertools import islice
1921
import logging
2022
import queue
2123
import warnings
@@ -85,9 +87,7 @@ def _to_wkb(v):
8587
_PANDAS_DTYPE_TO_BQ = {
8688
"bool": "BOOLEAN",
8789
"datetime64[ns, UTC]": "TIMESTAMP",
88-
# TODO: Update to DATETIME in V3
89-
# https://github.com/googleapis/python-bigquery/issues/985
90-
"datetime64[ns]": "TIMESTAMP",
90+
"datetime64[ns]": "DATETIME",
9191
"float32": "FLOAT",
9292
"float64": "FLOAT",
9393
"int8": "INTEGER",
@@ -379,6 +379,36 @@ def _first_valid(series):
379379
return series.at[first_valid_index]
380380

381381

382+
def _first_array_valid(series):
383+
"""Return the first "meaningful" element from the array series.
384+
385+
Here, "meaningful" means the first non-None element in one of the arrays that can
386+
be used for type detextion.
387+
"""
388+
first_valid_index = series.first_valid_index()
389+
if first_valid_index is None:
390+
return None
391+
392+
valid_array = series.at[first_valid_index]
393+
valid_item = next((item for item in valid_array if not pandas.isna(item)), None)
394+
395+
if valid_item is not None:
396+
return valid_item
397+
398+
# Valid item is None because all items in the "valid" array are invalid. Try
399+
# to find a true valid array manually.
400+
for array in islice(series, first_valid_index + 1, None):
401+
try:
402+
array_iter = iter(array)
403+
except TypeError:
404+
continue # Not an array, apparently, e.g. None, thus skip.
405+
valid_item = next((item for item in array_iter if not pandas.isna(item)), None)
406+
if valid_item is not None:
407+
break
408+
409+
return valid_item
410+
411+
382412
def dataframe_to_bq_schema(dataframe, bq_schema):
383413
"""Convert a pandas DataFrame schema to a BigQuery schema.
384414
@@ -482,6 +512,19 @@ def augment_schema(dataframe, current_bq_schema):
482512
# `pyarrow.ListType`
483513
detected_mode = "REPEATED"
484514
detected_type = ARROW_SCALAR_IDS_TO_BQ.get(arrow_table.values.type.id)
515+
516+
# For timezone-naive datetimes, pyarrow assumes the UTC timezone and adds
517+
# it to such datetimes, causing them to be recognized as TIMESTAMP type.
518+
# We thus additionally check the actual data to see if we need to overrule
519+
# that and choose DATETIME instead.
520+
# Note that this should only be needed for datetime values inside a list,
521+
# since scalar datetime values have a proper Pandas dtype that allows
522+
# distinguishing between timezone-naive and timezone-aware values before
523+
# even requiring the additional schema augment logic in this method.
524+
if detected_type == "TIMESTAMP":
525+
valid_item = _first_array_valid(dataframe[field.name])
526+
if isinstance(valid_item, datetime) and valid_item.tzinfo is None:
527+
detected_type = "DATETIME"
485528
else:
486529
detected_mode = field.mode
487530
detected_type = ARROW_SCALAR_IDS_TO_BQ.get(arrow_table.type.id)

google/cloud/bigquery/schema.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -257,16 +257,20 @@ def _key(self):
257257
Returns:
258258
Tuple: The contents of this :class:`~google.cloud.bigquery.schema.SchemaField`.
259259
"""
260-
field_type = self.field_type.upper()
261-
if field_type == "STRING" or field_type == "BYTES":
262-
if self.max_length is not None:
263-
field_type = f"{field_type}({self.max_length})"
264-
elif field_type.endswith("NUMERIC"):
265-
if self.precision is not None:
266-
if self.scale is not None:
267-
field_type = f"{field_type}({self.precision}, {self.scale})"
268-
else:
269-
field_type = f"{field_type}({self.precision})"
260+
field_type = self.field_type.upper() if self.field_type is not None else None
261+
262+
# Type can temporarily be set to None if the code needs a SchemaField instance,
263+
# but has npt determined the exact type of the field yet.
264+
if field_type is not None:
265+
if field_type == "STRING" or field_type == "BYTES":
266+
if self.max_length is not None:
267+
field_type = f"{field_type}({self.max_length})"
268+
elif field_type.endswith("NUMERIC"):
269+
if self.precision is not None:
270+
if self.scale is not None:
271+
field_type = f"{field_type}({self.precision}, {self.scale})"
272+
else:
273+
field_type = f"{field_type}({self.precision})"
270274

271275
policy_tags = (
272276
None if self.policy_tags is None else tuple(sorted(self.policy_tags.names))

samples/tests/test_load_table_dataframe.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_load_table_dataframe(capsys, client, random_table_id):
4444
"INTEGER",
4545
"FLOAT",
4646
"TIMESTAMP",
47-
"TIMESTAMP",
47+
"DATETIME",
4848
]
4949

5050
df = client.list_rows(table).to_dataframe()
@@ -64,9 +64,9 @@ def test_load_table_dataframe(capsys, client, random_table_id):
6464
pandas.Timestamp("1983-05-09T11:00:00+00:00"),
6565
]
6666
assert df["dvd_release"].tolist() == [
67-
pandas.Timestamp("2003-10-22T10:00:00+00:00"),
68-
pandas.Timestamp("2002-07-16T09:00:00+00:00"),
69-
pandas.Timestamp("2008-01-14T08:00:00+00:00"),
70-
pandas.Timestamp("2002-01-22T07:00:00+00:00"),
67+
pandas.Timestamp("2003-10-22T10:00:00"),
68+
pandas.Timestamp("2002-07-16T09:00:00"),
69+
pandas.Timestamp("2008-01-14T08:00:00"),
70+
pandas.Timestamp("2002-01-22T07:00:00"),
7171
]
7272
assert df["wikidata_id"].tolist() == [u"Q16403", u"Q25043", u"Q24953", u"Q24980"]

tests/system/test_pandas.py

+12-19
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def test_load_table_from_dataframe_w_automatic_schema(bigquery_client, dataset_i
6565
).dt.tz_localize(datetime.timezone.utc),
6666
),
6767
(
68-
"dt_col",
68+
"dt_col_no_tz",
6969
pandas.Series(
7070
[
7171
datetime.datetime(2010, 1, 2, 3, 44, 50),
@@ -130,7 +130,7 @@ def test_load_table_from_dataframe_w_automatic_schema(bigquery_client, dataset_i
130130
),
131131
),
132132
(
133-
"array_dt_col",
133+
"array_dt_col_no_tz",
134134
pandas.Series(
135135
[
136136
[datetime.datetime(2010, 1, 2, 3, 44, 50)],
@@ -196,9 +196,7 @@ def test_load_table_from_dataframe_w_automatic_schema(bigquery_client, dataset_i
196196
assert tuple(table.schema) == (
197197
bigquery.SchemaField("bool_col", "BOOLEAN"),
198198
bigquery.SchemaField("ts_col", "TIMESTAMP"),
199-
# TODO: Update to DATETIME in V3
200-
# https://github.com/googleapis/python-bigquery/issues/985
201-
bigquery.SchemaField("dt_col", "TIMESTAMP"),
199+
bigquery.SchemaField("dt_col_no_tz", "DATETIME"),
202200
bigquery.SchemaField("float32_col", "FLOAT"),
203201
bigquery.SchemaField("float64_col", "FLOAT"),
204202
bigquery.SchemaField("int8_col", "INTEGER"),
@@ -212,9 +210,7 @@ def test_load_table_from_dataframe_w_automatic_schema(bigquery_client, dataset_i
212210
bigquery.SchemaField("time_col", "TIME"),
213211
bigquery.SchemaField("array_bool_col", "BOOLEAN", mode="REPEATED"),
214212
bigquery.SchemaField("array_ts_col", "TIMESTAMP", mode="REPEATED"),
215-
# TODO: Update to DATETIME in V3
216-
# https://github.com/googleapis/python-bigquery/issues/985
217-
bigquery.SchemaField("array_dt_col", "TIMESTAMP", mode="REPEATED"),
213+
bigquery.SchemaField("array_dt_col_no_tz", "DATETIME", mode="REPEATED"),
218214
bigquery.SchemaField("array_float32_col", "FLOAT", mode="REPEATED"),
219215
bigquery.SchemaField("array_float64_col", "FLOAT", mode="REPEATED"),
220216
bigquery.SchemaField("array_int8_col", "INTEGER", mode="REPEATED"),
@@ -225,6 +221,7 @@ def test_load_table_from_dataframe_w_automatic_schema(bigquery_client, dataset_i
225221
bigquery.SchemaField("array_uint16_col", "INTEGER", mode="REPEATED"),
226222
bigquery.SchemaField("array_uint32_col", "INTEGER", mode="REPEATED"),
227223
)
224+
228225
assert numpy.array(
229226
sorted(map(list, bigquery_client.list_rows(table)), key=lambda r: r[5]),
230227
dtype="object",
@@ -237,13 +234,11 @@ def test_load_table_from_dataframe_w_automatic_schema(bigquery_client, dataset_i
237234
datetime.datetime(2011, 2, 3, 14, 50, 59, tzinfo=datetime.timezone.utc),
238235
datetime.datetime(2012, 3, 14, 15, 16, tzinfo=datetime.timezone.utc),
239236
],
240-
# dt_col
241-
# TODO: Remove tzinfo in V3.
242-
# https://github.com/googleapis/python-bigquery/issues/985
237+
# dt_col_no_tz
243238
[
244-
datetime.datetime(2010, 1, 2, 3, 44, 50, tzinfo=datetime.timezone.utc),
245-
datetime.datetime(2011, 2, 3, 14, 50, 59, tzinfo=datetime.timezone.utc),
246-
datetime.datetime(2012, 3, 14, 15, 16, tzinfo=datetime.timezone.utc),
239+
datetime.datetime(2010, 1, 2, 3, 44, 50),
240+
datetime.datetime(2011, 2, 3, 14, 50, 59),
241+
datetime.datetime(2012, 3, 14, 15, 16),
247242
],
248243
# float32_col
249244
[1.0, 2.0, 3.0],
@@ -280,12 +275,10 @@ def test_load_table_from_dataframe_w_automatic_schema(bigquery_client, dataset_i
280275
[datetime.datetime(2012, 3, 14, 15, 16, tzinfo=datetime.timezone.utc)],
281276
],
282277
# array_dt_col
283-
# TODO: Remove tzinfo in V3.
284-
# https://github.com/googleapis/python-bigquery/issues/985
285278
[
286-
[datetime.datetime(2010, 1, 2, 3, 44, 50, tzinfo=datetime.timezone.utc)],
287-
[datetime.datetime(2011, 2, 3, 14, 50, 59, tzinfo=datetime.timezone.utc)],
288-
[datetime.datetime(2012, 3, 14, 15, 16, tzinfo=datetime.timezone.utc)],
279+
[datetime.datetime(2010, 1, 2, 3, 44, 50)],
280+
[datetime.datetime(2011, 2, 3, 14, 50, 59)],
281+
[datetime.datetime(2012, 3, 14, 15, 16)],
289282
],
290283
# array_float32_col
291284
[[1.0], [2.0], [3.0]],

tests/unit/test__pandas_helpers.py

+117
Original file line numberDiff line numberDiff line change
@@ -1208,6 +1208,46 @@ def test_dataframe_to_bq_schema_geography(module_under_test):
12081208
)
12091209

12101210

1211+
@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
1212+
def test__first_array_valid_no_valid_items(module_under_test):
1213+
series = pandas.Series([None, pandas.NA, float("NaN")])
1214+
result = module_under_test._first_array_valid(series)
1215+
assert result is None
1216+
1217+
1218+
@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
1219+
def test__first_array_valid_valid_item_exists(module_under_test):
1220+
series = pandas.Series([None, [0], [1], None])
1221+
result = module_under_test._first_array_valid(series)
1222+
assert result == 0
1223+
1224+
1225+
@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
1226+
def test__first_array_valid_all_nan_items_in_first_valid_candidate(module_under_test):
1227+
import numpy
1228+
1229+
series = pandas.Series(
1230+
[
1231+
None,
1232+
[None, float("NaN"), pandas.NA, pandas.NaT, numpy.nan],
1233+
None,
1234+
[None, None],
1235+
[None, float("NaN"), pandas.NA, pandas.NaT, numpy.nan, 42, None],
1236+
[1, 2, 3],
1237+
None,
1238+
]
1239+
)
1240+
result = module_under_test._first_array_valid(series)
1241+
assert result == 42
1242+
1243+
1244+
@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
1245+
def test__first_array_valid_no_arrays_with_valid_items(module_under_test):
1246+
series = pandas.Series([[None, None], [None, None]])
1247+
result = module_under_test._first_array_valid(series)
1248+
assert result is None
1249+
1250+
12111251
@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
12121252
def test_augment_schema_type_detection_succeeds(module_under_test):
12131253
dataframe = pandas.DataFrame(
@@ -1274,6 +1314,59 @@ def test_augment_schema_type_detection_succeeds(module_under_test):
12741314
assert sorted(augmented_schema, key=by_name) == sorted(expected_schema, key=by_name)
12751315

12761316

1317+
@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
1318+
def test_augment_schema_repeated_fields(module_under_test):
1319+
dataframe = pandas.DataFrame(
1320+
data=[
1321+
# Include some values useless for type detection to make sure the logic
1322+
# indeed finds the value that is suitable.
1323+
{"string_array": None, "timestamp_array": None, "datetime_array": None},
1324+
{
1325+
"string_array": [None],
1326+
"timestamp_array": [None],
1327+
"datetime_array": [None],
1328+
},
1329+
{"string_array": None, "timestamp_array": None, "datetime_array": None},
1330+
{
1331+
"string_array": [None, "foo"],
1332+
"timestamp_array": [
1333+
None,
1334+
datetime.datetime(
1335+
2005, 5, 31, 14, 25, 55, tzinfo=datetime.timezone.utc
1336+
),
1337+
],
1338+
"datetime_array": [None, datetime.datetime(2005, 5, 31, 14, 25, 55)],
1339+
},
1340+
{"string_array": None, "timestamp_array": None, "datetime_array": None},
1341+
]
1342+
)
1343+
1344+
current_schema = (
1345+
schema.SchemaField("string_array", field_type=None, mode="NULLABLE"),
1346+
schema.SchemaField("timestamp_array", field_type=None, mode="NULLABLE"),
1347+
schema.SchemaField("datetime_array", field_type=None, mode="NULLABLE"),
1348+
)
1349+
1350+
with warnings.catch_warnings(record=True) as warned:
1351+
augmented_schema = module_under_test.augment_schema(dataframe, current_schema)
1352+
1353+
# there should be no relevant warnings
1354+
unwanted_warnings = [
1355+
warning for warning in warned if "Pyarrow could not" in str(warning)
1356+
]
1357+
assert not unwanted_warnings
1358+
1359+
# the augmented schema must match the expected
1360+
expected_schema = (
1361+
schema.SchemaField("string_array", field_type="STRING", mode="REPEATED"),
1362+
schema.SchemaField("timestamp_array", field_type="TIMESTAMP", mode="REPEATED"),
1363+
schema.SchemaField("datetime_array", field_type="DATETIME", mode="REPEATED"),
1364+
)
1365+
1366+
by_name = operator.attrgetter("name")
1367+
assert sorted(augmented_schema, key=by_name) == sorted(expected_schema, key=by_name)
1368+
1369+
12771370
@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
12781371
def test_augment_schema_type_detection_fails(module_under_test):
12791372
dataframe = pandas.DataFrame(
@@ -1310,6 +1403,30 @@ def test_augment_schema_type_detection_fails(module_under_test):
13101403
assert "struct_field" in warning_msg and "struct_field_2" in warning_msg
13111404

13121405

1406+
@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
1407+
def test_augment_schema_type_detection_fails_array_data(module_under_test):
1408+
dataframe = pandas.DataFrame(
1409+
data=[{"all_none_array": [None, float("NaN")], "empty_array": []}]
1410+
)
1411+
current_schema = [
1412+
schema.SchemaField("all_none_array", field_type=None, mode="NULLABLE"),
1413+
schema.SchemaField("empty_array", field_type=None, mode="NULLABLE"),
1414+
]
1415+
1416+
with warnings.catch_warnings(record=True) as warned:
1417+
augmented_schema = module_under_test.augment_schema(dataframe, current_schema)
1418+
1419+
assert augmented_schema is None
1420+
1421+
expected_warnings = [
1422+
warning for warning in warned if "could not determine" in str(warning)
1423+
]
1424+
assert len(expected_warnings) == 1
1425+
warning_msg = str(expected_warnings[0])
1426+
assert "pyarrow" in warning_msg.lower()
1427+
assert "all_none_array" in warning_msg and "empty_array" in warning_msg
1428+
1429+
13131430
def test_dataframe_to_parquet_dict_sequence_schema(module_under_test):
13141431
pandas = pytest.importorskip("pandas")
13151432

tests/unit/test_client.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7153,7 +7153,7 @@ def test_load_table_from_dataframe_w_automatic_schema(self):
71537153
SchemaField("int_col", "INTEGER"),
71547154
SchemaField("float_col", "FLOAT"),
71557155
SchemaField("bool_col", "BOOLEAN"),
7156-
SchemaField("dt_col", "TIMESTAMP"),
7156+
SchemaField("dt_col", "DATETIME"),
71577157
SchemaField("ts_col", "TIMESTAMP"),
71587158
SchemaField("date_col", "DATE"),
71597159
SchemaField("time_col", "TIME"),
@@ -7660,7 +7660,7 @@ def test_load_table_from_dataframe_w_partial_schema(self):
76607660
SchemaField("int_as_float_col", "INTEGER"),
76617661
SchemaField("float_col", "FLOAT"),
76627662
SchemaField("bool_col", "BOOLEAN"),
7663-
SchemaField("dt_col", "TIMESTAMP"),
7663+
SchemaField("dt_col", "DATETIME"),
76647664
SchemaField("ts_col", "TIMESTAMP"),
76657665
SchemaField("string_col", "STRING"),
76667666
SchemaField("bytes_col", "BYTES"),

tests/unit/test_schema.py

+5
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,11 @@ def test___repr__(self):
529529
expected = "SchemaField('field1', 'STRING', 'NULLABLE', None, (), None)"
530530
self.assertEqual(repr(field1), expected)
531531

532+
def test___repr__type_not_set(self):
533+
field1 = self._make_one("field1", field_type=None)
534+
expected = "SchemaField('field1', None, 'NULLABLE', None, (), None)"
535+
self.assertEqual(repr(field1), expected)
536+
532537
def test___repr__evaluable_no_policy_tags(self):
533538
field = self._make_one("field1", "STRING", "REQUIRED", "Description")
534539
field_repr = repr(field)

0 commit comments

Comments
 (0)