summary refs log tree commit diff
path: root/synapse/state
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2019-06-20 11:59:14 +0100
committerErik Johnston <erik@matrix.org>2019-06-20 11:59:14 +0100
commit45f28a9d2fc0466dcf2a05b0063b7caa3b7e12c3 (patch)
tree07bb21377c6611db89f64f948a2e27645662ff0e /synapse/state
parentAdd descriptions and remove redundant set(..) (diff)
parentRun Black. (#5482) (diff)
downloadsynapse-45f28a9d2fc0466dcf2a05b0063b7caa3b7e12c3.tar.xz
Merge branch 'develop' of github.com:matrix-org/synapse into erikj/histogram_extremities
Diffstat (limited to 'synapse/state')
-rw-r--r--synapse/state/__init__.py107
-rw-r--r--synapse/state/v1.py56
-rw-r--r--synapse/state/v2.py92
3 files changed, 104 insertions, 151 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 4b136b3054..1b454a56a1 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -107,8 +107,9 @@ class StateHandler(object):
         self._state_resolution_handler = hs.get_state_resolution_handler()
 
     @defer.inlineCallbacks
-    def get_current_state(self, room_id, event_type=None, state_key="",
-                          latest_event_ids=None):
+    def get_current_state(
+        self, room_id, event_type=None, state_key="", latest_event_ids=None
+    ):
         """ Retrieves the current state for the room. This is done by
         calling `get_latest_events_in_room` to get the leading edges of the
         event graph and then resolving any of the state conflicts.
@@ -137,8 +138,9 @@ class StateHandler(object):
             defer.returnValue(event)
             return
 
-        state_map = yield self.store.get_events(list(state.values()),
-                                                get_prev_content=False)
+        state_map = yield self.store.get_events(
+            list(state.values()), get_prev_content=False
+        )
         state = {
             key: state_map[e_id] for key, e_id in iteritems(state) if e_id in state_map
         }
@@ -220,9 +222,7 @@ class StateHandler(object):
             # state. Certainly store.get_current_state won't return any, and
             # persisting the event won't store the state group.
             if old_state:
-                prev_state_ids = {
-                    (s.type, s.state_key): s.event_id for s in old_state
-                }
+                prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state}
                 if event.is_state():
                     current_state_ids = dict(prev_state_ids)
                     key = (event.type, event.state_key)
@@ -248,9 +248,7 @@ class StateHandler(object):
             # Let's just correctly fill out the context and create a
             # new state group for it.
 
-            prev_state_ids = {
-                (s.type, s.state_key): s.event_id for s in old_state
-            }
+            prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state}
 
             if event.is_state():
                 key = (event.type, event.state_key)
@@ -282,7 +280,7 @@ class StateHandler(object):
         logger.debug("calling resolve_state_groups from compute_event_context")
 
         entry = yield self.resolve_state_groups_for_events(
-            event.room_id, event.prev_event_ids(),
+            event.room_id, event.prev_event_ids()
         )
 
         prev_state_ids = entry.state
@@ -305,9 +303,7 @@ class StateHandler(object):
                 # If the state at the event has a state group assigned then
                 # we can use that as the prev group
                 prev_group = entry.state_group
-                delta_ids = {
-                    key: event.event_id
-                }
+                delta_ids = {key: event.event_id}
             elif entry.prev_group:
                 # If the state at the event only has a prev group, then we can
                 # use that as a prev group too.
@@ -369,31 +365,31 @@ class StateHandler(object):
         # map from state group id to the state in that state group (where
         # 'state' is a map from state key to event id)
         # dict[int, dict[(str, str), str]]
-        state_groups_ids = yield self.store.get_state_groups_ids(
-            room_id, event_ids
-        )
+        state_groups_ids = yield self.store.get_state_groups_ids(room_id, event_ids)
 
         if len(state_groups_ids) == 0:
-            defer.returnValue(_StateCacheEntry(
-                state={},
-                state_group=None,
-            ))
+            defer.returnValue(_StateCacheEntry(state={}, state_group=None))
         elif len(state_groups_ids) == 1:
             name, state_list = list(state_groups_ids.items()).pop()
 
             prev_group, delta_ids = yield self.store.get_state_group_delta(name)
 
-            defer.returnValue(_StateCacheEntry(
-                state=state_list,
-                state_group=name,
-                prev_group=prev_group,
-                delta_ids=delta_ids,
-            ))
+            defer.returnValue(
+                _StateCacheEntry(
+                    state=state_list,
+                    state_group=name,
+                    prev_group=prev_group,
+                    delta_ids=delta_ids,
+                )
+            )
 
         room_version = yield self.store.get_room_version(room_id)
 
         result = yield self._state_resolution_handler.resolve_state_groups(
-            room_id, room_version, state_groups_ids, None,
+            room_id,
+            room_version,
+            state_groups_ids,
+            None,
             state_res_store=StateResolutionStore(self.store),
         )
         defer.returnValue(result)
@@ -403,27 +399,21 @@ class StateHandler(object):
         logger.info(
             "Resolving state for %s with %d groups", event.room_id, len(state_sets)
         )
-        state_set_ids = [{
-            (ev.type, ev.state_key): ev.event_id
-            for ev in st
-        } for st in state_sets]
-
-        state_map = {
-            ev.event_id: ev
-            for st in state_sets
-            for ev in st
-        }
+        state_set_ids = [
+            {(ev.type, ev.state_key): ev.event_id for ev in st} for st in state_sets
+        ]
+
+        state_map = {ev.event_id: ev for st in state_sets for ev in st}
 
         with Measure(self.clock, "state._resolve_events"):
             new_state = yield resolve_events_with_store(
-                room_version, state_set_ids,
+                room_version,
+                state_set_ids,
                 event_map=state_map,
                 state_res_store=StateResolutionStore(self.store),
             )
 
-        new_state = {
-            key: state_map[ev_id] for key, ev_id in iteritems(new_state)
-        }
+        new_state = {key: state_map[ev_id] for key, ev_id in iteritems(new_state)}
 
         defer.returnValue(new_state)
 
@@ -434,6 +424,7 @@ class StateResolutionHandler(object):
     Note that the storage layer depends on this handler, so all functions must
     be storage-independent.
     """
+
     def __init__(self, hs):
         self.clock = hs.get_clock()
 
@@ -453,7 +444,7 @@ class StateResolutionHandler(object):
     @defer.inlineCallbacks
     @log_function
     def resolve_state_groups(
-        self, room_id, room_version, state_groups_ids, event_map, state_res_store,
+        self, room_id, room_version, state_groups_ids, event_map, state_res_store
     ):
         """Resolves conflicts between a set of state groups
 
@@ -480,10 +471,7 @@ class StateResolutionHandler(object):
         Returns:
             Deferred[_StateCacheEntry]: resolved state
         """
-        logger.debug(
-            "resolve_state_groups state_groups %s",
-            state_groups_ids.keys()
-        )
+        logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys())
 
         group_names = frozenset(state_groups_ids.keys())
 
