summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/state.py143
1 files changed, 89 insertions, 54 deletions
diff --git a/synapse/state.py b/synapse/state.py
index 32125c95df..033f55d967 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -471,69 +471,39 @@ class StateResolutionHandler(object):
                 "Resolving state for %s with %d groups", room_id, len(state_groups_ids)
             )
 
-            # build a map from state key to the event_ids which set that state.
-            # dict[(str, str), set[str])
-            state = {}
+            # start by assuming we won't have any conflicted state, and build up the new
+            # state map by iterating through the state groups. If we discover a conflict,
+            # we give up and instead use `resolve_events_with_factory`.
+            #
+            # XXX: is this actually worthwhile, or should we just let
+            # resolve_events_with_factory do it?
+            new_state = {}
+            conflicted_state = False
             for st in itervalues(state_groups_ids):
                 for key, e_id in iteritems(st):
-                    state.setdefault(key, set()).add(e_id)
-
-            # build a map from state key to the event_ids which set that state,
-            # including only those where there are state keys in conflict.
-            conflicted_state = {
-                k: list(v)
-                for k, v in iteritems(state)
-                if len(v) > 1
-            }
+                    if key in new_state:
+                        conflicted_state = True
+                        break
+                    new_state[key] = e_id
+                if conflicted_state:
+                    break
 
             if conflicted_state:
                 logger.info("Resolving conflicted state for %r", room_id)
                 with Measure(self.clock, "state._resolve_events"):
                     new_state = yield resolve_events_with_factory(
-                        list(state_groups_ids.values()),
+                        list(itervalues(state_groups_ids)),
                         event_map=event_map,
                         state_map_factory=state_map_factory,
                     )
-            else:
-                new_state = {
-                    key: e_ids.pop() for key, e_ids in iteritems(state)
-                }
 
-            with Measure(self.clock, "state.create_group_ids"):
-                # if the new state matches any of the input state groups, we can
-                # use that state group again. Otherwise we will generate a state_id
-                # which will be used as a cache key for future resolutions, but
-                # not get persisted.
-                state_group = None
-                new_state_event_ids = frozenset(itervalues(new_state))
-                for sg, events in iteritems(state_groups_ids):
-                    if new_state_event_ids == frozenset(e_id for e_id in events):
-                        state_group = sg
-                        break
+            # if the new state matches any of the input state groups, we can
+            # use that state group again. Otherwise we will generate a state_id
+            # which will be used as a cache key for future resolutions, but
+            # not get persisted.
 
-                # TODO: We want to create a state group for this set of events, to
-                # increase cache hits, but we need to make sure that it doesn't
-                # end up as a prev_group without being added to the database
-
-                prev_group = None
-                delta_ids = None
-                for old_group, old_ids in iteritems(state_groups_ids):
-                    if not set(new_state) - set(old_ids):
-                        n_delta_ids = {
-                            k: v
-                            for k, v in iteritems(new_state)
-                            if old_ids.get(k) != v
-                        }
-                        if not delta_ids or len(n_delta_ids) < len(delta_ids):
-                            prev_group = old_group
-                            delta_ids = n_delta_ids
-
-            cache = _StateCacheEntry(
-                state=new_state,
-                state_group=state_group,
-                prev_group=prev_group,
-                delta_ids=delta_ids,
-            )
+            with Measure(self.clock, "state.create_group_ids"):
+                cache = _make_state_cache_entry(new_state, state_groups_ids)
 
             if self._state_cache is not None:
                 self._state_cache[group_names] = cache
@@ -541,6 +511,70 @@ class StateResolutionHandler(object):
             defer.returnValue(cache)
 
 
+def _make_state_cache_entry(
+    new_state,
+    state_groups_ids,
+):
+    """Given a resolved state, and a set of input state groups, pick one to base
+    a new state group on (if any), and return an appropriately-constructed
+    _StateCacheEntry.
+
+    Args:
+        new_state (dict[(str, str), str]): resolved state map (mapping from
+           (type, state_key) to event_id)
+
+        state_groups_ids (dict[int, dict[(str, str), str]]):
+                 map from state group id to the state in that state group
+                (where 'state' is a map from state key to event id)
+
+    Returns:
+        _StateCacheEntry
+    """
+    # if the new state matches any of the input state groups, we can
+    # use that state group again. Otherwise we will generate a state_id
+    # which will be used as a cache key for future resolutions, but
+    # not get persisted.
+
+    # first look for exact matches
+    new_state_event_ids = set(itervalues(new_state))
+    for sg, state in iteritems(state_groups_ids):
+        if len(new_state_event_ids) != len(state):
+            continue
+
+        old_state_event_ids = set(itervalues(state))
+        if new_state_event_ids == old_state_event_ids:
+            # got an exact match.
+            return _StateCacheEntry(
+                state=new_state,
+                state_group=sg,
+            )
+
+    # TODO: We want to create a state group for this set of events, to
+    # increase cache hits, but we need to make sure that it doesn't
+    # end up as a prev_group without being added to the database
+
+    # failing that, look for the closest match.
+    prev_group = None
+    delta_ids = None
+
+    for old_group, old_state in iteritems(state_groups_ids):
+        n_delta_ids = {
+            k: v
+            for k, v in iteritems(new_state)
+            if old_state.get(k) != v
+        }
+        if not delta_ids or len(n_delta_ids) < len(delta_ids):
+            prev_group = old_group
+            delta_ids = n_delta_ids
+
+    return _StateCacheEntry(
+        state=new_state,
+        state_group=None,
+        prev_group=prev_group,
+        delta_ids=delta_ids,
+    )
+
+
 def _ordered_events(events):
     def key_func(e):
         return -int(e.depth), hashlib.sha1(e.event_id.encode()).hexdigest()
@@ -582,7 +616,7 @@ def _seperate(state_sets):
     with them in different state sets.
 
     Args:
-        state_sets(list[dict[(str, str), str]]):
+        state_sets(iterable[dict[(str, str), str]]):
             List of dicts of (type, state_key) -> event_id, which are the
             different state groups to resolve.
 
@@ -596,10 +630,11 @@ def _seperate(state_sets):
             conflicted_state is a dict mapping (type, state_key) to a set of
             event ids for conflicted state keys.
     """
-    unconflicted_state = dict(state_sets[0])
+    state_set_iterator = iter(state_sets)
+    unconflicted_state = dict(next(state_set_iterator))
     conflicted_state = {}
 
-    for state_set in state_sets[1:]:
+    for state_set in state_set_iterator:
         for key, value in iteritems(state_set):
             # Check if there is an unconflicted entry for the state key.
             unconflicted_value = unconflicted_state.get(key)