summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/6501.misc1
-rw-r--r--changelog.d/6503.misc1
-rw-r--r--changelog.d/6521.misc1
-rw-r--r--changelog.d/6524.misc2
-rw-r--r--changelog.d/6526.bugfix1
-rw-r--r--changelog.d/6527.bugfix1
-rw-r--r--changelog.d/6530.misc2
-rw-r--r--changelog.d/6531.misc1
-rw-r--r--synapse/event_auth.py12
-rw-r--r--synapse/federation/federation_client.py103
-rw-r--r--synapse/handlers/federation.py352
-rw-r--r--synapse/state/__init__.py32
-rw-r--r--synapse/state/v1.py34
-rw-r--r--synapse/state/v2.py100
-rw-r--r--synapse/util/async_helpers.py4
-rw-r--r--tests/state/test_v2.py3
16 files changed, 356 insertions, 294 deletions
diff --git a/changelog.d/6501.misc b/changelog.d/6501.misc
new file mode 100644
index 0000000000..255f45a9c3
--- /dev/null
+++ b/changelog.d/6501.misc
@@ -0,0 +1 @@
+Refactor get_events_from_store_or_dest to return a dict.
diff --git a/changelog.d/6503.misc b/changelog.d/6503.misc
new file mode 100644
index 0000000000..e4e9a5a3d4
--- /dev/null
+++ b/changelog.d/6503.misc
@@ -0,0 +1 @@
+Move get_state methods into FederationHandler.
diff --git a/changelog.d/6521.misc b/changelog.d/6521.misc
new file mode 100644
index 0000000000..d9a44389b9
--- /dev/null
+++ b/changelog.d/6521.misc
@@ -0,0 +1 @@
+Refactor some code in the event authentication path for clarity.
diff --git a/changelog.d/6524.misc b/changelog.d/6524.misc
new file mode 100644
index 0000000000..f885597426
--- /dev/null
+++ b/changelog.d/6524.misc
@@ -0,0 +1,2 @@
+Improve sanity-checking when receiving events over federation.
+
diff --git a/changelog.d/6526.bugfix b/changelog.d/6526.bugfix
new file mode 100644
index 0000000000..53214b0748
--- /dev/null
+++ b/changelog.d/6526.bugfix
@@ -0,0 +1 @@
+Fix a bug which could cause the federation server to incorrectly return errors when handling certain obscure event graphs.
\ No newline at end of file
diff --git a/changelog.d/6527.bugfix b/changelog.d/6527.bugfix
new file mode 100644
index 0000000000..53214b0748
--- /dev/null
+++ b/changelog.d/6527.bugfix
@@ -0,0 +1 @@
+Fix a bug which could cause the federation server to incorrectly return errors when handling certain obscure event graphs.
\ No newline at end of file
diff --git a/changelog.d/6530.misc b/changelog.d/6530.misc
new file mode 100644
index 0000000000..f885597426
--- /dev/null
+++ b/changelog.d/6530.misc
@@ -0,0 +1,2 @@
+Improve sanity-checking when receiving events over federation.
+
diff --git a/changelog.d/6531.misc b/changelog.d/6531.misc
new file mode 100644
index 0000000000..598efb79fc
--- /dev/null
+++ b/changelog.d/6531.misc
@@ -0,0 +1 @@
+Improve sanity-checking when receiving events over federation.
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index ec3243b27b..d184b0273b 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -48,6 +48,18 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
     if not hasattr(event, "room_id"):
         raise AuthError(500, "Event has no room_id: %s" % event)
 
+    room_id = event.room_id
+
+    # I'm not really expecting to get auth events in the wrong room, but let's
+    # sanity-check it
+    for auth_event in auth_events.values():
+        if auth_event.room_id != room_id:
+            raise Exception(
+                "During auth for event %s in room %s, found event %s in the state "
+                "which is in room %s"
+                % (event.event_id, room_id, auth_event.event_id, auth_event.room_id)
+            )
+
     if do_sig_check:
         sender_domain = get_domain_from_id(event.sender)
 
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 709449c9e3..d396e6564f 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -18,8 +18,6 @@ import copy
 import itertools
 import logging
 
-from six.moves import range
-
 from prometheus_client import Counter
 
 from twisted.internet import defer
@@ -39,7 +37,7 @@ from synapse.api.room_versions import (
 )
 from synapse.events import builder, room_version_to_event_format
 from synapse.federation.federation_base import FederationBase, event_from_pdu_json
-from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.logging.context import make_deferred_yieldable
 from synapse.logging.utils import log_function
 from synapse.util import unwrapFirstError
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -310,19 +308,12 @@ class FederationClient(FederationBase):
         return signed_pdu
 
     @defer.inlineCallbacks
-    @log_function
-    def get_state_for_room(self, destination, room_id, event_id):
-        """Requests all of the room state at a given event from a remote homeserver.
-
-        Args:
-            destination (str): The remote homeserver to query for the state.
-            room_id (str): The id of the room we're interested in.
-            event_id (str): The id of the event we want the state at.
+    def get_room_state_ids(self, destination: str, room_id: str, event_id: str):
+        """Calls the /state_ids endpoint to fetch the state at a particular point
+        in the room, and the auth events for the given event
 
         Returns:
