diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 9c443a8009b..06c44bb5631 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -765,7 +765,7 @@ def get( raise Exception("State map was filtered and doesn't include: %s", key) return super().get(key, default) - def __contains__(self, key: object) -> bool: + def __contains__(self, key: Any) -> bool: if key not in self.state_filter: raise Exception("State map was filtered and doesn't include: %s", key) diff --git a/synapse/types/state.py b/synapse/types/state.py index 36024e1b426..937ffe3f979 100644 --- a/synapse/types/state.py +++ b/synapse/types/state.py @@ -20,6 +20,7 @@ import logging from typing import ( TYPE_CHECKING, + Any, Callable, Collection, Dict, @@ -584,12 +585,17 @@ def must_await_full_state(self, is_mine_id: Callable[[str], bool]) -> bool: # local users only return False - def __contains__(self, key: Tuple[str, str]) -> bool: + def __contains__(self, key: Any) -> bool: + if not isinstance(key, tuple) or len(key) != 2: + return False + typ, state_key = key + if typ in self.types: state_keys = self.types[typ] if state_keys is None or state_key in state_keys: return True + elif self.include_others: return True