Skip to content

Commit

Permalink
* Fix issues in order by column names
Browse files Browse the repository at this point in the history
* Fix failing tests
  • Loading branch information
ag-ramachandran committed Jan 7, 2025
1 parent ddb0d04 commit 56fd2fc
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 61 deletions.
100 changes: 60 additions & 40 deletions sqlalchemy_kusto/dialect_kql.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,6 @@ class KustoKqlIdentifierPreparer(compiler.IdentifierPreparer):
def __init__(self, dialect, **kw):
super().__init__(dialect, initial_quote='["', final_quote='"]', **kw)


def _get_order_by(order_by_cols):
unwrapped_order_by = []
for elem in order_by_cols:
if isinstance(elem, sql.elements._label_reference):
nested_element = elem.element
unwrapped_order_by.append(
f"{nested_element._order_by_label_element.name} "
f"{'desc' if (nested_element.modifier == operators.desc_op) else 'asc'}"
)
elif isinstance(elem, sql.elements.TextClause):
unwrapped_order_by.append(elem.text.replace(" ASC", " asc").replace(" DESC", " desc"))
else:
logger.warning("Unsupported order by clause: %s of type %s", elem, type(elem))
return unwrapped_order_by


class KustoKqlCompiler(compiler.SQLCompiler):
OPERATORS[operators.and_] = " and "

