summary refs log tree commit diff
path: root/synapse/state.py
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2018-07-23 19:00:16 +0100
committerRichard van der Hoff <richard@matrix.org>2018-07-23 19:10:11 +0100
commit5c705f70c9489427a7985ea10ec60552965b9a1c (patch)
tree37c8277f6cb7fc7428bb1088da660f91cde0df08 /synapse/state.py
parentMerge pull request #3584 from matrix-org/erikj/use_cached (diff)
downloadsynapse-5c705f70c9489427a7985ea10ec60552965b9a1c.tar.xz
Fixes and optimisations for resolve_state_groups
First of all, fix the logic which looks for identical input state groups so
that we actually use them. This turned out to be most easily done by factoring
the relevant code out to a separate function so that we could do an early
return.

Secondly, avoid building the whole `conflicted_state` dict (which was only ever
used as a boolean flag).

Thirdly, replace the construction of the `state` dict (which mapped from keys
to events that set them), with an optimistic construction of the resolution
result assuming there will be no conflicts. This should be no slower than
building the old `state` dict, and:
  - in the conflicted case, we'll short-cut it, saving part of the work
  - in the unconflicted case, it saves rebuilding the resolution from the
    `state` dict.

Finally, do a couple of s/values/itervalues/.
Diffstat (limited to 'synapse/state.py')
-rw-r--r--synapse/state.py143
1 files changed, 89 insertions, 54 deletions
diff --git a/synapse/state.py b/synapse/state.py
index 32125c95df..033f55d967 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -471,69 +471,39 @@ class StateResolutionHandler(object):
                 "Resolving state for %s with %d groups", room_id, len(state_groups_ids)
             )
 
-            # build a map from state key to the event_ids which set that state.
-            # dict[(str, str), set[str])
-            state = {}
+            # start by assuming we won't have any conflicted state, and build up the new
+            # state map by iterating through the state groups. If we discover a conflict,
+            # we give up and instead use `resolve_events_with_factory`.
+            #
+            # XXX: is this actually worthwhile, or should we just let
+            # resolve_events_with_factory do it?
+            new_state = {}
+            conflicted_state = False
             for st in itervalues(state_groups_ids):
                 for key, e_id in iteritems(st):
-                    state.setdefault(key, set()).add(e_id)
-
-            # build a map from state key to the event_ids which set that state,
-            # including only those where there are state keys in conflict.
-            conflicted_state = {
-                k: list(v)
-                for k, v in iteritems(state)
-                if len(v) > 1
-            }
+                    if key in new_state:
+                        conflicted_state = True
+                        break
+                    new_state[key] = e_id
+                if conflicted_state:
+                    break
 
             if conflicted_state:
                 logger.info("Resolving conflicted state for %r", room_id)
                 with Measure(self.clock, "state._resolve_events"):
                     new_state = yield resolve_events_with_factory(
-                        list(state_groups_ids.values()),
+                        list(itervalues(state_groups_ids)),
                         event_map=event_map,
                         state_map_factory=state_map_factory,
                     )
-            else:
-                new_state = {
-                    key: e_ids.pop() for key, e_ids in iteritems(state)
-                }
 
-            with Measure(self.clock, "state.create_group_ids"):
-                # if the new state matches any of the input state groups, we can
-                # use that state group again. Otherwise we will generate a state_id
-                # which will be used as a cache key for future resolutions, but
-                # not get persisted.
-                state_group = None
-                new_state_event_ids = frozenset(itervalues(new_state))
-                for sg, events in iteritems(state_groups_ids):
-                    if new_state_event_ids == frozenset(e_id for e_id in events):
-                        state_group = sg
-                        break
+            # if the new state matches any of the input state groups, we can
+            # use that state group again. Otherwise we will generate a state_id
+            # which will be used as a cache key for future resolutions, but
+            # not get persisted.
 
