From d1adcb6af863a7ae0b2a2385ff43abfd7b7213e7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 5 Jan 2024 17:02:08 +0000 Subject: [PATCH] Fob --- synapse/handlers/room_summary.py | 12 ++++++- synapse/storage/databases/main/state.py | 48 +++++++++++++++++++++++-- synapse/types/state.py | 11 ++++++ 3 files changed, 68 insertions(+), 3 deletions(-) diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index a534f5f280b..78bcac1429e 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -44,6 +44,7 @@ from synapse.config.ratelimiting import RatelimitSettings from synapse.events import EventBase from synapse.types import JsonDict, Requester, StrCollection +from synapse.types.state import StateFilter from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: @@ -546,7 +547,16 @@ async def _is_local_room_accessible( Returns: True if the room is accessible to the requesting user or server. """ - state_ids = await self._storage_controllers.state.get_current_state_ids(room_id) + event_types = [ + (EventTypes.JoinRules, ""), + (EventTypes.RoomHistoryVisibility, ""), + ] + if requester: + event_types.append((EventTypes.Member, requester)) + + state_ids = await self._storage_controllers.state.get_current_state_ids( + room_id, state_filter=StateFilter.from_types(event_types) + ) # If there's no state for the room, it isn't known. if not state_ids: diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 4700e74ad28..9c443a8009b 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -30,7 +30,10 @@ Optional, Set, Tuple, + TypeVar, + Union, cast, + overload, ) import attr @@ -52,7 +55,7 @@ ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore -from synapse.types import JsonDict, JsonMapping, StateMap +from synapse.types import JsonDict, JsonMapping, StateKey, StateMap from synapse.types.state import StateFilter from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList @@ -64,6 +67,8 @@ logger = logging.getLogger(__name__) +_T = TypeVar("_T") + MAX_STATE_DELTA_HOPS = 100 @@ -349,7 +354,8 @@ async def get_partial_filtered_current_state_ids( def _get_filtered_current_state_ids_txn( txn: LoggingTransaction, ) -> StateMap[str]: - results = {} + results = StateMapWrapper(state_filter=state_filter or StateFilter.all()) + sql = """ SELECT type, state_key, event_id FROM current_state_events WHERE room_id = ? @@ -726,3 +732,41 @@ def __init__( hs: "HomeServer", ): super().__init__(database, db_conn, hs) + + +@attr.s(auto_attribs=True, slots=True) +class StateMapWrapper(Dict[StateKey, str]): + """A wrapper around a StateMap[str] to ensure that we only query for items + that were not filtered out. + + This is to help prevent bugs where we filter out state but other bits of the + code expect the state to be there. + """ + + state_filter: StateFilter + + def __getitem__(self, key: StateKey) -> str: + if key not in self.state_filter: + raise Exception("State map was filtered and doesn't include: %s", key) + return super().__getitem__(key) + + @overload + def get(self, key: Tuple[str, str]) -> Optional[str]: + ... + + @overload + def get(self, key: Tuple[str, str], default: Union[str, _T]) -> Union[str, _T]: + ... + + def get( + self, key: StateKey, default: Union[str, _T, None] = None + ) -> Union[str, _T, None]: + if key not in self.state_filter: + raise Exception("State map was filtered and doesn't include: %s", key) + return super().get(key, default) + + def __contains__(self, key: object) -> bool: + if key not in self.state_filter: + raise Exception("State map was filtered and doesn't include: %s", key) + + return super().__contains__(key) diff --git a/synapse/types/state.py b/synapse/types/state.py index 5ca3c94bceb..36024e1b426 100644 --- a/synapse/types/state.py +++ b/synapse/types/state.py @@ -584,6 +584,17 @@ def must_await_full_state(self, is_mine_id: Callable[[str], bool]) -> bool: # local users only return False + def __contains__(self, key: Tuple[str, str]) -> bool: + typ, state_key = key + if typ in self.types: + state_keys = self.types[typ] + if state_keys is None or state_key in state_keys: + return True + elif self.include_others: + return True + + return False + _ALL_STATE_FILTER = StateFilter(types=immutabledict(), include_others=True) _ALL_NON_MEMBER_STATE_FILTER = StateFilter(