summary refs log tree commit diff
path: root/synapse/state/__init__.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2022-07-15 13:59:45 +0100
committerGitHub <noreply@github.com>2022-07-15 12:59:45 +0000
commit0731e0829c08aec7f31fdc72c236757e4cc38747 (patch)
treeb1a7046cf12a1bf24f9affd9fdca26180b863d0f /synapse/state/__init__.py
parentUse a real room in the notification rotation tests. (#13260) (diff)
downloadsynapse-0731e0829c08aec7f31fdc72c236757e4cc38747.tar.xz
Don't pull out the full state when storing state (#13274)
Diffstat (limited to 'synapse/state/__init__.py')
-rw-r--r--synapse/state/__init__.py36
1 files changed, 21 insertions, 15 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 56606e9afb..fcb7e829d4 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -298,12 +298,18 @@ class StateHandler:
 
             state_group_before_event_prev_group = entry.prev_group
             deltas_to_state_group_before_event = entry.delta_ids
+            state_ids_before_event = None
 
             # We make sure that we have a state group assigned to the state.
             if entry.state_group is None:
-                state_ids_before_event = await entry.get_state(
-                    self._state_storage_controller, StateFilter.all()
-                )
+                # store_state_group requires us to have either a previous state group
+                # (with deltas) or the complete state map. So, if we don't have a
+                # previous state group, load the complete state map now.
+                if state_group_before_event_prev_group is None:
+                    state_ids_before_event = await entry.get_state(
+                        self._state_storage_controller, StateFilter.all()
+                    )
+
                 state_group_before_event = (
                     await self._state_storage_controller.store_state_group(
                         event.event_id,
@@ -316,7 +322,6 @@ class StateHandler:
                 entry.state_group = state_group_before_event
             else:
                 state_group_before_event = entry.state_group
-                state_ids_before_event = None
 
         #
         # now if it's not a state event, we're done
@@ -336,19 +341,20 @@ class StateHandler:
         #
         # otherwise, we'll need to create a new state group for after the event
         #
-        if state_ids_before_event is None:
-            state_ids_before_event = await entry.get_state(
-                self._state_storage_controller, StateFilter.all()
-            )
 
         key = (event.type, event.state_key)
-        if key in state_ids_before_event:
-            replaces = state_ids_before_event[key]
-            if replaces != event.event_id:
-                event.unsigned["replaces_state"] = replaces
 
-        state_ids_after_event = dict(state_ids_before_event)
-        state_ids_after_event[key] = event.event_id
+        if state_ids_before_event is not None:
+            replaces = state_ids_before_event.get(key)
+        else:
+            replaces_state_map = await entry.get_state(
+                self._state_storage_controller, StateFilter.from_types([key])
+            )
+            replaces = replaces_state_map.get(key)
+
+        if replaces and replaces != event.event_id:
+            event.unsigned["replaces_state"] = replaces
+
         delta_ids = {key: event.event_id}
 
         state_group_after_event = (
@@ -357,7 +363,7 @@ class StateHandler:
                 event.room_id,
                 prev_group=state_group_before_event,
                 delta_ids=delta_ids,
-                current_state_ids=state_ids_after_event,
+                current_state_ids=None,
             )
         )