diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index f08ce6617d9..ec2ea8e395f 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -164,6 +164,23 @@ def __init__( prefilled_cache=user_signature_stream_prefill, ) + ( + device_list_federation_prefill, + device_list_federation_list_id, + ) = self.db_pool.get_cache_dict( + db_conn, + "device_lists_outbound_pokes", + entity_column="destination", + stream_column="stream_id", + max_value=device_list_max, + limit=10000, + ) + self._device_list_federation_stream_cache = StreamChangeCache( + "DeviceListFederationStreamChangeCache", + device_list_federation_list_id, + prefilled_cache=device_list_federation_prefill, + ) + if hs.config.worker.run_background_tasks: self._clock.looping_call( self._prune_old_outbound_device_pokes, 60 * 60 * 1000 @@ -204,6 +221,12 @@ def _invalidate_caches_for_devices( (row.user_id,) ) + else: + # self._device_list_federation_stream_cache.entity_has_changed( + # row.entity, token + # ) + pass + def device_lists_in_rooms_have_changed( self, room_ids: StrCollection, token: int ) -> None: @@ -346,6 +369,19 @@ async def get_device_updates_by_remote( if from_stream_id == now_stream_id: return now_stream_id, [] + # has_changed = self._device_list_federation_stream_cache.has_entity_changed( + # destination, int(from_stream_id) + # ) + # if not has_changed: + # # debugging for https://github.com/matrix-org/synapse/issues/14251 + # issue_8631_logger.debug( + # "%s: no change between %i and %i", + # destination, + # from_stream_id, + # now_stream_id, + # ) + # return now_stream_id, [] + updates = await self.db_pool.runInteraction( "get_device_updates_by_remote", self._get_device_updates_by_remote_txn, @@ -2089,6 +2125,13 @@ def _add_device_outbound_poke_to_stream_txn( stream_ids: List[int], context: Optional[Dict[str, str]], ) -> None: + for host in hosts: + txn.call_after( + self._device_list_federation_stream_cache.entity_has_changed, + host, + stream_ids[-1], + ) + now = self._clock.time_msec() stream_id_iterator = iter(stream_ids)