Skip to content

Commit

Permalink
Fob
Browse files Browse the repository at this point in the history
  • Loading branch information
erikjohnston committed Jan 5, 2024
1 parent 81b1c56 commit d1adcb6
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 3 deletions.
12 changes: 11 additions & 1 deletion synapse/handlers/room_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from synapse.config.ratelimiting import RatelimitSettings
from synapse.events import EventBase
from synapse.types import JsonDict, Requester, StrCollection
from synapse.types.state import StateFilter
from synapse.util.caches.response_cache import ResponseCache

if TYPE_CHECKING:
Expand Down Expand Up @@ -546,7 +547,16 @@ async def _is_local_room_accessible(
Returns:
True if the room is accessible to the requesting user or server.
"""
state_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
event_types = [
(EventTypes.JoinRules, ""),
(EventTypes.RoomHistoryVisibility, ""),
]
if requester:
event_types.append((EventTypes.Member, requester))

state_ids = await self._storage_controllers.state.get_current_state_ids(
room_id, state_filter=StateFilter.from_types(event_types)
)

# If there's no state for the room, it isn't known.
if not state_ids:
Expand Down
48 changes: 46 additions & 2 deletions synapse/storage/databases/main/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@
Optional,
Set,
Tuple,
TypeVar,
Union,
cast,
overload,
)

import attr
Expand All @@ -52,7 +55,7 @@
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.types import JsonDict, JsonMapping, StateMap
from synapse.types import JsonDict, JsonMapping, StateKey, StateMap
from synapse.types.state import StateFilter
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
Expand All @@ -64,6 +67,8 @@

logger = logging.getLogger(__name__)

_T = TypeVar("_T")


MAX_STATE_DELTA_HOPS = 100

Expand Down Expand Up @@ -349,7 +354,8 @@ async def get_partial_filtered_current_state_ids(
def _get_filtered_current_state_ids_txn(
txn: LoggingTransaction,
) -> StateMap[str]:
results = {}
results = StateMapWrapper(state_filter=state_filter or StateFilter.all())

sql = """
SELECT type, state_key, event_id FROM current_state_events
WHERE room_id = ?
Expand Down Expand Up @@ -726,3 +732,41 @@ def __init__(
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)


@attr.s(auto_attribs=True, slots=True)
class StateMapWrapper(Dict[StateKey, str]):
"""A wrapper around a StateMap[str] to ensure that we only query for items
that were not filtered out.
This is to help prevent bugs where we filter out state but other bits of the
code expect the state to be there.
"""

state_filter: StateFilter

def __getitem__(self, key: StateKey) -> str:
if key not in self.state_filter:
raise Exception("State map was filtered and doesn't include: %s", key)
return super().__getitem__(key)

@overload
def get(self, key: Tuple[str, str]) -> Optional[str]:
...

@overload
def get(self, key: Tuple[str, str], default: Union[str, _T]) -> Union[str, _T]:
...

def get(
self, key: StateKey, default: Union[str, _T, None] = None
) -> Union[str, _T, None]:
if key not in self.state_filter:
raise Exception("State map was filtered and doesn't include: %s", key)
return super().get(key, default)

def __contains__(self, key: object) -> bool:
if key not in self.state_filter:
raise Exception("State map was filtered and doesn't include: %s", key)

return super().__contains__(key)
11 changes: 11 additions & 0 deletions synapse/types/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,17 @@ def must_await_full_state(self, is_mine_id: Callable[[str], bool]) -> bool:
# local users only
return False

def __contains__(self, key: Tuple[str, str]) -> bool:
typ, state_key = key
if typ in self.types:
state_keys = self.types[typ]
if state_keys is None or state_key in state_keys:
return True
elif self.include_others:
return True

return False


_ALL_STATE_FILTER = StateFilter(types=immutabledict(), include_others=True)
_ALL_NON_MEMBER_STATE_FILTER = StateFilter(
Expand Down

0 comments on commit d1adcb6

Please sign in to comment.