summary refs log tree commit diff
path: root/synapse/state/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/state/__init__.py')
-rw-r--r--synapse/state/__init__.py32
1 files changed, 20 insertions, 12 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index a601303fa3..9bf2ec368f 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -25,6 +25,7 @@ from typing import (
     Sequence,
     Set,
     Union,
+    cast,
     overload,
 )
 
@@ -41,7 +42,7 @@ from synapse.logging.utils import log_function
 from synapse.state import v1, v2
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.roommember import ProfileInfo
-from synapse.types import Collection, StateMap
+from synapse.types import Collection, MutableStateMap, StateMap
 from synapse.util import Clock
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -205,7 +206,7 @@ class StateHandler(object):
 
         logger.debug("calling resolve_state_groups from get_current_state_ids")
         ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
-        return dict(ret.state)
+        return ret.state
 
     async def get_current_users_in_room(
         self, room_id: str, latest_event_ids: Optional[List[str]] = None
@@ -302,7 +303,7 @@ class StateHandler(object):
             # if we're given the state before the event, then we use that
             state_ids_before_event = {
                 (s.type, s.state_key): s.event_id for s in old_state
-            }
+            }  # type: StateMap[str]
             state_group_before_event = None
             state_group_before_event_prev_group = None
             deltas_to_state_group_before_event = None
@@ -315,7 +316,7 @@ class StateHandler(object):
                 event.room_id, event.prev_event_ids()
             )
 
-            state_ids_before_event = dict(entry.state)
+            state_ids_before_event = entry.state
             state_group_before_event = entry.state_group
             state_group_before_event_prev_group = entry.prev_group
             deltas_to_state_group_before_event = entry.delta_ids
@@ -540,7 +541,7 @@ class StateResolutionHandler(object):
             #
             # XXX: is this actually worthwhile, or should we just let
             # resolve_events_with_store do it?
-            new_state = {}
+            new_state = {}  # type: MutableStateMap[str]
             conflicted_state = False
             for st in state_groups_ids.values():
                 for key, e_id in st.items():
@@ -554,13 +555,20 @@ class StateResolutionHandler(object):
             if conflicted_state:
                 logger.info("Resolving conflicted state for %r", room_id)
                 with Measure(self.clock, "state._resolve_events"):
-                    new_state = await resolve_events_with_store(
-                        self.clock,
-                        room_id,
-                        room_version,
-                        list(state_groups_ids.values()),
-                        event_map=event_map,
-                        state_res_store=state_res_store,
+                    # resolve_events_with_store returns a StateMap, but we can
+                    # treat it as a MutableStateMap as it is above. It isn't
+                    # actually mutated anymore (and is frozen in
+                    # _make_state_cache_entry below).
+                    new_state = cast(
+                        MutableStateMap,
+                        await resolve_events_with_store(
+                            self.clock,
+                            room_id,
+                            room_version,
+                            list(state_groups_ids.values()),
+                            event_map=event_map,
+                            state_res_store=state_res_store,
+                        ),
                     )
 
             # if the new state matches any of the input state groups, we can