-            Deferred[Tuple[List[EventBase], List[EventBase]]]:
-                A list of events in the state, and a list of events in the auth chain
-                for the given event.
+            Tuple[List[str], List[str]]:  a tuple of (state event_ids, auth event_ids)
         """
         result = yield self.transport_layer.get_room_state_ids(
             destination, room_id, event_id=event_id
@@ -331,86 +322,12 @@ class FederationClient(FederationBase):
         state_event_ids = result["pdu_ids"]
         auth_event_ids = result.get("auth_chain_ids", [])
 
-        fetched_events, failed_to_fetch = yield self.get_events_from_store_or_dest(
-            destination, room_id, set(state_event_ids + auth_event_ids)
-        )
-
-        if failed_to_fetch:
-            logger.warning(
-                "Failed to fetch missing state/auth events for %s: %s",
-                room_id,
-                failed_to_fetch,
-            )
-
-        event_map = {ev.event_id: ev for ev in fetched_events}
-
-        pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
-        auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
-
-        auth_chain.sort(key=lambda e: e.depth)
-
-        return pdus, auth_chain
-
-    @defer.inlineCallbacks
-    def get_events_from_store_or_dest(self, destination, room_id, event_ids):
-        """Fetch events from a remote destination, checking if we already have them.
-
-        Args:
-            destination (str)
-            room_id (str)
-            event_ids (list)
-
-        Returns:
-            Deferred: A deferred resolving to a 2-tuple where the first is a list of
-            events and the second is a list of event ids that we failed to fetch.
-        """
-        seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
-        signed_events = list(seen_events.values())
-
-        failed_to_fetch = set()
-
-        missing_events = set(event_ids)
-        for k in seen_events:
-            missing_events.discard(k)
-
-        if not missing_events:
-            return signed_events, failed_to_fetch
-
-        logger.debug(
-            "Fetching unknown state/auth events %s for room %s",
-            missing_events,
-            event_ids,
-        )
-
-        room_version = yield self.store.get_room_version(room_id)
-
-        batch_size = 20
-        missing_events = list(missing_events)
-        for i in range(0, len(missing_events), batch_size):
-            batch = set(missing_events[i : i + batch_size])
-
-            deferreds = [
-                run_in_background(
-                    self.get_pdu,
-                    destinations=[destination],
-                    event_id=e_id,
-                    room_version=room_version,
-                )
-                for e_id in batch
-            ]
-
-            res = yield make_deferred_yieldable(
-                defer.DeferredList(deferreds, consumeErrors=True)
-            )
-            for success, result in res:
-                if success and result:
-                    signed_events.append(result)
-                    batch.discard(result.event_id)
-
-            # We removed all events we successfully fetched from `batch`
-            failed_to_fetch.update(batch)
+        if not isinstance(state_event_ids, list) or not isinstance(
+            auth_event_ids, list
+        ):
+            raise Exception("invalid response from /state_ids")
 
-        return signed_events, failed_to_fetch
+        return state_event_ids, auth_event_ids
 
     @defer.inlineCallbacks
     @log_function
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index bc26921768..abe02907b9 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -64,8 +64,7 @@ from synapse.replication.http.federation import (
 from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
 from synapse.state import StateResolutionStore, resolve_events_with_store
 from synapse.types import UserID, get_domain_from_id
-from synapse.util import unwrapFirstError
-from synapse.util.async_helpers import Linearizer
+from synapse.util.async_helpers import Linearizer, concurrently_execute
 from synapse.util.distributor import user_joined_room
 from synapse.util.retryutils import NotRetryingDestination
 from synapse.visibility import filter_events_for_server
@@ -240,7 +239,6 @@ class FederationHandler(BaseHandler):
             return None
 
         state = None
-        auth_chain = []
 
         # Get missing pdus if necessary.
         if not pdu.internal_metadata.is_outlier():
@@ -346,7 +344,6 @@ class FederationHandler(BaseHandler):
 
                 # Calculate the state after each of the previous events, and
                 # resolve them to find the correct state at the current event.
-                auth_chains = set()
                 event_map = {event_id: pdu}
                 try:
                     # Get the state of the events we know about
@@ -370,38 +367,14 @@ class FederationHandler(BaseHandler):
                             p,
                         )
 
-                        room_version = yield self.store.get_room_version(room_id)
-
                         with nested_logging_context(p):
                             # note that if any of the missing prevs share missing state or
                             # auth events, the requests to fetch those events are deduped
                             # by the get_pdu_cache in federation_client.
-                            (
-                                remote_state,
-                                got_auth_chain,
-                            ) = yield self.federation_client.get_state_for_room(
-                                origin, room_id, p
-                            )
-
-                            # we want the state *after* p; get_state_for_room returns the
-                            # state *before* p.
-                            remote_event = yield self.federation_client.get_pdu(
-                                [origin], p, room_version, outlier=True
+                            (remote_state, _,) = yield self._get_state_for_room(
+                                origin, room_id, p, include_event_in_state=True
                             )
 
-                            if remote_event is None:
-                                raise Exception(
-                                    "Unable to get missing prev_event %s" % (p,)
-                                )
-
-                            if remote_event.is_state():
-                                remote_state.append(remote_event)
-
-                            # XXX hrm I'm not convinced that duplicate events will compare
-                            # for equality, so I'm not sure this does what the author
-                            # hoped.
-                            auth_chains.update(got_auth_chain)
-
                             remote_state_map = {
                                 (x.type, x.state_key): x.event_id for x in remote_state
                             }
@@ -410,7 +383,9 @@ class FederationHandler(BaseHandler):
                             for x in remote_state:
                                 event_map[x.event_id] = x
 
+                    room_version = yield self.store.get_room_version(room_id)
                     state_map = yield resolve_events_with_store(
+                        room_id,
                         room_version,
                         state_maps,
                         event_map,
@@ -430,7 +405,6 @@ class FederationHandler(BaseHandler):
                     event_map.update(evs)
 
                     state = [event_map[e] for e in six.itervalues(state_map)]
-                    auth_chain = list(auth_chains)
                 except Exception:
                     logger.warning(
                         "[%s %s] Error attempting to resolve state at missing "
@@ -446,9 +420,7 @@ class FederationHandler(BaseHandler):
                         affected=event_id,
                     )
 
-        yield self._process_received_pdu(
-            origin, pdu, state=state, auth_chain=auth_chain
-        )
+        yield self._process_received_pdu(origin, pdu, state=state)
 
     @defer.inlineCallbacks
     def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
@@ -584,49 +556,149 @@ class FederationHandler(BaseHandler):
                         raise
 
     @defer.inlineCallbacks
-    def _process_received_pdu(self, origin, event, state, auth_chain):
-        """ Called when we have a new pdu. We need to do auth checks and put it
-        through the StateHandler.
+    @log_function
+    def _get_state_for_room(
+        self, destination, room_id, event_id, include_event_in_state
+    ):
+        """Requests all of the room state at a given event from a remote homeserver.
+
+        Args:
+            destination (str): The remote homeserver to query for the state.
+            room_id (str): The id of the room we're interested in.
+            event_id (str): The id of the event we want the state at.
+            include_event_in_state: if true, the event itself will be included in the
+                returned state event list.
+
+        Returns:
+            Deferred[Tuple[List[EventBase], List[EventBase]]]:
+                A list of events in the state, and a list of events in the auth chain
+                for the given event.
         """
