summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/7871.misc1
-rw-r--r--synapse/handlers/device.py241
-rw-r--r--synapse/util/distributor.py28
-rw-r--r--tests/handlers/test_device.py13
-rw-r--r--tests/handlers/test_e2e_keys.py10
-rw-r--r--tests/test_federation.py35
6 files changed, 162 insertions, 166 deletions
diff --git a/changelog.d/7871.misc b/changelog.d/7871.misc
new file mode 100644
index 0000000000..4d398a9f3a
--- /dev/null
+++ b/changelog.d/7871.misc
@@ -0,0 +1 @@
+Convert device handler to async/await.
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 31346b56c3..f947aa1627 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -15,9 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Any, Dict, Optional
-
-from twisted.internet import defer
+from typing import Any, Dict, List, Optional
 
 from synapse.api import errors
 from synapse.api.constants import EventTypes
@@ -57,21 +55,20 @@ class DeviceWorkerHandler(BaseHandler):
         self._auth_handler = hs.get_auth_handler()
 
     @trace
-    @defer.inlineCallbacks
-    def get_devices_by_user(self, user_id):
+    async def get_devices_by_user(self, user_id: str) -> List[Dict[str, Any]]:
         """
         Retrieve the given user's devices
 
         Args:
-            user_id (str):
+            user_id: The user ID to query for devices.
         Returns:
-            defer.Deferred: list[dict[str, X]]: info on each device
+            info on each device
         """
 
         set_tag("user_id", user_id)
-        device_map = yield self.store.get_devices_by_user(user_id)
+        device_map = await self.store.get_devices_by_user(user_id)
 
-        ips = yield self.store.get_last_client_ip_by_device(user_id, device_id=None)
+        ips = await self.store.get_last_client_ip_by_device(user_id, device_id=None)
 
         devices = list(device_map.values())
         for device in devices:
@@ -81,24 +78,23 @@ class DeviceWorkerHandler(BaseHandler):
         return devices
 
     @trace
-    @defer.inlineCallbacks
-    def get_device(self, user_id, device_id):
+    async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
         """ Retrieve the given device
 
         Args:
-            user_id (str):
-            device_id (str):
+            user_id: The user to get the device from
+            device_id: The device to fetch.
 
         Returns:
-            defer.Deferred: dict[str, X]: info on the device
+            info on the device
         Raises:
             errors.NotFoundError: if the device was not found
         """
         try:
-            device = yield self.store.get_device(user_id, device_id)
+            device = await self.store.get_device(user_id, device_id)
         except errors.StoreError:
             raise errors.NotFoundError
-        ips = yield self.store.get_last_client_ip_by_device(user_id, device_id)
+        ips = await self.store.get_last_client_ip_by_device(user_id, device_id)
         _update_device_from_client_ips(device, ips)
 
         set_tag("device", device)
@@ -106,10 +102,9 @@ class DeviceWorkerHandler(BaseHandler):
 
         return device
 
-    @measure_func("device.get_user_ids_changed")
     @trace
