summary refs log tree commit diff
path: root/synapse/state/v2.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/state/v2.py92
1 files changed, 44 insertions, 48 deletions
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 650995c92c..db969e8997 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -70,19 +70,18 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
 
     # Also fetch all auth events that appear in only some of the state sets'
     # auth chains.
-    auth_diff = yield _get_auth_chain_difference(
-        state_sets, event_map, state_res_store,
-    )
+    auth_diff = yield _get_auth_chain_difference(state_sets, event_map, state_res_store)
 
-    full_conflicted_set = set(itertools.chain(
-        itertools.chain.from_iterable(itervalues(conflicted_state)),
-        auth_diff,
-    ))
+    full_conflicted_set = set(
+        itertools.chain(
+            itertools.chain.from_iterable(itervalues(conflicted_state)), auth_diff
+        )
+    )
 
-    events = yield state_res_store.get_events([
-        eid for eid in full_conflicted_set
-        if eid not in event_map
-    ], allow_rejected=True)
+    events = yield state_res_store.get_events(
+        [eid for eid in full_conflicted_set if eid not in event_map],
+        allow_rejected=True,
+    )
     event_map.update(events)
 
     full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map)
@@ -91,22 +90,21 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
 
     # Get and sort all the power events (kicks/bans/etc)
     power_events = (
-        eid for eid in full_conflicted_set
-        if _is_power_event(event_map[eid])
+        eid for eid in full_conflicted_set if _is_power_event(event_map[eid])
     )
 
     sorted_power_events = yield _reverse_topological_power_sort(
-        power_events,
-        event_map,
-        state_res_store,
-        full_conflicted_set,
+        power_events, event_map, state_res_store, full_conflicted_set
     )
 
     logger.debug("sorted %d power events", len(sorted_power_events))
 
     # Now sequentially auth each one
     resolved_state = yield _iterative_auth_checks(
-        room_version, sorted_power_events, unconflicted_state, event_map,
+        room_version,
+        sorted_power_events,
+        unconflicted_state,
+        event_map,
         state_res_store,
     )
 
@@ -116,23 +114,20 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
     # events using the mainline of the resolved power level.
 
     leftover_events = [
-        ev_id
-        for ev_id in full_conflicted_set
-        if ev_id not in sorted_power_events
+        ev_id for ev_id in full_conflicted_set if ev_id not in sorted_power_events
     ]
 
     logger.debug("sorting %d remaining events", len(leftover_events))
 
     pl = resolved_state.get((EventTypes.PowerLevels, ""), None)
     leftover_events = yield _mainline_sort(
-        leftover_events, pl, event_map, state_res_store,
+        leftover_events, pl, event_map, state_res_store
     )
 
     logger.debug("resolving remaining events")
 
     resolved_state = yield _iterative_auth_checks(
-        room_version, leftover_events, resolved_state, event_map,
-        state_res_store,
+        room_version, leftover_events, resolved_state, event_map, state_res_store
     )
 
     logger.debug("resolved")
@@ -209,14 +204,16 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
         auth_ids = set(
             eid
             for key, eid in iteritems(state_set)
-            if (key[0] in (
-                EventTypes.Member,
-                EventTypes.ThirdPartyInvite,
-            ) or key in (
-                (EventTypes.PowerLevels, ''),
-                (EventTypes.Create, ''),
-                (EventTypes.JoinRules, ''),
-            )) and eid not in common
+            if (
+                key[0] in (EventTypes.Member, EventTypes.ThirdPartyInvite)
+                or key
+                in (
+                    (EventTypes.PowerLevels, ""),
+                    (EventTypes.Create, ""),
+                    (EventTypes.JoinRules, ""),
+                )
+            )
+            and eid not in common
         )
 
         auth_chain = yield state_res_store.get_auth_chain(auth_ids)
@@ -274,15 +271,16 @@ def _is_power_event(event):
         return True
 
     if event.type == EventTypes.Member:
-        if event.membership in ('leave', 'ban'):
+        if event.membership in ("leave", "ban"):
             return event.sender != event.state_key
 
     return False
 
 
 @defer.inlineCallbacks
-def _add_event_and_auth_chain_to_graph(graph, event_id, event_map,
-                                       state_res_store, auth_diff):
+def _add_event_and_auth_chain_to_graph(
+    graph, event_id, event_map, state_res_store, auth_diff
+):
     """Helper function for _reverse_topological_power_sort that add the event
     and its auth chain (that is in the auth diff) to the graph
 
@@ -327,7 +325,7 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_
     graph = {}
     for event_id in event_ids:
         yield _add_event_and_auth_chain_to_graph(
-            graph, event_id, event_map, state_res_store, auth_diff,
+            graph, event_id, event_map, state_res_store, auth_diff
         )
 
     event_to_pl = {}
@@ -342,18 +340,16 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_
         return -pl, ev.origin_server_ts, event_id
 
     # Note: graph is modified during the sort
-    it = lexicographical_topological_sort(
-        graph,
-        key=_get_power_order,
-    )
+    it = lexicographical_topological_sort(graph, key=_get_power_order)
     sorted_events = list(it)
 
     defer.returnValue(sorted_events)
 
 
 @defer.inlineCallbacks
-def _iterative_auth_checks(room_version, event_ids, base_state, event_map,
-                           state_res_store):
+def _iterative_auth_checks(
+    room_version, event_ids, base_state, event_map, state_res_store
+):
     """Sequentially apply auth checks to each event in given list, updating the
     state as it goes along.
 
@@ -389,9 +385,11 @@ def _iterative_auth_checks(room_version, event_ids, base_state, event_map,
 
         try:
             event_auth.check(
-                room_version, event, auth_events,
+                room_version,
+                event,
+                auth_events,
                 do_sig_check=False,
-                do_size_check=False
+                do_size_check=False,
             )
 
             resolved_state[(event.type, event.state_key)] = event_id
@@ -402,8 +400,7 @@ def _iterative_auth_checks(room_version, event_ids, base_state, event_map,
 
 
 @defer.inlineCallbacks
-def _mainline_sort(event_ids, resolved_power_event_id, event_map,
-                   state_res_store):
+def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_store):
     """Returns a sorted list of event_ids sorted by mainline ordering based on
     the given event resolved_power_event_id
 
@@ -436,8 +433,7 @@ def _mainline_sort(event_ids, resolved_power_event_id, event_map,
     order_map = {}
     for ev_id in event_ids:
         depth = yield _get_mainline_depth_for_event(
-            event_map[ev_id], mainline_map,
-            event_map, state_res_store,
+            event_map[ev_id], mainline_map, event_map, state_res_store
         )
         order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id)