-        room_id = event.room_id
-        event_id = event.event_id
+        (
+            state_event_ids,
+            auth_event_ids,
+        ) = yield self.federation_client.get_room_state_ids(
+            destination, room_id, event_id=event_id
+        )
 
-        logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
+        desired_events = set(state_event_ids + auth_event_ids)
+
+        if include_event_in_state:
+            desired_events.add(event_id)
+
+        event_map = yield self._get_events_from_store_or_dest(
+            destination, room_id, desired_events
+        )
 
-        event_ids = set()
-        if state:
-            event_ids |= {e.event_id for e in state}
-        if auth_chain:
-            event_ids |= {e.event_id for e in auth_chain}
+        failed_to_fetch = desired_events - event_map.keys()
+        if failed_to_fetch:
+            logger.warning(
+                "Failed to fetch missing state/auth events for %s: %s",
+                room_id,
+                failed_to_fetch,
+            )
 
-        seen_ids = yield self.store.have_seen_events(event_ids)
+        remote_state = [
+            event_map[e_id] for e_id in state_event_ids if e_id in event_map
+        ]
 
-        if state and auth_chain is not None:
-            # If we have any state or auth_chain given to us by the replication
-            # layer, then we should handle them (if we haven't before.)
+        if include_event_in_state:
+            remote_event = event_map.get(event_id)
+            if not remote_event:
+                raise Exception("Unable to get missing prev_event %s" % (event_id,))
+            if remote_event.is_state() and remote_event.rejected_reason is None:
+                remote_state.append(remote_event)
 
-            event_infos = []
+        auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
+        auth_chain.sort(key=lambda e: e.depth)
 
-            for e in itertools.chain(auth_chain, state):
-                if e.event_id in seen_ids:
-                    continue
-                e.internal_metadata.outlier = True
-                auth_ids = e.auth_event_ids()
-                auth = {
-                    (e.type, e.state_key): e
-                    for e in auth_chain
-                    if e.event_id in auth_ids or e.type == EventTypes.Create
-                }
-                event_infos.append(_NewEventInfo(event=e, auth_events=auth))
-                seen_ids.add(e.event_id)
+        return remote_state, auth_chain
 
