summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2018-03-28 16:57:55 +0100
committerErik Johnston <erik@matrix.org>2018-03-28 16:57:55 +0100
commitcdbf76e88ac8eac97f34d7b26891838928d19bdf (patch)
treed3bc535680e72ad0325f290516d207eaaa28d593
parentWIP fast path state (diff)
downloadsynapse-cdbf76e88ac8eac97f34d7b26891838928d19bdf.tar.xz
-rw-r--r--synapse/state.py46
1 files changed, 34 insertions, 12 deletions
diff --git a/synapse/state.py b/synapse/state.py
index 2ffdb0b01e..1bf4a0df6f 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -531,24 +531,33 @@ class StateResolutionHandler(object):
     def resolve_delta_state(self, unchanged_state_groups, changed_groups, conflicted_state, store,
                             event_map, state_map_factory):
         deltas = set(entry for _, _, group_deltas in changed_groups for entry in group_deltas)
+        conflicted_state_keys = set((etype, state_key) for etype, state_key, _ in conflicted_state)
+        conflicted_state_keys.update(deltas)
 
-        to_recalculate = event_auth.filter_dependent_state(deltas, conflicted_state)
+        if not deltas:
+            logger.info("No deltas")
+            defer.returnValue({}, conflicted_state)
+
+        to_recalculate = set(event_auth.filter_dependent_state(deltas, conflicted_state))
+        to_recalculate = set((etype, state_key) for etype, state_key, _ in to_recalculate)
+        to_recalculate.update(deltas)
 
         new_groups = list(unchanged_state_groups)
         new_groups.extend(g for _, g, _ in changed_groups)
 
-        types = [(etype, state_key) for etype, state_key, _ in to_recalculate]
-        state_sets = yield store._get_state_groups_from_groups(new_groups, types=types)
+        state_sets = yield store._get_state_groups_from_groups(new_groups, types=to_recalculate)
 
         state_sets = {
             sg: {
                 key: cs[key]
-                for key in types
+                for key in to_recalculate
+                if key in cs
             }
             for sg, cs in state_sets.iteritems()
         }
 
         logger.info("Recalculating: %s", to_recalculate)
+        logger.info("State Sets: %s", state_sets)
 
         group_names = frozenset(new_groups)
         with (yield self.resolve_linearizer.queue(group_names)):
@@ -557,9 +566,10 @@ class StateResolutionHandler(object):
                 cache = self._state_cache.get(group_names, None)
 
         if cache:
+            logger.info("Using cache")
             new_state = {
-                (etype, state_key): cache.state[(etype, state_key)]
-                for etype, state_key, _ in to_recalculate
+                key: cache.state[key]
+                for key in to_recalculate
             }
             unconflicted_state, _ = _seperate(
                 state_sets.values(),
@@ -585,13 +595,23 @@ class StateResolutionHandler(object):
             _, cs = _seperate(state_sets.values())
 
             needed_state = _get_auth_event_keys(cs, state_map)
-            sg = new_groups[0]
-            res = yield store._get_state_for_groups((sg,), types=needed_state)
-            state = res[sg]
-            for s in state_sets.itervalues():
-                s.update(state)
 
-            logger.info("Added state: %s", len(needed_state))
+            if needed_state - conflicted_state_keys:
+                res = yield store._get_state_for_groups((new_groups[0],), types=(needed_state - conflicted_state_keys))
+                state = res[new_groups[0]]
+                for s in state_sets.itervalues():
+                    s.update(state)
+
+                logger.info("Added unconflicted state: %s", state)
+
+            needed_conflicted_state = needed_state & conflicted_state_keys
+            if needed_conflicted_state:
+                for sg, s in state_sets.iteritems():
+                    res = yield store._get_state_for_groups((sg,), types=needed_conflicted_state)
+                    state = res[sg]
+                    s.update(state)
+
+                logger.info("Added conflicted state: %s", needed_conflicted_state)
 
             new_state, unconflicted_state, _ = yield resolve_events_with_factory(
                 state_sets.values(),
@@ -601,6 +621,8 @@ class StateResolutionHandler(object):
 
         conflicted_state = [e for e in conflicted_state if (e[0], e[1]) not in unconflicted_state]
 
+        logger.info("Returning: %s", new_state)
+
         defer.returnValue((new_state, conflicted_state))