-                # TODO: We want to create a state group for this set of events, to
-                # increase cache hits, but we need to make sure that it doesn't
-                # end up as a prev_group without being added to the database
-
-                prev_group = None
-                delta_ids = None
-                for old_group, old_ids in iteritems(state_groups_ids):
-                    if not set(new_state) - set(old_ids):
-                        n_delta_ids = {
-                            k: v
-                            for k, v in iteritems(new_state)
-                            if old_ids.get(k) != v
-                        }
-                        if not delta_ids or len(n_delta_ids) < len(delta_ids):
-                            prev_group = old_group
-                            delta_ids = n_delta_ids
-
-            cache = _StateCacheEntry(
-                state=new_state,
-                state_group=state_group,
-                prev_group=prev_group,
-                delta_ids=delta_ids,
-            )
+            with Measure(self.clock, "state.create_group_ids"):
+                cache = _make_state_cache_entry(new_state, state_groups_ids)
 
             if self._state_cache is not None:
                 self._state_cache[group_names] = cache
@@ -541,6 +511,70 @@ class StateResolutionHandler(object):
             defer.returnValue(cache)
 
 
+def _make_state_cache_entry(
+    new_state,
+    state_groups_ids,
+):
+    """Given a resolved state, and a set of input state groups, pick one to base
+    a new state group on (if any), and return an appropriately-constructed
+    _StateCacheEntry.
+
+    Args:
+        new_state (dict[(str, str), str]): resolved state map (mapping from
+           (type, state_key) to event_id)
+
+        state_groups_ids (dict[int, dict[(str, str), str]]):
+                 map from state group id to the state in that state group
+                (where 'state' is a map from state key to event id)
+
+    Returns:
+        _StateCacheEntry
+    """
+    # if the new state matches any of the input state groups, we can
+    # use that state group again. Otherwise we will generate a state_id
+    # which will be used as a cache key for future resolutions, but
+    # not get persisted.
+
+    # first look for exact matches
+    new_state_event_ids = set(itervalues(new_state))
+    for sg, state in iteritems(state_groups_ids):
+        if len(new_state_event_ids) != len(state):
+            continue
+
+        old_state_event_ids = set(itervalues(state))
+        if new_state_event_ids == old_state_event_ids:
+            # got an exact match.
+            return _StateCacheEntry(
+                state=new_state,
+                state_group=sg,
+            )
+
+    # TODO: We want to create a state group for this set of events, to
+    # increase cache hits, but we need to make sure that it doesn't
+    # end up as a prev_group without being added to the database
+
+    # failing that, look for the closest match.
+    prev_group = None
+    delta_ids = None
+
+    for old_group, old_state in iteritems(state_groups_ids):
+        n_delta_ids = {
+            k: v
+            for k, v in iteritems(new_state)
+            if old_state.get(k) != v
+        }
+        if not delta_ids or len(n_delta_ids) < len(delta_ids):
+            prev_group = old_group
+            delta_ids = n_delta_ids
+
+    return _StateCacheEntry(
+        state=new_state,
+        state_group=None,
+        prev_group=prev_group,
+        delta_ids=delta_ids,
+    )
+
+
 def _ordered_events(events):
     def key_func(e):
         return -int(e.depth), hashlib.sha1(e.event_id.encode()).hexdigest()
@@ -582,7 +616,7 @@ def _seperate(state_sets):
     with them in different state sets.
 
     Args:
-        state_sets(list[dict[(str, str), str]]):
+        state_sets(iterable[dict[(str, str), str]]):
             List of dicts of (type, state_key) -> event_id, which are the
             different state groups to resolve.
 
@@ -596,10 +630,11 @@ def _seperate(state_sets):
             conflicted_state is a dict mapping (type, state_key) to a set of
             event ids for conflicted state keys.
     """
-    unconflicted_state = dict(state_sets[0])
+    state_set_iterator = iter(state_sets)
+    unconflicted_state = dict(next(state_set_iterator))
     conflicted_state = {}
 
-    for state_set in state_sets[1:]:
+    for state_set in state_set_iterator:
         for key, value in iteritems(state_set):
             # Check if there is an unconflicted entry for the state key.
             unconflicted_value = unconflicted_state.get(key)