-            logger.info(
-                "[%s %s] persisting newly-received auth/state events %s",
+    @defer.inlineCallbacks
+    def _get_events_from_store_or_dest(self, destination, room_id, event_ids):
+        """Fetch events from a remote destination, checking if we already have them.
+
+        Args:
+            destination (str)
+            room_id (str)
+            event_ids (Iterable[str])
+
+        Persists any events we don't already have as outliers.
+
+        If we fail to fetch any of the events, a warning will be logged, and the event
+        will be omitted from the result. Likewise, any events which turn out not to
+        be in the given room.
+
+        Returns:
+            Deferred[dict[str, EventBase]]: A deferred resolving to a map
+            from event_id to event
+        """
+        fetched_events = yield self.store.get_events(event_ids, allow_rejected=True)
+
+        missing_events = set(event_ids) - fetched_events.keys()
+
+        if missing_events:
+            logger.debug(
+                "Fetching unknown state/auth events %s for room %s",
+                missing_events,
                 room_id,
-                event_id,
-                [e.event.event_id for e in event_infos],
             )
-            yield self._handle_new_events(origin, event_infos)
+
+            yield self._get_events_and_persist(
+                destination=destination, room_id=room_id, events=missing_events
+            )
+
+            # we need to make sure we re-load from the database to get the rejected
+            # state correct.
+            fetched_events.update(
+                (yield self.store.get_events(missing_events, allow_rejected=True))
+            )
+
+        # check for events which were in the wrong room.
+        #
+        # this can happen if a remote server claims that the state or
+        # auth_events at an event in room A are actually events in room B
+
+        bad_events = list(
+            (event_id, event.room_id)
+            for event_id, event in fetched_events.items()
+            if event.room_id != room_id
+        )
+
+        for bad_event_id, bad_room_id in bad_events:
+            # This is a bogus situation, but since we may only discover it a long time
+            # after it happened, we try our best to carry on, by just omitting the
+            # bad events from the returned auth/state set.
+            logger.warning(
+                "Remote server %s claims event %s in room %s is an auth/state "
+                "event in room %s",
+                destination,
+                bad_event_id,
+                bad_room_id,
+                room_id,
+            )
+            del fetched_events[bad_event_id]
+
+        return fetched_events
+
+    @defer.inlineCallbacks
+    def _process_received_pdu(self, origin, event, state):
+        """ Called when we have a new pdu. We need to do auth checks and put it
+        through the StateHandler.
+
+        Args:
+            origin: server sending the event
+
+            event: event to be persisted
+
+            state: Normally None, but if we are handling a gap in the graph
+                (ie, we are missing one or more prev_events), the resolved state at the
+                event
+        """
+        room_id = event.room_id
+        event_id = event.event_id
+
+        logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
 
         try:
             context = yield self._handle_new_event(origin, event, state=state)
@@ -683,8 +755,6 @@ class FederationHandler(BaseHandler):
         if dest == self.server_name:
             raise SynapseError(400, "Can't backfill from self.")
 
-        room_version = yield self.store.get_room_version(room_id)
-
         events = yield self.federation_client.backfill(
             dest, room_id, limit=limit, extremities=extremities
         )
@@ -713,6 +783,9 @@ class FederationHandler(BaseHandler):
 
         event_ids = set(e.event_id for e in events)
 
+        # build a list of events whose prev_events weren't in the batch.
+        # (XXX: this will include events whose prev_events we already have; that doesn't
+        # sound right?)
         edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids]
 
         logger.info("backfill: Got %d events with %d edges", len(events), len(edges))
@@ -723,7 +796,7 @@ class FederationHandler(BaseHandler):
         state_events = {}
         events_to_state = {}
         for e_id in edges:
