Skip to content

Commit bf050cf

Browse files
feat: add update and align methods to dataframe (#57)
* feat: add update and align methods to dataframe
1 parent bc7be7f commit bf050cf

File tree

4 files changed

+271
-47
lines changed

4 files changed

+271
-47
lines changed

bigframes/core/block_transforms.py

+72
Original file line numberDiff line numberDiff line change
@@ -504,3 +504,75 @@ def _kurt_from_moments_and_count(
504504
kurt_id, na_cond_id, ops.partial_arg3(ops.where_op, None)
505505
)
506506
return block, kurt_id
507+
508+
509+
def align(
510+
left_block: blocks.Block,
511+
right_block: blocks.Block,
512+
join: str = "outer",
513+
axis: typing.Union[str, int, None] = None,
514+
) -> typing.Tuple[blocks.Block, blocks.Block]:
515+
axis_n = core.utils.get_axis_number(axis) if axis is not None else None
516+
# Must align columns first as other way will likely create extra joins
517+
if (axis_n is None) or axis_n == 1:
518+
left_block, right_block = align_columns(left_block, right_block, join=join)
519+
if (axis_n is None) or axis_n == 0:
520+
left_block, right_block = align_rows(left_block, right_block, join=join)
521+
return left_block, right_block
522+
523+
524+
def align_rows(
525+
left_block: blocks.Block,
526+
right_block: blocks.Block,
527+
join: str = "outer",
528+
):
529+
joined_index, (get_column_left, get_column_right) = left_block.index.join(
530+
right_block.index, how=join
531+
)
532+
left_columns = [get_column_left(col) for col in left_block.value_columns]
533+
right_columns = [get_column_right(col) for col in right_block.value_columns]
534+
535+
left_block = joined_index._block.select_columns(left_columns)
536+
right_block = joined_index._block.select_columns(right_columns)
537+
return left_block, right_block
538+
539+
540+
def align_columns(
541+
left_block: blocks.Block,
542+
right_block: blocks.Block,
543+
join: str = "outer",
544+
):
545+
columns, lcol_indexer, rcol_indexer = left_block.column_labels.join(
546+
right_block.column_labels, how=join, return_indexers=True
547+
)
548+
column_indices = zip(
549+
lcol_indexer if (lcol_indexer is not None) else range(len(columns)),
550+
rcol_indexer if (rcol_indexer is not None) else range(len(columns)),
551+
)
552+
left_column_ids = []
553+
right_column_ids = []
554+
555+
original_left_block = left_block
556+
original_right_block = right_block
557+
558+
for left_index, right_index in column_indices:
559+
if left_index >= 0:
560+
left_col_id = original_left_block.value_columns[left_index]
561+
else:
562+
dtype = right_block.dtypes[right_index]
563+
left_block, left_col_id = left_block.create_constant(
564+
None, dtype=dtype, label=original_right_block.column_labels[right_index]
565+
)
566+
left_column_ids.append(left_col_id)
567+
568+
if right_index >= 0:
569+
right_col_id = original_right_block.value_columns[right_index]
570+
else:
571+
dtype = original_left_block.dtypes[left_index]
572+
right_block, right_col_id = right_block.create_constant(
573+
None, dtype=dtype, label=left_block.column_labels[left_index]
574+
)
575+
right_column_ids.append(right_col_id)
576+
left_final = left_block.select_columns(left_column_ids)
577+
right_final = right_block.select_columns(right_column_ids)
578+
return left_final, right_final

bigframes/dataframe.py

+66-42
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,55 @@ def rpow(
745745

746746
__rpow__ = rpow
747747

748+
def align(
749+
self,
750+
other: typing.Union[DataFrame, bigframes.series.Series],
751+
join: str = "outer",
752+
axis: typing.Union[str, int, None] = None,
753+
) -> typing.Tuple[
754+
typing.Union[DataFrame, bigframes.series.Series],
755+
typing.Union[DataFrame, bigframes.series.Series],
756+
]:
757+
axis_n = utils.get_axis_number(axis) if axis else None
758+
if axis_n == 1 and isinstance(other, bigframes.series.Series):
759+
raise NotImplementedError(
760+
f"align with series and axis=1 not supported. {constants.FEEDBACK_LINK}"
761+
)
762+
left_block, right_block = block_ops.align(
763+
self._block, other._block, join=join, axis=axis
764+
)
765+
return DataFrame(left_block), other.__class__(right_block)
766+
767+
def update(self, other, join: str = "left", overwrite=True, filter_func=None):
768+
other = other if isinstance(other, DataFrame) else DataFrame(other)
769+
if join != "left":
770+
raise ValueError("Only 'left' join supported for update")
771+
772+
if filter_func is not None: # Will always take other if possible
773+
774+
def update_func(
775+
left: bigframes.series.Series, right: bigframes.series.Series
776+
) -> bigframes.series.Series:
777+
return left.mask(right.notna() & filter_func(left), right)
778+
779+
elif overwrite:
780+
781+
def update_func(
782+
left: bigframes.series.Series, right: bigframes.series.Series
783+
) -> bigframes.series.Series:
784+
return left.mask(right.notna(), right)
785+
786+
else:
787+
788+
def update_func(
789+
left: bigframes.series.Series, right: bigframes.series.Series
790+
) -> bigframes.series.Series:
791+
return left.mask(left.isna(), right)
792+
793+
result = self.combine(other, update_func, how=join)
794+
795+
self._set_block(result._block)
796+
748797
def combine(
749798
self,
750799
other: DataFrame,
@@ -753,56 +802,31 @@ def combine(
753802
],
754803
fill_value=None,
755804
overwrite: bool = True,
805+
*,
806+
how: str = "outer",
756807
) -> DataFrame:
757-
# Join rows
758-
joined_index, (get_column_left, get_column_right) = self._block.index.join(
759-
other._block.index, how="outer"
760-
)
761-
columns, lcol_indexer, rcol_indexer = self.columns.join(
762-
other.columns, how="outer", return_indexers=True
763-
)
808+
l_aligned, r_aligned = block_ops.align(self._block, other._block, join=how)
764809

765-
column_indices = zip(
766-
lcol_indexer if (lcol_indexer is not None) else range(len(columns)),
767-
rcol_indexer if (lcol_indexer is not None) else range(len(columns)),
810+
other_missing_labels = self._block.column_labels.difference(
811+
other._block.column_labels
768812
)
769813

770-
block = joined_index._block
814+
l_frame = DataFrame(l_aligned)
815+
r_frame = DataFrame(r_aligned)
771816
results = []
772-
for left_index, right_index in column_indices:
773-
if left_index >= 0 and right_index >= 0: # -1 indices indicate missing
774-
left_col_id = get_column_left(self._block.value_columns[left_index])
775-
right_col_id = get_column_right(other._block.value_columns[right_index])
776-
left_series = bigframes.series.Series(block.select_column(left_col_id))
777-
right_series = bigframes.series.Series(
778-
block.select_column(right_col_id)
779-
)
817+
for (label, lseries), (_, rseries) in zip(l_frame.items(), r_frame.items()):
818+
if not ((label in other_missing_labels) and not overwrite):
780819
if fill_value is not None:
781-
left_series = left_series.fillna(fill_value)
782-
right_series = right_series.fillna(fill_value)
783-
results.append(func(left_series, right_series))
784-
elif left_index >= 0:
785-
# Does not exist in other
786-
if overwrite:
787-
dtype = self.dtypes[left_index]
788-
block, null_col_id = block.create_constant(None, dtype=dtype)
789-
result = bigframes.series.Series(block.select_column(null_col_id))
790-
results.append(result)
820+
result = func(
821+
lseries.fillna(fill_value), rseries.fillna(fill_value)
822+
)
791823
else:
792-
left_col_id = get_column_left(self._block.value_columns[left_index])
793-
result = bigframes.series.Series(block.select_column(left_col_id))
794-
if fill_value is not None:
795-
result = result.fillna(fill_value)
796-
results.append(result)
797-
elif right_index >= 0:
798-
right_col_id = get_column_right(other._block.value_columns[right_index])
799-
result = bigframes.series.Series(block.select_column(right_col_id))
800-
if fill_value is not None:
801-
result = result.fillna(fill_value)
802-
results.append(result)
824+
result = func(lseries, rseries)
803825
else:
804-
# Should not be possible
805-
raise ValueError("No right or left index.")
826+
result = (
827+
lseries.fillna(fill_value) if fill_value is not None else lseries
828+
)
829+
results.append(result)
806830

807831
if all([isinstance(val, bigframes.series.Series) for val in results]):
808832
import bigframes.core.reshape as rs

tests/system/small/test_dataframe.py

+71-5
Original file line numberDiff line numberDiff line change
@@ -1211,6 +1211,77 @@ def test_combine(
12111211
pd.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
12121212

12131213

1214+
@pytest.mark.parametrize(
1215+
("overwrite", "filter_func"),
1216+
[
1217+
(True, None),
1218+
(False, None),
1219+
(True, lambda x: x.isna() | (x % 2 == 0)),
1220+
],
1221+
ids=[
1222+
"default",
1223+
"overwritefalse",
1224+
"customfilter",
1225+
],
1226+
)
1227+
def test_df_update(overwrite, filter_func):
1228+
if pd.__version__.startswith("1."):
1229+
pytest.skip("dtype handled differently in pandas 1.x.")
1230+
index1 = pandas.Index([1, 2, 3, 4], dtype="Int64")
1231+
index2 = pandas.Index([1, 2, 4, 5], dtype="Int64")
1232+
pd_df1 = pandas.DataFrame(
1233+
{"a": [1, None, 3, 4], "b": [5, 6, None, 8]}, dtype="Int64", index=index1
1234+
)
1235+
pd_df2 = pandas.DataFrame(
1236+
{"a": [None, 20, 30, 40], "c": [90, None, 110, 120]},
1237+
dtype="Int64",
1238+
index=index2,
1239+
)
1240+
1241+
bf_df1 = dataframe.DataFrame(pd_df1)
1242+
bf_df2 = dataframe.DataFrame(pd_df2)
1243+
1244+
bf_df1.update(bf_df2, overwrite=overwrite, filter_func=filter_func)
1245+
pd_df1.update(pd_df2, overwrite=overwrite, filter_func=filter_func)
1246+
1247+
pd.testing.assert_frame_equal(bf_df1.to_pandas(), pd_df1)
1248+
1249+
1250+
@pytest.mark.parametrize(
1251+
("join", "axis"),
1252+
[
1253+
("outer", None),
1254+
("outer", 0),
1255+
("outer", 1),
1256+
("left", 0),
1257+
("right", 1),
1258+
("inner", None),
1259+
("inner", 1),
1260+
],
1261+
)
1262+
def test_df_align(join, axis):
1263+
index1 = pandas.Index([1, 2, 3, 4], dtype="Int64")
1264+
index2 = pandas.Index([1, 2, 4, 5], dtype="Int64")
1265+
pd_df1 = pandas.DataFrame(
1266+
{"a": [1, None, 3, 4], "b": [5, 6, None, 8]}, dtype="Int64", index=index1
1267+
)
1268+
pd_df2 = pandas.DataFrame(
1269+
{"a": [None, 20, 30, 40], "c": [90, None, 110, 120]},
1270+
dtype="Int64",
1271+
index=index2,
1272+
)
1273+
1274+
bf_df1 = dataframe.DataFrame(pd_df1)
1275+
bf_df2 = dataframe.DataFrame(pd_df2)
1276+
1277+
bf_result1, bf_result2 = bf_df1.align(bf_df2, join=join, axis=axis)
1278+
pd_result1, pd_result2 = pd_df1.align(pd_df2, join=join, axis=axis)
1279+
1280+
# Don't check dtype as pandas does unnecessary float conversion
1281+
pd.testing.assert_frame_equal(bf_result1.to_pandas(), pd_result1, check_dtype=False)
1282+
pd.testing.assert_frame_equal(bf_result2.to_pandas(), pd_result2, check_dtype=False)
1283+
1284+
12141285
def test_combine_first(
12151286
scalars_df_index,
12161287
scalars_df_2_index,
@@ -1232,11 +1303,6 @@ def test_combine_first(
12321303
pd_df_b.columns = ["b", "a", "d"]
12331304
pd_result = pd_df_a.combine_first(pd_df_b)
12341305

1235-
print("pandas")
1236-
print(pd_result.to_string())
1237-
print("bigframes")
1238-
print(bf_result.to_string())
1239-
12401306
# Some dtype inconsistency for all-NULL columns
12411307
pd.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
12421308

third_party/bigframes_vendored/pandas/core/frame.py

+62
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,35 @@ def drop(
503503
"""
504504
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
505505

506+
def align(
507+
self,
508+
other,
509+
join="outer",
510+
axis=None,
511+
) -> tuple:
512+
"""
513+
Align two objects on their axes with the specified join method.
514+
515+
Join method is specified for each axis Index.
516+
517+
Args:
518+
other (DataFrame or Series):
519+
join ({{'outer', 'inner', 'left', 'right'}}, default 'outer'):
520+
Type of alignment to be performed.
521+
left: use only keys from left frame, preserve key order.
522+
right: use only keys from right frame, preserve key order.
523+
outer: use union of keys from both frames, sort keys lexicographically.
524+
inner: use intersection of keys from both frames,
525+
preserve the order of the left keys.
526+
527+
axis (allowed axis of the other object, default None):
528+
Align on index (0), columns (1), or both (None).
529+
530+
Returns:
531+
tuple of (DataFrame, type of other): Aligned objects.
532+
"""
533+
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
534+
506535
def rename(
507536
self,
508537
*,
@@ -1265,6 +1294,39 @@ def combine_first(self, other) -> DataFrame:
12651294
"""
12661295
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
12671296

1297+
def update(
1298+
self, other, join: str = "left", overwrite: bool = True, filter_func=None
1299+
) -> DataFrame:
1300+
"""
1301+
Modify in place using non-NA values from another DataFrame.
1302+
1303+
Aligns on indices. There is no return value.
1304+
1305+
Args:
1306+
other (DataFrame, or object coercible into a DataFrame):
1307+
Should have at least one matching index/column label
1308+
with the original DataFrame. If a Series is passed,
1309+
its name attribute must be set, and that will be
1310+
used as the column name to align with the original DataFrame.
1311+
join ({'left'}, default 'left'):
1312+
Only left join is implemented, keeping the index and columns of the
1313+
original object.
1314+
overwrite (bool, default True):
1315+
How to handle non-NA values for overlapping keys:
1316+
True: overwrite original DataFrame's values
1317+
with values from `other`.
1318+
False: only update values that are NA in
1319+
the original DataFrame.
1320+
1321+
filter_func (callable(1d-array) -> bool 1d-array, optional):
1322+
Can choose to replace values other than NA. Return True for values
1323+
that should be updated.
1324+
1325+
Returns:
1326+
None: This method directly changes calling object.
1327+
"""
1328+
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
1329+
12681330
# ----------------------------------------------------------------------
12691331
# Data reshaping
12701332

0 commit comments

Comments
 (0)