Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Make the available concat alignment strategies more generic #20644

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions py-polars/polars/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object:
"diagonal_relaxed",
"horizontal",
"align",
"align_full",
"align_inner",
"align_left",
"align_right",
]
CorrelationMethod: TypeAlias = Literal["pearson", "spearman"]
DbReadEngine: TypeAlias = Literal["adbc", "connectorx"]
Expand Down
102 changes: 70 additions & 32 deletions py-polars/polars/functions/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def concat(
----------
items
DataFrames, LazyFrames, or Series to concatenate.
how : {'vertical', 'vertical_relaxed', 'diagonal', 'diagonal_relaxed', 'horizontal', 'align'}
Series only support the `vertical` strategy.
how : {'vertical', 'vertical_relaxed', 'diagonal', 'diagonal_relaxed', 'horizontal', 'align', 'align_full', 'align_inner', 'align_left', 'align_right'}
Note that `Series` only support the `vertical` strategy.

* vertical: Applies multiple `vstack` operations.
* vertical_relaxed: Same as `vertical`, but additionally coerces columns to
Expand All @@ -49,10 +49,14 @@ def concat(
their common supertype *if* they are mismatched (eg: Int32 β†’ Int64).
* horizontal: Stacks Series from DataFrames horizontally and fills with `null`
if the lengths don't match.
* align: Combines frames horizontally, auto-determining the common key columns
and aligning rows using the same logic as `align_frames`; this behaviour is
patterned after a full outer join, but does not handle column-name collision.
(If you need more control, you should use a suitable join method instead).
* align, align_full, align_left, align_right: Combines frames horizontally,
auto-determining the common key columns and aligning rows using the same
logic as `align_frames` (note that "align" is an alias for "align_full").
The "align" strategy determines the type of join used to align the frames,
equivalent to the "how" parameter on `align_frames`. Note that the common
join columns are automatically coalesced, but other column collisions
will raise an error (if you need more control over this you should use
a suitable `join` method directly).
rechunk
Make sure that the result data is in contiguous memory.
parallel
Expand Down Expand Up @@ -100,6 +104,9 @@ def concat(
β”‚ 2 ┆ 4 ┆ 6 ┆ 8 ┆ 10 β”‚
β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”˜

The "diagonal" strategy allows for some frames to have missing columns,
the values for which are filled with `null`:

>>> df_d1 = pl.DataFrame({"a": [1], "b": [3]})
>>> df_d2 = pl.DataFrame({"a": [2], "c": [4]})
>>> pl.concat([df_d1, df_d2], how="diagonal")
Expand All @@ -113,10 +120,12 @@ def concat(
β”‚ 2 ┆ null ┆ 4 β”‚
β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”˜

The "align" strategies require at least one common column to align on:

>>> df_a1 = pl.DataFrame({"id": [1, 2], "x": [3, 4]})
>>> df_a2 = pl.DataFrame({"id": [2, 3], "y": [5, 6]})
>>> df_a3 = pl.DataFrame({"id": [1, 3], "z": [7, 8]})
>>> pl.concat([df_a1, df_a2, df_a3], how="align")
>>> pl.concat([df_a1, df_a2, df_a3], how="align") # equivalent to "align_full"
shape: (3, 4)
β”Œβ”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”
β”‚ id ┆ x ┆ y ┆ z β”‚
Expand All @@ -127,6 +136,34 @@ def concat(
β”‚ 2 ┆ 4 ┆ 5 ┆ null β”‚
β”‚ 3 ┆ null ┆ 6 ┆ 8 β”‚
β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”˜
>>> pl.concat([df_a1, df_a2, df_a3], how="align_left")
shape: (2, 4)
β”Œβ”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”
β”‚ id ┆ x ┆ y ┆ z β”‚
β”‚ --- ┆ --- ┆ --- ┆ --- β”‚
β”‚ i64 ┆ i64 ┆ i64 ┆ i64 β”‚
β•žβ•β•β•β•β•β•ͺ═════β•ͺ══════β•ͺ══════║
β”‚ 1 ┆ 3 ┆ null ┆ 7 β”‚
β”‚ 2 ┆ 4 ┆ 5 ┆ null β”‚
β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”˜
>>> pl.concat([df_a1, df_a2, df_a3], how="align_right")
shape: (2, 4)
β”Œβ”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”
β”‚ id ┆ x ┆ y ┆ z β”‚
β”‚ --- ┆ --- ┆ --- ┆ --- β”‚
β”‚ i64 ┆ i64 ┆ i64 ┆ i64 β”‚
β•žβ•β•β•β•β•β•ͺ══════β•ͺ══════β•ͺ═════║
β”‚ 1 ┆ null ┆ null ┆ 7 β”‚
β”‚ 3 ┆ null ┆ 6 ┆ 8 β”‚
β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”˜
>>> pl.concat([df_a1, df_a2, df_a3], how="align_inner")
shape: (0, 4)
β”Œβ”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”
β”‚ id ┆ x ┆ y ┆ z β”‚
β”‚ --- ┆ --- ┆ --- ┆ --- β”‚
β”‚ i64 ┆ i64 ┆ i64 ┆ i64 β”‚
β•žβ•β•β•β•β•β•ͺ═════β•ͺ═════β•ͺ═════║
β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”˜
""" # noqa: W505
# unpack/standardise (handles generator input)
elems = list(items)
Expand All @@ -139,47 +176,48 @@ def concat(
):
return elems[0]

if how == "align":
if how.startswith("align"):
if not isinstance(elems[0], (pl.DataFrame, pl.LazyFrame)):
msg = f"'align' strategy is not supported for {type(elems[0]).__name__!r}"
msg = f"{how!r} strategy is not supported for {type(elems[0]).__name__!r}"
raise TypeError(msg)

# establish common columns, maintaining the order in which they appear
all_columns = list(chain.from_iterable(e.collect_schema() for e in elems))
key = {v: k for k, v in enumerate(ordered_unique(all_columns))}
output_column_order = list(key)
common_cols = sorted(
reduce(
lambda x, y: set(x) & set(y), # type: ignore[arg-type, return-value]
chain(e.collect_schema() for e in elems),
),
key=lambda k: key.get(k, 0),
)
# we require at least one key column for 'align'
# we require at least one key column for 'align' strategies
if not common_cols:
msg = "'align' strategy requires at least one common column"
msg = f"{how!r} strategy requires at least one common column"
raise InvalidOperationError(msg)

# align the frame data using a full outer join with no suffix-resolution
# (so we raise an error in case of column collision, like "horizontal")
lf: LazyFrame = reduce(
lambda x, y: (
x.join(
y,
how="full",
on=common_cols,
suffix="_PL_CONCAT_RIGHT",
maintain_order="right_left",
)
# Coalesce full outer join columns
.with_columns(
F.coalesce([name, f"{name}_PL_CONCAT_RIGHT"])
for name in common_cols
)
.drop([f"{name}_PL_CONCAT_RIGHT" for name in common_cols])
),
[df.lazy() for df in elems],
).sort(by=common_cols)

# align frame data using a join, with no suffix-resolution (will raise
# a DuplicateError in case of column collision, same as "horizontal")
join_method: JoinStrategy = (
"full" if how == "align" else how.removeprefix("align_") # type: ignore[assignment]
)
lf: LazyFrame = (
reduce(
lambda x, y: (
x.join(
y,
on=common_cols,
how=join_method,
maintain_order="right_left",
coalesce=True,
)
),
[df.lazy() for df in elems],
)
.sort(by=common_cols)
.select(*output_column_order)
)
eager = isinstance(elems[0], pl.DataFrame)
return lf.collect() if eager else lf # type: ignore[return-value]

Expand Down
55 changes: 44 additions & 11 deletions py-polars/tests/unit/functions/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,64 @@ def test_concat_align() -> None:
b = pl.DataFrame({"a": ["a", "b", "c"], "c": [5.5, 6.0, 7.5]})
c = pl.DataFrame({"a": ["a", "b", "c", "d", "e"], "d": ["w", "x", "y", "z", None]})

result = pl.concat([a, b, c], how="align")
for align_full in ("align", "align_full"):
result = pl.concat([a, b, c], how=align_full)
expected = pl.DataFrame(
{
"a": ["a", "b", "c", "d", "e", "e"],
"b": [1, 2, None, 4, 5, 6],
"c": [5.5, 6.0, 7.5, None, None, None],
"d": ["w", "x", "y", "z", None, None],
}
)
assert_frame_equal(result, expected)

result = pl.concat([a, b, c], how="align_left")
expected = pl.DataFrame(
{
"a": ["a", "b", "c", "d", "e", "e"],
"b": [1, 2, None, 4, 5, 6],
"c": [5.5, 6.0, 7.5, None, None, None],
"d": ["w", "x", "y", "z", None, None],
"a": ["a", "b", "d", "e", "e"],
"b": [1, 2, 4, 5, 6],
"c": [5.5, 6.0, None, None, None],
"d": ["w", "x", "z", None, None],
}
)
assert_frame_equal(result, expected)

result = pl.concat([a, b, c], how="align_right")
expected = pl.DataFrame(
{
"a": ["a", "b", "c", "d", "e"],
"b": [1, 2, None, None, None],
"c": [5.5, 6.0, 7.5, None, None],
"d": ["w", "x", "y", "z", None],
}
)
assert_frame_equal(result, expected)

def test_concat_align_no_common_cols() -> None:
result = pl.concat([a, b, c], how="align_inner")
expected = pl.DataFrame(
{
"a": ["a", "b"],
"b": [1, 2],
"c": [5.5, 6.0],
"d": ["w", "x"],
}
)
assert_frame_equal(result, expected)


@pytest.mark.parametrize(
"strategy", ["align", "align_full", "align_left", "align_right"]
)
def test_concat_align_no_common_cols(strategy: ConcatMethod) -> None:
df1 = pl.DataFrame({"a": [1, 2], "b": [1, 2]})
df2 = pl.DataFrame({"c": [3, 4], "d": [3, 4]})

with pytest.raises(
InvalidOperationError,
match="'align' strategy requires at least one common column",
match=f"{strategy!r} strategy requires at least one common column",
):
pl.concat((df1, df2), how="align")


data2 = pl.DataFrame({"field3": [3, 4], "field4": ["C", "D"]})
pl.concat((df1, df2), how=strategy)


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def test_invalid_concat_type_err() -> None:
)
with pytest.raises(
ValueError,
match="DataFrame `how` must be one of {'vertical', 'vertical_relaxed', 'diagonal', 'diagonal_relaxed', 'horizontal', 'align'}, got 'sausage'",
match="DataFrame `how` must be one of {'vertical', '.+', 'align_right'}, got 'sausage'",
):
pl.concat([df, df], how="sausage") # type: ignore[arg-type]

Expand Down
Loading