summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2018-03-28 16:19:26 +0100
committerErik Johnston <erik@matrix.org>2018-03-28 16:19:26 +0100
commitcfc2169b31d0771a808c52eeaf5e09f4222e50fc (patch)
tree3acefb549c0ad9ec316bfb1ddc45f59abc7884b4
parentMeasure time it takes to calculate state group ID (diff)
downloadsynapse-cfc2169b31d0771a808c52eeaf5e09f4222e50fc.tar.xz
WIP fast path state
-rw-r--r--synapse/event_auth.py29
-rw-r--r--synapse/state.py117
2 files changed, 134 insertions, 12 deletions
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index cd5627e36a..cf48f692fb 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -676,3 +676,32 @@ def auth_types_for_event(event):
                 auth_types.append(key)
 
     return auth_types
+
+
+def filter_dependent_state(keys, state):
+    if (EventTypes.Create, "") in keys:
+        return state
+
+    if (EventTypes.PowerLevels, "") in keys:
+        return state
+
+    def _filter_state(entry):
+        etype, state_key, sender = entry
+
+        if (etype, state_key) in keys:
+            return True
+
+        if (EventTypes.Member, sender) in keys:
+            return True
+
+        if etype == EventTypes.Member:
+            if (EventTypes.JoinRules, "") in keys:
+                return True
+
+            for t, _ in keys:
+                if t == EventTypes.ThirdPartyInvite:
+                    return True
+
+        return False
+
+    return filter(_filter_state, state)
diff --git a/synapse/state.py b/synapse/state.py
index 26093c8434..2ffdb0b01e 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -55,9 +55,9 @@ def _gen_state_id():
 
 
 class _StateCacheEntry(object):
-    __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
+    __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids", "conflicted_state"]
 
-    def __init__(self, state, state_group, prev_group=None, delta_ids=None):
+    def __init__(self, state, state_group, prev_group=None, delta_ids=None, conflicted_state=None):
         # dict[(str, str), str] map  from (type, state_key) to event_id
         self.state = frozendict(state)
 
@@ -80,6 +80,8 @@ class _StateCacheEntry(object):
         else:
             self.state_id = _gen_state_id()
 
+        self.conflicted_state = conflicted_state
+
     def __len__(self):
         return len(self.state)
 
@@ -375,7 +377,7 @@ class StateHandler(object):
         }
 
         with Measure(self.clock, "state._resolve_events"):
