summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/6531.misc1
-rw-r--r--synapse/handlers/federation.py1
-rw-r--r--synapse/state/__init__.py32
-rw-r--r--synapse/state/v1.py34
-rw-r--r--synapse/state/v2.py100
-rw-r--r--tests/state/test_v2.py3
6 files changed, 128 insertions, 43 deletions
diff --git a/changelog.d/6531.misc b/changelog.d/6531.misc
new file mode 100644
index 0000000000..598efb79fc
--- /dev/null
+++ b/changelog.d/6531.misc
@@ -0,0 +1 @@
+Improve sanity-checking when receiving events over federation.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index ebeffbb768..fd3f5ced55 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -397,6 +397,7 @@ class FederationHandler(BaseHandler):
                                 event_map[x.event_id] = x
 
                     state_map = yield resolve_events_with_store(
+                        room_id,
                         room_version,
                         state_maps,
                         event_map,
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 139beef8ed..0e75e94c6f 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -16,7 +16,7 @@
 
 import logging
 from collections import namedtuple
-from typing import Iterable, Optional
+from typing import Dict, Iterable, List, Optional, Tuple
 
 from six import iteritems, itervalues
 
@@ -416,6 +416,7 @@ class StateHandler(object):
 
         with Measure(self.clock, "state._resolve_events"):
             new_state = yield resolve_events_with_store(
+                event.room_id,
                 room_version,
                 state_set_ids,
                 event_map=state_map,
@@ -461,7 +462,7 @@ class StateResolutionHandler(object):
         not be called for a single state group
 
         Args:
-            room_id (str): room we are resolving for (used for logging)
+            room_id (str): room we are resolving for (used for logging and sanity checks)
             room_version (str): version of the room
             state_groups_ids (dict[int, dict[(str, str), str]]):
                  map from state group id to the state in that state group
@@ -517,6 +518,7 @@ class StateResolutionHandler(object):
                 logger.info("Resolving conflicted state for %r", room_id)
                 with Measure(self.clock, "state._resolve_events"):
                     new_state = yield resolve_events_with_store(
+                        room_id,
                         room_version,
                         list(itervalues(state_groups_ids)),
                         event_map=event_map,
@@ -588,36 +590,44 @@ def _make_state_cache_entry(new_state, state_groups_ids):
     )
 
 
-def resolve_events_with_store(room_version, state_sets, event_map, state_res_store):
+def resolve_events_with_store(
+    room_id: str,
+    room_version: str,
+    state_sets: List[Dict[Tuple[str, str], str]],
+    event_map: Optional[Dict[str, EventBase]],
+    state_res_store: "StateResolutionStore",
+):
     """
     Args:
-        room_version(str): Version of the room
+        room_id: the room we are working in
+
+        room_version: Version of the room
 
-        state_sets(list): List of dicts of (type, state_key) -> event_id,
+        state_sets: List of dicts of (type, state_key) -> event_id,
             which are the different state groups to resolve.
 
-        event_map(dict[str,FrozenEvent]|None):
+        event_map:
             a dict from event_id to event, for any events that we happen to
             have in flight (eg, those currently being persisted). This will be
             used as a starting point fof finding the state we need; any missing
             events will be requested via state_map_factory.
 
-            If None, all events will be fetched via state_map_factory.
+            If None, all events will be fetched via state_res_store.
 
-        state_res_store (StateResolutionStore)
+        state_res_store: a place to fetch events from
 
-    Returns
+    Returns:
         Deferred[dict[(str, str), str]]:
             a map from (type, state_key) to event_id.
     """
     v = KNOWN_ROOM_VERSIONS[room_version]
     if v.state_res == StateResolutionVersions.V1:
         return v1.resolve_events_with_store(
-            state_sets, event_map, state_res_store.get_events
+            room_id, state_sets, event_map, state_res_store.get_events
         )
     else:
         return v2.resolve_events_with_store(
-            room_version, state_sets, event_map, state_res_store
+            room_id, room_version, state_sets, event_map, state_res_store
         )
 
 
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index a2f92d9ff9..b2f9865f39 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -15,6 +15,7 @@
 
 import hashlib
 import logging
+from typing import Callable, Dict, List, Optional, Tuple
 
 from six import iteritems, iterkeys, itervalues
 
@@ -24,6 +25,7 @@ from synapse import event_auth
 from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 from synapse.api.room_versions import RoomVersions
+from synapse.events import EventBase
 
 logger = logging.getLogger(__name__)
 
@@ -32,13 +34,20 @@ POWER_KEY = (EventTypes.PowerLevels, "")
 
 
 @defer.inlineCallbacks
-def resolve_events_with_store(state_sets, event_map, state_map_factory):
+def resolve_events_with_store(
+    room_id: str,
+    state_sets: List[Dict[Tuple[str, str], str]],
+    event_map: Optional[Dict[str, EventBase]],
+    state_map_factory: Callable,
+):
     """
     Args:
-        state_sets(list): List of dicts of (type, state_key) -> event_id,
+        room_id: the room we are working in
+
+        state_sets: List of dicts of (type, state_key) -> event_id,
             which are the different state groups to resolve.
 
-        event_map(dict[str,FrozenEvent]|None):
+        event_map:
             a dict from event_id to event, for any events that we happen to
             have in flight (eg, those currently being persisted). This will be
             used as a starting point fof finding the state we need; any missing
@@ -46,11 +55,11 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory):
 
             If None, all events will be fetched via state_map_factory.
 
-        state_map_factory(func): will be called
+        state_map_factory: will be called
             with a list of event_ids that are needed, and should return with
             a Deferred of dict of event_id to event.
 
-    Returns
+    Returns:
         Deferred[dict[(str, str), str]]:
             a map from (type, state_key) to event_id.
     """
@@ -76,6 +85,14 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory):
     if event_map is not None:
         state_map.update(event_map)
 
+    # everything in the state map should be in the right room
+    for event in state_map.values():
+        if event.room_id != room_id:
+            raise Exception(
+                "Attempting to state-resolve for room %s with event %s which is in %s"
+                % (room_id, event.event_id, event.room_id,)
+            )
+
     # get the ids of the auth events which allow us to authenticate the
     # conflicted state, picking only from the unconflicting state.
     #
@@ -95,6 +112,13 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory):
     )
 
     state_map_new = yield state_map_factory(new_needed_events)
