diff options
Diffstat (limited to 'synapse/state/v2.py')
-rw-r--r-- | synapse/state/v2.py | 92 |
1 files changed, 44 insertions, 48 deletions
diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 650995c92c..db969e8997 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -70,19 +70,18 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto # Also fetch all auth events that appear in only some of the state sets' # auth chains. - auth_diff = yield _get_auth_chain_difference( - state_sets, event_map, state_res_store, - ) + auth_diff = yield _get_auth_chain_difference(state_sets, event_map, state_res_store) - full_conflicted_set = set(itertools.chain( - itertools.chain.from_iterable(itervalues(conflicted_state)), - auth_diff, - )) + full_conflicted_set = set( + itertools.chain( + itertools.chain.from_iterable(itervalues(conflicted_state)), auth_diff + ) + ) - events = yield state_res_store.get_events([ - eid for eid in full_conflicted_set - if eid not in event_map - ], allow_rejected=True) + events = yield state_res_store.get_events( + [eid for eid in full_conflicted_set if eid not in event_map], + allow_rejected=True, + ) event_map.update(events) full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map) @@ -91,22 +90,21 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto # Get and sort all the power events (kicks/bans/etc) power_events = ( - eid for eid in full_conflicted_set - if _is_power_event(event_map[eid]) + eid for eid in full_conflicted_set if _is_power_event(event_map[eid]) ) sorted_power_events = yield _reverse_topological_power_sort( - power_events, - event_map, - state_res_store, - full_conflicted_set, + power_events, event_map, state_res_store, full_conflicted_set ) logger.debug("sorted %d power events", len(sorted_power_events)) # Now sequentially auth each one resolved_state = yield _iterative_auth_checks( - room_version, sorted_power_events, unconflicted_state, event_map, + room_version, + sorted_power_events, + unconflicted_state, + event_map, state_res_store, ) @@ -116,23 +114,20 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto # events using the mainline of the resolved power level. leftover_events = [ - ev_id - for ev_id in full_conflicted_set - if ev_id not in sorted_power_events + ev_id for ev_id in full_conflicted_set if ev_id not in sorted_power_events ] logger.debug("sorting %d remaining events", len(leftover_events)) pl = resolved_state.get((EventTypes.PowerLevels, ""), None) leftover_events = yield _mainline_sort( - leftover_events, pl, event_map, state_res_store, + leftover_events, pl, event_map, state_res_store ) logger.debug("resolving remaining events") resolved_state = yield _iterative_auth_checks( - room_version, leftover_events, resolved_state, event_map, - state_res_store, + room_version, leftover_events, resolved_state, event_map, state_res_store ) logger.debug("resolved") @@ -209,14 +204,16 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store): auth_ids = set( eid for key, eid in iteritems(state_set) - if (key[0] in ( - EventTypes.Member, - EventTypes.ThirdPartyInvite, - ) or key in ( - (EventTypes.PowerLevels, ''), - (EventTypes.Create, ''), - (EventTypes.JoinRules, ''), - )) and eid not in common + if ( + key[0] in (EventTypes.Member, EventTypes.ThirdPartyInvite) + or key + in ( + (EventTypes.PowerLevels, ""), + (EventTypes.Create, ""), + (EventTypes.JoinRules, ""), + ) + ) + and eid not in common ) auth_chain = yield state_res_store.get_auth_chain(auth_ids) @@ -274,15 +271,16 @@ def _is_power_event(event): return True if event.type == EventTypes.Member: - if event.membership in ('leave', 'ban'): + if event.membership in ("leave", "ban"): return event.sender != event.state_key return False @defer.inlineCallbacks -def _add_event_and_auth_chain_to_graph(graph, event_id, event_map, - state_res_store, auth_diff): +def _add_event_and_auth_chain_to_graph( + graph, event_id, event_map, state_res_store, auth_diff +): """Helper function for _reverse_topological_power_sort that add the event and its auth chain (that is in the auth diff) to the graph @@ -327,7 +325,7 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_ graph = {} for event_id in event_ids: yield _add_event_and_auth_chain_to_graph( - graph, event_id, event_map, state_res_store, auth_diff, + graph, event_id, event_map, state_res_store, auth_diff ) event_to_pl = {} @@ -342,18 +340,16 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_ return -pl, ev.origin_server_ts, event_id # Note: graph is modified during the sort - it = lexicographical_topological_sort( - graph, - key=_get_power_order, - ) + it = lexicographical_topological_sort(graph, key=_get_power_order) sorted_events = list(it) defer.returnValue(sorted_events) @defer.inlineCallbacks -def _iterative_auth_checks(room_version, event_ids, base_state, event_map, - state_res_store): +def _iterative_auth_checks( + room_version, event_ids, base_state, event_map, state_res_store +): """Sequentially apply auth checks to each event in given list, updating the state as it goes along. @@ -389,9 +385,11 @@ def _iterative_auth_checks(room_version, event_ids, base_state, event_map, try: event_auth.check( - room_version, event, auth_events, + room_version, + event, + auth_events, do_sig_check=False, - do_size_check=False + do_size_check=False, ) resolved_state[(event.type, event.state_key)] = event_id @@ -402,8 +400,7 @@ def _iterative_auth_checks(room_version, event_ids, base_state, event_map, @defer.inlineCallbacks -def _mainline_sort(event_ids, resolved_power_event_id, event_map, - state_res_store): +def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_store): """Returns a sorted list of event_ids sorted by mainline ordering based on the given event resolved_power_event_id @@ -436,8 +433,7 @@ def _mainline_sort(event_ids, resolved_power_event_id, event_map, order_map = {} for ev_id in event_ids: depth = yield _get_mainline_depth_for_event( - event_map[ev_id], mainline_map, - event_map, state_res_store, + event_map[ev_id], mainline_map, event_map, state_res_store ) order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id) |