summary refs log tree commit diff
path: root/synapse/state.py
diff options
context:
space:
mode:
authorMatthew Hodgson <matthew@matrix.org>2018-05-29 00:25:22 +0100
committerMatthew Hodgson <matthew@matrix.org>2018-05-29 00:25:22 +0100
commit7a6df013cc8a128278d2ce7e5eb569e0b424f9b0 (patch)
tree5de624a65953eb96ab67274462d850a88c0cce3c /synapse/state.py
parentmake lazy_load_members configurable in filters (diff)
parentMerge pull request #3256 from matrix-org/3218-official-prom (diff)
downloadsynapse-7a6df013cc8a128278d2ce7e5eb569e0b424f9b0.tar.xz
merge develop
Diffstat (limited to '')
-rw-r--r--synapse/state.py95
1 files changed, 49 insertions, 46 deletions
diff --git a/synapse/state.py b/synapse/state.py
index 932f602508..b8c27c6815 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -32,6 +32,8 @@ from frozendict import frozendict
 import logging
 import hashlib
 
+from six import iteritems, itervalues
+
 logger = logging.getLogger(__name__)
 
 
@@ -132,7 +134,7 @@ class StateHandler(object):
 
         state_map = yield self.store.get_events(state.values(), get_prev_content=False)
         state = {
-            key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
+            key: state_map[e_id] for key, e_id in iteritems(state) if e_id in state_map
         }
 
         defer.returnValue(state)
@@ -338,7 +340,7 @@ class StateHandler(object):
         )
 
         if len(state_groups_ids) == 1:
-            name, state_list = state_groups_ids.items().pop()
+            name, state_list = list(state_groups_ids.items()).pop()
 
             prev_group, delta_ids = yield self.store.get_state_group_delta(name)
 
@@ -378,7 +380,7 @@ class StateHandler(object):
             new_state = resolve_events_with_state_map(state_set_ids, state_map)
 
         new_state = {
-            key: state_map[ev_id] for key, ev_id in new_state.items()
+            key: state_map[ev_id] for key, ev_id in iteritems(new_state)
         }
 
         return new_state
@@ -458,15 +460,15 @@ class StateResolutionHandler(object):
             # build a map from state key to the event_ids which set that state.
             # dict[(str, str), set[str])
             state = {}
-            for st in state_groups_ids.values():
-                for key, e_id in st.items():
+            for st in itervalues(state_groups_ids):
+                for key, e_id in iteritems(st):
                     state.setdefault(key, set()).add(e_id)
 
             # build a map from state key to the event_ids which set that state,
             # including only those where there are state keys in conflict.
             conflicted_state = {
                 k: list(v)
-                for k, v in state.items()
+                for k, v in iteritems(state)
                 if len(v) > 1
             }
 
@@ -474,42 +476,43 @@ class StateResolutionHandler(object):
                 logger.info("Resolving conflicted state for %r", room_id)
                 with Measure(self.clock, "state._resolve_events"):
                     new_state = yield resolve_events_with_factory(
-                        state_groups_ids.values(),
+                        list(state_groups_ids.values()),
                         event_map=event_map,
                         state_map_factory=state_map_factory,
                     )
             else:
                 new_state = {
-                    key: e_ids.pop() for key, e_ids in state.items()
+                    key: e_ids.pop() for key, e_ids in iteritems(state)
                 }
 