-    @defer.inlineCallbacks
-    def get_user_ids_changed(self, user_id, from_token):
+    @measure_func("device.get_user_ids_changed")
+    async def get_user_ids_changed(self, user_id, from_token):
         """Get list of users that have had the devices updated, or have newly
         joined a room, that `user_id` may be interested in.
 
@@ -120,13 +115,13 @@ class DeviceWorkerHandler(BaseHandler):
 
         set_tag("user_id", user_id)
         set_tag("from_token", from_token)
-        now_room_key = yield self.store.get_room_events_max_id()
+        now_room_key = await self.store.get_room_events_max_id()
 
-        room_ids = yield self.store.get_rooms_for_user(user_id)
+        room_ids = await self.store.get_rooms_for_user(user_id)
 
         # First we check if any devices have changed for users that we share
         # rooms with.
-        users_who_share_room = yield self.store.get_users_who_share_room_with_user(
+        users_who_share_room = await self.store.get_users_who_share_room_with_user(
             user_id
         )
 
@@ -135,14 +130,14 @@ class DeviceWorkerHandler(BaseHandler):
         # Always tell the user about their own devices
         tracked_users.add(user_id)
 
-        changed = yield self.store.get_users_whose_devices_changed(
+        changed = await self.store.get_users_whose_devices_changed(
             from_token.device_list_key, tracked_users
         )
 
         # Then work out if any users have since joined
         rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
 
-        member_events = yield self.store.get_membership_changes_for_user(
+        member_events = await self.store.get_membership_changes_for_user(
             user_id, from_token.room_key, now_room_key
         )
         rooms_changed.update(event.room_id for event in member_events)
@@ -152,7 +147,7 @@ class DeviceWorkerHandler(BaseHandler):
         possibly_changed = set(changed)
         possibly_left = set()
         for room_id in rooms_changed:
-            current_state_ids = yield self.store.get_current_state_ids(room_id)
+            current_state_ids = await self.store.get_current_state_ids(room_id)
 
             # The user may have left the room
             # TODO: Check if they actually did or if we were just invited.
@@ -166,7 +161,7 @@ class DeviceWorkerHandler(BaseHandler):
 
             # Fetch the current state at the time.
             try:
-                event_ids = yield self.store.get_forward_extremeties_for_room(
+                event_ids = await self.store.get_forward_extremeties_for_room(
                     room_id, stream_ordering=stream_ordering
                 )
             except errors.StoreError:
@@ -192,7 +187,7 @@ class DeviceWorkerHandler(BaseHandler):
                 continue
 
             # mapping from event_id -> state_dict
-            prev_state_ids = yield self.state_store.get_state_ids_for_events(event_ids)
+            prev_state_ids = await self.state_store.get_state_ids_for_events(event_ids)
 
             # Check if we've joined the room? If so we just blindly add all the users to
             # the "possibly changed" users.
@@ -238,11 +233,10 @@ class DeviceWorkerHandler(BaseHandler):
 
         return result
 
-    @defer.inlineCallbacks
-    def on_federation_query_user_devices(self, user_id):
-        stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
-        master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master")
-        self_signing_key = yield self.store.get_e2e_cross_signing_key(
+    async def on_federation_query_user_devices(self, user_id):
+        stream_id, devices = await self.store.get_devices_with_keys_by_user(user_id)
+        master_key = await self.store.get_e2e_cross_signing_key(user_id, "master")
+        self_signing_key = await self.store.get_e2e_cross_signing_key(
             user_id, "self_signing"
         )
 
@@ -271,8 +265,7 @@ class DeviceHandler(DeviceWorkerHandler):
 
         hs.get_distributor().observe("user_left_room", self.user_left_room)
 
-    @defer.inlineCallbacks
-    def check_device_registered(
+    async def check_device_registered(
         self, user_id, device_id, initial_device_display_name=None
     ):
         """
@@ -290,13 +283,13 @@ class DeviceHandler(DeviceWorkerHandler):
             str: device id (generated if none was supplied)
         """
         if device_id is not None:
-            new_device = yield self.store.store_device(
+            new_device = await self.store.store_device(
                 user_id=user_id,
                 device_id=device_id,
                 initial_device_display_name=initial_device_display_name,
             )
             if new_device:
-                yield self.notify_device_update(user_id, [device_id])
+                await self.notify_device_update(user_id, [device_id])
             return device_id
 
         # if the device id is not specified, we'll autogen one, but loop a few
@@ -304,33 +297,29 @@ class DeviceHandler(DeviceWorkerHandler):
         attempts = 0
         while attempts < 5:
             device_id = stringutils.random_string(10).upper()
-            new_device = yield self.store.store_device(
+            new_device = await self.store.store_device(
                 user_id=user_id,
                 device_id=device_id,
                 initial_device_display_name=initial_device_display_name,
             )
             if new_device:
-                yield self.notify_device_update(user_id, [device_id])
+                await self.notify_device_update(user_id, [device_id])
                 return device_id
             attempts += 1
 
         raise errors.StoreError(500, "Couldn't generate a device ID.")
 
     @trace
-    @defer.inlineCallbacks
-    def delete_device(self, user_id, device_id):
+    async def delete_device(self, user_id: str, device_id: str) -> None:
         """ Delete the given device
 
         Args:
-            user_id (str):
-            device_id (str):
-
-        Returns:
-            defer.Deferred:
+            user_id: The user to delete the device from.
+            device_id: The device to delete.
         """
 
         try:
-            yield self.store.delete_device(user_id, device_id)
+            await self.store.delete_device(user_id, device_id)
         except errors.StoreError as e:
             if e.code == 404:
                 # no match
@@ -342,49 +331,40 @@ class DeviceHandler(DeviceWorkerHandler):
             else:
                 raise
 
-        yield defer.ensureDeferred(
-            self._auth_handler.delete_access_tokens_for_user(
-                user_id, device_id=device_id
-            )
+        await self._auth_handler.delete_access_tokens_for_user(
+            user_id, device_id=device_id
         )
 
-        yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id)
+        await self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id)
 
-        yield self.notify_device_update(user_id, [device_id])
+        await self.notify_device_update(user_id, [device_id])
 
     @trace
-    @defer.inlineCallbacks
-    def delete_all_devices_for_user(self, user_id, except_device_id=None):
+    async def delete_all_devices_for_user(
+        self, user_id: str, except_device_id: Optional[str] = None
+    ) -> None:
         """Delete all of the user's devices
 
         Args:
-            user_id (str):
-            except_device_id (str|None): optional device id which should not
-                be deleted
-
-        Returns:
-            defer.Deferred:
+            user_id: The user to remove all devices from
+            except_device_id: optional device id which should not be deleted
         """
-        device_map = yield self.store.get_devices_by_user(user_id)
+        device_map = await self.store.get_devices_by_user(user_id)
         device_ids = list(device_map)
         if except_device_id is not None:
             device_ids = [d for d in device_ids if d != except_device_id]
-        yield self.delete_devices(user_id, device_ids)
+        await self.delete_devices(user_id, device_ids)
 
-    @defer.inlineCallbacks
-    def delete_devices(self, user_id, device_ids):
+    async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
         """ Delete several devices
 
         Args:
-            user_id (str):
-            device_ids (List[str]): The list of device IDs to delete
-
-        Returns:
-            defer.Deferred:
+            user_id: The user to delete devices from.
+            device_ids: The list of device IDs to delete
         """
 
         try:
-            yield self.store.delete_devices(user_id, device_ids)
+            await self.store.delete_devices(user_id, device_ids)
         except errors.StoreError as e:
             if e.code == 404:
                 # no match
@@ -397,28 +377,22 @@ class DeviceHandler(DeviceWorkerHandler):
         # Delete access tokens and e2e keys for each device. Not optimised as it is not
         # considered as part of a critical path.
         for device_id in device_ids:
-            yield defer.ensureDeferred(
-                self._auth_handler.delete_access_tokens_for_user(
-                    user_id, device_id=device_id
-                )
+            await self._auth_handler.delete_access_tokens_for_user(
+                user_id, device_id=device_id
             )
-            yield self.store.delete_e2e_keys_by_device(
+            await self.store.delete_e2e_keys_by_device(
                 user_id=user_id, device_id=device_id
             )
 
-        yield self.notify_device_update(user_id, device_ids)
+        await self.notify_device_update(user_id, device_ids)
 
-    @defer.inlineCallbacks
-    def update_device(self, user_id, device_id, content):
+    async def update_device(self, user_id: str, device_id: str, content: dict) -> None:
         """ Update the given device
 
         Args:
-            user_id (str):
-            device_id (str):
-            content (dict): body of update request
-
-        Returns:
-            defer.Deferred:
+            user_id: The user to update devices of.
+            device_id: The device to update.
+            content: body of update request
         """
 
         # Reject a new displayname which is too long.
@@ -431,10 +405,10 @@ class DeviceHandler(DeviceWorkerHandler):
             )
 
         try:
-            yield self.store.update_device(
+            await self.store.update_device(
                 user_id, device_id, new_display_name=new_display_name
             )
-            yield self.notify_device_update(user_id, [device_id])
+            await self.notify_device_update(user_id, [device_id])
         except errors.StoreError as e:
             if e.code == 404:
                 raise errors.NotFoundError()
@@ -443,12 +417,11 @@ class DeviceHandler(DeviceWorkerHandler):
 
     @trace
     @measure_func("notify_device_update")
-    @defer.inlineCallbacks
-    def notify_device_update(self, user_id, device_ids):
+    async def notify_device_update(self, user_id, device_ids):
         """Notify that a user's device(s) has changed. Pokes the notifier, and
         remote servers if the user is local.
         """
-        users_who_share_room = yield self.store.get_users_who_share_room_with_user(
+        users_who_share_room = await self.store.get_users_who_share_room_with_user(
             user_id
         )
 
@@ -459,7 +432,7 @@ class DeviceHandler(DeviceWorkerHandler):
 
         set_tag("target_hosts", hosts)
 
-        position = yield self.store.add_device_change_to_streams(
+        position = await self.store.add_device_change_to_streams(
             user_id, device_ids, list(hosts)
         )
 
@@ -468,11 +441,11 @@ class DeviceHandler(DeviceWorkerHandler):
                 "Notifying about update %r/%r, ID: %r", user_id, device_id, position
             )
 
-        room_ids = yield self.store.get_rooms_for_user(user_id)
+        room_ids = await self.store.get_rooms_for_user(user_id)
 
         # specify the user ID too since the user should always get their own device list
         # updates, even if they aren't in any rooms.
-        yield self.notifier.on_new_event(
+        self.notifier.on_new_event(
             "device_list_key", position, users=[user_id], rooms=room_ids
         )
 
@@ -484,29 +457,29 @@ class DeviceHandler(DeviceWorkerHandler):
                 self.federation_sender.send_device_messages(host)
                 log_kv({"message": "sent device update to host", "host": host})
 
-    @defer.inlineCallbacks
-    def notify_user_signature_update(self, from_user_id, user_ids):
+    async def notify_user_signature_update(
+        self, from_user_id: str, user_ids: List[str]
+    ) -> None:
         """Notify a user that they have made new signatures of other users.
 
         Args:
