summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/7975.misc1
-rwxr-xr-xcontrib/cmdclient/console.py16
-rw-r--r--synapse/crypto/keyring.py60
-rw-r--r--synapse/federation/federation_client.py8
-rw-r--r--synapse/federation/sender/__init__.py19
-rw-r--r--synapse/federation/transport/client.py96
-rw-r--r--synapse/handlers/groups_local.py35
-rw-r--r--synapse/http/matrixfederationclient.py72
-rw-r--r--tests/crypto/test_keyring.py11
-rw-r--r--tests/federation/test_complexity.py21
-rw-r--r--tests/federation/test_federation_sender.py10
-rw-r--r--tests/handlers/test_directory.py5
-rw-r--r--tests/handlers/test_profile.py3
-rw-r--r--tests/http/test_fedclient.py50
-rw-r--r--tests/replication/test_federation_sender_shard.py13
-rw-r--r--tests/rest/admin/test_admin.py4
-rw-r--r--tests/rest/key/v2/test_remote_key_resource.py4
-rw-r--r--tests/test_federation.py2
18 files changed, 209 insertions, 221 deletions
diff --git a/changelog.d/7975.misc b/changelog.d/7975.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/7975.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/contrib/cmdclient/console.py b/contrib/cmdclient/console.py
index 77422f5e5d..dfc1d294dc 100755
--- a/contrib/cmdclient/console.py
+++ b/contrib/cmdclient/console.py
@@ -609,13 +609,15 @@ class SynapseCmd(cmd.Cmd):
 
     @defer.inlineCallbacks
     def _do_event_stream(self, timeout):
-        res = yield self.http_client.get_json(
-            self._url() + "/events",
-            {
-                "access_token": self._tok(),
-                "timeout": str(timeout),
-                "from": self.event_stream_token,
-            },
+        res = yield defer.ensureDeferred(
+            self.http_client.get_json(
+                self._url() + "/events",
+                {
+                    "access_token": self._tok(),
+                    "timeout": str(timeout),
+                    "from": self.event_stream_token,
+                },
+            )
         )
         print(json.dumps(res, indent=4))
 
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index dbfc3e8972..443cde0b6d 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -632,18 +632,20 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
         )
 
         try:
