summary refs log tree commit diff
path: root/synapse/federation/federation_server.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/federation_server.py')
-rw-r--r--synapse/federation/federation_server.py562
1 files changed, 282 insertions, 280 deletions
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index da06ab379d..32a8a2ee46 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -1,6 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
 # Copyright 2018 New Vector Ltd
+# Copyright 2019 Matrix.org Federation C.I.C
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,6 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from typing import Any, Callable, Dict, List, Match, Optional, Tuple, Union
 
 import six
 from six import iteritems
@@ -36,21 +38,24 @@ from synapse.api.errors import (
     UnsupportedRoomVersionError,
 )
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
-from synapse.crypto.event_signing import compute_event_signature
-from synapse.events import room_version_to_event_format
+from synapse.events import EventBase
 from synapse.federation.federation_base import FederationBase, event_from_pdu_json
 from synapse.federation.persistence import TransactionActions
 from synapse.federation.units import Edu, Transaction
 from synapse.http.endpoint import parse_server_name
-from synapse.logging.context import nested_logging_context
+from synapse.logging.context import (
+    make_deferred_yieldable,
+    nested_logging_context,
+    run_in_background,
+)
 from synapse.logging.opentracing import log_kv, start_active_span_from_edu, trace
 from synapse.logging.utils import log_function
 from synapse.replication.http.federation import (
     ReplicationFederationSendEduRestServlet,
     ReplicationGetQueryRestServlet,
 )
-from synapse.types import get_domain_from_id
-from synapse.util import glob_to_regex
+from synapse.types import JsonDict, get_domain_from_id
+from synapse.util import glob_to_regex, unwrapFirstError
 from synapse.util.async_helpers import Linearizer, concurrently_execute
 from synapse.util.caches.response_cache import ResponseCache
 
@@ -75,6 +80,9 @@ class FederationServer(FederationBase):
 
         self.auth = hs.get_auth()
         self.handler = hs.get_handlers().federation_handler
+        self.state = hs.get_state_handler()
+
+        self.device_handler = hs.get_device_handler()
 
         self._server_linearizer = Linearizer("fed_server")
         self._transaction_linearizer = Linearizer("fed_txn_handler")
@@ -87,14 +95,14 @@ class FederationServer(FederationBase):
         # come in waves.
         self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000)
 
-    @defer.inlineCallbacks
-    @log_function
-    def on_backfill_request(self, origin, room_id, versions, limit):
-        with (yield self._server_linearizer.queue((origin, room_id))):
+    async def on_backfill_request(
+        self, origin: str, room_id: str, versions: List[str], limit: int
+    ) -> Tuple[int, Dict[str, Any]]:
+        with (await self._server_linearizer.queue((origin, room_id))):
             origin_host, _ = parse_server_name(origin)
-            yield self.check_server_matches_acl(origin_host, room_id)
+            await self.check_server_matches_acl(origin_host, room_id)
 