-            from_user_id (str): the user who made the signature
-            user_ids (list[str]): the users IDs that have new signatures
+            from_user_id: the user who made the signature
+            user_ids: the users IDs that have new signatures
         """
 
-        position = yield self.store.add_user_signature_change_to_streams(
+        position = await self.store.add_user_signature_change_to_streams(
             from_user_id, user_ids
         )
 
         self.notifier.on_new_event("device_list_key", position, users=[from_user_id])
 
-    @defer.inlineCallbacks
-    def user_left_room(self, user, room_id):
+    async def user_left_room(self, user, room_id):
         user_id = user.to_string()
-        room_ids = yield self.store.get_rooms_for_user(user_id)
+        room_ids = await self.store.get_rooms_for_user(user_id)
         if not room_ids:
             # We no longer share rooms with this user, so we'll no longer
             # receive device updates. Mark this in DB.
-            yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
+            await self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
 
 
 def _update_device_from_client_ips(device, client_ips):
@@ -549,8 +522,7 @@ class DeviceListUpdater(object):
         )
 
     @trace
-    @defer.inlineCallbacks
-    def incoming_device_list_update(self, origin, edu_content):
+    async def incoming_device_list_update(self, origin, edu_content):
         """Called on incoming device list update from federation. Responsible
         for parsing the EDU and adding to pending updates list.
         """
@@ -583,7 +555,7 @@ class DeviceListUpdater(object):
             )
             return
 
-        room_ids = yield self.store.get_rooms_for_user(user_id)
+        room_ids = await self.store.get_rooms_for_user(user_id)
         if not room_ids:
             # We don't share any rooms with this user. Ignore update, as we
             # probably won't get any further updates.
@@ -608,14 +580,13 @@ class DeviceListUpdater(object):
             (device_id, stream_id, prev_ids, edu_content)
         )
 
-        yield self._handle_device_updates(user_id)
+        await self._handle_device_updates(user_id)
 
     @measure_func("_incoming_device_list_update")
-    @defer.inlineCallbacks
-    def _handle_device_updates(self, user_id):
+    async def _handle_device_updates(self, user_id):
         "Actually handle pending updates."
 
-        with (yield self._remote_edu_linearizer.queue(user_id)):
+        with (await self._remote_edu_linearizer.queue(user_id)):
             pending_updates = self._pending_updates.pop(user_id, [])
             if not pending_updates:
                 # This can happen since we batch updates
@@ -632,7 +603,7 @@ class DeviceListUpdater(object):
 
             # Given a list of updates we check if we need to resync. This
             # happens if we've missed updates.
-            resync = yield self._need_to_do_resync(user_id, pending_updates)
+            resync = await self._need_to_do_resync(user_id, pending_updates)
 
             if logger.isEnabledFor(logging.INFO):
                 logger.info(
@@ -643,16 +614,16 @@ class DeviceListUpdater(object):
                 )
 
             if resync:
-                yield self.user_device_resync(user_id)
+                await self.user_device_resync(user_id)
             else:
                 # Simply update the single device, since we know that is the only
                 # change (because of the single prev_id matching the current cache)
                 for device_id, stream_id, prev_ids, content in pending_updates:
-                    yield self.store.update_remote_device_list_cache_entry(
+                    await self.store.update_remote_device_list_cache_entry(
                         user_id, device_id, content, stream_id
                     )
 
-                yield self.device_handler.notify_device_update(
+                await self.device_handler.notify_device_update(
                     user_id, [device_id for device_id, _, _, _ in pending_updates]
                 )
 
@@ -660,14 +631,13 @@ class DeviceListUpdater(object):
                     stream_id for _, stream_id, _, _ in pending_updates
                 )
 
-    @defer.inlineCallbacks
-    def _need_to_do_resync(self, user_id, updates):
+    async def _need_to_do_resync(self, user_id, updates):
         """Given a list of updates for a user figure out if we need to do a full
         resync, or whether we have enough data that we can just apply the delta.
         """
         seen_updates = self._seen_updates.get(user_id, set())
 
-        extremity = yield self.store.get_device_list_last_stream_id_for_remote(user_id)
+        extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)
 
         logger.debug("Current extremity for %r: %r", user_id, extremity)
 
@@ -692,8 +662,7 @@ class DeviceListUpdater(object):
         return False
 
     @trace
-    @defer.inlineCallbacks
-    def _maybe_retry_device_resync(self):
+    async def _maybe_retry_device_resync(self):
         """Retry to resync device lists that are out of sync, except if another retry is
         in progress.
         """
@@ -705,12 +674,12 @@ class DeviceListUpdater(object):
             # we don't send too many requests.
             self._resync_retry_in_progress = True
             # Get all of the users that need resyncing.
-            need_resync = yield self.store.get_user_ids_requiring_device_list_resync()
+            need_resync = await self.store.get_user_ids_requiring_device_list_resync()
             # Iterate over the set of user IDs.
             for user_id in need_resync:
                 try:
                     # Try to resync the current user's devices list.
-                    result = yield self.user_device_resync(
+                    result = await self.user_device_resync(
                         user_id=user_id, mark_failed_as_stale=False,
                     )
 
@@ -734,16 +703,17 @@ class DeviceListUpdater(object):
             # Allow future calls to retry resyncinc out of sync device lists.
             self._resync_retry_in_progress = False
 
-    @defer.inlineCallbacks
-    def user_device_resync(self, user_id, mark_failed_as_stale=True):
+    async def user_device_resync(
+        self, user_id: str, mark_failed_as_stale: bool = True
+    ) -> Optional[dict]:
         """Fetches all devices for a user and updates the device cache with them.
 
         Args:
-            user_id (str): The user's id whose device_list will be updated.
-            mark_failed_as_stale (bool): Whether to mark the user's device list as stale
+            user_id: The user's id whose device_list will be updated.
+            mark_failed_as_stale: Whether to mark the user's device list as stale
                 if the attempt to resync failed.
         Returns:
