summary refs log tree commit diff
path: root/synapse/state.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/state.py')
-rw-r--r--synapse/state.py34
1 files changed, 32 insertions, 2 deletions
diff --git a/synapse/state.py b/synapse/state.py
index b31bbcdbd2..cd428e83cd 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -54,12 +54,15 @@ def _gen_state_id():
 
 
 class _StateCacheEntry(object):
-    __slots__ = ["state", "state_group", "state_id"]
+    __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
 
-    def __init__(self, state, state_group):
+    def __init__(self, state, state_group, prev_group=None, delta_ids=None):
         self.state = state
         self.state_group = state_group
 
+        self.prev_group = prev_group
+        self.delta_ids = delta_ids
+
         # The `state_id` is a unique ID we generate that can be used as ID for
         # this collection of state. Usually this would be the same as the
         # state group, but on worker instances we can't generate a new state
@@ -243,11 +246,20 @@ class StateHandler(object):
             if key in context.prev_state_ids:
                 replaces = context.prev_state_ids[key]
                 event.unsigned["replaces_state"] = replaces
+
             context.current_state_ids = dict(context.prev_state_ids)
             context.current_state_ids[key] = event.event_id
+
+            context.prev_group = entry.prev_group
+            context.delta_ids = entry.delta_ids
+            if context.delta_ids is not None:
+                context.delta_ids[key] = event.event_id
         else:
             context.current_state_ids = context.prev_state_ids
 
+            context.prev_group = entry.prev_group
+            context.delta_ids = entry.delta_ids
+
         context.prev_state_events = []
         defer.returnValue(context)
 
@@ -281,6 +293,8 @@ class StateHandler(object):
             defer.returnValue(_StateCacheEntry(
                 state=state_list,
                 state_group=name,
+                prev_group=name,
+                delta_ids={},
             ))
 
         if self._state_cache is not None:
@@ -330,6 +344,7 @@ class StateHandler(object):
             if new_state_event_ids == frozenset(e_id for e_id in events):
                 state_group = sg
                 break
+
         if state_group is None:
             # Worker instances don't have access to this method, but we want
             # to set the state_group on the main instance to increase cache
@@ -337,9 +352,24 @@ class StateHandler(object):
             if hasattr(self.store, "get_next_state_group"):
                 state_group = self.store.get_next_state_group()
 
+        prev_group = None
+        delta_ids = None
+        for old_group, old_ids in state_groups_ids.items():
+            if not set(new_state.iterkeys()) - set(old_ids.iterkeys()):
+                n_delta_ids = {
+                    k: v
+                    for k, v in new_state.items()
+                    if old_ids.get(k) != v
+                }
+                if not delta_ids or len(n_delta_ids) < len(delta_ids):
+                    prev_group = old_group
+                    delta_ids = n_delta_ids
+
         cache = _StateCacheEntry(
             state=new_state,
             state_group=state_group,
+            prev_group=prev_group,
+            delta_ids=delta_ids,
         )
 
         if self._state_cache is not None: