diff --git a/synapse/storage/database.py b/synapse/storage/database.py index bb452c77b6e..52ffba8661f 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -71,6 +71,7 @@ PsycopgEngine, Sqlite3Engine, ) +from synapse.storage.engines._base import IsolationLevel from synapse.storage.types import Connection, Cursor, SQLQueryParameters from synapse.types import StrCollection from synapse.util.async_helpers import delay_cancellation @@ -408,7 +409,7 @@ def execute_values( values: Collection[Iterable[Any]], template: Optional[str] = None, fetch: bool = True, - ) -> List[Tuple]: + ) -> Iterable[Tuple]: """Corresponds to psycopg2.extras.execute_values. Only available when using postgres. @@ -453,7 +454,7 @@ def execute_values( def f( the_sql: str, the_args: Sequence[Sequence[Any]] ) -> Iterable[Tuple[Any, ...]]: - with self.txn.copy(the_sql, the_args) as copy: + with self.txn.copy(the_sql, the_args) as copy: # type: ignore[attr-defined] yield from copy.rows() # Flatten the values. @@ -468,7 +469,7 @@ def copy_write( def f( the_sql: str, the_args: Iterable[Any], the_values: Iterable[Iterable[Any]] ) -> None: - with self.txn.copy(the_sql, the_args) as copy: + with self.txn.copy(the_sql, the_args) as copy: # type: ignore[attr-defined] for record in the_values: copy.write_row(record) @@ -504,12 +505,6 @@ def executescript(self, sql: str) -> None: def _make_sql_one_line(self, sql: str) -> str: "Strip newlines out of SQL so that the loggers in the DB are on one line" - if isinstance(self.database_engine, PsycopgEngine): - import psycopg.sql - - if isinstance(sql, psycopg.sql.Composed): - return sql.as_string(None) - return " ".join(line.strip() for line in sql.splitlines() if line.strip()) def _do_execute( @@ -933,7 +928,7 @@ async def runInteraction( func: Callable[..., R], *args: Any, db_autocommit: bool = False, - isolation_level: Optional[int] = None, + isolation_level: Optional[IsolationLevel] = None, **kwargs: Any, ) -> R: """Starts a transaction on the database and runs a given function @@ -1015,7 +1010,7 @@ async def runWithConnection( func: Callable[Concatenate[LoggingDatabaseConnection, P], R], *args: Any, db_autocommit: bool = False, - isolation_level: Optional[int] = None, + isolation_level: Optional[IsolationLevel] = None, **kwargs: Any, ) -> R: """Wraps the .runWithConnection() method on the underlying db_pool. @@ -2421,7 +2416,7 @@ def simple_delete_many_batch_txn( txn: LoggingTransaction, table: str, keys: Collection[str], - values: Iterable[Iterable[Any]], + values: Sequence[Iterable[Any]], ) -> None: """Executes a DELETE query on the named table. diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 6e0b58ac609..da79af37a43 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -1214,7 +1214,7 @@ async def claim_e2e_fallback_keys( def _claim_e2e_fallback_keys_bulk_txn( self, txn: LoggingTransaction, - query_list: Iterable[Tuple[str, str, str, bool]], + query_list: Collection[Tuple[str, str, str, bool]], ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]: """Efficient implementation of claim_e2e_fallback_keys for Postgres. @@ -1342,7 +1342,7 @@ def _claim_e2e_one_time_key_simple( def _claim_e2e_one_time_keys_bulk( self, txn: LoggingTransaction, - query_list: Iterable[Tuple[str, str, str, int]], + query_list: Collection[Tuple[str, str, str, int]], ) -> List[Tuple[str, str, str, str, str]]: """Bulk claim OTKs, for DBs that support DELETE FROM... RETURNING. diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index f42023418e2..610fe8ccb76 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -102,10 +102,10 @@ DatabasePool, LoggingDatabaseConnection, LoggingTransaction, - PostgresEngine, ) from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.stream import StreamWorkerStore +from synapse.storage.engines import PostgresEngine from synapse.types import JsonDict, StrCollection from synapse.util import json_encoder from synapse.util.caches.descriptors import cached diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py index 4322c8b5684..50d7364e425 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py @@ -132,7 +132,7 @@ def attempt_to_set_autocommit(self, conn: ConnectionType, autocommit: bool) -> N @abc.abstractmethod def attempt_to_set_isolation_level( - self, conn: ConnectionType, isolation_level: Optional[IsolationLevelType] + self, conn: ConnectionType, isolation_level: Optional[IsolationLevel] = None ) -> None: """Attempt to set the connections isolation level. diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 069aa3485b9..f6e296e2f09 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -21,7 +21,7 @@ import abc import logging -from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, Tuple, cast +from typing import TYPE_CHECKING, Any, Mapping, Optional, Tuple, cast from synapse.storage.engines._base import ( AUTO_INCREMENT_PRIMARY_KEYPLACEHOLDER, @@ -29,6 +29,7 @@ ConnectionType, CursorType, IncorrectDatabaseSetup, + IsolationLevel, IsolationLevelType, ) from synapse.storage.types import Cursor, DBAPI2Module @@ -36,16 +37,14 @@ if TYPE_CHECKING: from synapse.storage.database import LoggingDatabaseConnection - logger = logging.getLogger(__name__) class PostgresEngine( - Generic[ConnectionType, CursorType, IsolationLevelType], BaseDatabaseEngine[ConnectionType, CursorType, IsolationLevelType], metaclass=abc.ABCMeta, ): - isolation_level_map: Mapping[int, IsolationLevelType] + isolation_level_map: Mapping[IsolationLevel, IsolationLevelType] default_isolation_level: IsolationLevelType def __init__(self, module: DBAPI2Module, database_config: Mapping[str, Any]): @@ -173,7 +172,7 @@ def convert_param_style(self, sql: str) -> str: def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None: # mypy doesn't realize that ConnectionType matches the Connection protocol. - self.attempt_to_set_isolation_level(db_conn.conn, self.default_isolation_level) # type: ignore[arg-type] + self.attempt_to_set_isolation_level(db_conn.conn) # type: ignore[arg-type] # Set the bytea output to escape, vs the default of hex cursor = db_conn.cursor() @@ -187,7 +186,12 @@ def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None: # Abort really long-running statements and turn them into errors. if self.statement_timeout is not None: - self.set_statement_timeout(cursor.txn, self.statement_timeout) + # Because the PostgresEngine is considered an ABCMeta, a superclass and a + # subclass, cursor's type is messy. We know it should be a CursorType, + # but for now that doesn't pass cleanly through LoggingDatabaseConnection + # and LoggingTransaction. Fortunately, it's merely running an execute() + # and nothing more exotic. + self.set_statement_timeout(cursor.txn, self.statement_timeout) # type: ignore[arg-type] cursor.close() db_conn.commit() diff --git a/synapse/storage/engines/psycopg.py b/synapse/storage/engines/psycopg.py index 6dd01319e14..e591a6bf6c4 100644 --- a/synapse/storage/engines/psycopg.py +++ b/synapse/storage/engines/psycopg.py @@ -56,9 +56,10 @@ def set_statement_timeout( self, cursor: psycopg.Cursor, statement_timeout: int ) -> None: """Configure the current cursor's statement timeout.""" - cursor.execute( - psycopg.sql.SQL("SET statement_timeout TO {}").format(statement_timeout) + query_str = psycopg.sql.SQL("SET statement_timeout TO {}").format( + statement_timeout ) + cursor.execute(query_str.as_string()) def convert_param_style(self, sql: str) -> str: # if isinstance(sql, psycopg.sql.Composed): @@ -87,7 +88,7 @@ def attempt_to_set_autocommit( conn.autocommit = autocommit def attempt_to_set_isolation_level( - self, conn: psycopg.Connection, isolation_level: Optional[int] + self, conn: psycopg.Connection, isolation_level: Optional[IsolationLevel] = None ) -> None: if isolation_level is None: pg_isolation_level = self.default_isolation_level diff --git a/synapse/storage/engines/psycopg2.py b/synapse/storage/engines/psycopg2.py index 9ce1d7ad3bb..e473cabfd9c 100644 --- a/synapse/storage/engines/psycopg2.py +++ b/synapse/storage/engines/psycopg2.py @@ -76,10 +76,12 @@ def attempt_to_set_autocommit( return conn.set_session(autocommit=autocommit) def attempt_to_set_isolation_level( - self, conn: psycopg2.extensions.connection, isolation_level: Optional[int] + self, + conn: psycopg2.extensions.connection, + isolation_level: Optional[IsolationLevel] = None, ) -> None: if isolation_level is None: - isolation_level = self.default_isolation_level + pg_isolation_level = self.default_isolation_level else: - isolation_level = self.isolation_level_map[isolation_level] - return conn.set_isolation_level(isolation_level) + pg_isolation_level = self.isolation_level_map[isolation_level] + return conn.set_isolation_level(pg_isolation_level) diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 0e9ea4c1203..5b35d93322a 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -25,7 +25,10 @@ from typing import TYPE_CHECKING, Any, List, Mapping, Optional from synapse.storage.engines import BaseDatabaseEngine -from synapse.storage.engines._base import AUTO_INCREMENT_PRIMARY_KEYPLACEHOLDER +from synapse.storage.engines._base import ( + AUTO_INCREMENT_PRIMARY_KEYPLACEHOLDER, + IsolationLevel, +) from synapse.storage.types import Cursor if TYPE_CHECKING: @@ -146,7 +149,7 @@ def attempt_to_set_autocommit( pass def attempt_to_set_isolation_level( - self, conn: sqlite3.Connection, isolation_level: Optional[int] + self, conn: sqlite3.Connection, isolation_level: Optional[IsolationLevel] = None ) -> None: # All transactions are SERIALIZABLE by default in sqlite pass diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 0c792948d5e..d8c3aef6c49 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -123,7 +123,7 @@ def runWithConnection(func, *args, **kwargs): # type: ignore[no-untyped-def] self.datastore = SQLBaseStore(db, None, hs) # type: ignore[arg-type] def tearDown(self) -> None: - if USE_POSTGRES_FOR_TESTS != "psycopg": + if USE_POSTGRES_FOR_TESTS and USE_POSTGRES_FOR_TESTS != "psycopg": self.execute_batch_patcher.stop() self.execute_values_patcher.stop() @@ -388,7 +388,7 @@ def test_update_many(self) -> Generator["defer.Deferred[object]", object, None]: ) # execute_batch is only used on psycopg2. - if USE_POSTGRES_FOR_TESTS != "psycopg": + if USE_POSTGRES_FOR_TESTS and USE_POSTGRES_FOR_TESTS != "psycopg": self.mock_execute_batch.assert_called_once_with( self.mock_txn, "UPDATE tablename SET col3 = ? WHERE col1 = ? AND col2 = ?", @@ -429,7 +429,7 @@ def test_update_many_no_iterable( ) # execute_batch is only used on psycopg2. - if USE_POSTGRES_FOR_TESTS != "psycopg": + if USE_POSTGRES_FOR_TESTS and USE_POSTGRES_FOR_TESTS != "psycopg": self.mock_execute_batch.assert_not_called() else: self.mock_txn.executemany.assert_not_called() @@ -601,7 +601,7 @@ def test_upsert_many(self) -> Generator["defer.Deferred[object]", object, None]: ) # execute_values is only used on psycopg2. - if USE_POSTGRES_FOR_TESTS != "psycopg": + if USE_POSTGRES_FOR_TESTS and USE_POSTGRES_FOR_TESTS != "psycopg": self.mock_execute_values.assert_called_once_with( self.mock_txn, "INSERT INTO tablename (keycol1, keycol2, valuecol3) VALUES ? ON CONFLICT (keycol1, keycol2) DO UPDATE SET valuecol3=EXCLUDED.valuecol3", @@ -631,7 +631,7 @@ def test_upsert_many_no_values( ) # execute_values is only used on psycopg2. - if USE_POSTGRES_FOR_TESTS != "psycopg": + if USE_POSTGRES_FOR_TESTS and USE_POSTGRES_FOR_TESTS != "psycopg": self.mock_execute_values.assert_called_once_with( self.mock_txn, "INSERT INTO tablename (columnname) VALUES ? ON CONFLICT (columnname) DO NOTHING",