diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 3b0ff786c25..f4c67176b71 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -121,6 +121,12 @@ async def on_rdata( ) self.store.device_lists_in_rooms_have_changed(all_room_ids, token) + # If we're sending federation we need to update the device lists + # outbound pokes stream change cache with updated hosts. + if self.send_handler and any(row.hosts_calculated for row in rows): + hosts = await self.store.get_destinations_for_device(token) + self.store.device_lists_outbound_pokes_have_changed(hosts, token) + self.store.process_replication_rows(stream_name, instance_name, token, rows) # NOTE: this must be called after process_replication_rows to ensure any # cache invalidations are first handled before any stream ID advances. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index ec2ea8e395f..76abf1e5ba3 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -164,22 +164,24 @@ def __init__( prefilled_cache=user_signature_stream_prefill, ) - ( - device_list_federation_prefill, - device_list_federation_list_id, - ) = self.db_pool.get_cache_dict( - db_conn, - "device_lists_outbound_pokes", - entity_column="destination", - stream_column="stream_id", - max_value=device_list_max, - limit=10000, - ) - self._device_list_federation_stream_cache = StreamChangeCache( - "DeviceListFederationStreamChangeCache", - device_list_federation_list_id, - prefilled_cache=device_list_federation_prefill, - ) + self._device_list_federation_stream_cache = None + if hs.get_federation_sender() is not None: + ( + device_list_federation_prefill, + device_list_federation_list_id, + ) = self.db_pool.get_cache_dict( + db_conn, + "device_lists_outbound_pokes", + entity_column="destination", + stream_column="stream_id", + max_value=device_list_max, + limit=10000, + ) + self._device_list_federation_stream_cache = StreamChangeCache( + "DeviceListFederationStreamChangeCache", + device_list_federation_list_id, + prefilled_cache=device_list_federation_prefill, + ) if hs.config.worker.run_background_tasks: self._clock.looping_call( @@ -221,11 +223,15 @@ def _invalidate_caches_for_devices( (row.user_id,) ) - else: - # self._device_list_federation_stream_cache.entity_has_changed( - # row.entity, token - # ) - pass + def device_lists_outbound_pokes_have_changed( + self, destinations: StrCollection, token: int + ) -> None: + assert self._device_list_federation_stream_cache is not None + + for destination in destinations: + self._device_list_federation_stream_cache.entity_has_changed( + destination, token + ) def device_lists_in_rooms_have_changed( self, room_ids: StrCollection, token: int @@ -369,18 +375,21 @@ async def get_device_updates_by_remote( if from_stream_id == now_stream_id: return now_stream_id, [] - # has_changed = self._device_list_federation_stream_cache.has_entity_changed( - # destination, int(from_stream_id) - # ) - # if not has_changed: - # # debugging for https://github.com/matrix-org/synapse/issues/14251 - # issue_8631_logger.debug( - # "%s: no change between %i and %i", - # destination, - # from_stream_id, - # now_stream_id, - # ) - # return now_stream_id, [] + if self._device_list_federation_stream_cache is None: + raise Exception("Func can only be used on federation senders") + + has_changed = self._device_list_federation_stream_cache.has_entity_changed( + destination, int(from_stream_id) + ) + if not has_changed: + # debugging for https://github.com/matrix-org/synapse/issues/14251 + issue_8631_logger.debug( + "%s: no change between %i and %i", + destination, + from_stream_id, + now_stream_id, + ) + return now_stream_id, [] updates = await self.db_pool.runInteraction( "get_device_updates_by_remote", @@ -2125,12 +2134,13 @@ def _add_device_outbound_poke_to_stream_txn( stream_ids: List[int], context: Optional[Dict[str, str]], ) -> None: - for host in hosts: - txn.call_after( - self._device_list_federation_stream_cache.entity_has_changed, - host, - stream_ids[-1], - ) + if self._device_list_federation_stream_cache: + for host in hosts: + txn.call_after( + self._device_list_federation_stream_cache.entity_has_changed, + host, + stream_ids[-1], + ) now = self._clock.time_msec() stream_id_iterator = iter(stream_ids)