-            state, auth = yield self.federation_client.get_state_for_room(
+            state, auth = yield self._get_state_for_room(
                 destination=dest, room_id=room_id, event_id=e_id
             )
             auth_events.update({a.event_id: a for a in auth})
@@ -741,95 +814,11 @@ class FederationHandler(BaseHandler):
         auth_events.update(
             {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map}
         )
-        missing_auth = required_auth - set(auth_events)
-        failed_to_fetch = set()
-
-        # Try and fetch any missing auth events from both DB and remote servers.
-        # We repeatedly do this until we stop finding new auth events.
-        while missing_auth - failed_to_fetch:
-            logger.info("Missing auth for backfill: %r", missing_auth)
-            ret_events = yield self.store.get_events(missing_auth - failed_to_fetch)
-            auth_events.update(ret_events)
-
-            required_auth.update(
-                a_id for event in ret_events.values() for a_id in event.auth_event_ids()
-            )
-            missing_auth = required_auth - set(auth_events)
-
-            if missing_auth - failed_to_fetch:
-                logger.info(
-                    "Fetching missing auth for backfill: %r",
-                    missing_auth - failed_to_fetch,
-                )
-
-                results = yield make_deferred_yieldable(
-                    defer.gatherResults(
-                        [
-                            run_in_background(
-                                self.federation_client.get_pdu,
-                                [dest],
-                                event_id,
-                                room_version=room_version,
-                                outlier=True,
-                                timeout=10000,
-                            )
-                            for event_id in missing_auth - failed_to_fetch
-                        ],
-                        consumeErrors=True,
-                    )
-                ).addErrback(unwrapFirstError)
-                auth_events.update({a.event_id: a for a in results if a})
-                required_auth.update(
-                    a_id
-                    for event in results
-                    if event
-                    for a_id in event.auth_event_ids()
-                )
-                missing_auth = required_auth - set(auth_events)
-
-                failed_to_fetch = missing_auth - set(auth_events)
-
-        seen_events = yield self.store.have_seen_events(
-            set(auth_events.keys()) | set(state_events.keys())
-        )
-
-        # We now have a chunk of events plus associated state and auth chain to
-        # persist. We do the persistence in two steps:
-        #   1. Auth events and state get persisted as outliers, plus the
-        #      backward extremities get persisted (as non-outliers).
-        #   2. The rest of the events in the chunk get persisted one by one, as
-        #      each one depends on the previous event for its state.
-        #
-        # The important thing is that events in the chunk get persisted as
-        # non-outliers, including when those events are also in the state or
-        # auth chain. Caution must therefore be taken to ensure that they are
-        # not accidentally marked as outliers.
 
-        # Step 1a: persist auth events that *don't* appear in the chunk
         ev_infos = []
-        for a in auth_events.values():
-            # We only want to persist auth events as outliers that we haven't
-            # seen and aren't about to persist as part of the backfilled chunk.
-            if a.event_id in seen_events or a.event_id in event_map:
-                continue
 
-            a.internal_metadata.outlier = True
-            ev_infos.append(
-                _NewEventInfo(
-                    event=a,
-                    auth_events={
-                        (
-                            auth_events[a_id].type,
-                            auth_events[a_id].state_key,
-                        ): auth_events[a_id]
-                        for a_id in a.auth_event_ids()
-                        if a_id in auth_events
-                    },
-                )
-            )
-
-        # Step 1b: persist the events in the chunk we fetched state for (i.e.
-        # the backwards extremities) as non-outliers.
+        # Step 1: persist the events in the chunk we fetched state for (i.e.
+        # the backwards extremities), with custom auth events and state
         for e_id in events_to_state:
             # For paranoia we ensure that these events are marked as
             # non-outliers
@@ -1071,6 +1060,57 @@ class FederationHandler(BaseHandler):
 
         return False
 
+    @defer.inlineCallbacks
+    def _get_events_and_persist(
+        self, destination: str, room_id: str, events: Iterable[str]
+    ):
+        """Fetch the given events from a server, and persist them as outliers.
+
+        Logs a warning if we can't find the given event.
+        """
+
+        room_version = yield self.store.get_room_version(room_id)
+
+        event_infos = []
+
+        async def get_event(event_id: str):
+            with nested_logging_context(event_id):
+                try:
+                    event = await self.federation_client.get_pdu(
+                        [destination], event_id, room_version, outlier=True,
+                    )
+                    if event is None:
+                        logger.warning(
+                            "Server %s didn't return event %s", destination, event_id,
+                        )
+                        return
+
+                    # recursively fetch the auth events for this event
+                    auth_events = await self._get_events_from_store_or_dest(
+                        destination, room_id, event.auth_event_ids()
+                    )
+                    auth = {}
+                    for auth_event_id in event.auth_event_ids():
+                        ae = auth_events.get(auth_event_id)
+                        if ae:
+                            auth[(ae.type, ae.state_key)] = ae
+
+                    event_infos.append(_NewEventInfo(event, None, auth))
+
+                except Exception as e:
+                    logger.warning(
+                        "Error fetching missing state/auth event %s: %s %s",
+                        event_id,
+                        type(e),
+                        e,
+                    )
+
+        yield concurrently_execute(get_event, events, 5)
+
+        yield self._handle_new_events(
+            destination, event_infos,
+        )
+
     def _sanity_check_event(self, ev):
         """
         Do some early sanity checks of a received event
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 139beef8ed..0e75e94c6f 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -16,7 +16,7 @@
 
 import logging
 from collections import namedtuple
-from typing import Iterable, Optional
+from typing import Dict, Iterable, List, Optional, Tuple
 
 from six import iteritems, itervalues
 
@@ -416,6 +416,7 @@ class StateHandler(object):
 
         with Measure(self.clock, "state._resolve_events"):
             new_state = yield resolve_events_with_store(
+                event.room_id,
                 room_version,
                 state_set_ids,
                 event_map=state_map,
@@ -461,7 +462,7 @@ class StateResolutionHandler(object):
         not be called for a single state group
 
         Args:
-            room_id (str): room we are resolving for (used for logging)
+            room_id (str): room we are resolving for (used for logging and sanity checks)
             room_version (str): version of the room
             state_groups_ids (dict[int, dict[(str, str), str]]):
                  map from state group id to the state in that state group
@@ -517,6 +518,7 @@ class StateResolutionHandler(object):
                 logger.info("Resolving conflicted state for %r", room_id)
                 with Measure(self.clock, "state._resolve_events"):
                     new_state = yield resolve_events_with_store(
+                        room_id,
                         room_version,
                         list(itervalues(state_groups_ids)),
                         event_map=event_map,
@@ -588,36 +590,44 @@ def _make_state_cache_entry(new_state, state_groups_ids):
     )
 
 
-def resolve_events_with_store(room_version, state_sets, event_map, state_res_store):
+def resolve_events_with_store(
+    room_id: str,
+    room_version: str,
+    state_sets: List[Dict[Tuple[str, str], str]],
+    event_map: Optional[Dict[str, EventBase]],
+    state_res_store: "StateResolutionStore",
+):
     """
     Args:
-        room_version(str): Version of the room
+        room_id: the room we are working in
+
+        room_version: Version of the room
 
-        state_sets(list): List of dicts of (type, state_key) -> event_id,
+        state_sets: List of dicts of (type, state_key) -> event_id,
             which are the different state groups to resolve.
 
-        event_map(dict[str,FrozenEvent]|None):
+        event_map:
             a dict from event_id to event, for any events that we happen to
             have in flight (eg, those currently being persisted). This will be
             used as a starting point fof finding the state we need; any missing
             events will be requested via state_map_factory.
 
-            If None, all events will be fetched via state_map_factory.
+            If None, all events will be fetched via state_res_store.
 
-        state_res_store (StateResolutionStore)
+        state_res_store: a place to fetch events from
 
-    Returns
+    Returns:
         Deferred[dict[(str, str), str]]:
             a map from (type, state_key) to event_id.
     """
     v = KNOWN_ROOM_VERSIONS[room_version]
     if v.state_res == StateResolutionVersions.V1:
         return v1.resolve_events_with_store(
-            state_sets, event_map, state_res_store.get_events
+            room_id, state_sets, event_map, state_res_store.get_events
         )
     else:
         return v2.resolve_events_with_store(
-            room_version, state_sets, event_map, state_res_store
+            room_id, room_version, state_sets, event_map, state_res_store
         )
 
 
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index a2f92d9ff9..b2f9865f39 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -15,6 +15,7 @@
 
 import hashlib
 import logging
+from typing import Callable, Dict, List, Optional, Tuple
 
 from six import iteritems, iterkeys, itervalues
 
@@ -24,6 +25,7 @@ from synapse import event_auth
 from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 from synapse.api.room_versions import RoomVersions
+from synapse.events import EventBase
 
 logger = logging.getLogger(__name__)
 
@@ -32,13 +34,20 @@ POWER_KEY = (EventTypes.PowerLevels, "")
 
 
 @defer.inlineCallbacks
-def resolve_events_with_store(state_sets, event_map, state_map_factory):
+def resolve_events_with_store(
+    room_id: str,
+    state_sets: List[Dict[Tuple[str, str], str]],
+    event_map: Optional[Dict[str, EventBase]],
+    state_map_factory: Callable,
+):
     """
     Args:
-        state_sets(list): List of dicts of (type, state_key) -> event_id,
+        room_id: the room we are working in
+
+        state_sets: List of dicts of (type, state_key) -> event_id,
             which are the different state groups to resolve.
 
-        event_map(dict[str,FrozenEvent]|None):
+        event_map:
             a dict from event_id to event, for any events that we happen to
             have in flight (eg, those currently being persisted). This will be
             used as a starting point fof finding the state we need; any missing
@@ -46,11 +55,11 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory):
 
             If None, all events will be fetched via state_map_factory.
 
