summary refs log tree commit diff
path: root/synapse/state
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/state')
-rw-r--r--synapse/state/__init__.py12
-rw-r--r--synapse/state/v1.py40
-rw-r--r--synapse/state/v2.py11
3 files changed, 42 insertions, 21 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 6223daf522..2e15471435 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -636,16 +636,20 @@ class StateResolutionHandler:
         """
         try:
             with Measure(self.clock, "state._resolve_events") as m:
-                v = KNOWN_ROOM_VERSIONS[room_version]
-                if v.state_res == StateResolutionVersions.V1:
+                room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+                if room_version_obj.state_res == StateResolutionVersions.V1:
                     return await v1.resolve_events_with_store(
-                        room_id, state_sets, event_map, state_res_store.get_events
+                        room_id,
+                        room_version_obj,
+                        state_sets,
+                        event_map,
+                        state_res_store.get_events,
                     )
                 else:
                     return await v2.resolve_events_with_store(
                         self.clock,
                         room_id,
-                        room_version,
+                        room_version_obj,
                         state_sets,
                         event_map,
                         state_res_store,
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 267193cedf..92336d7cc8 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -29,7 +29,7 @@ from typing import (
 from synapse import event_auth
 from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
-from synapse.api.room_versions import RoomVersions
+from synapse.api.room_versions import RoomVersion, RoomVersions
 from synapse.events import EventBase
 from synapse.types import MutableStateMap, StateMap
 
@@ -41,6 +41,7 @@ POWER_KEY = (EventTypes.PowerLevels, "")
 
 async def resolve_events_with_store(
     room_id: str,
+    room_version: RoomVersion,
     state_sets: Sequence[StateMap[str]],
     event_map: Optional[Dict[str, EventBase]],
     state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
@@ -104,7 +105,7 @@ async def resolve_events_with_store(
     # get the ids of the auth events which allow us to authenticate the
     # conflicted state, picking only from the unconflicting state.
     auth_events = _create_auth_events_from_maps(
-        unconflicted_state, conflicted_state, state_map
+        room_version, unconflicted_state, conflicted_state, state_map
     )
 
     new_needed_events = set(auth_events.values())
@@ -132,7 +133,7 @@ async def resolve_events_with_store(
     state_map.update(state_map_new)
 
     return _resolve_with_state(
-        unconflicted_state, conflicted_state, auth_events, state_map
+        room_version, unconflicted_state, conflicted_state, auth_events, state_map
     )
 
 
@@ -187,6 +188,7 @@ def _seperate(
 
 
 def _create_auth_events_from_maps(
+    room_version: RoomVersion,
     unconflicted_state: StateMap[str],
     conflicted_state: StateMap[Set[str]],
     state_map: Dict[str, EventBase],
@@ -194,6 +196,7 @@ def _create_auth_events_from_maps(
     """
 
     Args:
+        room_version: The room version.
         unconflicted_state: The unconflicted state map.
         conflicted_state: The conflicted state map.
         state_map:
