Skip to content

Commit

Permalink
feat: Add SQL support for the NORMALIZE string function
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Jan 14, 2025
1 parent 8dcaec2 commit c4f8e08
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 7 deletions.
3 changes: 0 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 0 additions & 3 deletions crates/polars-sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ repository = { workspace = true }
description = "SQL transpiler for Polars. Converts SQL to Polars logical plans"

[dependencies]
arrow = { workspace = true }
polars-core = { workspace = true, features = ["rows"] }
polars-error = { workspace = true }
polars-lazy = { workspace = true, features = ["abs", "binary_encoding", "concat_str", "cross_join", "cum_agg", "dtype-date", "dtype-decimal", "dtype-struct", "is_in", "list_eval", "log", "meta", "regex", "round_series", "sign", "string_normalize", "string_reverse", "strings", "timezones", "trigonometry"] }
Expand All @@ -19,10 +18,8 @@ polars-time = { workspace = true }
polars-utils = { workspace = true }

hex = { workspace = true }
once_cell = { workspace = true }
rand = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
sqlparser = { workspace = true }

[dev-dependencies]
Expand Down
35 changes: 34 additions & 1 deletion crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use polars_core::prelude::{
use polars_lazy::dsl::Expr;
#[cfg(feature = "list_eval")]
use polars_lazy::dsl::ListNameSpaceExtension;
use polars_ops::chunked_array::UnicodeForm;
use polars_plan::dsl::{coalesce, concat_str, len, max_horizontal, min_horizontal, when};
use polars_plan::plans::{typed_lit, LiteralValue};
use polars_plan::prelude::LiteralValue::Null;
Expand Down Expand Up @@ -376,6 +377,13 @@ pub(crate) enum PolarsSQLFunctions {
/// SELECT LTRIM(column_1) FROM df;
/// ```
LTrim,
/// SQL 'normalize' function
/// Convert string to Unicode normalization form
/// (one of "NFC", "NFKC", "NFD", or "NFKD").
/// ```sql
/// SELECT NORMALIZE(column_1, 'NFC') FROM df;
/// ```
Normalize,
/// SQL 'octet_length' function
/// Returns the length of a given string in bytes.
/// ```sql
Expand All @@ -391,7 +399,7 @@ pub(crate) enum PolarsSQLFunctions {
/// SQL 'replace' function
/// Replace a given substring with another string.
/// ```sql
/// SELECT REPLACE(column_1,'old','new') FROM df;
/// SELECT REPLACE(column_1, 'old', 'new') FROM df;
/// ```
Replace,
/// SQL 'reverse' function
Expand Down Expand Up @@ -859,6 +867,7 @@ impl PolarsSQLFunctions {
"left" => Self::Left,
"lower" => Self::Lower,
"ltrim" => Self::LTrim,
"normalize" => Self::Normalize,
"octet_length" => Self::OctetLength,
"strpos" => Self::StrPos,
"regexp_like" => Self::RegexpLike,
Expand Down Expand Up @@ -1152,6 +1161,30 @@ impl SQLFunctionVisitor<'_> {
},
}
},
Normalize => {
let args = extract_args(function)?;
match args.len() {
1 => self.visit_unary(|e| e.str().normalize(UnicodeForm::NFC)),
2 => {
self.try_visit_binary(|e, form| {
let form = match form {
Expr::Literal(LiteralValue::String(s)) => match s.as_str() {
"NFC" => UnicodeForm::NFC,
"NFD" => UnicodeForm::NFD,
"NFKC" => UnicodeForm::NFKC,
"NFKD" => UnicodeForm::NFKD,
_ => polars_bail!(SQLSyntax: "invalid 'form' for NORMALIZE (found {})", s),
},
_ => polars_bail!(SQLSyntax: "invalid 'form' for NORMALIZE (found {})", form),
};
Ok(e.str().normalize(form))
})
},
_ => {
polars_bail!(SQLSyntax: "NORMALIZE expects 1-2 arguments (found {})", args.len())
},
}
},
OctetLength => self.visit_unary(|e| e.str().len_bytes()),
StrPos => {
// // note: SQL is 1-indexed; returns zero if no match found
Expand Down
15 changes: 15 additions & 0 deletions py-polars/tests/unit/sql/test_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,21 @@ def test_string_like_multiline() -> None:
assert df.sql(f"SELECT txt FROM self WHERE txt LIKE '{s}'").item() == s


@pytest.mark.parametrize("form", ["NFKC", "NFKD"])
def test_string_normalize(form: str) -> None:
df = pl.DataFrame({"txt": ["Test", "𝕋𝕖𝕤𝕥", "𝕿𝖊𝖘𝖙", "𝗧𝗲𝘀𝘁", "Ⓣⓔⓢⓣ"]}) # noqa: RUF001
res = df.sql(
f"""
SELECT txt, NORMALIZE(txt,'{form}') AS norm_txt
FROM self
"""
)
assert res.to_dict(as_series=False) == {
"txt": ["Test", "𝕋𝕖𝕤𝕥", "𝕿𝖊𝖘𝖙", "𝗧𝗲𝘀𝘁", "Ⓣⓔⓢⓣ"], # noqa: RUF001
"norm_txt": ["Test", "Test", "Test", "Test", "Test"],
}


def test_string_position() -> None:
df = pl.Series(
name="city",
Expand Down

0 comments on commit c4f8e08

Please sign in to comment.