-            pdus = yield self.handler.on_backfill_request(
+            pdus = await self.handler.on_backfill_request(
                 origin, room_id, versions, limit
             )
 
@@ -102,76 +110,114 @@ class FederationServer(FederationBase):
 
         return 200, res
 
-    @defer.inlineCallbacks
-    @log_function
-    def on_incoming_transaction(self, origin, transaction_data):
+    async def on_incoming_transaction(
+        self, origin: str, transaction_data: JsonDict
+    ) -> Tuple[int, Dict[str, Any]]:
         # keep this as early as possible to make the calculated origin ts as
         # accurate as possible.
         request_time = self._clock.time_msec()
 
         transaction = Transaction(**transaction_data)
 
-        if not transaction.transaction_id:
+        if not transaction.transaction_id:  # type: ignore
             raise Exception("Transaction missing transaction_id")
 
-        logger.debug("[%s] Got transaction", transaction.transaction_id)
+        logger.debug("[%s] Got transaction", transaction.transaction_id)  # type: ignore
 
         # use a linearizer to ensure that we don't process the same transaction
         # multiple times in parallel.
         with (
-            yield self._transaction_linearizer.queue(
-                (origin, transaction.transaction_id)
+            await self._transaction_linearizer.queue(
+                (origin, transaction.transaction_id)  # type: ignore
             )
         ):
-            result = yield self._handle_incoming_transaction(
+            result = await self._handle_incoming_transaction(
                 origin, transaction, request_time
             )
 
         return result
 
-    @defer.inlineCallbacks
-    def _handle_incoming_transaction(self, origin, transaction, request_time):
+    async def _handle_incoming_transaction(
+        self, origin: str, transaction: Transaction, request_time: int
+    ) -> Tuple[int, Dict[str, Any]]:
         """ Process an incoming transaction and return the HTTP response
 
         Args:
-            origin (unicode): the server making the request
-            transaction (Transaction): incoming transaction
-            request_time (int): timestamp that the HTTP request arrived at
+            origin: the server making the request
+            transaction: incoming transaction
+            request_time: timestamp that the HTTP request arrived at
 
         Returns:
-            Deferred[(int, object)]: http response code and body
+            HTTP response code and body
         """
-        response = yield self.transaction_actions.have_responded(origin, transaction)
+        response = await self.transaction_actions.have_responded(origin, transaction)
 
         if response:
             logger.debug(
                 "[%s] We've already responded to this request",
-                transaction.transaction_id,
+                transaction.transaction_id,  # type: ignore
             )
             return response
 
-        logger.debug("[%s] Transaction is new", transaction.transaction_id)
+        logger.debug("[%s] Transaction is new", transaction.transaction_id)  # type: ignore
 
-        # Reject if PDU count > 50 and EDU count > 100
-        if len(transaction.pdus) > 50 or (
-            hasattr(transaction, "edus") and len(transaction.edus) > 100
+        # Reject if PDU count > 50 or EDU count > 100
+        if len(transaction.pdus) > 50 or (  # type: ignore
+            hasattr(transaction, "edus") and len(transaction.edus) > 100  # type: ignore
         ):
 
             logger.info("Transaction PDU or EDU count too large. Returning 400")
 
             response = {}
-            yield self.transaction_actions.set_response(
+            await self.transaction_actions.set_response(
                 origin, transaction, 400, response
             )
             return 400, response
 
-        received_pdus_counter.inc(len(transaction.pdus))
+        # We process PDUs and EDUs in parallel. This is important as we don't
+        # want to block things like to device messages from reaching clients
+        # behind the potentially expensive handling of PDUs.
+        pdu_results, _ = await make_deferred_yieldable(
+            defer.gatherResults(
+                [
+                    run_in_background(
+                        self._handle_pdus_in_txn, origin, transaction, request_time
+                    ),
+                    run_in_background(self._handle_edus_in_txn, origin, transaction),
+                ],
+                consumeErrors=True,
+            ).addErrback(unwrapFirstError)
+        )
+
+        response = {"pdus": pdu_results}
+
+        logger.debug("Returning: %s", str(response))
+
+        await self.transaction_actions.set_response(origin, transaction, 200, response)
+        return 200, response
+
+    async def _handle_pdus_in_txn(
+        self, origin: str, transaction: Transaction, request_time: int
+    ) -> Dict[str, dict]:
+        """Process the PDUs in a received transaction.
+
+        Args:
+            origin: the server making the request
+            transaction: incoming transaction
+            request_time: timestamp that the HTTP request arrived at
+
+        Returns:
+            A map from event ID of a processed PDU to any errors we should
+            report back to the sending server.
+        """
+
+        received_pdus_counter.inc(len(transaction.pdus))  # type: ignore
 
         origin_host, _ = parse_server_name(origin)
 
-        pdus_by_room = {}
+        pdus_by_room = {}  # type: Dict[str, List[EventBase]]
 
-        for p in transaction.pdus:
+        for p in transaction.pdus:  # type: ignore
             if "unsigned" in p:
                 unsigned = p["unsigned"]
                 if "age" in unsigned:
@@ -196,24 +242,17 @@ class FederationServer(FederationBase):
                 continue
 
             try:
-                room_version = yield self.store.get_room_version(room_id)
+                room_version = await self.store.get_room_version(room_id)
             except NotFoundError:
                 logger.info("Ignoring PDU for unknown room_id: %s", room_id)
                 continue
-
-            try:
-                format_ver = room_version_to_event_format(room_version)
-            except UnsupportedRoomVersionError:
+            except UnsupportedRoomVersionError as e:
                 # this can happen if support for a given room version is withdrawn,
                 # so that we still get events for said room.
-                logger.info(
-                    "Ignoring PDU for room %s with unknown version %s",
-                    room_id,
-                    room_version,
-                )
+                logger.info("Ignoring PDU: %s", e)
                 continue
 
-            event = event_from_pdu_json(p, format_ver)
+            event = event_from_pdu_json(p, room_version)
             pdus_by_room.setdefault(room_id, []).append(event)
 
         pdu_results = {}
@@ -222,13 +261,12 @@ class FederationServer(FederationBase):
         # require callouts to other servers to fetch missing events), but
         # impose a limit to avoid going too crazy with ram/cpu.
 
-        @defer.inlineCallbacks
-        def process_pdus_for_room(room_id):
+        async def process_pdus_for_room(room_id: str):
             logger.debug("Processing PDUs for %s", room_id)
             try:
-                yield self.check_server_matches_acl(origin_host, room_id)
+                await self.check_server_matches_acl(origin_host, room_id)
             except AuthError as e:
-                logger.warn("Ignoring PDUs for room %s from banned server", room_id)
+                logger.warning("Ignoring PDUs for room %s from banned server", room_id)
                 for pdu in pdus_by_room[room_id]:
                     event_id = pdu.event_id
                     pdu_results[event_id] = e.error_dict()
@@ -238,10 +276,10 @@ class FederationServer(FederationBase):
                 event_id = pdu.event_id
                 with nested_logging_context(event_id):
                     try:
-                        yield self._handle_received_pdu(origin, pdu)
+                        await self._handle_received_pdu(origin, pdu)
                         pdu_results[event_id] = {}
                     except FederationError as e:
-                        logger.warn("Error handling PDU %s: %s", event_id, e)
+                        logger.warning("Error handling PDU %s: %s", event_id, e)
                         pdu_results[event_id] = {"error": str(e)}
                     except Exception as e:
                         f = failure.Failure()
@@ -252,36 +290,40 @@ class FederationServer(FederationBase):
                             exc_info=(f.type, f.value, f.getTracebackObject()),
                         )
 
-        yield concurrently_execute(
+        await concurrently_execute(
             process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
         )
 
-        if hasattr(transaction, "edus"):
-            for edu in (Edu(**x) for x in transaction.edus):
-                yield self.received_edu(origin, edu.edu_type, edu.content)
-
-        response = {"pdus": pdu_results}
+        return pdu_results
 
-        logger.debug("Returning: %s", str(response))
+    async def _handle_edus_in_txn(self, origin: str, transaction: Transaction):
+        """Process the EDUs in a received transaction.
+        """
 
-        yield self.transaction_actions.set_response(origin, transaction, 200, response)
-        return 200, response
+        async def _process_edu(edu_dict):
+            received_edus_counter.inc()
 
-    @defer.inlineCallbacks
-    def received_edu(self, origin, edu_type, content):
-        received_edus_counter.inc()
-        yield self.registry.on_edu(edu_type, origin, content)
+            edu = Edu(
+                origin=origin,
+                destination=self.server_name,
+                edu_type=edu_dict["edu_type"],
+                content=edu_dict["content"],
+            )
+            await self.registry.on_edu(edu.edu_type, origin, edu.content)
 
-    @defer.inlineCallbacks
-    @log_function
-    def on_context_state_request(self, origin, room_id, event_id):
-        if not event_id:
-            raise NotImplementedError("Specify an event")
+        await concurrently_execute(
+            _process_edu,
+            getattr(transaction, "edus", []),
+            TRANSACTION_CONCURRENCY_LIMIT,
+        )
 
+    async def on_context_state_request(
+        self, origin: str, room_id: str, event_id: str
+    ) -> Tuple[int, Dict[str, Any]]:
         origin_host, _ = parse_server_name(origin)
-        yield self.check_server_matches_acl(origin_host, room_id)
+        await self.check_server_matches_acl(origin_host, room_id)
 
-        in_room = yield self.auth.check_host_in_room(room_id, origin)
+        in_room = await self.auth.check_host_in_room(room_id, origin)
         if not in_room:
             raise AuthError(403, "Host not in room.")
 
@@ -290,237 +332,196 @@ class FederationServer(FederationBase):
         # in the cache so we could return it without waiting for the linearizer
         # - but that's non-trivial to get right, and anyway somewhat defeats
         # the point of the linearizer.
-        with (yield self._server_linearizer.queue((origin, room_id))):
-            resp = yield self._state_resp_cache.wrap(
-                (room_id, event_id),
-                self._on_context_state_request_compute,
-                room_id,
-                event_id,
+        with (await self._server_linearizer.queue((origin, room_id))):
+            resp = dict(
+                await self._state_resp_cache.wrap(
+                    (room_id, event_id),
+                    self._on_context_state_request_compute,
+                    room_id,
+                    event_id,
+                )
             )
 
+        room_version = await self.store.get_room_version_id(room_id)
+        resp["room_version"] = room_version
+
         return 200, resp
 
-    @defer.inlineCallbacks
-    def on_state_ids_request(self, origin, room_id, event_id):
+    async def on_state_ids_request(
+        self, origin: str, room_id: str, event_id: str
+    ) -> Tuple[int, Dict[str, Any]]:
         if not event_id:
             raise NotImplementedError("Specify an event")
 
         origin_host, _ = parse_server_name(origin)
-        yield self.check_server_matches_acl(origin_host, room_id)
+        await self.check_server_matches_acl(origin_host, room_id)
 
-        in_room = yield self.auth.check_host_in_room(room_id, origin)
+        in_room = await self.auth.check_host_in_room(room_id, origin)
         if not in_room:
             raise AuthError(403, "Host not in room.")
 
-        state_ids = yield self.handler.get_state_ids_for_pdu(room_id, event_id)
-        auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids)
+        state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
+        auth_chain_ids = await self.store.get_auth_chain_ids(state_ids)
 
         return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
 
-    @defer.inlineCallbacks
-    def _on_context_state_request_compute(self, room_id, event_id):
-        pdus = yield self.handler.get_state_for_pdu(room_id, event_id)
-        auth_chain = yield self.store.get_auth_chain([pdu.event_id for pdu in pdus])
-
-        for event in auth_chain:
-            # We sign these again because there was a bug where we
-            # incorrectly signed things the first time round
-            if self.hs.is_mine_id(event.event_id):
-                event.signatures.update(
-                    compute_event_signature(
-                        event.get_pdu_json(),
-                        self.hs.hostname,
-                        self.hs.config.signing_key[0],
-                    )
-                )
+    async def _on_context_state_request_compute(
+        self, room_id: str, event_id: str
+    ) -> Dict[str, list]:
+        if event_id:
+            pdus = await self.handler.get_state_for_pdu(room_id, event_id)
+        else:
+            pdus = (await self.state.get_current_state(room_id)).values()
+
+        auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus])
 
         return {
             "pdus": [pdu.get_pdu_json() for pdu in pdus],
             "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
         }
 
-    @defer.inlineCallbacks
-    @log_function
-    def on_pdu_request(self, origin, event_id):
-        pdu = yield self.handler.get_persisted_pdu(origin, event_id)
+    async def on_pdu_request(
+        self, origin: str, event_id: str
+    ) -> Tuple[int, Union[JsonDict, str]]:
+        pdu = await self.handler.get_persisted_pdu(origin, event_id)
 
         if pdu:
             return 200, self._transaction_from_pdus([pdu]).get_dict()
         else:
             return 404, ""
 
-    @defer.inlineCallbacks
-    def on_query_request(self, query_type, args):
+    async def on_query_request(
+        self, query_type: str, args: Dict[str, str]
+    ) -> Tuple[int, Dict[str, Any]]:
         received_queries_counter.labels(query_type).inc()
-        resp = yield self.registry.on_query(query_type, args)
+        resp = await self.registry.on_query(query_type, args)
         return 200, resp
 
-    @defer.inlineCallbacks
-    def on_make_join_request(self, origin, room_id, user_id, supported_versions):
+    async def on_make_join_request(
+        self, origin: str, room_id: str, user_id: str, supported_versions: List[str]
+    ) -> Dict[str, Any]:
         origin_host, _ = parse_server_name(origin)
-        yield self.check_server_matches_acl(origin_host, room_id)
+        await self.check_server_matches_acl(origin_host, room_id)
 
-        room_version = yield self.store.get_room_version(room_id)
+        room_version = await self.store.get_room_version_id(room_id)
         if room_version not in supported_versions:
-            logger.warn("Room version %s not in %s", room_version, supported_versions)
+            logger.warning(
+                "Room version %s not in %s", room_version, supported_versions
+            )
             raise IncompatibleRoomVersionError(room_version=room_version)
 
-        pdu = yield self.handler.on_make_join_request(origin, room_id, user_id)
+        pdu = await self.handler.on_make_join_request(origin, room_id, user_id)
         time_now = self._clock.time_msec()
         return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
 
-    @defer.inlineCallbacks
-    def on_invite_request(self, origin, content, room_version):
-        if room_version not in KNOWN_ROOM_VERSIONS:
+    async def on_invite_request(
+        self, origin: str, content: JsonDict, room_version_id: str
+    ) -> Dict[str, Any]:
+        room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
+        if not room_version:
             raise SynapseError(
                 400,
                 "Homeserver does not support this room version",
                 Codes.UNSUPPORTED_ROOM_VERSION,
             )
 
-        format_ver = room_version_to_event_format(room_version)
-
-        pdu = event_from_pdu_json(content, format_ver)
+        pdu = event_from_pdu_json(content, room_version)
         origin_host, _ = parse_server_name(origin)
-        yield self.check_server_matches_acl(origin_host, pdu.room_id)
-        ret_pdu = yield self.handler.on_invite_request(origin, pdu)
+        await self.check_server_matches_acl(origin_host, pdu.room_id)
+        pdu = await self._check_sigs_and_hash(room_version, pdu)
+        ret_pdu = await self.handler.on_invite_request(origin, pdu, room_version)
         time_now = self._clock.time_msec()
         return {"event": ret_pdu.get_pdu_json(time_now)}
 
-    @defer.inlineCallbacks
-    def on_send_join_request(self, origin, content, room_id):
+    async def on_send_join_request(
+        self, origin: str, content: JsonDict, room_id: str
+    ) -> Dict[str, Any]:
         logger.debug("on_send_join_request: content: %s", content)
 
-        room_version = yield self.store.get_room_version(room_id)
-        format_ver = room_version_to_event_format(room_version)
-        pdu = event_from_pdu_json(content, format_ver)
+        room_version = await self.store.get_room_version(room_id)
+        pdu = event_from_pdu_json(content, room_version)
 
         origin_host, _ = parse_server_name(origin)
-        yield self.check_server_matches_acl(origin_host, pdu.room_id)
+        await self.check_server_matches_acl(origin_host, pdu.room_id)
 
         logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
-        res_pdus = yield self.handler.on_send_join_request(origin, pdu)
+
+        pdu = await self._check_sigs_and_hash(room_version, pdu)
+
+        res_pdus = await self.handler.on_send_join_request(origin, pdu)
         time_now = self._clock.time_msec()
-        return (
-            200,
-            {
-                "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
-                "auth_chain": [
-                    p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
-                ],
-            },
-        )
+        return {
+            "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
+            "auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]],
+        }
 
-    @defer.inlineCallbacks
-    def on_make_leave_request(self, origin, room_id, user_id):
+    async def on_make_leave_request(
+        self, origin: str, room_id: str, user_id: str
+    ) -> Dict[str, Any]:
         origin_host, _ = parse_server_name(origin)
-        yield self.check_server_matches_acl(origin_host, room_id)
-        pdu = yield self.handler.on_make_leave_request(origin, room_id, user_id)
+        await self.check_server_matches_acl(origin_host, room_id)
+        pdu = await self.handler.on_make_leave_request(origin, room_id, user_id)
 
-        room_version = yield self.store.get_room_version(room_id)
+        room_version = await self.store.get_room_version_id(room_id)
 
         time_now = self._clock.time_msec()
         return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
 
-    @defer.inlineCallbacks
-    def on_send_leave_request(self, origin, content, room_id):
+    async def on_send_leave_request(
+        self, origin: str, content: JsonDict, room_id: str
+    ) -> dict:
         logger.debug("on_send_leave_request: content: %s", content)
 
-        room_version = yield self.store.get_room_version(room_id)
-        format_ver = room_version_to_event_format(room_version)
-        pdu = event_from_pdu_json(content, format_ver)
+        room_version = await self.store.get_room_version(room_id)
+        pdu = event_from_pdu_json(content, room_version)
 
         origin_host, _ = parse_server_name(origin)
-        yield self.check_server_matches_acl(origin_host, pdu.room_id)
+        await self.check_server_matches_acl(origin_host, pdu.room_id)
 
         logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
-        yield self.handler.on_send_leave_request(origin, pdu)
-        return 200, {}
 
-    @defer.inlineCallbacks
-    def on_event_auth(self, origin, room_id, event_id):
-        with (yield self._server_linearizer.queue((origin, room_id))):
-            origin_host, _ = parse_server_name(origin)
-            yield self.check_server_matches_acl(origin_host, room_id)
+        pdu = await self._check_sigs_and_hash(room_version, pdu)
 
-            time_now = self._clock.time_msec()
-            auth_pdus = yield self.handler.on_event_auth(event_id)
-            res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]}
-        return 200, res
+        await self.handler.on_send_leave_request(origin, pdu)
+        return {}
 
-    @defer.inlineCallbacks
-    def on_query_auth_request(self, origin, content, room_id, event_id):
-        """
-        Content is a dict with keys::
-            auth_chain (list): A list of events that give the auth chain.
-            missing (list): A list of event_ids indicating what the other
-              side (`origin`) think we're missing.
-            rejects (dict): A mapping from event_id to a 2-tuple of reason
-              string and a proof (or None) of why the event was rejected.
-              The keys of this dict give the list of events the `origin` has
-              rejected.
-
-        Args:
-            origin (str)
-            content (dict)
-            event_id (str)
-
-        Returns:
-            Deferred: Results in `dict` with the same format as `content`
-        """
-        with (yield self._server_linearizer.queue((origin, room_id))):
+    async def on_event_auth(
+        self, origin: str, room_id: str, event_id: str
+    ) -> Tuple[int, Dict[str, Any]]:
+        with (await self._server_linearizer.queue((origin, room_id))):
             origin_host, _ = parse_server_name(origin)
-            yield self.check_server_matches_acl(origin_host, room_id)
-
-            room_version = yield self.store.get_room_version(room_id)
-            format_ver = room_version_to_event_format(room_version)
-
-            auth_chain = [
-                event_from_pdu_json(e, format_ver) for e in content["auth_chain"]
-            ]
-
-            signed_auth = yield self._check_sigs_and_hash_and_fetch(
-                origin, auth_chain, outlier=True, room_version=room_version
-            )
-
-            ret = yield self.handler.on_query_auth(
-                origin,
-                event_id,
-                room_id,
-                signed_auth,
-                content.get("rejects", []),
-                content.get("missing", []),
-            )
+            await self.check_server_matches_acl(origin_host, room_id)
 
             time_now = self._clock.time_msec()
-            send_content = {
-                "auth_chain": [e.get_pdu_json(time_now) for e in ret["auth_chain"]],
-                "rejects": ret.get("rejects", []),
-                "missing": ret.get("missing", []),
-            }
-
-        return 200, send_content
+            auth_pdus = await self.handler.on_event_auth(event_id)
+            res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]}
+        return 200, res
 
     @log_function
