summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/8069.misc1
-rw-r--r--synapse/storage/databases/main/devices.py333
-rw-r--r--tests/handlers/test_typing.py2
-rw-r--r--tests/storage/test_devices.py44
-rw-r--r--tests/storage/test_end_to_end_keys.py16
5 files changed, 220 insertions, 176 deletions
diff --git a/changelog.d/8069.misc b/changelog.d/8069.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8069.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 7a5f0bab05..2b33060480 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -15,9 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import List, Optional, Set, Tuple
-
-from twisted.internet import defer
+from typing import Dict, Iterable, List, Optional, Set, Tuple
 
 from synapse.api.errors import Codes, StoreError
 from synapse.logging.opentracing import (
@@ -33,14 +31,9 @@ from synapse.storage.database import (
     LoggingTransaction,
     make_tuple_comparison_clause,
 )
-from synapse.types import Collection, get_verify_key_from_cross_signing_key
+from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key
 from synapse.util import json_encoder
-from synapse.util.caches.descriptors import (
-    Cache,
-    cached,
-    cachedInlineCallbacks,
-    cachedList,
-)
+from synapse.util.caches.descriptors import Cache, cached, cachedList
 from synapse.util.iterutils import batch_iter
 from synapse.util.stringutils import shortstr
 
@@ -54,13 +47,13 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
 
 
 class DeviceWorkerStore(SQLBaseStore):
-    def get_device(self, user_id, device_id):
+    def get_device(self, user_id: str, device_id: str):
         """Retrieve a device. Only returns devices that are not marked as
         hidden.
 
         Args:
-            user_id (str): The ID of the user which owns the device
-            device_id (str): The ID of the device to retrieve
+            user_id: The ID of the user which owns the device
+            device_id: The ID of the device to retrieve
         Returns:
             defer.Deferred for a dict containing the device information
         Raises:
@@ -73,19 +66,17 @@ class DeviceWorkerStore(SQLBaseStore):
             desc="get_device",
         )
 
-    @defer.inlineCallbacks
-    def get_devices_by_user(self, user_id):
+    async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]:
         """Retrieve all of a user's registered devices. Only returns devices
         that are not marked as hidden.
 
         Args:
-            user_id (str):
+            user_id:
         Returns:
-            defer.Deferred: resolves to a dict from device_id to a dict
-            containing "device_id", "user_id" and "display_name" for each
-            device.
+            A mapping from device_id to a dict containing "device_id", "user_id"
+            and "display_name" for each device.
         """
-        devices = yield self.db_pool.simple_select_list(
+        devices = await self.db_pool.simple_select_list(
             table="devices",
             keyvalues={"user_id": user_id, "hidden": False},
             retcols=("user_id", "device_id", "display_name"),
@@ -95,19 +86,20 @@ class DeviceWorkerStore(SQLBaseStore):
         return {d["device_id"]: d for d in devices}
 
     @trace
-    @defer.inlineCallbacks
-    def get_device_updates_by_remote(self, destination, from_stream_id, limit):
+    async def get_device_updates_by_remote(
+        self, destination: str, from_stream_id: int, limit: int
+    ) -> Tuple[int, List[Tuple[str, dict]]]:
         """Get a stream of device updates to send to the given remote server.
 
         Args:
-            destination (str): The host the device updates are intended for
-            from_stream_id (int): The minimum stream_id to filter updates by, exclusive
-            limit (int): Maximum number of device updates to return
+            destination: The host the device updates are intended for
+            from_stream_id: The minimum stream_id to filter updates by, exclusive
+            limit: Maximum number of device updates to return
+
         Returns:
-            Deferred[tuple[int, list[tuple[string,dict]]]]:
-                current stream id (ie, the stream id of the last update included in the
-                response), and the list of updates, where each update is a pair of EDU
-                type and EDU contents
+            A mapping from the  current stream id (ie, the stream id of the last
+            update included in the response), and the list of updates, where
+            each update is a pair of EDU type and EDU contents.
         """
         now_stream_id = self._device_list_id_gen.get_current_token()
 
@@ -117,7 +109,7 @@ class DeviceWorkerStore(SQLBaseStore):
         if not has_changed:
             return now_stream_id, []
 
-        updates = yield self.db_pool.runInteraction(
+        updates = await self.db_pool.runInteraction(
             "get_device_updates_by_remote",
             self._get_device_updates_by_remote_txn,
             destination,
@@ -136,9 +128,7 @@ class DeviceWorkerStore(SQLBaseStore):
         master_key_by_user = {}
         self_signing_key_by_user = {}
         for user in users:
-            cross_signing_key = yield defer.ensureDeferred(
-                self.get_e2e_cross_signing_key(user, "master")
-            )
+            cross_signing_key = await self.get_e2e_cross_signing_key(user, "master")
             if cross_signing_key:
                 key_id, verify_key = get_verify_key_from_cross_signing_key(
                     cross_signing_key
@@ -151,8 +141,8 @@ class DeviceWorkerStore(SQLBaseStore):
                     "device_id": verify_key.version,
                 }
 
-            cross_signing_key = yield defer.ensureDeferred(
-                self.get_e2e_cross_signing_key(user, "self_signing")
+            cross_signing_key = await self.get_e2e_cross_signing_key(
+                user, "self_signing"
             )
             if cross_signing_key:
                 key_id, verify_key = get_verify_key_from_cross_signing_key(
@@ -202,7 +192,7 @@ class DeviceWorkerStore(SQLBaseStore):
                 if update_stream_id > previous_update_stream_id:
                     query_map[key] = (update_stream_id, update_context)
 
-        results = yield self._get_device_update_edus_by_remote(
+        results = await self._get_device_update_edus_by_remote(
             destination, from_stream_id, query_map
         )
 
@@ -215,16 +205,21 @@ class DeviceWorkerStore(SQLBaseStore):
         return now_stream_id, results
 
     def _get_device_updates_by_remote_txn(
-        self, txn, destination, from_stream_id, now_stream_id, limit
+        self,
+        txn: LoggingTransaction,
+        destination: str,
+        from_stream_id: int,
+        now_stream_id: int,
+        limit: int,
     ):
         """Return device update information for a given remote destination
 
         Args:
-            txn (LoggingTransaction): The transaction to execute
-            destination (str): The host the device updates are intended for
-            from_stream_id (int): The minimum stream_id to filter updates by, exclusive
-            now_stream_id (int): The maximum stream_id to filter updates by, inclusive
-            limit (int): Maximum number of device updates to return
+            txn: The transaction to execute
+            destination: The host the device updates are intended for
+            from_stream_id: The minimum stream_id to filter updates by, exclusive
+            now_stream_id: The maximum stream_id to filter updates by, inclusive
+            limit: Maximum number of device updates to return
 
         Returns:
             List: List of device updates
@@ -240,23 +235,26 @@ class DeviceWorkerStore(SQLBaseStore):
 
         return list(txn)
 
-    @defer.inlineCallbacks
-    def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_map):
+    async def _get_device_update_edus_by_remote(
+        self,
+        destination: str,
+        from_stream_id: int,
+        query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]],
+    ) -> List[Tuple[str, dict]]:
         """Returns a list of device update EDUs as well as E2EE keys
 
         Args:
-            destination (str): The host the device updates are intended for
-            from_stream_id (int): The minimum stream_id to filter updates by, exclusive
+            destination: The host the device updates are intended for
+            from_stream_id: The minimum stream_id to filter updates by, exclusive
             query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping
                 user_id/device_id to update stream_id and the relevant json-encoded
                 opentracing context
 
         Returns:
-            List[Dict]: List of objects representing an device update EDU
-
+            List of objects representing an device update EDU
         """
         devices = (
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "_get_e2e_device_keys_txn",
                 self._get_e2e_device_keys_txn,
                 query_map.keys(),
@@ -271,7 +269,7 @@ class DeviceWorkerStore(SQLBaseStore):
         for user_id, user_devices in devices.items():
             # The prev_id for the first row is always the last row before
             # `from_stream_id`
-            prev_id = yield self._get_last_device_update_for_remote_user(
+            prev_id = await self._get_last_device_update_for_remote_user(
                 destination, user_id, from_stream_id
             )
 
@@ -315,7 +313,7 @@ class DeviceWorkerStore(SQLBaseStore):
         return results
 
     def _get_last_device_update_for_remote_user(
-        self, destination, user_id, from_stream_id
+        self, destination: str, user_id: str, from_stream_id: int
     ):
         def f(txn):
             prev_sent_id_sql = """
@@ -329,7 +327,7 @@ class DeviceWorkerStore(SQLBaseStore):
 
         return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f)
 
-    def mark_as_sent_devices_by_remote(self, destination, stream_id):
+    def mark_as_sent_devices_by_remote(self, destination: str, stream_id: int):
         """Mark that updates have successfully been sent to the destination.
         """
         return self.db_pool.runInteraction(
@@ -339,7 +337,9 @@ class DeviceWorkerStore(SQLBaseStore):
             stream_id,
         )
 
-    def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
+    def _mark_as_sent_devices_by_remote_txn(
+        self, txn: LoggingTransaction, destination: str, stream_id: int
+    ) -> None:
         # We update the device_lists_outbound_last_success with the successfully
         # poked users.
         sql = """
@@ -367,17 +367,21 @@ class DeviceWorkerStore(SQLBaseStore):
         """
         txn.execute(sql, (destination, stream_id))
 
-    @defer.inlineCallbacks
-    def add_user_signature_change_to_streams(self, from_user_id, user_ids):
+    async def add_user_signature_change_to_streams(
+        self, from_user_id: str, user_ids: List[str]
+    ) -> int:
         """Persist that a user has made new signatures
 
         Args:
-            from_user_id (str): the user who made the signatures
-            user_ids (list[str]): the users who were signed
+            from_user_id: the user who made the signatures
+            user_ids: the users who were signed
+
+        Returns:
+            THe new stream ID.
         """
 
         with self._device_list_id_gen.get_next() as stream_id:
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "add_user_sig_change_to_streams",
                 self._add_user_signature_change_txn,
                 from_user_id,
@@ -386,7 +390,13 @@ class DeviceWorkerStore(SQLBaseStore):
             )
         return stream_id
 
-    def _add_user_signature_change_txn(self, txn, from_user_id, user_ids, stream_id):
+    def _add_user_signature_change_txn(
+        self,
+        txn: LoggingTransaction,
+        from_user_id: str,
+        user_ids: List[str],
+        stream_id: int,
+    ) -> None:
         txn.call_after(
             self._user_signature_stream_cache.entity_has_changed,
             from_user_id,
@@ -402,29 +412,30 @@ class DeviceWorkerStore(SQLBaseStore):
             },
         )
 
-    def get_device_stream_token(self):
+    def get_device_stream_token(self) -> int:
         return self._device_list_id_gen.get_current_token()
 
     @trace
-    @defer.inlineCallbacks
-    def get_user_devices_from_cache(self, query_list):
+    async def get_user_devices_from_cache(
+        self, query_list: List[Tuple[str, str]]
+    ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
         """Get the devices (and keys if any) for remote users from the cache.
 
         Args:
-            query_list(list): List of (user_id, device_ids), if device_ids is
+            query_list: List of (user_id, device_ids), if device_ids is
                 falsey then return all device ids for that user.
 
         Returns:
-            (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
-            a set of user_ids and results_map is a mapping of
-            user_id -> device_id -> device_info
+            A tuple of (user_ids_not_in_cache, results_map), where
+            user_ids_not_in_cache is a set of user_ids and results_map is a
+            mapping of user_id -> device_id -> device_info.
         """
         user_ids = {user_id for user_id, _ in query_list}
-        user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
+        user_map = await self.get_device_list_last_stream_id_for_remotes(list(user_ids))
 
         # We go and check if any of the users need to have their device lists
         # resynced. If they do then we remove them from the cached list.
-        users_needing_resync = yield self.get_user_ids_requiring_device_list_resync(
+        users_needing_resync = await self.get_user_ids_requiring_device_list_resync(
             user_ids
         )
         user_ids_in_cache = {
@@ -438,19 +449,19 @@ class DeviceWorkerStore(SQLBaseStore):
                 continue
 
             if device_id:
-                device = yield self._get_cached_user_device(user_id, device_id)
+                device = await self._get_cached_user_device(user_id, device_id)
                 results.setdefault(user_id, {})[device_id] = device
             else:
-                results[user_id] = yield self.get_cached_devices_for_user(user_id)
+                results[user_id] = await self.get_cached_devices_for_user(user_id)
 
         set_tag("in_cache", results)
         set_tag("not_in_cache", user_ids_not_in_cache)
 
         return user_ids_not_in_cache, results
 
-    @cachedInlineCallbacks(num_args=2, tree=True)
-    def _get_cached_user_device(self, user_id, device_id):
-        content = yield self.db_pool.simple_select_one_onecol(
+    @cached(num_args=2, tree=True)
+    async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDict:
+        content = await self.db_pool.simple_select_one_onecol(
             table="device_lists_remote_cache",
             keyvalues={"user_id": user_id, "device_id": device_id},
             retcol="content",
@@ -458,9 +469,9 @@ class DeviceWorkerStore(SQLBaseStore):
         )
         return db_to_json(content)
 
-    @cachedInlineCallbacks()
-    def get_cached_devices_for_user(self, user_id):
-        devices = yield self.db_pool.simple_select_list(
+    @cached()
+    async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]:
+        devices = await self.db_pool.simple_select_list(
             table="device_lists_remote_cache",
             keyvalues={"user_id": user_id},
             retcols=("device_id", "content"),
@@ -470,11 +481,11 @@ class DeviceWorkerStore(SQLBaseStore):
             device["device_id"]: db_to_json(device["content"]) for device in devices
         }
 
-    def get_devices_with_keys_by_user(self, user_id):
+    def get_devices_with_keys_by_user(self, user_id: str):
         """Get all devices (with any device keys) for a user
 
         Returns:
-            (stream_id, devices)
+            Deferred which resolves to (stream_id, devices)
         """
         return self.db_pool.runInteraction(
             "get_devices_with_keys_by_user",
@@ -482,7 +493,9 @@ class DeviceWorkerStore(SQLBaseStore):
             user_id,
         )
 
-    def _get_devices_with_keys_by_user_txn(self, txn, user_id):
+    def _get_devices_with_keys_by_user_txn(
+        self, txn: LoggingTransaction, user_id: str
+    ) -> Tuple[int, List[JsonDict]]:
         now_stream_id = self._device_list_id_gen.get_current_token()
 
         devices = self._get_e2e_device_keys_txn(
@@ -515,17 +528,18 @@ class DeviceWorkerStore(SQLBaseStore):
 
         return now_stream_id, []
 
-    def get_users_whose_devices_changed(self, from_key, user_ids):
+    async def get_users_whose_devices_changed(
+        self, from_key: str, user_ids: Iterable[str]
+    ) -> Set[str]:
         """Get set of users whose devices have changed since `from_key` that
         are in the given list of user_ids.
 
         Args:
-            from_key (str): The device lists stream token
-            user_ids (Iterable[str])
+            from_key: The device lists stream token
+            user_ids: The user IDs to query for devices.
 
         Returns:
-            Deferred[set[str]]: The set of user_ids whose devices have changed
-            since `from_key`
+            The set of user_ids whose devices have changed since `from_key`
         """
         from_key = int(from_key)
 
@@ -536,7 +550,7 @@ class DeviceWorkerStore(SQLBaseStore):
         )
 
         if not to_check:
-            return defer.succeed(set())
+            return set()
 
         def _get_users_whose_devices_changed_txn(txn):
             changes = set()
@@ -556,18 +570,22 @@ class DeviceWorkerStore(SQLBaseStore):
 
             return changes
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
         )
 
-    @defer.inlineCallbacks
-    def get_users_whose_signatures_changed(self, user_id, from_key):
+    async def get_users_whose_signatures_changed(
+        self, user_id: str, from_key: str
+    ) -> Set[str]:
         """Get the users who have new cross-signing signatures made by `user_id` since
         `from_key`.
 
         Args:
-            user_id (str): the user who made the signatures
-            from_key (str): The device lists stream token
+            user_id: the user who made the signatures
+            from_key: The device lists stream token
+
+        Returns:
+            A set of user IDs with updated signatures.
         """
         from_key = int(from_key)
         if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
@@ -575,7 +593,7 @@ class DeviceWorkerStore(SQLBaseStore):
                 SELECT DISTINCT user_ids FROM user_signature_stream
                 WHERE from_user_id = ? AND stream_id > ?
             """
-            rows = yield self.db_pool.execute(
+            rows = await self.db_pool.execute(
                 "get_users_whose_signatures_changed", None, sql, user_id, from_key
             )
             return {user for row in rows for user in db_to_json(row[0])}
@@ -638,7 +656,7 @@ class DeviceWorkerStore(SQLBaseStore):
         )
 
     @cached(max_entries=10000)
-    def get_device_list_last_stream_id_for_remote(self, user_id):
+    def get_device_list_last_stream_id_for_remote(self, user_id: str):
         """Get the last stream_id we got for a user. May be None if we haven't
         got any information for them.
         """
@@ -655,7 +673,7 @@ class DeviceWorkerStore(SQLBaseStore):
         list_name="user_ids",
         inlineCallbacks=True,
     )
-    def get_device_list_last_stream_id_for_remotes(self, user_ids):
+    def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
         rows = yield self.db_pool.simple_select_many_batch(
             table="device_lists_remote_extremeties",
             column="user_id",
@@ -669,8 +687,7 @@ class DeviceWorkerStore(SQLBaseStore):
 
         return results
 
-    @defer.inlineCallbacks
-    def get_user_ids_requiring_device_list_resync(
+    async def get_user_ids_requiring_device_list_resync(
         self, user_ids: Optional[Collection[str]] = None,
     ) -> Set[str]:
         """Given a list of remote users return the list of users that we
@@ -681,7 +698,7 @@ class DeviceWorkerStore(SQLBaseStore):
             The IDs of users whose device lists need resync.
         """
         if user_ids:
-            rows = yield self.db_pool.simple_select_many_batch(
+            rows = await self.db_pool.simple_select_many_batch(
                 table="device_lists_remote_resync",
                 column="user_id",
                 iterable=user_ids,
@@ -689,7 +706,7 @@ class DeviceWorkerStore(SQLBaseStore):
                 desc="get_user_ids_requiring_device_list_resync_with_iterable",
             )
         else:
-            rows = yield self.db_pool.simple_select_list(
+            rows = await self.db_pool.simple_select_list(
                 table="device_lists_remote_resync",
                 keyvalues=None,
                 retcols=("user_id",),
@@ -710,7 +727,7 @@ class DeviceWorkerStore(SQLBaseStore):
             desc="make_remote_user_device_cache_as_stale",
         )
 
-    def mark_remote_user_device_list_as_unsubscribed(self, user_id):
+    def mark_remote_user_device_list_as_unsubscribed(self, user_id: str):
         """Mark that we no longer track device lists for remote user.
         """
 
@@ -779,16 +796,15 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
             "drop_device_lists_outbound_last_success_non_unique_idx",
         )
 
-    @defer.inlineCallbacks
-    def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
+    async def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
         def f(conn):
             txn = conn.cursor()
             txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
             txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
             txn.close()
 
-        yield self.db_pool.runWithConnection(f)
-        yield self.db_pool.updates._end_background_update(
+        await self.db_pool.runWithConnection(f)
+        await self.db_pool.updates._end_background_update(
             DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES
         )
         return 1
@@ -868,18 +884,20 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
 
         self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
 
-    @defer.inlineCallbacks
-    def store_device(self, user_id, device_id, initial_device_display_name):
+    async def store_device(
+        self, user_id: str, device_id: str, initial_device_display_name: str
+    ) -> bool:
         """Ensure the given device is known; add it to the store if not
 
         Args:
-            user_id (str): id of user associated with the device
-            device_id (str): id of device
-            initial_device_display_name (str): initial displayname of the
-               device. Ignored if device exists.
+            user_id: id of user associated with the device
+            device_id: id of device
+            initial_device_display_name: initial displayname of the device.
+                Ignored if device exists.
+
         Returns:
-            defer.Deferred: boolean whether the device was inserted or an
-                existing device existed with that ID.
+            Whether the device was inserted or an existing device existed with that ID.
+
         Raises:
             StoreError: if the device is already in use
         """
@@ -888,7 +906,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             return False
 
         try:
-            inserted = yield self.db_pool.simple_insert(
+            inserted = await self.db_pool.simple_insert(
                 "devices",
                 values={
                     "user_id": user_id,
@@ -902,7 +920,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             if not inserted:
                 # if the device already exists, check if it's a real device, or
                 # if the device ID is reserved by something else
-                hidden = yield self.db_pool.simple_select_one_onecol(
+                hidden = await self.db_pool.simple_select_one_onecol(
                     "devices",
                     keyvalues={"user_id": user_id, "device_id": device_id},
                     retcol="hidden",
@@ -927,17 +945,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             )
             raise StoreError(500, "Problem storing device.")
 
-    @defer.inlineCallbacks
-    def delete_device(self, user_id, device_id):
+    async def delete_device(self, user_id: str, device_id: str) -> None:
         """Delete a device.
 
         Args:
-            user_id (str): The ID of the user which owns the device
-            device_id (str): The ID of the device to delete
-        Returns:
-            defer.Deferred
+            user_id: The ID of the user which owns the device
+            device_id: The ID of the device to delete
         """
-        yield self.db_pool.simple_delete_one(
+        await self.db_pool.simple_delete_one(
             table="devices",
             keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
             desc="delete_device",
@@ -945,17 +960,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
 
         self.device_id_exists_cache.invalidate((user_id, device_id))
 
-    @defer.inlineCallbacks
-    def delete_devices(self, user_id, device_ids):
+    async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
         """Deletes several devices.
 
         Args:
-            user_id (str): The ID of the user which owns the devices
-            device_ids (list): The IDs of the devices to delete
-        Returns:
-            defer.Deferred
+            user_id: The ID of the user which owns the devices
+            device_ids: The IDs of the devices to delete
         """
-        yield self.db_pool.simple_delete_many(
+        await self.db_pool.simple_delete_many(
             table="devices",
             column="device_id",
             iterable=device_ids,
@@ -965,26 +977,25 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         for device_id in device_ids:
             self.device_id_exists_cache.invalidate((user_id, device_id))
 
-    def update_device(self, user_id, device_id, new_display_name=None):
+    async def update_device(
+        self, user_id: str, device_id: str, new_display_name: Optional[str] = None
+    ) -> None:
         """Update a device. Only updates the device if it is not marked as
         hidden.
 
         Args:
-            user_id (str): The ID of the user which owns the device
-            device_id (str): The ID of the device to update
-            new_display_name (str|None): new displayname for device; None
-               to leave unchanged
+            user_id: The ID of the user which owns the device
+            device_id: The ID of the device to update
+            new_display_name: new displayname for device; None to leave unchanged
         Raises:
             StoreError: if the device is not found
-        Returns:
-            defer.Deferred
         """
         updates = {}
         if new_display_name is not None:
             updates["display_name"] = new_display_name
         if not updates:
-            return defer.succeed(None)
-        return self.db_pool.simple_update_one(
+            return None
+        await self.db_pool.simple_update_one(
             table="devices",
             keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
             updatevalues=updates,
@@ -992,7 +1003,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         )
 
     def update_remote_device_list_cache_entry(
-        self, user_id, device_id, content, stream_id
+        self, user_id: str, device_id: str, content: JsonDict, stream_id: int
     ):
         """Updates a single device in the cache of a remote user's devicelist.
 
@@ -1000,10 +1011,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         device list.
 
         Args:
-            user_id (str): User to update device list for
-            device_id (str): ID of decivice being updated
-            content (dict): new data on this device
-            stream_id (int): the version of the device list
+            user_id: User to update device list for
+            device_id: ID of decivice being updated
+            content: new data on this device
+            stream_id: the version of the device list
 
         Returns:
             Deferred[None]
@@ -1018,8 +1029,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         )
 
     def _update_remote_device_list_cache_entry_txn(
-        self, txn, user_id, device_id, content, stream_id
-    ):
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        device_id: str,
+        content: JsonDict,
+        stream_id: int,
+    ) -> None:
         if content.get("deleted"):
             self.db_pool.simple_delete_txn(
                 txn,
@@ -1055,16 +1071,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             lock=False,
         )
 
-    def update_remote_device_list_cache(self, user_id, devices, stream_id):
+    def update_remote_device_list_cache(
+        self, user_id: str, devices: List[dict], stream_id: int
+    ):
         """Replace the entire cache of the remote user's devices.
 
         Note: assumes that we are the only thread that can be updating this user's
         device list.
 
         Args:
-            user_id (str): User to update device list for
-            devices (list[dict]): list of device objects supplied over federation
-            stream_id (int): the version of the device list
+            user_id: User to update device list for
+            devices: list of device objects supplied over federation
+            stream_id: the version of the device list
 
         Returns:
             Deferred[None]
@@ -1077,7 +1095,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             stream_id,
         )
 
-    def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id):
+    def _update_remote_device_list_cache_txn(
+        self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
+    ):
         self.db_pool.simple_delete_txn(
             txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
         )
@@ -1118,8 +1138,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id},
         )
 
-    @defer.inlineCallbacks
-    def add_device_change_to_streams(self, user_id, device_ids, hosts):
+    async def add_device_change_to_streams(
+        self, user_id: str, device_ids: Collection[str], hosts: List[str]
+    ):
         """Persist that a user's devices have been updated, and which hosts
         (if any) should be poked.
         """
@@ -1127,7 +1148,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             return
 
         with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "add_device_change_to_stream",
                 self._add_device_change_to_stream_txn,
                 user_id,
@@ -1142,7 +1163,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         with self._device_list_id_gen.get_next_mult(
             len(hosts) * len(device_ids)
         ) as stream_ids:
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "add_device_outbound_poke_to_stream",
                 self._add_device_outbound_poke_to_stream_txn,
                 user_id,
@@ -1187,7 +1208,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         )
 
     def _add_device_outbound_poke_to_stream_txn(
-        self, txn, user_id, device_ids, hosts, stream_ids, context,
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        device_ids: Collection[str],
+        hosts: List[str],
+        stream_ids: List[str],
+        context: Dict[str, str],
     ):
         for host in hosts:
             txn.call_after(
@@ -1219,7 +1246,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             ],
         )
 
-    def _prune_old_outbound_device_pokes(self, prune_age=24 * 60 * 60 * 1000):
+    def _prune_old_outbound_device_pokes(self, prune_age: int = 24 * 60 * 60 * 1000):
         """Delete old entries out of the device_lists_outbound_pokes to ensure
         that we don't fill up due to dead servers.
 
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 64ddd8243d..64afd581bc 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -116,7 +116,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             retry_timings_res
         )
 
-        self.datastore.get_device_updates_by_remote.return_value = defer.succeed(
+        self.datastore.get_device_updates_by_remote.side_effect = lambda destination, from_stream_id, limit: make_awaitable(
             (0, [])
         )
 
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index c2539b353a..87ed8f8cd1 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -34,7 +34,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_store_new_device(self):
-        yield self.store.store_device("user_id", "device_id", "display_name")
+        yield defer.ensureDeferred(
+            self.store.store_device("user_id", "device_id", "display_name")
+        )
 
         res = yield self.store.get_device("user_id", "device_id")
         self.assertDictContainsSubset(
@@ -48,11 +50,17 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_devices_by_user(self):
-        yield self.store.store_device("user_id", "device1", "display_name 1")
-        yield self.store.store_device("user_id", "device2", "display_name 2")
-        yield self.store.store_device("user_id2", "device3", "display_name 3")
+        yield defer.ensureDeferred(
+            self.store.store_device("user_id", "device1", "display_name 1")
+        )
+        yield defer.ensureDeferred(
+            self.store.store_device("user_id", "device2", "display_name 2")
+        )
+        yield defer.ensureDeferred(
+            self.store.store_device("user_id2", "device3", "display_name 3")
+        )
 
-        res = yield self.store.get_devices_by_user("user_id")
+        res = yield defer.ensureDeferred(self.store.get_devices_by_user("user_id"))
         self.assertEqual(2, len(res.keys()))
         self.assertDictContainsSubset(
             {
@@ -76,13 +84,13 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
         device_ids = ["device_id1", "device_id2"]
 
         # Add two device updates with a single stream_id
-        yield self.store.add_device_change_to_streams(
-            "user_id", device_ids, ["somehost"]
+        yield defer.ensureDeferred(
+            self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
         )
 
         # Get all device updates ever meant for this remote
-        now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
-            "somehost", -1, limit=100
+        now_stream_id, device_updates = yield defer.ensureDeferred(
+            self.store.get_device_updates_by_remote("somehost", -1, limit=100)
         )
 
         # Check original device_ids are contained within these updates
@@ -99,19 +107,23 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_update_device(self):
-        yield self.store.store_device("user_id", "device_id", "display_name 1")
+        yield defer.ensureDeferred(
+            self.store.store_device("user_id", "device_id", "display_name 1")
+        )
 
         res = yield self.store.get_device("user_id", "device_id")
         self.assertEqual("display_name 1", res["display_name"])
 
         # do a no-op first
-        yield self.store.update_device("user_id", "device_id")
+        yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
         res = yield self.store.get_device("user_id", "device_id")
         self.assertEqual("display_name 1", res["display_name"])
 
         # do the update
-        yield self.store.update_device(
-            "user_id", "device_id", new_display_name="display_name 2"
+        yield defer.ensureDeferred(
+            self.store.update_device(
+                "user_id", "device_id", new_display_name="display_name 2"
+            )
         )
 
         # check it worked
@@ -121,7 +133,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
     @defer.inlineCallbacks
     def test_update_unknown_device(self):
         with self.assertRaises(synapse.api.errors.StoreError) as cm:
-            yield self.store.update_device(
-                "user_id", "unknown_device_id", new_display_name="display_name 2"
+            yield defer.ensureDeferred(
+                self.store.update_device(
+                    "user_id", "unknown_device_id", new_display_name="display_name 2"
+                )
             )
         self.assertEqual(404, cm.exception.code)
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 9f8d30373b..d57cdffd8b 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -30,7 +30,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
         now = 1470174257070
         json = {"key": "value"}
 
-        yield self.store.store_device("user", "device", None)
+        yield defer.ensureDeferred(self.store.store_device("user", "device", None))
 
         yield self.store.set_e2e_device_keys("user", "device", now, json)
 
@@ -47,7 +47,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
         now = 1470174257070
         json = {"key": "value"}
 
-        yield self.store.store_device("user", "device", None)
+        yield defer.ensureDeferred(self.store.store_device("user", "device", None))
 
         changed = yield self.store.set_e2e_device_keys("user", "device", now, json)
         self.assertTrue(changed)
@@ -63,7 +63,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
         json = {"key": "value"}
 
         yield self.store.set_e2e_device_keys("user", "device", now, json)
-        yield self.store.store_device("user", "device", "display_name")
+        yield defer.ensureDeferred(
+            self.store.store_device("user", "device", "display_name")
+        )
 
         res = yield defer.ensureDeferred(
             self.store.get_e2e_device_keys((("user", "device"),))
@@ -79,10 +81,10 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
     def test_multiple_devices(self):
         now = 1470174257070
 
-        yield self.store.store_device("user1", "device1", None)
-        yield self.store.store_device("user1", "device2", None)
-        yield self.store.store_device("user2", "device1", None)
-        yield self.store.store_device("user2", "device2", None)
+        yield defer.ensureDeferred(self.store.store_device("user1", "device1", None))
+        yield defer.ensureDeferred(self.store.store_device("user1", "device2", None))
+        yield defer.ensureDeferred(self.store.store_device("user2", "device1", None))
+        yield defer.ensureDeferred(self.store.store_device("user2", "device2", None))
 
         yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
         yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})