-        state_map_factory(func): will be called
+        state_map_factory: will be called
             with a list of event_ids that are needed, and should return with
             a Deferred of dict of event_id to event.
 
-    Returns
+    Returns:
         Deferred[dict[(str, str), str]]:
             a map from (type, state_key) to event_id.
     """
@@ -76,6 +85,14 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory):
     if event_map is not None:
         state_map.update(event_map)
 
+    # everything in the state map should be in the right room
+    for event in state_map.values():
+        if event.room_id != room_id:
+            raise Exception(
+                "Attempting to state-resolve for room %s with event %s which is in %s"
+                % (room_id, event.event_id, event.room_id,)
+            )
+
     # get the ids of the auth events which allow us to authenticate the
     # conflicted state, picking only from the unconflicting state.
     #
@@ -95,6 +112,13 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory):
     )
 
     state_map_new = yield state_map_factory(new_needed_events)
+    for event in state_map_new.values():
+        if event.room_id != room_id:
+            raise Exception(
+                "Attempting to state-resolve for room %s with event %s which is in %s"
+                % (room_id, event.event_id, event.room_id,)
+            )
+
     state_map.update(state_map_new)
 
     return _resolve_with_state(
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index b327c86f40..cb77ed5b78 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -16,29 +16,40 @@
 import heapq
 import itertools
 import logging
+from typing import Dict, List, Optional, Tuple
 
 from six import iteritems, itervalues
 
 from twisted.internet import defer
 
+import synapse.state
 from synapse import event_auth
 from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
+from synapse.events import EventBase
 
 logger = logging.getLogger(__name__)
 
 
 @defer.inlineCallbacks
-def resolve_events_with_store(room_version, state_sets, event_map, state_res_store):
+def resolve_events_with_store(
+    room_id: str,
+    room_version: str,
+    state_sets: List[Dict[Tuple[str, str], str]],
+    event_map: Optional[Dict[str, EventBase]],
+    state_res_store: "synapse.state.StateResolutionStore",
+):
     """Resolves the state using the v2 state resolution algorithm
 
     Args:
-        room_version (str): The room version
+        room_id: the room we are working in
+
+        room_version: The room version
 
-        state_sets(list): List of dicts of (type, state_key) -> event_id,
+        state_sets: List of dicts of (type, state_key) -> event_id,
             which are the different state groups to resolve.
 
-        event_map(dict[str,FrozenEvent]|None):
+        event_map:
             a dict from event_id to event, for any events that we happen to
             have in flight (eg, those currently being persisted). This will be
             used as a starting point fof finding the state we need; any missing
@@ -46,9 +57,9 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
 
             If None, all events will be fetched via state_res_store.
 
-        state_res_store (StateResolutionStore)
+        state_res_store:
 