-    def on_query_client_keys(self, origin, content):
-        return self.on_query_request("client_keys", content)
+    async def on_query_client_keys(
+        self, origin: str, content: Dict[str, str]
+    ) -> Tuple[int, Dict[str, Any]]:
+        return await self.on_query_request("client_keys", content)
 
-    def on_query_user_devices(self, origin, user_id):
-        return self.on_query_request("user_devices", user_id)
+    async def on_query_user_devices(
+        self, origin: str, user_id: str
+    ) -> Tuple[int, Dict[str, Any]]:
+        keys = await self.device_handler.on_federation_query_user_devices(user_id)
+        return 200, keys
 
     @trace
-    @defer.inlineCallbacks
-    @log_function
-    def on_claim_client_keys(self, origin, content):
+    async def on_claim_client_keys(
+        self, origin: str, content: JsonDict
+    ) -> Dict[str, Any]:
         query = []
         for user_id, device_keys in content.get("one_time_keys", {}).items():
             for device_id, algorithm in device_keys.items():
                 query.append((user_id, device_id, algorithm))
 
         log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
-        results = yield self.store.claim_e2e_one_time_keys(query)
+        results = await self.store.claim_e2e_one_time_keys(query)
 
-        json_result = {}
+        json_result = {}  # type: Dict[str, Dict[str, dict]]
         for user_id, device_keys in results.items():
             for device_id, keys in device_keys.items():
                 for key_id, json_bytes in keys.items():
