Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.list.index_of_in() architectural review PR #20733

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/polars-lazy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ sign = ["polars-plan/sign"]
timezones = ["polars-plan/timezones"]
list_gather = ["polars-ops/list_gather", "polars-plan/list_gather"]
list_count = ["polars-ops/list_count", "polars-plan/list_count"]
list_index_of_in = ["polars-ops/list_index_of_in", "polars-plan/list_index_of_in"]
array_count = ["polars-ops/array_count", "polars-plan/array_count", "dtype-array"]
true_div = ["polars-plan/true_div"]
extract_jsonpath = ["polars-plan/extract_jsonpath", "polars-ops/extract_jsonpath"]
Expand Down Expand Up @@ -377,6 +378,7 @@ features = [
"list_drop_nulls",
"list_eval",
"list_gather",
"list_index_of_in",
"list_sample",
"list_sets",
"list_to_struct",
Expand Down
1 change: 1 addition & 0 deletions crates/polars-ops/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,4 @@ abs = []
cov = []
gather = []
replace = ["is_in"]
list_index_of_in = []
54 changes: 54 additions & 0 deletions crates/polars-ops/src/chunked_array/list/index_of_in.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use super::*;
use crate::series::index_of;

pub fn list_index_of_in(ca: &ListChunked, needles: &Series) -> PolarsResult<Series> {
let mut builder = PrimitiveChunkedBuilder::<IdxType>::new(ca.name().clone(), ca.len());
if needles.len() == 1 {
// For some reason we need to do casting ourselves.
let needle = needles.get(0).unwrap();
let cast_needle = needle.cast(ca.dtype().inner_dtype().unwrap());
if cast_needle != needle {
todo!("nicer error handling");
}
let needle = Scalar::new(
cast_needle.dtype().clone(),
cast_needle.into_static(),
);
ca.amortized_iter().for_each(|opt_series| {
if let Some(subseries) = opt_series {
builder.append_option(
// TODO clone() sucks, maybe need to change the API for
// index_of so it takes AnyValue<'_> instead of a Scalar
// which implies AnyValue<'static>?
index_of(subseries.as_ref(), needle.clone())
.unwrap()
.map(|v| v.try_into().unwrap()),
);
} else {
builder.append_null();
}
});
} else {
ca.amortized_iter()
// TODO iter() assumes a single chunk. could continue to use this
// and just rechunk(), or have needles also be a ChunkedArray, in
// which case we'd need to have to use one of the
// dispatch-on-dtype-and-cast-to-relevant-chunkedarray-type macros
// to duplicate the implementation code per dtype.
.zip(needles.iter())
.for_each(|(opt_series, needle)| {
match (opt_series, needle) {
(None, _) => builder.append_null(),
(Some(subseries), needle) => {
let needle = Scalar::new(needles.dtype().clone(), needle.into_static());
builder.append_option(
index_of(subseries.as_ref(), needle)
.unwrap()
.map(|v| v.try_into().unwrap()),
);
},
}
});
}
Ok(builder.finish().into())
}
4 changes: 4 additions & 0 deletions crates/polars-ops/src/chunked_array/list/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ mod sets;
mod sum_mean;
#[cfg(feature = "list_to_struct")]
mod to_struct;
#[cfg(feature = "list_index_of_in")]
mod index_of_in;

#[cfg(feature = "list_count")]
pub use count::*;
Expand All @@ -23,6 +25,8 @@ pub use namespace::*;
pub use sets::*;
#[cfg(feature = "list_to_struct")]
pub use to_struct::*;
#[cfg(feature = "list_index_of_in")]
pub use index_of_in::*;

pub trait AsList {
fn as_list(&self) -> &ListChunked;
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-plan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ dtype-struct = ["polars-core/dtype-struct"]
object = ["polars-core/object"]
list_gather = ["polars-ops/list_gather"]
list_count = ["polars-ops/list_count"]
list_index_of_in = ["polars-ops/list_index_of_in"]
array_count = ["polars-ops/array_count", "dtype-array"]
trigonometry = []
sign = []
Expand Down Expand Up @@ -295,6 +296,7 @@ features = [
"streaming",
"true_div",
"sign",
"list_index_of_in",
]
# defines the configuration attribute `docsrs`
rustdoc-args = ["--cfg", "docsrs"]
16 changes: 16 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ pub enum ListFunction {
ToArray(usize),
#[cfg(feature = "list_to_struct")]
ToStruct(ListToStructArgs),
#[cfg(feature = "list_index_of_in")]
IndexOfIn,
}

impl ListFunction {
Expand Down Expand Up @@ -107,6 +109,8 @@ impl ListFunction {
NUnique => mapper.with_dtype(IDX_DTYPE),
#[cfg(feature = "list_to_struct")]
ToStruct(args) => mapper.try_map_dtype(|x| args.get_output_dtype(x)),
#[cfg(feature = "list_index_of_in")]
IndexOfIn => mapper.with_dtype(IDX_DTYPE),
}
}
}
Expand Down Expand Up @@ -180,6 +184,8 @@ impl Display for ListFunction {
ToArray(_) => "to_array",
#[cfg(feature = "list_to_struct")]
ToStruct(_) => "to_struct",
#[cfg(feature = "list_index_of_in")]
IndexOfIn => "index_of_in",
};
write!(f, "list.{name}")
}
Expand Down Expand Up @@ -243,6 +249,8 @@ impl From<ListFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
NUnique => map!(n_unique),
#[cfg(feature = "list_to_struct")]
ToStruct(args) => map!(to_struct, &args),
#[cfg(feature = "list_index_of_in")]
IndexOfIn => map_as_slice!(index_of_in),
}
}
}
Expand Down Expand Up @@ -547,6 +555,14 @@ pub(super) fn count_matches(args: &[Column]) -> PolarsResult<Column> {
list_count_matches(ca, element.get(0).unwrap()).map(Column::from)
}

