summary refs log tree commit diff
path: root/synapse/storage/databases/main/state.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/state.py')
-rw-r--r--synapse/storage/databases/main/state.py30
1 files changed, 21 insertions, 9 deletions
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 28460fd364..4a461a0abb 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -12,9 +12,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import collections.abc
 import logging
-from typing import TYPE_CHECKING, Collection, Iterable, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
+
+from frozendict import frozendict
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
@@ -29,7 +30,7 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.state import StateFilter
-from synapse.types import JsonDict, StateMap
+from synapse.types import JsonDict, JsonMapping, StateMap
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import cached, cachedList
 
@@ -132,7 +133,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return room_version
 
-    async def get_room_predecessor(self, room_id: str) -> Optional[dict]:
+    async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]:
         """Get the predecessor of an upgraded room if it exists.
         Otherwise return None.
 
@@ -158,9 +159,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         predecessor = create_event.content.get("predecessor", None)
 
         # Ensure the key is a dictionary
-        if not isinstance(predecessor, collections.abc.Mapping):
+        if not isinstance(predecessor, (dict, frozendict)):
             return None
 
+        # The keys must be strings since the data is JSON.
         return predecessor
 
     async def get_create_event_for_room(self, room_id: str) -> EventBase:
@@ -306,8 +308,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         list_name="event_ids",
         num_args=1,
     )
-    async def _get_state_group_for_events(self, event_ids: Collection[str]) -> JsonDict:
-        """Returns mapping event_id -> state_group"""
+    async def _get_state_group_for_events(
+        self, event_ids: Collection[str]
+    ) -> Dict[str, int]:
+        """Returns mapping event_id -> state_group.
+
+        Raises:
+             RuntimeError if the state is unknown at any of the given events
+        """
         rows = await self.db_pool.simple_select_many_batch(
             table="event_to_state_groups",
             column="event_id",
@@ -317,7 +325,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             desc="_get_state_group_for_events",
         )
 
-        return {row["event_id"]: row["state_group"] for row in rows}
+        res = {row["event_id"]: row["state_group"] for row in rows}
+        for e in event_ids:
+            if e not in res:
+                raise RuntimeError("No state group for unknown or outlier event %s" % e)
+        return res
 
     async def get_referenced_state_groups(
         self, state_groups: Iterable[int]
@@ -521,7 +533,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
         )
 
         for user_id in potentially_left_users - joined_users:
-            await self.mark_remote_user_device_list_as_unsubscribed(user_id)
+            await self.mark_remote_user_device_list_as_unsubscribed(user_id)  # type: ignore[attr-defined]
 
         return batch_size