Expand All @@ -78,7 +61,6 @@ def visit_select(

if len(select_stmt.get_final_froms()) != 1:
raise NotSupportedError('Only single "select from" query is supported in kql compiler')

compiled_query_lines = []

from_object = select_stmt.get_final_froms()[0]
Expand All @@ -97,15 +79,19 @@ def visit_select(
else:
compiled_query_lines.append(self._convert_schema_in_statement(from_object.text))

projections_parts_dict = self._get_projection_or_summarize(select_stmt)
if "extend" in projections_parts_dict:
compiled_query_lines.append(projections_parts_dict.pop("extend"))

if select_stmt._whereclause is not None:
where_clause = select_stmt._whereclause._compiler_dispatch(self, **kwargs)
if where_clause:
converted_where_clause = self.sql_to_kql_where(where_clause)
converted_where_clause = self._sql_to_kql_where(where_clause)
compiled_query_lines.append(f"| where {converted_where_clause}")

projections = self._get_projection_or_summarize(select_stmt)
if projections:
compiled_query_lines.append(projections)
for statement_part in projections_parts_dict.values():
if statement_part:
compiled_query_lines.append(statement_part)

if select_stmt._limit_clause is not None: # pylint: disable=protected-access
kwargs["literal_execute"] = True
Expand All @@ -122,7 +108,7 @@ def visit_select(
def limit_clause(self, select, **kw):
return ""

def _get_projection_or_summarize(self, select: selectable.Select) -> str:
def _get_projection_or_summarize(self, select: selectable.Select) -> dict[str, str]:
"""Builds the ending part of the query either project or summarize"""
columns = select.inner_columns
group_by_cols = select._group_by_clauses # pylint: disable=protected-access
Expand All @@ -147,6 +133,7 @@ def _get_projection_or_summarize(self, select: selectable.Select) -> str:
projection_columns = []
for column in [c for c in columns if c.name != "*"]:
column_name, column_alias = self._extract_column_name_and_alias(column)
column_alias = self._escape_and_quote_columns(column_alias, True)
# Do we have a group by clause ?
match_agg_cols = re.match(AGGREGATE_PATTERN, column_name, re.IGNORECASE)
# Do we have aggregate columns ?
Expand All @@ -165,26 +152,53 @@ def _get_projection_or_summarize(self, select: selectable.Select) -> str:
if column_alias and column_alias != column_name:
extend_columns.add(self._build_column_projection(column_name, column_alias, True))
if column_alias:
projection_columns.append(self._escape_and_quote_columns(column_alias))
projection_columns.append(self._escape_and_quote_columns(column_alias,True))
else:
projection_columns.append(self._escape_and_quote_columns(column_name))
# group by columns
by_columns = self._group_by(group_by_cols)
if has_aggregates or bool(by_columns): # Summarize can happen with or without aggregate being created
# escape each column with _escape_and_quote_columns
summarize_statement = f"| summarize {', '.join(summarize_columns)} "
if by_columns:
summarize_statement = f"{summarize_statement} by {', '.join(by_columns)}"
if extend_columns:
# escape each column with _escape_and_quote_columns
extend_statement = f"| extend {', '.join(sorted(extend_columns))}"
project_statement = f"| project {', '.join(projection_columns)}" if projection_columns else ""

unwrapped_order_by = _get_order_by(order_by_cols)
unwrapped_order_by = self._get_order_by(order_by_cols)

sort_statement = f"| order by {', '.join(unwrapped_order_by)}" if unwrapped_order_by else ""
# projection_statement = f"{extend_statement}{summarize_statement}{project_statement}{sort_statement}"
return {
"extend": extend_statement,
"summarize": summarize_statement,
"project": project_statement,
"sort": sort_statement,
}

def _get_order_by(self, order_by_cols):
unwrapped_order_by = []
for elem in order_by_cols:
if isinstance(elem, sql.elements._label_reference):
nested_element = elem.element
unwrapped_order_by.append(
f"{self._escape_and_quote_columns(nested_element._order_by_label_element.name,is_alias=True)} "
f"{'desc' if (nested_element.modifier == operators.desc_op) else 'asc'}"
)
elif isinstance(elem, sql.elements.TextClause):
sort_parts = elem.text.split()
if len(sort_parts) == 2:
unwrapped_order_by.append(
f"{self._escape_and_quote_columns(sort_parts[0],is_alias=True)} {sort_parts[1].lower()}"
)
elif len(sort_parts) == 1:
unwrapped_order_by.append(self._escape_and_quote_columns(sort_parts[0],is_alias=True))
else:
unwrapped_order_by.append(elem.text.replace(" ASC", " asc").replace(" DESC", " desc"))
else:
logger.warning("Unsupported order by clause: %s of type %s", elem, type(elem))
return unwrapped_order_by

sort_statement = f"| sort by {', '.join(unwrapped_order_by)}" if unwrapped_order_by else ""
projection_statement = f"{extend_statement}{summarize_statement}{project_statement}{sort_statement}"
return projection_statement

def _group_by(self, group_by_cols):
by_columns = set()
Expand All @@ -197,9 +211,11 @@ def _group_by(self, group_by_cols):
return by_columns

@staticmethod
def _escape_and_quote_columns(name: str):
def _escape_and_quote_columns(name: str, is_alias=False):
if name is None:
return None
name = name.strip()
if KustoKqlCompiler._is_kql_function(name):
if KustoKqlCompiler._is_kql_function(name) and not is_alias:
return name
if name.startswith('"') and name.endswith('"'):
name = name[1:-1]
Expand All @@ -222,19 +238,23 @@ def _escape_and_quote_columns(name: str):
return f'["{name}"]'

@staticmethod
def sql_to_kql_where(where_clause: str):
where_clause = where_clause.strip()
#Handle 'IS NULL' and 'IS NOT NULL' -> KQL equivalents
def _sql_to_kql_where(where_clause: str):
where_clause = where_clause.strip().replace("\n", "")
# Handle 'IS NULL' and 'IS NOT NULL' -> KQL equivalents
where_clause = re.sub(
r'(\["[^\]]+"\])\s*IS NULL', r"isnull(\1)", where_clause, re.IGNORECASE
) # IS NULL -> isnull(["FieldName"])
where_clause = re.sub(
r'(\["[^\]]+"\])\s*IS NOT NULL', r"isnotnull(\1)", where_clause, re.IGNORECASE
) # IS NOT NULL -> isnotnull(["FieldName"])
#Handle comparison operators
where_clause = re.sub(
r"(\s)(=)\s*", r" \2= ", where_clause, re.IGNORECASE
) # Change '=' to '==' for equality comparisons
# Handle comparison operators
# Change '=' to '==' for equality comparisons
where_clause = re.sub(r"(?<=[^=])=(?=\s|$|[^=])", r"==", where_clause, re.IGNORECASE)
# Remove spaces in < = and > = operators
where_clause = re.sub(r"\s*<\s*=\s*", "<=", where_clause, re.IGNORECASE)
where_clause = re.sub(r"\s*>\s*=\s*", ">=", where_clause, re.IGNORECASE)
where_clause = where_clause.replace(">==", ">=")
where_clause = where_clause.replace("<==", "<=")
where_clause = re.sub(r"(\s)(<>|!=)\s*", r" \2 ", where_clause, re.IGNORECASE) # Handle '!=' and '<>'
where_clause = re.sub(
r"(\s)(<|<=|>|>=)\s*", r" \2 ", where_clause, re.IGNORECASE
Expand All @@ -249,7 +269,7 @@ def sql_to_kql_where(where_clause: str):
where_clause = re.sub(
r"(\s)IN\s*\(([^)]+)\)", r"\1in (\2)", where_clause, re.IGNORECASE
) # IN operator (list of values)
#Handle BETWEEN operator (if needed)
# Handle BETWEEN operator (if needed)

where_clause = re.sub(
r"(\w+|\[\"[A-Za-z0-9_]+\"\]) (BETWEEN|between) (\d) (AND|and) (\d)",
Expand Down
77 changes: 56 additions & 21 deletions tests/unit/test_dialect_kql.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
distinct,
literal_column,
select,
text,
text, func,
)
from sqlalchemy.sql import sqltypes
from sqlalchemy.sql.selectable import TextAsFrom

from sqlalchemy_kusto.dialect_kql import KustoKqlCompiler
Expand All @@ -38,7 +39,7 @@ def test_compiler_with_projection():
query_expected = (
'let virtual_table = (["logs"] '
"| take 10);virtual_table"
'| extend id = ["Id"], tId = ["TypeId"]'
'| extend ["id"] = ["Id"], ["tId"] = ["TypeId"]'
'| project ["id"], ["tId"], ["Type"]'
"| take __[POSTCOMPILE_param_1]"
)
Expand Down Expand Up @@ -108,11 +109,11 @@ def test_group_by_text():
query_compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})).replace("\n", "")
# raw query text from query
query_expected = (
'["ActiveUsersLastMonth"]| extend ActiveUserMetric = ["ActiveUsers"], '
'EventInfo_Time = ["EventInfo_Time"] / time(1d)'
'["ActiveUsersLastMonth"]| extend ["ActiveUserMetric"] = ["ActiveUsers"], '
'["EventInfo_Time"] = ["EventInfo_Time"] / time(1d)'
'| summarize by ["EventInfo_Time"] / time(1d)'
'| project ["EventInfo_Time"], ["ActiveUserMetric"]'
"| sort by ActiveUserMetric desc"
'| order by ["ActiveUserMetric"] desc'
)
assert query_compiled == query_expected

Expand All @@ -127,10 +128,10 @@ def test_group_by_text_vaccine_dataset():
)
query_compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})).replace("\n", "")
query_expected = (
'database("superset").["CovidVaccineData"]'
'| summarize by ["country_name"]'
'| project ["country_name"]'
"| sort by country_name asc"
'database("superset").["CovidVaccineData"]| '
'extend ["country_name"] = ["country_name"]| '
'summarize by ["country_name"]| '
'project ["country_name"]| order by ["country_name"] asc'
)
assert query_compiled == query_expected