-            new_state = resolve_events_with_state_map(state_set_ids, state_map)
+            new_state, _, _ = resolve_events_with_state_map(state_set_ids, state_map)
 
         new_state = {
             key: state_map[ev_id] for key, ev_id in new_state.iteritems()
@@ -462,18 +464,17 @@ class StateResolutionHandler(object):
                 for key, e_id in st.iteritems():
                     state.setdefault(key, set()).add(e_id)
 
-            # build a map from state key to the event_ids which set that state,
-            # including only those where there are state keys in conflict.
-            conflicted_state = {
-                k: list(v)
-                for k, v in state.iteritems()
-                if len(v) > 1
-            }
+            # Check that there is a conflict between the state groups
+            conflicted_state = False
+            for values in state.itervalues():
+                if len(values):
+                    conflicted_state = True
+                    break
 
             if conflicted_state:
                 logger.info("Resolving conflicted state for %r", room_id)
                 with Measure(self.clock, "state._resolve_events"):
-                    new_state = yield resolve_events_with_factory(
+                    new_state, _, conflicted_state_keys = yield resolve_events_with_factory(
                         state_groups_ids.values(),
                         event_map=event_map,
                         state_map_factory=state_map_factory,
@@ -482,6 +483,7 @@ class StateResolutionHandler(object):
                 new_state = {
                     key: e_ids.pop() for key, e_ids in state.iteritems()
                 }
+                conflicted_state_keys = []
 
             with Measure(self.clock, "state.create_group_ids"):
                 # if the new state matches any of the input state groups, we can
@@ -517,6 +519,7 @@ class StateResolutionHandler(object):
                 state_group=state_group,
                 prev_group=prev_group,
                 delta_ids=delta_ids,
+                conflicted_state=conflicted_state_keys,
             )
 
             if self._state_cache is not None:
@@ -524,6 +527,82 @@ class StateResolutionHandler(object):
 
             defer.returnValue(cache)
 
+    @defer.inlineCallbacks
+    def resolve_delta_state(self, unchanged_state_groups, changed_groups, conflicted_state, store,
+                            event_map, state_map_factory):
+        deltas = set(entry for _, _, group_deltas in changed_groups for entry in group_deltas)
+
+        to_recalculate = event_auth.filter_dependent_state(deltas, conflicted_state)
+
+        new_groups = list(unchanged_state_groups)
+        new_groups.extend(g for _, g, _ in changed_groups)
+
+        types = [(etype, state_key) for etype, state_key, _ in to_recalculate]
+        state_sets = yield store._get_state_groups_from_groups(new_groups, types=types)
+
+        state_sets = {
+            sg: {
+                key: cs[key]
+                for key in types
+            }
+            for sg, cs in state_sets.iteritems()
+        }
+
+        logger.info("Recalculating: %s", to_recalculate)
+
+        group_names = frozenset(new_groups)
+        with (yield self.resolve_linearizer.queue(group_names)):
+            cache = None
+            if self._state_cache is not None:
+                cache = self._state_cache.get(group_names, None)
+
+        if cache:
+            new_state = {
+                (etype, state_key): cache.state[(etype, state_key)]
+                for etype, state_key, _ in to_recalculate
+            }
+            unconflicted_state, _ = _seperate(
+                state_sets.values(),
+            )
+        else:
+            needed_events = set(
+                event_id
+                for state in state_sets.itervalues()
+                for event_id in state.itervalues()
+            )
+            if event_map is not None:
+                needed_events -= set(event_map.iterkeys())
+
+            # logger.info("state_sets: %s", state_sets)
+            logger.info("Asking for %d conflicted events", len(needed_events))
+
+            # dict[str, FrozenEvent]: a map from state event id to event. Only includes
+            # the state events which are in conflict (and those in event_map)
+            state_map = yield state_map_factory(needed_events)
+            if event_map is not None:
+                state_map.update(event_map)
+
+            _, cs = _seperate(state_sets.values())
+
+            needed_state = _get_auth_event_keys(cs, state_map)
+            sg = new_groups[0]
+            res = yield store._get_state_for_groups((sg,), types=needed_state)
+            state = res[sg]
+            for s in state_sets.itervalues():
+                s.update(state)
+
+            logger.info("Added state: %s", len(needed_state))
+
+            new_state, unconflicted_state, _ = yield resolve_events_with_factory(
+                state_sets.values(),
+                event_map=state_map,
+                state_map_factory=state_map_factory,
+            )
+
+        conflicted_state = [e for e in conflicted_state if (e[0], e[1]) not in unconflicted_state]
+
+        defer.returnValue((new_state, conflicted_state))
+
 
 def _ordered_events(events):
     def key_func(e):
@@ -677,6 +756,15 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
     ))
 
 
+def _get_auth_event_keys(conflicted_state, state_map):
+    auth_events = set()
+    for event_ids in conflicted_state.itervalues():
+        for event_id in event_ids:
+            if event_id in state_map:
+                auth_events.update(event_auth.auth_types_for_event(state_map[event_id]))
+    return auth_events
+
+
 def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
     auth_events = {}
     for event_ids in conflicted_state.itervalues():
@@ -719,7 +807,12 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_
     for key, event in resolved_state.iteritems():
         new_state[key] = event.event_id
 
-    return new_state
+    return new_state, unconflicted_state_ids, [
+        (key[0], key[1], event.sender)
+        for key, evs in conflicted_state.iteritems()
+        if len(evs) > 1
+        for event in evs
+    ]
 
 
 def _resolve_state_events(conflicted_state, auth_events):