+    for event in state_map_new.values():
+        if event.room_id != room_id:
+            raise Exception(
+                "Attempting to state-resolve for room %s with event %s which is in %s"
+                % (room_id, event.event_id, event.room_id,)
+            )
+
     state_map.update(state_map_new)
 
     return _resolve_with_state(
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index b327c86f40..cb77ed5b78 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -16,29 +16,40 @@
 import heapq
 import itertools
 import logging
+from typing import Dict, List, Optional, Tuple
 
 from six import iteritems, itervalues
 
 from twisted.internet import defer
 
+import synapse.state
 from synapse import event_auth
 from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
+from synapse.events import EventBase
 
 logger = logging.getLogger(__name__)
 
 
 @defer.inlineCallbacks
-def resolve_events_with_store(room_version, state_sets, event_map, state_res_store):
+def resolve_events_with_store(
+    room_id: str,
+    room_version: str,
+    state_sets: List[Dict[Tuple[str, str], str]],
+    event_map: Optional[Dict[str, EventBase]],
+    state_res_store: "synapse.state.StateResolutionStore",
+):
     """Resolves the state using the v2 state resolution algorithm
 
     Args:
-        room_version (str): The room version
+        room_id: the room we are working in
+
+        room_version: The room version
 
-        state_sets(list): List of dicts of (type, state_key) -> event_id,
+        state_sets: List of dicts of (type, state_key) -> event_id,
             which are the different state groups to resolve.
 
-        event_map(dict[str,FrozenEvent]|None):
+        event_map:
             a dict from event_id to event, for any events that we happen to
             have in flight (eg, those currently being persisted). This will be
             used as a starting point fof finding the state we need; any missing
@@ -46,9 +57,9 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
 
             If None, all events will be fetched via state_res_store.
 
-        state_res_store (StateResolutionStore)
+        state_res_store:
 
-    Returns
+    Returns:
         Deferred[dict[(str, str), str]]:
             a map from (type, state_key) to event_id.
     """
@@ -84,6 +95,14 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
     )
     event_map.update(events)
 
+    # everything in the event map should be in the right room
+    for event in event_map.values():
+        if event.room_id != room_id:
+            raise Exception(
+                "Attempting to state-resolve for room %s with event %s which is in %s"
+                % (room_id, event.event_id, event.room_id,)
+            )
+
     full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map)
 
     logger.debug("%d full_conflicted_set entries", len(full_conflicted_set))