@@ -540,10 +528,7 @@ class StateResolutionHandler(object):
             defer.returnValue(cache)
 
 
-def _make_state_cache_entry(
-    new_state,
-    state_groups_ids,
-):
+def _make_state_cache_entry(new_state, state_groups_ids):
     """Given a resolved state, and a set of input state groups, pick one to base
     a new state group on (if any), and return an appropriately-constructed
     _StateCacheEntry.
@@ -573,10 +558,7 @@ def _make_state_cache_entry(
         old_state_event_ids = set(itervalues(state))
         if new_state_event_ids == old_state_event_ids:
             # got an exact match.
-            return _StateCacheEntry(
-                state=new_state,
-                state_group=sg,
-            )
+            return _StateCacheEntry(state=new_state, state_group=sg)
 
     # TODO: We want to create a state group for this set of events, to
     # increase cache hits, but we need to make sure that it doesn't
@@ -587,20 +569,13 @@ def _make_state_cache_entry(
     delta_ids = None
 
     for old_group, old_state in iteritems(state_groups_ids):
-        n_delta_ids = {
-            k: v
-            for k, v in iteritems(new_state)
-            if old_state.get(k) != v
-        }
+        n_delta_ids = {k: v for k, v in iteritems(new_state) if old_state.get(k) != v}
         if not delta_ids or len(n_delta_ids) < len(delta_ids):
             prev_group = old_group
             delta_ids = n_delta_ids
 
     return _StateCacheEntry(
-        state=new_state,
-        state_group=None,
-        prev_group=prev_group,
-        delta_ids=delta_ids,
+        state=new_state, state_group=None, prev_group=prev_group, delta_ids=delta_ids
     )
 
 
@@ -629,11 +604,11 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
     v = KNOWN_ROOM_VERSIONS[room_version]
     if v.state_res == StateResolutionVersions.V1:
         return v1.resolve_events_with_store(
-            state_sets, event_map, state_res_store.get_events,
+            state_sets, event_map, state_res_store.get_events
         )
     else:
         return v2.resolve_events_with_store(
-            room_version, state_sets, event_map, state_res_store,
+            room_version, state_sets, event_map, state_res_store
         )
 
 
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 29b4e86cfd..88acd4817e 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -57,23 +57,17 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory):
     if len(state_sets) == 1:
         defer.returnValue(state_sets[0])
 
-    unconflicted_state, conflicted_state = _seperate(
-        state_sets,
-    )
+    unconflicted_state, conflicted_state = _seperate(state_sets)
 
     needed_events = set(
-        event_id
-        for event_ids in itervalues(conflicted_state)
-        for event_id in event_ids
+        event_id for event_ids in itervalues(conflicted_state) for event_id in event_ids
     )
     needed_event_count = len(needed_events)
     if event_map is not None:
         needed_events -= set(iterkeys(event_map))
 
     logger.info(
-        "Asking for %d/%d conflicted events",
-        len(needed_events),
-        needed_event_count,
+        "Asking for %d/%d conflicted events", len(needed_events), needed_event_count
     )
 
     # dict[str, FrozenEvent]: a map from state event id to event. Only includes
@@ -97,17 +91,17 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory):
         new_needed_events -= set(iterkeys(event_map))
 
     logger.info(
-        "Asking for %d/%d auth events",
-        len(new_needed_events),
-        new_needed_event_count,
+        "Asking for %d/%d auth events", len(new_needed_events), new_needed_event_count
     )
 
     state_map_new = yield state_map_factory(new_needed_events)
     state_map.update(state_map_new)
 
-    defer.returnValue(_resolve_with_state(
-        unconflicted_state, conflicted_state, auth_events, state_map
-    ))
+    defer.returnValue(
+        _resolve_with_state(
+            unconflicted_state, conflicted_state, auth_events, state_map
+        )
+    )
 
 
 def _seperate(state_sets):
@@ -173,8 +167,9 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma
     return auth_events
 
 
-def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event_ids,
-                        state_map):
+def _resolve_with_state(
+    unconflicted_state_ids, conflicted_state_ids, auth_event_ids, state_map
+):
     conflicted_state = {}
     for key, event_ids in iteritems(conflicted_state_ids):
         events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
@@ -190,9 +185,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event
     }
 
     try:
-        resolved_state = _resolve_state_events(
-            conflicted_state, auth_events
-        )
+        resolved_state = _resolve_state_events(conflicted_state, auth_events)
     except Exception:
         logger.exception("Failed to resolve state")
         raise
@@ -218,37 +211,28 @@ def _resolve_state_events(conflicted_state, auth_events):
     if POWER_KEY in conflicted_state:
         events = conflicted_state[POWER_KEY]
         logger.debug("Resolving conflicted power levels %r", events)
-        resolved_state[POWER_KEY] = _resolve_auth_events(
-            events, auth_events)
+        resolved_state[POWER_KEY] = _resolve_auth_events(events, auth_events)
 
     auth_events.update(resolved_state)
 
     for key, events in iteritems(conflicted_state):
         if key[0] == EventTypes.JoinRules:
             logger.debug("Resolving conflicted join rules %r", events)
-            resolved_state[key] = _resolve_auth_events(
-                events,
-                auth_events
-            )
+            resolved_state[key] = _resolve_auth_events(events, auth_events)
 
     auth_events.update(resolved_state)
 
     for key, events in iteritems(conflicted_state):
         if key[0] == EventTypes.Member:
             logger.debug("Resolving conflicted member lists %r", events)
-            resolved_state[key] = _resolve_auth_events(
-                events,
-                auth_events
-            )
+            resolved_state[key] = _resolve_auth_events(events, auth_events)
 
     auth_events.update(resolved_state)
 
     for key, events in iteritems(conflicted_state):
         if key not in resolved_state:
             logger.debug("Resolving conflicted state %r:%r", key, events)
-            resolved_state[key] = _resolve_normal_events(
-                events, auth_events
-            )
+            resolved_state[key] = _resolve_normal_events(events, auth_events)
 
     return resolved_state
 
@@ -257,9 +241,7 @@ def _resolve_auth_events(events, auth_events):
     reverse = [i for i in reversed(_ordered_events(events))]
 
     auth_keys = set(
-        key
-        for event in events
-        for key in event_auth.auth_types_for_event(event)
+        key for event in events for key in event_auth.auth_types_for_event(event)
     )
 
     new_auth_events = {}
@@ -313,6 +295,6 @@ def _ordered_events(events):
     def key_func(e):
         # we have to use utf-8 rather than ascii here because it turns out we allow
         # people to send us events with non-ascii event IDs :/
-        return -int(e.depth), hashlib.sha1(e.event_id.encode('utf-8')).hexdigest()
+        return -int(e.depth), hashlib.sha1(e.event_id.encode("utf-8")).hexdigest()
 
     return sorted(events, key=key_func)
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 650995c92c..db969e8997 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -70,19 +70,18 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
 
     # Also fetch all auth events that appear in only some of the state sets'
     # auth chains.
-    auth_diff = yield _get_auth_chain_difference(
-        state_sets, event_map, state_res_store,
-    )
+    auth_diff = yield _get_auth_chain_difference(state_sets, event_map, state_res_store)
 
-    full_conflicted_set = set(itertools.chain(
-        itertools.chain.from_iterable(itervalues(conflicted_state)),
-        auth_diff,
-    ))
+    full_conflicted_set = set(
+        itertools.chain(
+            itertools.chain.from_iterable(itervalues(conflicted_state)), auth_diff
+        )
+    )
 
-    events = yield state_res_store.get_events([
-        eid for eid in full_conflicted_set
-        if eid not in event_map
-    ], allow_rejected=True)
+    events = yield state_res_store.get_events(
+        [eid for eid in full_conflicted_set if eid not in event_map],
+        allow_rejected=True,
+    )
     event_map.update(events)
 
     full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map)
@@ -91,22 +90,21 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
 
     # Get and sort all the power events (kicks/bans/etc)
     power_events = (
-        eid for eid in full_conflicted_set
-        if _is_power_event(event_map[eid])
+        eid for eid in full_conflicted_set if _is_power_event(event_map[eid])
     )
 
     sorted_power_events = yield _reverse_topological_power_sort(
-        power_events,
-        event_map,
-        state_res_store,
-        full_conflicted_set,
+        power_events, event_map, state_res_store, full_conflicted_set
     )
 
     logger.debug("sorted %d power events", len(sorted_power_events))
 
     # Now sequentially auth each one
     resolved_state = yield _iterative_auth_checks(
-        room_version, sorted_power_events, unconflicted_state, event_map,
+        room_version,
+        sorted_power_events,
+        unconflicted_state,
+        event_map,
         state_res_store,
     )
 
@@ -116,23 +114,20 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
     # events using the mainline of the resolved power level.
 
     leftover_events = [
-        ev_id
-        for ev_id in full_conflicted_set
-        if ev_id not in sorted_power_events
+        ev_id for ev_id in full_conflicted_set if ev_id not in sorted_power_events
     ]
 
     logger.debug("sorting %d remaining events", len(leftover_events))
 
     pl = resolved_state.get((EventTypes.PowerLevels, ""), None)
     leftover_events = yield _mainline_sort(
-        leftover_events, pl, event_map, state_res_store,
+        leftover_events, pl, event_map, state_res_store
     )
 
     logger.debug("resolving remaining events")
 
     resolved_state = yield _iterative_auth_checks(
-        room_version, leftover_events, resolved_state, event_map,
-        state_res_store,
+        room_version, leftover_events, resolved_state, event_map, state_res_store
     )
 
     logger.debug("resolved")
@@ -209,14 +204,16 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
         auth_ids = set(
             eid
             for key, eid in iteritems(state_set)
-            if (key[0] in (
-                EventTypes.Member,
-                EventTypes.ThirdPartyInvite,
-            ) or key in (
-                (EventTypes.PowerLevels, ''),
-                (EventTypes.Create, ''),
-                (EventTypes.JoinRules, ''),
-            )) and eid not in common
+            if (
+                key[0] in (EventTypes.Member, EventTypes.ThirdPartyInvite)
+                or key
+                in (
+                    (EventTypes.PowerLevels, ""),
+                    (EventTypes.Create, ""),
+                    (EventTypes.JoinRules, ""),
+                )
+            )
+            and eid not in common
         )
 
         auth_chain = yield state_res_store.get_auth_chain(auth_ids)
@@ -274,15 +271,16 @@ def _is_power_event(event):
         return True
 
     if event.type == EventTypes.Member:
-        if event.membership in ('leave', 'ban'):
+        if event.membership in ("leave", "ban"):
             return event.sender != event.state_key
 
     return False
 
 
 @defer.inlineCallbacks
-def _add_event_and_auth_chain_to_graph(graph, event_id, event_map,
-                                       state_res_store, auth_diff):
+def _add_event_and_auth_chain_to_graph(
+    graph, event_id, event_map, state_res_store, auth_diff
+):
     """Helper function for _reverse_topological_power_sort that add the event
     and its auth chain (that is in the auth diff) to the graph
 