-    Returns
+    Returns:
         Deferred[dict[(str, str), str]]:
             a map from (type, state_key) to event_id.
     """
@@ -84,6 +95,14 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
     )
     event_map.update(events)
 
+    # everything in the event map should be in the right room
+    for event in event_map.values():
+        if event.room_id != room_id:
+            raise Exception(
+                "Attempting to state-resolve for room %s with event %s which is in %s"
+                % (room_id, event.event_id, event.room_id,)
+            )
+
     full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map)
 
     logger.debug("%d full_conflicted_set entries", len(full_conflicted_set))
@@ -94,13 +113,14 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
     )
 
     sorted_power_events = yield _reverse_topological_power_sort(
-        power_events, event_map, state_res_store, full_conflicted_set
+        room_id, power_events, event_map, state_res_store, full_conflicted_set
     )
 
     logger.debug("sorted %d power events", len(sorted_power_events))
 
     # Now sequentially auth each one
     resolved_state = yield _iterative_auth_checks(
+        room_id,
         room_version,
         sorted_power_events,
         unconflicted_state,
@@ -121,13 +141,18 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
 
     pl = resolved_state.get((EventTypes.PowerLevels, ""), None)
     leftover_events = yield _mainline_sort(
-        leftover_events, pl, event_map, state_res_store
+        room_id, leftover_events, pl, event_map, state_res_store
     )
 
     logger.debug("resolving remaining events")
 
     resolved_state = yield _iterative_auth_checks(
-        room_version, leftover_events, resolved_state, event_map, state_res_store
+        room_id,
+        room_version,
+        leftover_events,
+        resolved_state,
+        event_map,
+        state_res_store,
     )
 
     logger.debug("resolved")
@@ -141,11 +166,12 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
 
 
 @defer.inlineCallbacks
-def _get_power_level_for_sender(event_id, event_map, state_res_store):
+def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
     """Return the power level of the sender of the given event according to
     their auth events.
 
     Args:
+        room_id (str)
         event_id (str)
         event_map (dict[str,FrozenEvent])
         state_res_store (StateResolutionStore)
@@ -153,11 +179,11 @@ def _get_power_level_for_sender(event_id, event_map, state_res_store):
     Returns:
         Deferred[int]
     """
-    event = yield _get_event(event_id, event_map, state_res_store)
+    event = yield _get_event(room_id, event_id, event_map, state_res_store)
 
     pl = None
     for aid in event.auth_event_ids():
-        aev = yield _get_event(aid, event_map, state_res_store)
+        aev = yield _get_event(room_id, aid, event_map, state_res_store)
         if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
             pl = aev
             break
@@ -165,7 +191,7 @@ def _get_power_level_for_sender(event_id, event_map, state_res_store):
     if pl is None:
         # Couldn't find power level. Check if they're the creator of the room
         for aid in event.auth_event_ids():
-            aev = yield _get_event(aid, event_map, state_res_store)
+            aev = yield _get_event(room_id, aid, event_map, state_res_store)
             if (aev.type, aev.state_key) == (EventTypes.Create, ""):
                 if aev.content.get("creator") == event.sender:
                     return 100
@@ -279,7 +305,7 @@ def _is_power_event(event):
 
 @defer.inlineCallbacks
 def _add_event_and_auth_chain_to_graph(
-    graph, event_id, event_map, state_res_store, auth_diff
+    graph, room_id, event_id, event_map, state_res_store, auth_diff
 ):
     """Helper function for _reverse_topological_power_sort that add the event
     and its auth chain (that is in the auth diff) to the graph
