Skip to content

Commit

Permalink
Merge remote-tracking branch 'refs/remotes/origin/psycopg3' into psyc…
Browse files Browse the repository at this point in the history
…opg3
  • Loading branch information
clokep committed Oct 23, 2024
2 parents 3bbd562 + f5b6429 commit 7ff4584
Show file tree
Hide file tree
Showing 9 changed files with 41 additions and 36 deletions.
19 changes: 7 additions & 12 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
PsycopgEngine,
Sqlite3Engine,
)
from synapse.storage.engines._base import IsolationLevel
from synapse.storage.types import Connection, Cursor, SQLQueryParameters
from synapse.types import StrCollection
from synapse.util.async_helpers import delay_cancellation
Expand Down Expand Up @@ -408,7 +409,7 @@ def execute_values(
values: Collection[Iterable[Any]],
template: Optional[str] = None,
fetch: bool = True,
) -> List[Tuple]:
) -> Iterable[Tuple]:
"""Corresponds to psycopg2.extras.execute_values. Only available when
using postgres.
Expand Down Expand Up @@ -453,7 +454,7 @@ def execute_values(
def f(
the_sql: str, the_args: Sequence[Sequence[Any]]
) -> Iterable[Tuple[Any, ...]]:
with self.txn.copy(the_sql, the_args) as copy:
with self.txn.copy(the_sql, the_args) as copy: # type: ignore[attr-defined]
yield from copy.rows()

# Flatten the values.
Expand All @@ -468,7 +469,7 @@ def copy_write(
def f(
the_sql: str, the_args: Iterable[Any], the_values: Iterable[Iterable[Any]]
) -> None:
with self.txn.copy(the_sql, the_args) as copy:
with self.txn.copy(the_sql, the_args) as copy: # type: ignore[attr-defined]
for record in the_values:
copy.write_row(record)

Expand Down Expand Up @@ -504,12 +505,6 @@ def executescript(self, sql: str) -> None:

def _make_sql_one_line(self, sql: str) -> str:
"Strip newlines out of SQL so that the loggers in the DB are on one line"
if isinstance(self.database_engine, PsycopgEngine):
import psycopg.sql

if isinstance(sql, psycopg.sql.Composed):
return sql.as_string(None)

return " ".join(line.strip() for line in sql.splitlines() if line.strip())

def _do_execute(
Expand Down Expand Up @@ -933,7 +928,7 @@ async def runInteraction(
func: Callable[..., R],
*args: Any,
db_autocommit: bool = False,
isolation_level: Optional[int] = None,
isolation_level: Optional[IsolationLevel] = None,
**kwargs: Any,
) -> R:
"""Starts a transaction on the database and runs a given function
Expand Down Expand Up @@ -1015,7 +1010,7 @@ async def runWithConnection(
func: Callable[Concatenate[LoggingDatabaseConnection, P], R],
*args: Any,
db_autocommit: bool = False,
isolation_level: Optional[int] = None,
isolation_level: Optional[IsolationLevel] = None,
**kwargs: Any,
) -> R:
"""Wraps the .runWithConnection() method on the underlying db_pool.
Expand Down Expand Up @@ -2421,7 +2416,7 @@ def simple_delete_many_batch_txn(
txn: LoggingTransaction,
table: str,
keys: Collection[str],
values: Iterable[Iterable[Any]],
values: Sequence[Iterable[Any]],
) -> None:
"""Executes a DELETE query on the named table.
Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,7 +1214,7 @@ async def claim_e2e_fallback_keys(
def _claim_e2e_fallback_keys_bulk_txn(
self,
txn: LoggingTransaction,
query_list: Iterable[Tuple[str, str, str, bool]],
query_list: Collection[Tuple[str, str, str, bool]],
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
"""Efficient implementation of claim_e2e_fallback_keys for Postgres.
Expand Down Expand Up @@ -1342,7 +1342,7 @@ def _claim_e2e_one_time_key_simple(
def _claim_e2e_one_time_keys_bulk(
self,
txn: LoggingTransaction,
query_list: Iterable[Tuple[str, str, str, int]],
query_list: Collection[Tuple[str, str, str, int]],
) -> List[Tuple[str, str, str, str, str]]:
"""Bulk claim OTKs, for DBs that support DELETE FROM... RETURNING.
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/event_push_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,10 @@
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
PostgresEngine,
)
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.stream import StreamWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.types import JsonDict, StrCollection
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/engines/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def attempt_to_set_autocommit(self, conn: ConnectionType, autocommit: bool) -> N

@abc.abstractmethod
def attempt_to_set_isolation_level(
self, conn: ConnectionType, isolation_level: Optional[IsolationLevelType]
self, conn: ConnectionType, isolation_level: Optional[IsolationLevel] = None
) -> None:
"""Attempt to set the connections isolation level.
Expand Down
16 changes: 10 additions & 6 deletions synapse/storage/engines/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,31 +21,30 @@

import abc
import logging
from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, Tuple, cast
from typing import TYPE_CHECKING, Any, Mapping, Optional, Tuple, cast

from synapse.storage.engines._base import (
AUTO_INCREMENT_PRIMARY_KEYPLACEHOLDER,
BaseDatabaseEngine,
ConnectionType,
CursorType,
IncorrectDatabaseSetup,
IsolationLevel,
IsolationLevelType,
)
from synapse.storage.types import Cursor, DBAPI2Module

if TYPE_CHECKING:
from synapse.storage.database import LoggingDatabaseConnection


logger = logging.getLogger(__name__)


class PostgresEngine(
Generic[ConnectionType, CursorType, IsolationLevelType],
BaseDatabaseEngine[ConnectionType, CursorType, IsolationLevelType],
metaclass=abc.ABCMeta,
):
isolation_level_map: Mapping[int, IsolationLevelType]
isolation_level_map: Mapping[IsolationLevel, IsolationLevelType]
default_isolation_level: IsolationLevelType

def __init__(self, module: DBAPI2Module, database_config: Mapping[str, Any]):
Expand Down Expand Up @@ -173,7 +172,7 @@ def convert_param_style(self, sql: str) -> str:

def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None:
# mypy doesn't realize that ConnectionType matches the Connection protocol.
self.attempt_to_set_isolation_level(db_conn.conn, self.default_isolation_level) # type: ignore[arg-type]
self.attempt_to_set_isolation_level(db_conn.conn) # type: ignore[arg-type]

# Set the bytea output to escape, vs the default of hex
cursor = db_conn.cursor()
Expand All @@ -187,7 +186,12 @@ def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None:

# Abort really long-running statements and turn them into errors.
if self.statement_timeout is not None:
self.set_statement_timeout(cursor.txn, self.statement_timeout)
# Because the PostgresEngine is considered an ABCMeta, a superclass and a
# subclass, cursor's type is messy. We know it should be a CursorType,
# but for now that doesn't pass cleanly through LoggingDatabaseConnection
# and LoggingTransaction. Fortunately, it's merely running an execute()
# and nothing more exotic.
self.set_statement_timeout(cursor.txn, self.statement_timeout) # type: ignore[arg-type]

cursor.close()
db_conn.commit()
Expand Down
7 changes: 4 additions & 3 deletions synapse/storage/engines/psycopg.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ def set_statement_timeout(
self, cursor: psycopg.Cursor, statement_timeout: int
) -> None:
"""Configure the current cursor's statement timeout."""
cursor.execute(
psycopg.sql.SQL("SET statement_timeout TO {}").format(statement_timeout)
query_str = psycopg.sql.SQL("SET statement_timeout TO {}").format(
statement_timeout
)
cursor.execute(query_str.as_string())

def convert_param_style(self, sql: str) -> str:
# if isinstance(sql, psycopg.sql.Composed):
Expand Down Expand Up @@ -87,7 +88,7 @@ def attempt_to_set_autocommit(
conn.autocommit = autocommit

def attempt_to_set_isolation_level(
self, conn: psycopg.Connection, isolation_level: Optional[int]
self, conn: psycopg.Connection, isolation_level: Optional[IsolationLevel] = None
) -> None:
if isolation_level is None:
pg_isolation_level = self.default_isolation_level
Expand Down
10 changes: 6 additions & 4 deletions synapse/storage/engines/psycopg2.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,12 @@ def attempt_to_set_autocommit(
return conn.set_session(autocommit=autocommit)

def attempt_to_set_isolation_level(
self, conn: psycopg2.extensions.connection, isolation_level: Optional[int]
self,
conn: psycopg2.extensions.connection,
isolation_level: Optional[IsolationLevel] = None,
) -> None:
if isolation_level is None:
isolation_level = self.default_isolation_level
pg_isolation_level = self.default_isolation_level
else:
isolation_level = self.isolation_level_map[isolation_level]
return conn.set_isolation_level(isolation_level)
pg_isolation_level = self.isolation_level_map[isolation_level]
return conn.set_isolation_level(pg_isolation_level)
7 changes: 5 additions & 2 deletions synapse/storage/engines/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
from typing import TYPE_CHECKING, Any, List, Mapping, Optional

from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.engines._base import AUTO_INCREMENT_PRIMARY_KEYPLACEHOLDER
from synapse.storage.engines._base import (
AUTO_INCREMENT_PRIMARY_KEYPLACEHOLDER,
IsolationLevel,
)
from synapse.storage.types import Cursor

if TYPE_CHECKING:
Expand Down Expand Up @@ -146,7 +149,7 @@ def attempt_to_set_autocommit(
pass

def attempt_to_set_isolation_level(
self, conn: sqlite3.Connection, isolation_level: Optional[int]
self, conn: sqlite3.Connection, isolation_level: Optional[IsolationLevel] = None
) -> None:
# All transactions are SERIALIZABLE by default in sqlite
pass
Expand Down
10 changes: 5 additions & 5 deletions tests/storage/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def runWithConnection(func, *args, **kwargs): # type: ignore[no-untyped-def]
self.datastore = SQLBaseStore(db, None, hs) # type: ignore[arg-type]

def tearDown(self) -> None:
if USE_POSTGRES_FOR_TESTS != "psycopg":
if USE_POSTGRES_FOR_TESTS and USE_POSTGRES_FOR_TESTS != "psycopg":
self.execute_batch_patcher.stop()
self.execute_values_patcher.stop()

Expand Down Expand Up @@ -388,7 +388,7 @@ def test_update_many(self) -> Generator["defer.Deferred[object]", object, None]:
)

# execute_batch is only used on psycopg2.
if USE_POSTGRES_FOR_TESTS != "psycopg":
if USE_POSTGRES_FOR_TESTS and USE_POSTGRES_FOR_TESTS != "psycopg":
self.mock_execute_batch.assert_called_once_with(
self.mock_txn,
"UPDATE tablename SET col3 = ? WHERE col1 = ? AND col2 = ?",
Expand Down Expand Up @@ -429,7 +429,7 @@ def test_update_many_no_iterable(
)

# execute_batch is only used on psycopg2.
if USE_POSTGRES_FOR_TESTS != "psycopg":
if USE_POSTGRES_FOR_TESTS and USE_POSTGRES_FOR_TESTS != "psycopg":
self.mock_execute_batch.assert_not_called()
else:
self.mock_txn.executemany.assert_not_called()
Expand Down Expand Up @@ -601,7 +601,7 @@ def test_upsert_many(self) -> Generator["defer.Deferred[object]", object, None]:
)

# execute_values is only used on psycopg2.
if USE_POSTGRES_FOR_TESTS != "psycopg":
if USE_POSTGRES_FOR_TESTS and USE_POSTGRES_FOR_TESTS != "psycopg":
self.mock_execute_values.assert_called_once_with(
self.mock_txn,
"INSERT INTO tablename (keycol1, keycol2, valuecol3) VALUES ? ON CONFLICT (keycol1, keycol2) DO UPDATE SET valuecol3=EXCLUDED.valuecol3",
Expand Down Expand Up @@ -631,7 +631,7 @@ def test_upsert_many_no_values(
)

# execute_values is only used on psycopg2.
if USE_POSTGRES_FOR_TESTS != "psycopg":
if USE_POSTGRES_FOR_TESTS and USE_POSTGRES_FOR_TESTS != "psycopg":
self.mock_execute_values.assert_called_once_with(
self.mock_txn,
"INSERT INTO tablename (columnname) VALUES ? ON CONFLICT (columnname) DO NOTHING",
Expand Down

0 comments on commit 7ff4584

Please sign in to comment.