Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Additional checks
Browse files Browse the repository at this point in the history
ag-ramachandran committed Jan 7, 2025
1 parent 56fd2fc commit c4b67d5
Showing 2 changed files with 49 additions and 57 deletions.
35 changes: 20 additions & 15 deletions sqlalchemy_kusto/dialect_kql.py
Original file line number Diff line number Diff line change
@@ -22,7 +22,8 @@
"min": "min",
"max": "max",
}
AGGREGATE_PATTERN = r"(\w+)\s*\(\s*(DISTINCT|distinct\s*)?\(?\s*(\*|\w+)\s*\)?\s*\)"
#AGGREGATE_PATTERN = r"(\w+)\s*\(\s*(DISTINCT|distinct\s*)?\(?\s*(\*|\w+)\s*\)?\s*\)"
AGGREGATE_PATTERN = r"(\w+)\s*\(\s*(DISTINCT|distinct\s*)?\(?\s*(\*|\[?\"?\'?\w+\"?\]?)\s*\)?\s*\)"


class UniversalSet:
@@ -37,6 +38,7 @@ class KustoKqlIdentifierPreparer(compiler.IdentifierPreparer):
def __init__(self, dialect, **kw):
super().__init__(dialect, initial_quote='["', final_quote='"]', **kw)


class KustoKqlCompiler(compiler.SQLCompiler):
OPERATORS[operators.and_] = " and "

@@ -58,7 +60,6 @@ def visit_select(
**kwargs,
):
logger.debug("Incoming query: %s", select_stmt)

