diff options
Diffstat (limited to 'synapse/state')
-rw-r--r-- | synapse/state/__init__.py | 12 | ||||
-rw-r--r-- | synapse/state/v1.py | 40 | ||||
-rw-r--r-- | synapse/state/v2.py | 11 |
3 files changed, 42 insertions, 21 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 6223daf522..2e15471435 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -636,16 +636,20 @@ class StateResolutionHandler: """ try: with Measure(self.clock, "state._resolve_events") as m: - v = KNOWN_ROOM_VERSIONS[room_version] - if v.state_res == StateResolutionVersions.V1: + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + if room_version_obj.state_res == StateResolutionVersions.V1: return await v1.resolve_events_with_store( - room_id, state_sets, event_map, state_res_store.get_events + room_id, + room_version_obj, + state_sets, + event_map, + state_res_store.get_events, ) else: return await v2.resolve_events_with_store( self.clock, room_id, - room_version, + room_version_obj, state_sets, event_map, state_res_store, diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 267193cedf..92336d7cc8 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -29,7 +29,7 @@ from typing import ( from synapse import event_auth from synapse.api.constants import EventTypes from synapse.api.errors import AuthError -from synapse.api.room_versions import RoomVersions +from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.events import EventBase from synapse.types import MutableStateMap, StateMap @@ -41,6 +41,7 @@ POWER_KEY = (EventTypes.PowerLevels, "") async def resolve_events_with_store( room_id: str, + room_version: RoomVersion, state_sets: Sequence[StateMap[str]], event_map: Optional[Dict[str, EventBase]], state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]], @@ -104,7 +105,7 @@ async def resolve_events_with_store( # get the ids of the auth events which allow us to authenticate the # conflicted state, picking only from the unconflicting state. auth_events = _create_auth_events_from_maps( - unconflicted_state, conflicted_state, state_map + room_version, unconflicted_state, conflicted_state, state_map ) new_needed_events = set(auth_events.values()) @@ -132,7 +133,7 @@ async def resolve_events_with_store( state_map.update(state_map_new) return _resolve_with_state( - unconflicted_state, conflicted_state, auth_events, state_map + room_version, unconflicted_state, conflicted_state, auth_events, state_map ) @@ -187,6 +188,7 @@ def _seperate( def _create_auth_events_from_maps( + room_version: RoomVersion, unconflicted_state: StateMap[str], conflicted_state: StateMap[Set[str]], state_map: Dict[str, EventBase], @@ -194,6 +196,7 @@ def _create_auth_events_from_maps( """ Args: + room_version: The room version. unconflicted_state: The unconflicted state map. conflicted_state: The conflicted state map. state_map: @@ -205,7 +208,9 @@ def _create_auth_events_from_maps( for event_ids in conflicted_state.values(): for event_id in event_ids: if event_id in state_map: - keys = event_auth.auth_types_for_event(state_map[event_id]) + keys = event_auth.auth_types_for_event( + room_version, state_map[event_id] + ) for key in keys: if key not in auth_events: auth_event_id = unconflicted_state.get(key, None) @@ -215,6 +220,7 @@ def _create_auth_events_from_maps( def _resolve_with_state( + room_version: RoomVersion, unconflicted_state_ids: MutableStateMap[str], conflicted_state_ids: StateMap[Set[str]], auth_event_ids: StateMap[str], @@ -235,7 +241,9 @@ def _resolve_with_state( } try: - resolved_state = _resolve_state_events(conflicted_state, auth_events) + resolved_state = _resolve_state_events( + room_version, conflicted_state, auth_events + ) except Exception: logger.exception("Failed to resolve state") raise @@ -248,7 +256,9 @@ def _resolve_with_state( def _resolve_state_events( - conflicted_state: StateMap[List[EventBase]], auth_events: MutableStateMap[EventBase] + room_version: RoomVersion, + conflicted_state: StateMap[List[EventBase]], + auth_events: MutableStateMap[EventBase], ) -> StateMap[EventBase]: """This is where we actually decide which of the conflicted state to use. @@ -263,21 +273,27 @@ def _resolve_state_events( if POWER_KEY in conflicted_state: events = conflicted_state[POWER_KEY] logger.debug("Resolving conflicted power levels %r", events) - resolved_state[POWER_KEY] = _resolve_auth_events(events, auth_events) + resolved_state[POWER_KEY] = _resolve_auth_events( + room_version, events, auth_events + ) auth_events.update(resolved_state) for key, events in conflicted_state.items(): if key[0] == EventTypes.JoinRules: logger.debug("Resolving conflicted join rules %r", events) - resolved_state[key] = _resolve_auth_events(events, auth_events) + resolved_state[key] = _resolve_auth_events( + room_version, events, auth_events + ) auth_events.update(resolved_state) for key, events in conflicted_state.items(): if key[0] == EventTypes.Member: logger.debug("Resolving conflicted member lists %r", events) - resolved_state[key] = _resolve_auth_events(events, auth_events) + resolved_state[key] = _resolve_auth_events( + room_version, events, auth_events + ) auth_events.update(resolved_state) @@ -290,12 +306,14 @@ def _resolve_state_events( def _resolve_auth_events( - events: List[EventBase], auth_events: StateMap[EventBase] + room_version: RoomVersion, events: List[EventBase], auth_events: StateMap[EventBase] ) -> EventBase: reverse = list(reversed(_ordered_events(events))) auth_keys = { - key for event in events for key in event_auth.auth_types_for_event(event) + key + for event in events + for key in event_auth.auth_types_for_event(room_version, event) } new_auth_events = {} diff --git a/synapse/state/v2.py b/synapse/state/v2.py index e66e6571c8..7b1e8361de 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -36,7 +36,7 @@ import synapse.state from synapse import event_auth from synapse.api.constants import EventTypes from synapse.api.errors import AuthError -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.api.room_versions import RoomVersion from synapse.events import EventBase from synapse.types import MutableStateMap, StateMap from synapse.util import Clock @@ -53,7 +53,7 @@ _AWAIT_AFTER_ITERATIONS = 100 async def resolve_events_with_store( clock: Clock, room_id: str, - room_version: str, + room_version: RoomVersion, state_sets: Sequence[StateMap[str]], event_map: Optional[Dict[str, EventBase]], state_res_store: "synapse.state.StateResolutionStore", @@ -497,7 +497,7 @@ async def _reverse_topological_power_sort( async def _iterative_auth_checks( clock: Clock, room_id: str, - room_version: str, + room_version: RoomVersion, event_ids: List[str], base_state: StateMap[str], event_map: Dict[str, EventBase], @@ -519,7 +519,6 @@ async def _iterative_auth_checks( Returns the final updated state """ resolved_state = dict(base_state) - room_version_obj = KNOWN_ROOM_VERSIONS[room_version] for idx, event_id in enumerate(event_ids, start=1): event = event_map[event_id] @@ -538,7 +537,7 @@ async def _iterative_auth_checks( if ev.rejected_reason is None: auth_events[(ev.type, ev.state_key)] = ev - for key in event_auth.auth_types_for_event(event): + for key in event_auth.auth_types_for_event(room_version, event): if key in resolved_state: ev_id = resolved_state[key] ev = await _get_event(room_id, ev_id, event_map, state_res_store) @@ -548,7 +547,7 @@ async def _iterative_auth_checks( try: event_auth.check( - room_version_obj, + room_version, event, auth_events, do_sig_check=False, |