diff options
Diffstat (limited to 'synapse/state/v1.py')
-rw-r--r-- | synapse/state/v1.py | 49 |
1 files changed, 37 insertions, 12 deletions
diff --git a/synapse/state/v1.py b/synapse/state/v1.py index a2f92d9ff9..9bf98d06f2 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -15,6 +15,7 @@ import hashlib import logging +from typing import Callable, Dict, List, Optional from six import iteritems, iterkeys, itervalues @@ -24,6 +25,8 @@ from synapse import event_auth from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersions +from synapse.events import EventBase +from synapse.types import StateMap logger = logging.getLogger(__name__) @@ -32,13 +35,20 @@ POWER_KEY = (EventTypes.PowerLevels, "") @defer.inlineCallbacks -def resolve_events_with_store(state_sets, event_map, state_map_factory): +def resolve_events_with_store( + room_id: str, + state_sets: List[StateMap[str]], + event_map: Optional[Dict[str, EventBase]], + state_map_factory: Callable, +): """ Args: - state_sets(list): List of dicts of (type, state_key) -> event_id, + room_id: the room we are working in + + state_sets: List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. - event_map(dict[str,FrozenEvent]|None): + event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be used as a starting point fof finding the state we need; any missing @@ -46,11 +56,11 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): If None, all events will be fetched via state_map_factory. - state_map_factory(func): will be called + state_map_factory: will be called with a list of event_ids that are needed, and should return with a Deferred of dict of event_id to event. - Returns + Returns: Deferred[dict[(str, str), str]]: a map from (type, state_key) to event_id. """ @@ -59,9 +69,9 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): unconflicted_state, conflicted_state = _seperate(state_sets) - needed_events = set( + needed_events = { event_id for event_ids in itervalues(conflicted_state) for event_id in event_ids - ) + } needed_event_count = len(needed_events) if event_map is not None: needed_events -= set(iterkeys(event_map)) @@ -76,6 +86,14 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): if event_map is not None: state_map.update(event_map) + # everything in the state map should be in the right room + for event in state_map.values(): + if event.room_id != room_id: + raise Exception( + "Attempting to state-resolve for room %s with event %s which is in %s" + % (room_id, event.event_id, event.room_id,) + ) + # get the ids of the auth events which allow us to authenticate the # conflicted state, picking only from the unconflicting state. # @@ -95,6 +113,13 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): ) state_map_new = yield state_map_factory(new_needed_events) + for event in state_map_new.values(): + if event.room_id != room_id: + raise Exception( + "Attempting to state-resolve for room %s with event %s which is in %s" + % (room_id, event.event_id, event.room_id,) + ) + state_map.update(state_map_new) return _resolve_with_state( @@ -236,11 +261,11 @@ def _resolve_state_events(conflicted_state, auth_events): def _resolve_auth_events(events, auth_events): - reverse = [i for i in reversed(_ordered_events(events))] + reverse = list(reversed(_ordered_events(events))) - auth_keys = set( + auth_keys = { key for event in events for key in event_auth.auth_types_for_event(event) - ) + } new_auth_events = {} for key in auth_keys: @@ -256,7 +281,7 @@ def _resolve_auth_events(events, auth_events): try: # The signatures have already been checked at this point event_auth.check( - RoomVersions.V1.identifier, + RoomVersions.V1, event, auth_events, do_sig_check=False, @@ -274,7 +299,7 @@ def _resolve_normal_events(events, auth_events): try: # The signatures have already been checked at this point event_auth.check( - RoomVersions.V1.identifier, + RoomVersions.V1, event, auth_events, do_sig_check=False, |