diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 650995c92c..db969e8997 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -70,19 +70,18 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
# Also fetch all auth events that appear in only some of the state sets'
# auth chains.
- auth_diff = yield _get_auth_chain_difference(
- state_sets, event_map, state_res_store,
- )
+ auth_diff = yield _get_auth_chain_difference(state_sets, event_map, state_res_store)
- full_conflicted_set = set(itertools.chain(
- itertools.chain.from_iterable(itervalues(conflicted_state)),
- auth_diff,
- ))
+ full_conflicted_set = set(
+ itertools.chain(
+ itertools.chain.from_iterable(itervalues(conflicted_state)), auth_diff
+ )
+ )
- events = yield state_res_store.get_events([
- eid for eid in full_conflicted_set
- if eid not in event_map
- ], allow_rejected=True)
+ events = yield state_res_store.get_events(
+ [eid for eid in full_conflicted_set if eid not in event_map],
+ allow_rejected=True,
+ )
event_map.update(events)
full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map)
@@ -91,22 +90,21 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
# Get and sort all the power events (kicks/bans/etc)
power_events = (
- eid for eid in full_conflicted_set
- if _is_power_event(event_map[eid])
+ eid for eid in full_conflicted_set if _is_power_event(event_map[eid])
)
sorted_power_events = yield _reverse_topological_power_sort(
- power_events,
- event_map,
- state_res_store,
- full_conflicted_set,
+ power_events, event_map, state_res_store, full_conflicted_set
)
logger.debug("sorted %d power events", len(sorted_power_events))
# Now sequentially auth each one
resolved_state = yield _iterative_auth_checks(
- room_version, sorted_power_events, unconflicted_state, event_map,
+ room_version,
+ sorted_power_events,
+ unconflicted_state,
+ event_map,
state_res_store,
)
@@ -116,23 +114,20 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
# events using the mainline of the resolved power level.
leftover_events = [
- ev_id
- for ev_id in full_conflicted_set
- if ev_id not in sorted_power_events
+ ev_id for ev_id in full_conflicted_set if ev_id not in sorted_power_events
]
logger.debug("sorting %d remaining events", len(leftover_events))
pl = resolved_state.get((EventTypes.PowerLevels, ""), None)
leftover_events = yield _mainline_sort(
- leftover_events, pl, event_map, state_res_store,
+ leftover_events, pl, event_map, state_res_store
)
logger.debug("resolving remaining events")
resolved_state = yield _iterative_auth_checks(
- room_version, leftover_events, resolved_state, event_map,
- state_res_store,
+ room_version, leftover_events, resolved_state, event_map, state_res_store
)
logger.debug("resolved")
@@ -209,14 +204,16 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
auth_ids = set(
eid
for key, eid in iteritems(state_set)
- if (key[0] in (
- EventTypes.Member,
- EventTypes.ThirdPartyInvite,
- ) or key in (
- (EventTypes.PowerLevels, ''),
- (EventTypes.Create, ''),
- (EventTypes.JoinRules, ''),
- )) and eid not in common
+ if (
+ key[0] in (EventTypes.Member, EventTypes.ThirdPartyInvite)
+ or key
+ in (
+ (EventTypes.PowerLevels, ""),
+ (EventTypes.Create, ""),
+ (EventTypes.JoinRules, ""),
+ )
+ )
+ and eid not in common
)
auth_chain = yield state_res_store.get_auth_chain(auth_ids)
@@ -274,15 +271,16 @@ def _is_power_event(event):
return True
if event.type == EventTypes.Member:
- if event.membership in ('leave', 'ban'):
+ if event.membership in ("leave", "ban"):
return event.sender != event.state_key
return False
@defer.inlineCallbacks
-def _add_event_and_auth_chain_to_graph(graph, event_id, event_map,
- state_res_store, auth_diff):
+def _add_event_and_auth_chain_to_graph(
+ graph, event_id, event_map, state_res_store, auth_diff
+):
"""Helper function for _reverse_topological_power_sort that add the event
and its auth chain (that is in the auth diff) to the graph
@@ -327,7 +325,7 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_
graph = {}
for event_id in event_ids:
yield _add_event_and_auth_chain_to_graph(
- graph, event_id, event_map, state_res_store, auth_diff,
+ graph, event_id, event_map, state_res_store, auth_diff
)
event_to_pl = {}
@@ -342,18 +340,16 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_
return -pl, ev.origin_server_ts, event_id
# Note: graph is modified during the sort
- it = lexicographical_topological_sort(
- graph,
- key=_get_power_order,
- )
+ it = lexicographical_topological_sort(graph, key=_get_power_order)
sorted_events = list(it)
defer.returnValue(sorted_events)
@defer.inlineCallbacks
-def _iterative_auth_checks(room_version, event_ids, base_state, event_map,
- state_res_store):
+def _iterative_auth_checks(
+ room_version, event_ids, base_state, event_map, state_res_store
+):
"""Sequentially apply auth checks to each event in given list, updating the
state as it goes along.
@@ -389,9 +385,11 @@ def _iterative_auth_checks(room_version, event_ids, base_state, event_map,
try:
event_auth.check(
- room_version, event, auth_events,
+ room_version,
+ event,
+ auth_events,
do_sig_check=False,
- do_size_check=False
+ do_size_check=False,
)
resolved_state[(event.type, event.state_key)] = event_id
@@ -402,8 +400,7 @@ def _iterative_auth_checks(room_version, event_ids, base_state, event_map,
@defer.inlineCallbacks
-def _mainline_sort(event_ids, resolved_power_event_id, event_map,
- state_res_store):
+def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_store):
"""Returns a sorted list of event_ids sorted by mainline ordering based on
the given event resolved_power_event_id
@@ -436,8 +433,7 @@ def _mainline_sort(event_ids, resolved_power_event_id, event_map,
order_map = {}
for ev_id in event_ids:
depth = yield _get_mainline_depth_for_event(
- event_map[ev_id], mainline_map,
- event_map, state_res_store,
+ event_map[ev_id], mainline_map, event_map, state_res_store
)
order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id)
|