summary refs log tree commit diff
path: root/synapse/state/v1.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/state/v1.py')
-rw-r--r--synapse/state/v1.py49
1 files changed, 37 insertions, 12 deletions
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index a2f92d9ff9..9bf98d06f2 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -15,6 +15,7 @@
 
 import hashlib
 import logging
+from typing import Callable, Dict, List, Optional
 
 from six import iteritems, iterkeys, itervalues
 
@@ -24,6 +25,8 @@ from synapse import event_auth
 from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 from synapse.api.room_versions import RoomVersions
+from synapse.events import EventBase
+from synapse.types import StateMap
 
 logger = logging.getLogger(__name__)
 
@@ -32,13 +35,20 @@ POWER_KEY = (EventTypes.PowerLevels, "")
 
 
 @defer.inlineCallbacks
-def resolve_events_with_store(state_sets, event_map, state_map_factory):
+def resolve_events_with_store(
+    room_id: str,
+    state_sets: List[StateMap[str]],
+    event_map: Optional[Dict[str, EventBase]],
+    state_map_factory: Callable,
+):
     """
     Args:
-        state_sets(list): List of dicts of (type, state_key) -> event_id,
+        room_id: the room we are working in
+
+        state_sets: List of dicts of (type, state_key) -> event_id,
             which are the different state groups to resolve.
 
-        event_map(dict[str,FrozenEvent]|None):
+        event_map:
             a dict from event_id to event, for any events that we happen to
             have in flight (eg, those currently being persisted). This will be
             used as a starting point fof finding the state we need; any missing
@@ -46,11 +56,11 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory):
 
             If None, all events will be fetched via state_map_factory.
 
-        state_map_factory(func): will be called
+        state_map_factory: will be called
             with a list of event_ids that are needed, and should return with
             a Deferred of dict of event_id to event.
 
-    Returns
+    Returns:
         Deferred[dict[(str, str), str]]:
             a map from (type, state_key) to event_id.
     """
@@ -59,9 +69,9 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory):
 
     unconflicted_state, conflicted_state = _seperate(state_sets)
 
-    needed_events = set(
+    needed_events = {
         event_id for event_ids in itervalues(conflicted_state) for event_id in event_ids
-    )
+    }
     needed_event_count = len(needed_events)
     if event_map is not None:
         needed_events -= set(iterkeys(event_map))
@@ -76,6 +86,14 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory):
     if event_map is not None:
         state_map.update(event_map)
 
+    # everything in the state map should be in the right room
+    for event in state_map.values():
+        if event.room_id != room_id:
+            raise Exception(
+                "Attempting to state-resolve for room %s with event %s which is in %s"
+                % (room_id, event.event_id, event.room_id,)
+            )
+
     # get the ids of the auth events which allow us to authenticate the
     # conflicted state, picking only from the unconflicting state.
     #
@@ -95,6 +113,13 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory):
     )
 
     state_map_new = yield state_map_factory(new_needed_events)
+    for event in state_map_new.values():
+        if event.room_id != room_id:
+            raise Exception(
+                "Attempting to state-resolve for room %s with event %s which is in %s"
+                % (room_id, event.event_id, event.room_id,)
+            )
+
     state_map.update(state_map_new)
 
     return _resolve_with_state(
@@ -236,11 +261,11 @@ def _resolve_state_events(conflicted_state, auth_events):
 
 
 def _resolve_auth_events(events, auth_events):
-    reverse = [i for i in reversed(_ordered_events(events))]
+    reverse = list(reversed(_ordered_events(events)))
 
-    auth_keys = set(
+    auth_keys = {
         key for event in events for key in event_auth.auth_types_for_event(event)
-    )
+    }
 
     new_auth_events = {}
     for key in auth_keys:
@@ -256,7 +281,7 @@ def _resolve_auth_events(events, auth_events):
         try:
             # The signatures have already been checked at this point
             event_auth.check(
-                RoomVersions.V1.identifier,
+                RoomVersions.V1,
                 event,
                 auth_events,
                 do_sig_check=False,
@@ -274,7 +299,7 @@ def _resolve_normal_events(events, auth_events):
         try:
             # The signatures have already been checked at this point
             event_auth.check(
-                RoomVersions.V1.identifier,
+                RoomVersions.V1,
                 event,
                 auth_events,
                 do_sig_check=False,