summary refs log tree commit diff
path: root/synapse/state/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/state/__init__.py')
-rw-r--r--synapse/state/__init__.py35
1 files changed, 19 insertions, 16 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 0219091c4e..4b4ed42cff 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -288,7 +288,6 @@ class StateHandler:
         #
         # first of all, figure out the state before the event
         #
-
         if old_state:
             # if we're given the state before the event, then we use that
             state_ids_before_event: StateMap[str] = {
@@ -419,33 +418,37 @@ class StateHandler:
         """
         logger.debug("resolve_state_groups event_ids %s", event_ids)
 
-        # map from state group id to the state in that state group (where
-        # 'state' is a map from state key to event id)
-        # dict[int, dict[(str, str), str]]
-        state_groups_ids = await self.state_store.get_state_groups_ids(
-            room_id, event_ids
-        )
-
-        if len(state_groups_ids) == 0:
-            return _StateCacheEntry(state={}, state_group=None)
-        elif len(state_groups_ids) == 1:
-            name, state_list = list(state_groups_ids.items()).pop()
+        state_groups = await self.state_store.get_state_group_for_events(event_ids)
 
-            prev_group, delta_ids = await self.state_store.get_state_group_delta(name)
+        state_group_ids = state_groups.values()
 
+        # check if each event has same state group id, if so there's no state to resolve
+        state_group_ids_set = set(state_group_ids)
+        if len(state_group_ids_set) == 1:
+            (state_group_id,) = state_group_ids_set
+            state = await self.state_store.get_state_for_groups(state_group_ids_set)
+            prev_group, delta_ids = await self.state_store.get_state_group_delta(
+                state_group_id
+            )
             return _StateCacheEntry(
-                state=state_list,
-                state_group=name,
+                state=state[state_group_id],
+                state_group=state_group_id,
                 prev_group=prev_group,
                 delta_ids=delta_ids,
             )
+        elif len(state_group_ids_set) == 0:
+            return _StateCacheEntry(state={}, state_group=None)
 
         room_version = await self.store.get_room_version_id(room_id)
 
+        state_to_resolve = await self.state_store.get_state_for_groups(
+            state_group_ids_set
+        )
+
         result = await self._state_resolution_handler.resolve_state_groups(
             room_id,
             room_version,
-            state_groups_ids,
+            state_to_resolve,
             None,
             state_res_store=StateResolutionStore(self.store),
         )