-            Deferred[dict]: a dict with device info as under the "devices" in the result of this
+            A dict with device info as under the "devices" in the result of this
             request:
             https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
         """
@@ -752,12 +722,12 @@ class DeviceListUpdater(object):
         # Fetch all devices for the user.
         origin = get_domain_from_id(user_id)
         try:
-            result = yield self.federation.query_user_devices(origin, user_id)
+            result = await self.federation.query_user_devices(origin, user_id)
         except NotRetryingDestination:
             if mark_failed_as_stale:
                 # Mark the remote user's device list as stale so we know we need to retry
                 # it later.
-                yield self.store.mark_remote_user_device_cache_as_stale(user_id)
+                await self.store.mark_remote_user_device_cache_as_stale(user_id)
 
             return
         except (RequestSendFailed, HttpResponseException) as e:
@@ -768,7 +738,7 @@ class DeviceListUpdater(object):
             if mark_failed_as_stale:
                 # Mark the remote user's device list as stale so we know we need to retry
                 # it later.
-                yield self.store.mark_remote_user_device_cache_as_stale(user_id)
+                await self.store.mark_remote_user_device_cache_as_stale(user_id)
 
             # We abort on exceptions rather than accepting the update
             # as otherwise synapse will 'forget' that its device list
@@ -792,7 +762,7 @@ class DeviceListUpdater(object):
             if mark_failed_as_stale:
                 # Mark the remote user's device list as stale so we know we need to retry
                 # it later.
-                yield self.store.mark_remote_user_device_cache_as_stale(user_id)
+                await self.store.mark_remote_user_device_cache_as_stale(user_id)
 
             return
         log_kv({"result": result})
@@ -833,25 +803,24 @@ class DeviceListUpdater(object):
                 stream_id,
             )
 
-        yield self.store.update_remote_device_list_cache(user_id, devices, stream_id)
+        await self.store.update_remote_device_list_cache(user_id, devices, stream_id)
         device_ids = [device["device_id"] for device in devices]
 
         # Handle cross-signing keys.
-        cross_signing_device_ids = yield self.process_cross_signing_key_update(
+        cross_signing_device_ids = await self.process_cross_signing_key_update(
             user_id, master_key, self_signing_key,
         )
         device_ids = device_ids + cross_signing_device_ids
 
-        yield self.device_handler.notify_device_update(user_id, device_ids)
+        await self.device_handler.notify_device_update(user_id, device_ids)
 
         # We clobber the seen updates since we've re-synced from a given
         # point.
         self._seen_updates[user_id] = {stream_id}
 
-        defer.returnValue(result)
+        return result
 
-    @defer.inlineCallbacks
-    def process_cross_signing_key_update(
+    async def process_cross_signing_key_update(
         self,
         user_id: str,
         master_key: Optional[Dict[str, Any]],
@@ -872,14 +841,14 @@ class DeviceListUpdater(object):
         device_ids = []
 
         if master_key:
-            yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
+            await self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
             _, verify_key = get_verify_key_from_cross_signing_key(master_key)
             # verify_key is a VerifyKey from signedjson, which uses
             # .version to denote the portion of the key ID after the
             # algorithm and colon, which is the device ID
             device_ids.append(verify_key.version)
         if self_signing_key:
-            yield self.store.set_e2e_cross_signing_key(
+            await self.store.set_e2e_cross_signing_key(
                 user_id, "self_signing", self_signing_key
             )
             _, verify_key = get_verify_key_from_cross_signing_key(self_signing_key)
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index da20523b70..22a857a306 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -12,10 +12,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import inspect
 import logging
 
 from twisted.internet import defer
+from twisted.internet.defer import Deferred, fail, succeed
+from twisted.python import failure
 
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
@@ -79,6 +81,28 @@ class Distributor(object):
         run_as_background_process(name, self.signals[name].fire, *args, **kwargs)
 
 
+def maybeAwaitableDeferred(f, *args, **kw):
+    """
+    Invoke a function that may or may not return a Deferred or an Awaitable.
+
+    This is a modified version of twisted.internet.defer.maybeDeferred.
+    """
+    try:
+        result = f(*args, **kw)
+    except Exception:
+        return fail(failure.Failure(captureVars=Deferred.debug))
+
+    if isinstance(result, Deferred):
+        return result
+    # Handle the additional case of an awaitable being returned.
+    elif inspect.isawaitable(result):
+        return defer.ensureDeferred(result)
+    elif isinstance(result, failure.Failure):
+        return fail(result)
+    else:
+        return succeed(result)
+
+
 class Signal(object):
     """A Signal is a dispatch point that stores a list of callables as
     observers of it.
