summary refs log tree commit diff
path: root/synapse/handlers/federation.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/federation.py')
-rw-r--r--synapse/handlers/federation.py1666
1 files changed, 898 insertions, 768 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index f72b81d419..3e60774b33 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -19,17 +19,20 @@
 
 import itertools
 import logging
+from typing import Dict, Iterable, List, Optional, Sequence, Tuple
 
 import six
 from six import iteritems, itervalues
 from six.moves import http_client, zip
 
+import attr
 from signedjson.key import decode_verify_key_bytes
 from signedjson.sign import verify_signed_json
 from unpaddedbase64 import decode_base64
 
 from twisted.internet import defer
 
+from synapse import event_auth
 from synapse.api.constants import EventTypes, Membership, RejectedReason
 from synapse.api.errors import (
     AuthError,
@@ -37,14 +40,17 @@ from synapse.api.errors import (
     Codes,
     FederationDeniedError,
     FederationError,
+    HttpResponseException,
     RequestSendFailed,
-    StoreError,
     SynapseError,
 )
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion, RoomVersions
 from synapse.crypto.event_signing import compute_event_signature
 from synapse.event_auth import auth_types_for_event
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
 from synapse.events.validator import EventValidator
+from synapse.handlers._base import BaseHandler
 from synapse.logging.context import (
     make_deferred_yieldable,
     nested_logging_context,
@@ -52,52 +58,49 @@ from synapse.logging.context import (
     run_in_background,
 )
 from synapse.logging.utils import log_function
+from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
 from synapse.replication.http.federation import (
     ReplicationCleanRoomRestServlet,
     ReplicationFederationSendEventsRestServlet,
+    ReplicationStoreRoomOnInviteRestServlet,
 )
 from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
 from synapse.state import StateResolutionStore, resolve_events_with_store
-from synapse.types import UserID, get_domain_from_id
-from synapse.util import unwrapFirstError
-from synapse.util.async_helpers import Linearizer
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
+from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id
+from synapse.util.async_helpers import Linearizer, concurrently_execute
 from synapse.util.distributor import user_joined_room
 from synapse.util.retryutils import NotRetryingDestination
+from synapse.util.stringutils import shortstr
 from synapse.visibility import filter_events_for_server
 
-from ._base import BaseHandler
-
 logger = logging.getLogger(__name__)
 
 
-def shortstr(iterable, maxitems=5):
-    """If iterable has maxitems or fewer, return the stringification of a list
-    containing those items.
+@attr.s
+class _NewEventInfo:
+    """Holds information about a received event, ready for passing to _handle_new_events
 
-    Otherwise, return the stringification of a a list with the first maxitems items,
-    followed by "...".
+    Attributes:
+        event: the received event
 
-    Args:
-        iterable (Iterable): iterable to truncate
-        maxitems (int): number of items to return before truncating
+        state: the state at that event
 
-    Returns:
-        unicode
+        auth_events: the auth_event map for that event
     """
 
-    items = list(itertools.islice(iterable, maxitems + 1))
-    if len(items) <= maxitems:
-        return str(items)
-    return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]"
+    event = attr.ib(type=EventBase)
+    state = attr.ib(type=Optional[Sequence[EventBase]], default=None)
+    auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None)
 
 
 class FederationHandler(BaseHandler):
     """Handles events that originated from federation.
         Responsible for:
         a) handling received Pdus before handing them on as Events to the rest
-        of the home server (including auth and state conflict resoultion)
+        of the homeserver (including auth and state conflict resoultion)
         b) converting events that were produced by local clients that may need
-        to be sent to remote home servers.
+        to be sent to remote homeservers.
         c) doing the necessary dances to invite remote users and join remote
         rooms.
     """
@@ -108,6 +111,8 @@ class FederationHandler(BaseHandler):
         self.hs = hs
 
         self.store = hs.get_datastore()
+        self.storage = hs.get_storage()
+        self.state_store = self.storage.state
         self.federation_client = hs.get_federation_client()
         self.state_handler = hs.get_state_handler()
         self.server_name = hs.hostname
@@ -117,13 +122,14 @@ class FederationHandler(BaseHandler):
         self.pusher_pool = hs.get_pusherpool()
         self.spam_checker = hs.get_spam_checker()
         self.event_creation_handler = hs.get_event_creation_handler()
+        self._message_handler = hs.get_message_handler()
         self._server_notices_mxid = hs.config.server_notices_mxid
         self.config = hs.config
         self.http_client = hs.get_simple_http_client()
+        self._instance_name = hs.get_instance_name()
+        self._replication = hs.get_replication_data_handler()
 
-        self._send_events_to_master = ReplicationFederationSendEventsRestServlet.make_client(
-            hs
-        )
+        self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs)
         self._notify_user_membership_change = ReplicationUserJoinedLeftRoomRestServlet.make_client(
             hs
         )
@@ -131,14 +137,26 @@ class FederationHandler(BaseHandler):
             hs
         )
 
+        if hs.config.worker_app:
+            self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client(
+                hs
+            )
+            self._maybe_store_room_on_invite = ReplicationStoreRoomOnInviteRestServlet.make_client(
+                hs
+            )
+        else:
+            self._device_list_updater = hs.get_device_handler().device_list_updater
+            self._maybe_store_room_on_invite = self.store.maybe_store_room_on_invite
+
         # When joining a room we need to queue any events for that room up
         self.room_queues = {}
         self._room_pdu_linearizer = Linearizer("fed_room_pdu")
 
         self.third_party_event_rules = hs.get_third_party_event_rules()
 
-    @defer.inlineCallbacks
-    def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
+        self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
+
+    async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
         """ Process a PDU received via a federation /send/ transaction, or
         via backfill of missing prev_events
 
@@ -148,17 +166,15 @@ class FederationHandler(BaseHandler):
             pdu (FrozenEvent): received PDU
             sent_to_us_directly (bool): True if this event was pushed to us; False if
                 we pulled it as the result of a missing prev_event.
-
-        Returns (Deferred): completes with None
         """
 
         room_id = pdu.room_id
         event_id = pdu.event_id
 
-        logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu)
+        logger.info("handling received PDU: %s", pdu)
 
         # We reprocess pdus when we have seen them only as outliers
-        existing = yield self.store.get_event(
+        existing = await self.store.get_event(
             event_id, allow_none=True, allow_rejected=True
         )
 
@@ -179,7 +195,7 @@ class FederationHandler(BaseHandler):
         try:
             self._sanity_check_event(pdu)
         except SynapseError as err:
-            logger.warn(
+            logger.warning(
                 "[%s %s] Received event failed sanity checks", room_id, event_id
             )
             raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id)
@@ -202,7 +218,7 @@ class FederationHandler(BaseHandler):
         #
         # Note that if we were never in the room then we would have already
         # dropped the event, since we wouldn't know the room version.
-        is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name)
+        is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
         if not is_in_room:
             logger.info(
                 "[%s %s] Ignoring PDU from %s as we're not in the room",
@@ -213,25 +229,24 @@ class FederationHandler(BaseHandler):
             return None
 
         state = None
-        auth_chain = []
 
         # Get missing pdus if necessary.
         if not pdu.internal_metadata.is_outlier():
             # We only backfill backwards to the min depth.
-            min_depth = yield self.get_min_depth_for_context(pdu.room_id)
+            min_depth = await self.get_min_depth_for_context(pdu.room_id)
 
             logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
 
             prevs = set(pdu.prev_event_ids())
-            seen = yield self.store.have_seen_events(prevs)
+            seen = await self.store.have_seen_events(prevs)
 
-            if min_depth and pdu.depth < min_depth:
+            if min_depth is not None and pdu.depth < min_depth:
                 # This is so that we don't notify the user about this
                 # message, to work around the fact that some events will
                 # reference really really old events we really don't want to
                 # send to the clients.
                 pdu.internal_metadata.outlier = True
-            elif min_depth and pdu.depth > min_depth:
+            elif min_depth is not None and pdu.depth > min_depth:
                 missing_prevs = prevs - seen
                 if sent_to_us_directly and missing_prevs:
                     # If we're missing stuff, ensure we only fetch stuff one
@@ -243,7 +258,7 @@ class FederationHandler(BaseHandler):
                         len(missing_prevs),
                         shortstr(missing_prevs),
                     )
-                    with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
+                    with (await self._room_pdu_linearizer.queue(pdu.room_id)):
                         logger.info(
                             "[%s %s] Acquired room lock to fetch %d missing prev_events",
                             room_id,
@@ -251,13 +266,19 @@ class FederationHandler(BaseHandler):
                             len(missing_prevs),
                         )
 
-                        yield self._get_missing_events_for_pdu(
-                            origin, pdu, prevs, min_depth
-                        )
+                        try:
+                            await self._get_missing_events_for_pdu(
+                                origin, pdu, prevs, min_depth
+                            )
+                        except Exception as e:
+                            raise Exception(
+                                "Error fetching missing prev_events for %s: %s"
+                                % (event_id, e)
+                            )
 
                         # Update the set of things we've seen after trying to
                         # fetch the missing stuff
-                        seen = yield self.store.have_seen_events(prevs)
+                        seen = await self.store.have_seen_events(prevs)
 
                         if not prevs - seen:
                             logger.info(
@@ -265,14 +286,6 @@ class FederationHandler(BaseHandler):
                                 room_id,
                                 event_id,
                             )
-                elif missing_prevs:
-                    logger.info(
-                        "[%s %s] Not recursively fetching %d missing prev_events: %s",
-                        room_id,
-                        event_id,
-                        len(missing_prevs),
-                        shortstr(missing_prevs),
-                    )
 
             if prevs - seen:
                 # We've still not been able to get all of the prev_events for this event.
@@ -300,7 +313,7 @@ class FederationHandler(BaseHandler):
                 # following.
 
                 if sent_to_us_directly:
-                    logger.warn(
+                    logger.warning(
                         "[%s %s] Rejecting: failed to fetch %d prev events: %s",
                         room_id,
                         event_id,
@@ -317,18 +330,21 @@ class FederationHandler(BaseHandler):
                         affected=pdu.event_id,
                     )
 
+                logger.info(
+                    "Event %s is missing prev_events: calculating state for a "
+                    "backwards extremity",
+                    event_id,
+                )
+
                 # Calculate the state after each of the previous events, and
                 # resolve them to find the correct state at the current event.
-                auth_chains = set()
                 event_map = {event_id: pdu}
                 try:
                     # Get the state of the events we know about
-                    ours = yield self.store.get_state_groups_ids(room_id, seen)
+                    ours = await self.state_store.get_state_groups_ids(room_id, seen)
 
                     # state_maps is a list of mappings from (type, state_key) to event_id
-                    state_maps = list(
-                        ours.values()
-                    )  # type: list[dict[tuple[str, str], str]]
+                    state_maps = list(ours.values())  # type: List[StateMap[str]]
 
                     # we don't need this any more, let's delete it.
                     del ours
@@ -337,43 +353,17 @@ class FederationHandler(BaseHandler):
                     # know about
                     for p in prevs - seen:
                         logger.info(
-                            "[%s %s] Requesting state at missing prev_event %s",
-                            room_id,
-                            event_id,
-                            p,
+                            "Requesting state at missing prev_event %s", event_id,
                         )
 
-                        room_version = yield self.store.get_room_version(room_id)
-
                         with nested_logging_context(p):
                             # note that if any of the missing prevs share missing state or
                             # auth events, the requests to fetch those events are deduped
                             # by the get_pdu_cache in federation_client.
-                            remote_state, got_auth_chain = (
-                                yield self.federation_client.get_state_for_room(
-                                    origin, room_id, p
-                                )
+                            (remote_state, _,) = await self._get_state_for_room(
+                                origin, room_id, p, include_event_in_state=True
                             )
 
-                            # we want the state *after* p; get_state_for_room returns the
-                            # state *before* p.
-                            remote_event = yield self.federation_client.get_pdu(
-                                [origin], p, room_version, outlier=True
-                            )
-
-                            if remote_event is None:
-                                raise Exception(
-                                    "Unable to get missing prev_event %s" % (p,)
-                                )
-
-                            if remote_event.is_state():
-                                remote_state.append(remote_event)
-
-                            # XXX hrm I'm not convinced that duplicate events will compare
-                            # for equality, so I'm not sure this does what the author
-                            # hoped.
-                            auth_chains.update(got_auth_chain)
-
                             remote_state_map = {
                                 (x.type, x.state_key): x.event_id for x in remote_state
                             }
@@ -382,7 +372,9 @@ class FederationHandler(BaseHandler):
                             for x in remote_state:
                                 event_map[x.event_id] = x
 
-                    state_map = yield resolve_events_with_store(
+                    room_version = await self.store.get_room_version_id(room_id)
+                    state_map = await resolve_events_with_store(
+                        room_id,
                         room_version,
                         state_maps,
                         event_map,
@@ -394,17 +386,16 @@ class FederationHandler(BaseHandler):
 
                     # First though we need to fetch all the events that are in
                     # state_map, so we can build up the state below.
-                    evs = yield self.store.get_events(
+                    evs = await self.store.get_events(
                         list(state_map.values()),
                         get_prev_content=False,
-                        check_redacted=False,
+                        redact_behaviour=EventRedactBehaviour.AS_IS,
                     )
                     event_map.update(evs)
 
                     state = [event_map[e] for e in six.itervalues(state_map)]
-                    auth_chain = list(auth_chains)
                 except Exception:
-                    logger.warn(
+                    logger.warning(
                         "[%s %s] Error attempting to resolve state at missing "
                         "prev_events",
                         room_id,
@@ -418,12 +409,9 @@ class FederationHandler(BaseHandler):
                         affected=event_id,
                     )
 
-        yield self._process_received_pdu(
-            origin, pdu, state=state, auth_chain=auth_chain
-        )
+        await self._process_received_pdu(origin, pdu, state=state)
 
-    @defer.inlineCallbacks
-    def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
+    async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
         """
         Args:
             origin (str): Origin of the pdu. Will be called to get the missing events