@@ -287,6 +313,7 @@ def _add_event_and_auth_chain_to_graph(
     Args:
         graph (dict[str, set[str]]): A map from event ID to the events auth
             event IDs
+        room_id (str): the room we are working in
         event_id (str): Event to add to the graph
         event_map (dict[str,FrozenEvent])
         state_res_store (StateResolutionStore)
@@ -298,7 +325,7 @@ def _add_event_and_auth_chain_to_graph(
         eid = state.pop()
         graph.setdefault(eid, set())
 
-        event = yield _get_event(eid, event_map, state_res_store)
+        event = yield _get_event(room_id, eid, event_map, state_res_store)
         for aid in event.auth_event_ids():
             if aid in auth_diff:
                 if aid not in graph:
@@ -308,11 +335,14 @@ def _add_event_and_auth_chain_to_graph(
 
 
 @defer.inlineCallbacks
-def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_diff):
+def _reverse_topological_power_sort(
+    room_id, event_ids, event_map, state_res_store, auth_diff
+):
     """Returns a list of the event_ids sorted by reverse topological ordering,
     and then by power level and origin_server_ts
 
     Args:
+        room_id (str): the room we are working in
         event_ids (list[str]): The events to sort
         event_map (dict[str,FrozenEvent])
         state_res_store (StateResolutionStore)
@@ -325,12 +355,14 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_
     graph = {}
     for event_id in event_ids:
         yield _add_event_and_auth_chain_to_graph(
-            graph, event_id, event_map, state_res_store, auth_diff
+            graph, room_id, event_id, event_map, state_res_store, auth_diff
         )
 
     event_to_pl = {}
     for event_id in graph:
-        pl = yield _get_power_level_for_sender(event_id, event_map, state_res_store)
+        pl = yield _get_power_level_for_sender(
+            room_id, event_id, event_map, state_res_store
+        )
         event_to_pl[event_id] = pl
 
     def _get_power_order(event_id):
@@ -348,12 +380,13 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_
 
 @defer.inlineCallbacks
 def _iterative_auth_checks(
-    room_version, event_ids, base_state, event_map, state_res_store
+    room_id, room_version, event_ids, base_state, event_map, state_res_store
 ):
     """Sequentially apply auth checks to each event in given list, updating the
     state as it goes along.
 
     Args:
+        room_id (str)
         room_version (str)
         event_ids (list[str]): Ordered list of events to apply auth checks to
         base_state (dict[tuple[str, str], str]): The set of state to start with
@@ -370,7 +403,7 @@ def _iterative_auth_checks(
 
         auth_events = {}
         for aid in event.auth_event_ids():
-            ev = yield _get_event(aid, event_map, state_res_store)
+            ev = yield _get_event(room_id, aid, event_map, state_res_store)
 
             if ev.rejected_reason is None:
                 auth_events[(ev.type, ev.state_key)] = ev
@@ -378,7 +411,7 @@ def _iterative_auth_checks(
         for key in event_auth.auth_types_for_event(event):
             if key in resolved_state:
                 ev_id = resolved_state[key]
-                ev = yield _get_event(ev_id, event_map, state_res_store)
+                ev = yield _get_event(room_id, ev_id, event_map, state_res_store)
 
                 if ev.rejected_reason is None:
                     auth_events[key] = event_map[ev_id]
@@ -400,11 +433,14 @@ def _iterative_auth_checks(
 
 
 @defer.inlineCallbacks
-def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_store):
+def _mainline_sort(
+    room_id, event_ids, resolved_power_event_id, event_map, state_res_store
+):
     """Returns a sorted list of event_ids sorted by mainline ordering based on
     the given event resolved_power_event_id
 
     Args:
+        room_id (str): room we're working in
         event_ids (list[str]): Events to sort
         resolved_power_event_id (str): The final resolved power level event ID
         event_map (dict[str,FrozenEvent])
@@ -417,11 +453,11 @@ def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_stor
     pl = resolved_power_event_id
     while pl:
         mainline.append(pl)
-        pl_ev = yield _get_event(pl, event_map, state_res_store)
+        pl_ev = yield _get_event(room_id, pl, event_map, state_res_store)
         auth_events = pl_ev.auth_event_ids()
         pl = None
         for aid in auth_events:
-            ev = yield _get_event(aid, event_map, state_res_store)
+            ev = yield _get_event(room_id, aid, event_map, state_res_store)
             if (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""):
                 pl = aid
                 break
@@ -457,6 +493,8 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
         Deferred[int]
     """
 
+    room_id = event.room_id
+
     # We do an iterative search, replacing `event with the power level in its
     # auth events (if any)
     while event:
@@ -468,7 +506,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
         event = None
 
         for aid in auth_events:
-            aev = yield _get_event(aid, event_map, state_res_store)
+            aev = yield _get_event(room_id, aid, event_map, state_res_store)
             if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
                 event = aev
                 break
@@ -478,11 +516,12 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
 
 
 @defer.inlineCallbacks
-def _get_event(event_id, event_map, state_res_store):
+def _get_event(room_id, event_id, event_map, state_res_store):
     """Helper function to look up event in event_map, falling back to looking
     it up in the store
 
     Args:
+        room_id (str)
         event_id (str)
         event_map (dict[str,FrozenEvent])
         state_res_store (StateResolutionStore)
@@ -493,7 +532,14 @@ def _get_event(event_id, event_map, state_res_store):
     if event_id not in event_map:
         events = yield state_res_store.get_events([event_id], allow_rejected=True)
         event_map.update(events)
-    return event_map[event_id]
+    event = event_map[event_id]
+    assert event is not None
+    if event.room_id != room_id:
+        raise Exception(
+            "In state res for room %s, event %s is in %s"
+            % (room_id, event_id, event.room_id)
+        )
+    return event
 
 
 def lexicographical_topological_sort(graph, key):
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 5c4de2e69f..04b6abdc24 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -140,8 +140,8 @@ def concurrently_execute(func, args, limit):
 
     Args:
         func (func): Function to execute, should return a deferred or coroutine.
-        args (list): List of arguments to pass to func, each invocation of func
-            gets a signle argument.
+        args (Iterable): List of arguments to pass to func, each invocation of func
+            gets a single argument.
         limit (int): Maximum number of conccurent executions.
 
     Returns:
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index 8d3845c870..0f341d3ac3 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -58,6 +58,7 @@ class FakeEvent(object):
         self.type = type
         self.state_key = state_key
         self.content = content
+        self.room_id = ROOM_ID
 
     def to_event(self, auth_events, prev_events):
         """Given the auth_events and prev_events, convert to a Frozen Event
@@ -418,6 +419,7 @@ class StateTestCase(unittest.TestCase):
                 state_before = dict(state_at_event[prev_events[0]])
             else:
                 state_d = resolve_events_with_store(
+                    ROOM_ID,
                     RoomVersions.V2.identifier,
                     [state_at_event[n] for n in prev_events],
                     event_map=event_map,
@@ -565,6 +567,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
         # Test that we correctly handle passing `None` as the event_map
 
         state_d = resolve_events_with_store(
+            ROOM_ID,
             RoomVersions.V2.identifier,
             [self.state_at_bob, self.state_at_charlie],
             event_map=None,