diff options
Diffstat (limited to 'synapse/state')
-rw-r--r-- | synapse/state/__init__.py | 269 | ||||
-rw-r--r-- | synapse/state/v1.py | 49 | ||||
-rw-r--r-- | synapse/state/v2.py | 181 |
3 files changed, 293 insertions, 206 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 2b0f4c79ee..2fa529fcd0 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -16,6 +16,7 @@ import logging from collections import namedtuple +from typing import Dict, Iterable, List, Optional, Set from six import iteritems, itervalues @@ -27,13 +28,15 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions +from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.logging.utils import log_function from synapse.state import v1, v2 +from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour +from synapse.types import StateMap from synapse.util.async_helpers import Linearizer -from synapse.util.caches import get_cache_factor_for from synapse.util.caches.expiringcache import ExpiringCache -from synapse.util.metrics import Measure +from synapse.util.metrics import Measure, measure_func logger = logging.getLogger(__name__) @@ -49,7 +52,6 @@ state_groups_histogram = Histogram( KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) -SIZE_OF_CACHE = 100000 * get_cache_factor_for("state_cache") EVICTION_TIMEOUT_SECONDS = 60 * 60 @@ -103,6 +105,7 @@ class StateHandler(object): def __init__(self, hs): self.clock = hs.get_clock() self.store = hs.get_datastore() + self.state_store = hs.get_storage().state self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() @@ -191,24 +194,37 @@ class StateHandler(object): return joined_users @defer.inlineCallbacks - def get_current_hosts_in_room(self, room_id, latest_event_ids=None): - if not latest_event_ids: - latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - logger.debug("calling resolve_state_groups from get_current_hosts_in_room") - entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) + def get_current_hosts_in_room(self, room_id): + event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + return (yield self.get_hosts_in_room_at_events(room_id, event_ids)) + + @defer.inlineCallbacks + def get_hosts_in_room_at_events(self, room_id, event_ids): + """Get the hosts that were in a room at the given event ids + + Args: + room_id (str): + event_ids (list[str]): + + Returns: + Deferred[list[str]]: the hosts in the room at the given events + """ + entry = yield self.resolve_state_groups_for_events(room_id, event_ids) joined_hosts = yield self.store.get_joined_hosts(room_id, entry) return joined_hosts @defer.inlineCallbacks - def compute_event_context(self, event, old_state=None): + def compute_event_context( + self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None + ): """Build an EventContext structure for the event. This works out what the current state should be for the event, and generates a new state group if necessary. Args: - event (synapse.events.EventBase): - old_state (dict|None): The state at the event if it can't be + event: + old_state: The state at the event if it can't be calculated from existing events. This is normally only specified when receiving an event from federation where we don't have the prev events for, e.g. when backfilling. @@ -220,6 +236,9 @@ class StateHandler(object): # If this is an outlier, then we know it shouldn't have any current # state. Certainly store.get_current_state won't return any, and # persisting the event won't store the state group. + + # FIXME: why do we populate current_state_ids? I thought the point was + # that we weren't supposed to have any state for outliers? if old_state: prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state} if event.is_state(): @@ -236,114 +255,105 @@ class StateHandler(object): # group for it. context = EventContext.with_state( state_group=None, + state_group_before_event=None, current_state_ids=current_state_ids, prev_state_ids=prev_state_ids, ) return context + # + # first of all, figure out the state before the event + # + if old_state: - # We already have the state, so we don't need to calculate it. - # Let's just correctly fill out the context and create a - # new state group for it. - - prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state} - - if event.is_state(): - key = (event.type, event.state_key) - if key in prev_state_ids: - replaces = prev_state_ids[key] - if replaces != event.event_id: # Paranoia check - event.unsigned["replaces_state"] = replaces - current_state_ids = dict(prev_state_ids) - current_state_ids[key] = event.event_id - else: - current_state_ids = prev_state_ids + # if we're given the state before the event, then we use that + state_ids_before_event = { + (s.type, s.state_key): s.event_id for s in old_state + } + state_group_before_event = None + state_group_before_event_prev_group = None + deltas_to_state_group_before_event = None - state_group = yield self.store.store_state_group( - event.event_id, - event.room_id, - prev_group=None, - delta_ids=None, - current_state_ids=current_state_ids, - ) + else: + # otherwise, we'll need to resolve the state across the prev_events. + logger.debug("calling resolve_state_groups from compute_event_context") - context = EventContext.with_state( - state_group=state_group, - current_state_ids=current_state_ids, - prev_state_ids=prev_state_ids, + entry = yield self.resolve_state_groups_for_events( + event.room_id, event.prev_event_ids() ) - return context + state_ids_before_event = entry.state + state_group_before_event = entry.state_group + state_group_before_event_prev_group = entry.prev_group + deltas_to_state_group_before_event = entry.delta_ids - logger.debug("calling resolve_state_groups from compute_event_context") + # + # make sure that we have a state group at that point. If it's not a state event, + # that will be the state group for the new event. If it *is* a state event, + # it might get rejected (in which case we'll need to persist it with the + # previous state group) + # - entry = yield self.resolve_state_groups_for_events( - event.room_id, event.prev_event_ids() - ) + if not state_group_before_event: + state_group_before_event = yield self.state_store.store_state_group( + event.event_id, + event.room_id, + prev_group=state_group_before_event_prev_group, + delta_ids=deltas_to_state_group_before_event, + current_state_ids=state_ids_before_event, + ) - prev_state_ids = entry.state - prev_group = None - delta_ids = None + # XXX: can we update the state cache entry for the new state group? or + # could we set a flag on resolve_state_groups_for_events to tell it to + # always make a state group? + + # + # now if it's not a state event, we're done + # + + if not event.is_state(): + return EventContext.with_state( + state_group_before_event=state_group_before_event, + state_group=state_group_before_event, + current_state_ids=state_ids_before_event, + prev_state_ids=state_ids_before_event, + prev_group=state_group_before_event_prev_group, + delta_ids=deltas_to_state_group_before_event, + ) - if event.is_state(): - # If this is a state event then we need to create a new state - # group for the state after this event. + # + # otherwise, we'll need to create a new state group for after the event + # - key = (event.type, event.state_key) - if key in prev_state_ids: - replaces = prev_state_ids[key] + key = (event.type, event.state_key) + if key in state_ids_before_event: + replaces = state_ids_before_event[key] + if replaces != event.event_id: event.unsigned["replaces_state"] = replaces - current_state_ids = dict(prev_state_ids) - current_state_ids[key] = event.event_id - - if entry.state_group: - # If the state at the event has a state group assigned then - # we can use that as the prev group - prev_group = entry.state_group - delta_ids = {key: event.event_id} - elif entry.prev_group: - # If the state at the event only has a prev group, then we can - # use that as a prev group too. - prev_group = entry.prev_group - delta_ids = dict(entry.delta_ids) - delta_ids[key] = event.event_id - - state_group = yield self.store.store_state_group( - event.event_id, - event.room_id, - prev_group=prev_group, - delta_ids=delta_ids, - current_state_ids=current_state_ids, - ) - else: - current_state_ids = prev_state_ids - prev_group = entry.prev_group - delta_ids = entry.delta_ids - - if entry.state_group is None: - entry.state_group = yield self.store.store_state_group( - event.event_id, - event.room_id, - prev_group=entry.prev_group, - delta_ids=entry.delta_ids, - current_state_ids=current_state_ids, - ) - entry.state_id = entry.state_group - - state_group = entry.state_group - - context = EventContext.with_state( - state_group=state_group, - current_state_ids=current_state_ids, - prev_state_ids=prev_state_ids, - prev_group=prev_group, + state_ids_after_event = dict(state_ids_before_event) + state_ids_after_event[key] = event.event_id + delta_ids = {key: event.event_id} + + state_group_after_event = yield self.state_store.store_state_group( + event.event_id, + event.room_id, + prev_group=state_group_before_event, delta_ids=delta_ids, + current_state_ids=state_ids_after_event, ) - return context + return EventContext.with_state( + state_group=state_group_after_event, + state_group_before_event=state_group_before_event, + current_state_ids=state_ids_after_event, + prev_state_ids=state_ids_before_event, + prev_group=state_group_before_event, + delta_ids=delta_ids, + ) + @measure_func() @defer.inlineCallbacks def resolve_state_groups_for_events(self, room_id, event_ids): """ Given a list of event_ids this method fetches the state at each @@ -364,14 +374,16 @@ class StateHandler(object): # map from state group id to the state in that state group (where # 'state' is a map from state key to event id) # dict[int, dict[(str, str), str]] - state_groups_ids = yield self.store.get_state_groups_ids(room_id, event_ids) + state_groups_ids = yield self.state_store.get_state_groups_ids( + room_id, event_ids + ) if len(state_groups_ids) == 0: return _StateCacheEntry(state={}, state_group=None) elif len(state_groups_ids) == 1: name, state_list = list(state_groups_ids.items()).pop() - prev_group, delta_ids = yield self.store.get_state_group_delta(name) + prev_group, delta_ids = yield self.state_store.get_state_group_delta(name) return _StateCacheEntry( state=state_list, @@ -380,7 +392,7 @@ class StateHandler(object): delta_ids=delta_ids, ) - room_version = yield self.store.get_room_version(room_id) + room_version = yield self.store.get_room_version_id(room_id) result = yield self._state_resolution_handler.resolve_state_groups( room_id, @@ -404,6 +416,7 @@ class StateHandler(object): with Measure(self.clock, "state._resolve_events"): new_state = yield resolve_events_with_store( + event.room_id, room_version, state_set_ids, event_map=state_map, @@ -432,7 +445,7 @@ class StateResolutionHandler(object): self._state_cache = ExpiringCache( cache_name="state_cache", clock=self.clock, - max_len=SIZE_OF_CACHE, + max_len=100000, expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, iterable=True, reset_expiry_on_get=True, @@ -449,7 +462,7 @@ class StateResolutionHandler(object): not be called for a single state group Args: - room_id (str): room we are resolving for (used for logging) + room_id (str): room we are resolving for (used for logging and sanity checks) room_version (str): version of the room state_groups_ids (dict[int, dict[(str, str), str]]): map from state group id to the state in that state group @@ -505,6 +518,7 @@ class StateResolutionHandler(object): logger.info("Resolving conflicted state for %r", room_id) with Measure(self.clock, "state._resolve_events"): new_state = yield resolve_events_with_store( + room_id, room_version, list(itervalues(state_groups_ids)), event_map=event_map, @@ -576,36 +590,44 @@ def _make_state_cache_entry(new_state, state_groups_ids): ) -def resolve_events_with_store(room_version, state_sets, event_map, state_res_store): +def resolve_events_with_store( + room_id: str, + room_version: str, + state_sets: List[StateMap[str]], + event_map: Optional[Dict[str, EventBase]], + state_res_store: "StateResolutionStore", +): """ Args: - room_version(str): Version of the room + room_id: the room we are working in - state_sets(list): List of dicts of (type, state_key) -> event_id, + room_version: Version of the room + + state_sets: List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. - event_map(dict[str,FrozenEvent]|None): + event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be used as a starting point fof finding the state we need; any missing events will be requested via state_map_factory. - If None, all events will be fetched via state_map_factory. + If None, all events will be fetched via state_res_store. - state_res_store (StateResolutionStore) + state_res_store: a place to fetch events from - Returns + Returns: Deferred[dict[(str, str), str]]: a map from (type, state_key) to event_id. """ v = KNOWN_ROOM_VERSIONS[room_version] if v.state_res == StateResolutionVersions.V1: return v1.resolve_events_with_store( - state_sets, event_map, state_res_store.get_events + room_id, state_sets, event_map, state_res_store.get_events ) else: return v2.resolve_events_with_store( - room_version, state_sets, event_map, state_res_store + room_id, room_version, state_sets, event_map, state_res_store ) @@ -633,28 +655,21 @@ class StateResolutionStore(object): return self.store.get_events( event_ids, - check_redacted=False, + redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, allow_rejected=allow_rejected, ) - def get_auth_chain(self, event_ids): - """Gets the full auth chain for a set of events (including rejected - events). - - Includes the given event IDs in the result. + def get_auth_chain_difference(self, state_sets: List[Set[str]]): + """Given sets of state events figure out the auth chain difference (as + per state res v2 algorithm). - Note that: - 1. All events must be state events. - 2. For v1 rooms this may not have the full auth chain in the - presence of rejected events - - Args: - event_ids (list): The event IDs of the events to fetch the auth - chain for. Must be state events. + This equivalent to fetching the full auth chain for each set of state + and returning the events that don't appear in each and every auth + chain. Returns: - Deferred[list[str]]: List of event IDs of the auth chain. + Deferred[Set[str]]: Set of event IDs. """ - return self.store.get_auth_chain_ids(event_ids, include_given=True) + return self.store.get_auth_chain_difference(state_sets) diff --git a/synapse/state/v1.py b/synapse/state/v1.py index a2f92d9ff9..9bf98d06f2 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -15,6 +15,7 @@ import hashlib import logging +from typing import Callable, Dict, List, Optional from six import iteritems, iterkeys, itervalues @@ -24,6 +25,8 @@ from synapse import event_auth from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersions +from synapse.events import EventBase +from synapse.types import StateMap logger = logging.getLogger(__name__) @@ -32,13 +35,20 @@ POWER_KEY = (EventTypes.PowerLevels, "") @defer.inlineCallbacks -def resolve_events_with_store(state_sets, event_map, state_map_factory): +def resolve_events_with_store( + room_id: str, + state_sets: List[StateMap[str]], + event_map: Optional[Dict[str, EventBase]], + state_map_factory: Callable, +): """ Args: - state_sets(list): List of dicts of (type, state_key) -> event_id, + room_id: the room we are working in + + state_sets: List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. - event_map(dict[str,FrozenEvent]|None): + event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be used as a starting point fof finding the state we need; any missing @@ -46,11 +56,11 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): If None, all events will be fetched via state_map_factory. - state_map_factory(func): will be called + state_map_factory: will be called with a list of event_ids that are needed, and should return with a Deferred of dict of event_id to event. - Returns + Returns: Deferred[dict[(str, str), str]]: a map from (type, state_key) to event_id. """ @@ -59,9 +69,9 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): unconflicted_state, conflicted_state = _seperate(state_sets) - needed_events = set( + needed_events = { event_id for event_ids in itervalues(conflicted_state) for event_id in event_ids - ) + } needed_event_count = len(needed_events) if event_map is not None: needed_events -= set(iterkeys(event_map)) @@ -76,6 +86,14 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): if event_map is not None: state_map.update(event_map) + # everything in the state map should be in the right room + for event in state_map.values(): + if event.room_id != room_id: + raise Exception( + "Attempting to state-resolve for room %s with event %s which is in %s" + % (room_id, event.event_id, event.room_id,) + ) + # get the ids of the auth events which allow us to authenticate the # conflicted state, picking only from the unconflicting state. # @@ -95,6 +113,13 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): ) state_map_new = yield state_map_factory(new_needed_events) + for event in state_map_new.values(): + if event.room_id != room_id: + raise Exception( + "Attempting to state-resolve for room %s with event %s which is in %s" + % (room_id, event.event_id, event.room_id,) + ) + state_map.update(state_map_new) return _resolve_with_state( @@ -236,11 +261,11 @@ def _resolve_state_events(conflicted_state, auth_events): def _resolve_auth_events(events, auth_events): - reverse = [i for i in reversed(_ordered_events(events))] + reverse = list(reversed(_ordered_events(events))) - auth_keys = set( + auth_keys = { key for event in events for key in event_auth.auth_types_for_event(event) - ) + } new_auth_events = {} for key in auth_keys: @@ -256,7 +281,7 @@ def _resolve_auth_events(events, auth_events): try: # The signatures have already been checked at this point event_auth.check( - RoomVersions.V1.identifier, + RoomVersions.V1, event, auth_events, do_sig_check=False, @@ -274,7 +299,7 @@ def _resolve_normal_events(events, auth_events): try: # The signatures have already been checked at this point event_auth.check( - RoomVersions.V1.identifier, + RoomVersions.V1, event, auth_events, do_sig_check=False, diff --git a/synapse/state/v2.py b/synapse/state/v2.py index b327c86f40..18484e2fa6 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -16,29 +16,42 @@ import heapq import itertools import logging +from typing import Dict, List, Optional from six import iteritems, itervalues from twisted.internet import defer +import synapse.state from synapse import event_auth from synapse.api.constants import EventTypes from synapse.api.errors import AuthError +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.events import EventBase +from synapse.types import StateMap logger = logging.getLogger(__name__) @defer.inlineCallbacks -def resolve_events_with_store(room_version, state_sets, event_map, state_res_store): +def resolve_events_with_store( + room_id: str, + room_version: str, + state_sets: List[StateMap[str]], + event_map: Optional[Dict[str, EventBase]], + state_res_store: "synapse.state.StateResolutionStore", +): """Resolves the state using the v2 state resolution algorithm Args: - room_version (str): The room version + room_id: the room we are working in + + room_version: The room version - state_sets(list): List of dicts of (type, state_key) -> event_id, + state_sets: List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. - event_map(dict[str,FrozenEvent]|None): + event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be used as a starting point fof finding the state we need; any missing @@ -46,9 +59,9 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto If None, all events will be fetched via state_res_store. - state_res_store (StateResolutionStore) + state_res_store: - Returns + Returns: Deferred[dict[(str, str), str]]: a map from (type, state_key) to event_id. """ @@ -84,7 +97,15 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto ) event_map.update(events) - full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map) + # everything in the event map should be in the right room + for event in event_map.values(): + if event.room_id != room_id: + raise Exception( + "Attempting to state-resolve for room %s with event %s which is in %s" + % (room_id, event.event_id, event.room_id,) + ) + + full_conflicted_set = {eid for eid in full_conflicted_set if eid in event_map} logger.debug("%d full_conflicted_set entries", len(full_conflicted_set)) @@ -94,13 +115,14 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto ) sorted_power_events = yield _reverse_topological_power_sort( - power_events, event_map, state_res_store, full_conflicted_set + room_id, power_events, event_map, state_res_store, full_conflicted_set ) logger.debug("sorted %d power events", len(sorted_power_events)) # Now sequentially auth each one resolved_state = yield _iterative_auth_checks( + room_id, room_version, sorted_power_events, unconflicted_state, @@ -121,13 +143,18 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto pl = resolved_state.get((EventTypes.PowerLevels, ""), None) leftover_events = yield _mainline_sort( - leftover_events, pl, event_map, state_res_store + room_id, leftover_events, pl, event_map, state_res_store ) logger.debug("resolving remaining events") resolved_state = yield _iterative_auth_checks( - room_version, leftover_events, resolved_state, event_map, state_res_store + room_id, + room_version, + leftover_events, + resolved_state, + event_map, + state_res_store, ) logger.debug("resolved") @@ -141,11 +168,12 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto @defer.inlineCallbacks -def _get_power_level_for_sender(event_id, event_map, state_res_store): +def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): """Return the power level of the sender of the given event according to their auth events. Args: + room_id (str) event_id (str) event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) @@ -153,20 +181,24 @@ def _get_power_level_for_sender(event_id, event_map, state_res_store): Returns: Deferred[int] """ - event = yield _get_event(event_id, event_map, state_res_store) + event = yield _get_event(room_id, event_id, event_map, state_res_store) pl = None for aid in event.auth_event_ids(): - aev = yield _get_event(aid, event_map, state_res_store) - if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): + aev = yield _get_event( + room_id, aid, event_map, state_res_store, allow_none=True + ) + if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): pl = aev break if pl is None: # Couldn't find power level. Check if they're the creator of the room for aid in event.auth_event_ids(): - aev = yield _get_event(aid, event_map, state_res_store) - if (aev.type, aev.state_key) == (EventTypes.Create, ""): + aev = yield _get_event( + room_id, aid, event_map, state_res_store, allow_none=True + ) + if aev and (aev.type, aev.state_key) == (EventTypes.Create, ""): if aev.content.get("creator") == event.sender: return 100 break @@ -195,36 +227,12 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store): Returns: Deferred[set[str]]: Set of event IDs """ - common = set(itervalues(state_sets[0])).intersection( - *(itervalues(s) for s in state_sets[1:]) - ) - - auth_sets = [] - for state_set in state_sets: - auth_ids = set( - eid - for key, eid in iteritems(state_set) - if ( - key[0] in (EventTypes.Member, EventTypes.ThirdPartyInvite) - or key - in ( - (EventTypes.PowerLevels, ""), - (EventTypes.Create, ""), - (EventTypes.JoinRules, ""), - ) - ) - and eid not in common - ) - auth_chain = yield state_res_store.get_auth_chain(auth_ids) - auth_ids.update(auth_chain) - - auth_sets.append(auth_ids) - - intersection = set(auth_sets[0]).intersection(*auth_sets[1:]) - union = set().union(*auth_sets) + difference = yield state_res_store.get_auth_chain_difference( + [set(state_set.values()) for state_set in state_sets] + ) - return union - intersection + return difference def _seperate(state_sets): @@ -243,7 +251,7 @@ def _seperate(state_sets): conflicted_state = {} for key in set(itertools.chain.from_iterable(state_sets)): - event_ids = set(state_set.get(key) for state_set in state_sets) + event_ids = {state_set.get(key) for state_set in state_sets} if len(event_ids) == 1: unconflicted_state[key] = event_ids.pop() else: @@ -279,7 +287,7 @@ def _is_power_event(event): @defer.inlineCallbacks def _add_event_and_auth_chain_to_graph( - graph, event_id, event_map, state_res_store, auth_diff + graph, room_id, event_id, event_map, state_res_store, auth_diff ): """Helper function for _reverse_topological_power_sort that add the event and its auth chain (that is in the auth diff) to the graph @@ -287,6 +295,7 @@ def _add_event_and_auth_chain_to_graph( Args: graph (dict[str, set[str]]): A map from event ID to the events auth event IDs + room_id (str): the room we are working in event_id (str): Event to add to the graph event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) @@ -298,7 +307,7 @@ def _add_event_and_auth_chain_to_graph( eid = state.pop() graph.setdefault(eid, set()) - event = yield _get_event(eid, event_map, state_res_store) + event = yield _get_event(room_id, eid, event_map, state_res_store) for aid in event.auth_event_ids(): if aid in auth_diff: if aid not in graph: @@ -308,11 +317,14 @@ def _add_event_and_auth_chain_to_graph( @defer.inlineCallbacks -def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_diff): +def _reverse_topological_power_sort( + room_id, event_ids, event_map, state_res_store, auth_diff +): """Returns a list of the event_ids sorted by reverse topological ordering, and then by power level and origin_server_ts Args: + room_id (str): the room we are working in event_ids (list[str]): The events to sort event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) @@ -325,12 +337,14 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_ graph = {} for event_id in event_ids: yield _add_event_and_auth_chain_to_graph( - graph, event_id, event_map, state_res_store, auth_diff + graph, room_id, event_id, event_map, state_res_store, auth_diff ) event_to_pl = {} for event_id in graph: - pl = yield _get_power_level_for_sender(event_id, event_map, state_res_store) + pl = yield _get_power_level_for_sender( + room_id, event_id, event_map, state_res_store + ) event_to_pl[event_id] = pl def _get_power_order(event_id): @@ -348,44 +362,53 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_ @defer.inlineCallbacks def _iterative_auth_checks( - room_version, event_ids, base_state, event_map, state_res_store + room_id, room_version, event_ids, base_state, event_map, state_res_store ): """Sequentially apply auth checks to each event in given list, updating the state as it goes along. Args: + room_id (str) room_version (str) event_ids (list[str]): Ordered list of events to apply auth checks to - base_state (dict[tuple[str, str], str]): The set of state to start with + base_state (StateMap[str]): The set of state to start with event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) Returns: - Deferred[dict[tuple[str, str], str]]: Returns the final updated state + Deferred[StateMap[str]]: Returns the final updated state """ resolved_state = base_state.copy() + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] for event_id in event_ids: event = event_map[event_id] auth_events = {} for aid in event.auth_event_ids(): - ev = yield _get_event(aid, event_map, state_res_store) + ev = yield _get_event( + room_id, aid, event_map, state_res_store, allow_none=True + ) - if ev.rejected_reason is None: - auth_events[(ev.type, ev.state_key)] = ev + if not ev: + logger.warning( + "auth_event id %s for event %s is missing", aid, event_id + ) + else: + if ev.rejected_reason is None: + auth_events[(ev.type, ev.state_key)] = ev for key in event_auth.auth_types_for_event(event): if key in resolved_state: ev_id = resolved_state[key] - ev = yield _get_event(ev_id, event_map, state_res_store) + ev = yield _get_event(room_id, ev_id, event_map, state_res_store) if ev.rejected_reason is None: auth_events[key] = event_map[ev_id] try: event_auth.check( - room_version, + room_version_obj, event, auth_events, do_sig_check=False, @@ -400,11 +423,14 @@ def _iterative_auth_checks( @defer.inlineCallbacks -def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_store): +def _mainline_sort( + room_id, event_ids, resolved_power_event_id, event_map, state_res_store +): """Returns a sorted list of event_ids sorted by mainline ordering based on the given event resolved_power_event_id Args: + room_id (str): room we're working in event_ids (list[str]): Events to sort resolved_power_event_id (str): The final resolved power level event ID event_map (dict[str,FrozenEvent]) @@ -417,12 +443,14 @@ def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_stor pl = resolved_power_event_id while pl: mainline.append(pl) - pl_ev = yield _get_event(pl, event_map, state_res_store) + pl_ev = yield _get_event(room_id, pl, event_map, state_res_store) auth_events = pl_ev.auth_event_ids() pl = None for aid in auth_events: - ev = yield _get_event(aid, event_map, state_res_store) - if (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""): + ev = yield _get_event( + room_id, aid, event_map, state_res_store, allow_none=True + ) + if ev and (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""): pl = aid break @@ -457,6 +485,8 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor Deferred[int] """ + room_id = event.room_id + # We do an iterative search, replacing `event with the power level in its # auth events (if any) while event: @@ -468,8 +498,10 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor event = None for aid in auth_events: - aev = yield _get_event(aid, event_map, state_res_store) - if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): + aev = yield _get_event( + room_id, aid, event_map, state_res_store, allow_none=True + ) + if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): event = aev break @@ -478,22 +510,37 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor @defer.inlineCallbacks -def _get_event(event_id, event_map, state_res_store): +def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False): """Helper function to look up event in event_map, falling back to looking it up in the store Args: + room_id (str) event_id (str) event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) + allow_none (bool): if the event is not found, return None rather than raising + an exception Returns: - Deferred[FrozenEvent] + Deferred[Optional[FrozenEvent]] """ if event_id not in event_map: events = yield state_res_store.get_events([event_id], allow_rejected=True) event_map.update(events) - return event_map[event_id] + event = event_map.get(event_id) + + if event is None: + if allow_none: + return None + raise Exception("Unknown event %s" % (event_id,)) + + if event.room_id != room_id: + raise Exception( + "In state res for room %s, event %s is in %s" + % (room_id, event_id, event.room_id) + ) + return event def lexicographical_topological_sort(graph, key): |