@@ -435,12 +423,12 @@ class FederationHandler(BaseHandler):
         room_id = pdu.room_id
         event_id = pdu.event_id
 
-        seen = yield self.store.have_seen_events(prevs)
+        seen = await self.store.have_seen_events(prevs)
 
         if not prevs - seen:
             return
 
-        latest = yield self.store.get_latest_event_ids_in_room(room_id)
+        latest = await self.store.get_latest_event_ids_in_room(room_id)
 
         # We add the prev events that we have seen to the latest
         # list to ensure the remote server doesn't give them to us
@@ -504,7 +492,7 @@ class FederationHandler(BaseHandler):
         # All that said: Let's try increasing the timout to 60s and see what happens.
 
         try:
-            missing_events = yield self.federation_client.get_missing_events(
+            missing_events = await self.federation_client.get_missing_events(
                 origin,
                 room_id,
                 earliest_events_ids=list(latest),
@@ -513,11 +501,13 @@ class FederationHandler(BaseHandler):
                 min_depth=min_depth,
                 timeout=60000,
             )
-        except RequestSendFailed as e:
+        except (RequestSendFailed, HttpResponseException, NotRetryingDestination) as e:
             # We failed to get the missing events, but since we need to handle
             # the case of `get_missing_events` not returning the necessary
             # events anyway, it is safe to simply log the error and continue.
-            logger.warn("[%s %s]: Failed to get prev_events: %s", room_id, event_id, e)
+            logger.warning(
+                "[%s %s]: Failed to get prev_events: %s", room_id, event_id, e
+            )
             return
 
         logger.info(
@@ -541,10 +531,10 @@ class FederationHandler(BaseHandler):
             )
             with nested_logging_context(ev.event_id):
                 try:
-                    yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
+                    await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
                 except FederationError as e:
                     if e.code == 403:
-                        logger.warn(
+                        logger.warning(
                             "[%s %s] Received prev_event %s failed history check.",
                             room_id,
                             event_id,
@@ -553,66 +543,154 @@ class FederationHandler(BaseHandler):
                     else:
                         raise
 
-    @defer.inlineCallbacks
-    def _process_received_pdu(self, origin, event, state, auth_chain):
-        """ Called when we have a new pdu. We need to do auth checks and put it
-        through the StateHandler.
+    async def _get_state_for_room(
+        self,
+        destination: str,
+        room_id: str,
+        event_id: str,
+        include_event_in_state: bool = False,
+    ) -> Tuple[List[EventBase], List[EventBase]]:
+        """Requests all of the room state at a given event from a remote homeserver.
+
+        Args:
+            destination: The remote homeserver to query for the state.
+            room_id: The id of the room we're interested in.
+            event_id: The id of the event we want the state at.
+            include_event_in_state: if true, the event itself will be included in the
+                returned state event list.
+
+        Returns:
+            A list of events in the state, possibly including the event itself, and
+            a list of events in the auth chain for the given event.
         """
-        room_id = event.room_id
-        event_id = event.event_id
+        (
+            state_event_ids,
+            auth_event_ids,
+        ) = await self.federation_client.get_room_state_ids(
+            destination, room_id, event_id=event_id
+        )
 
-        logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
+        desired_events = set(state_event_ids + auth_event_ids)
 
-        event_ids = set()
-        if state:
-            event_ids |= {e.event_id for e in state}
-        if auth_chain:
-            event_ids |= {e.event_id for e in auth_chain}
+        if include_event_in_state:
+            desired_events.add(event_id)
 
-        seen_ids = yield self.store.have_seen_events(event_ids)
+        event_map = await self._get_events_from_store_or_dest(
+            destination, room_id, desired_events
+        )
 
-        if state and auth_chain is not None:
-            # If we have any state or auth_chain given to us by the replication
-            # layer, then we should handle them (if we haven't before.)
+        failed_to_fetch = desired_events - event_map.keys()
+        if failed_to_fetch:
+            logger.warning(
+                "Failed to fetch missing state/auth events for %s %s",
+                event_id,
+                failed_to_fetch,
+            )
 
-            event_infos = []
+        remote_state = [
+            event_map[e_id] for e_id in state_event_ids if e_id in event_map
+        ]
 
-            for e in itertools.chain(auth_chain, state):
-                if e.event_id in seen_ids:
-                    continue
-                e.internal_metadata.outlier = True
-                auth_ids = e.auth_event_ids()
-                auth = {
-                    (e.type, e.state_key): e
-                    for e in auth_chain
-                    if e.event_id in auth_ids or e.type == EventTypes.Create
-                }
-                event_infos.append({"event": e, "auth_events": auth})
-                seen_ids.add(e.event_id)
+        if include_event_in_state:
+            remote_event = event_map.get(event_id)
+            if not remote_event:
+                raise Exception("Unable to get missing prev_event %s" % (event_id,))
+            if remote_event.is_state() and remote_event.rejected_reason is None:
+                remote_state.append(remote_event)
 
-            logger.info(
-                "[%s %s] persisting newly-received auth/state events %s",
+        auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
+        auth_chain.sort(key=lambda e: e.depth)
+
+        return remote_state, auth_chain
+
+    async def _get_events_from_store_or_dest(
+        self, destination: str, room_id: str, event_ids: Iterable[str]
+    ) -> Dict[str, EventBase]:
+        """Fetch events from a remote destination, checking if we already have them.
+
+        Persists any events we don't already have as outliers.
+
+        If we fail to fetch any of the events, a warning will be logged, and the event
+        will be omitted from the result. Likewise, any events which turn out not to
+        be in the given room.
+
+        Returns:
+            map from event_id to event
+        """
+        fetched_events = await self.store.get_events(event_ids, allow_rejected=True)
+
+        missing_events = set(event_ids) - fetched_events.keys()
+
+        if missing_events:
+            logger.debug(
+                "Fetching unknown state/auth events %s for room %s",
+                missing_events,
                 room_id,
-                event_id,
-                [e["event"].event_id for e in event_infos],
             )
-            yield self._handle_new_events(origin, event_infos)
+
+            await self._get_events_and_persist(
+                destination=destination, room_id=room_id, events=missing_events
+            )
+
+            # we need to make sure we re-load from the database to get the rejected
+            # state correct.
+            fetched_events.update(
+                (await self.store.get_events(missing_events, allow_rejected=True))
+            )
+
+        # check for events which were in the wrong room.
+        #
+        # this can happen if a remote server claims that the state or
+        # auth_events at an event in room A are actually events in room B
+
+        bad_events = [
+            (event_id, event.room_id)
+            for event_id, event in fetched_events.items()
+            if event.room_id != room_id
+        ]
+
+        for bad_event_id, bad_room_id in bad_events:
+            # This is a bogus situation, but since we may only discover it a long time
+            # after it happened, we try our best to carry on, by just omitting the
+            # bad events from the returned auth/state set.
+            logger.warning(
+                "Remote server %s claims event %s in room %s is an auth/state "
+                "event in room %s",
+                destination,
+                bad_event_id,
+                bad_room_id,
+                room_id,
+            )
+
+            del fetched_events[bad_event_id]
+
+        return fetched_events
+
+    async def _process_received_pdu(
+        self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]],
+    ):
+        """ Called when we have a new pdu. We need to do auth checks and put it
+        through the StateHandler.
+
+        Args:
+            origin: server sending the event
+
+            event: event to be persisted
+
+            state: Normally None, but if we are handling a gap in the graph
+                (ie, we are missing one or more prev_events), the resolved state at the
+                event
+        """
+        room_id = event.room_id
+        event_id = event.event_id
+
+        logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
 
         try:
-            context = yield self._handle_new_event(origin, event, state=state)
+            context = await self._handle_new_event(origin, event, state=state)
         except AuthError as e:
             raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
 
-        room = yield self.store.get_room(room_id)
-
-        if not room:
-            try:
-                yield self.store.store_room(
-                    room_id=room_id, room_creator_user_id="", is_public=False
-                )
-            except StoreError:
-                logger.exception("Failed to store room.")
-
         if event.type == EventTypes.Member:
             if event.membership == Membership.JOIN:
                 # Only fire user_joined_room if the user has acutally
@@ -620,11 +698,11 @@ class FederationHandler(BaseHandler):
                 # changing their profile info.
                 newly_joined = True
 
-                prev_state_ids = yield context.get_prev_state_ids(self.store)
+                prev_state_ids = await context.get_prev_state_ids()
 
                 prev_state_id = prev_state_ids.get((event.type, event.state_key))
                 if prev_state_id:
-                    prev_state = yield self.store.get_event(
+                    prev_state = await self.store.get_event(
                         prev_state_id, allow_none=True
                     )
                     if prev_state and prev_state.membership == Membership.JOIN:
@@ -632,11 +710,82 @@ class FederationHandler(BaseHandler):
 
                 if newly_joined:
                     user = UserID.from_string(event.state_key)
-                    yield self.user_joined_room(user, room_id)
+                    await self.user_joined_room(user, room_id)
+
+        # For encrypted messages we check that we know about the sending device,
+        # if we don't then we mark the device cache for that user as stale.
+        if event.type == EventTypes.Encrypted:
+            device_id = event.content.get("device_id")
+            sender_key = event.content.get("sender_key")
+
+            cached_devices = await self.store.get_cached_devices_for_user(event.sender)
+
+            resync = False  # Whether we should resync device lists.
+
+            device = None
+            if device_id is not None:
+                device = cached_devices.get(device_id)
+                if device is None:
+                    logger.info(
+                        "Received event from remote device not in our cache: %s %s",
+                        event.sender,
+                        device_id,
+                    )
+                    resync = True
+
+            # We also check if the `sender_key` matches what we expect.
+            if sender_key is not None:
+                # Figure out what sender key we're expecting. If we know the
+                # device and recognize the algorithm then we can work out the
+                # exact key to expect. Otherwise check it matches any key we
+                # have for that device.
+                if device:
+                    keys = device.get("keys", {}).get("keys", {})
+
+                    if event.content.get("algorithm") == "m.megolm.v1.aes-sha2":
+                        # For this algorithm we expect a curve25519 key.
+                        key_name = "curve25519:%s" % (device_id,)
+                        current_keys = [keys.get(key_name)]
+                    else:
+                        # We don't know understand the algorithm, so we just
+                        # check it matches a key for the device.
+                        current_keys = keys.values()
+                elif device_id:
+                    # We don't have any keys for the device ID.
+                    current_keys = []
+                else:
+                    # The event didn't include a device ID, so we just look for
+                    # keys across all devices.
+                    current_keys = (
+                        key
+                        for device in cached_devices
+                        for key in device.get("keys", {}).get("keys", {}).values()
+                    )
+
+                # We now check that the sender key matches (one of) the expected
+                # keys.
+                if sender_key not in current_keys:
+                    logger.info(
+                        "Received event from remote device with unexpected sender key: %s %s: %s",
+                        event.sender,
+                        device_id or "<no device_id>",
+                        sender_key,
+                    )
+                    resync = True
+
+            if resync:
+                await self.store.mark_remote_user_device_cache_as_stale(event.sender)
+
+                # Immediately attempt a resync in the background
+                if self.config.worker_app:
+                    return run_in_background(self._user_device_resync, event.sender)
+                else:
+                    return run_in_background(
+                        self._device_list_updater.user_device_resync, event.sender
+                    )
 
     @log_function
-    @defer.inlineCallbacks
-    def backfill(self, dest, room_id, limit, extremities):
+    async def backfill(self, dest, room_id, limit, extremities):
         """ Trigger a backfill request to `dest` for the given `room_id`
 
         This will attempt to get more events from the remote. If the other side
@@ -653,9 +802,7 @@ class FederationHandler(BaseHandler):
         if dest == self.server_name:
             raise SynapseError(400, "Can't backfill from self.")
 
-        room_version = yield self.store.get_room_version(room_id)
-
-        events = yield self.federation_client.backfill(
+        events = await self.federation_client.backfill(
             dest, room_id, limit=limit, extremities=extremities
         )
 
@@ -670,8 +817,8 @@ class FederationHandler(BaseHandler):
         #     self._sanity_check_event(ev)
 
         # Don't bother processing events we already have.
-        seen_events = yield self.store.have_events_in_timeline(
-            set(e.event_id for e in events)
+        seen_events = await self.store.have_events_in_timeline(
+            {e.event_id for e in events}
         )
 
         events = [e for e in events if e.event_id not in seen_events]
@@ -681,8 +828,11 @@ class FederationHandler(BaseHandler):
 
         event_map = {e.event_id: e for e in events}
 
-        event_ids = set(e.event_id for e in events)
+        event_ids = {e.event_id for e in events}
 
+        # build a list of events whose prev_events weren't in the batch.
+        # (XXX: this will include events whose prev_events we already have; that doesn't
+        # sound right?)
         edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids]
 
         logger.info("backfill: Got %d events with %d edges", len(events), len(edges))
@@ -693,113 +843,32 @@ class FederationHandler(BaseHandler):
         state_events = {}
         events_to_state = {}
         for e_id in edges:
-            state, auth = yield self.federation_client.get_state_for_room(
-                destination=dest, room_id=room_id, event_id=e_id
+            state, auth = await self._get_state_for_room(
+                destination=dest,
+                room_id=room_id,
+                event_id=e_id,
+                include_event_in_state=False,
             )
             auth_events.update({a.event_id: a for a in auth})
             auth_events.update({s.event_id: s for s in state})
             state_events.update({s.event_id: s for s in state})
             events_to_state[e_id] = state
 
-        required_auth = set(
+        required_auth = {
             a_id
             for event in events
             + list(state_events.values())
             + list(auth_events.values())
             for a_id in event.auth_event_ids()
-        )
+        }
         auth_events.update(
             {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map}
         )
-        missing_auth = required_auth - set(auth_events)
-        failed_to_fetch = set()
-
-        # Try and fetch any missing auth events from both DB and remote servers.
-        # We repeatedly do this until we stop finding new auth events.
-        while missing_auth - failed_to_fetch:
-            logger.info("Missing auth for backfill: %r", missing_auth)
-            ret_events = yield self.store.get_events(missing_auth - failed_to_fetch)
-            auth_events.update(ret_events)
-
-            required_auth.update(
-                a_id for event in ret_events.values() for a_id in event.auth_event_ids()
-            )
-            missing_auth = required_auth - set(auth_events)
-
-            if missing_auth - failed_to_fetch:
-                logger.info(
-                    "Fetching missing auth for backfill: %r",
-                    missing_auth - failed_to_fetch,
-                )
-
-                results = yield make_deferred_yieldable(
-                    defer.gatherResults(
-                        [
-                            run_in_background(
-                                self.federation_client.get_pdu,
-                                [dest],
-                                event_id,
-                                room_version=room_version,
-                                outlier=True,
-                                timeout=10000,
-                            )
-                            for event_id in missing_auth - failed_to_fetch
-                        ],
-                        consumeErrors=True,
-                    )
-                ).addErrback(unwrapFirstError)
-                auth_events.update({a.event_id: a for a in results if a})
-                required_auth.update(
-                    a_id
-                    for event in results
-                    if event
-                    for a_id in event.auth_event_ids()
-                )
-                missing_auth = required_auth - set(auth_events)
-
-                failed_to_fetch = missing_auth - set(auth_events)
-
-        seen_events = yield self.store.have_seen_events(
-            set(auth_events.keys()) | set(state_events.keys())
-        )
-
-        # We now have a chunk of events plus associated state and auth chain to
-        # persist. We do the persistence in two steps:
-        #   1. Auth events and state get persisted as outliers, plus the
-        #      backward extremities get persisted (as non-outliers).
-        #   2. The rest of the events in the chunk get persisted one by one, as
-        #      each one depends on the previous event for its state.
-        #
-        # The important thing is that events in the chunk get persisted as
-        # non-outliers, including when those events are also in the state or
-        # auth chain. Caution must therefore be taken to ensure that they are
-        # not accidentally marked as outliers.
 
-        # Step 1a: persist auth events that *don't* appear in the chunk
         ev_infos = []
-        for a in auth_events.values():
-            # We only want to persist auth events as outliers that we haven't
-            # seen and aren't about to persist as part of the backfilled chunk.
-            if a.event_id in seen_events or a.event_id in event_map:
-                continue
-
-            a.internal_metadata.outlier = True
-            ev_infos.append(
-                {
-                    "event": a,
-                    "auth_events": {
-                        (
-                            auth_events[a_id].type,
-                            auth_events[a_id].state_key,
-                        ): auth_events[a_id]
-                        for a_id in a.auth_event_ids()
-                        if a_id in auth_events
-                    },
-                }
-            )
 
-        # Step 1b: persist the events in the chunk we fetched state for (i.e.
-        # the backwards extremities) as non-outliers.
+        # Step 1: persist the events in the chunk we fetched state for (i.e.
+        # the backwards extremities), with custom auth events and state
         for e_id in events_to_state:
             # For paranoia we ensure that these events are marked as
             # non-outliers
@@ -807,10 +876,10 @@ class FederationHandler(BaseHandler):
             assert not ev.internal_metadata.is_outlier()
 
             ev_infos.append(
-                {
-                    "event": ev,
-                    "state": events_to_state[e_id],
-                    "auth_events": {
+                _NewEventInfo(
+                    event=ev,
+                    state=events_to_state[e_id],
+                    auth_events={
                         (
                             auth_events[a_id].type,
                             auth_events[a_id].state_key,
@@ -818,10 +887,10 @@ class FederationHandler(BaseHandler):
                         for a_id in ev.auth_event_ids()
                         if a_id in auth_events
                     },
-                }
+                )
             )
 
-        yield self._handle_new_events(dest, ev_infos, backfilled=True)
+        await self._handle_new_events(dest, ev_infos, backfilled=True)
 
         # Step 2: Persist the rest of the events in the chunk one by one
         events.sort(key=lambda e: e.depth)
@@ -837,16 +906,15 @@ class FederationHandler(BaseHandler):
             # We store these one at a time since each event depends on the
             # previous to work out the state.
             # TODO: We can probably do something more clever here.
-            yield self._handle_new_event(dest, event, backfilled=True)
+            await self._handle_new_event(dest, event, backfilled=True)
 
         return events
 
-    @defer.inlineCallbacks
-    def maybe_backfill(self, room_id, current_depth):
+    async def maybe_backfill(self, room_id, current_depth):
         """Checks the database to see if we should backfill before paginating,
         and if so do.
         """
-        extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id)
+        extremities = await self.store.get_oldest_events_with_depth_in_room(room_id)
 
         if not extremities:
             logger.debug("Not backfilling as no extremeties found.")
@@ -878,16 +946,18 @@ class FederationHandler(BaseHandler):
         #   state *before* the event, ignoring the special casing certain event
         #   types have.
 
-        forward_events = yield self.store.get_successor_events(list(extremities))
+        forward_events = await self.store.get_successor_events(list(extremities))
 
-        extremities_events = yield self.store.get_events(
-            forward_events, check_redacted=False, get_prev_content=False
+        extremities_events = await self.store.get_events(
+            forward_events,
+            redact_behaviour=EventRedactBehaviour.AS_IS,
+            get_prev_content=False,
         )
 
         # We set `check_history_visibility_only` as we might otherwise get false
         # positives from users having been erased.
-        filtered_extremities = yield filter_events_for_server(
-            self.store,
+        filtered_extremities = await filter_events_for_server(
+            self.storage,
             self.server_name,
             list(extremities_events.values()),
             redact=False,
@@ -916,7 +986,7 @@ class FederationHandler(BaseHandler):
         # First we try hosts that are already in the room
         # TODO: HEURISTIC ALERT.
 
-        curr_state = yield self.state_handler.get_current_state(room_id)
+        curr_state = await self.state_handler.get_current_state(room_id)
 
         def get_domains_from_state(state):
             """Get joined domains from state
@@ -955,12 +1025,11 @@ class FederationHandler(BaseHandler):
             domain for domain, depth in curr_domains if domain != self.server_name
         ]
 
-        @defer.inlineCallbacks
-        def try_backfill(domains):
+        async def try_backfill(domains):
             # TODO: Should we try multiple of these at a time?
             for dom in domains:
                 try:
-                    yield self.backfill(
+                    await self.backfill(
                         dom, room_id, limit=100, extremities=extremities
                     )
                     # If this succeeded then we probably already have the
@@ -970,6 +1039,12 @@ class FederationHandler(BaseHandler):
                 except SynapseError as e:
                     logger.info("Failed to backfill from %s because %s", dom, e)
                     continue
+                except HttpResponseException as e:
+                    if 400 <= e.code < 500:
+                        raise e.to_synapse_error()
+
+                    logger.info("Failed to backfill from %s because %s", dom, e)
+                    continue
                 except CodeMessageException as e:
                     if 400 <= e.code < 500:
                         raise
@@ -991,7 +1066,7 @@ class FederationHandler(BaseHandler):
 
             return False
 
-        success = yield try_backfill(likely_domains)
+        success = await try_backfill(likely_domains)
         if success:
             return True
 
@@ -1005,7 +1080,7 @@ class FederationHandler(BaseHandler):
 
         logger.debug("calling resolve_state_groups in _maybe_backfill")
         resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events)
-        states = yield make_deferred_yieldable(
+        states = await make_deferred_yieldable(
             defer.gatherResults(
                 [resolve(room_id, [e]) for e in event_ids], consumeErrors=True
             )
@@ -1015,7 +1090,7 @@ class FederationHandler(BaseHandler):
         # event_ids.
         states = dict(zip(event_ids, [s.state for s in states]))
 
-        state_map = yield self.store.get_events(
+        state_map = await self.store.get_events(
             [e_id for ids in itervalues(states) for e_id in itervalues(ids)],
             get_prev_content=False,
         )
@@ -1031,7 +1106,7 @@ class FederationHandler(BaseHandler):
         for e_id, _ in sorted_extremeties_tuple:
             likely_domains = get_domains_from_state(states[e_id])
 
-            success = yield try_backfill(
+            success = await try_backfill(
                 [dom for dom, _ in likely_domains if dom not in tried_domains]
             )
             if success:
@@ -1041,6 +1116,56 @@ class FederationHandler(BaseHandler):
 
         return False
 
+    async def _get_events_and_persist(
+        self, destination: str, room_id: str, events: Iterable[str]
+    ):
+        """Fetch the given events from a server, and persist them as outliers.
+
+        Logs a warning if we can't find the given event.
+        """
+
+        room_version = await self.store.get_room_version(room_id)
+
+        event_infos = []
+
+        async def get_event(event_id: str):
+            with nested_logging_context(event_id):
+                try:
+                    event = await self.federation_client.get_pdu(
+                        [destination], event_id, room_version, outlier=True,
+                    )
+                    if event is None:
+                        logger.warning(
+                            "Server %s didn't return event %s", destination, event_id,
+                        )
+                        return
+
+                    # recursively fetch the auth events for this event
+                    auth_events = await self._get_events_from_store_or_dest(
+                        destination, room_id, event.auth_event_ids()
+                    )
+                    auth = {}
+                    for auth_event_id in event.auth_event_ids():
+                        ae = auth_events.get(auth_event_id)
+                        if ae:
+                            auth[(ae.type, ae.state_key)] = ae
+
+                    event_infos.append(_NewEventInfo(event, None, auth))
+
+                except Exception as e:
+                    logger.warning(
+                        "Error fetching missing state/auth event %s: %s %s",
+                        event_id,
+                        type(e),
+                        e,
+                    )
+
+        await concurrently_execute(get_event, events, 5)
+
+        await self._handle_new_events(
+            destination, event_infos,
+        )
+
     def _sanity_check_event(self, ev):
         """
         Do some early sanity checks of a received event
@@ -1058,7 +1183,7 @@ class FederationHandler(BaseHandler):
             SynapseError if the event does not pass muster
         """
         if len(ev.prev_event_ids()) > 20:
-            logger.warn(
+            logger.warning(
                 "Rejecting event %s which has %i prev_events",
                 ev.event_id,
                 len(ev.prev_event_ids()),
@@ -1066,20 +1191,19 @@ class FederationHandler(BaseHandler):
             raise SynapseError(http_client.BAD_REQUEST, "Too many prev_events")
 
         if len(ev.auth_event_ids()) > 10:
-            logger.warn(
+            logger.warning(
                 "Rejecting event %s which has %i auth_events",
                 ev.event_id,
                 len(ev.auth_event_ids()),
             )
             raise SynapseError(http_client.BAD_REQUEST, "Too many auth_events")
 
-    @defer.inlineCallbacks
-    def send_invite(self, target_host, event):
+    async def send_invite(self, target_host, event):
         """ Sends the invite to the remote server for signing.
 
         Invites must be signed by the invitee's server before distribution.
         """
-        pdu = yield self.federation_client.send_invite(
+        pdu = await self.federation_client.send_invite(
             destination=target_host,
             room_id=event.room_id,
             event_id=event.event_id,
@@ -1088,19 +1212,18 @@ class FederationHandler(BaseHandler):
 
         return pdu
 
-    @defer.inlineCallbacks
-    def on_event_auth(self, event_id):
-        event = yield self.store.get_event(event_id)
-        auth = yield self.store.get_auth_chain(
-            [auth_id for auth_id in event.auth_event_ids()], include_given=True
+    async def on_event_auth(self, event_id: str) -> List[EventBase]:
+        event = await self.store.get_event(event_id)
+        auth = await self.store.get_auth_chain(
+            list(event.auth_event_ids()), include_given=True
         )
-        return [e for e in auth]
+        return list(auth)
 
-    @log_function
-    @defer.inlineCallbacks
-    def do_invite_join(self, target_hosts, room_id, joinee, content):
+    async def do_invite_join(
+        self, target_hosts: Iterable[str], room_id: str, joinee: str, content: JsonDict
+    ) -> Tuple[str, int]:
         """ Attempts to join the `joinee` to the room `room_id` via the
-        server `target_host`.
+        servers contained in `target_hosts`.
 
         This first triggers a /make_join/ request that returns a partial
         event that we can fill out and sign. This is then sent to the
@@ -1109,10 +1232,23 @@ class FederationHandler(BaseHandler):
 
         We suspend processing of any received events from this room until we
         have finished processing the join.
+
+        Args:
+            target_hosts: List of servers to attempt to join the room with.
+
+            room_id: The ID of the room to join.
+
+            joinee: The User ID of the joining user.
+
+            content: The event content to use for the join event.
         """
+        # TODO: We should be able to call this on workers, but the upgrading of
+        # room stuff after join currently doesn't work on workers.
+        assert self.config.worker.worker_app is None
+
         logger.debug("Joining %s to %s", joinee, room_id)
 
-        origin, event, event_format_version = yield self._make_and_verify_event(
+        origin, event, room_version_obj = await self._make_and_verify_event(
             target_hosts,
             room_id,
             joinee,
@@ -1128,7 +1264,7 @@ class FederationHandler(BaseHandler):
 
         self.room_queues[room_id] = []
 
-        yield self._clean_room_for_join(room_id)
+        await self._clean_room_for_join(room_id)
 
         handled_events = set()
 
@@ -1140,8 +1276,9 @@ class FederationHandler(BaseHandler):
                 target_hosts.insert(0, origin)
             except ValueError:
                 pass
-            ret = yield self.federation_client.send_join(
-                target_hosts, event, event_format_version
+
+            ret = await self.federation_client.send_join(
+                target_hosts, event, room_version_obj
             )
 
             origin = ret["origin"]
@@ -1158,17 +1295,49 @@ class FederationHandler(BaseHandler):
 
             logger.debug("do_invite_join event: %s", event)
 
-            try:
-                yield self.store.store_room(
-                    room_id=room_id, room_creator_user_id="", is_public=False
-                )
-            except Exception:
-                # FIXME
-                pass
+            # if this is the first time we've joined this room, it's time to add
+            # a row to `rooms` with the correct room version. If there's already a
+            # row there, we should override it, since it may have been populated
+            # based on an invite request which lied about the room version.
+            #
+            # federation_client.send_join has already checked that the room
+            # version in the received create event is the same as room_version_obj,
+            # so we can rely on it now.
+            #
+            await self.store.upsert_room_on_join(
+                room_id=room_id, room_version=room_version_obj,
+            )
 
-            yield self._persist_auth_tree(origin, auth_chain, state, event)
+            max_stream_id = await self._persist_auth_tree(
+                origin, auth_chain, state, event, room_version_obj
+            )
+
+            # We wait here until this instance has seen the events come down
+            # replication (if we're using replication) as the below uses caches.
+            #
+            # TODO: Currently the events stream is written to from master
+            await self._replication.wait_for_stream_position(
+                self.config.worker.writers.events, "events", max_stream_id
+            )
+
+            # Check whether this room is the result of an upgrade of a room we already know
+            # about. If so, migrate over user information
+            predecessor = await self.store.get_room_predecessor(room_id)
+            if not predecessor or not isinstance(predecessor.get("room_id"), str):
+                return event.event_id, max_stream_id
+            old_room_id = predecessor["room_id"]
+            logger.debug(
+                "Found predecessor for %s during remote join: %s", room_id, old_room_id
+            )
+
+            # We retrieve the room member handler here as to not cause a cyclic dependency
+            member_handler = self.hs.get_room_member_handler()
+            await member_handler.transfer_room_state_on_room_upgrade(
+                old_room_id, room_id
+            )
 
             logger.debug("Finished joining %s to %s", joinee, room_id)
+            return event.event_id, max_stream_id
         finally:
             room_queue = self.room_queues[room_id]
             del self.room_queues[room_id]
@@ -1181,10 +1350,7 @@ class FederationHandler(BaseHandler):
 
             run_in_background(self._handle_queued_pdus, room_queue)
 
-        return True
-
-    @defer.inlineCallbacks
-    def _handle_queued_pdus(self, room_queue):
+    async def _handle_queued_pdus(self, room_queue):
         """Process PDUs which got queued up while we were busy send_joining.
 
         Args:
@@ -1200,28 +1366,24 @@ class FederationHandler(BaseHandler):
                     p.room_id,
                 )
                 with nested_logging_context(p.event_id):
-                    yield self.on_receive_pdu(origin, p, sent_to_us_directly=True)
+                    await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
             except Exception as e:
-                logger.warn(
+                logger.warning(
                     "Error handling queued PDU %s from %s: %s", p.event_id, origin, e
                 )
 
-    @defer.inlineCallbacks
-    @log_function
-    def on_make_join_request(self, origin, room_id, user_id):
+    async def on_make_join_request(
+        self, origin: str, room_id: str, user_id: str
+    ) -> EventBase:
         """ We've received a /make_join/ request, so we create a partial
         join event for the room and return that. We do *not* persist or
         process it until the other server has signed it and sent it back.
 
         Args:
-            origin (str): The (verified) server name of the requesting server.
-            room_id (str): Room to create join event in
-            user_id (str): The user to create the join for
-
-        Returns:
-            Deferred[FrozenEvent]
+            origin: The (verified) server name of the requesting server.
+            room_id: Room to create join event in
+            user_id: The user to create the join for
         """
-
         if get_domain_from_id(user_id) != origin:
             logger.info(
                 "Got /make_join request for user %r from different origin %s, ignoring",
@@ -1232,7 +1394,7 @@ class FederationHandler(BaseHandler):
 
         event_content = {"membership": Membership.JOIN}
 
-        room_version = yield self.store.get_room_version(room_id)
+        room_version = await self.store.get_room_version_id(room_id)
 
         builder = self.event_builder_factory.new(
             room_version,
@@ -1246,14 +1408,14 @@ class FederationHandler(BaseHandler):
         )
 
         try:
-            event, context = yield self.event_creation_handler.create_new_client_event(
+            event, context = await self.event_creation_handler.create_new_client_event(
                 builder=builder
             )
         except AuthError as e:
-            logger.warn("Failed to create join %r because %s", event, e)
+            logger.warning("Failed to create join to %s because %s", room_id, e)
             raise e
 
-        event_allowed = yield self.third_party_event_rules.check_event_allowed(
+        event_allowed = await self.third_party_event_rules.check_event_allowed(
             event, context
         )
         if not event_allowed:
@@ -1264,26 +1426,33 @@ class FederationHandler(BaseHandler):
 
         # The remote hasn't signed it yet, obviously. We'll do the full checks
         # when we get the event back in `on_send_join_request`
-        yield self.auth.check_from_context(
+        await self.auth.check_from_context(
             room_version, event, context, do_sig_check=False
         )
 
         return event
 
-    @defer.inlineCallbacks
-    @log_function
-    def on_send_join_request(self, origin, pdu):
+    async def on_send_join_request(self, origin, pdu):
         """ We have received a join event for a room. Fully process it and
         respond with the current state and auth chains.
         """
         event = pdu
 
         logger.debug(
-            "on_send_join_request: Got event: %s, signatures: %s",
+            "on_send_join_request from %s: Got event: %s, signatures: %s",
+            origin,
             event.event_id,
             event.signatures,
         )
 
+        if get_domain_from_id(event.sender) != origin:
+            logger.info(
+                "Got /send_join request for user %r from different origin %s",
+                event.sender,
+                origin,
+            )
+            raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
+
         event.internal_metadata.outlier = False
         # Send this event on behalf of the origin server.
         #
@@ -1300,9 +1469,9 @@ class FederationHandler(BaseHandler):
         # would introduce the danger of backwards-compatibility problems.
         event.internal_metadata.send_on_behalf_of = origin
 
-        context = yield self._handle_new_event(origin, event)
+        context = await self._handle_new_event(origin, event)
 
-        event_allowed = yield self.third_party_event_rules.check_event_allowed(
+        event_allowed = await self.third_party_event_rules.check_event_allowed(
             event, context
         )
         if not event_allowed:
@@ -1320,29 +1489,28 @@ class FederationHandler(BaseHandler):
         if event.type == EventTypes.Member:
             if event.content["membership"] == Membership.JOIN:
                 user = UserID.from_string(event.state_key)
-                yield self.user_joined_room(user, event.room_id)
+                await self.user_joined_room(user, event.room_id)
 
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = await context.get_prev_state_ids()
 
         state_ids = list(prev_state_ids.values())
-        auth_chain = yield self.store.get_auth_chain(state_ids)
+        auth_chain = await self.store.get_auth_chain(state_ids)
 
-        state = yield self.store.get_events(list(prev_state_ids.values()))
+        state = await self.store.get_events(list(prev_state_ids.values()))
 
         return {"state": list(state.values()), "auth_chain": auth_chain}
 
-    @defer.inlineCallbacks
-    def on_invite_request(self, origin, pdu):
+    async def on_invite_request(
+        self, origin: str, event: EventBase, room_version: RoomVersion
+    ):
         """ We've got an invite event. Process and persist it. Sign it.
 
         Respond with the now signed event.
         """
-        event = pdu
-
         if event.state_key is None:
             raise SynapseError(400, "The invite event did not have a state key")
 
-        is_blocked = yield self.store.is_room_blocked(event.room_id)
+        is_blocked = await self.store.is_room_blocked(event.room_id)
         if is_blocked:
             raise SynapseError(403, "This room has been blocked on this server")
 
@@ -1373,24 +1541,35 @@ class FederationHandler(BaseHandler):
         if event.state_key == self._server_notices_mxid:
             raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user")
 
+        # keep a record of the room version, if we don't yet know it.
+        # (this may get overwritten if we later get a different room version in a
+        # join dance).
+        await self._maybe_store_room_on_invite(
+            room_id=event.room_id, room_version=room_version
+        )
+
         event.internal_metadata.outlier = True
         event.internal_metadata.out_of_band_membership = True
 
         event.signatures.update(
             compute_event_signature(
-                event.get_pdu_json(), self.hs.hostname, self.hs.config.signing_key[0]
+                room_version,
+                event.get_pdu_json(),
+                self.hs.hostname,
+                self.hs.config.signing_key[0],
             )
         )
 
-        context = yield self.state_handler.compute_event_context(event)
-        yield self.persist_events_and_notify([(event, context)])
+        context = await self.state_handler.compute_event_context(event)
+        await self.persist_events_and_notify([(event, context)])
 
         return event
 
-    @defer.inlineCallbacks
-    def do_remotely_reject_invite(self, target_hosts, room_id, user_id):
-        origin, event, event_format_version = yield self._make_and_verify_event(
-            target_hosts, room_id, user_id, "leave"
+    async def do_remotely_reject_invite(
+        self, target_hosts: Iterable[str], room_id: str, user_id: str, content: JsonDict
+    ) -> Tuple[EventBase, int]:
+        origin, event, room_version = await self._make_and_verify_event(
+            target_hosts, room_id, user_id, "leave", content=content
         )
         # Mark as outlier as we don't have any state for this event; we're not
         # even in the room.
@@ -1405,18 +1584,27 @@ class FederationHandler(BaseHandler):
         except ValueError:
             pass
 
-        yield self.federation_client.send_leave(target_hosts, event)
-
-        context = yield self.state_handler.compute_event_context(event)
-        yield self.persist_events_and_notify([(event, context)])
-
-        return event
-
-    @defer.inlineCallbacks
-    def _make_and_verify_event(
-        self, target_hosts, room_id, user_id, membership, content={}, params=None
-    ):
-        origin, event, format_ver = yield self.federation_client.make_membership_event(
+        await self.federation_client.send_leave(target_hosts, event)
+
+        context = await self.state_handler.compute_event_context(event)
+        stream_id = await self.persist_events_and_notify([(event, context)])
+
+        return event, stream_id
+
+    async def _make_and_verify_event(
+        self,
+        target_hosts: Iterable[str],
+        room_id: str,
+        user_id: str,
+        membership: str,
+        content: JsonDict = {},
+        params: Optional[Dict[str, str]] = None,
+    ) -> Tuple[str, EventBase, RoomVersion]:
+        (
+            origin,
+            event,
+            room_version,
+        ) = await self.federation_client.make_membership_event(
             target_hosts, room_id, user_id, membership, content, params=params
         )
 
@@ -1428,22 +1616,19 @@ class FederationHandler(BaseHandler):
         assert event.user_id == user_id
         assert event.state_key == user_id
         assert event.room_id == room_id
-        return origin, event, format_ver
+        return origin, event, room_version
 
-    @defer.inlineCallbacks
-    @log_function
-    def on_make_leave_request(self, origin, room_id, user_id):
+    async def on_make_leave_request(
+        self, origin: str, room_id: str, user_id: str
+    ) -> EventBase:
         """ We've received a /make_leave/ request, so we create a partial
         leave event for the room and return that. We do *not* persist or
         process it until the other server has signed it and sent it back.
 
         Args:
-            origin (str): The (verified) server name of the requesting server.
-            room_id (str): Room to create leave event in
-            user_id (str): The user to create the leave for
-
-        Returns:
-            Deferred[FrozenEvent]
+            origin: The (verified) server name of the requesting server.
+            room_id: Room to create leave event in
+            user_id: The user to create the leave for
         """
         if get_domain_from_id(user_id) != origin:
             logger.info(
@@ -1453,7 +1638,7 @@ class FederationHandler(BaseHandler):
             )
             raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
 
-        room_version = yield self.store.get_room_version(room_id)
+        room_version = await self.store.get_room_version_id(room_id)
         builder = self.event_builder_factory.new(
             room_version,
             {
@@ -1465,11 +1650,11 @@ class FederationHandler(BaseHandler):
             },
         )
 
-        event, context = yield self.event_creation_handler.create_new_client_event(
+        event, context = await self.event_creation_handler.create_new_client_event(
             builder=builder
         )
 
-        event_allowed = yield self.third_party_event_rules.check_event_allowed(
+        event_allowed = await self.third_party_event_rules.check_event_allowed(
             event, context
         )
         if not event_allowed:
@@ -1481,18 +1666,16 @@ class FederationHandler(BaseHandler):
         try:
             # The remote hasn't signed it yet, obviously. We'll do the full checks
             # when we get the event back in `on_send_leave_request`
-            yield self.auth.check_from_context(
+            await self.auth.check_from_context(
                 room_version, event, context, do_sig_check=False
             )
         except AuthError as e:
-            logger.warn("Failed to create new leave %r because %s", event, e)
+            logger.warning("Failed to create new leave %r because %s", event, e)
             raise e
 
         return event
 
-    @defer.inlineCallbacks
-    @log_function
-    def on_send_leave_request(self, origin, pdu):
+    async def on_send_leave_request(self, origin, pdu):
         """ We have received a leave event for a room. Fully process it."""
         event = pdu
 
@@ -1502,11 +1685,19 @@ class FederationHandler(BaseHandler):
             event.signatures,
         )
 
+        if get_domain_from_id(event.sender) != origin:
+            logger.info(
+                "Got /send_leave request for user %r from different origin %s",
+                event.sender,
+                origin,
+            )
+            raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
+
         event.internal_metadata.outlier = False
 
-        context = yield self._handle_new_event(origin, event)
+        context = await self._handle_new_event(origin, event)
 
-        event_allowed = yield self.third_party_event_rules.check_event_allowed(
+        event_allowed = await self.third_party_event_rules.check_event_allowed(
             event, context
         )
         if not event_allowed:
@@ -1523,16 +1714,15 @@ class FederationHandler(BaseHandler):
 
         return None
 
-    @defer.inlineCallbacks
-    def get_state_for_pdu(self, room_id, event_id):
+    async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
         """Returns the state at the event. i.e. not including said event.
         """
 
-        event = yield self.store.get_event(
+        event = await self.store.get_event(
             event_id, allow_none=False, check_room_id=room_id
         )
 
-        state_groups = yield self.store.get_state_groups(room_id, [event_id])
+        state_groups = await self.state_store.get_state_groups(room_id, [event_id])
 
         if state_groups:
             _, state = list(iteritems(state_groups)).pop()
@@ -1543,7 +1733,7 @@ class FederationHandler(BaseHandler):
                 if "replaces_state" in event.unsigned:
                     prev_id = event.unsigned["replaces_state"]
                     if prev_id != event.event_id:
-                        prev_event = yield self.store.get_event(prev_id)
+                        prev_event = await self.store.get_event(prev_id)
                         results[(event.type, event.state_key)] = prev_event
                 else:
                     del results[(event.type, event.state_key)]
@@ -1553,15 +1743,14 @@ class FederationHandler(BaseHandler):
         else:
             return []
 
-    @defer.inlineCallbacks
-    def get_state_ids_for_pdu(self, room_id, event_id):
+    async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
         """Returns the state at the event. i.e. not including said event.
         """
-        event = yield self.store.get_event(
+        event = await self.store.get_event(
             event_id, allow_none=False, check_room_id=room_id
         )
 
-        state_groups = yield self.store.get_state_groups_ids(room_id, [event_id])
+        state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
 
         if state_groups:
             _, state = list(state_groups.items()).pop()
@@ -1580,46 +1769,50 @@ class FederationHandler(BaseHandler):
         else:
             return []
 
-    @defer.inlineCallbacks
     @log_function
-    def on_backfill_request(self, origin, room_id, pdu_list, limit):
-        in_room = yield self.auth.check_host_in_room(room_id, origin)
+    async def on_backfill_request(
+        self, origin: str, room_id: str, pdu_list: List[str], limit: int
+    ) -> List[EventBase]:
+        in_room = await self.auth.check_host_in_room(room_id, origin)
         if not in_room:
             raise AuthError(403, "Host not in room.")
 
-        events = yield self.store.get_backfill_events(room_id, pdu_list, limit)
+        # Synapse asks for 100 events per backfill request. Do not allow more.
+        limit = min(limit, 100)
 
-        events = yield filter_events_for_server(self.store, origin, events)
+        events = await self.store.get_backfill_events(room_id, pdu_list, limit)
+
+        events = await filter_events_for_server(self.storage, origin, events)
 
         return events
 
-    @defer.inlineCallbacks
     @log_function
-    def get_persisted_pdu(self, origin, event_id):
+    async def get_persisted_pdu(
+        self, origin: str, event_id: str
+    ) -> Optional[EventBase]:
         """Get an event from the database for the given server.
 
         Args:
-            origin [str]: hostname of server which is requesting the event; we
+            origin: hostname of server which is requesting the event; we
                will check that the server is allowed to see it.
-            event_id [str]: id of the event being requested
+            event_id: id of the event being requested
 
         Returns:
-            Deferred[EventBase|None]: None if we know nothing about the event;
-                otherwise the (possibly-redacted) event.
+            None if we know nothing about the event; otherwise the (possibly-redacted) event.
 
         Raises:
             AuthError if the server is not currently in the room
         """
-        event = yield self.store.get_event(
+        event = await self.store.get_event(
             event_id, allow_none=True, allow_rejected=True
         )
 
         if event:
-            in_room = yield self.auth.check_host_in_room(event.room_id, origin)
+            in_room = await self.auth.check_host_in_room(event.room_id, origin)
             if not in_room:
                 raise AuthError(403, "Host not in room.")
 
-            events = yield filter_events_for_server(self.store, origin, [event])
+            events = await filter_events_for_server(self.storage, origin, [event])
             event = events[0]
             return event
         else:
@@ -1628,11 +1821,10 @@ class FederationHandler(BaseHandler):
     def get_min_depth_for_context(self, context):
         return self.store.get_min_depth(context)
 
-    @defer.inlineCallbacks
-    def _handle_new_event(
+    async def _handle_new_event(
         self, origin, event, state=None, auth_events=None, backfilled=False
     ):
-        context = yield self._prep_event(
+        context = await self._prep_event(
             origin, event, state=state, auth_events=auth_events, backfilled=backfilled
         )
 
@@ -1640,12 +1832,16 @@ class FederationHandler(BaseHandler):
         # hack around with a try/finally instead.
         success = False
         try:
-            if not event.internal_metadata.is_outlier() and not backfilled:
-                yield self.action_generator.handle_push_actions_for_event(
+            if (
+                not event.internal_metadata.is_outlier()
+                and not backfilled
+                and not context.rejected
+            ):
+                await self.action_generator.handle_push_actions_for_event(
                     event, context
                 )
 
-            yield self.persist_events_and_notify(
+            await self.persist_events_and_notify(
                 [(event, context)], backfilled=backfilled
             )
             success = True
@@ -1657,8 +1853,12 @@ class FederationHandler(BaseHandler):
 
         return context
 
-    @defer.inlineCallbacks
-    def _handle_new_events(self, origin, event_infos, backfilled=False):
+    async def _handle_new_events(
+        self,
+        origin: str,
+        event_infos: Iterable[_NewEventInfo],
+        backfilled: bool = False,
+    ) -> None:
         """Creates the appropriate contexts and persists events. The events
         should not depend on one another, e.g. this should be used to persist
         a bunch of outliers, but not a chunk of individual events that depend
@@ -1667,36 +1867,41 @@ class FederationHandler(BaseHandler):
         Notifies about the events where appropriate.
         """
 
-        @defer.inlineCallbacks
-        def prep(ev_info):
-            event = ev_info["event"]
+        async def prep(ev_info: _NewEventInfo):
+            event = ev_info.event
             with nested_logging_context(suffix=event.event_id):
-                res = yield self._prep_event(
+                res = await self._prep_event(
                     origin,
                     event,
-                    state=ev_info.get("state"),
-                    auth_events=ev_info.get("auth_events"),
+                    state=ev_info.state,
+                    auth_events=ev_info.auth_events,
                     backfilled=backfilled,
                 )
             return res
 
-        contexts = yield make_deferred_yieldable(
+        contexts = await make_deferred_yieldable(
             defer.gatherResults(
                 [run_in_background(prep, ev_info) for ev_info in event_infos],
                 consumeErrors=True,
             )
         )
 
-        yield self.persist_events_and_notify(
+        await self.persist_events_and_notify(
             [
-                (ev_info["event"], context)
+                (ev_info.event, context)
                 for ev_info, context in zip(event_infos, contexts)
             ],
             backfilled=backfilled,
         )
 
-    @defer.inlineCallbacks
-    def _persist_auth_tree(self, origin, auth_events, state, event):
+    async def _persist_auth_tree(
+        self,
+        origin: str,
+        auth_events: List[EventBase],
+        state: List[EventBase],
+        event: EventBase,
+        room_version: RoomVersion,
+    ) -> int:
         """Checks the auth chain is valid (and passes auth checks) for the
         state and event. Then persists the auth chain and state atomically.
         Persists the event separately. Notifies about the persisted events
@@ -1705,18 +1910,17 @@ class FederationHandler(BaseHandler):
         Will attempt to fetch missing auth events.
 
         Args:
-            origin (str): Where the events came from
-            auth_events (list)
-            state (list)
-            event (Event)
-
-        Returns:
-            Deferred
+            origin: Where the events came from
+            auth_events
+            state
+            event
+            room_version: The room version we expect this room to have, and
+                will raise if it doesn't match the version in the create event.
         """
         events_to_context = {}
         for e in itertools.chain(auth_events, state):
             e.internal_metadata.outlier = True
-            ctx = yield self.state_handler.compute_event_context(e)
+            ctx = await self.state_handler.compute_event_context(e)
             events_to_context[e.event_id] = ctx
 
         event_map = {
@@ -1734,10 +1938,13 @@ class FederationHandler(BaseHandler):
             # invalid, and it would fail auth checks anyway.
             raise SynapseError(400, "No create event in state")
 
-        room_version = create_event.content.get(
+        room_version_id = create_event.content.get(
             "room_version", RoomVersions.V1.identifier
         )
 
+        if room_version.identifier != room_version_id:
+            raise SynapseError(400, "Room version mismatch")
+
         missing_auth_events = set()
         for e in itertools.chain(auth_events, state, [event]):
             for e_id in e.auth_event_ids():
@@ -1745,8 +1952,8 @@ class FederationHandler(BaseHandler):
                     missing_auth_events.add(e_id)
 
         for e_id in missing_auth_events:
-            m_ev = yield self.federation_client.get_pdu(
-                [origin], e_id, room_version=room_version, outlier=True, timeout=10000
+            m_ev = await self.federation_client.get_pdu(
+                [origin], e_id, room_version=room_version, outlier=True, timeout=10000,
             )
             if m_ev and m_ev.event_id == e_id:
                 event_map[e_id] = m_ev
@@ -1763,7 +1970,7 @@ class FederationHandler(BaseHandler):
                 auth_for_e[(EventTypes.Create, "")] = create_event
 
             try:
-                self.auth.check(room_version, e, auth_events=auth_for_e)
+                event_auth.check(room_version, e, auth_events=auth_for_e)
             except SynapseError as err:
                 # we may get SynapseErrors here as well as AuthErrors. For
                 # instance, there are a couple of (ancient) events in some
@@ -1771,94 +1978,80 @@ class FederationHandler(BaseHandler):
                 # cause SynapseErrors in auth.check. We don't want to give up
                 # the attempt to federate altogether in such cases.
 
-                logger.warn("Rejecting %s because %s", e.event_id, err.msg)
+                logger.warning("Rejecting %s because %s", e.event_id, err.msg)
 
                 if e == event:
                     raise
                 events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
 
-        yield self.persist_events_and_notify(
+        await self.persist_events_and_notify(
             [
                 (e, events_to_context[e.event_id])
                 for e in itertools.chain(auth_events, state)
             ]
         )
 
-        new_event_context = yield self.state_handler.compute_event_context(
+        new_event_context = await self.state_handler.compute_event_context(
             event, old_state=state
         )
 
-        yield self.persist_events_and_notify([(event, new_event_context)])
-
-    @defer.inlineCallbacks
-    def _prep_event(self, origin, event, state, auth_events, backfilled):
-        """
+        return await self.persist_events_and_notify([(event, new_event_context)])
 
-        Args:
-            origin:
-            event:
-            state:
-            auth_events:
-            backfilled (bool)
-
-        Returns:
-            Deferred, which resolves to synapse.events.snapshot.EventContext
-        """
-        context = yield self.state_handler.compute_event_context(event, old_state=state)
+    async def _prep_event(
+        self,
+        origin: str,
+        event: EventBase,
+        state: Optional[Iterable[EventBase]],
+        auth_events: Optional[StateMap[EventBase]],
+        backfilled: bool,
+    ) -> EventContext:
+        context = await self.state_handler.compute_event_context(event, old_state=state)
 
         if not auth_events:
-            prev_state_ids = yield context.get_prev_state_ids(self.store)
-            auth_events_ids = yield self.auth.compute_auth_events(
+            prev_state_ids = await context.get_prev_state_ids()
+            auth_events_ids = await self.auth.compute_auth_events(
                 event, prev_state_ids, for_verification=True
             )
-            auth_events = yield self.store.get_events(auth_events_ids)
+            auth_events = await self.store.get_events(auth_events_ids)
             auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
 
         # This is a hack to fix some old rooms where the initial join event
         # didn't reference the create event in its auth events.
         if event.type == EventTypes.Member and not event.auth_event_ids():
             if len(event.prev_event_ids()) == 1 and event.depth < 5:
-                c = yield self.store.get_event(
+                c = await self.store.get_event(
                     event.prev_event_ids()[0], allow_none=True
                 )
                 if c and c.type == EventTypes.Create:
                     auth_events[(c.type, c.state_key)] = c
 
-        try:
-            yield self.do_auth(origin, event, context, auth_events=auth_events)
-        except AuthError as e:
-            logger.warn("[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg)
-
-            context.rejected = RejectedReason.AUTH_ERROR
+        context = await self.do_auth(origin, event, context, auth_events=auth_events)
 
         if not context.rejected:
-            yield self._check_for_soft_fail(event, state, backfilled)
+            await self._check_for_soft_fail(event, state, backfilled)
 
         if event.type == EventTypes.GuestAccess and not context.rejected:
-            yield self.maybe_kick_guest_users(event)
+            await self.maybe_kick_guest_users(event)
 
         return context
 
-    @defer.inlineCallbacks
-    def _check_for_soft_fail(self, event, state, backfilled):
-        """Checks if we should soft fail the event, if so marks the event as
+    async def _check_for_soft_fail(
+        self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool
+    ) -> None:
+        """Checks if we should soft fail the event; if so, marks the event as
         such.
 
         Args:
-            event (FrozenEvent)
-            state (dict|None): The state at the event if we don't have all the
-                event's prev events
-            backfilled (bool): Whether the event is from backfill
-
-        Returns:
-            Deferred
+            event
+            state: The state at the event if we don't have all the event's prev events
+            backfilled: Whether the event is from backfill
         """
         # For new (non-backfilled and non-outlier) events we check if the event
         # passes auth based on the current state. If it doesn't then we
         # "soft-fail" the event.
         do_soft_fail_check = not backfilled and not event.internal_metadata.is_outlier()
         if do_soft_fail_check:
-            extrem_ids = yield self.store.get_latest_event_ids_in_room(event.room_id)
+            extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
 
             extrem_ids = set(extrem_ids)
             prev_event_ids = set(event.prev_event_ids())
@@ -1869,7 +2062,8 @@ class FederationHandler(BaseHandler):
                 do_soft_fail_check = False
 
         if do_soft_fail_check:
-            room_version = yield self.store.get_room_version(event.room_id)
+            room_version = await self.store.get_room_version_id(event.room_id)
+            room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
 
             # Calculate the "current state".
             if state is not None:
@@ -1885,19 +2079,19 @@ class FederationHandler(BaseHandler):
                 # given state at the event. This should correctly handle cases
                 # like bans, especially with state res v2.
 
-                state_sets = yield self.store.get_state_groups(
+                state_sets = await self.state_store.get_state_groups(
                     event.room_id, extrem_ids
                 )
                 state_sets = list(state_sets.values())
                 state_sets.append(state)
-                current_state_ids = yield self.state_handler.resolve_events(
+                current_state_ids = await self.state_handler.resolve_events(
                     room_version, state_sets, event
                 )
                 current_state_ids = {
                     k: e.event_id for k, e in iteritems(current_state_ids)
                 }
             else:
-                current_state_ids = yield self.state_handler.get_current_state_ids(
+                current_state_ids = await self.state_handler.get_current_state_ids(
                     event.room_id, latest_event_ids=extrem_ids
                 )
 
@@ -1913,26 +2107,27 @@ class FederationHandler(BaseHandler):
                 e for k, e in iteritems(current_state_ids) if k in auth_types
             ]
 
-            current_auth_events = yield self.store.get_events(current_state_ids)
+            current_auth_events = await self.store.get_events(current_state_ids)
             current_auth_events = {
                 (e.type, e.state_key): e for e in current_auth_events.values()
             }
 
             try:
-                self.auth.check(room_version, event, auth_events=current_auth_events)
+                event_auth.check(
+                    room_version_obj, event, auth_events=current_auth_events
+                )
             except AuthError as e:
-                logger.warn("Soft-failing %r because %s", event, e)
+                logger.warning("Soft-failing %r because %s", event, e)
                 event.internal_metadata.soft_failed = True
 
-    @defer.inlineCallbacks
-    def on_query_auth(
+    async def on_query_auth(
         self, origin, event_id, room_id, remote_auth_chain, rejects, missing
     ):
-        in_room = yield self.auth.check_host_in_room(room_id, origin)
+        in_room = await self.auth.check_host_in_room(room_id, origin)
         if not in_room:
             raise AuthError(403, "Host not in room.")
 
-        event = yield self.store.get_event(
+        event = await self.store.get_event(
             event_id, allow_none=False, check_room_id=room_id
         )
 
@@ -1940,70 +2135,77 @@ class FederationHandler(BaseHandler):
         # don't want to fall into the trap of `missing` being wrong.
         for e in remote_auth_chain:
             try:
-                yield self._handle_new_event(origin, e)
+                await self._handle_new_event(origin, e)
             except AuthError:
                 pass
 
         # Now get the current auth_chain for the event.
-        local_auth_chain = yield self.store.get_auth_chain(
-            [auth_id for auth_id in event.auth_event_ids()], include_given=True
+        local_auth_chain = await self.store.get_auth_chain(
+            list(event.auth_event_ids()), include_given=True
         )
 
         # TODO: Check if we would now reject event_id. If so we need to tell
         # everyone.
 
-        ret = yield self.construct_auth_difference(local_auth_chain, remote_auth_chain)
+        ret = await self.construct_auth_difference(local_auth_chain, remote_auth_chain)
 
         logger.debug("on_query_auth returning: %s", ret)
 
         return ret
 
-    @defer.inlineCallbacks
-    def on_get_missing_events(
+    async def on_get_missing_events(
         self, origin, room_id, earliest_events, latest_events, limit
     ):
-        in_room = yield self.auth.check_host_in_room(room_id, origin)
+        in_room = await self.auth.check_host_in_room(room_id, origin)
         if not in_room:
             raise AuthError(403, "Host not in room.")
 
+        # Only allow up to 20 events to be retrieved per request.
         limit = min(limit, 20)
 
-        missing_events = yield self.store.get_missing_events(
+        missing_events = await self.store.get_missing_events(
             room_id=room_id,
             earliest_events=earliest_events,
             latest_events=latest_events,
             limit=limit,
         )
 
-        missing_events = yield filter_events_for_server(
-            self.store, origin, missing_events
+        missing_events = await filter_events_for_server(
+            self.storage, origin, missing_events
         )
 
         return missing_events
 
-    @defer.inlineCallbacks
-    @log_function
-    def do_auth(self, origin, event, context, auth_events):
+    async def do_auth(
+        self,
+        origin: str,
+        event: EventBase,
+        context: EventContext,
+        auth_events: StateMap[EventBase],
+    ) -> EventContext:
         """
 
         Args:
-            origin (str):
-            event (synapse.events.EventBase):
-            context (synapse.events.snapshot.EventContext):
-            auth_events (dict[(str, str)->synapse.events.EventBase]):
+            origin:
+            event:
+            context:
+            auth_events:
                 Map from (event_type, state_key) to event
 
-                What we expect the event's auth_events to be, based on the event's
-                position in the dag. I think? maybe??
+                Normally, our calculated auth_events based on the state of the room
+                at the event's position in the DAG, though occasionally (eg if the
+                event is an outlier), may be the auth events claimed by the remote
+                server.
 
                 Also NB that this function adds entries to it.
         Returns:
-            defer.Deferred[None]
+            updated context object
         """
-        room_version = yield self.store.get_room_version(event.room_id)
+        room_version = await self.store.get_room_version_id(event.room_id)
+        room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
 
         try:
-            yield self._update_auth_events_and_context_for_auth(
+            context = await self._update_auth_events_and_context_for_auth(
                 origin, event, context, auth_events
             )
         except Exception:
@@ -2018,15 +2220,20 @@ class FederationHandler(BaseHandler):
             )
 
         try:
-            self.auth.check(room_version, event, auth_events=auth_events)
+            event_auth.check(room_version_obj, event, auth_events=auth_events)
         except AuthError as e:
-            logger.warn("Failed auth resolution for %r because %s", event, e)
-            raise e
+            logger.warning("Failed auth resolution for %r because %s", event, e)
+            context.rejected = RejectedReason.AUTH_ERROR
 
-    @defer.inlineCallbacks
-    def _update_auth_events_and_context_for_auth(
-        self, origin, event, context, auth_events
-    ):
+        return context
+
+    async def _update_auth_events_and_context_for_auth(
+        self,
+        origin: str,
+        event: EventBase,
+        context: EventContext,
+        auth_events: StateMap[EventBase],
+    ) -> EventContext:
         """Helper for do_auth. See there for docs.
 
         Checks whether a given event has the expected auth events. If it
@@ -2034,59 +2241,59 @@ class FederationHandler(BaseHandler):
         we can come to a consensus (e.g. if one server missed some valid
         state).
 
-        This attempts to resovle any potential divergence of state between
+        This attempts to resolve any potential divergence of state between
         servers, but is not essential and so failures should not block further
         processing of the event.
 
         Args:
-            origin (str):
-            event (synapse.events.EventBase):
-            context (synapse.events.snapshot.EventContext):
-            auth_events (dict[(str, str)->synapse.events.EventBase]):
+            origin:
+            event:
+            context:
+
+            auth_events:
+                Map from (event_type, state_key) to event
+
+                Normally, our calculated auth_events based on the state of the room
+                at the event's position in the DAG, though occasionally (eg if the
+                event is an outlier), may be the auth events claimed by the remote
+                server.
+
+                Also NB that this function adds entries to it.
 
         Returns:
-            defer.Deferred[None]
+            updated context
         """
         event_auth_events = set(event.auth_event_ids())
 
-        if event.is_state():
-            event_key = (event.type, event.state_key)
-        else:
-            event_key = None
-
-        # if the event's auth_events refers to events which are not in our
-        # calculated auth_events, we need to fetch those events from somewhere.
-        #
-        # we start by fetching them from the store, and then try calling /event_auth/.
+        # missing_auth is the set of the event's auth_events which we don't yet have
+        # in auth_events.
         missing_auth = event_auth_events.difference(
             e.event_id for e in auth_events.values()
         )
 
+        # if we have missing events, we need to fetch those events from somewhere.
+        #
+        # we start by checking if they are in the store, and then try calling /event_auth/.
         if missing_auth:
-            # TODO: can we use store.have_seen_events here instead?
-            have_events = yield self.store.get_seen_events_with_rejections(missing_auth)
-            logger.debug("Got events %s from store", have_events)
-            missing_auth.difference_update(have_events.keys())
-        else:
-            have_events = {}
-
-        have_events.update({e.event_id: "" for e in auth_events.values()})
+            have_events = await self.store.have_seen_events(missing_auth)
+            logger.debug("Events %s are in the store", have_events)
+            missing_auth.difference_update(have_events)
 
         if missing_auth:
             # If we don't have all the auth events, we need to get them.
             logger.info("auth_events contains unknown events: %s", missing_auth)
             try:
                 try:
-                    remote_auth_chain = yield self.federation_client.get_event_auth(
+                    remote_auth_chain = await self.federation_client.get_event_auth(
                         origin, event.room_id, event.event_id
                     )
                 except RequestSendFailed as e:
                     # The other side isn't around or doesn't implement the
                     # endpoint, so lets just bail out.
                     logger.info("Failed to get event auth from remote: %s", e)
-                    return
+                    return context
 
-                seen_remotes = yield self.store.have_seen_events(
+                seen_remotes = await self.store.have_seen_events(
                     [e.event_id for e in remote_auth_chain]
                 )
 
@@ -2109,32 +2316,31 @@ class FederationHandler(BaseHandler):
                         logger.debug(
                             "do_auth %s missing_auth: %s", event.event_id, e.event_id
                         )
-                        yield self._handle_new_event(origin, e, auth_events=auth)
+                        await self._handle_new_event(origin, e, auth_events=auth)
 
                         if e.event_id in event_auth_events:
                             auth_events[(e.type, e.state_key)] = e
                     except AuthError:
                         pass
 
-                have_events = yield self.store.get_seen_events_with_rejections(
-                    event.auth_event_ids()
-                )
             except Exception:
-                # FIXME:
                 logger.exception("Failed to get auth chain")
 
         if event.internal_metadata.is_outlier():
+            # XXX: given that, for an outlier, we'll be working with the
+            # event's *claimed* auth events rather than those we calculated:
+            # (a) is there any point in this test, since different_auth below will
+            # obviously be empty
+            # (b) alternatively, why don't we do it earlier?
             logger.info("Skipping auth_event fetch for outlier")
-            return
+            return context
 
-        # FIXME: Assumes we have and stored all the state for all the
-        # prev_events
         different_auth = event_auth_events.difference(
             e.event_id for e in auth_events.values()
         )
 
         if not different_auth:
-            return
+            return context
 
         logger.info(
             "auth_events refers to events which are not in our calculated auth "
@@ -2142,175 +2348,94 @@ class FederationHandler(BaseHandler):
             different_auth,
         )
 
-        room_version = yield self.store.get_room_version(event.room_id)
+        # XXX: currently this checks for redactions but I'm not convinced that is
+        # necessary?
+        different_events = await self.store.get_events_as_list(different_auth)
 
-        different_events = yield make_deferred_yieldable(
-            defer.gatherResults(
-                [
-                    run_in_background(
-                        self.store.get_event, d, allow_none=True, allow_rejected=False
-                    )
-                    for d in different_auth
-                    if d in have_events and not have_events[d]
-                ],
-                consumeErrors=True,
-            )
-        ).addErrback(unwrapFirstError)
-
-        if different_events:
-            local_view = dict(auth_events)
-            remote_view = dict(auth_events)
-            remote_view.update(
-                {(d.type, d.state_key): d for d in different_events if d}
-            )
-
-            new_state = yield self.state_handler.resolve_events(
-                room_version,
-                [list(local_view.values()), list(remote_view.values())],
-                event,
-            )
-
-            logger.info(
-                "After state res: updating auth_events with new state %s",
-                {
-                    (d.type, d.state_key): d.event_id
-                    for d in new_state.values()
-                    if auth_events.get((d.type, d.state_key)) != d
-                },
-            )
-
-            auth_events.update(new_state)
-
-            different_auth = event_auth_events.difference(
-                e.event_id for e in auth_events.values()
-            )
-
-            yield self._update_context_for_auth_events(
-                event, context, auth_events, event_key
-            )
+        for d in different_events:
+            if d.room_id != event.room_id:
+                logger.warning(
+                    "Event %s refers to auth_event %s which is in a different room",
+                    event.event_id,
+                    d.event_id,
+                )
 
-        if not different_auth:
-            # we're done
-            return
+                # don't attempt to resolve the claimed auth events against our own
+                # in this case: just use our own auth events.
+                #
+                # XXX: should we reject the event in this case? It feels like we should,
+                # but then shouldn't we also do so if we've failed to fetch any of the
+                # auth events?
+                return context
+
+        # now we state-resolve between our own idea of the auth events, and the remote's
+        # idea of them.
+
+        local_state = auth_events.values()
+        remote_auth_events = dict(auth_events)
+        remote_auth_events.update({(d.type, d.state_key): d for d in different_events})
+        remote_state = remote_auth_events.values()
+
+        room_version = await self.store.get_room_version_id(event.room_id)
+        new_state = await self.state_handler.resolve_events(
+            room_version, (local_state, remote_state), event
+        )
 
         logger.info(
-            "auth_events still refers to events which are not in the calculated auth "
-            "chain after state resolution: %s",
-            different_auth,
+            "After state res: updating auth_events with new state %s",
+            {
+                (d.type, d.state_key): d.event_id
+                for d in new_state.values()
+                if auth_events.get((d.type, d.state_key)) != d
+            },
         )
 
-        # Only do auth resolution if we have something new to say.
-        # We can't prove an auth failure.
-        do_resolution = False
-
-        for e_id in different_auth:
-            if e_id in have_events:
-                if have_events[e_id] == RejectedReason.NOT_ANCESTOR:
-                    do_resolution = True
-                    break
-
-        if not do_resolution:
-            logger.info(
-                "Skipping auth resolution due to lack of provable rejection reasons"
-            )
-            return
-
-        logger.info("Doing auth resolution")
-
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
-
-        # 1. Get what we think is the auth chain.
-        auth_ids = yield self.auth.compute_auth_events(event, prev_state_ids)
-        local_auth_chain = yield self.store.get_auth_chain(auth_ids, include_given=True)
-
-        try:
-            # 2. Get remote difference.
-            try:
-                result = yield self.federation_client.query_auth(
-                    origin, event.room_id, event.event_id, local_auth_chain
-                )
-            except RequestSendFailed as e:
-                # The other side isn't around or doesn't implement the
-                # endpoint, so lets just bail out.
-                logger.info("Failed to query auth from remote: %s", e)
-                return
-
-            seen_remotes = yield self.store.have_seen_events(
-                [e.event_id for e in result["auth_chain"]]
-            )
-
-            # 3. Process any remote auth chain events we haven't seen.
-            for ev in result["auth_chain"]:
-                if ev.event_id in seen_remotes:
-                    continue
-
-                if ev.event_id == event.event_id:
-                    continue
-
-                try:
-                    auth_ids = ev.auth_event_ids()
-                    auth = {
-                        (e.type, e.state_key): e
-                        for e in result["auth_chain"]
-                        if e.event_id in auth_ids or event.type == EventTypes.Create
-                    }
-                    ev.internal_metadata.outlier = True
-
-                    logger.debug(
-                        "do_auth %s different_auth: %s", event.event_id, e.event_id
-                    )
-
-                    yield self._handle_new_event(origin, ev, auth_events=auth)
-
-                    if ev.event_id in event_auth_events:
-                        auth_events[(ev.type, ev.state_key)] = ev
-                except AuthError:
-                    pass
-
-        except Exception:
-            # FIXME:
-            logger.exception("Failed to query auth chain")
-
-        # 4. Look at rejects and their proofs.
-        # TODO.
+        auth_events.update(new_state)
 
-        yield self._update_context_for_auth_events(
-            event, context, auth_events, event_key
+        context = await self._update_context_for_auth_events(
+            event, context, auth_events
         )
 
-    @defer.inlineCallbacks
-    def _update_context_for_auth_events(self, event, context, auth_events, event_key):
+        return context
+
+    async def _update_context_for_auth_events(
+        self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]
+    ) -> EventContext:
         """Update the state_ids in an event context after auth event resolution,
         storing the changes as a new state group.
 
         Args:
-            event (Event): The event we're handling the context for
+            event: The event we're handling the context for
 
-            context (synapse.events.snapshot.EventContext): event context
-                to be updated
+            context: initial event context
 
-            auth_events (dict[(str, str)->str]): Events to update in the event
-                context.
+            auth_events: Events to update in the event context.
 
-            event_key ((str, str)): (type, state_key) for the current event.
-                this will not be included in the current_state in the context.
+        Returns:
+            new event context
         """
+        # exclude the state key of the new event from the current_state in the context.
+        if event.is_state():
+            event_key = (event.type, event.state_key)  # type: Optional[Tuple[str, str]]
+        else:
+            event_key = None
         state_updates = {
             k: a.event_id for k, a in iteritems(auth_events) if k != event_key
         }
-        current_state_ids = yield context.get_current_state_ids(self.store)
+
+        current_state_ids = await context.get_current_state_ids()
         current_state_ids = dict(current_state_ids)
 
         current_state_ids.update(state_updates)
 
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = await context.get_prev_state_ids()
         prev_state_ids = dict(prev_state_ids)
 
         prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)})
 
         # create a new state group as a delta from the existing one.
         prev_group = context.state_group
-        state_group = yield self.store.store_state_group(
+        state_group = await self.state_store.store_state_group(
             event.event_id,
             event.room_id,
             prev_group=prev_group,
@@ -2318,16 +2443,18 @@ class FederationHandler(BaseHandler):
             current_state_ids=current_state_ids,
         )
 
-        yield context.update_state(
+        return EventContext.with_state(
             state_group=state_group,
+            state_group_before_event=context.state_group_before_event,
             current_state_ids=current_state_ids,
             prev_state_ids=prev_state_ids,
             prev_group=prev_group,
             delta_ids=state_updates,
         )
 
-    @defer.inlineCallbacks
-    def construct_auth_difference(self, local_auth, remote_auth):
+    async def construct_auth_difference(
+        self, local_auth: Iterable[EventBase], remote_auth: Iterable[EventBase]
+    ) -> Dict:
         """ Given a local and remote auth chain, find the differences. This
         assumes that we have already processed all events in remote_auth
 
@@ -2436,7 +2563,7 @@ class FederationHandler(BaseHandler):
         reason_map = {}
 
         for e in base_remote_rejected:
-            reason = yield self.store.get_rejection_reason(e.event_id)
+            reason = await self.store.get_rejection_reason(e.event_id)
             if reason is None:
                 # TODO: e is not in the current state, so we should
                 # construct some proof of that.
@@ -2444,15 +2571,6 @@ class FederationHandler(BaseHandler):
 
             reason_map[e.event_id] = reason
 
-            if reason == RejectedReason.AUTH_ERROR:
-                pass
-            elif reason == RejectedReason.REPLACED:
-                # TODO: Get proof
-                pass
-            elif reason == RejectedReason.NOT_ANCESTOR:
-                # TODO: Get proof.
-                pass
-
         logger.debug("construct_auth_difference returning")
 
         return {
@@ -2464,9 +2582,8 @@ class FederationHandler(BaseHandler):
             "missing": [e.event_id for e in missing_locals],
         }
 
-    @defer.inlineCallbacks
     @log_function
-    def exchange_third_party_invite(
+    async def exchange_third_party_invite(
         self, sender_user_id, target_user_id, room_id, signed
     ):
         third_party_invite = {"signed": signed}
@@ -2482,16 +2599,16 @@ class FederationHandler(BaseHandler):
             "state_key": target_user_id,
         }
 
-        if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
-            room_version = yield self.store.get_room_version(room_id)
+        if await self.auth.check_host_in_room(room_id, self.hs.hostname):
+            room_version = await self.store.get_room_version_id(room_id)
             builder = self.event_builder_factory.new(room_version, event_dict)
 
             EventValidator().validate_builder(builder)
-            event, context = yield self.event_creation_handler.create_new_client_event(
+            event, context = await self.event_creation_handler.create_new_client_event(
                 builder=builder
             )
 
-            event_allowed = yield self.third_party_event_rules.check_event_allowed(
+            event_allowed = await self.third_party_event_rules.check_event_allowed(
                 event, context
             )
             if not event_allowed:
@@ -2503,58 +2620,58 @@ class FederationHandler(BaseHandler):
                     403, "This event is not allowed in this context", Codes.FORBIDDEN
                 )
 
-            event, context = yield self.add_display_name_to_third_party_invite(
+            event, context = await self.add_display_name_to_third_party_invite(
                 room_version, event_dict, event, context
             )
 
-            EventValidator().validate_new(event)
+            EventValidator().validate_new(event, self.config)
 
             # We need to tell the transaction queue to send this out, even
             # though the sender isn't a local user.
             event.internal_metadata.send_on_behalf_of = self.hs.hostname
 
             try:
-                yield self.auth.check_from_context(room_version, event, context)
+                await self.auth.check_from_context(room_version, event, context)
             except AuthError as e:
-                logger.warn("Denying new third party invite %r because %s", event, e)
+                logger.warning("Denying new third party invite %r because %s", event, e)
                 raise e
 
-            yield self._check_signature(event, context)
+            await self._check_signature(event, context)
+
+            # We retrieve the room member handler here as to not cause a cyclic dependency
             member_handler = self.hs.get_room_member_handler()
-            yield member_handler.send_membership_event(None, event, context)
+            await member_handler.send_membership_event(None, event, context)
         else:
-            destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
-            yield self.federation_client.forward_third_party_invite(
+            destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)}
+            await self.federation_client.forward_third_party_invite(
                 destinations, room_id, event_dict
             )
 
-    @defer.inlineCallbacks
-    @log_function
-    def on_exchange_third_party_invite_request(self, room_id, event_dict):
+    async def on_exchange_third_party_invite_request(
+        self, room_id: str, event_dict: JsonDict
+    ) -> None:
         """Handle an exchange_third_party_invite request from a remote server
 
         The remote server will call this when it wants to turn a 3pid invite
         into a normal m.room.member invite.
 
         Args:
-            room_id (str): The ID of the room.
+            room_id: The ID of the room.
 
             event_dict (dict[str, Any]): Dictionary containing the event body.
 
-        Returns:
-            Deferred: resolves (to None)
         """
-        room_version = yield self.store.get_room_version(room_id)
+        room_version = await self.store.get_room_version_id(room_id)
 
         # NB: event_dict has a particular specced format we might need to fudge
         # if we change event formats too much.
         builder = self.event_builder_factory.new(room_version, event_dict)
 
-        event, context = yield self.event_creation_handler.create_new_client_event(
+        event, context = await self.event_creation_handler.create_new_client_event(
             builder=builder
         )
 
-        event_allowed = yield self.third_party_event_rules.check_event_allowed(
+        event_allowed = await self.third_party_event_rules.check_event_allowed(
             event, context
         )
         if not event_allowed:
@@ -2565,26 +2682,26 @@ class FederationHandler(BaseHandler):
                 403, "This event is not allowed in this context", Codes.FORBIDDEN
             )
 
-        event, context = yield self.add_display_name_to_third_party_invite(
+        event, context = await self.add_display_name_to_third_party_invite(
             room_version, event_dict, event, context
         )
 
         try:
-            self.auth.check_from_context(room_version, event, context)
+            await self.auth.check_from_context(room_version, event, context)
         except AuthError as e:
-            logger.warn("Denying third party invite %r because %s", event, e)
+            logger.warning("Denying third party invite %r because %s", event, e)
             raise e
-        yield self._check_signature(event, context)
+        await self._check_signature(event, context)
 
         # We need to tell the transaction queue to send this out, even
         # though the sender isn't a local user.
         event.internal_metadata.send_on_behalf_of = get_domain_from_id(event.sender)
 
+        # We retrieve the room member handler here as to not cause a cyclic dependency
         member_handler = self.hs.get_room_member_handler()
-        yield member_handler.send_membership_event(None, event, context)
+        await member_handler.send_membership_event(None, event, context)
 
-    @defer.inlineCallbacks
-    def add_display_name_to_third_party_invite(
+    async def add_display_name_to_third_party_invite(
         self, room_version, event_dict, event, context
     ):
         key = (
@@ -2592,14 +2709,19 @@ class FederationHandler(BaseHandler):
             event.content["third_party_invite"]["signed"]["token"],
         )
         original_invite = None
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = await context.get_prev_state_ids()
         original_invite_id = prev_state_ids.get(key)
         if original_invite_id:
-            original_invite = yield self.store.get_event(
+            original_invite = await self.store.get_event(
                 original_invite_id, allow_none=True
             )
         if original_invite:
-            display_name = original_invite.content["display_name"]
+            # If the m.room.third_party_invite event's content is empty, it means the
+            # invite has been revoked. In this case, we don't have to raise an error here
+            # because the auth check will fail on the invite (because it's not able to
+            # fetch public keys from the m.room.third_party_invite event's content, which
+            # is empty).
+            display_name = original_invite.content.get("display_name")
             event_dict["content"]["third_party_invite"]["display_name"] = display_name
         else:
             logger.info(
@@ -2611,14 +2733,13 @@ class FederationHandler(BaseHandler):
 
         builder = self.event_builder_factory.new(room_version, event_dict)
         EventValidator().validate_builder(builder)
-        event, context = yield self.event_creation_handler.create_new_client_event(
+        event, context = await self.event_creation_handler.create_new_client_event(
             builder=builder
         )
-        EventValidator().validate_new(event)
+        EventValidator().validate_new(event, self.config)
         return (event, context)
 
-    @defer.inlineCallbacks
-    def _check_signature(self, event, context):
+    async def _check_signature(self, event, context):
         """
         Checks that the signature in the event is consistent with its invite.
 
@@ -2635,12 +2756,12 @@ class FederationHandler(BaseHandler):
         signed = event.content["third_party_invite"]["signed"]
         token = signed["token"]
 
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = await context.get_prev_state_ids()
         invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token))
 
         invite_event = None
         if invite_event_id:
-            invite_event = yield self.store.get_event(invite_event_id, allow_none=True)
+            invite_event = await self.store.get_event(invite_event_id, allow_none=True)
 
         if not invite_event:
             raise AuthError(403, "Could not find invite")
@@ -2689,7 +2810,7 @@ class FederationHandler(BaseHandler):
                             raise
                         try:
                             if "key_validity_url" in public_key_object:
-                                yield self._check_key_revocation(
+                                await self._check_key_revocation(
                                     public_key, public_key_object["key_validity_url"]
                                 )
                         except Exception:
@@ -2703,8 +2824,7 @@ class FederationHandler(BaseHandler):
                 last_exception = e
         raise last_exception
 
-    @defer.inlineCallbacks
-    def _check_key_revocation(self, public_key, url):
+    async def _check_key_revocation(self, public_key, url):
         """
         Checks whether public_key has been revoked.
 
@@ -2718,47 +2838,58 @@ class FederationHandler(BaseHandler):
                 for revocation.
         """
         try:
-            response = yield self.http_client.get_json(url, {"public_key": public_key})
+            response = await self.http_client.get_json(url, {"public_key": public_key})
         except Exception:
             raise SynapseError(502, "Third party certificate could not be checked")
         if "valid" not in response or not response["valid"]:
             raise AuthError(403, "Third party certificate was invalid")
 
-    @defer.inlineCallbacks
-    def persist_events_and_notify(self, event_and_contexts, backfilled=False):
+    async def persist_events_and_notify(
+        self,
+        event_and_contexts: Sequence[Tuple[EventBase, EventContext]],
+        backfilled: bool = False,
+    ) -> int:
         """Persists events and tells the notifier/pushers about them, if
         necessary.
 
         Args:
-            event_and_contexts(list[tuple[FrozenEvent, EventContext]])
-            backfilled (bool): Whether these events are a result of
+            event_and_contexts:
+            backfilled: Whether these events are a result of
                 backfilling or not
-
-        Returns:
-            Deferred
         """
-        if self.config.worker_app:
-            yield self._send_events_to_master(
+        if self.config.worker.writers.events != self._instance_name:
+            result = await self._send_events(
+                instance_name=self.config.worker.writers.events,
                 store=self.store,
                 event_and_contexts=event_and_contexts,
                 backfilled=backfilled,
             )
+            return result["max_stream_id"]
         else:
-            max_stream_id = yield self.store.persist_events(
+            max_stream_id = await self.storage.persistence.persist_events(
                 event_and_contexts, backfilled=backfilled
             )
 
+            if self._ephemeral_messages_enabled:
+                for (event, context) in event_and_contexts:
+                    # If there's an expiry timestamp on the event, schedule its expiry.
+                    self._message_handler.maybe_schedule_expiry(event)
+
             if not backfilled:  # Never notify for backfilled events
                 for event, _ in event_and_contexts:
-                    yield self._notify_persisted_event(event, max_stream_id)
+                    await self._notify_persisted_event(event, max_stream_id)
+
+            return max_stream_id
 
-    def _notify_persisted_event(self, event, max_stream_id):
+    async def _notify_persisted_event(
+        self, event: EventBase, max_stream_id: int
+    ) -> None:
         """Checks to see if notifier/pushers should be notified about the
         event or not.
 
         Args:
-            event (FrozenEvent)
-            max_stream_id (int): The max_stream_id returned by persist_events
+            event:
+            max_stream_id: The max_stream_id returned by persist_events
         """
 
         extra_users = []
@@ -2782,32 +2913,31 @@ class FederationHandler(BaseHandler):
             event, event_stream_id, max_stream_id, extra_users=extra_users
         )
 
-        return self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
+        await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
 
-    def _clean_room_for_join(self, room_id):
+    async def _clean_room_for_join(self, room_id: str) -> None:
         """Called to clean up any data in DB for a given room, ready for the
         server to join the room.
 
         Args:
-            room_id (str)
+            room_id
         """
         if self.config.worker_app:
-            return self._clean_room_for_join_client(room_id)
+            await self._clean_room_for_join_client(room_id)
         else:
-            return self.store.clean_room_for_join(room_id)
+            await self.store.clean_room_for_join(room_id)
 
-    def user_joined_room(self, user, room_id):
+    async def user_joined_room(self, user: UserID, room_id: str) -> None:
         """Called when a new user has joined the room
         """
         if self.config.worker_app:
-            return self._notify_user_membership_change(
+            await self._notify_user_membership_change(
                 room_id=room_id, user_id=user.to_string(), change="joined"
             )
         else:
-            return user_joined_room(self.distributor, user, room_id)
+            user_joined_room(self.distributor, user, room_id)
 
-    @defer.inlineCallbacks
-    def get_room_complexity(self, remote_room_hosts, room_id):
+    async def get_room_complexity(self, remote_room_hosts, room_id):
         """
         Fetch the complexity of a remote room over federation.
 
@@ -2821,12 +2951,12 @@ class FederationHandler(BaseHandler):
         """
 
         for host in remote_room_hosts:
-            res = yield self.federation_client.get_room_complexity(host, room_id)
+            res = await self.federation_client.get_room_complexity(host, room_id)
 
             # We got a result, return it.
             if res:
-                defer.returnValue(res)
+                return res
 
         # We fell off the bottom, couldn't get the complexity from anyone. Oh
         # well.
-        defer.returnValue(None)
+        return None