diff options
author | Erik Johnston <erik@matrix.org> | 2019-06-20 11:59:14 +0100 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2019-06-20 11:59:14 +0100 |
commit | 45f28a9d2fc0466dcf2a05b0063b7caa3b7e12c3 (patch) | |
tree | 07bb21377c6611db89f64f948a2e27645662ff0e /synapse/state | |
parent | Add descriptions and remove redundant set(..) (diff) | |
parent | Run Black. (#5482) (diff) | |
download | synapse-45f28a9d2fc0466dcf2a05b0063b7caa3b7e12c3.tar.xz |
Merge branch 'develop' of github.com:matrix-org/synapse into erikj/histogram_extremities
Diffstat (limited to 'synapse/state')
-rw-r--r-- | synapse/state/__init__.py | 107 | ||||
-rw-r--r-- | synapse/state/v1.py | 56 | ||||
-rw-r--r-- | synapse/state/v2.py | 92 |
3 files changed, 104 insertions, 151 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 4b136b3054..1b454a56a1 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -107,8 +107,9 @@ class StateHandler(object): self._state_resolution_handler = hs.get_state_resolution_handler() @defer.inlineCallbacks - def get_current_state(self, room_id, event_type=None, state_key="", - latest_event_ids=None): + def get_current_state( + self, room_id, event_type=None, state_key="", latest_event_ids=None + ): """ Retrieves the current state for the room. This is done by calling `get_latest_events_in_room` to get the leading edges of the event graph and then resolving any of the state conflicts. @@ -137,8 +138,9 @@ class StateHandler(object): defer.returnValue(event) return - state_map = yield self.store.get_events(list(state.values()), - get_prev_content=False) + state_map = yield self.store.get_events( + list(state.values()), get_prev_content=False + ) state = { key: state_map[e_id] for key, e_id in iteritems(state) if e_id in state_map } @@ -220,9 +222,7 @@ class StateHandler(object): # state. Certainly store.get_current_state won't return any, and # persisting the event won't store the state group. if old_state: - prev_state_ids = { - (s.type, s.state_key): s.event_id for s in old_state - } + prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state} if event.is_state(): current_state_ids = dict(prev_state_ids) key = (event.type, event.state_key) @@ -248,9 +248,7 @@ class StateHandler(object): # Let's just correctly fill out the context and create a # new state group for it. - prev_state_ids = { - (s.type, s.state_key): s.event_id for s in old_state - } + prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state} if event.is_state(): key = (event.type, event.state_key) @@ -282,7 +280,7 @@ class StateHandler(object): logger.debug("calling resolve_state_groups from compute_event_context") entry = yield self.resolve_state_groups_for_events( - event.room_id, event.prev_event_ids(), + event.room_id, event.prev_event_ids() ) prev_state_ids = entry.state @@ -305,9 +303,7 @@ class StateHandler(object): # If the state at the event has a state group assigned then # we can use that as the prev group prev_group = entry.state_group - delta_ids = { - key: event.event_id - } + delta_ids = {key: event.event_id} elif entry.prev_group: # If the state at the event only has a prev group, then we can # use that as a prev group too. @@ -369,31 +365,31 @@ class StateHandler(object): # map from state group id to the state in that state group (where # 'state' is a map from state key to event id) # dict[int, dict[(str, str), str]] - state_groups_ids = yield self.store.get_state_groups_ids( - room_id, event_ids - ) + state_groups_ids = yield self.store.get_state_groups_ids(room_id, event_ids) if len(state_groups_ids) == 0: - defer.returnValue(_StateCacheEntry( - state={}, - state_group=None, - )) + defer.returnValue(_StateCacheEntry(state={}, state_group=None)) elif len(state_groups_ids) == 1: name, state_list = list(state_groups_ids.items()).pop() prev_group, delta_ids = yield self.store.get_state_group_delta(name) - defer.returnValue(_StateCacheEntry( - state=state_list, - state_group=name, - prev_group=prev_group, - delta_ids=delta_ids, - )) + defer.returnValue( + _StateCacheEntry( + state=state_list, + state_group=name, + prev_group=prev_group, + delta_ids=delta_ids, + ) + ) room_version = yield self.store.get_room_version(room_id) result = yield self._state_resolution_handler.resolve_state_groups( - room_id, room_version, state_groups_ids, None, + room_id, + room_version, + state_groups_ids, + None, state_res_store=StateResolutionStore(self.store), ) defer.returnValue(result) @@ -403,27 +399,21 @@ class StateHandler(object): logger.info( "Resolving state for %s with %d groups", event.room_id, len(state_sets) ) - state_set_ids = [{ - (ev.type, ev.state_key): ev.event_id - for ev in st - } for st in state_sets] - - state_map = { - ev.event_id: ev - for st in state_sets - for ev in st - } + state_set_ids = [ + {(ev.type, ev.state_key): ev.event_id for ev in st} for st in state_sets + ] + + state_map = {ev.event_id: ev for st in state_sets for ev in st} with Measure(self.clock, "state._resolve_events"): new_state = yield resolve_events_with_store( - room_version, state_set_ids, + room_version, + state_set_ids, event_map=state_map, state_res_store=StateResolutionStore(self.store), ) - new_state = { - key: state_map[ev_id] for key, ev_id in iteritems(new_state) - } + new_state = {key: state_map[ev_id] for key, ev_id in iteritems(new_state)} defer.returnValue(new_state) @@ -434,6 +424,7 @@ class StateResolutionHandler(object): Note that the storage layer depends on this handler, so all functions must be storage-independent. """ + def __init__(self, hs): self.clock = hs.get_clock() @@ -453,7 +444,7 @@ class StateResolutionHandler(object): @defer.inlineCallbacks @log_function def resolve_state_groups( - self, room_id, room_version, state_groups_ids, event_map, state_res_store, + self, room_id, room_version, state_groups_ids, event_map, state_res_store ): """Resolves conflicts between a set of state groups @@ -480,10 +471,7 @@ class StateResolutionHandler(object): Returns: Deferred[_StateCacheEntry]: resolved state """ - logger.debug( - "resolve_state_groups state_groups %s", - state_groups_ids.keys() - ) + logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys()) group_names = frozenset(state_groups_ids.keys()) @@ -540,10 +528,7 @@ class StateResolutionHandler(object): defer.returnValue(cache) -def _make_state_cache_entry( - new_state, - state_groups_ids, -): +def _make_state_cache_entry(new_state, state_groups_ids): """Given a resolved state, and a set of input state groups, pick one to base a new state group on (if any), and return an appropriately-constructed _StateCacheEntry. @@ -573,10 +558,7 @@ def _make_state_cache_entry( old_state_event_ids = set(itervalues(state)) if new_state_event_ids == old_state_event_ids: # got an exact match. - return _StateCacheEntry( - state=new_state, - state_group=sg, - ) + return _StateCacheEntry(state=new_state, state_group=sg) # TODO: We want to create a state group for this set of events, to # increase cache hits, but we need to make sure that it doesn't @@ -587,20 +569,13 @@ def _make_state_cache_entry( delta_ids = None for old_group, old_state in iteritems(state_groups_ids): - n_delta_ids = { - k: v - for k, v in iteritems(new_state) - if old_state.get(k) != v - } + n_delta_ids = {k: v for k, v in iteritems(new_state) if old_state.get(k) != v} if not delta_ids or len(n_delta_ids) < len(delta_ids): prev_group = old_group delta_ids = n_delta_ids return _StateCacheEntry( - state=new_state, - state_group=None, - prev_group=prev_group, - delta_ids=delta_ids, + state=new_state, state_group=None, prev_group=prev_group, delta_ids=delta_ids ) @@ -629,11 +604,11 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto v = KNOWN_ROOM_VERSIONS[room_version] if v.state_res == StateResolutionVersions.V1: return v1.resolve_events_with_store( - state_sets, event_map, state_res_store.get_events, + state_sets, event_map, state_res_store.get_events ) else: return v2.resolve_events_with_store( - room_version, state_sets, event_map, state_res_store, + room_version, state_sets, event_map, state_res_store ) diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 29b4e86cfd..88acd4817e 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -57,23 +57,17 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): if len(state_sets) == 1: defer.returnValue(state_sets[0]) - unconflicted_state, conflicted_state = _seperate( - state_sets, - ) + unconflicted_state, conflicted_state = _seperate(state_sets) needed_events = set( - event_id - for event_ids in itervalues(conflicted_state) - for event_id in event_ids + event_id for event_ids in itervalues(conflicted_state) for event_id in event_ids ) needed_event_count = len(needed_events) if event_map is not None: needed_events -= set(iterkeys(event_map)) logger.info( - "Asking for %d/%d conflicted events", - len(needed_events), - needed_event_count, + "Asking for %d/%d conflicted events", len(needed_events), needed_event_count ) # dict[str, FrozenEvent]: a map from state event id to event. Only includes @@ -97,17 +91,17 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): new_needed_events -= set(iterkeys(event_map)) logger.info( - "Asking for %d/%d auth events", - len(new_needed_events), - new_needed_event_count, + "Asking for %d/%d auth events", len(new_needed_events), new_needed_event_count ) state_map_new = yield state_map_factory(new_needed_events) state_map.update(state_map_new) - defer.returnValue(_resolve_with_state( - unconflicted_state, conflicted_state, auth_events, state_map - )) + defer.returnValue( + _resolve_with_state( + unconflicted_state, conflicted_state, auth_events, state_map + ) + ) def _seperate(state_sets): @@ -173,8 +167,9 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma return auth_events -def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event_ids, - state_map): +def _resolve_with_state( + unconflicted_state_ids, conflicted_state_ids, auth_event_ids, state_map +): conflicted_state = {} for key, event_ids in iteritems(conflicted_state_ids): events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map] @@ -190,9 +185,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event } try: - resolved_state = _resolve_state_events( - conflicted_state, auth_events - ) + resolved_state = _resolve_state_events(conflicted_state, auth_events) except Exception: logger.exception("Failed to resolve state") raise @@ -218,37 +211,28 @@ def _resolve_state_events(conflicted_state, auth_events): if POWER_KEY in conflicted_state: events = conflicted_state[POWER_KEY] logger.debug("Resolving conflicted power levels %r", events) - resolved_state[POWER_KEY] = _resolve_auth_events( - events, auth_events) + resolved_state[POWER_KEY] = _resolve_auth_events(events, auth_events) auth_events.update(resolved_state) for key, events in iteritems(conflicted_state): if key[0] == EventTypes.JoinRules: logger.debug("Resolving conflicted join rules %r", events) - resolved_state[key] = _resolve_auth_events( - events, - auth_events - ) + resolved_state[key] = _resolve_auth_events(events, auth_events) auth_events.update(resolved_state) for key, events in iteritems(conflicted_state): if key[0] == EventTypes.Member: logger.debug("Resolving conflicted member lists %r", events) - resolved_state[key] = _resolve_auth_events( - events, - auth_events - ) + resolved_state[key] = _resolve_auth_events(events, auth_events) auth_events.update(resolved_state) for key, events in iteritems(conflicted_state): if key not in resolved_state: logger.debug("Resolving conflicted state %r:%r", key, events) - resolved_state[key] = _resolve_normal_events( - events, auth_events - ) + resolved_state[key] = _resolve_normal_events(events, auth_events) return resolved_state @@ -257,9 +241,7 @@ def _resolve_auth_events(events, auth_events): reverse = [i for i in reversed(_ordered_events(events))] auth_keys = set( - key - for event in events - for key in event_auth.auth_types_for_event(event) + key for event in events for key in event_auth.auth_types_for_event(event) ) new_auth_events = {} @@ -313,6 +295,6 @@ def _ordered_events(events): def key_func(e): # we have to use utf-8 rather than ascii here because it turns out we allow # people to send us events with non-ascii event IDs :/ - return -int(e.depth), hashlib.sha1(e.event_id.encode('utf-8')).hexdigest() + return -int(e.depth), hashlib.sha1(e.event_id.encode("utf-8")).hexdigest() return sorted(events, key=key_func) diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 650995c92c..db969e8997 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -70,19 +70,18 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto # Also fetch all auth events that appear in only some of the state sets' # auth chains. - auth_diff = yield _get_auth_chain_difference( - state_sets, event_map, state_res_store, - ) + auth_diff = yield _get_auth_chain_difference(state_sets, event_map, state_res_store) - full_conflicted_set = set(itertools.chain( - itertools.chain.from_iterable(itervalues(conflicted_state)), - auth_diff, - )) + full_conflicted_set = set( + itertools.chain( + itertools.chain.from_iterable(itervalues(conflicted_state)), auth_diff + ) + ) - events = yield state_res_store.get_events([ - eid for eid in full_conflicted_set - if eid not in event_map - ], allow_rejected=True) + events = yield state_res_store.get_events( + [eid for eid in full_conflicted_set if eid not in event_map], + allow_rejected=True, + ) event_map.update(events) full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map) @@ -91,22 +90,21 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto # Get and sort all the power events (kicks/bans/etc) power_events = ( - eid for eid in full_conflicted_set - if _is_power_event(event_map[eid]) + eid for eid in full_conflicted_set if _is_power_event(event_map[eid]) ) sorted_power_events = yield _reverse_topological_power_sort( - power_events, - event_map, - state_res_store, - full_conflicted_set, + power_events, event_map, state_res_store, full_conflicted_set ) logger.debug("sorted %d power events", len(sorted_power_events)) # Now sequentially auth each one resolved_state = yield _iterative_auth_checks( - room_version, sorted_power_events, unconflicted_state, event_map, + room_version, + sorted_power_events, + unconflicted_state, + event_map, state_res_store, ) @@ -116,23 +114,20 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto # events using the mainline of the resolved power level. leftover_events = [ - ev_id - for ev_id in full_conflicted_set - if ev_id not in sorted_power_events + ev_id for ev_id in full_conflicted_set if ev_id not in sorted_power_events ] logger.debug("sorting %d remaining events", len(leftover_events)) pl = resolved_state.get((EventTypes.PowerLevels, ""), None) leftover_events = yield _mainline_sort( - leftover_events, pl, event_map, state_res_store, + leftover_events, pl, event_map, state_res_store ) logger.debug("resolving remaining events") resolved_state = yield _iterative_auth_checks( - room_version, leftover_events, resolved_state, event_map, - state_res_store, + room_version, leftover_events, resolved_state, event_map, state_res_store ) logger.debug("resolved") @@ -209,14 +204,16 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store): auth_ids = set( eid for key, eid in iteritems(state_set) - if (key[0] in ( - EventTypes.Member, - EventTypes.ThirdPartyInvite, - ) or key in ( - (EventTypes.PowerLevels, ''), - (EventTypes.Create, ''), - (EventTypes.JoinRules, ''), - )) and eid not in common + if ( + key[0] in (EventTypes.Member, EventTypes.ThirdPartyInvite) + or key + in ( + (EventTypes.PowerLevels, ""), + (EventTypes.Create, ""), + (EventTypes.JoinRules, ""), + ) + ) + and eid not in common ) auth_chain = yield state_res_store.get_auth_chain(auth_ids) @@ -274,15 +271,16 @@ def _is_power_event(event): return True if event.type == EventTypes.Member: - if event.membership in ('leave', 'ban'): + if event.membership in ("leave", "ban"): return event.sender != event.state_key return False @defer.inlineCallbacks -def _add_event_and_auth_chain_to_graph(graph, event_id, event_map, - state_res_store, auth_diff): +def _add_event_and_auth_chain_to_graph( + graph, event_id, event_map, state_res_store, auth_diff +): """Helper function for _reverse_topological_power_sort that add the event and its auth chain (that is in the auth diff) to the graph @@ -327,7 +325,7 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_ graph = {} for event_id in event_ids: yield _add_event_and_auth_chain_to_graph( - graph, event_id, event_map, state_res_store, auth_diff, + graph, event_id, event_map, state_res_store, auth_diff ) event_to_pl = {} @@ -342,18 +340,16 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_ return -pl, ev.origin_server_ts, event_id # Note: graph is modified during the sort - it = lexicographical_topological_sort( - graph, - key=_get_power_order, - ) + it = lexicographical_topological_sort(graph, key=_get_power_order) sorted_events = list(it) defer.returnValue(sorted_events) @defer.inlineCallbacks -def _iterative_auth_checks(room_version, event_ids, base_state, event_map, - state_res_store): +def _iterative_auth_checks( + room_version, event_ids, base_state, event_map, state_res_store +): """Sequentially apply auth checks to each event in given list, updating the state as it goes along. @@ -389,9 +385,11 @@ def _iterative_auth_checks(room_version, event_ids, base_state, event_map, try: event_auth.check( - room_version, event, auth_events, + room_version, + event, + auth_events, do_sig_check=False, - do_size_check=False + do_size_check=False, ) resolved_state[(event.type, event.state_key)] = event_id @@ -402,8 +400,7 @@ def _iterative_auth_checks(room_version, event_ids, base_state, event_map, @defer.inlineCallbacks -def _mainline_sort(event_ids, resolved_power_event_id, event_map, - state_res_store): +def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_store): """Returns a sorted list of event_ids sorted by mainline ordering based on the given event resolved_power_event_id @@ -436,8 +433,7 @@ def _mainline_sort(event_ids, resolved_power_event_id, event_map, order_map = {} for ev_id in event_ids: depth = yield _get_mainline_depth_for_event( - event_map[ev_id], mainline_map, - event_map, state_res_store, + event_map[ev_id], mainline_map, event_map, state_res_store ) order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id) |