Skip to content

Commit 1b6087c

Browse files
feat: read_pandas accepts pandas Series and Index objects
1 parent 90bcec5 commit 1b6087c

File tree

3 files changed

+70
-4
lines changed

3 files changed

+70
-4
lines changed

bigframes/pandas/__init__.py

+16
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
Literal,
3333
MutableSequence,
3434
Optional,
35+
overload,
3536
Sequence,
3637
Tuple,
3738
Union,
@@ -577,7 +578,22 @@ def read_gbq_table(
577578
read_gbq_table.__doc__ = inspect.getdoc(bigframes.session.Session.read_gbq_table)
578579

579580

581+
@overload
580582
def read_pandas(pandas_dataframe: pandas.DataFrame) -> bigframes.dataframe.DataFrame:
583+
...
584+
585+
586+
@overload
587+
def read_pandas(pandas_dataframe: pandas.Series) -> bigframes.series.Series:
588+
...
589+
590+
591+
@overload
592+
def read_pandas(pandas_dataframe: pandas.Index) -> bigframes.core.indexes.Index:
593+
...
594+
595+
596+
def read_pandas(pandas_dataframe: Union[pandas.DataFrame, pandas.Series, pandas.Index]):
581597
return global_session.with_default_session(
582598
bigframes.session.Session.read_pandas,
583599
pandas_dataframe,

bigframes/session/__init__.py

+39-4
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
Mapping,
3535
MutableSequence,
3636
Optional,
37+
overload,
3738
Sequence,
3839
Tuple,
3940
Union,
@@ -95,7 +96,9 @@
9596

9697
# Avoid circular imports.
9798
if typing.TYPE_CHECKING:
99+
import bigframes.core.indexes as indices
98100
import bigframes.dataframe as dataframe
101+
import bigframes.series as series
99102

100103
_BIGFRAMES_DEFAULT_CONNECTION_ID = "bigframes-default-connection"
101104

@@ -963,7 +966,21 @@ def read_gbq_model(self, model_name: str):
963966
model = self.bqclient.get_model(model_ref)
964967
return bigframes.ml.loader.from_bq(self, model)
965968

969+
@overload
970+
def read_pandas(self, pandas_dataframe: pandas.Index) -> indices.Index:
971+
...
972+
973+
@overload
974+
def read_pandas(self, pandas_dataframe: pandas.Series) -> series.Series:
975+
...
976+
977+
@overload
966978
def read_pandas(self, pandas_dataframe: pandas.DataFrame) -> dataframe.DataFrame:
979+
...
980+
981+
def read_pandas(
982+
self, pandas_dataframe: Union[pandas.DataFrame, pandas.Series, pandas.Index]
983+
) -> Union[dataframe.DataFrame, series.Series, indices.Index]:
967984
"""Loads DataFrame from a pandas DataFrame.
968985
969986
The pandas DataFrame will be persisted as a temporary BigQuery table, which can be
@@ -986,13 +1003,31 @@ def read_pandas(self, pandas_dataframe: pandas.DataFrame) -> dataframe.DataFrame
9861003
[2 rows x 2 columns]
9871004
9881005
Args:
989-
pandas_dataframe (pandas.DataFrame):
990-
a pandas DataFrame object to be loaded.
1006+
pandas_dataframe (pandas.DataFrame, pandas.Series, or pandas.Index):
1007+
a pandas DataFrame/Series/Index object to be loaded.
9911008
9921009
Returns:
993-
bigframes.dataframe.DataFrame: The BigQuery DataFrame.
1010+
An equivalent bigframes.pandas.(DataFrame/Series/Index) object
9941011
"""
995-
return self._read_pandas(pandas_dataframe, "read_pandas")
1012+
import bigframes.series as series
1013+
1014+
# Try to handle non-dataframe pandas objects as well
1015+
if isinstance(pandas_dataframe, pandas.Series):
1016+
bf_df = self._read_pandas(pandas.DataFrame(pandas_dataframe), "read_pandas")
1017+
bf_series = typing.cast(series.Series, bf_df[bf_df.columns[0]])
1018+
# wrapping into df can set name to 0 so reset to original object name
1019+
bf_series.name = pandas_dataframe.name
1020+
return bf_series
1021+
if isinstance(pandas_dataframe, pandas.Index):
1022+
return self._read_pandas(
1023+
pandas.DataFrame(index=pandas_dataframe), "read_pandas"
1024+
).index
1025+
if isinstance(pandas_dataframe, pandas.DataFrame):
1026+
return self._read_pandas(pandas_dataframe, "read_pandas")
1027+
else:
1028+
raise ValueError(
1029+
f"read_pandas() expects a pandas dataframe, but got a {type(pandas_dataframe)}"
1030+
)
9961031

9971032
def _read_pandas(
9981033
self, pandas_dataframe: pandas.DataFrame, api_name: str

tests/system/small/test_session.py

+15
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,21 @@ def test_read_pandas(session, scalars_dfs):
421421
pd.testing.assert_frame_equal(result, expected)
422422

423423

424+
def test_read_pandas_series(session):
425+
idx = pd.Index([2, 7, 1, 2, 8], dtype=pd.Int64Dtype())
426+
pd_series = pd.Series([3, 1, 4, 1, 5], dtype=pd.Int64Dtype(), index=idx)
427+
bf_series = session.read_pandas(pd_series)
428+
429+
pd.testing.assert_series_equal(bf_series.to_pandas(), pd_series)
430+
431+
432+
def test_read_pandas_index(session):
433+
pd_idx = pd.Index([2, 7, 1, 2, 8], dtype=pd.Int64Dtype())
434+
bf_idx = session.read_pandas(pd_idx)
435+
436+
pd.testing.assert_index_equal(bf_idx.to_pandas(), pd_idx)
437+
438+
424439
def test_read_pandas_inline_respects_location():
425440
options = bigframes.BigQueryOptions(location="europe-west1")
426441
session = bigframes.Session(options)

0 commit comments

Comments
 (0)