-            # if the new state matches any of the input state groups, we can
-            # use that state group again. Otherwise we will generate a state_id
-            # which will be used as a cache key for future resolutions, but
-            # not get persisted.
-            state_group = None
-            new_state_event_ids = frozenset(new_state.values())
-            for sg, events in state_groups_ids.items():
-                if new_state_event_ids == frozenset(e_id for e_id in events):
-                    state_group = sg
-                    break
-
-            # TODO: We want to create a state group for this set of events, to
-            # increase cache hits, but we need to make sure that it doesn't
-            # end up as a prev_group without being added to the database
-
-            prev_group = None
-            delta_ids = None
-            for old_group, old_ids in state_groups_ids.iteritems():
-                if not set(new_state) - set(old_ids):
-                    n_delta_ids = {
-                        k: v
-                        for k, v in new_state.iteritems()
-                        if old_ids.get(k) != v
-                    }
-                    if not delta_ids or len(n_delta_ids) < len(delta_ids):
-                        prev_group = old_group
-                        delta_ids = n_delta_ids
+            with Measure(self.clock, "state.create_group_ids"):
+                # if the new state matches any of the input state groups, we can
+                # use that state group again. Otherwise we will generate a state_id
+                # which will be used as a cache key for future resolutions, but
+                # not get persisted.
+                state_group = None
+                new_state_event_ids = frozenset(itervalues(new_state))
+                for sg, events in iteritems(state_groups_ids):
+                    if new_state_event_ids == frozenset(e_id for e_id in events):
+                        state_group = sg
+                        break
+
+                # TODO: We want to create a state group for this set of events, to
+                # increase cache hits, but we need to make sure that it doesn't
+                # end up as a prev_group without being added to the database
+
+                prev_group = None
+                delta_ids = None
+                for old_group, old_ids in iteritems(state_groups_ids):
+                    if not set(new_state) - set(old_ids):
+                        n_delta_ids = {
+                            k: v
+                            for k, v in iteritems(new_state)
+                            if old_ids.get(k) != v
+                        }
+                        if not delta_ids or len(n_delta_ids) < len(delta_ids):
+                            prev_group = old_group
+                            delta_ids = n_delta_ids
 
             cache = _StateCacheEntry(
                 state=new_state,
@@ -526,7 +529,7 @@ class StateResolutionHandler(object):
 
 def _ordered_events(events):
     def key_func(e):
-        return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()
+        return -int(e.depth), hashlib.sha1(e.event_id.encode()).hexdigest()
 
     return sorted(events, key=key_func)
 
@@ -583,7 +586,7 @@ def _seperate(state_sets):
     conflicted_state = {}
 
     for state_set in state_sets[1:]:
-        for key, value in state_set.iteritems():
+        for key, value in iteritems(state_set):
             # Check if there is an unconflicted entry for the state key.
             unconflicted_value = unconflicted_state.get(key)
             if unconflicted_value is None:
@@ -639,7 +642,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
 
     needed_events = set(
         event_id
-        for event_ids in conflicted_state.itervalues()
+        for event_ids in itervalues(conflicted_state)
         for event_id in event_ids
     )
     if event_map is not None:
@@ -661,7 +664,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
         unconflicted_state, conflicted_state, state_map
     )
 
-    new_needed_events = set(auth_events.itervalues())
+    new_needed_events = set(itervalues(auth_events))
     new_needed_events -= needed_events
     if event_map is not None:
         new_needed_events -= set(event_map.iterkeys())
@@ -678,7 +681,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
 
 def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
     auth_events = {}
-    for event_ids in conflicted_state.itervalues():
+    for event_ids in itervalues(conflicted_state):
         for event_id in event_ids:
             if event_id in state_map:
                 keys = event_auth.auth_types_for_event(state_map[event_id])
@@ -693,7 +696,7 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma
 def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ids,
                         state_map):
     conflicted_state = {}
-    for key, event_ids in conflicted_state_ds.iteritems():
+    for key, event_ids in iteritems(conflicted_state_ds):
         events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
         if len(events) > 1:
             conflicted_state[key] = events
@@ -702,7 +705,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_
 
     auth_events = {
         key: state_map[ev_id]
-        for key, ev_id in auth_event_ids.items()
+        for key, ev_id in iteritems(auth_event_ids)
         if ev_id in state_map
     }
 
@@ -715,7 +718,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_
         raise
 
     new_state = unconflicted_state_ids
-    for key, event in resolved_state.iteritems():
+    for key, event in iteritems(resolved_state):
         new_state[key] = event.event_id
 
     return new_state
@@ -740,7 +743,7 @@ def _resolve_state_events(conflicted_state, auth_events):
 
     auth_events.update(resolved_state)
 
-    for key, events in conflicted_state.items():
+    for key, events in iteritems(conflicted_state):
         if key[0] == EventTypes.JoinRules:
             logger.debug("Resolving conflicted join rules %r", events)
             resolved_state[key] = _resolve_auth_events(
@@ -750,7 +753,7 @@ def _resolve_state_events(conflicted_state, auth_events):
 
     auth_events.update(resolved_state)
 
-    for key, events in conflicted_state.items():
+    for key, events in iteritems(conflicted_state):
         if key[0] == EventTypes.Member:
             logger.debug("Resolving conflicted member lists %r", events)
             resolved_state[key] = _resolve_auth_events(
@@ -760,7 +763,7 @@ def _resolve_state_events(conflicted_state, auth_events):
 
     auth_events.update(resolved_state)
 
-    for key, events in conflicted_state.items():
+    for key, events in iteritems(conflicted_state):
         if key not in resolved_state:
             logger.debug("Resolving conflicted state %r:%r", key, events)
             resolved_state[key] = _resolve_normal_events(