From 56fd2fc95db03d4d6dc34d10473325fc34f44108 Mon Sep 17 00:00:00 2001 From: ag-ramachandran Date: Tue, 7 Jan 2025 07:26:48 +0530 Subject: [PATCH] * Fix issues in order by column names * Fix failing tests --- sqlalchemy_kusto/dialect_kql.py | 100 +++++++++++++++++++------------- tests/unit/test_dialect_kql.py | 77 +++++++++++++++++------- 2 files changed, 116 insertions(+), 61 deletions(-) diff --git a/sqlalchemy_kusto/dialect_kql.py b/sqlalchemy_kusto/dialect_kql.py index 7e77377..ab901f2 100644 --- a/sqlalchemy_kusto/dialect_kql.py +++ b/sqlalchemy_kusto/dialect_kql.py @@ -37,23 +37,6 @@ class KustoKqlIdentifierPreparer(compiler.IdentifierPreparer): def __init__(self, dialect, **kw): super().__init__(dialect, initial_quote='["', final_quote='"]', **kw) - -def _get_order_by(order_by_cols): - unwrapped_order_by = [] - for elem in order_by_cols: - if isinstance(elem, sql.elements._label_reference): - nested_element = elem.element - unwrapped_order_by.append( - f"{nested_element._order_by_label_element.name} " - f"{'desc' if (nested_element.modifier == operators.desc_op) else 'asc'}" - ) - elif isinstance(elem, sql.elements.TextClause): - unwrapped_order_by.append(elem.text.replace(" ASC", " asc").replace(" DESC", " desc")) - else: - logger.warning("Unsupported order by clause: %s of type %s", elem, type(elem)) - return unwrapped_order_by - - class KustoKqlCompiler(compiler.SQLCompiler): OPERATORS[operators.and_] = " and " @@ -78,7 +61,6 @@ def visit_select( if len(select_stmt.get_final_froms()) != 1: raise NotSupportedError('Only single "select from" query is supported in kql compiler') - compiled_query_lines = [] from_object = select_stmt.get_final_froms()[0] @@ -97,15 +79,19 @@ def visit_select( else: compiled_query_lines.append(self._convert_schema_in_statement(from_object.text)) + projections_parts_dict = self._get_projection_or_summarize(select_stmt) + if "extend" in projections_parts_dict: + compiled_query_lines.append(projections_parts_dict.pop("extend")) + if select_stmt._whereclause is not None: where_clause = select_stmt._whereclause._compiler_dispatch(self, **kwargs) if where_clause: - converted_where_clause = self.sql_to_kql_where(where_clause) + converted_where_clause = self._sql_to_kql_where(where_clause) compiled_query_lines.append(f"| where {converted_where_clause}") - projections = self._get_projection_or_summarize(select_stmt) - if projections: - compiled_query_lines.append(projections) + for statement_part in projections_parts_dict.values(): + if statement_part: + compiled_query_lines.append(statement_part) if select_stmt._limit_clause is not None: # pylint: disable=protected-access kwargs["literal_execute"] = True @@ -122,7 +108,7 @@ def visit_select( def limit_clause(self, select, **kw): return "" - def _get_projection_or_summarize(self, select: selectable.Select) -> str: + def _get_projection_or_summarize(self, select: selectable.Select) -> dict[str, str]: """Builds the ending part of the query either project or summarize""" columns = select.inner_columns group_by_cols = select._group_by_clauses # pylint: disable=protected-access @@ -147,6 +133,7 @@ def _get_projection_or_summarize(self, select: selectable.Select) -> str: projection_columns = [] for column in [c for c in columns if c.name != "*"]: column_name, column_alias = self._extract_column_name_and_alias(column) + column_alias = self._escape_and_quote_columns(column_alias, True) # Do we have a group by clause ? match_agg_cols = re.match(AGGREGATE_PATTERN, column_name, re.IGNORECASE) # Do we have aggregate columns ? @@ -165,26 +152,53 @@ def _get_projection_or_summarize(self, select: selectable.Select) -> str: if column_alias and column_alias != column_name: extend_columns.add(self._build_column_projection(column_name, column_alias, True)) if column_alias: - projection_columns.append(self._escape_and_quote_columns(column_alias)) + projection_columns.append(self._escape_and_quote_columns(column_alias,True)) else: projection_columns.append(self._escape_and_quote_columns(column_name)) # group by columns by_columns = self._group_by(group_by_cols) if has_aggregates or bool(by_columns): # Summarize can happen with or without aggregate being created - # escape each column with _escape_and_quote_columns summarize_statement = f"| summarize {', '.join(summarize_columns)} " if by_columns: summarize_statement = f"{summarize_statement} by {', '.join(by_columns)}" if extend_columns: - # escape each column with _escape_and_quote_columns extend_statement = f"| extend {', '.join(sorted(extend_columns))}" project_statement = f"| project {', '.join(projection_columns)}" if projection_columns else "" - unwrapped_order_by = _get_order_by(order_by_cols) + unwrapped_order_by = self._get_order_by(order_by_cols) + + sort_statement = f"| order by {', '.join(unwrapped_order_by)}" if unwrapped_order_by else "" + # projection_statement = f"{extend_statement}{summarize_statement}{project_statement}{sort_statement}" + return { + "extend": extend_statement, + "summarize": summarize_statement, + "project": project_statement, + "sort": sort_statement, + } + + def _get_order_by(self, order_by_cols): + unwrapped_order_by = [] + for elem in order_by_cols: + if isinstance(elem, sql.elements._label_reference): + nested_element = elem.element + unwrapped_order_by.append( + f"{self._escape_and_quote_columns(nested_element._order_by_label_element.name,is_alias=True)} " + f"{'desc' if (nested_element.modifier == operators.desc_op) else 'asc'}" + ) + elif isinstance(elem, sql.elements.TextClause): + sort_parts = elem.text.split() + if len(sort_parts) == 2: + unwrapped_order_by.append( + f"{self._escape_and_quote_columns(sort_parts[0],is_alias=True)} {sort_parts[1].lower()}" + ) + elif len(sort_parts) == 1: + unwrapped_order_by.append(self._escape_and_quote_columns(sort_parts[0],is_alias=True)) + else: + unwrapped_order_by.append(elem.text.replace(" ASC", " asc").replace(" DESC", " desc")) + else: + logger.warning("Unsupported order by clause: %s of type %s", elem, type(elem)) + return unwrapped_order_by - sort_statement = f"| sort by {', '.join(unwrapped_order_by)}" if unwrapped_order_by else "" - projection_statement = f"{extend_statement}{summarize_statement}{project_statement}{sort_statement}" - return projection_statement def _group_by(self, group_by_cols): by_columns = set() @@ -197,9 +211,11 @@ def _group_by(self, group_by_cols): return by_columns @staticmethod - def _escape_and_quote_columns(name: str): + def _escape_and_quote_columns(name: str, is_alias=False): + if name is None: + return None name = name.strip() - if KustoKqlCompiler._is_kql_function(name): + if KustoKqlCompiler._is_kql_function(name) and not is_alias: return name if name.startswith('"') and name.endswith('"'): name = name[1:-1] @@ -222,19 +238,23 @@ def _escape_and_quote_columns(name: str): return f'["{name}"]' @staticmethod - def sql_to_kql_where(where_clause: str): - where_clause = where_clause.strip() - #Handle 'IS NULL' and 'IS NOT NULL' -> KQL equivalents + def _sql_to_kql_where(where_clause: str): + where_clause = where_clause.strip().replace("\n", "") + # Handle 'IS NULL' and 'IS NOT NULL' -> KQL equivalents where_clause = re.sub( r'(\["[^\]]+"\])\s*IS NULL', r"isnull(\1)", where_clause, re.IGNORECASE ) # IS NULL -> isnull(["FieldName"]) where_clause = re.sub( r'(\["[^\]]+"\])\s*IS NOT NULL', r"isnotnull(\1)", where_clause, re.IGNORECASE ) # IS NOT NULL -> isnotnull(["FieldName"]) - #Handle comparison operators - where_clause = re.sub( - r"(\s)(=)\s*", r" \2= ", where_clause, re.IGNORECASE - ) # Change '=' to '==' for equality comparisons + # Handle comparison operators + # Change '=' to '==' for equality comparisons + where_clause = re.sub(r"(?<=[^=])=(?=\s|$|[^=])", r"==", where_clause, re.IGNORECASE) + # Remove spaces in < = and > = operators + where_clause = re.sub(r"\s*<\s*=\s*", "<=", where_clause, re.IGNORECASE) + where_clause = re.sub(r"\s*>\s*=\s*", ">=", where_clause, re.IGNORECASE) + where_clause = where_clause.replace(">==", ">=") + where_clause = where_clause.replace("<==", "<=") where_clause = re.sub(r"(\s)(<>|!=)\s*", r" \2 ", where_clause, re.IGNORECASE) # Handle '!=' and '<>' where_clause = re.sub( r"(\s)(<|<=|>|>=)\s*", r" \2 ", where_clause, re.IGNORECASE @@ -249,7 +269,7 @@ def sql_to_kql_where(where_clause: str): where_clause = re.sub( r"(\s)IN\s*\(([^)]+)\)", r"\1in (\2)", where_clause, re.IGNORECASE ) # IN operator (list of values) - #Handle BETWEEN operator (if needed) + # Handle BETWEEN operator (if needed) where_clause = re.sub( r"(\w+|\[\"[A-Za-z0-9_]+\"\]) (BETWEEN|between) (\d) (AND|and) (\d)", diff --git a/tests/unit/test_dialect_kql.py b/tests/unit/test_dialect_kql.py index 182f4a0..0a09f95 100644 --- a/tests/unit/test_dialect_kql.py +++ b/tests/unit/test_dialect_kql.py @@ -11,8 +11,9 @@ distinct, literal_column, select, - text, + text, func, ) +from sqlalchemy.sql import sqltypes from sqlalchemy.sql.selectable import TextAsFrom from sqlalchemy_kusto.dialect_kql import KustoKqlCompiler @@ -38,7 +39,7 @@ def test_compiler_with_projection(): query_expected = ( 'let virtual_table = (["logs"] ' "| take 10);virtual_table" - '| extend id = ["Id"], tId = ["TypeId"]' + '| extend ["id"] = ["Id"], ["tId"] = ["TypeId"]' '| project ["id"], ["tId"], ["Type"]' "| take __[POSTCOMPILE_param_1]" ) @@ -108,11 +109,11 @@ def test_group_by_text(): query_compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})).replace("\n", "") # raw query text from query query_expected = ( - '["ActiveUsersLastMonth"]| extend ActiveUserMetric = ["ActiveUsers"], ' - 'EventInfo_Time = ["EventInfo_Time"] / time(1d)' + '["ActiveUsersLastMonth"]| extend ["ActiveUserMetric"] = ["ActiveUsers"], ' + '["EventInfo_Time"] = ["EventInfo_Time"] / time(1d)' '| summarize by ["EventInfo_Time"] / time(1d)' '| project ["EventInfo_Time"], ["ActiveUserMetric"]' - "| sort by ActiveUserMetric desc" + '| order by ["ActiveUserMetric"] desc' ) assert query_compiled == query_expected @@ -127,10 +128,10 @@ def test_group_by_text_vaccine_dataset(): ) query_compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})).replace("\n", "") query_expected = ( - 'database("superset").["CovidVaccineData"]' - '| summarize by ["country_name"]' - '| project ["country_name"]' - "| sort by country_name asc" + 'database("superset").["CovidVaccineData"]| ' + 'extend ["country_name"] = ["country_name"]| ' + 'summarize by ["country_name"]| ' + 'project ["country_name"]| order by ["country_name"] asc' ) assert query_compiled == query_expected @@ -163,10 +164,10 @@ def test_distinct_count_by_text(): # raw query text from query query_expected = ( '["ActiveUsersLastMonth"]' - '| extend EventInfo_Time = ["EventInfo_Time"] / time(1d)' - '| summarize DistinctUsers = dcount(["ActiveUsers"]) by ["EventInfo_Time"] / time(1d)' + '| extend ["EventInfo_Time"] = ["EventInfo_Time"] / time(1d)' + '| summarize ["DistinctUsers"] = dcount(["ActiveUsers"]) by ["EventInfo_Time"] / time(1d)' '| project ["EventInfo_Time"], ["DistinctUsers"]' - "| sort by ActiveUserMetric desc" + '| order by ["ActiveUserMetric"] desc' ) assert query_compiled == query_expected @@ -187,10 +188,10 @@ def test_distinct_count_alt_by_text(): # raw query text from query query_expected = ( '["ActiveUsersLastMonth"]' - '| extend EventInfo_Time = ["EventInfo_Time"] / time(1d)' - '| summarize DistinctUsers = dcount(["ActiveUsers"]) by ["EventInfo_Time"] / time(1d)' + '| extend ["EventInfo_Time"] = ["EventInfo_Time"] / time(1d)' + '| summarize ["DistinctUsers"] = dcount(["ActiveUsers"]) by ["EventInfo_Time"] / time(1d)' '| project ["EventInfo_Time"], ["DistinctUsers"]' - "| sort by ActiveUserMetric desc" + '| order by ["ActiveUserMetric"] desc' ) assert query_compiled == query_expected @@ -240,14 +241,10 @@ def test_limit(): sql = "logs" limit = 5 query = select("*").select_from(TextAsFrom(text(sql), ["*"]).alias("inner_qry")).limit(limit) - query_compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})).replace("\n", "") - query_expected = 'let inner_qry = (["logs"]);' "inner_qry" "| take 5" - assert query_compiled == query_expected - def test_select_count(): kql_query = "logs" column_count = literal_column("count(*)").label("count") @@ -266,14 +263,52 @@ def test_select_count(): 'let inner_qry = (["logs"]);' "inner_qry" "| where Field1 > 1 and Field2 < 2" - "| summarize count = count() " + '| summarize ["count"] = count() ' '| project ["count"]' - "| sort by count desc" + '| order by ["count"] desc' "| take 5" ) assert query_compiled == query_expected +def compiler_with_startofmonth_group_by(): + metadata = MetaData() + sales_data = Table( + "SalesData", + metadata, + Column("order_date", String), + Column("product_line", String), + Column("sales", sqltypes.Float), + ) + query = ( + select( + [ + func.startofmonth(sales_data.c.order_date).label("order_date"), + sales_data.c.product_line.label("product_line"), + func.sum(sales_data.c.sales).label("(Sales)"), + ] + ) + .where( + text( + "order_date >= datetime('2003-01-01T00:00:00.000000') and order_date < datetime('2005-06-01T00:00:00.000000')" + ) + ) + .group_by( + func.startofmonth(sales_data.c.order_date), + sales_data.c.product_line, + ) + .order_by(text('"(Sales)" DESC')) + ) + + query_compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})).replace("\n", "") + query_expected = ( + '["SalesData"]' + '| extend ["order_date"] = startofmonth(["order_date"])' + '| summarize ["(Sales)"] = sum(["sales"]) by ["order_date"], ["product_line"]' + '| project ["order_date"], ["product_line"], ["(Sales)"]' + '| order by ["(Sales)"] desc' + ) + assert query_compiled == query_expected def test_select_with_let(): kql_query = "let x = 5; let y = 3; MyTable | where Field1 == x and Field2 == y"