diff --git a/.github/workflows/qa.yml b/.github/workflows/qa.yml index 5ae432b..2bb9aa2 100644 --- a/.github/workflows/qa.yml +++ b/.github/workflows/qa.yml @@ -18,7 +18,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v2 with: - python-version: 3.8 + python-version: '3.10' - uses: actions/cache@v2 with: @@ -44,7 +44,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v2 with: - python-version: 3.8 + python-version: '3.10' - uses: actions/cache@v2 with: diff --git a/Makefile b/Makefile index 415fa69..a1bcb55 100644 --- a/Makefile +++ b/Makefile @@ -27,17 +27,24 @@ install-dev: # Install dev dependencies ############################################################################## # Development process ############################################################################## -check: # Run formatters and linters - @echo "Running checkers..." +format: + @echo "Running formatters..." - @echo "\n1. Run $(GREEN_ITALIC)isort$(DEFAULT) to order imports." - $(PYTHON) -m isort --profile black . + @echo "\n1. Run $(GREEN_ITALIC)ruff$(DEFAULT) to format code." + $(PYTHON) -m ruff check --fix-only . @echo "\n2. Run $(GREEN_ITALIC)black$(DEFAULT) to format code." $(PYTHON) -m black . - @echo "\n3. Run $(GREEN_ITALIC)pylint$(DEFAULT) to lint the project." - $(PYTHON) -m pylint setup.py sqlalchemy_kusto/ + +check: # Run formatters and linters + @echo "Running checkers..." + + @echo "\n1. Run $(GREEN_ITALIC)ruff$(DEFAULT) to check code." + $(PYTHON) -m ruff check . + + @echo "\n2. Run $(GREEN_ITALIC)black$(DEFAULT) to check code formatting." + $(PYTHON) -m black . --check @echo "\n4. Run $(GREEN_ITALIC)mypy$(DEFAULT) for type checking." $(PYTHON) -m mypy . diff --git a/pyproject.toml b/pyproject.toml index f04d999..58eb0c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,9 +5,84 @@ requires = [ ] build-backend = "setuptools.build_meta" -[tool.black] +[tool.ruff] +target-version = "py310" + +# https://beta.ruff.rs/docs/settings/#line-length line-length = 120 -target-version = ["py38", "py39", "py310", "py311"] + +# https://beta.ruff.rs/docs/settings/#select +lint.select = [ + "F", # Pyflakes (https://beta.ruff.rs/docs/rules/#pyflakes-f) + "E", # pycodestyle (https://beta.ruff.rs/docs/rules/#pycodestyle-e-w) + "C90", # mccabe (https://beta.ruff.rs/docs/rules/#mccabe-c90) + "N", # pep8-naming (https://beta.ruff.rs/docs/rules/#pep8-naming-n) + "D", # pydocstyle (https://beta.ruff.rs/docs/rules/#pydocstyle-d) + "UP", # pyupgrade (https://beta.ruff.rs/docs/rules/#pyupgrade-up) + "ANN", # flake8-annotations (https://beta.ruff.rs/docs/rules/#flake8-annotations-ann) + "B", # flake8-bugbear (https://beta.ruff.rs/docs/rules/#flake8-bugbear-b) + "C4", # flake8-comprehensions (https://beta.ruff.rs/docs/rules/#flake8-comprehensions-c4) + "G", # flake8-logging-format (https://beta.ruff.rs/docs/rules/#flake8-logging-format-g) + "T20", # flake8-print (https://beta.ruff.rs/docs/rules/#flake8-print-t20) + "PT", # flake8-pytest-style (https://beta.ruff.rs/docs/rules/#flake8-pytest-style-pt) + "TID", # flake8-tidy-imports (https://beta.ruff.rs/docs/rules/#flake8-tidy-imports-tid) + "ARG", # flake8-unused-arguments (https://beta.ruff.rs/docs/rules/#flake8-unused-arguments-arg) + "PTH", # flake8-use-pathlib (https://beta.ruff.rs/docs/rules/#flake8-use-pathlib-pth) + "ERA", # eradicate (https://beta.ruff.rs/docs/rules/#eradicate-era) + "PL", # pylint (https://beta.ruff.rs/docs/rules/#pylint-pl) + "TRY", # tryceratops (https://beta.ruff.rs/docs/rules/#tryceratops-try) + "RUF100", # Unused noqa directive +] + +# https://beta.ruff.rs/docs/settings/#ignore +lint.ignore = [ + "C901", # too complex + + # pycodestyle (https://beta.ruff.rs/docs/rules/#pydocstyle-d) + "D100", # Missing docstring in public module + "D101", # Missing docstring in public class + "D102", # Missing docstring in public method + "D103", # Missing docstring in public function + "D104", # Missing docstring in public package + "D105", # Missing docstring in magic method + "D106", # Missing docstring in public nested class + "D107", # Missing docstring in `__init__` + "D203", # 1 blank line required before class docstring + "D205", # 1 blank line required between summary line and description + "D212", # Multi-line docstring summary should start at the first line + + "N818", # Exception name {name} should be named with an Error suffix; + + "TRY003", # Avoid specifying long messages outside the exception class + + # flake8-annotations + "ANN001", # Missing type annotation for function argument + "ANN002", # Missing type annotation for `*args` + "ANN003", # Missing type annotation for `**kwargs` + "ANN201", # Missing return type annotation for public function + "ANN202", # Missing return type annotation for private function + "ANN204", # Missing return type annotation for special method + "ANN401", # Dynamically typed expressions (typing.Any) are disallowed + + "ARG002", # Unused method argument + + "PLR0913", # Too many arguments in function definition +] + +[tool.ruff.lint.pycodestyle] +max-doc-length = 120 + +[tool.ruff.lint.pydocstyle] +# Use Google-style docstrings +convention = "google" + +[tool.ruff.lint.flake8-pytest-style] +# Set the parametrize values type in tests. +parametrize-values-type = "list" + +[tool.black] +line-length = 88 +target-version = ["py310", "py311"] include = ".pyi?$" exclude = """ ( @@ -24,36 +99,13 @@ exclude = """ ) """ -[tool.isort] -line_length = 120 -multi_line_output = 3 -include_trailing_comma = true -force_grid_wrap = 0 -use_parentheses = true - [tool.mypy] -python_version = "3.8" +python_version = "3.10" strict_optional = true show_error_codes = true warn_redundant_casts = true warn_unused_ignores = true -disallow_any_generics = true +disallow_any_generics = false check_untyped_defs = true no_implicit_reexport = true ignore_missing_imports = true - -[tool.pylint.messages_control] -max-line-length = 120 -disable = [ - "consider-using-f-string", - "missing-class-docstring", - "missing-function-docstring", - "missing-module-docstring", - "no-self-use", - "protected-access", - "too-few-public-methods", - "too-many-arguments", - "too-many-locals", - "too-many-public-methods", - "unused-argument", -] diff --git a/setup.py b/setup.py index b39a299..1bcf0bb 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,4 @@ +from pathlib import Path from setuptools import find_packages, setup NAME = "sqlalchemy-kusto" @@ -12,16 +13,16 @@ ] EXTRAS = { "dev": [ - "black>=21.12b0", - "isort>=5.10.1", - "mypy==0.971", - "pylint==2.15.0", - "pytest>=6.2.5", - "python-dotenv>=0.19.2", + "black>=24.10.0", + "mypy>=1.14.1", + "pytest>=8.3.4", + "python-dotenv>=1.0.1", + "ruff>=0.8.6", ] } -with open("README.md", "r", encoding="utf-8") as f: +path = Path("README.md") +with path.open(encoding="utf-8") as f: LONG_DESCRIPTION = f.read() setup( @@ -51,7 +52,7 @@ project_urls={ "Bug Tracker": "https://github.com/dodopizza/sqlalchemy-kusto/issues", }, - python_requires=">=3.8", + python_requires=">=3.10", version=VERSION, zip_safe=False, ) diff --git a/sqlalchemy_kusto/__init__.py b/sqlalchemy_kusto/__init__.py index 2d1062c..2f5fcb1 100644 --- a/sqlalchemy_kusto/__init__.py +++ b/sqlalchemy_kusto/__init__.py @@ -1,6 +1,5 @@ from sqlalchemy_kusto.dbapi import connect -# pylint: disable=redefined-builtin from sqlalchemy_kusto.errors import ( DatabaseError, DataError, @@ -31,7 +30,7 @@ "Warning", ] -apilevel = "2.0" # pylint: disable=invalid-name +apilevel = "2.0" # Threads may share the module and connections -threadsafety = 2 # pylint: disable=invalid-name -paramstyle = "pyformat" # pylint: disable=invalid-name +threadsafety = 2 +paramstyle = "pyformat" diff --git a/sqlalchemy_kusto/dbapi.py b/sqlalchemy_kusto/dbapi.py index 71ba521..1b697a4 100644 --- a/sqlalchemy_kusto/dbapi.py +++ b/sqlalchemy_kusto/dbapi.py @@ -1,7 +1,12 @@ from collections import namedtuple -from typing import Any, List, Optional, Tuple +from typing import Any -from azure.kusto.data import ClientRequestProperties, KustoClient, KustoConnectionStringBuilder +from azure.identity import WorkloadIdentityCredential +from azure.kusto.data import ( + ClientRequestProperties, + KustoClient, + KustoConnectionStringBuilder, +) from azure.kusto.data._models import KustoResultColumn from azure.kusto.data.exceptions import KustoAuthenticationError, KustoServiceError @@ -13,7 +18,7 @@ def check_closed(func): def decorator(self, *args, **kwargs): if self.closed: - raise Exception("{klass} already closed".format(klass=self.__class__.__name__)) + raise ValueError(f"{self.__class__.__name__} already closed") return func(self, *args, **kwargs) return decorator @@ -23,8 +28,8 @@ def check_result(func): """Decorator that checks if the cursor has results from `execute`.""" def decorator(self, *args, **kwargs): - if self._results is None: # pylint: disable=protected-access - raise Exception("Called before `execute`") + if self._results is None: + raise ValueError("Called before `execute`") return func(self, *args, **kwargs) return decorator @@ -34,13 +39,23 @@ def connect( cluster: str, database: str, msi: bool = False, - user_msi: str = None, - azure_ad_client_id: str = None, - azure_ad_client_secret: str = None, - azure_ad_tenant_id: str = None, + workload_identity: bool = False, + user_msi: str | None = None, + azure_ad_client_id: str | None = None, + azure_ad_client_secret: str | None = None, + azure_ad_tenant_id: str | None = None, ): """Return a connection to the database.""" - return Connection(cluster, database, msi, user_msi, azure_ad_client_id, azure_ad_client_secret, azure_ad_tenant_id) + return Connection( + cluster, + database, + msi, + workload_identity, + user_msi, + azure_ad_client_id, + azure_ad_client_secret, + azure_ad_tenant_id, + ) class Connection: @@ -51,13 +66,14 @@ def __init__( cluster: str, database: str, msi: bool = False, - user_msi: str = None, - azure_ad_client_id: str = None, - azure_ad_client_secret: str = None, - azure_ad_tenant_id: str = None, + workload_identity: bool = False, + user_msi: str | None = None, + azure_ad_client_id: str | None = None, + azure_ad_client_secret: str | None = None, + azure_ad_tenant_id: str | None = None, ): self.closed = False - self.cursors: List[Cursor] = [] + self.cursors: list[Cursor] = [] kcsb = None if azure_ad_client_id and azure_ad_client_secret and azure_ad_tenant_id: @@ -68,15 +84,27 @@ def __init__( app_key=azure_ad_client_secret, authority_id=azure_ad_tenant_id, ) + elif workload_identity: + # Workload Identity + kcsb = KustoConnectionStringBuilder.with_azure_token_credential( + cluster, WorkloadIdentityCredential() + ) elif msi: # Managed Service Identity (MSI) - kcsb = KustoConnectionStringBuilder.with_aad_managed_service_identity_authentication( - cluster, client_id=user_msi - ) + if user_msi is None or user_msi == "": + # System managed identity + kcsb = KustoConnectionStringBuilder.with_aad_managed_service_identity_authentication( + cluster + ) + else: + # user managed identity + kcsb = KustoConnectionStringBuilder.with_aad_managed_service_identity_authentication( + cluster, client_id=user_msi + ) else: # neither SP or MSI kcsb = KustoConnectionStringBuilder.with_az_cli_authentication(cluster) - kcsb._set_connector_details("sqlalchemy-kusto", "0.1.0") # pylint: disable=protected-access + kcsb._set_connector_details("sqlalchemy-kusto", "1.1.0") self.kusto_client = KustoClient(kcsb) self.database = database self.properties = ClientRequestProperties() @@ -84,7 +112,6 @@ def __init__( @check_closed def close(self): """Close the connection now. Kusto does not require to close the connection.""" - # self.closed = True for cursor in self.cursors: cursor.close() @@ -130,17 +157,19 @@ def __init__( self, kusto_client: KustoClient, database: str, - properties: Optional[ClientRequestProperties] = None, + properties: ClientRequestProperties | None = None, ): - self._results: Optional[List[Tuple[Any, ...]]] = None + self._results: list[tuple[Any, ...]] | None = None self.kusto_client = kusto_client self.database = database self.closed = False - self.description: Optional[List[CursorDescriptionRow]] = None + self.description: list[CursorDescriptionRow] | None = None self.current_item_index = 0 - self.properties = properties if properties is not None else ClientRequestProperties() + self.properties = ( + properties if properties is not None else ClientRequestProperties() + ) - @property # type: ignore + @property @check_result @check_closed def rowcount(self) -> int: @@ -152,7 +181,6 @@ def rowcount(self) -> int: @check_closed def close(self): """Closes the cursor.""" - # self.closed = True @check_closed def execute(self, operation, parameters=None) -> "Cursor": @@ -165,23 +193,29 @@ def execute(self, operation, parameters=None) -> "Cursor": query = Cursor._apply_parameters(operation, parameters) query = query.rstrip() try: - server_response = self.kusto_client.execute(self.database, query, self.properties) + server_response = self.kusto_client.execute( + self.database, query, self.properties + ) except KustoServiceError as kusto_error: - raise errors.DatabaseError(str(kusto_error)) + raise errors.DatabaseError(str(kusto_error)) from kusto_error except KustoAuthenticationError as context_error: - raise errors.OperationalError(str(context_error)) + raise errors.OperationalError(str(context_error)) from context_error rows = [] for row in server_response.primary_results[0]: rows.append(tuple(row.to_list())) self._results = rows - self.description = self._get_description_from_columns(server_response.primary_results[0].columns) + self.description = self._get_description_from_columns( + server_response.primary_results[0].columns + ) return self @check_closed def executemany(self, operation, seq_of_parameters=None): - """Not supported""" - raise NotImplementedError("`executemany` is not supported, use `execute` instead") + """Not supported.""" + raise NotImplementedError( + "`executemany` is not supported, use `execute` instead" + ) @check_result @check_closed @@ -199,7 +233,7 @@ def fetchone(self): @check_result @check_closed - def fetchmany(self, size: int = None): + def fetchmany(self, size: int | None = None): """ Fetches the next set of rows of a query result, returning a sequence of sequences (e.g. a list of tuples). An empty sequence is returned when @@ -224,15 +258,17 @@ def fetchall(self): @check_closed def setinputsizes(self, sizes): - """Not supported""" + """Not supported.""" @check_closed def setoutputsizes(self, sizes): - """Not supported""" + """Not supported.""" @staticmethod - def _get_description_from_columns(columns: List[KustoResultColumn]) -> List[CursorDescriptionRow]: - """Gets CursorDescriptionRow for Kusto columns""" + def _get_description_from_columns( + columns: list[KustoResultColumn], + ) -> list[CursorDescriptionRow]: + """Gets CursorDescriptionRow for Kusto columns.""" return [ CursorDescriptionRow( name=column.column_name, @@ -258,31 +294,32 @@ def __next__(self): next = __next__ @staticmethod - def _apply_parameters(operation, parameters) -> str: - """Applies parameters to operation string""" + def _apply_parameters(operation, parameters: dict) -> str: + """Applies parameters to operation string.""" if not parameters: return operation - escaped_parameters = {key: Cursor._escape(value) for key, value in parameters.items()} + escaped_parameters = { + key: Cursor._escape(value) for key, value in parameters.items() + } return operation % escaped_parameters @staticmethod - def _escape(value) -> str: + def _escape(value: Any) -> str: """ Escape the parameter value. Note that bool is a subclass of int so order of statements matter. """ - if value == "*": return value if isinstance(value, str): return "'{}'".format(value.replace("'", "''")) if isinstance(value, bool): return "TRUE" if value else "FALSE" - if isinstance(value, (int, float)): + if isinstance(value, int | float): return str(value) - if isinstance(value, (list, tuple)): + if isinstance(value, list | tuple): return ", ".join(Cursor._escape(element) for element in value) return value diff --git a/sqlalchemy_kusto/dialect_base.py b/sqlalchemy_kusto/dialect_base.py index ce874ba..0438f95 100644 --- a/sqlalchemy_kusto/dialect_base.py +++ b/sqlalchemy_kusto/dialect_base.py @@ -1,12 +1,20 @@ import json from abc import ABC from types import ModuleType -from typing import Any, Dict, List, Optional, Tuple +from typing import Any from sqlalchemy.engine import Connection, default from sqlalchemy.engine.url import URL from sqlalchemy.sql import compiler -from sqlalchemy.types import DATE, TIMESTAMP, BigInteger, Boolean, Float, Integer, String +from sqlalchemy.types import ( + DATE, + TIMESTAMP, + BigInteger, + Boolean, + Float, + Integer, + String, +) import sqlalchemy_kusto @@ -56,7 +64,7 @@ class KustoBaseDialect(default.DefaultDialect, ABC): description_encoding = None supports_native_boolean = True supports_simple_order_by_label = True - _map_parse_connection_parameters: Dict[str, Any] = { + _map_parse_connection_parameters: dict[str, Any] = { "msi": parse_bool_argument, "azure_ad_client_id": str, "azure_ad_client_secret": str, @@ -66,11 +74,11 @@ class KustoBaseDialect(default.DefaultDialect, ABC): } @classmethod - def dbapi(cls) -> ModuleType: # pylint: disable-msg=method-hidden + def dbapi(cls) -> ModuleType: return sqlalchemy_kusto - def create_connect_args(self, url: URL) -> Tuple[List[Any], Dict[str, Any]]: - kwargs: Dict[str, Any] = { + def create_connect_args(self, url: URL) -> tuple[list[Any], dict[str, Any]]: + kwargs: dict[str, Any] = { "cluster": "https://" + url.host, "database": url.database, } @@ -84,21 +92,33 @@ def create_connect_args(self, url: URL) -> Tuple[List[Any], Dict[str, Any]]: return [], kwargs - def get_schema_names(self, connection: Connection, **kwargs) -> List[str]: + def get_schema_names(self, connection: Connection, **kwargs) -> list[str]: result = connection.execute(".show databases | project DatabaseName") return [row.DatabaseName for row in result] - def has_table(self, connection: Connection, table_name: str, schema: Optional[str] = None, **kwargs) -> bool: + def has_table( + self, + connection: Connection, + table_name: str, + schema: str | None = None, + **kwargs, + ) -> bool: return table_name in self.get_table_names(connection, schema) - def get_table_names(self, connection: Connection, schema: Optional[str] = None, **kwargs) -> List[str]: + def get_table_names( + self, connection: Connection, schema: str | None = None, **kwargs + ) -> list[str]: # Schema is not used in Kusto cause database is written in the connection string result = connection.execute(".show tables | project TableName") return [row.TableName for row in result] def get_columns( - self, connection: Connection, table_name: str, schema: Optional[str] = None, **kwargs - ) -> List[Dict[str, Any]]: + self, + connection: Connection, + table_name: str, + schema: str | None = None, + **kwargs, + ) -> list[dict[str, Any]]: table_search_query = f""" .show tables | where TableName == "{table_name}" @@ -117,16 +137,23 @@ def get_columns( query_result = connection.execute(function_schema) rows = list(query_result) entity_schema = json.loads(rows[0].Schema) - return [self.schema_definition(column) for column in entity_schema["OutputColumns"]] - entity_type = "table" if table_search_result.rowcount == 1 else "materialized-view" + return [ + self.schema_definition(column) + for column in entity_schema["OutputColumns"] + ] + entity_type = ( + "table" if table_search_result.rowcount == 1 else "materialized-view" + ) query = f".show {entity_type} {table_name} schema as json" query_result = connection.execute(query) rows = list(query_result) entity_schema = json.loads(rows[0].Schema) - return [self.schema_definition(column) for column in entity_schema["OrderedColumns"]] + return [ + self.schema_definition(column) for column in entity_schema["OrderedColumns"] + ] @staticmethod - def schema_definition(column): + def schema_definition(column) -> dict: return { "name": column["Name"], "type": kql_to_sql_types[column["CslType"].lower()], @@ -134,39 +161,65 @@ def schema_definition(column): "default": "", } - def get_view_names(self, connection: Connection, schema: Optional[str] = None, **kwargs) -> List[str]: - materialized_views = connection.execute(".show materialized-views | project Name") + def get_view_names( + self, connection: Connection, schema: str | None = None, **kwargs + ) -> list[str]: + materialized_views = connection.execute( + ".show materialized-views | project Name" + ) # Functions are also Views. # Filtering no input functions specifically here as there is no way to pass parameters today - functions = connection.execute(".show functions | where Parameters =='()' | project Name") + functions = connection.execute( + ".show functions | where Parameters =='()' | project Name" + ) materialized_view = [row.Name for row in materialized_views] view = [row.Name for row in functions] return materialized_view + view - def get_pk_constraint(self, connection: Connection, table_name: str, schema: Optional[str] = None, **kw): + def get_pk_constraint( + self, connection: Connection, table_name: str, schema: str | None = None, **kw + ): return {"constrained_columns": [], "name": None} def get_foreign_keys(self, connection, table_name, schema=None, **kwargs): return [] - def get_check_constraints(self, connection: Connection, table_name: str, schema: Optional[str] = None, **kwargs): + def get_check_constraints( + self, + connection: Connection, + table_name: str, + schema: str | None = None, + **kwargs, + ): return [] def get_table_comment( - self, connection: Connection, table_name, schema: Optional[str] = None, **kwargs - ) -> Dict[str, Any]: - """Not implemented""" + self, connection: Connection, table_name, schema: str | None = None, **kwargs + ) -> dict[str, Any]: + """Not implemented.""" return {"text": ""} def get_indexes( - self, connection: Connection, table_name: str, schema: Optional[str] = None, **kwargs - ) -> List[Dict[str, Any]]: + self, + connection: Connection, + table_name: str, + schema: str | None = None, + **kwargs, + ) -> list[dict[str, Any]]: return [] - def get_unique_constraints(self, connection: Connection, table_name: str, schema: Optional[str] = None, **kwargs): + def get_unique_constraints( + self, + connection: Connection, + table_name: str, + schema: str | None = None, + **kwargs, + ): return [] - def _check_unicode_returns(self, connection: Connection, additional_tests: List[Any] = None) -> bool: + def _check_unicode_returns( + self, connection: Connection, additional_tests: list[Any] | None = None + ) -> bool: return True def _check_unicode_description(self, connection: Connection) -> bool: @@ -176,9 +229,10 @@ def do_ping(self, dbapi_connection: sqlalchemy_kusto.dbapi.Connection): try: query = ".show tables" dbapi_connection.execute(query) - return True except sqlalchemy_kusto.OperationalError: return False + else: + return True def do_rollback(self, dbapi_connection: sqlalchemy_kusto.dbapi.Connection): pass @@ -225,7 +279,13 @@ def set_isolation_level(self, dbapi_conn, level): def get_isolation_level(self, dbapi_conn): pass - def get_view_definition(self, connection: Connection, view_name: str, schema: Optional[str] = None, **kwargs): + def get_view_definition( + self, + connection: Connection, + view_name: str, + schema: str | None = None, + **kwargs, + ): pass def get_primary_keys(self, connection, table_name, schema=None, **kw): diff --git a/sqlalchemy_kusto/dialect_kql.py b/sqlalchemy_kusto/dialect_kql.py index c60b43a..efbec5e 100644 --- a/sqlalchemy_kusto/dialect_kql.py +++ b/sqlalchemy_kusto/dialect_kql.py @@ -1,6 +1,5 @@ import logging import re -from typing import List, Optional, Tuple from sqlalchemy import Column, exc from sqlalchemy.sql import compiler, operators, selectable @@ -17,7 +16,7 @@ class UniversalSet: - def __contains__(self, item): + def __contains__(self, item) -> bool: return True @@ -25,7 +24,7 @@ class KustoKqlIdentifierPreparer(compiler.IdentifierPreparer): # We want to quote all table and column names to prevent unconventional names usage reserved_words = UniversalSet() - def __init__(self, dialect, **kw): + def __init__(self, dialect, **kw) -> None: super().__init__(dialect, initial_quote='["', final_quote='"]', **kw) @@ -48,11 +47,13 @@ def visit_select( lateral=False, from_linter=None, **kwargs, - ): + ) -> str: logger.debug("Incoming query: %s", select_stmt) if len(select_stmt.get_final_froms()) != 1: - raise NotSupportedError('Only single "select from" query is supported in kql compiler') + raise NotSupportedError( + 'Only single "select from" query is supported in kql compiler' + ) compiled_query_lines = [] @@ -61,7 +62,9 @@ def visit_select( query = self._get_most_inner_element(from_object.element) (main, lets) = self._extract_let_statements(query.text) compiled_query_lines.extend(lets) - compiled_query_lines.append(f"let {from_object.name} = ({self._convert_schema_in_statement(main)});") + compiled_query_lines.append( + f"let {from_object.name} = ({self._convert_schema_in_statement(main)});" + ) compiled_query_lines.append(from_object.name) elif hasattr(from_object, "name"): if from_object.schema is not None: @@ -70,7 +73,9 @@ def visit_select( unquoted_name = from_object.name.strip("\"'") compiled_query_lines.append(f'["{unquoted_name}"]') else: - compiled_query_lines.append(self._convert_schema_in_statement(from_object.text)) + compiled_query_lines.append( + self._convert_schema_in_statement(from_object.text) + ) if select_stmt._whereclause is not None: where_clause = select_stmt._whereclause._compiler_dispatch(self, **kwargs) @@ -81,11 +86,11 @@ def visit_select( if projections: compiled_query_lines.append(projections) - if select_stmt._limit_clause is not None: # pylint: disable=protected-access + if select_stmt._limit_clause is not None: kwargs["literal_execute"] = True compiled_query_lines.append( f"| take {self.process(select_stmt._limit_clause, **kwargs)}" - ) # pylint: disable=protected-access + ) compiled_query_lines = list(filter(None, compiled_query_lines)) @@ -97,7 +102,7 @@ def limit_clause(self, select, **kw): return "" def _get_projection_or_summarize(self, select: selectable.Select) -> str: - """Builds the ending part of the query either project or summarize""" + """Builds the ending part of the query either project or summarize.""" columns = select.inner_columns if columns is not None: column_labels = [] @@ -108,10 +113,14 @@ def _get_projection_or_summarize(self, select: selectable.Select) -> str: if column_name in aggregates_sql_to_kql: is_summarize = True column_labels.append( - self._build_column_projection(aggregates_sql_to_kql[column_name], column_alias) + self._build_column_projection( + aggregates_sql_to_kql[column_name], column_alias + ) ) else: - column_labels.append(self._build_column_projection(column_name, column_alias)) + column_labels.append( + self._build_column_projection(column_name, column_alias) + ) if column_labels: projection_type = "summarize" if is_summarize else "project" @@ -119,7 +128,7 @@ def _get_projection_or_summarize(self, select: selectable.Select) -> str: return "" def _get_most_inner_element(self, clause): - """Finds the most nested element in clause""" + """Finds the most nested element in clause.""" inner_element = getattr(clause, "element", None) if inner_element is not None: return self._get_most_inner_element(inner_element) @@ -127,8 +136,8 @@ def _get_most_inner_element(self, clause): return clause @staticmethod - def _extract_let_statements(clause) -> Tuple[str, List[str]]: - """Separates the final query from let statements""" + def _extract_let_statements(clause) -> tuple[str, list[str]]: + """Separates the final query from let statements.""" rows = [s.strip() for s in clause.split(";")] main = next(filter(lambda row: not row.startswith("let"), rows), None) @@ -139,15 +148,17 @@ def _extract_let_statements(clause) -> Tuple[str, List[str]]: return main, lets @staticmethod - def _extract_column_name_and_alias(column: Column) -> Tuple[str, Optional[str]]: + def _extract_column_name_and_alias(column: Column) -> tuple[str, str | None]: if hasattr(column, "element"): return column.element.name, column.name return column.name, None @staticmethod - def _build_column_projection(column_name: str, column_alias: str = None): - """Generates column alias semantic for project statement""" + def _build_column_projection( + column_name: str, column_alias: str | None = None + ) -> str: + """Generates column alias semantic for project statement.""" return f"{column_alias} = {column_name}" if column_alias else column_name @staticmethod @@ -166,7 +177,6 @@ def _convert_schema_in_statement(query: str) -> str: - ["schema"].["table"] -> database("schema").["table"] - ["table"] -> ["table"] """ - pattern = r"^\[?([a-zA-Z0-9]+\b|\"[a-zA-Z0-9 \-_.]+\")?\]?\.?\[?([a-zA-Z0-9]+\b|\"[a-zA-Z0-9 \-_.]+\")\]?" match = re.search(pattern, query) if not match: @@ -179,7 +189,9 @@ def _convert_schema_in_statement(query: str) -> str: return query.replace(original, f'["{unquoted_table}"]', 1) unquoted_schema = match.group(1).strip("\"'") - return query.replace(original, f'database("{unquoted_schema}").["{unquoted_table}"]', 1) + return query.replace( + original, f'database("{unquoted_schema}").["{unquoted_table}"]', 1 + ) class KustoKqlHttpsDialect(KustoBaseDialect): diff --git a/sqlalchemy_kusto/dialect_sql.py b/sqlalchemy_kusto/dialect_sql.py index 6404080..6ce2de8 100644 --- a/sqlalchemy_kusto/dialect_sql.py +++ b/sqlalchemy_kusto/dialect_sql.py @@ -5,17 +5,17 @@ class KustoSqlCompiler(compiler.SQLCompiler): def get_select_precolumns(self, select, **kw) -> str: - """Kusto uses TOP instead of LIMIT""" + """Kusto uses TOP instead of LIMIT.""" select_precolumns = super().get_select_precolumns(select, **kw) if select._limit_clause is not None: kw["literal_execute"] = True - select_precolumns += "TOP %s " % self.process(select._limit_clause, **kw) + select_precolumns += f"TOP {self.process(select._limit_clause, **kw)} " return select_precolumns def limit_clause(self, select, **kw): - """Do not add LIMIT to the end of the query""" + """Do not add LIMIT to the end of the query.""" return "" def visit_sequence(self, sequence, **kw): @@ -24,16 +24,21 @@ def visit_sequence(self, sequence, **kw): def visit_empty_set_expr(self, element_types): pass - def update_from_clause(self, update_stmt, from_table, extra_froms, from_hints, **kw): + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): pass - def delete_extra_from_clause(self, update_stmt, from_table, extra_froms, from_hints, **kw): + def delete_extra_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): pass class KustoSqlHttpsDialect(KustoBaseDialect): name = "kustosql" statement_compiler = KustoSqlCompiler - # For some reason supports_statement_cache doesn't work when defined in the KustoBaseDialect. + # For some reason supports_statement_cache + # doesn't work when defined in the KustoBaseDialect. # Need to investigate why it happens. supports_statement_cache = True diff --git a/sqlalchemy_kusto/errors.py b/sqlalchemy_kusto/errors.py index ef40a75..70116d1 100644 --- a/sqlalchemy_kusto/errors.py +++ b/sqlalchemy_kusto/errors.py @@ -2,7 +2,7 @@ class Error(Exception): pass -class Warning(Exception): # pylint: disable-msg=redefined-builtin +class Warning(Exception): pass diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 50e0f3c..90a2bcb 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -3,8 +3,12 @@ from dotenv import load_dotenv from sqlalchemy.dialects import registry -registry.register("kustosql.https", "sqlalchemy_kusto.dialect_sql", "KustoSqlHttpsDialect") -registry.register("kustokql.https", "sqlalchemy_kusto.dialect_kql", "KustoKqlHttpsDialect") +registry.register( + "kustosql.https", "sqlalchemy_kusto.dialect_sql", "KustoSqlHttpsDialect" +) +registry.register( + "kustokql.https", "sqlalchemy_kusto.dialect_kql", "KustoKqlHttpsDialect" +) load_dotenv() AZURE_AD_CLIENT_ID = os.environ.get("AZURE_AD_CLIENT_ID", "") diff --git a/tests/integration/test_dbapi.py b/tests/integration/test_dbapi.py index f68fe4a..94c7d39 100644 --- a/tests/integration/test_dbapi.py +++ b/tests/integration/test_dbapi.py @@ -8,20 +8,21 @@ ) -def test_connect(): +def test_connect() -> None: connection = connect("test", DATABASE, True) assert connection is not None -def test_execute(): +def test_execute() -> None: connection = connect( KUSTO_URL, DATABASE, False, + False, None, azure_ad_client_id=AZURE_AD_CLIENT_ID, azure_ad_client_secret=AZURE_AD_CLIENT_SECRET, azure_ad_tenant_id=AZURE_AD_TENANT_ID, ) - result = connection.execute(f"select 1").fetchall() + result = connection.execute("select 1").fetchall() assert result is not None diff --git a/tests/integration/test_dialect_sql.py b/tests/integration/test_dialect_sql.py index 07168d2..4f26563 100644 --- a/tests/integration/test_dialect_sql.py +++ b/tests/integration/test_dialect_sql.py @@ -1,7 +1,13 @@ +from collections.abc import Generator +from typing import Any import uuid import pytest -from azure.kusto.data import ClientRequestProperties, KustoClient, KustoConnectionStringBuilder +from azure.kusto.data import ( + ClientRequestProperties, + KustoClient, + KustoConnectionStringBuilder, +) from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine from tests.integration.conftest import ( @@ -21,31 +27,31 @@ ) -def test_ping(): +def test_ping() -> None: conn = engine.connect() result = engine.dialect.do_ping(conn) assert result is True -def test_get_table_names(temp_table_name): +def test_get_table_names(temp_table_name: str) -> None: conn = engine.connect() result = engine.dialect.get_table_names(conn) assert temp_table_name in result -def test_get_view_names(temp_table_name): +def test_get_view_names(temp_table_name: str) -> None: conn = engine.connect() result = engine.dialect.get_view_names(conn) assert f"{temp_table_name}_fn" in result -def test_get_columns(temp_table_name): +def test_get_columns(temp_table_name: str) -> None: conn = engine.connect() columns_result = engine.dialect.get_columns(conn, temp_table_name) - assert set(["Id", "Text"]) == set([c["name"] for c in columns_result]) + assert {"Id", "Text"} == {c["name"] for c in columns_result} -def test_fetch_one(temp_table_name): +def test_fetch_one(temp_table_name: str) -> None: engine.connect() result = engine.execute(f"select top 2 * from {temp_table_name} order by Id") assert result.fetchone() == (1, "value_1") @@ -53,21 +59,33 @@ def test_fetch_one(temp_table_name): assert result.fetchone() is None -def test_fetch_many(temp_table_name): +def test_fetch_many(temp_table_name: str) -> None: engine.connect() result = engine.execute(f"select top 5 * from {temp_table_name} order by Id") - assert set([(x[0], x[1]) for x in result.fetchmany(3)]) == set([(1, "value_1"), (2, "value_2"), (3, "value_3")]) - assert set([(x[0], x[1]) for x in result.fetchmany(3)]) == set([(4, "value_4"), (5, "value_5")]) + assert {(x[0], x[1]) for x in result.fetchmany(3)} == { + (1, "value_1"), + (2, "value_2"), + (3, "value_3"), + } + assert {(x[0], x[1]) for x in result.fetchmany(3)} == { + (4, "value_4"), + (5, "value_5"), + } -def test_fetch_all(temp_table_name): +def test_fetch_all(temp_table_name: str) -> None: engine.connect() result = engine.execute(f"select top 3 * from {temp_table_name} order by Id") - assert set([(x[0], x[1]) for x in result.fetchall()]) == set([(1, "value_1"), (2, "value_2"), (3, "value_3")]) + assert {(x[0], x[1]) for x in result.fetchall()} == { + (1, "value_1"), + (2, "value_2"), + (3, "value_3"), + } -def test_limit(temp_table_name): +def test_limit(temp_table_name: str) -> None: + limit = 5 stream = Table( temp_table_name, MetaData(), @@ -75,61 +93,72 @@ def test_limit(temp_table_name): Column("Text", String), ) - query = stream.select().limit(5) + query = stream.select().limit(limit) engine.connect() result = engine.execute(query) result_length = len(result.fetchall()) - assert result_length == 5 + assert result_length == limit -def get_kcsb(): +def get_kcsb() -> Any: return ( KustoConnectionStringBuilder.with_az_cli_authentication(KUSTO_URL) - if not AZURE_AD_CLIENT_ID and not AZURE_AD_CLIENT_SECRET and not AZURE_AD_TENANT_ID + if not AZURE_AD_CLIENT_ID + and not AZURE_AD_CLIENT_SECRET + and not AZURE_AD_TENANT_ID else KustoConnectionStringBuilder.with_aad_application_key_authentication( KUSTO_URL, AZURE_AD_CLIENT_ID, AZURE_AD_CLIENT_SECRET, AZURE_AD_TENANT_ID ) ) -def _create_temp_table(table_name: str): +def _create_temp_table(table_name: str) -> None: client = KustoClient(get_kcsb()) - response = client.execute(DATABASE, f".create table {table_name}(Id: int, Text: string)", ClientRequestProperties()) + client.execute( + DATABASE, + f".create table {table_name}(Id: int, Text: string)", + ClientRequestProperties(), + ) -def _create_temp_fn(fn_name: str): +def _create_temp_fn(fn_name: str) -> None: client = KustoClient(get_kcsb()) - response = client.execute(DATABASE, f".create function {fn_name}() {{ print now()}}", ClientRequestProperties()) + client.execute( + DATABASE, + f".create function {fn_name}() {{ print now()}}", + ClientRequestProperties(), + ) -def _ingest_data_to_table(table_name: str): +def _ingest_data_to_table(table_name: str) -> None: client = KustoClient(get_kcsb()) data_to_ingest = {i: "value_" + str(i) for i in range(1, 10)} str_data = "\n".join("{},{}".format(*p) for p in data_to_ingest.items()) ingest_query = f""".ingest inline into table {table_name} <| {str_data}""" - response = client.execute(DATABASE, ingest_query, ClientRequestProperties()) + client.execute(DATABASE, ingest_query, ClientRequestProperties()) -def _drop_table(table_name: str): +def _drop_table(table_name: str) -> None: client = KustoClient(get_kcsb()) _ = client.execute(DATABASE, f".drop table {table_name}", ClientRequestProperties()) - _ = client.execute(DATABASE, f".drop function {table_name}_fn", ClientRequestProperties()) + _ = client.execute( + DATABASE, f".drop function {table_name}_fn", ClientRequestProperties() + ) -@pytest.fixture() -def temp_table_name(): +@pytest.fixture +def temp_table_name() -> str: return "_temp_" + uuid.uuid4().hex @pytest.fixture(autouse=True) -def run_around_tests(temp_table_name): +def run_around_tests(temp_table_name: str) -> Generator[str, None, None]: _create_temp_table(temp_table_name) _create_temp_fn(f"{temp_table_name}_fn") _ingest_data_to_table(temp_table_name) # A test function will be run at this point yield temp_table_name _drop_table(temp_table_name) - # assert files_before == files_after diff --git a/tests/integration/test_error_handling.py b/tests/integration/test_error_handling.py index 38079e5..0a618b2 100644 --- a/tests/integration/test_error_handling.py +++ b/tests/integration/test_error_handling.py @@ -5,7 +5,7 @@ from tests.integration.conftest import DATABASE, KUSTO_SQL_ALCHEMY_URL -def test_operational_error(): +def test_operational_error() -> None: wrong_tenant_id = "wrong_tenant_id" azure_ad_client_id = "x" azure_ad_client_secret = "x" diff --git a/tests/unit/test_dialect_kql.py b/tests/unit/test_dialect_kql.py index e621075..71824c5 100644 --- a/tests/unit/test_dialect_kql.py +++ b/tests/unit/test_dialect_kql.py @@ -1,12 +1,22 @@ import pytest import sqlalchemy as sa -from sqlalchemy import Column, MetaData, String, Table, column, create_engine, literal_column, select, text +from sqlalchemy import ( + Column, + MetaData, + String, + Table, + column, + create_engine, + literal_column, + select, + text, +) from sqlalchemy.sql.selectable import TextAsFrom engine = create_engine("kustokql+https://localhost/testdb") -def test_compiler_with_projection(): +def test_compiler_with_projection() -> None: statement_str = "logs | take 10" stmt = TextAsFrom(sa.text(statement_str), []).alias("virtual_table") query = sa.select( @@ -31,7 +41,7 @@ def test_compiler_with_projection(): assert query_compiled == query_expected -def test_compiler_with_star(): +def test_compiler_with_star() -> None: statement_str = "logs | take 10" stmt = TextAsFrom(sa.text(statement_str), []).alias("virtual_table") query = sa.select( @@ -42,20 +52,30 @@ def test_compiler_with_star(): query = query.limit(10) query_compiled = str(query.compile(engine)).replace("\n", "") - query_expected = 'let virtual_table = (["logs"] | take 10);' "virtual_table" "| take __[POSTCOMPILE_param_1]" + query_expected = ( + 'let virtual_table = (["logs"] | take 10);' + "virtual_table" + "| take __[POSTCOMPILE_param_1]" + ) assert query_compiled == query_expected -def test_select_from_text(): - query = select([column("Field1"), column("Field2")]).select_from(text("logs")).limit(100) - query_compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})).replace("\n", "") +def test_select_from_text() -> None: + query = ( + select([column("Field1"), column("Field2")]) + .select_from(text("logs")) + .limit(100) + ) + query_compiled = str( + query.compile(engine, compile_kwargs={"literal_binds": True}) + ).replace("\n", "") query_expected = '["logs"]' "| project Field1, Field2" "| take 100" assert query_compiled == query_expected -def test_use_table(): +def test_use_table() -> None: metadata = MetaData() stream = Table( "logs", @@ -67,23 +87,31 @@ def test_use_table(): query = stream.select().limit(5) query_compiled = str(query.compile(engine)).replace("\n", "") - query_expected = '["logs"]' "| project Field1, Field2" "| take __[POSTCOMPILE_param_1]" + query_expected = ( + '["logs"]' "| project Field1, Field2" "| take __[POSTCOMPILE_param_1]" + ) assert query_compiled == query_expected -def test_limit(): +def test_limit() -> None: sql = "logs" limit = 5 - query = select("*").select_from(TextAsFrom(text(sql), ["*"]).alias("inner_qry")).limit(limit) + query = ( + select("*") + .select_from(TextAsFrom(text(sql), ["*"]).alias("inner_qry")) + .limit(limit) + ) - query_compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})).replace("\n", "") + query_compiled = str( + query.compile(engine, compile_kwargs={"literal_binds": True}) + ).replace("\n", "") query_expected = 'let inner_qry = (["logs"]);' "inner_qry" "| take 5" assert query_compiled == query_expected -def test_select_count(): +def test_select_count() -> None: kql_query = "logs" column_count = literal_column("count(*)").label("count") query = ( @@ -95,7 +123,9 @@ def test_select_count(): .limit(5) ) - query_compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})).replace("\n", "") + query_compiled = str( + query.compile(engine, compile_kwargs={"literal_binds": True}) + ).replace("\n", "") query_expected = ( 'let inner_qry = (["logs"]);' @@ -108,11 +138,17 @@ def test_select_count(): assert query_compiled == query_expected -def test_select_with_let(): +def test_select_with_let() -> None: kql_query = "let x = 5; let y = 3; MyTable | where Field1 == x and Field2 == y" - query = select("*").select_from(TextAsFrom(text(kql_query), ["*"]).alias("inner_qry")).limit(5) + query = ( + select("*") + .select_from(TextAsFrom(text(kql_query), ["*"]).alias("inner_qry")) + .limit(5) + ) - query_compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})).replace("\n", "") + query_compiled = str( + query.compile(engine, compile_kwargs={"literal_binds": True}) + ).replace("\n", "") query_expected = ( "let x = 5;" @@ -125,7 +161,7 @@ def test_select_with_let(): assert query_compiled == query_expected -def test_quotes(): +def test_quotes() -> None: quote = engine.dialect.identifier_preparer.quote metadata = MetaData() stream = Table( @@ -150,7 +186,7 @@ def test_quotes(): @pytest.mark.parametrize( - "schema_name,table_name,expected_table_name", + ("schema_name", "table_name", "expected_table_name"), [ ("schema", "table", 'database("schema").["table"]'), ("schema", '"table.name"', 'database("schema").["table.name"]'), @@ -161,7 +197,9 @@ def test_quotes(): (None, "MyTable", '["MyTable"]'), ], ) -def test_schema_from_metadata(table_name: str, schema_name: str, expected_table_name: str): +def test_schema_from_metadata( + table_name: str, schema_name: str, expected_table_name: str +) -> None: metadata = MetaData(schema=schema_name) if schema_name else MetaData() stream = Table( table_name, @@ -176,7 +214,7 @@ def test_schema_from_metadata(table_name: str, schema_name: str, expected_table_ @pytest.mark.parametrize( - "query_table_name,expected_table_name", + ("query_table_name", "expected_table_name"), [ ("schema.table", 'database("schema").["table"]'), ('schema."table.name"', 'database("schema").["table.name"]'), @@ -189,10 +227,16 @@ def test_schema_from_metadata(table_name: str, schema_name: str, expected_table_ ('["table"]', '["table"]'), ], ) -def test_schema_from_query(query_table_name: str, expected_table_name: str): - query = select("*").select_from(TextAsFrom(text(query_table_name), ["*"]).alias("inner_qry")).limit(5) +def test_schema_from_query(query_table_name: str, expected_table_name: str) -> None: + query = ( + select("*") + .select_from(TextAsFrom(text(query_table_name), ["*"]).alias("inner_qry")) + .limit(5) + ) - query_compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})).replace("\n", "") + query_compiled = str( + query.compile(engine, compile_kwargs={"literal_binds": True}) + ).replace("\n", "") query_expected = f"let inner_qry = ({expected_table_name});inner_qry| take 5" assert query_compiled == query_expected