-            query_response = yield self.client.post_json(
-                destination=perspective_name,
-                path="/_matrix/key/v2/query",
-                data={
-                    "server_keys": {
-                        server_name: {
-                            key_id: {"minimum_valid_until_ts": min_valid_ts}
-                            for key_id, min_valid_ts in server_keys.items()
+            query_response = yield defer.ensureDeferred(
+                self.client.post_json(
+                    destination=perspective_name,
+                    path="/_matrix/key/v2/query",
+                    data={
+                        "server_keys": {
+                            server_name: {
+                                key_id: {"minimum_valid_until_ts": min_valid_ts}
+                                for key_id, min_valid_ts in server_keys.items()
+                            }
+                            for server_name, server_keys in keys_to_fetch.items()
                         }
-                        for server_name, server_keys in keys_to_fetch.items()
-                    }
-                },
+                    },
+                )
             )
         except (NotRetryingDestination, RequestSendFailed) as e:
             # these both have str() representations which we can't really improve upon
@@ -792,23 +794,25 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
 
             time_now_ms = self.clock.time_msec()
             try:
-                response = yield self.client.get_json(
-                    destination=server_name,
-                    path="/_matrix/key/v2/server/"
-                    + urllib.parse.quote(requested_key_id),
-                    ignore_backoff=True,
-                    # we only give the remote server 10s to respond. It should be an
-                    # easy request to handle, so if it doesn't reply within 10s, it's
-                    # probably not going to.
-                    #
-                    # Furthermore, when we are acting as a notary server, we cannot
-                    # wait all day for all of the origin servers, as the requesting
-                    # server will otherwise time out before we can respond.
-                    #
-                    # (Note that get_json may make 4 attempts, so this can still take
-                    # almost 45 seconds to fetch the headers, plus up to another 60s to
-                    # read the response).
-                    timeout=10000,
+                response = yield defer.ensureDeferred(
+                    self.client.get_json(
+                        destination=server_name,
+                        path="/_matrix/key/v2/server/"
+                        + urllib.parse.quote(requested_key_id),
+                        ignore_backoff=True,
+                        # we only give the remote server 10s to respond. It should be an
+                        # easy request to handle, so if it doesn't reply within 10s, it's
+                        # probably not going to.
+                        #
+                        # Furthermore, when we are acting as a notary server, we cannot
+                        # wait all day for all of the origin servers, as the requesting
+                        # server will otherwise time out before we can respond.
+                        #
+                        # (Note that get_json may make 4 attempts, so this can still take
+                        # almost 45 seconds to fetch the headers, plus up to another 60s to
+                        # read the response).
+                        timeout=10000,
+                    )
                 )
             except (NotRetryingDestination, RequestSendFailed) as e:
                 # these both have str() representations which we can't really improve
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 994e6c8d5a..38ac7ec699 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -135,7 +135,7 @@ class FederationClient(FederationBase):
                 and try the request anyway.
 
         Returns:
-            a Deferred which will eventually yield a JSON object from the
+            a Awaitable which will eventually yield a JSON object from the
             response
         """
         sent_queries_counter.labels(query_type).inc()
@@ -157,7 +157,7 @@ class FederationClient(FederationBase):
             content (dict): The query content.
 
         Returns:
-            a Deferred which will eventually yield a JSON object from the
+            an Awaitable which will eventually yield a JSON object from the
             response
         """
         sent_queries_counter.labels("client_device_keys").inc()
@@ -180,7 +180,7 @@ class FederationClient(FederationBase):
             content (dict): The query content.
 
         Returns:
-            a Deferred which will eventually yield a JSON object from the
+            an Awaitable which will eventually yield a JSON object from the
             response
         """
         sent_queries_counter.labels("client_one_time_keys").inc()
@@ -900,7 +900,7 @@ class FederationClient(FederationBase):
                 party instance
 
         Returns:
-            Deferred[Dict[str, Any]]: The response from the remote server, or None if
+            Awaitable[Dict[str, Any]]: The response from the remote server, or None if
             `remote_server` is the same as the local server_name
 
         Raises:
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index ba4ddd2370..8f549ae6ee 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -288,8 +288,7 @@ class FederationSender(object):
         for destination in destinations:
             self._get_per_destination_queue(destination).send_pdu(pdu, order)
 
-    @defer.inlineCallbacks
-    def send_read_receipt(self, receipt: ReadReceipt):
+    async def send_read_receipt(self, receipt: ReadReceipt) -> None:
         """Send a RR to any other servers in the room
 
         Args:
@@ -330,9 +329,7 @@ class FederationSender(object):
         room_id = receipt.room_id
 
         # Work out which remote servers should be poked and poke them.
-        domains = yield defer.ensureDeferred(
-            self.state.get_current_hosts_in_room(room_id)
-        )
+        domains = await self.state.get_current_hosts_in_room(room_id)
         domains = [
             d
             for d in domains
@@ -387,8 +384,7 @@ class FederationSender(object):
             queue.flush_read_receipts_for_room(room_id)
 
     @preserve_fn  # the caller should not yield on this
-    @defer.inlineCallbacks
-    def send_presence(self, states: List[UserPresenceState]):
+    async def send_presence(self, states: List[UserPresenceState]):
         """Send the new presence states to the appropriate destinations.
 
         This actually queues up the presence states ready for sending and
@@ -423,7 +419,7 @@ class FederationSender(object):
                 if not states_map:
                     break
 
-                yield self._process_presence_inner(list(states_map.values()))
+                await self._process_presence_inner(list(states_map.values()))
         except Exception:
             logger.exception("Error sending presence states to servers")
         finally:
@@ -450,14 +446,11 @@ class FederationSender(object):
             self._get_per_destination_queue(destination).send_presence(states)
 
     @measure_func("txnqueue._process_presence")
-    @defer.inlineCallbacks
-    def _process_presence_inner(self, states: List[UserPresenceState]):
+    async def _process_presence_inner(self, states: List[UserPresenceState]):
         """Given a list of states populate self.pending_presence_by_dest and
         poke to send a new transaction to each destination
         """
-        hosts_and_states = yield defer.ensureDeferred(
-            get_interested_remotes(self.store, states, self.state)
-        )
+        hosts_and_states = await get_interested_remotes(self.store, states, self.state)
 
         for destinations, states in hosts_and_states:
             for destination in destinations:
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index cfdf23d366..9ea821dbb2 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -18,8 +18,6 @@ import logging
 import urllib
 from typing import Any, Dict, Optional
 
-from twisted.internet import defer
-
 from synapse.api.constants import Membership
 from synapse.api.errors import Codes, HttpResponseException, SynapseError
 from synapse.api.urls import (
@@ -51,7 +49,7 @@ class TransportLayerClient(object):
             event_id (str): The event we want the context at.
 
         Returns:
-            Deferred: Results in a dict received from the remote homeserver.
+            Awaitable: Results in a dict received from the remote homeserver.
         """
         logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id)
 
@@ -75,7 +73,7 @@ class TransportLayerClient(object):
                 giving up. None indicates no timeout.
 
         Returns:
-            Deferred: Results in a dict received from the remote homeserver.
+            Awaitable: Results in a dict received from the remote homeserver.
         """
         logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id)
 
@@ -96,7 +94,7 @@ class TransportLayerClient(object):
             limit (int)
 
         Returns:
-            Deferred: Results in a dict received from the remote homeserver.
+            Awaitable: Results in a dict received from the remote homeserver.
         """
         logger.debug(
             "backfill dest=%s, room_id=%s, event_tuples=%r, limit=%s",
@@ -118,16 +116,15 @@ class TransportLayerClient(object):
             destination, path=path, args=args, try_trailing_slash_on_400=True
         )
 
-    @defer.inlineCallbacks
     @log_function
-    def send_transaction(self, transaction, json_data_callback=None):
+    async def send_transaction(self, transaction, json_data_callback=None):
         """ Sends the given Transaction to its destination
 
         Args:
             transaction (Transaction)
 
         Returns:
-            Deferred: Succeeds when we get a 2xx HTTP response. The result
+            Succeeds when we get a 2xx HTTP response. The result
             will be the decoded JSON body.
 
             Fails with ``HTTPRequestException`` if we get an HTTP response
@@ -154,7 +151,7 @@ class TransportLayerClient(object):
 
         path = _create_v1_path("/send/%s", transaction.transaction_id)
 
-        response = yield self.client.put_json(
+        response = await self.client.put_json(
             transaction.destination,
             path=path,
             data=json_data,
@@ -166,14 +163,13 @@ class TransportLayerClient(object):
 
         return response
 
-    @defer.inlineCallbacks
     @log_function
-    def make_query(
+    async def make_query(
         self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False
     ):
         path = _create_v1_path("/query/%s", query_type)
 
-        content = yield self.client.get_json(
+        content = await self.client.get_json(
             destination=destination,
             path=path,
             args=args,
@@ -184,9 +180,10 @@ class TransportLayerClient(object):
 
         return content
 
-    @defer.inlineCallbacks
     @log_function
-    def make_membership_event(self, destination, room_id, user_id, membership, params):
+    async def make_membership_event(
+        self, destination, room_id, user_id, membership, params
+    ):
         """Asks a remote server to build and sign us a membership event
 
         Note that this does not append any events to any graphs.
@@ -200,7 +197,7 @@ class TransportLayerClient(object):
                 request.
 
         Returns:
-            Deferred: Succeeds when we get a 2xx HTTP response. The result
+            Succeeds when we get a 2xx HTTP response. The result
             will be the decoded JSON body (ie, the new event).
 
             Fails with ``HTTPRequestException`` if we get an HTTP response
@@ -231,7 +228,7 @@ class TransportLayerClient(object):
             ignore_backoff = True
             retry_on_dns_fail = True
 
-        content = yield self.client.get_json(
+        content = await self.client.get_json(
             destination=destination,
             path=path,
             args=params,
@@ -242,34 +239,31 @@ class TransportLayerClient(object):
 
         return content
 
-    @defer.inlineCallbacks
     @log_function
-    def send_join_v1(self, destination, room_id, event_id, content):
+    async def send_join_v1(self, destination, room_id, event_id, content):
         path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
 
-        response = yield self.client.put_json(
+        response = await self.client.put_json(
             destination=destination, path=path, data=content
         )
 
         return response
 
-    @defer.inlineCallbacks
     @log_function
-    def send_join_v2(self, destination, room_id, event_id, content):
+    async def send_join_v2(self, destination, room_id, event_id, content):
         path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
 
-        response = yield self.client.put_json(
+        response = await self.client.put_json(
             destination=destination, path=path, data=content
         )
 
         return response
 
-    @defer.inlineCallbacks
     @log_function
-    def send_leave_v1(self, destination, room_id, event_id, content):
+    async def send_leave_v1(self, destination, room_id, event_id, content):
         path = _create_v1_path("/send_leave/%s/%s", room_id, event_id)
 
-        response = yield self.client.put_json(
+        response = await self.client.put_json(
             destination=destination,
             path=path,
             data=content,
@@ -282,12 +276,11 @@ class TransportLayerClient(object):
 
         return response
 
-    @defer.inlineCallbacks
     @log_function
-    def send_leave_v2(self, destination, room_id, event_id, content):
+    async def send_leave_v2(self, destination, room_id, event_id, content):
         path = _create_v2_path("/send_leave/%s/%s", room_id, event_id)
 
-        response = yield self.client.put_json(
+        response = await self.client.put_json(
             destination=destination,
             path=path,
             data=content,
@@ -300,31 +293,28 @@ class TransportLayerClient(object):
 
         return response
 
-    @defer.inlineCallbacks
     @log_function
-    def send_invite_v1(self, destination, room_id, event_id, content):
+    async def send_invite_v1(self, destination, room_id, event_id, content):
         path = _create_v1_path("/invite/%s/%s", room_id, event_id)
 
-        response = yield self.client.put_json(
+        response = await self.client.put_json(
             destination=destination, path=path, data=content, ignore_backoff=True
         )
 
         return response
 
-    @defer.inlineCallbacks
     @log_function
-    def send_invite_v2(self, destination, room_id, event_id, content):
+    async def send_invite_v2(self, destination, room_id, event_id, content):
         path = _create_v2_path("/invite/%s/%s", room_id, event_id)
 
-        response = yield self.client.put_json(
+        response = await self.client.put_json(
             destination=destination, path=path, data=content, ignore_backoff=True
         )
 
         return response
 
-    @defer.inlineCallbacks
     @log_function
-    def get_public_rooms(
+    async def get_public_rooms(
         self,
         remote_server: str,
         limit: Optional[int] = None,
@@ -355,7 +345,7 @@ class TransportLayerClient(object):
             data["filter"] = search_filter
 
             try:
-                response = yield self.client.post_json(
+                response = await self.client.post_json(
                     destination=remote_server, path=path, data=data, ignore_backoff=True
                 )
             except HttpResponseException as e:
@@ -381,7 +371,7 @@ class TransportLayerClient(object):
                 args["since"] = [since_token]
 
             try:
-                response = yield self.client.get_json(
+                response = await self.client.get_json(
                     destination=remote_server, path=path, args=args, ignore_backoff=True
                 )
             except HttpResponseException as e:
@@ -396,29 +386,26 @@ class TransportLayerClient(object):
 
         return response
 
-    @defer.inlineCallbacks
     @log_function
-    def exchange_third_party_invite(self, destination, room_id, event_dict):
+    async def exchange_third_party_invite(self, destination, room_id, event_dict):
         path = _create_v1_path("/exchange_third_party_invite/%s", room_id)
 
-        response = yield self.client.put_json(
+        response = await self.client.put_json(
             destination=destination, path=path, data=event_dict
         )
 
         return response
 
-    @defer.inlineCallbacks
     @log_function
-    def get_event_auth(self, destination, room_id, event_id):
+    async def get_event_auth(self, destination, room_id, event_id):
         path = _create_v1_path("/event_auth/%s/%s", room_id, event_id)
 
-        content = yield self.client.get_json(destination=destination, path=path)
+        content = await self.client.get_json(destination=destination, path=path)
 
         return content
 
-    @defer.inlineCallbacks
     @log_function
-    def query_client_keys(self, destination, query_content, timeout):
+    async def query_client_keys(self, destination, query_content, timeout):
         """Query the device keys for a list of user ids hosted on a remote
         server.
 
@@ -453,14 +440,13 @@ class TransportLayerClient(object):
         """
         path = _create_v1_path("/user/keys/query")
 
-        content = yield self.client.post_json(
+        content = await self.client.post_json(
             destination=destination, path=path, data=query_content, timeout=timeout
         )
         return content
 
-    @defer.inlineCallbacks
     @log_function
-    def query_user_devices(self, destination, user_id, timeout):
+    async def query_user_devices(self, destination, user_id, timeout):
         """Query the devices for a user id hosted on a remote server.
 
         Response:
@@ -493,14 +479,13 @@ class TransportLayerClient(object):
         """
         path = _create_v1_path("/user/devices/%s", user_id)
 
-        content = yield self.client.get_json(
+        content = await self.client.get_json(
             destination=destination, path=path, timeout=timeout
         )
         return content
 
-    @defer.inlineCallbacks
     @log_function
-    def claim_client_keys(self, destination, query_content, timeout):
+    async def claim_client_keys(self, destination, query_content, timeout):
         """Claim one-time keys for a list of devices hosted on a remote server.
 
         Request:
@@ -532,14 +517,13 @@ class TransportLayerClient(object):
 
         path = _create_v1_path("/user/keys/claim")
 
-        content = yield self.client.post_json(
+        content = await self.client.post_json(
             destination=destination, path=path, data=query_content, timeout=timeout
         )
         return content
 
-    @defer.inlineCallbacks
     @log_function
-    def get_missing_events(
+    async def get_missing_events(
         self,
         destination,
         room_id,
@@ -551,7 +535,7 @@ class TransportLayerClient(object):
     ):
         path = _create_v1_path("/get_missing_events/%s", room_id)
 
-        content = yield self.client.post_json(
+        content = await self.client.post_json(
             destination=destination,
             path=path,
             data={
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index ecdb12a7bf..0e2656ccb3 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -23,39 +23,32 @@ logger = logging.getLogger(__name__)
 
 
 def _create_rerouter(func_name):
-    """Returns a function that looks at the group id and calls the function
+    """Returns an async function that looks at the group id and calls the function
     on federation or the local group server if the group is local
     """
 
-    def f(self, group_id, *args, **kwargs):
+    async def f(self, group_id, *args, **kwargs):
         if self.is_mine_id(group_id):
-            return getattr(self.groups_server_handler, func_name)(
+            return await getattr(self.groups_server_handler, func_name)(
                 group_id, *args, **kwargs
             )
         else:
             destination = get_domain_from_id(group_id)
-            d = getattr(self.transport_client, func_name)(
-                destination, group_id, *args, **kwargs
-            )
 
-            # Capture errors returned by the remote homeserver and
-            # re-throw specific errors as SynapseErrors. This is so
-            # when the remote end responds with things like 403 Not
-            # In Group, we can communicate that to the client instead
-            # of a 500.
-            def http_response_errback(failure):
-                failure.trap(HttpResponseException)
-                e = failure.value
+            try:
+                return await getattr(self.transport_client, func_name)(
+                    destination, group_id, *args, **kwargs
+                )
+            except HttpResponseException as e:
+                # Capture errors returned by the remote homeserver and
+                # re-throw specific errors as SynapseErrors. This is so
+                # when the remote end responds with things like 403 Not
+                # In Group, we can communicate that to the client instead
+                # of a 500.
                 raise e.to_synapse_error()
-
-            def request_failed_errback(failure):
-                failure.trap(RequestSendFailed)
+            except RequestSendFailed:
                 raise SynapseError(502, "Failed to contact group server")
 
-            d.addErrback(http_response_errback)
-            d.addErrback(request_failed_errback)
-            return d
-
     return f
 
 
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index ea026ed9f4..2a6373937a 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -121,8 +121,7 @@ class MatrixFederationRequest(object):
         return self.json
 
 
-@defer.inlineCallbacks
-def _handle_json_response(reactor, timeout_sec, request, response):
+async def _handle_json_response(reactor, timeout_sec, request, response):
     """
     Reads the JSON body of a response, with a timeout
 
@@ -141,7 +140,7 @@ def _handle_json_response(reactor, timeout_sec, request, response):
         d = treq.json_content(response)
         d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
 
-        body = yield make_deferred_yieldable(d)
+        body = await make_deferred_yieldable(d)
     except TimeoutError as e:
         logger.warning(
             "{%s} [%s] Timed out reading response", request.txn_id, request.destination,
@@ -224,8 +223,7 @@ class MatrixFederationHttpClient(object):
 
         self._cooperator = Cooperator(scheduler=schedule)
 
-    @defer.inlineCallbacks
-    def _send_request_with_optional_trailing_slash(
+    async def _send_request_with_optional_trailing_slash(
         self, request, try_trailing_slash_on_400=False, **send_request_args
     ):
         """Wrapper for _send_request which can optionally retry the request
@@ -246,10 +244,10 @@ class MatrixFederationHttpClient(object):
                 (except 429).
 
         Returns:
-            Deferred[Dict]: Parsed JSON response body.
+            Dict: Parsed JSON response body.
         """
         try:
-            response = yield self._send_request(request, **send_request_args)
+            response = await self._send_request(request, **send_request_args)
         except HttpResponseException as e:
             # Received an HTTP error > 300. Check if it meets the requirements
             # to retry with a trailing slash
@@ -265,12 +263,11 @@ class MatrixFederationHttpClient(object):
             logger.info("Retrying request with trailing slash")
             request.path += "/"
 
-            response = yield self._send_request(request, **send_request_args)
+            response = await self._send_request(request, **send_request_args)
 
         return response
 
-    @defer.inlineCallbacks
-    def _send_request(
+    async def _send_request(
         self,
         request,
         retry_on_dns_fail=True,
@@ -311,7 +308,7 @@ class MatrixFederationHttpClient(object):
             backoff_on_404 (bool): Back off if we get a 404
 
         Returns:
-            Deferred[twisted.web.client.Response]: resolves with the HTTP
+            twisted.web.client.Response: resolves with the HTTP
             response object on success.
 
         Raises:
@@ -335,7 +332,7 @@ class MatrixFederationHttpClient(object):
         ):
             raise FederationDeniedError(request.destination)
 
-        limiter = yield synapse.util.retryutils.get_retry_limiter(
+        limiter = await synapse.util.retryutils.get_retry_limiter(
             request.destination,
             self.clock,
             self._store,
@@ -433,7 +430,7 @@ class MatrixFederationHttpClient(object):
                                 reactor=self.reactor,
                             )
 
-                            response = yield request_deferred
+                            response = await request_deferred
                     except TimeoutError as e:
                         raise RequestSendFailed(e, can_retry=True) from e
                     except DNSLookupError as e:
@@ -474,7 +471,7 @@ class MatrixFederationHttpClient(object):
                         )
 
                         try:
-                            body = yield make_deferred_yieldable(d)
+                            body = await make_deferred_yieldable(d)
                         except Exception as e:
                             # Eh, we're already going to raise an exception so lets
                             # ignore if this fails.
@@ -528,7 +525,7 @@ class MatrixFederationHttpClient(object):
                             delay,
                         )
 
-                        yield self.clock.sleep(delay)
+                        await self.clock.sleep(delay)
                         retries_left -= 1
                     else:
                         raise
@@ -591,8 +588,7 @@ class MatrixFederationHttpClient(object):
             )
         return auth_headers
 
-    @defer.inlineCallbacks
-    def put_json(
+    async def put_json(
         self,
         destination,
         path,
@@ -636,7 +632,7 @@ class MatrixFederationHttpClient(object):
                 enabled.
 
         Returns:
-            Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
+            dict|list: Succeeds when we get a 2xx HTTP response. The
             result will be the decoded JSON body.
 
         Raises:
@@ -658,7 +654,7 @@ class MatrixFederationHttpClient(object):
             json=data,
         )
 
-        response = yield self._send_request_with_optional_trailing_slash(
+        response = await self._send_request_with_optional_trailing_slash(
             request,
             try_trailing_slash_on_400,
             backoff_on_404=backoff_on_404,
@@ -667,14 +663,13 @@ class MatrixFederationHttpClient(object):
             timeout=timeout,
         )
 
-        body = yield _handle_json_response(
+        body = await _handle_json_response(
             self.reactor, self.default_timeout, request, response
         )
 
         return body
 
-    @defer.inlineCallbacks
-    def post_json(
+    async def post_json(
         self,
         destination,
         path,
@@ -707,7 +702,7 @@ class MatrixFederationHttpClient(object):
 
             args (dict): query params
         Returns:
-            Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
+            dict|list: Succeeds when we get a 2xx HTTP response. The
             result will be the decoded JSON body.
 
         Raises:
@@ -725,7 +720,7 @@ class MatrixFederationHttpClient(object):
             method="POST", destination=destination, path=path, query=args, json=data
         )
 
-        response = yield self._send_request(
+        response = await self._send_request(
             request,
             long_retries=long_retries,
             timeout=timeout,
@@ -737,13 +732,12 @@ class MatrixFederationHttpClient(object):
         else:
             _sec_timeout = self.default_timeout
 
-        body = yield _handle_json_response(
+        body = await _handle_json_response(
             self.reactor, _sec_timeout, request, response
         )
         return body
 
-    @defer.inlineCallbacks
-    def get_json(
+    async def get_json(
         self,
         destination,
         path,
@@ -775,7 +769,7 @@ class MatrixFederationHttpClient(object):
                 response we should try appending a trailing slash to the end of
                 the request. Workaround for #3622 in Synapse <= v0.99.3.
         Returns:
-            Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
+            dict|list: Succeeds when we get a 2xx HTTP response. The
             result will be the decoded JSON body.
 
         Raises:
@@ -792,7 +786,7 @@ class MatrixFederationHttpClient(object):
             method="GET", destination=destination, path=path, query=args
         )
 
-        response = yield self._send_request_with_optional_trailing_slash(
+        response = await self._send_request_with_optional_trailing_slash(
             request,
             try_trailing_slash_on_400,
             backoff_on_404=False,
@@ -801,14 +795,13 @@ class MatrixFederationHttpClient(object):
             timeout=timeout,
         )
 
-        body = yield _handle_json_response(
+        body = await _handle_json_response(
             self.reactor, self.default_timeout, request, response
         )
 
         return body
 
-    @defer.inlineCallbacks
-    def delete_json(
+    async def delete_json(
         self,
         destination,
         path,
@@ -836,7 +829,7 @@ class MatrixFederationHttpClient(object):
 
             args (dict): query params
         Returns:
-            Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
+            dict|list: Succeeds when we get a 2xx HTTP response. The
             result will be the decoded JSON body.
 
         Raises:
@@ -853,20 +846,19 @@ class MatrixFederationHttpClient(object):
             method="DELETE", destination=destination, path=path, query=args
         )
 
-        response = yield self._send_request(
+        response = await self._send_request(
             request,
             long_retries=long_retries,
             timeout=timeout,
             ignore_backoff=ignore_backoff,
         )
 
-        body = yield _handle_json_response(
+        body = await _handle_json_response(
             self.reactor, self.default_timeout, request, response
         )
         return body
 
-    @defer.inlineCallbacks
-    def get_file(
+    async def get_file(
         self,
         destination,
         path,
@@ -886,7 +878,7 @@ class MatrixFederationHttpClient(object):
                 and try the request anyway.
 
         Returns:
-            Deferred[tuple[int, dict]]: Resolves with an (int,dict) tuple of
+            tuple[int, dict]: Resolves with an (int,dict) tuple of
             the file length and a dict of the response headers.
 
         Raises:
@@ -903,7 +895,7 @@ class MatrixFederationHttpClient(object):
             method="GET", destination=destination, path=path, query=args
         )
 
-        response = yield self._send_request(
+        response = await self._send_request(
             request, retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff
         )
 
@@ -912,7 +904,7 @@ class MatrixFederationHttpClient(object):
         try:
             d = _readBodyToFile(response, output_stream, max_size)
             d.addTimeout(self.default_timeout, self.reactor)
-            length = yield make_deferred_yieldable(d)
+            length = await make_deferred_yieldable(d)
         except Exception as e:
             logger.warning(
                 "{%s} [%s] Error reading response: %s",
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index f9ce609923..e0ad8e8a77 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -102,11 +102,10 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         }
         persp_deferred = defer.Deferred()
 
-        @defer.inlineCallbacks
-        def get_perspectives(**kwargs):
+        async def get_perspectives(**kwargs):
             self.assertEquals(current_context().request, "11")
             with PreserveLoggingContext():
-                yield persp_deferred
+                await persp_deferred
             return persp_resp
 
         self.http_client.post_json.side_effect = get_perspectives
@@ -355,7 +354,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
         }
         signedjson.sign.sign_json(response, SERVER_NAME, testkey)
 
-        def get_json(destination, path, **kwargs):
+        async def get_json(destination, path, **kwargs):
             self.assertEqual(destination, SERVER_NAME)
             self.assertEqual(path, "/_matrix/key/v2/server/key1")
             return response
@@ -444,7 +443,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
         Tell the mock http client to expect a perspectives-server key query
         """
 
-        def post_json(destination, path, data, **kwargs):
+        async def post_json(destination, path, data, **kwargs):
             self.assertEqual(destination, self.mock_perspective_server.server_name)
             self.assertEqual(path, "/_matrix/key/v2/query")
 
@@ -580,14 +579,12 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
         # remove the perspectives server's signature
         response = build_response()
         del response["signatures"][self.mock_perspective_server.server_name]
-        self.http_client.post_json.return_value = {"server_keys": [response]}
         keys = get_key_from_perspectives(response)
         self.assertEqual(keys, {}, "Expected empty dict with missing persp server sig")
 
         # remove the origin server's signature
         response = build_response()
         del response["signatures"][SERVER_NAME]
-        self.http_client.post_json.return_value = {"server_keys": [response]}
         keys = get_key_from_perspectives(response)
         self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig")
 
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 5cd0510f0d..b8ca118716 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -23,6 +23,7 @@ from synapse.rest.client.v1 import login, room
 from synapse.types import UserID
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 
 
 class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
@@ -78,9 +79,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999}))
+        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
         handler.federation_handler.do_invite_join = Mock(
-            return_value=defer.succeed(("", 1))
+            return_value=make_awaitable(("", 1))
         )
 
         d = handler._remote_join(
@@ -109,9 +110,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999}))
+        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
         handler.federation_handler.do_invite_join = Mock(
-            return_value=defer.succeed(("", 1))
+            return_value=make_awaitable(("", 1))
         )
 
         d = handler._remote_join(
@@ -147,9 +148,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=defer.succeed(None))
+        fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
         handler.federation_handler.do_invite_join = Mock(
-            return_value=defer.succeed(("", 1))
+            return_value=make_awaitable(("", 1))
         )
 
         # Artificially raise the complexity
@@ -203,9 +204,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999}))
+        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
         handler.federation_handler.do_invite_join = Mock(
-            return_value=defer.succeed(("", 1))
+            return_value=make_awaitable(("", 1))
         )
 
         d = handler._remote_join(
@@ -233,9 +234,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999}))
+        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
         handler.federation_handler.do_invite_join = Mock(
-            return_value=defer.succeed(("", 1))
+            return_value=make_awaitable(("", 1))
         )
 
         d = handler._remote_join(
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index d1bd18da39..5f512ff8bf 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -47,13 +47,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
         mock_send_transaction = (
             self.hs.get_federation_transport_client().send_transaction
         )
-        mock_send_transaction.return_value = defer.succeed({})
+        mock_send_transaction.return_value = make_awaitable({})
 
         sender = self.hs.get_federation_sender()
         receipt = ReadReceipt(
             "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
         )
-        self.successResultOf(sender.send_read_receipt(receipt))
+        self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
 
         self.pump()
 
@@ -87,13 +87,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
         mock_send_transaction = (
             self.hs.get_federation_transport_client().send_transaction
         )
-        mock_send_transaction.return_value = defer.succeed({})
+        mock_send_transaction.return_value = make_awaitable({})
 
         sender = self.hs.get_federation_sender()
         receipt = ReadReceipt(
             "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
         )
-        self.successResultOf(sender.send_read_receipt(receipt))
+        self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
 
         self.pump()
 
@@ -125,7 +125,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
         receipt = ReadReceipt(
             "room_id", "m.read", "user_id", ["other_id"], {"ts": 1234}
         )
-        self.successResultOf(sender.send_read_receipt(receipt))
+        self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
         self.pump()
         mock_send_transaction.assert_not_called()
 
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 00bb776271..bc0c5aefdc 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -16,8 +16,6 @@
 
 from mock import Mock
 
-from twisted.internet import defer
-
 import synapse
 import synapse.api.errors
 from synapse.api.constants import EventTypes
@@ -26,6 +24,7 @@ from synapse.rest.client.v1 import directory, login, room
 from synapse.types import RoomAlias, create_requester
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 
 
 class DirectoryTestCase(unittest.HomeserverTestCase):
@@ -71,7 +70,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
 
     def test_get_remote_association(self):
-        self.mock_federation.make_query.return_value = defer.succeed(
+        self.mock_federation.make_query.return_value = make_awaitable(
             {"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
         )
 
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 4f1347cd25..d70e1fc608 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -24,6 +24,7 @@ from synapse.handlers.profile import MasterProfileHandler
 from synapse.types import UserID
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 from tests.utils import setup_test_homeserver
 
 
@@ -138,7 +139,7 @@ class ProfileTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_other_name(self):
-        self.mock_federation.make_query.return_value = defer.succeed(
+        self.mock_federation.make_query.return_value = make_awaitable(
             {"displayname": "Alice"}
         )
 
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index fff4f0cbf4..ac598249e4 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -58,7 +58,9 @@ class FederationClientTests(HomeserverTestCase):
         @defer.inlineCallbacks
         def do_request():
             with LoggingContext("one") as context:
-                fetch_d = self.cl.get_json("testserv:8008", "foo/bar")
+                fetch_d = defer.ensureDeferred(
+                    self.cl.get_json("testserv:8008", "foo/bar")
+                )
 
                 # Nothing happened yet
                 self.assertNoResult(fetch_d)
@@ -120,7 +122,9 @@ class FederationClientTests(HomeserverTestCase):
         """
         If the DNS lookup returns an error, it will bubble up.
         """
-        d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(
+            self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
+        )
         self.pump()
 
         f = self.failureResultOf(d)
@@ -128,7 +132,9 @@ class FederationClientTests(HomeserverTestCase):
         self.assertIsInstance(f.value.inner_exception, DNSLookupError)
 
     def test_client_connection_refused(self):
-        d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(
+            self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+        )
 
         self.pump()
 
@@ -154,7 +160,9 @@ class FederationClientTests(HomeserverTestCase):
         If the HTTP request is not connected and is timed out, it'll give a
         ConnectingCancelledError or TimeoutError.
         """
-        d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(
+            self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+        )
 
         self.pump()
 
@@ -184,7 +192,9 @@ class FederationClientTests(HomeserverTestCase):
         If the HTTP request is connected, but gets no response before being
         timed out, it'll give a ResponseNeverReceived.
         """
-        d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(
+            self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+        )
 
         self.pump()
 
@@ -226,7 +236,7 @@ class FederationClientTests(HomeserverTestCase):
         # Try making a GET request to a blacklisted IPv4 address
         # ------------------------------------------------------
         # Make the request
-        d = cl.get_json("internal:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(cl.get_json("internal:8008", "foo/bar", timeout=10000))
 
         # Nothing happened yet
         self.assertNoResult(d)
@@ -244,7 +254,9 @@ class FederationClientTests(HomeserverTestCase):
         # Try making a POST request to a blacklisted IPv6 address
         # -------------------------------------------------------
         # Make the request
-        d = cl.post_json("internalv6:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(
+            cl.post_json("internalv6:8008", "foo/bar", timeout=10000)
+        )
 
         # Nothing has happened yet
         self.assertNoResult(d)
@@ -263,7 +275,7 @@ class FederationClientTests(HomeserverTestCase):
         # Try making a GET request to a non-blacklisted IPv4 address
         # ----------------------------------------------------------
         # Make the request
-        d = cl.post_json("fine:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(cl.post_json("fine:8008", "foo/bar", timeout=10000))
 
         # Nothing has happened yet
         self.assertNoResult(d)
@@ -286,7 +298,7 @@ class FederationClientTests(HomeserverTestCase):
         request = MatrixFederationRequest(
             method="GET", destination="testserv:8008", path="foo/bar"
         )
-        d = self.cl._send_request(request, timeout=10000)
+        d = defer.ensureDeferred(self.cl._send_request(request, timeout=10000))
 
         self.pump()
 
@@ -310,7 +322,9 @@ class FederationClientTests(HomeserverTestCase):
         If the HTTP request is connected, but gets no response before being
         timed out, it'll give a ResponseNeverReceived.
         """
-        d = self.cl.post_json("testserv:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(
+            self.cl.post_json("testserv:8008", "foo/bar", timeout=10000)
+        )
 
         self.pump()
 
@@ -342,7 +356,9 @@ class FederationClientTests(HomeserverTestCase):
         requiring a trailing slash. We need to retry the request with a
         trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622.
         """
-        d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+        d = defer.ensureDeferred(
+            self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+        )
 
         # Send the request
         self.pump()
@@ -395,7 +411,9 @@ class FederationClientTests(HomeserverTestCase):
 
         See test_client_requires_trailing_slashes() for context.
         """
-        d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+        d = defer.ensureDeferred(
+            self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+        )
 
         # Send the request
         self.pump()
@@ -432,7 +450,11 @@ class FederationClientTests(HomeserverTestCase):
         self.failureResultOf(d)
 
     def test_client_sends_body(self):
-        self.cl.post_json("testserv:8008", "foo/bar", timeout=10000, data={"a": "b"})
+        defer.ensureDeferred(
+            self.cl.post_json(
+                "testserv:8008", "foo/bar", timeout=10000, data={"a": "b"}
+            )
+        )
 
         self.pump()
 
@@ -453,7 +475,7 @@ class FederationClientTests(HomeserverTestCase):
 
     def test_closes_connection(self):
         """Check that the client closes unused HTTP connections"""
-        d = self.cl.get_json("testserv:8008", "foo/bar")
+        d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar"))
 
         self.pump()
 
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 8d4dbf232e..83f9aa291c 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -16,8 +16,6 @@ import logging
 
 from mock import Mock
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes, Membership
 from synapse.events.builder import EventBuilderFactory
 from synapse.rest.admin import register_servlets_for_client_rest_resource
@@ -25,6 +23,7 @@ from synapse.rest.client.v1 import login, room
 from synapse.types import UserID
 
 from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.test_utils import make_awaitable
 
 logger = logging.getLogger(__name__)
 
@@ -46,7 +45,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         new event.
         """
         mock_client = Mock(spec=["put_json"])
-        mock_client.put_json.side_effect = lambda *_, **__: defer.succeed({})
+        mock_client.put_json.side_effect = lambda *_, **__: make_awaitable({})
 
         self.make_worker_hs(
             "synapse.app.federation_sender",
@@ -74,7 +73,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         new events.
         """
         mock_client1 = Mock(spec=["put_json"])
-        mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({})
+        mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({})
         self.make_worker_hs(
             "synapse.app.federation_sender",
             {
@@ -86,7 +85,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         )
 
         mock_client2 = Mock(spec=["put_json"])
-        mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({})
+        mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({})
         self.make_worker_hs(
             "synapse.app.federation_sender",
             {
@@ -137,7 +136,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         new typing EDUs.
         """
         mock_client1 = Mock(spec=["put_json"])
-        mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({})
+        mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({})
         self.make_worker_hs(
             "synapse.app.federation_sender",
             {
@@ -149,7 +148,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         )
 
         mock_client2 = Mock(spec=["put_json"])
-        mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({})
+        mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({})
         self.make_worker_hs(
             "synapse.app.federation_sender",
             {
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index b1a4decced..0f1144fe1e 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -178,7 +178,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
 
         self.fetches = []
 
-        def get_file(destination, path, output_stream, args=None, max_size=None):
+        async def get_file(destination, path, output_stream, args=None, max_size=None):
             """
             Returns tuple[int,dict,str,int] of file length, response headers,
             absolute URI, and response code.
@@ -192,7 +192,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
             d = Deferred()
             d.addCallback(write_to)
             self.fetches.append((d, destination, path, args))
-            return make_deferred_yieldable(d)
+            return await make_deferred_yieldable(d)
 
         client = Mock()
         client.get_file = get_file
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 99eb477149..6850c666be 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -53,7 +53,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
         Tell the mock http client to expect an outgoing GET request for the given key
         """
 
-        def get_json(destination, path, ignore_backoff=False, **kwargs):
+        async def get_json(destination, path, ignore_backoff=False, **kwargs):
             self.assertTrue(ignore_backoff)
             self.assertEqual(destination, server_name)
             key_id = "%s:%s" % (signing_key.alg, signing_key.version)
@@ -177,7 +177,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
 
         # wire up outbound POST /key/v2/query requests from hs2 so that they
         # will be forwarded to hs1
-        def post_json(destination, path, data):
+        async def post_json(destination, path, data):
             self.assertEqual(destination, self.hs.hostname)
             self.assertEqual(
                 path, "/_matrix/key/v2/query",
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 87a16d7d7a..c2f12c2741 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -95,7 +95,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         prev_events that said event references.
         """
 
-        def post_json(destination, path, data, headers=None, timeout=0):
+        async def post_json(destination, path, data, headers=None, timeout=0):
             # If it asks us for new missing events, give them NOTHING
             if path.startswith("/_matrix/federation/v1/get_missing_events/"):
                 return {"events": []}