@@ -94,13 +113,14 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
     )
 
     sorted_power_events = yield _reverse_topological_power_sort(
-        power_events, event_map, state_res_store, full_conflicted_set
+        room_id, power_events, event_map, state_res_store, full_conflicted_set
     )
 
     logger.debug("sorted %d power events", len(sorted_power_events))
 
     # Now sequentially auth each one
     resolved_state = yield _iterative_auth_checks(
+        room_id,
         room_version,
         sorted_power_events,
         unconflicted_state,
@@ -121,13 +141,18 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
 
     pl = resolved_state.get((EventTypes.PowerLevels, ""), None)
     leftover_events = yield _mainline_sort(
-        leftover_events, pl, event_map, state_res_store
+        room_id, leftover_events, pl, event_map, state_res_store
     )
 
     logger.debug("resolving remaining events")
 
     resolved_state = yield _iterative_auth_checks(
-        room_version, leftover_events, resolved_state, event_map, state_res_store
+        room_id,
+        room_version,
+        leftover_events,
+        resolved_state,
+        event_map,
+        state_res_store,
     )
 
     logger.debug("resolved")
@@ -141,11 +166,12 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
 
 
 @defer.inlineCallbacks
-def _get_power_level_for_sender(event_id, event_map, state_res_store):
+def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
     """Return the power level of the sender of the given event according to
     their auth events.
 
     Args:
+        room_id (str)
         event_id (str)
         event_map (dict[str,FrozenEvent])
         state_res_store (StateResolutionStore)
@@ -153,11 +179,11 @@ def _get_power_level_for_sender(event_id, event_map, state_res_store):
     Returns:
         Deferred[int]
     """
-    event = yield _get_event(event_id, event_map, state_res_store)
+    event = yield _get_event(room_id, event_id, event_map, state_res_store)
 
     pl = None
     for aid in event.auth_event_ids():
-        aev = yield _get_event(aid, event_map, state_res_store)
+        aev = yield _get_event(room_id, aid, event_map, state_res_store)
         if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
             pl = aev
             break
@@ -165,7 +191,7 @@ def _get_power_level_for_sender(event_id, event_map, state_res_store):
     if pl is None:
         # Couldn't find power level. Check if they're the creator of the room
         for aid in event.auth_event_ids():
-            aev = yield _get_event(aid, event_map, state_res_store)
+            aev = yield _get_event(room_id, aid, event_map, state_res_store)
             if (aev.type, aev.state_key) == (EventTypes.Create, ""):
                 if aev.content.get("creator") == event.sender:
                     return 100
@@ -279,7 +305,7 @@ def _is_power_event(event):
 
 @defer.inlineCallbacks
 def _add_event_and_auth_chain_to_graph(
-    graph, event_id, event_map, state_res_store, auth_diff
+    graph, room_id, event_id, event_map, state_res_store, auth_diff
 ):
     """Helper function for _reverse_topological_power_sort that add the event
     and its auth chain (that is in the auth diff) to the graph
@@ -287,6 +313,7 @@ def _add_event_and_auth_chain_to_graph(
     Args:
         graph (dict[str, set[str]]): A map from event ID to the events auth
             event IDs
+        room_id (str): the room we are working in
         event_id (str): Event to add to the graph
         event_map (dict[str,FrozenEvent])
         state_res_store (StateResolutionStore)
@@ -298,7 +325,7 @@ def _add_event_and_auth_chain_to_graph(
         eid = state.pop()
         graph.setdefault(eid, set())
 
-        event = yield _get_event(eid, event_map, state_res_store)
+        event = yield _get_event(room_id, eid, event_map, state_res_store)
         for aid in event.auth_event_ids():
             if aid in auth_diff:
                 if aid not in graph:
@@ -308,11 +335,14 @@ def _add_event_and_auth_chain_to_graph(
 
 
 @defer.inlineCallbacks
-def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_diff):
+def _reverse_topological_power_sort(
+    room_id, event_ids, event_map, state_res_store, auth_diff
+):
     """Returns a list of the event_ids sorted by reverse topological ordering,
     and then by power level and origin_server_ts
 
     Args:
+        room_id (str): the room we are working in
         event_ids (list[str]): The events to sort
         event_map (dict[str,FrozenEvent])
         state_res_store (StateResolutionStore)
@@ -325,12 +355,14 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_
     graph = {}
     for event_id in event_ids:
         yield _add_event_and_auth_chain_to_graph(
-            graph, event_id, event_map, state_res_store, auth_diff
+            graph, room_id, event_id, event_map, state_res_store, auth_diff
         )
 
     event_to_pl = {}
     for event_id in graph:
-        pl = yield _get_power_level_for_sender(event_id, event_map, state_res_store)
+        pl = yield _get_power_level_for_sender(
+            room_id, event_id, event_map, state_res_store
+        )
         event_to_pl[event_id] = pl
 
     def _get_power_order(event_id):
@@ -348,12 +380,13 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_
 
 @defer.inlineCallbacks
 def _iterative_auth_checks(
-    room_version, event_ids, base_state, event_map, state_res_store
+    room_id, room_version, event_ids, base_state, event_map, state_res_store
 ):
     """Sequentially apply auth checks to each event in given list, updating the
     state as it goes along.
 
     Args:
+        room_id (str)
         room_version (str)
         event_ids (list[str]): Ordered list of events to apply auth checks to
         base_state (dict[tuple[str, str], str]): The set of state to start with
@@ -370,7 +403,7 @@ def _iterative_auth_checks(
 
         auth_events = {}
         for aid in event.auth_event_ids():
-            ev = yield _get_event(aid, event_map, state_res_store)
+            ev = yield _get_event(room_id, aid, event_map, state_res_store)
 
             if ev.rejected_reason is None:
                 auth_events[(ev.type, ev.state_key)] = ev
@@ -378,7 +411,7 @@ def _iterative_auth_checks(
         for key in event_auth.auth_types_for_event(event):
             if key in resolved_state:
                 ev_id = resolved_state[key]
-                ev = yield _get_event(ev_id, event_map, state_res_store)
+                ev = yield _get_event(room_id, ev_id, event_map, state_res_store)
 
                 if ev.rejected_reason is None:
                     auth_events[key] = event_map[ev_id]
@@ -400,11 +433,14 @@ def _iterative_auth_checks(
 
 
 @defer.inlineCallbacks
-def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_store):
+def _mainline_sort(
+    room_id, event_ids, resolved_power_event_id, event_map, state_res_store
+):
     """Returns a sorted list of event_ids sorted by mainline ordering based on
     the given event resolved_power_event_id
 
     Args:
+        room_id (str): room we're working in
         event_ids (list[str]): Events to sort
         resolved_power_event_id (str): The final resolved power level event ID
         event_map (dict[str,FrozenEvent])
@@ -417,11 +453,11 @@ def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_stor
     pl = resolved_power_event_id
     while pl:
         mainline.append(pl)
-        pl_ev = yield _get_event(pl, event_map, state_res_store)
+        pl_ev = yield _get_event(room_id, pl, event_map, state_res_store)
         auth_events = pl_ev.auth_event_ids()
         pl = None
         for aid in auth_events:
-            ev = yield _get_event(aid, event_map, state_res_store)
+            ev = yield _get_event(room_id, aid, event_map, state_res_store)
             if (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""):
                 pl = aid
                 break
@@ -457,6 +493,8 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
         Deferred[int]
     """
 
+    room_id = event.room_id
+
     # We do an iterative search, replacing `event with the power level in its
     # auth events (if any)
     while event:
@@ -468,7 +506,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
         event = None
 
         for aid in auth_events:
-            aev = yield _get_event(aid, event_map, state_res_store)
+            aev = yield _get_event(room_id, aid, event_map, state_res_store)
             if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
                 event = aev
                 break
@@ -478,11 +516,12 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
 
 
 @defer.inlineCallbacks
-def _get_event(event_id, event_map, state_res_store):
+def _get_event(room_id, event_id, event_map, state_res_store):
     """Helper function to look up event in event_map, falling back to looking
     it up in the store
 
     Args:
+        room_id (str)
         event_id (str)
         event_map (dict[str,FrozenEvent])
         state_res_store (StateResolutionStore)
@@ -493,7 +532,14 @@ def _get_event(event_id, event_map, state_res_store):
     if event_id not in event_map:
         events = yield state_res_store.get_events([event_id], allow_rejected=True)
         event_map.update(events)
-    return event_map[event_id]
+    event = event_map[event_id]
+    assert event is not None
+    if event.room_id != room_id:
+        raise Exception(
+            "In state res for room %s, event %s is in %s"
+            % (room_id, event_id, event.room_id)
+        )
+    return event
 
 
 def lexicographical_topological_sort(graph, key):
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index 8d3845c870..0f341d3ac3 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -58,6 +58,7 @@ class FakeEvent(object):
         self.type = type
         self.state_key = state_key
         self.content = content
+        self.room_id = ROOM_ID
 
     def to_event(self, auth_events, prev_events):
         """Given the auth_events and prev_events, convert to a Frozen Event
@@ -418,6 +419,7 @@ class StateTestCase(unittest.TestCase):
                 state_before = dict(state_at_event[prev_events[0]])
             else:
                 state_d = resolve_events_with_store(
+                    ROOM_ID,
                     RoomVersions.V2.identifier,
                     [state_at_event[n] for n in prev_events],
                     event_map=event_map,
@@ -565,6 +567,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
         # Test that we correctly handle passing `None` as the event_map
 
         state_d = resolve_events_with_store(
+            ROOM_ID,
             RoomVersions.V2.identifier,
             [self.state_at_bob, self.state_at_charlie],
             event_map=None,