@@ -205,7 +208,9 @@ def _create_auth_events_from_maps(
     for event_ids in conflicted_state.values():
         for event_id in event_ids:
             if event_id in state_map:
-                keys = event_auth.auth_types_for_event(state_map[event_id])
+                keys = event_auth.auth_types_for_event(
+                    room_version, state_map[event_id]
+                )
                 for key in keys:
                     if key not in auth_events:
                         auth_event_id = unconflicted_state.get(key, None)
@@ -215,6 +220,7 @@ def _create_auth_events_from_maps(
 
 
 def _resolve_with_state(
+    room_version: RoomVersion,
     unconflicted_state_ids: MutableStateMap[str],
     conflicted_state_ids: StateMap[Set[str]],
     auth_event_ids: StateMap[str],
@@ -235,7 +241,9 @@ def _resolve_with_state(
     }
 
     try:
-        resolved_state = _resolve_state_events(conflicted_state, auth_events)
+        resolved_state = _resolve_state_events(
+            room_version, conflicted_state, auth_events
+        )
     except Exception:
         logger.exception("Failed to resolve state")
         raise
@@ -248,7 +256,9 @@ def _resolve_with_state(
 
 
 def _resolve_state_events(
-    conflicted_state: StateMap[List[EventBase]], auth_events: MutableStateMap[EventBase]
+    room_version: RoomVersion,
+    conflicted_state: StateMap[List[EventBase]],
+    auth_events: MutableStateMap[EventBase],
 ) -> StateMap[EventBase]:
     """This is where we actually decide which of the conflicted state to
     use.
@@ -263,21 +273,27 @@ def _resolve_state_events(
     if POWER_KEY in conflicted_state:
         events = conflicted_state[POWER_KEY]
         logger.debug("Resolving conflicted power levels %r", events)
-        resolved_state[POWER_KEY] = _resolve_auth_events(events, auth_events)
+        resolved_state[POWER_KEY] = _resolve_auth_events(
+            room_version, events, auth_events
+        )
 
     auth_events.update(resolved_state)
 
     for key, events in conflicted_state.items():
         if key[0] == EventTypes.JoinRules:
             logger.debug("Resolving conflicted join rules %r", events)
-            resolved_state[key] = _resolve_auth_events(events, auth_events)
+            resolved_state[key] = _resolve_auth_events(
+                room_version, events, auth_events
+            )
 
     auth_events.update(resolved_state)
 
     for key, events in conflicted_state.items():
         if key[0] == EventTypes.Member:
             logger.debug("Resolving conflicted member lists %r", events)
-            resolved_state[key] = _resolve_auth_events(events, auth_events)
+            resolved_state[key] = _resolve_auth_events(
+                room_version, events, auth_events
+            )
 
     auth_events.update(resolved_state)
 
@@ -290,12 +306,14 @@ def _resolve_state_events(
 
 
 def _resolve_auth_events(
-    events: List[EventBase], auth_events: StateMap[EventBase]
+    room_version: RoomVersion, events: List[EventBase], auth_events: StateMap[EventBase]
 ) -> EventBase:
     reverse = list(reversed(_ordered_events(events)))
 
     auth_keys = {
-        key for event in events for key in event_auth.auth_types_for_event(event)
+        key
+        for event in events
+        for key in event_auth.auth_types_for_event(room_version, event)
     }
 
     new_auth_events = {}
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index e66e6571c8..7b1e8361de 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -36,7 +36,7 @@ import synapse.state
 from synapse import event_auth
 from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.api.room_versions import RoomVersion
 from synapse.events import EventBase
 from synapse.types import MutableStateMap, StateMap
 from synapse.util import Clock
@@ -53,7 +53,7 @@ _AWAIT_AFTER_ITERATIONS = 100
 async def resolve_events_with_store(
     clock: Clock,
     room_id: str,
-    room_version: str,
+    room_version: RoomVersion,
     state_sets: Sequence[StateMap[str]],
     event_map: Optional[Dict[str, EventBase]],
     state_res_store: "synapse.state.StateResolutionStore",
@@ -497,7 +497,7 @@ async def _reverse_topological_power_sort(
 async def _iterative_auth_checks(
     clock: Clock,
     room_id: str,
-    room_version: str,
+    room_version: RoomVersion,
     event_ids: List[str],
     base_state: StateMap[str],
     event_map: Dict[str, EventBase],
@@ -519,7 +519,6 @@ async def _iterative_auth_checks(
         Returns the final updated state
     """
     resolved_state = dict(base_state)
-    room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
 
     for idx, event_id in enumerate(event_ids, start=1):
         event = event_map[event_id]
@@ -538,7 +537,7 @@ async def _iterative_auth_checks(
                 if ev.rejected_reason is None:
                     auth_events[(ev.type, ev.state_key)] = ev
 
-        for key in event_auth.auth_types_for_event(event):
+        for key in event_auth.auth_types_for_event(room_version, event):
             if key in resolved_state:
                 ev_id = resolved_state[key]
                 ev = await _get_event(room_id, ev_id, event_map, state_res_store)
@@ -548,7 +547,7 @@ async def _iterative_auth_checks(
 
         try:
             event_auth.check(
-                room_version_obj,
+                room_version,
                 event,
                 auth_events,
                 do_sig_check=False,