diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index caabce8e4ce..fc9baaf6943 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -30,13 +30,11 @@ ) import attr -from twisted.internet import defer from twisted.internet.interfaces import IDelayedCall from synapse.api.constants import EventTypes from synapse.api.errors import Codes, NotFoundError, ShadowBanError, SynapseError from synapse.events import EventBase -from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import set_tag from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases.main.delayed_events import ( @@ -92,15 +90,8 @@ def __init__(self, hs: "HomeServer"): async def _schedule_db_events() -> None: # TODO: Sync all state first, so that affected delayed state events will be cancelled events, remaining_timeout_delays = await self.store.process_all_delays(self._get_current_ts()) - await make_deferred_yieldable( - defer.gatherResults( - [ - run_as_background_process("_send_event", self._send_event, *args) - for args in events - ], - consumeErrors=True, - ) - ) + for args in events: + await self._send_event(*args) for delay_id, user_localpart, relative_delay in remaining_timeout_delays: self._schedule(delay_id, user_localpart, relative_delay) @@ -243,14 +234,11 @@ async def update(self, requester: Requester, delay_id: str, action: str) -> None self._schedule(delay_id, user_localpart, delay) elif enum_action == _UpdateDelayedEventAction.SEND: - await self._send_now(delay_id, user_localpart) - - async def _send_now(self, delay_id: DelayID, user_localpart: UserLocalpart) -> None: - args, removed_timeout_delay_ids = await self.store.pop_event(delay_id, user_localpart) + args, removed_timeout_delay_ids = await self.store.pop_event(delay_id, user_localpart) - for timeout_delay_id in removed_timeout_delay_ids: - self._unschedule(timeout_delay_id, user_localpart) - await self._send_event(user_localpart, *args) + for timeout_delay_id in removed_timeout_delay_ids: + self._unschedule(timeout_delay_id, user_localpart) + await self._send_event(user_localpart, *args) async def _send_on_timeout(self, delay_id: DelayID, user_localpart: UserLocalpart) -> None: del self._delayed_calls[_DelayedCallKey(delay_id, user_localpart)] diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index ef4921d77ab..bd876b36e9d 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -61,6 +61,16 @@ JsonDict, ] +# TODO: If a Tuple type hint can be extended, extend the above one +DelayedPartialEventWithUser = Tuple[ + UserLocalpart, + RoomID, + EventType, + Optional[StateKey], + Optional[Timestamp], + JsonDict, +] + # TODO: Try to support workers class DelayedEventsStore(SQLBaseStore): @@ -295,7 +305,7 @@ async def get_all_for_user( ] async def process_all_delays(self, current_ts: Timestamp) -> Tuple[ - List[DelayedPartialEvent], + List[DelayedPartialEventWithUser], List[Tuple[DelayID, UserLocalpart, Delay]], ]: """ @@ -305,34 +315,34 @@ async def process_all_delays(self, current_ts: Timestamp) -> Tuple[ """ def process_all_delays_txn(txn: LoggingTransaction) -> Tuple[ - List[DelayedPartialEvent], + List[DelayedPartialEventWithUser], List[Tuple[DelayID, UserLocalpart, Delay]], ]: - events: List[DelayedPartialEvent] = [] + events: List[DelayedPartialEventWithUser] = [] removed_timeout_delay_ids: Set[DelayID] = set() txn.execute( """ WITH delay_send_times AS ( - SELECT delay_rowid, running_since + delay AS send_ts + SELECT delay_rowid, user_localpart, running_since + delay AS send_ts FROM delayed_events JOIN delayed_event_timeouts USING (delay_rowid) ) - SELECT delay_rowid + SELECT delay_rowid, user_localpart FROM delay_send_times WHERE send_ts < ? ORDER BY send_ts """, (current_ts,), ) - for (delay_rowid,) in txn: + for row in txn: try: event, removed_timeout_delay_id = self._pop_event_txn( txn, - keyvalues={"delay_rowid": delay_rowid}, + keyvalues={"delay_rowid": row[0]}, ) except NotFoundError: pass - events.append(event) + events.append((UserLocalpart(row[1]), *event)) removed_timeout_delay_ids |= removed_timeout_delay_id txn.execute(