Expand Down Expand Up @@ -163,10 +164,10 @@ def test_distinct_count_by_text():
# raw query text from query
query_expected = (
'["ActiveUsersLastMonth"]'
'| extend EventInfo_Time = ["EventInfo_Time"] / time(1d)'
'| summarize DistinctUsers = dcount(["ActiveUsers"]) by ["EventInfo_Time"] / time(1d)'
'| extend ["EventInfo_Time"] = ["EventInfo_Time"] / time(1d)'
'| summarize ["DistinctUsers"] = dcount(["ActiveUsers"]) by ["EventInfo_Time"] / time(1d)'
'| project ["EventInfo_Time"], ["DistinctUsers"]'
"| sort by ActiveUserMetric desc"
'| order by ["ActiveUserMetric"] desc'
)
assert query_compiled == query_expected

Expand All @@ -187,10 +188,10 @@ def test_distinct_count_alt_by_text():
# raw query text from query
query_expected = (
'["ActiveUsersLastMonth"]'
'| extend EventInfo_Time = ["EventInfo_Time"] / time(1d)'
'| summarize DistinctUsers = dcount(["ActiveUsers"]) by ["EventInfo_Time"] / time(1d)'
'| extend ["EventInfo_Time"] = ["EventInfo_Time"] / time(1d)'
'| summarize ["DistinctUsers"] = dcount(["ActiveUsers"]) by ["EventInfo_Time"] / time(1d)'
'| project ["EventInfo_Time"], ["DistinctUsers"]'
"| sort by ActiveUserMetric desc"
'| order by ["ActiveUserMetric"] desc'
)

assert query_compiled == query_expected
Expand Down Expand Up @@ -240,14 +241,10 @@ def test_limit():
sql = "logs"
limit = 5
query = select("*").select_from(TextAsFrom(text(sql), ["*"]).alias("inner_qry")).limit(limit)

query_compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})).replace("\n", "")

query_expected = 'let inner_qry = (["logs"]);' "inner_qry" "| take 5"

assert query_compiled == query_expected


def test_select_count():
kql_query = "logs"
column_count = literal_column("count(*)").label("count")
Expand All @@ -266,14 +263,52 @@ def test_select_count():
'let inner_qry = (["logs"]);'
"inner_qry"
"| where Field1 > 1 and Field2 < 2"
"| summarize count = count() "
'| summarize ["count"] = count() '
'| project ["count"]'
"| sort by count desc"
'| order by ["count"] desc'
"| take 5"
)

assert query_compiled == query_expected

def compiler_with_startofmonth_group_by():
metadata = MetaData()
sales_data = Table(
"SalesData",
metadata,
Column("order_date", String),
Column("product_line", String),
Column("sales", sqltypes.Float),
)
query = (
select(
[
func.startofmonth(sales_data.c.order_date).label("order_date"),
sales_data.c.product_line.label("product_line"),
func.sum(sales_data.c.sales).label("(Sales)"),
]
)
.where(
text(
"order_date >= datetime('2003-01-01T00:00:00.000000') and order_date < datetime('2005-06-01T00:00:00.000000')"
)
)
.group_by(
func.startofmonth(sales_data.c.order_date),
sales_data.c.product_line,
)
.order_by(text('"(Sales)" DESC'))
)

query_compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})).replace("\n", "")
query_expected = (
'["SalesData"]'
'| extend ["order_date"] = startofmonth(["order_date"])'
'| summarize ["(Sales)"] = sum(["sales"]) by ["order_date"], ["product_line"]'
'| project ["order_date"], ["product_line"], ["(Sales)"]'
'| order by ["(Sales)"] desc'
)
assert query_compiled == query_expected

def test_select_with_let():
kql_query = "let x = 5; let y = 3; MyTable | where Field1 == x and Field2 == y"
Expand Down

0 comments on commit 56fd2fc

Please sign in to comment.