summary refs log tree commit diff
path: root/synapse/state
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/state')
-rw-r--r--synapse/state/__init__.py36
1 files changed, 21 insertions, 15 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 56606e9afb..fcb7e829d4 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -298,12 +298,18 @@ class StateHandler:
 
             state_group_before_event_prev_group = entry.prev_group
             deltas_to_state_group_before_event = entry.delta_ids
+            state_ids_before_event = None
 
             # We make sure that we have a state group assigned to the state.
             if entry.state_group is None:
-                state_ids_before_event = await entry.get_state(
-                    self._state_storage_controller, StateFilter.all()
-                )
+                # store_state_group requires us to have either a previous state group
+                # (with deltas) or the complete state map. So, if we don't have a
+                # previous state group, load the complete state map now.
+                if state_group_before_event_prev_group is None:
+                    state_ids_before_event = await entry.get_state(
+                        self._state_storage_controller, StateFilter.all()
+                    )
+
                 state_group_before_event = (
                     await self._state_storage_controller.store_state_group(
                         event.event_id,
@@ -316,7 +322,6 @@ class StateHandler:
                 entry.state_group = state_group_before_event
             else:
                 state_group_before_event = entry.state_group
-                state_ids_before_event = None
 
         #
         # now if it's not a state event, we're done
@@ -336,19 +341,20 @@ class StateHandler:
         #
         # otherwise, we'll need to create a new state group for after the event
         #
-        if state_ids_before_event is None:
-            state_ids_before_event = await entry.get_state(
-                self._state_storage_controller, StateFilter.all()
-            )
 
         key = (event.type, event.state_key)
-        if key in state_ids_before_event:
-            replaces = state_ids_before_event[key]
-            if replaces != event.event_id:
-                event.unsigned["replaces_state"] = replaces
 
-        state_ids_after_event = dict(state_ids_before_event)
-        state_ids_after_event[key] = event.event_id
+        if state_ids_before_event is not None:
+            replaces = state_ids_before_event.get(key)
+        else:
+            replaces_state_map = await entry.get_state(
+                self._state_storage_controller, StateFilter.from_types([key])
+            )
+            replaces = replaces_state_map.get(key)
+
+        if replaces and replaces != event.event_id:
+            event.unsigned["replaces_state"] = replaces
+
         delta_ids = {key: event.event_id}
 
         state_group_after_event = (
@@ -357,7 +363,7 @@ class StateHandler:
                 event.room_id,
                 prev_group=state_group_before_event,
                 delta_ids=delta_ids,
-                current_state_ids=state_ids_after_event,
+                current_state_ids=None,
             )
         )