summary refs log tree commit diff
path: root/synapse/state.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/state.py')
-rw-r--r--synapse/state.py246
1 files changed, 147 insertions, 99 deletions
diff --git a/synapse/state.py b/synapse/state.py
index 15a593d41c..033f55d967 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -18,7 +18,7 @@ import hashlib
 import logging
 from collections import namedtuple
 
-from six import iteritems, itervalues
+from six import iteritems, iterkeys, itervalues
 
 from frozendict import frozendict
 
@@ -203,25 +203,27 @@ class StateHandler(object):
             # If this is an outlier, then we know it shouldn't have any current
             # state. Certainly store.get_current_state won't return any, and
             # persisting the event won't store the state group.
-            context = EventContext()
             if old_state:
-                context.prev_state_ids = {
+                prev_state_ids = {
                     (s.type, s.state_key): s.event_id for s in old_state
                 }
                 if event.is_state():
-                    context.current_state_ids = dict(context.prev_state_ids)
+                    current_state_ids = dict(prev_state_ids)
                     key = (event.type, event.state_key)
-                    context.current_state_ids[key] = event.event_id
+                    current_state_ids[key] = event.event_id
                 else:
-                    context.current_state_ids = context.prev_state_ids
+                    current_state_ids = prev_state_ids
             else:
-                context.current_state_ids = {}
-                context.prev_state_ids = {}
-            context.prev_state_events = []
+                current_state_ids = {}
+                prev_state_ids = {}
 
             # We don't store state for outliers, so we don't generate a state
-            # froup for it.
-            context.state_group = None
+            # group for it.
+            context = EventContext.with_state(
+                state_group=None,
+                current_state_ids=current_state_ids,
+                prev_state_ids=prev_state_ids,
+            )
 
             defer.returnValue(context)
 
@@ -230,31 +232,35 @@ class StateHandler(object):
             # Let's just correctly fill out the context and create a
             # new state group for it.
 
-            context = EventContext()
-            context.prev_state_ids = {
+            prev_state_ids = {
                 (s.type, s.state_key): s.event_id for s in old_state
             }
 
             if event.is_state():
                 key = (event.type, event.state_key)
-                if key in context.prev_state_ids:
-                    replaces = context.prev_state_ids[key]
+                if key in prev_state_ids:
+                    replaces = prev_state_ids[key]
                     if replaces != event.event_id:  # Paranoia check
                         event.unsigned["replaces_state"] = replaces
-                context.current_state_ids = dict(context.prev_state_ids)
-                context.current_state_ids[key] = event.event_id
+                current_state_ids = dict(prev_state_ids)
+                current_state_ids[key] = event.event_id
             else:
-                context.current_state_ids = context.prev_state_ids
+                current_state_ids = prev_state_ids
 
-            context.state_group = yield self.store.store_state_group(
+            state_group = yield self.store.store_state_group(
                 event.event_id,
                 event.room_id,
                 prev_group=None,
                 delta_ids=None,
-                current_state_ids=context.current_state_ids,
+                current_state_ids=current_state_ids,
+            )
+
+            context = EventContext.with_state(
+                state_group=state_group,
+                current_state_ids=current_state_ids,
+                prev_state_ids=prev_state_ids,
             )
 
-            context.prev_state_events = []
             defer.returnValue(context)
 
         logger.debug("calling resolve_state_groups from compute_event_context")
@@ -262,47 +268,47 @@ class StateHandler(object):
             event.room_id, [e for e, _ in event.prev_events],
         )
 
-        curr_state = entry.state
+        prev_state_ids = entry.state
+        prev_group = None
+        delta_ids = None
 
-        context = EventContext()
-        context.prev_state_ids = curr_state
         if event.is_state():
             # If this is a state event then we need to create a new state
             # group for the state after this event.
 
             key = (event.type, event.state_key)
-            if key in context.prev_state_ids:
-                replaces = context.prev_state_ids[key]
+            if key in prev_state_ids:
+                replaces = prev_state_ids[key]
                 event.unsigned["replaces_state"] = replaces
 
-            context.current_state_ids = dict(context.prev_state_ids)
-            context.current_state_ids[key] = event.event_id
+            current_state_ids = dict(prev_state_ids)
+            current_state_ids[key] = event.event_id
 
             if entry.state_group:
                 # If the state at the event has a state group assigned then
                 # we can use that as the prev group
-                context.prev_group = entry.state_group
-                context.delta_ids = {
+                prev_group = entry.state_group
+                delta_ids = {
                     key: event.event_id
                 }
             elif entry.prev_group:
                 # If the state at the event only has a prev group, then we can
                 # use that as a prev group too.
-                context.prev_group = entry.prev_group
-                context.delta_ids = dict(entry.delta_ids)
-                context.delta_ids[key] = event.event_id
+                prev_group = entry.prev_group
+                delta_ids = dict(entry.delta_ids)
+                delta_ids[key] = event.event_id
 
-            context.state_group = yield self.store.store_state_group(
+            state_group = yield self.store.store_state_group(
                 event.event_id,
                 event.room_id,
-                prev_group=context.prev_group,
-                delta_ids=context.delta_ids,
-                current_state_ids=context.current_state_ids,
+                prev_group=prev_group,
+                delta_ids=delta_ids,
+                current_state_ids=current_state_ids,
             )
         else:
-            context.current_state_ids = context.prev_state_ids
-            context.prev_group = entry.prev_group
-            context.delta_ids = entry.delta_ids
+            current_state_ids = prev_state_ids
+            prev_group = entry.prev_group
+            delta_ids = entry.delta_ids
 
             if entry.state_group is None:
                 entry.state_group = yield self.store.store_state_group(
@@ -310,13 +316,20 @@ class StateHandler(object):
                     event.room_id,
                     prev_group=entry.prev_group,
                     delta_ids=entry.delta_ids,
-                    current_state_ids=context.current_state_ids,
+                    current_state_ids=current_state_ids,
                 )
                 entry.state_id = entry.state_group
 
-            context.state_group = entry.state_group
+            state_group = entry.state_group
+
+        context = EventContext.with_state(
+            state_group=state_group,
+            current_state_ids=current_state_ids,
+            prev_state_ids=prev_state_ids,
+            prev_group=prev_group,
+            delta_ids=delta_ids,
+        )
 
-        context.prev_state_events = []
         defer.returnValue(context)
 
     @defer.inlineCallbacks
@@ -458,69 +471,39 @@ class StateResolutionHandler(object):
                 "Resolving state for %s with %d groups", room_id, len(state_groups_ids)
             )
 
-            # build a map from state key to the event_ids which set that state.
-            # dict[(str, str), set[str])
-            state = {}
+            # start by assuming we won't have any conflicted state, and build up the new
+            # state map by iterating through the state groups. If we discover a conflict,
+            # we give up and instead use `resolve_events_with_factory`.
+            #
+            # XXX: is this actually worthwhile, or should we just let
+            # resolve_events_with_factory do it?
+            new_state = {}
+            conflicted_state = False
             for st in itervalues(state_groups_ids):
                 for key, e_id in iteritems(st):
-                    state.setdefault(key, set()).add(e_id)
-
-            # build a map from state key to the event_ids which set that state,
-            # including only those where there are state keys in conflict.
-            conflicted_state = {
-                k: list(v)
-                for k, v in iteritems(state)
-                if len(v) > 1
-            }
+                    if key in new_state:
+                        conflicted_state = True
+                        break
+                    new_state[key] = e_id
+                if conflicted_state:
+                    break
 
             if conflicted_state:
                 logger.info("Resolving conflicted state for %r", room_id)
                 with Measure(self.clock, "state._resolve_events"):
                     new_state = yield resolve_events_with_factory(
-                        list(state_groups_ids.values()),
+                        list(itervalues(state_groups_ids)),
                         event_map=event_map,
                         state_map_factory=state_map_factory,
                     )
-            else:
-                new_state = {
-                    key: e_ids.pop() for key, e_ids in iteritems(state)
-                }
 
-            with Measure(self.clock, "state.create_group_ids"):
-                # if the new state matches any of the input state groups, we can
-                # use that state group again. Otherwise we will generate a state_id
-                # which will be used as a cache key for future resolutions, but
-                # not get persisted.
-                state_group = None
-                new_state_event_ids = frozenset(itervalues(new_state))
-                for sg, events in iteritems(state_groups_ids):
-                    if new_state_event_ids == frozenset(e_id for e_id in events):
-                        state_group = sg
-                        break
+            # if the new state matches any of the input state groups, we can
+            # use that state group again. Otherwise we will generate a state_id
+            # which will be used as a cache key for future resolutions, but
+            # not get persisted.
 
-                # TODO: We want to create a state group for this set of events, to
-                # increase cache hits, but we need to make sure that it doesn't
-                # end up as a prev_group without being added to the database
-
-                prev_group = None
-                delta_ids = None
-                for old_group, old_ids in iteritems(state_groups_ids):
-                    if not set(new_state) - set(old_ids):
-                        n_delta_ids = {
-                            k: v
-                            for k, v in iteritems(new_state)
-                            if old_ids.get(k) != v
-                        }
-                        if not delta_ids or len(n_delta_ids) < len(delta_ids):
-                            prev_group = old_group
-                            delta_ids = n_delta_ids
-
-            cache = _StateCacheEntry(
-                state=new_state,
-                state_group=state_group,
-                prev_group=prev_group,
-                delta_ids=delta_ids,
-            )
+            with Measure(self.clock, "state.create_group_ids"):
+                cache = _make_state_cache_entry(new_state, state_groups_ids)
 
             if self._state_cache is not None:
                 self._state_cache[group_names] = cache
@@ -528,6 +511,70 @@ class StateResolutionHandler(object):
             defer.returnValue(cache)
 
 
+def _make_state_cache_entry(
+    new_state,
+    state_groups_ids,
+):
+    """Given a resolved state, and a set of input state groups, pick one to base
+    a new state group on (if any), and return an appropriately-constructed
+    _StateCacheEntry.
+
+    Args:
+        new_state (dict[(str, str), str]): resolved state map (mapping from
+           (type, state_key) to event_id)
+
+        state_groups_ids (dict[int, dict[(str, str), str]]):
+                 map from state group id to the state in that state group
+                (where 'state' is a map from state key to event id)
+
+    Returns:
+        _StateCacheEntry
+    """
+    # if the new state matches any of the input state groups, we can
+    # use that state group again. Otherwise we will generate a state_id
+    # which will be used as a cache key for future resolutions, but
+    # not get persisted.
+
+    # first look for exact matches
+    new_state_event_ids = set(itervalues(new_state))
+    for sg, state in iteritems(state_groups_ids):
+        if len(new_state_event_ids) != len(state):
+            continue
+
+        old_state_event_ids = set(itervalues(state))
+        if new_state_event_ids == old_state_event_ids:
+            # got an exact match.
+            return _StateCacheEntry(
+                state=new_state,
+                state_group=sg,
+            )
+
+    # TODO: We want to create a state group for this set of events, to
+    # increase cache hits, but we need to make sure that it doesn't
+    # end up as a prev_group without being added to the database
+
+    # failing that, look for the closest match.
+    prev_group = None
+    delta_ids = None
+
+    for old_group, old_state in iteritems(state_groups_ids):
+        n_delta_ids = {
+            k: v
+            for k, v in iteritems(new_state)
+            if old_state.get(k) != v
+        }
+        if not delta_ids or len(n_delta_ids) < len(delta_ids):
+            prev_group = old_group
+            delta_ids = n_delta_ids
+
+    return _StateCacheEntry(
+        state=new_state,
+        state_group=None,
+        prev_group=prev_group,
+        delta_ids=delta_ids,
+    )
+
+
 def _ordered_events(events):
     def key_func(e):
         return -int(e.depth), hashlib.sha1(e.event_id.encode()).hexdigest()
@@ -569,7 +616,7 @@ def _seperate(state_sets):
     with them in different state sets.
 
     Args:
-        state_sets(list[dict[(str, str), str]]):
+        state_sets(iterable[dict[(str, str), str]]):
             List of dicts of (type, state_key) -> event_id, which are the
             different state groups to resolve.
 
@@ -583,10 +630,11 @@ def _seperate(state_sets):
             conflicted_state is a dict mapping (type, state_key) to a set of
             event ids for conflicted state keys.
     """
-    unconflicted_state = dict(state_sets[0])
+    state_set_iterator = iter(state_sets)
+    unconflicted_state = dict(next(state_set_iterator))
     conflicted_state = {}
 
-    for state_set in state_sets[1:]:
+    for state_set in state_set_iterator:
         for key, value in iteritems(state_set):
             # Check if there is an unconflicted entry for the state key.
             unconflicted_value = unconflicted_state.get(key)
@@ -647,7 +695,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
         for event_id in event_ids
     )
     if event_map is not None:
-        needed_events -= set(event_map.iterkeys())
+        needed_events -= set(iterkeys(event_map))
 
     logger.info("Asking for %d conflicted events", len(needed_events))
 
@@ -668,7 +716,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
     new_needed_events = set(itervalues(auth_events))
     new_needed_events -= needed_events
     if event_map is not None:
-        new_needed_events -= set(event_map.iterkeys())
+        new_needed_events -= set(iterkeys(event_map))
 
     logger.info("Asking for %d auth events", len(new_needed_events))