@@ -122,7 +146,7 @@ class Signal(object):
                     ),
                 )
 
-            return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
+            return maybeAwaitableDeferred(observer, *args, **kwargs).addErrback(eb)
 
         deferreds = [run_in_background(do, o) for o in self.observers]
 
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 62b47f6574..6aa322bf3a 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -142,10 +142,8 @@ class DeviceTestCase(unittest.HomeserverTestCase):
         self.get_success(self.handler.delete_device(user1, "abc"))
 
         # check the device was deleted
-        res = self.handler.get_device(user1, "abc")
-        self.pump()
-        self.assertIsInstance(
-            self.failureResultOf(res).value, synapse.api.errors.NotFoundError
+        self.get_failure(
+            self.handler.get_device(user1, "abc"), synapse.api.errors.NotFoundError
         )
 
         # we'd like to check the access token was invalidated, but that's a
@@ -180,10 +178,9 @@ class DeviceTestCase(unittest.HomeserverTestCase):
 
     def test_update_unknown_device(self):
         update = {"display_name": "new_display"}
-        res = self.handler.update_device("user_id", "unknown_device_id", update)
-        self.pump()
-        self.assertIsInstance(
-            self.failureResultOf(res).value, synapse.api.errors.NotFoundError
+        self.get_failure(
+            self.handler.update_device("user_id", "unknown_device_id", update),
+            synapse.api.errors.NotFoundError,
         )
 
     def _record_users(self):
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index cdd093ffa8..210ddcbb88 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -334,10 +334,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
 
         res = None
         try:
-            yield self.hs.get_device_handler().check_device_registered(
-                user_id=local_user,
-                device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
-                initial_device_display_name="new display name",
+            yield defer.ensureDeferred(
+                self.hs.get_device_handler().check_device_registered(
+                    user_id=local_user,
+                    device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
+                    initial_device_display_name="new display name",
+                )
             )
         except errors.SynapseError as e:
             res = e.code
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 89dcc58b99..87a16d7d7a 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -173,7 +173,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         # Register a mock on the store so that the incoming update doesn't fail because
         # we don't share a room with the user.
         store = self.homeserver.get_datastore()
-        store.get_rooms_for_user = Mock(return_value=["!someroom:test"])
+        store.get_rooms_for_user = Mock(return_value=succeed(["!someroom:test"]))
 
         # Manually inject a fake device list update. We need this update to include at
         # least one prev_id so that the user's device list will need to be retried.
@@ -218,23 +218,26 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         # Register mock device list retrieval on the federation client.
         federation_client = self.homeserver.get_federation_client()
         federation_client.query_user_devices = Mock(
-            return_value={
-                "user_id": remote_user_id,
-                "stream_id": 1,
-                "devices": [],
-                "master_key": {
+            return_value=succeed(
+                {
                     "user_id": remote_user_id,
-                    "usage": ["master"],
-                    "keys": {"ed25519:" + remote_master_key: remote_master_key},
-                },
-                "self_signing_key": {
-                    "user_id": remote_user_id,
-                    "usage": ["self_signing"],
-                    "keys": {
-                        "ed25519:" + remote_self_signing_key: remote_self_signing_key
+                    "stream_id": 1,
+                    "devices": [],
+                    "master_key": {
+                        "user_id": remote_user_id,
+                        "usage": ["master"],
+                        "keys": {"ed25519:" + remote_master_key: remote_master_key},
                     },
-                },
-            }
+                    "self_signing_key": {
+                        "user_id": remote_user_id,
+                        "usage": ["self_signing"],
+                        "keys": {
+                            "ed25519:"
+                            + remote_self_signing_key: remote_self_signing_key
+                        },
+                    },
+                }
+            )
         )
 
         # Resync the device list.