@@ -327,7 +325,7 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_
     graph = {}
     for event_id in event_ids:
         yield _add_event_and_auth_chain_to_graph(
-            graph, event_id, event_map, state_res_store, auth_diff,
+            graph, event_id, event_map, state_res_store, auth_diff
         )
 
     event_to_pl = {}
@@ -342,18 +340,16 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_
         return -pl, ev.origin_server_ts, event_id
 
     # Note: graph is modified during the sort
-    it = lexicographical_topological_sort(
-        graph,
-        key=_get_power_order,
-    )
+    it = lexicographical_topological_sort(graph, key=_get_power_order)
     sorted_events = list(it)
 
     defer.returnValue(sorted_events)
 
 
 @defer.inlineCallbacks
-def _iterative_auth_checks(room_version, event_ids, base_state, event_map,
-                           state_res_store):
+def _iterative_auth_checks(
+    room_version, event_ids, base_state, event_map, state_res_store
+):
     """Sequentially apply auth checks to each event in given list, updating the
     state as it goes along.
 
@@ -389,9 +385,11 @@ def _iterative_auth_checks(room_version, event_ids, base_state, event_map,
 
         try:
             event_auth.check(
-                room_version, event, auth_events,
+                room_version,
+                event,
+                auth_events,
                 do_sig_check=False,
-                do_size_check=False
+                do_size_check=False,
             )
 
             resolved_state[(event.type, event.state_key)] = event_id
@@ -402,8 +400,7 @@ def _iterative_auth_checks(room_version, event_ids, base_state, event_map,
 
 
 @defer.inlineCallbacks
-def _mainline_sort(event_ids, resolved_power_event_id, event_map,
-                   state_res_store):
+def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_store):
     """Returns a sorted list of event_ids sorted by mainline ordering based on
     the given event resolved_power_event_id
 
@@ -436,8 +433,7 @@ def _mainline_sort(event_ids, resolved_power_event_id, event_map,
     order_map = {}
     for ev_id in event_ids:
         depth = yield _get_mainline_depth_for_event(
-            event_map[ev_id], mainline_map,
-            event_map, state_res_store,
+            event_map[ev_id], mainline_map, event_map, state_res_store
         )
         order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id)