Skip to content

Commit

Permalink
* Additional tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ag-ramachandran committed Jan 9, 2025
1 parent c4b67d5 commit 747ca9f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 28 deletions.
14 changes: 7 additions & 7 deletions sqlalchemy_kusto/dialect_kql.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import re
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple

from sqlalchemy import Column, exc, sql
from sqlalchemy.sql import compiler, operators, selectable
Expand All @@ -22,7 +22,7 @@
"min": "min",
"max": "max",
}
#AGGREGATE_PATTERN = r"(\w+)\s*\(\s*(DISTINCT|distinct\s*)?\(?\s*(\*|\w+)\s*\)?\s*\)"
# AGGREGATE_PATTERN = r"(\w+)\s*\(\s*(DISTINCT|distinct\s*)?\(?\s*(\*|\w+)\s*\)?\s*\)"
AGGREGATE_PATTERN = r"(\w+)\s*\(\s*(DISTINCT|distinct\s*)?\(?\s*(\*|\[?\"?\'?\w+\"?\]?)\s*\)?\s*\)"


Expand Down Expand Up @@ -107,7 +107,7 @@ def visit_select(
def limit_clause(self, select, **kw):
return ""

def _get_projection_or_summarize(self, select: selectable.Select) -> dict[str, str]:
def _get_projection_or_summarize(self, select: selectable.Select) -> Dict[str, str]:
"""Builds the ending part of the query either project or summarize"""
columns = select.inner_columns
group_by_cols = select._group_by_clauses # pylint: disable=protected-access
Expand Down Expand Up @@ -171,7 +171,7 @@ def _get_projection_or_summarize(self, select: selectable.Select) -> dict[str, s
}

@staticmethod
def _extract_maybe_agg_column_parts(column_name):
def _extract_maybe_agg_column_parts(column_name) -> Optional[str]:
match_agg_cols = re.match(AGGREGATE_PATTERN, column_name, re.IGNORECASE)
if match_agg_cols and match_agg_cols.groups():
# Check if the aggregate function is count_distinct. This is case from superset
Expand All @@ -189,7 +189,7 @@ def _get_order_by(self, order_by_cols):
nested_element = elem.element
unwrapped_order_by.append(
f"{self._escape_and_quote_columns(nested_element._order_by_label_element.name,is_alias=True)} "
f"{'desc' if (nested_element.modifier == operators.desc_op) else 'asc'}"
f"{'desc' if (nested_element.modifier is operators.desc_op) else 'asc'}"
)
elif isinstance(elem, sql.elements.TextClause):
sort_parts = elem.text.split()
Expand All @@ -216,7 +216,7 @@ def _group_by(self, group_by_cols):
return by_columns

@staticmethod
def _escape_and_quote_columns(name: str, is_alias=False):
def _escape_and_quote_columns(name: Optional[str], is_alias=False):
if name is None:
return None
name = name.strip()
Expand Down Expand Up @@ -367,7 +367,7 @@ def _convert_schema_in_statement(query: str) -> str:
return query.replace(original, f'database("{unquoted_schema}").["{unquoted_table}"]', 1)

@staticmethod
def _sql_to_kql_aggregate(sql_agg: str, column_name: str = None, is_distinct: bool = False) -> str:
def _sql_to_kql_aggregate(sql_agg: str, column_name: str = None, is_distinct: bool = False) -> Optional[str]:
"""
Converts SQL aggregate function to KQL equivalent.
If a column name is provided, applies it to the aggregate.
Expand Down
43 changes: 22 additions & 21 deletions tests/unit/test_dialect_kql.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,23 +206,23 @@ def test_escape_and_quote_columns():
assert KustoKqlCompiler._escape_and_quote_columns("EventInfo_Time / time(1d)") == '["EventInfo_Time"] / time(1d)'


def test_sql_to_kql_aggregate():
assert KustoKqlCompiler._sql_to_kql_aggregate("count(*)") == "count()"
assert KustoKqlCompiler._sql_to_kql_aggregate("count", "UserId") == 'count(["UserId"])'
assert (
KustoKqlCompiler._sql_to_kql_aggregate("count(distinct", "CustomerId", is_distinct=True)
== 'dcount(["CustomerId"])'
)
assert (
KustoKqlCompiler._sql_to_kql_aggregate("count_distinct", "CustomerId", is_distinct=True)
== 'dcount(["CustomerId"])'
)
assert KustoKqlCompiler._sql_to_kql_aggregate("sum", "Sales") == 'sum(["Sales"])'
assert KustoKqlCompiler._sql_to_kql_aggregate("avg", "ResponseTime") == 'avg(["ResponseTime"])'
assert KustoKqlCompiler._sql_to_kql_aggregate("AVG", "ResponseTime") == 'avg(["ResponseTime"])'
assert KustoKqlCompiler._sql_to_kql_aggregate("min", "Size") == 'min(["Size"])'
assert KustoKqlCompiler._sql_to_kql_aggregate("max", "Area") == 'max(["Area"])'
assert KustoKqlCompiler._sql_to_kql_aggregate("unknown", "Column") is None
@pytest.mark.parametrize(
"sql_aggregate, column_name, is_distinct, expected_kql",
[
("count(*)", None, False, "count()"),
("count", "UserId", False, 'count(["UserId"])'),
("count(distinct", "CustomerId", True, 'dcount(["CustomerId"])'),
("count_distinct", "CustomerId", True, 'dcount(["CustomerId"])'),
("sum", "Sales", False, 'sum(["Sales"])'),
("avg", "ResponseTime", False, 'avg(["ResponseTime"])'),
("AVG", "ResponseTime", False, 'avg(["ResponseTime"])'),
("min", "Size", False, 'min(["Size"])'),
("max", "Area", False, 'max(["Area"])'),
("unknown", "Column", False, None),
],
)
def test_sql_to_kql_aggregate(sql_aggregate, column_name, is_distinct, expected_kql):
assert KustoKqlCompiler._sql_to_kql_aggregate(sql_aggregate, column_name, is_distinct) == expected_kql


def test_use_table():
Expand Down Expand Up @@ -351,19 +351,20 @@ def test_schema_from_metadata(table_name: str, schema_name: str, expected_table_
("SUM(scores)", 'sum(["scores"])'),
('MIN("scores")', 'min(["scores"])'),
('MIN(["scores"])', 'min(["scores"])'),
('max(scores)', 'max(["scores"])'),
('startofmonth(somedate)', None),
('startofmonth(somedate)/time(1d)', None),
("max(scores)", 'max(["scores"])'),
("startofmonth(somedate)", None),
("startofmonth(somedate)/time(1d)", None),
],
)
def test_match_aggregates(column_name: str, expected_aggregate: str):
kql_agg = KustoKqlCompiler._extract_maybe_agg_column_parts(column_name)
if expected_aggregate:
assert kql_agg is not None
assert kql_agg == expected_aggregate
else :
else:
assert kql_agg is None


@pytest.mark.parametrize(
"query_table_name,expected_table_name",
[
Expand Down

0 comments on commit 747ca9f

Please sign in to comment.