#[cfg(feature = "list_index_of_in")]
pub(super) fn index_of_in(args: &[Column]) -> PolarsResult<Column> {
let s = &args[0];
let needles = &args[1];
let ca = s.list()?;
list_index_of_in(ca, needles.as_materialized_series()).map(Column::from)
}

pub(super) fn sum(s: &Column) -> PolarsResult<Column> {
s.list()?.lst_sum().map(Column::from)
}
Expand Down
13 changes: 13 additions & 0 deletions crates/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,19 @@ impl ListNameSpace {
)
}

#[cfg(feature = "list_index_of_in")]
/// Find the index of needle in the list.
pub fn index_of_in<N: Into<Expr>>(self, needle: N) -> Expr {
let other = needle.into();

self.0.map_many_private(
FunctionExpr::ListExpr(ListFunction::IndexOfIn),
&[other],
false,
None,
)
}

#[cfg(feature = "list_sets")]
fn set_operation(self, other: Expr, set_operation: SetOperation) -> Expr {
Expr::Function {
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ new_streaming = ["polars-lazy/new_streaming"]
bitwise = ["polars/bitwise"]
approx_unique = ["polars/approx_unique"]
string_normalize = ["polars/string_normalize"]
list_index_of_in = ["polars/list_index_of_in"]

dtype-i8 = []
dtype-i16 = []
Expand Down Expand Up @@ -207,6 +208,7 @@ operations = [
"list_any_all",
"list_drop_nulls",
"list_sample",
"list_index_of_in",
"cutqcut",
"rle",
"extract_groups",
Expand Down
5 changes: 5 additions & 0 deletions crates/polars-python/src/expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ impl PyExpr {
.into()
}

#[cfg(feature = "list_index_of_in")]
fn list_index_of_in(&self, value: PyExpr) -> Self {
self.inner.clone().list().index_of_in(value.inner).into()
}

fn list_join(&self, separator: PyExpr, ignore_nulls: bool) -> Self {
self.inner
.clone()
Expand Down
1 change: 1 addition & 0 deletions crates/polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ trigonometry = ["polars-lazy?/trigonometry"]
true_div = ["polars-lazy?/true_div"]
unique_counts = ["polars-ops/unique_counts", "polars-lazy?/unique_counts"]
zip_with = ["polars-core/zip_with"]
list_index_of_in = ["polars-ops/list_index_of_in", "polars-lazy?/list_index_of_in"]

bigidx = ["polars-core/bigidx", "polars-lazy?/bigidx", "polars-ops/big_idx", "polars-utils/bigidx"]
polars_cloud = ["polars-lazy?/polars_cloud"]
Expand Down
1 change: 1 addition & 0 deletions py-polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ list_any_all = ["polars-python/list_any_all"]
array_any_all = ["polars-python/array_any_all"]
list_drop_nulls = ["polars-python/list_drop_nulls"]
list_sample = ["polars-python/list_sample"]
list_index_of_in = ["polars-python/list_index_of_in"]
cutqcut = ["polars-python/cutqcut"]
rle = ["polars-python/rle"]
extract_groups = ["polars-python/extract_groups"]
Expand Down
16 changes: 16 additions & 0 deletions py-polars/polars/expr/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,22 @@ def count_matches(self, element: IntoExpr) -> Expr:
element = parse_into_expression(element, str_as_lit=True)
return wrap_expr(self._pyexpr.list_count_matches(element))

def index_of_in(self, element: IntoExpr) -> Expr:
"""
TODO

Parameters
----------
needles
TODO

Examples
--------
TODO
"""
element = parse_into_expression(element, str_as_lit=True)
return wrap_expr(self._pyexpr.list_index_of_in(element))

def to_array(self, width: int) -> Expr:
"""
Convert a List column into an Array column with the same inner data type.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Tests for ``.list.index_of_in()``."""

import polars as pl
from polars.testing import assert_frame_equal


def test_index_of_in_from_constant() -> None:
df = pl.DataFrame({"lists": [[3, 1], [2, 4], [5, 3, 1]]})
assert_frame_equal(
df.select(pl.col("lists").list.index_of_in(1)),
pl.DataFrame({"lists": [1, None, 2]}, schema={"lists": pl.get_index_type()}),
)


def test_index_of_in_from_column() -> None:
df = pl.DataFrame({"lists": [[3, 1], [2, 4], [5, 3, 1]], "values": [1, 2, 6]})
assert_frame_equal(
df.select(pl.col("lists").list.index_of_in(pl.col("values"))),
pl.DataFrame({"lists": [1, 0, None]}, schema={"lists": pl.get_index_type()}),
)
Loading