@@ -542,16 +543,19 @@ class FederationServer(FederationBase):
 
         return {"one_time_keys": json_result}
 
-    @defer.inlineCallbacks
-    @log_function
-    def on_get_missing_events(
-        self, origin, room_id, earliest_events, latest_events, limit
-    ):
-        with (yield self._server_linearizer.queue((origin, room_id))):
+    async def on_get_missing_events(
+        self,
+        origin: str,
+        room_id: str,
+        earliest_events: List[str],
+        latest_events: List[str],
+        limit: int,
+    ) -> Dict[str, list]:
+        with (await self._server_linearizer.queue((origin, room_id))):
             origin_host, _ = parse_server_name(origin)
-            yield self.check_server_matches_acl(origin_host, room_id)
+            await self.check_server_matches_acl(origin_host, room_id)
 
-            logger.info(
+            logger.debug(
                 "on_get_missing_events: earliest_events: %r, latest_events: %r,"
                 " limit: %d",
                 earliest_events,
@@ -559,27 +563,27 @@ class FederationServer(FederationBase):
                 limit,
             )
 
-            missing_events = yield self.handler.on_get_missing_events(
+            missing_events = await self.handler.on_get_missing_events(
                 origin, room_id, earliest_events, latest_events, limit
             )
 
             if len(missing_events) < 5:
-                logger.info(
+                logger.debug(
                     "Returning %d events: %r", len(missing_events), missing_events
                 )
             else:
-                logger.info("Returning %d events", len(missing_events))
+                logger.debug("Returning %d events", len(missing_events))
 
             time_now = self._clock.time_msec()
 
         return {"events": [ev.get_pdu_json(time_now) for ev in missing_events]}
 
     @log_function
-    def on_openid_userinfo(self, token):
+    async def on_openid_userinfo(self, token: str) -> Optional[str]:
         ts_now_ms = self._clock.time_msec()
-        return self.store.get_user_id_for_open_id_token(token, ts_now_ms)
+        return await self.store.get_user_id_for_open_id_token(token, ts_now_ms)
 
-    def _transaction_from_pdus(self, pdu_list):
+    def _transaction_from_pdus(self, pdu_list: List[EventBase]) -> Transaction:
         """Returns a new Transaction containing the given PDUs suitable for
         transmission.
         """
@@ -592,8 +596,7 @@ class FederationServer(FederationBase):
             destination=None,
         )
 
-    @defer.inlineCallbacks
-    def _handle_received_pdu(self, origin, pdu):
+    async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None:
         """ Process a PDU received in a federation /send/ transaction.
 
         If the event is invalid, then this method throws a FederationError.
@@ -614,10 +617,8 @@ class FederationServer(FederationBase):
         until we try to backfill across the discontinuity.
 
         Args:
-            origin (str): server which sent the pdu
-            pdu (FrozenEvent): received pdu
-
-        Returns (Deferred): completes with None
+            origin: server which sent the pdu
+            pdu: received pdu
 
         Raises: FederationError if the signatures / hash do not match, or
             if the event was unacceptable for any other reason (eg, too large,
@@ -646,68 +647,67 @@ class FederationServer(FederationBase):
                 logger.info("Accepting join PDU %s from %s", pdu.event_id, origin)
 
         # We've already checked that we know the room version by this point
-        room_version = yield self.store.get_room_version(pdu.room_id)
+        room_version = await self.store.get_room_version(pdu.room_id)
 
         # Check signature.
         try:
-            pdu = yield self._check_sigs_and_hash(room_version, pdu)
+            pdu = await self._check_sigs_and_hash(room_version, pdu)
         except SynapseError as e:
             raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id)
 
-        yield self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True)
+        await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True)
 
     def __str__(self):
         return "<ReplicationLayer(%s)>" % self.server_name
 
-    @defer.inlineCallbacks
-    def exchange_third_party_invite(
-        self, sender_user_id, target_user_id, room_id, signed
+    async def exchange_third_party_invite(
+        self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict
     ):
-        ret = yield self.handler.exchange_third_party_invite(
+        ret = await self.handler.exchange_third_party_invite(
             sender_user_id, target_user_id, room_id, signed
         )
         return ret
 
-    @defer.inlineCallbacks
-    def on_exchange_third_party_invite_request(self, room_id, event_dict):
-        ret = yield self.handler.on_exchange_third_party_invite_request(
+    async def on_exchange_third_party_invite_request(
+        self, room_id: str, event_dict: Dict
+    ):
+        ret = await self.handler.on_exchange_third_party_invite_request(
             room_id, event_dict
         )
         return ret
 
-    @defer.inlineCallbacks
-    def check_server_matches_acl(self, server_name, room_id):
+    async def check_server_matches_acl(self, server_name: str, room_id: str):
         """Check if the given server is allowed by the server ACLs in the room
 
         Args:
-            server_name (str): name of server, *without any port part*
-            room_id (str): ID of the room to check
+            server_name: name of server, *without any port part*
+            room_id: ID of the room to check
 
         Raises:
             AuthError if the server does not match the ACL
         """
-        state_ids = yield self.store.get_current_state_ids(room_id)
+        state_ids = await self.store.get_current_state_ids(room_id)
         acl_event_id = state_ids.get((EventTypes.ServerACL, ""))
 
         if not acl_event_id:
             return
 
-        acl_event = yield self.store.get_event(acl_event_id)
+        acl_event = await self.store.get_event(acl_event_id)
         if server_matches_acl_event(server_name, acl_event):
             return
 
         raise AuthError(code=403, msg="Server is banned from room")
 
 
-def server_matches_acl_event(server_name, acl_event):
+def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
     """Check if the given server is allowed by the ACL event
 
     Args:
-        server_name (str): name of server, without any port part
-        acl_event (EventBase): m.room.server_acl event
+        server_name: name of server, without any port part
+        acl_event: m.room.server_acl event
 
     Returns:
-        bool: True if this server is allowed by the ACLs
+        True if this server is allowed by the ACLs
     """
     logger.debug("Checking %s against acl %s", server_name, acl_event.content)
 
@@ -715,7 +715,7 @@ def server_matches_acl_event(server_name, acl_event):
     # server name is a literal IP
     allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
     if not isinstance(allow_ip_literals, bool):
-        logger.warn("Ignorning non-bool allow_ip_literals flag")
+        logger.warning("Ignorning non-bool allow_ip_literals flag")
         allow_ip_literals = True
     if not allow_ip_literals:
         # check for ipv6 literals. These start with '['.
@@ -729,7 +729,7 @@ def server_matches_acl_event(server_name, acl_event):
     # next,  check the deny list
     deny = acl_event.content.get("deny", [])
     if not isinstance(deny, (list, tuple)):
-        logger.warn("Ignorning non-list deny ACL %s", deny)
+        logger.warning("Ignorning non-list deny ACL %s", deny)
         deny = []
     for e in deny:
         if _acl_entry_matches(server_name, e):
@@ -739,7 +739,7 @@ def server_matches_acl_event(server_name, acl_event):
     # then the allow list.
     allow = acl_event.content.get("allow", [])
     if not isinstance(allow, (list, tuple)):
-        logger.warn("Ignorning non-list allow ACL %s", allow)
+        logger.warning("Ignorning non-list allow ACL %s", allow)
         allow = []
     for e in allow:
         if _acl_entry_matches(server_name, e):
@@ -751,9 +751,9 @@ def server_matches_acl_event(server_name, acl_event):
     return False
 
 
-def _acl_entry_matches(server_name, acl_entry):
+def _acl_entry_matches(server_name: str, acl_entry: str) -> Match:
     if not isinstance(acl_entry, six.string_types):
-        logger.warn(
+        logger.warning(
             "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
         )
         return False
@@ -770,13 +770,13 @@ class FederationHandlerRegistry(object):
         self.edu_handlers = {}
         self.query_handlers = {}
 
-    def register_edu_handler(self, edu_type, handler):
+    def register_edu_handler(self, edu_type: str, handler: Callable[[str, dict], None]):
         """Sets the handler callable that will be used to handle an incoming
         federation EDU of the given type.
 
         Args:
-            edu_type (str): The type of the incoming EDU to register handler for
-            handler (Callable[[str, dict]]): A callable invoked on incoming EDU
+            edu_type: The type of the incoming EDU to register handler for
+            handler: A callable invoked on incoming EDU
                 of the given type. The arguments are the origin server name and
                 the EDU contents.
         """
@@ -787,14 +787,16 @@ class FederationHandlerRegistry(object):
 
         self.edu_handlers[edu_type] = handler
 
-    def register_query_handler(self, query_type, handler):
+    def register_query_handler(
+        self, query_type: str, handler: Callable[[dict], defer.Deferred]
+    ):
         """Sets the handler callable that will be used to handle an incoming
         federation query of the given type.
 
         Args:
-            query_type (str): Category name of the query, which should match
+            query_type: Category name of the query, which should match
                 the string used by make_query.
-            handler (Callable[[dict], Deferred[dict]]): Invoked to handle
+            handler: Invoked to handle
                 incoming queries of this type. The return will be yielded
                 on and the result used as the response to the query request.
         """
@@ -805,24 +807,24 @@ class FederationHandlerRegistry(object):
 
         self.query_handlers[query_type] = handler
 
-    @defer.inlineCallbacks
-    def on_edu(self, edu_type, origin, content):
+    async def on_edu(self, edu_type: str, origin: str, content: dict):
         handler = self.edu_handlers.get(edu_type)
         if not handler:
-            logger.warn("No handler registered for EDU type %s", edu_type)
+            logger.warning("No handler registered for EDU type %s", edu_type)
+            return
 
         with start_active_span_from_edu(content, "handle_edu"):
             try:
-                yield handler(origin, content)
+                await handler(origin, content)
             except SynapseError as e:
                 logger.info("Failed to handle edu %r: %r", edu_type, e)
             except Exception:
                 logger.exception("Failed to handle edu %r", edu_type)
 
-    def on_query(self, query_type, args):
+    def on_query(self, query_type: str, args: dict) -> defer.Deferred:
         handler = self.query_handlers.get(query_type)
         if not handler:
-            logger.warn("No handler registered for query type %s", query_type)
+            logger.warning("No handler registered for query type %s", query_type)
             raise NotFoundError("No handler for Query type '%s'" % (query_type,))
 
         return handler(args)
@@ -846,7 +848,7 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
 
         super(ReplicationFederationHandlerRegistry, self).__init__()
 
-    def on_edu(self, edu_type, origin, content):
+    async def on_edu(self, edu_type: str, origin: str, content: dict):
         """Overrides FederationHandlerRegistry
         """
         if not self.config.use_presence and edu_type == "m.presence":
@@ -854,17 +856,17 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
 
         handler = self.edu_handlers.get(edu_type)
         if handler:
-            return super(ReplicationFederationHandlerRegistry, self).on_edu(
+            return await super(ReplicationFederationHandlerRegistry, self).on_edu(
                 edu_type, origin, content
             )
 
-        return self._send_edu(edu_type=edu_type, origin=origin, content=content)
+        return await self._send_edu(edu_type=edu_type, origin=origin, content=content)
 
-    def on_query(self, query_type, args):
+    async def on_query(self, query_type: str, args: dict):
         """Overrides FederationHandlerRegistry
         """
         handler = self.query_handlers.get(query_type)
         if handler:
-            return handler(args)
+            return await handler(args)
 
-        return self._get_query_client(query_type=query_type, args=args)
+        return await self._get_query_client(query_type=query_type, args=args)