summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/storage/events.py65
1 files changed, 48 insertions, 17 deletions
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 87d0b78a2d..6d2aae9c4a 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -344,7 +344,9 @@ class EventsStore(EventsWorkerStore):
                 new_forward_extremeties = {}
 
                 # map room_id->(type,state_key)->event_id tracking the full
-                # state in each room after adding these events
+                # state in each room after adding these events.
+                # This is simply used to prefill the get_current_state_ids
+                # cache
                 current_state_for_room = {}
 
                 # map room_id->(to_delete, to_insert) where to_delete is a list
@@ -419,28 +421,40 @@ class EventsStore(EventsWorkerStore):
                             logger.info(
                                 "Calculating state delta for room %s", room_id,
                             )
-
                             with Measure(
-                                    self._clock,
-                                    "persist_events.get_new_state_after_events",
+                                self._clock,
+                                "persist_events.get_new_state_after_events",
                             ):
-                                current_state = yield self._get_new_state_after_events(
+                                res = yield self._get_new_state_after_events(
                                     room_id,
                                     ev_ctx_rm,
                                     latest_event_ids,
                                     new_latest_event_ids,
                                 )
-
-                            if current_state is not None:
-                                current_state_for_room[room_id] = current_state
+                                current_state, delta_ids = res
+
+                            # If either are not None then there has been a change,
+                            # and we need to work out the delta (or use that
+                            # given)
+                            if delta_ids is not None:
+                                # If there is a delta we know that we've
+                                # only added or replaced state, never
+                                # removed keys entirely.
+                                state_delta_for_room[room_id] = ([], delta_ids)
+                            elif current_state is not None:
                                 with Measure(
-                                        self._clock,
-                                        "persist_events.calculate_state_delta",
+                                    self._clock,
+                                    "persist_events.calculate_state_delta",
                                 ):
                                     delta = yield self._calculate_state_delta(
                                         room_id, current_state,
                                     )
-                                    state_delta_for_room[room_id] = delta
+                                state_delta_for_room[room_id] = delta
+
+                            # If we have the current_state then lets prefill
+                            # the cache with it.
+                            if current_state is not None:
+                                current_state_for_room[room_id] = current_state
 
                 yield self.runInteraction(
                     "persist_events",
@@ -539,9 +553,10 @@ class EventsStore(EventsWorkerStore):
                 the new forward extremities for the room.
 
         Returns:
-            Deferred[dict[(str,str), str]|None]:
-                None if there are no changes to the room state, or
-                a dict of (type, state_key) -> event_id].
+            Deferred[tuple[dict[(str,str), str]|None, dict[(str,str), str]|None]]:
+            Returns a tuple of two state maps, the first being the full new current
+            state and the second being the delta to the existing current state.
+            If both are None then there has been no change.
         """
 
         if not new_latest_event_ids:
@@ -549,6 +564,9 @@ class EventsStore(EventsWorkerStore):
 
         # map from state_group to ((type, key) -> event_id) state map
         state_groups_map = {}
+
+        state_group_deltas = {}
+
         for ev, ctx in events_context:
             if ctx.state_group is None:
                 # I don't think this can happen, but let's double-check
@@ -567,6 +585,9 @@ class EventsStore(EventsWorkerStore):
             if current_state_ids is not None:
                 state_groups_map[ctx.state_group] = current_state_ids
 
+            if ctx.prev_group:
+                state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids
+
         # We need to map the event_ids to their state groups. First, let's
         # check if the event is one we're persisting, in which case we can
         # pull the state group from its context.
@@ -608,7 +629,7 @@ class EventsStore(EventsWorkerStore):
         # If they old and new groups are the same then we don't need to do
         # anything.
         if old_state_groups == new_state_groups:
-            return
+            defer.returnValue((None, None))
 
         # Now that we have calculated new_state_groups we need to get
         # their state IDs so we can resolve to a single state set.
@@ -620,7 +641,17 @@ class EventsStore(EventsWorkerStore):
         if len(new_state_groups) == 1:
             # If there is only one state group, then we know what the current
             # state is.
-            defer.returnValue(state_groups_map[new_state_groups.pop()])
+            new_state_group = new_state_groups.pop()
+
+            delta_ids = None
+            if len(old_state_groups) == 1:
+                old_state_group = old_state_groups.pop()
+
+                delta_ids = state_group_deltas.get(
+                    (old_state_group, new_state_group,), None
+                )
+
+            defer.returnValue((state_groups_map[new_state_group], delta_ids))
 
         # Ok, we need to defer to the state handler to resolve our state sets.
 
@@ -639,7 +670,7 @@ class EventsStore(EventsWorkerStore):
             room_id, state_groups, events_map, get_events
         )
 
-        defer.returnValue(res.state)
+        defer.returnValue((res.state, None))
 
     @defer.inlineCallbacks
     def _calculate_state_delta(self, room_id, current_state):