diff options
Diffstat (limited to 'synapse/storage/databases/main/state.py')
-rw-r--r-- | synapse/storage/databases/main/state.py | 30 |
1 files changed, 21 insertions, 9 deletions
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 28460fd364..4a461a0abb 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -12,9 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import collections.abc import logging -from typing import TYPE_CHECKING, Collection, Iterable, Optional, Set, Tuple +from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple + +from frozendict import frozendict from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError @@ -29,7 +30,7 @@ from synapse.storage.database import ( from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.state import StateFilter -from synapse.types import JsonDict, StateMap +from synapse.types import JsonDict, JsonMapping, StateMap from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList @@ -132,7 +133,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return room_version - async def get_room_predecessor(self, room_id: str) -> Optional[dict]: + async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]: """Get the predecessor of an upgraded room if it exists. Otherwise return None. @@ -158,9 +159,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): predecessor = create_event.content.get("predecessor", None) # Ensure the key is a dictionary - if not isinstance(predecessor, collections.abc.Mapping): + if not isinstance(predecessor, (dict, frozendict)): return None + # The keys must be strings since the data is JSON. return predecessor async def get_create_event_for_room(self, room_id: str) -> EventBase: @@ -306,8 +308,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): list_name="event_ids", num_args=1, ) - async def _get_state_group_for_events(self, event_ids: Collection[str]) -> JsonDict: - """Returns mapping event_id -> state_group""" + async def _get_state_group_for_events( + self, event_ids: Collection[str] + ) -> Dict[str, int]: + """Returns mapping event_id -> state_group. + + Raises: + RuntimeError if the state is unknown at any of the given events + """ rows = await self.db_pool.simple_select_many_batch( table="event_to_state_groups", column="event_id", @@ -317,7 +325,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): desc="_get_state_group_for_events", ) - return {row["event_id"]: row["state_group"] for row in rows} + res = {row["event_id"]: row["state_group"] for row in rows} + for e in event_ids: + if e not in res: + raise RuntimeError("No state group for unknown or outlier event %s" % e) + return res async def get_referenced_state_groups( self, state_groups: Iterable[int] @@ -521,7 +533,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): ) for user_id in potentially_left_users - joined_users: - await self.mark_remote_user_device_list_as_unsubscribed(user_id) + await self.mark_remote_user_device_list_as_unsubscribed(user_id) # type: ignore[attr-defined] return batch_size |