diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 4700e74ad2..06c44bb563 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -30,7 +30,10 @@ from typing import (
Optional,
Set,
Tuple,
+ TypeVar,
+ Union,
cast,
+ overload,
)
import attr
@@ -52,7 +55,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
-from synapse.types import JsonDict, JsonMapping, StateMap
+from synapse.types import JsonDict, JsonMapping, StateKey, StateMap
from synapse.types.state import StateFilter
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
@@ -64,6 +67,8 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+_T = TypeVar("_T")
+
MAX_STATE_DELTA_HOPS = 100
@@ -349,7 +354,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
def _get_filtered_current_state_ids_txn(
txn: LoggingTransaction,
) -> StateMap[str]:
- results = {}
+ results = StateMapWrapper(state_filter=state_filter or StateFilter.all())
+
sql = """
SELECT type, state_key, event_id FROM current_state_events
WHERE room_id = ?
@@ -726,3 +732,41 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
+
+
+@attr.s(auto_attribs=True, slots=True)
+class StateMapWrapper(Dict[StateKey, str]):
+ """A wrapper around a StateMap[str] to ensure that we only query for items
+ that were not filtered out.
+
+ This is to help prevent bugs where we filter out state but other bits of the
+ code expect the state to be there.
+ """
+
+ state_filter: StateFilter
+
+ def __getitem__(self, key: StateKey) -> str:
+ if key not in self.state_filter:
+ raise Exception("State map was filtered and doesn't include: %s", key)
+ return super().__getitem__(key)
+
+ @overload
+ def get(self, key: Tuple[str, str]) -> Optional[str]:
+ ...
+
+ @overload
+ def get(self, key: Tuple[str, str], default: Union[str, _T]) -> Union[str, _T]:
+ ...
+
+ def get(
+ self, key: StateKey, default: Union[str, _T, None] = None
+ ) -> Union[str, _T, None]:
+ if key not in self.state_filter:
+ raise Exception("State map was filtered and doesn't include: %s", key)
+ return super().get(key, default)
+
+ def __contains__(self, key: Any) -> bool:
+ if key not in self.state_filter:
+ raise Exception("State map was filtered and doesn't include: %s", key)
+
+ return super().__contains__(key)
|