Skip to content

Commit 63025b5

Browse files
feat: add multi-column dataframe merge
1 parent d8910d4 commit 63025b5

File tree

7 files changed

+206
-146
lines changed

7 files changed

+206
-146
lines changed

bigframes/core/blocks.py

+73
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import bigframes.core as core
3939
import bigframes.core.guid as guid
4040
import bigframes.core.indexes as indexes
41+
import bigframes.core.joins as joins
4142
import bigframes.core.ordering as ordering
4243
import bigframes.core.utils
4344
import bigframes.core.utils as utils
@@ -1436,6 +1437,78 @@ def concat(
14361437
result_block = result_block.reset_index()
14371438
return result_block
14381439

1440+
def merge(
1441+
self,
1442+
other: Block,
1443+
how: typing.Literal[
1444+
"inner",
1445+
"left",
1446+
"outer",
1447+
"right",
1448+
],
1449+
left_col_ids: typing.Sequence[str],
1450+
right_col_ids: typing.Sequence[str],
1451+
sort: bool,
1452+
suffixes: tuple[str, str] = ("_x", "_y"),
1453+
) -> Block:
1454+
(
1455+
joined_expr,
1456+
coalesced_join_cols,
1457+
(get_column_left, get_column_right),
1458+
) = joins.join_by_column(
1459+
self.expr,
1460+
left_col_ids,
1461+
other.expr,
1462+
right_col_ids,
1463+
how=how,
1464+
sort=sort,
1465+
)
1466+
1467+
# which join key parts should be coalesced
1468+
merge_join_key_mask = [
1469+
str(self.col_id_to_label[left_id]) == str(other.col_id_to_label[right_id])
1470+
for left_id, right_id in zip(left_col_ids, right_col_ids)
1471+
]
1472+
labels_to_coalesce = [
1473+
self.col_id_to_label[col_id]
1474+
for i, col_id in enumerate(left_col_ids)
1475+
if merge_join_key_mask[i]
1476+
]
1477+
1478+
def left_col_mapping(col_id: str) -> str:
1479+
if col_id in left_col_ids:
1480+
join_key_part = left_col_ids.index(col_id)
1481+
if merge_join_key_mask[join_key_part]:
1482+
return coalesced_join_cols[join_key_part]
1483+
return get_column_left(col_id)
1484+
1485+
def right_col_mapping(col_id: str) -> typing.Optional[str]:
1486+
if col_id in right_col_ids:
1487+
join_key_part = right_col_ids.index(col_id)
1488+
if merge_join_key_mask[join_key_part]:
1489+
return None
1490+
return get_column_right(col_id)
1491+
1492+
left_columns = [left_col_mapping(col_id) for col_id in self.value_columns]
1493+
1494+
right_columns = [
1495+
typing.cast(str, right_col_mapping(col_id))
1496+
for col_id in other.value_columns
1497+
if right_col_mapping(col_id)
1498+
]
1499+
1500+
expr = joined_expr.select_columns([*left_columns, *right_columns])
1501+
labels = utils.merge_column_labels(
1502+
self.column_labels,
1503+
other.column_labels,
1504+
coalesce_labels=labels_to_coalesce,
1505+
suffixes=suffixes,
1506+
)
1507+
1508+
# Constructs default index
1509+
expr, offset_index_id = expr.promote_offsets()
1510+
return Block(expr, index_columns=[offset_index_id], column_labels=labels)
1511+
14391512
def _force_reproject(self) -> Block:
14401513
"""Forces a reprojection of the underlying tables expression. Used to force predicate/order application before subsequent operations."""
14411514
return Block(

bigframes/core/joins/single_column.py

+20-40
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def join_by_column(
4444
"right",
4545
],
4646
sort: bool = False,
47-
coalesce_join_keys: bool = True,
4847
allow_row_identity_join: bool = True,
4948
) -> Tuple[
5049
core.ArrayValue,
@@ -59,8 +58,6 @@ def join_by_column(
5958
right: Expression for right table to join.
6059
right_column_ids: Column IDs (not label) to join by.
6160
how: The type of join to perform.
62-
coalesce_join_keys: if set to False, returned column ids will contain
63-
both left and right join key columns.
6461
allow_row_identity_join (bool):
6562
If True, allow matching by row identity. Set to False to always
6663
perform a true JOIN in generated SQL.
@@ -71,8 +68,6 @@ def join_by_column(
7168
* Sequence[str]: Column IDs of the coalesced join columns. Sometimes either the
7269
left/right table will have missing rows. This column pulls the
7370
non-NULL value from either left/right.
74-
If coalesce_join_keys is False, will return uncombined left and
75-
right key columns.
7671
* Tuple[Callable, Callable]: For a given column ID from left or right,
7772
respectively, return the new column id from the combined expression.
7873
"""
@@ -100,9 +95,7 @@ def join_by_column(
10095
right_join_keys = [
10196
combined_expr.get_column(get_column_right(col)) for col in right_column_ids
10297
]
103-
join_key_cols = get_join_cols(
104-
left_join_keys, right_join_keys, how, coalesce_join_keys
105-
)
98+
join_key_cols = get_coalesced_join_cols(left_join_keys, right_join_keys, how)
10699
join_key_ids = [col.get_name() for col in join_key_cols]
107100
combined_expr = combined_expr.projection(
108101
[*join_key_cols, *combined_expr.columns]
@@ -182,9 +175,7 @@ def get_column_right(col_id):
182175
right_join_keys = [
183176
combined_table[get_column_right(col)] for col in right_column_ids
184177
]
185-
join_key_cols = get_join_cols(
186-
left_join_keys, right_join_keys, how, coalesce_join_keys
187-
)
178+
join_key_cols = get_coalesced_join_cols(left_join_keys, right_join_keys, how)
188179
# We could filter out the original join columns, but predicates/ordering
189180
# might still reference them in implicit joins.
190181
columns = (
@@ -226,46 +217,35 @@ def get_column_right(col_id):
226217
)
227218

228219

229-
def get_join_cols(
220+
def get_coalesced_join_cols(
230221
left_join_cols: typing.Iterable[ibis_types.Value],
231222
right_join_cols: typing.Iterable[ibis_types.Value],
232223
how: str,
233-
coalesce_join_keys: bool = True,
234224
) -> typing.List[ibis_types.Value]:
235225
join_key_cols: list[ibis_types.Value] = []
236226
for left_col, right_col in zip(left_join_cols, right_join_cols):
237-
if not coalesce_join_keys:
227+
if how == "left" or how == "inner":
238228
join_key_cols.append(left_col.name(guid.generate_guid(prefix="index_")))
229+
elif how == "right":
239230
join_key_cols.append(right_col.name(guid.generate_guid(prefix="index_")))
240-
else:
241-
if how == "left" or how == "inner":
231+
elif how == "outer":
232+
# The left index and the right index might contain null values, for
233+
# example due to an outer join with different numbers of rows. Coalesce
234+
# these to take the index value from either column.
235+
# Use a random name in case the left index and the right index have the
236+
# same name. In such a case, _x and _y suffixes will already be used.
237+
# Don't need to coalesce if they are exactly the same column.
238+
if left_col.name("index").equals(right_col.name("index")):
242239
join_key_cols.append(left_col.name(guid.generate_guid(prefix="index_")))
243-
elif how == "right":
244-
join_key_cols.append(
245-
right_col.name(guid.generate_guid(prefix="index_"))
246-
)
247-
elif how == "outer":
248-
# The left index and the right index might contain null values, for
249-
# example due to an outer join with different numbers of rows. Coalesce
250-
# these to take the index value from either column.
251-
# Use a random name in case the left index and the right index have the
252-
# same name. In such a case, _x and _y suffixes will already be used.
253-
# Don't need to coalesce if they are exactly the same column.
254-
if left_col.name("index").equals(right_col.name("index")):
255-
join_key_cols.append(
256-
left_col.name(guid.generate_guid(prefix="index_"))
257-
)
258-
else:
259-
join_key_cols.append(
260-
ibis.coalesce(
261-
left_col,
262-
right_col,
263-
).name(guid.generate_guid(prefix="index_"))
264-
)
265240
else:
266-
raise ValueError(
267-
f"Unexpected join type: {how}. {constants.FEEDBACK_LINK}"
241+
join_key_cols.append(
242+
ibis.coalesce(
243+
left_col,
244+
right_col,
245+
).name(guid.generate_guid(prefix="index_"))
268246
)
247+
else:
248+
raise ValueError(f"Unexpected join type: {how}. {constants.FEEDBACK_LINK}")
269249
return join_key_cols
270250

271251

bigframes/core/utils.py

+33
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,36 @@ def get_standardized_ids(
8484
idx_ids, col_ids = ids[: len(idx_ids)], ids[len(idx_ids) :]
8585

8686
return col_ids, idx_ids
87+
88+
89+
def merge_column_labels(
90+
left_labels: pd.Index,
91+
right_labels: pd.Index,
92+
coalesce_labels: typing.Sequence,
93+
suffixes: tuple[str, str] = ("_x", "_y"),
94+
) -> pd.Index:
95+
result_labels = []
96+
97+
for col_label in left_labels:
98+
if col_label in right_labels:
99+
if col_label in coalesce_labels:
100+
# Merging on the same column only returns 1 key column from coalesce both.
101+
# Take the left key column.
102+
result_labels.append(col_label)
103+
else:
104+
result_labels.append(str(col_label) + suffixes[0])
105+
else:
106+
result_labels.append(col_label)
107+
108+
for col_label in right_labels:
109+
if col_label in left_labels:
110+
if col_label in coalesce_labels:
111+
# Merging on the same column only returns 1 key column from coalesce both.
112+
# Pass the right key column.
113+
pass
114+
else:
115+
result_labels.append(str(col_label) + suffixes[1])
116+
else:
117+
result_labels.append(col_label)
118+
119+
return pd.Index(result_labels)

bigframes/dataframe.py

+34-93
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
import bigframes.core.indexers as indexers
4747
import bigframes.core.indexes as indexes
4848
import bigframes.core.io
49-
import bigframes.core.joins as joins
5049
import bigframes.core.ordering as order
5150
import bigframes.core.utils as utils
5251
import bigframes.core.window
@@ -1758,12 +1757,10 @@ def merge(
17581757
] = "inner",
17591758
# TODO(garrettwu): Currently can take inner, outer, left and right. To support
17601759
# cross joins
1761-
# TODO(garrettwu): Support "on" list of columns and None. Currently a single
1762-
# column must be provided
1763-
on: Optional[str] = None,
1760+
on: Union[blocks.Label, Sequence[blocks.Label], None] = None,
17641761
*,
1765-
left_on: Optional[str] = None,
1766-
right_on: Optional[str] = None,
1762+
left_on: Union[blocks.Label, Sequence[blocks.Label], None] = None,
1763+
right_on: Union[blocks.Label, Sequence[blocks.Label], None] = None,
17671764
sort: bool = False,
17681765
suffixes: tuple[str, str] = ("_x", "_y"),
17691766
) -> DataFrame:
@@ -1777,97 +1774,41 @@ def merge(
17771774
)
17781775
left_on, right_on = on, on
17791776

1780-
left = self
1781-
left_on_sql = self._sql_names(left_on)
1782-
# 0 elements already throws an exception
1783-
if len(left_on_sql) > 1:
1784-
raise ValueError(f"The column label {left_on} is not unique.")
1785-
left_on_sql = left_on_sql[0]
1786-
1787-
right_on_sql = right._sql_names(right_on)
1788-
if len(right_on_sql) > 1:
1789-
raise ValueError(f"The column label {right_on} is not unique.")
1790-
right_on_sql = right_on_sql[0]
1791-
1792-
(
1793-
joined_expr,
1794-
join_key_ids,
1795-
(get_column_left, get_column_right),
1796-
) = joins.join_by_column(
1797-
left._block.expr,
1798-
[left_on_sql],
1799-
right._block.expr,
1800-
[right_on_sql],
1801-
how=how,
1802-
sort=sort,
1803-
# In merging on the same column, it only returns 1 key column from coalesced both.
1804-
# While if 2 different columns, both will be presented in the result.
1805-
coalesce_join_keys=(left_on == right_on),
1806-
)
1807-
# TODO(swast): Add suffixes to the column labels instead of reusing the
1808-
# column IDs as the new labels.
1809-
# Drop the index column(s) to be consistent with pandas.
1810-
left_columns = [
1811-
join_key_ids[0] if (col_id == left_on_sql) else get_column_left(col_id)
1812-
for col_id in left._block.value_columns
1813-
]
1814-
1815-
right_columns = []
1816-
for col_id in right._block.value_columns:
1817-
if col_id == right_on_sql:
1818-
# When left_on == right_on
1819-
if len(join_key_ids) > 1:
1820-
right_columns.append(join_key_ids[1])
1821-
else:
1822-
right_columns.append(get_column_right(col_id))
1823-
1824-
expr = joined_expr.select_columns([*left_columns, *right_columns])
1825-
labels = self._get_merged_col_labels(
1826-
right, left_on=left_on, right_on=right_on, suffixes=suffixes
1827-
)
1777+
if utils.is_list_like(left_on):
1778+
left_on = list(left_on) # type: ignore
1779+
else:
1780+
left_on = [left_on]
18281781

1829-
# Constructs default index
1830-
expr, offset_index_id = expr.promote_offsets()
1831-
block = blocks.Block(
1832-
expr, index_columns=[offset_index_id], column_labels=labels
1782+
if utils.is_list_like(right_on):
1783+
right_on = list(right_on) # type: ignore
1784+
else:
1785+
right_on = [right_on]
1786+
1787+
left_join_ids = []
1788+
for label in left_on: # type: ignore
1789+
left_col_id = self._resolve_label_exact(label)
1790+
# 0 elements already throws an exception
1791+
if not left_col_id:
1792+
raise ValueError(f"No column {label} found in self.")
1793+
left_join_ids.append(left_col_id)
1794+
1795+
right_join_ids = []
1796+
for label in right_on: # type: ignore
1797+
right_col_id = right._resolve_label_exact(label)
1798+
if not right_col_id:
1799+
raise ValueError(f"No column {label} found in other.")
1800+
right_join_ids.append(right_col_id)
1801+
1802+
block = self._block.merge(
1803+
right._block,
1804+
how,
1805+
left_join_ids,
1806+
right_join_ids,
1807+
sort=sort,
1808+
suffixes=suffixes,
18331809
)
18341810
return DataFrame(block)
18351811

1836-
def _get_merged_col_labels(
1837-
self,
1838-
right: DataFrame,
1839-
left_on: str,
1840-
right_on: str,
1841-
suffixes: tuple[str, str] = ("_x", "_y"),
1842-
) -> List[blocks.Label]:
1843-
on_col_equal = left_on == right_on
1844-
1845-
left_col_labels: list[blocks.Label] = []
1846-
for col_label in self._block.column_labels:
1847-
if col_label in right._block.column_labels:
1848-
if on_col_equal and col_label == left_on:
1849-
# Merging on the same column only returns 1 key column from coalesce both.
1850-
# Take the left key column.
1851-
left_col_labels.append(col_label)
1852-
else:
1853-
left_col_labels.append(str(col_label) + suffixes[0])
1854-
else:
1855-
left_col_labels.append(col_label)
1856-
1857-
right_col_labels: list[blocks.Label] = []
1858-
for col_label in right._block.column_labels:
1859-
if col_label in self._block.column_labels:
1860-
if on_col_equal and col_label == left_on:
1861-
# Merging on the same column only returns 1 key column from coalesce both.
1862-
# Pass the right key column.
1863-
pass
1864-
else:
1865-
right_col_labels.append(str(col_label) + suffixes[1])
1866-
else:
1867-
right_col_labels.append(col_label)
1868-
1869-
return left_col_labels + right_col_labels
1870-
18711812
def join(
18721813
self, other: DataFrame, *, on: Optional[str] = None, how: str = "left"
18731814
) -> DataFrame:

0 commit comments

Comments
 (0)