if len(select_stmt.get_final_froms()) != 1:
raise NotSupportedError('Only single "select from" query is supported in kql compiler')
compiled_query_lines = []
@@ -98,9 +99,7 @@ def visit_select(
compiled_query_lines.append(
f"| take {self.process(select_stmt._limit_clause, **kwargs)}"
) # pylint: disable=protected-access

compiled_query_lines = list(filter(None, compiled_query_lines))

compiled_query = "\n".join(compiled_query_lines)
logger.warning("Compiled query: %s", compiled_query)
return compiled_query
@@ -135,15 +134,10 @@ def _get_projection_or_summarize(self, select: selectable.Select) -> dict[str, s
column_name, column_alias = self._extract_column_name_and_alias(column)
column_alias = self._escape_and_quote_columns(column_alias, True)
# Do we have a group by clause ?
match_agg_cols = re.match(AGGREGATE_PATTERN, column_name, re.IGNORECASE)
# Do we have aggregate columns ?
if match_agg_cols:
kql_agg = self._extract_maybe_agg_column_parts(column_name)
if kql_agg:
has_aggregates = True
aggregate_func, distinct_keyword, agg_column_name = match_agg_cols.groups()
# Check if the aggregate function is count_distinct. This is case from superset
# where we can use count(distinct or count_distinct)
is_distinct = bool(distinct_keyword) or aggregate_func.casefold() == "count_distinct"
kql_agg = self._sql_to_kql_aggregate(aggregate_func, agg_column_name, is_distinct)
summarize_columns.add(self._build_column_projection(kql_agg, column_alias))
# No group by clause
else:
@@ -152,7 +146,7 @@ def _get_projection_or_summarize(self, select: selectable.Select) -> dict[str, s
if column_alias and column_alias != column_name:
extend_columns.add(self._build_column_projection(column_name, column_alias, True))
if column_alias:
projection_columns.append(self._escape_and_quote_columns(column_alias,True))
projection_columns.append(self._escape_and_quote_columns(column_alias, True))
else:
projection_columns.append(self._escape_and_quote_columns(column_name))
# group by columns
@@ -176,6 +170,18 @@ def _get_projection_or_summarize(self, select: selectable.Select) -> dict[str, s
"sort": sort_statement,
}

@staticmethod
def _extract_maybe_agg_column_parts(column_name):
match_agg_cols = re.match(AGGREGATE_PATTERN, column_name, re.IGNORECASE)
if match_agg_cols and match_agg_cols.groups():
# Check if the aggregate function is count_distinct. This is case from superset
# where we can use count(distinct or count_distinct)
aggregate_func, distinct_keyword, agg_column_name = match_agg_cols.groups()
is_distinct = bool(distinct_keyword) or aggregate_func.casefold() == "count_distinct"
kql_agg = KustoKqlCompiler._sql_to_kql_aggregate(aggregate_func.lower(), agg_column_name, is_distinct)
return kql_agg
return None

def _get_order_by(self, order_by_cols):
unwrapped_order_by = []
for elem in order_by_cols:
@@ -192,14 +198,13 @@ def _get_order_by(self, order_by_cols):
f"{self._escape_and_quote_columns(sort_parts[0],is_alias=True)} {sort_parts[1].lower()}"
)
elif len(sort_parts) == 1:
unwrapped_order_by.append(self._escape_and_quote_columns(sort_parts[0],is_alias=True))
unwrapped_order_by.append(self._escape_and_quote_columns(sort_parts[0], is_alias=True))
else:
unwrapped_order_by.append(elem.text.replace(" ASC", " asc").replace(" DESC", " desc"))
else:
logger.warning("Unsupported order by clause: %s of type %s", elem, type(elem))
return unwrapped_order_by


def _group_by(self, group_by_cols):
by_columns = set()
for column in group_by_cols:
@@ -382,7 +387,7 @@ def _sql_to_kql_aggregate(sql_agg: str, column_name: str = None, is_distinct: bo
if return_value:
return return_value
# Other summarize operators have to be looked up
aggregate_function = aggregates_sql_to_kql.get(sql_agg.split("(")[0])
aggregate_function = aggregates_sql_to_kql.get(sql_agg.lower().split("(")[0])
if aggregate_function:
return_value = f"{aggregate_function}({column_name_escaped})"
return return_value
71 changes: 29 additions & 42 deletions tests/unit/test_dialect_kql.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import re

import pytest
import sqlalchemy as sa
from sqlalchemy import (
@@ -9,14 +11,15 @@
column,
create_engine,
distinct,
func,
literal_column,
select,
text, func,
text,
)
from sqlalchemy.sql import sqltypes
from sqlalchemy.sql.selectable import TextAsFrom

from sqlalchemy_kusto.dialect_kql import KustoKqlCompiler
from sqlalchemy_kusto.dialect_kql import AGGREGATE_PATTERN, KustoKqlCompiler

engine = create_engine("kustokql+https://localhost/testdb")

@@ -216,6 +219,7 @@ def test_sql_to_kql_aggregate():
)
assert KustoKqlCompiler._sql_to_kql_aggregate("sum", "Sales") == 'sum(["Sales"])'
assert KustoKqlCompiler._sql_to_kql_aggregate("avg", "ResponseTime") == 'avg(["ResponseTime"])'
assert KustoKqlCompiler._sql_to_kql_aggregate("AVG", "ResponseTime") == 'avg(["ResponseTime"])'
assert KustoKqlCompiler._sql_to_kql_aggregate("min", "Size") == 'min(["Size"])'
assert KustoKqlCompiler._sql_to_kql_aggregate("max", "Area") == 'max(["Area"])'
assert KustoKqlCompiler._sql_to_kql_aggregate("unknown", "Column") is None
@@ -245,6 +249,7 @@ def test_limit():
query_expected = 'let inner_qry = (["logs"]);' "inner_qry" "| take 5"
assert query_compiled == query_expected


def test_select_count():
kql_query = "logs"
column_count = literal_column("count(*)").label("count")
@@ -271,44 +276,6 @@ def test_select_count():

assert query_compiled == query_expected

def compiler_with_startofmonth_group_by():
metadata = MetaData()
sales_data = Table(
"SalesData",
metadata,
Column("order_date", String),
Column("product_line", String),
Column("sales", sqltypes.Float),
)
query = (
select(
[
func.startofmonth(sales_data.c.order_date).label("order_date"),
sales_data.c.product_line.label("product_line"),
func.sum(sales_data.c.sales).label("(Sales)"),
]
)
.where(
text(
"order_date >= datetime('2003-01-01T00:00:00.000000') and order_date < datetime('2005-06-01T00:00:00.000000')"
)
)
.group_by(
func.startofmonth(sales_data.c.order_date),
sales_data.c.product_line,
)
.order_by(text('"(Sales)" DESC'))
)

query_compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})).replace("\n", "")
query_expected = (
'["SalesData"]'
'| extend ["order_date"] = startofmonth(["order_date"])'
'| summarize ["(Sales)"] = sum(["sales"]) by ["order_date"], ["product_line"]'
'| project ["order_date"], ["product_line"], ["(Sales)"]'
'| order by ["(Sales)"] desc'
)
assert query_compiled == query_expected

def test_select_with_let():
kql_query = "let x = 5; let y = 3; MyTable | where Field1 == x and Field2 == y"
@@ -370,13 +337,33 @@ def test_schema_from_metadata(table_name: str, schema_name: str, expected_table_
metadata,
)
query = stream.select().limit(5)

query_compiled = str(query.compile(engine)).replace("\n", "")

query_expected = f"{expected_table_name}| take __[POSTCOMPILE_param_1]"
assert query_compiled == query_expected


@pytest.mark.parametrize(
"column_name,expected_aggregate",
[
("AVG(Score)", 'avg(["Score"])'),
('AVG("2014")', 'avg(["2014"])'),
('sum("2014")', 'sum(["2014"])'),
("SUM(scores)", 'sum(["scores"])'),
('MIN("scores")', 'min(["scores"])'),
('MIN(["scores"])', 'min(["scores"])'),
('max(scores)', 'max(["scores"])'),
('startofmonth(somedate)', None),
('startofmonth(somedate)/time(1d)', None),
],
)
def test_match_aggregates(column_name: str, expected_aggregate: str):
kql_agg = KustoKqlCompiler._extract_maybe_agg_column_parts(column_name)
if expected_aggregate:
assert kql_agg is not None
assert kql_agg == expected_aggregate
else :
assert kql_agg is None

@pytest.mark.parametrize(
"query_table_name,expected_table_name",
[

0 comments on commit c4b67d5

Please sign in to comment.