summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8042.misc1
-rw-r--r--synapse/storage/databases/main/devices.py12
-rw-r--r--synapse/storage/databases/main/directory.py51
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py30
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py73
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py31
-rw-r--r--tests/handlers/test_appservice.py2
-rw-r--r--tests/storage/test_directory.py32
-rw-r--r--tests/storage/test_end_to_end_keys.py12
-rw-r--r--tests/storage/test_monthly_active_users.py17
10 files changed, 141 insertions, 120 deletions
diff --git a/changelog.d/8042.misc b/changelog.d/8042.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8042.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 81e64de126..7a5f0bab05 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -136,7 +136,9 @@ class DeviceWorkerStore(SQLBaseStore):
         master_key_by_user = {}
         self_signing_key_by_user = {}
         for user in users:
-            cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
+            cross_signing_key = yield defer.ensureDeferred(
+                self.get_e2e_cross_signing_key(user, "master")
+            )
             if cross_signing_key:
                 key_id, verify_key = get_verify_key_from_cross_signing_key(
                     cross_signing_key
@@ -149,8 +151,8 @@ class DeviceWorkerStore(SQLBaseStore):
                     "device_id": verify_key.version,
                 }
 
-            cross_signing_key = yield self.get_e2e_cross_signing_key(
-                user, "self_signing"
+            cross_signing_key = yield defer.ensureDeferred(
+                self.get_e2e_cross_signing_key(user, "self_signing")
             )
             if cross_signing_key:
                 key_id, verify_key = get_verify_key_from_cross_signing_key(
@@ -246,7 +248,7 @@ class DeviceWorkerStore(SQLBaseStore):
             destination (str): The host the device updates are intended for
             from_stream_id (int): The minimum stream_id to filter updates by, exclusive
             query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping
-                user_id/device_id to update stream_id and the relevent json-encoded
+                user_id/device_id to update stream_id and the relevant json-encoded
                 opentracing context
 
         Returns:
@@ -599,7 +601,7 @@ class DeviceWorkerStore(SQLBaseStore):
             between the requested tokens due to the limit.
 
             The token returned can be used in a subsequent call to this
-            function to get further updatees.
+            function to get further updates.
 
             The updates are a list of 2-tuples of stream ID and the row data
         """
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index 7819bfcbb3..037e02603c 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -14,30 +14,29 @@
 # limitations under the License.
 
 from collections import namedtuple
-from typing import Optional
-
-from twisted.internet import defer
+from typing import Iterable, Optional
 
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore
+from synapse.types import RoomAlias
 from synapse.util.caches.descriptors import cached
 
 RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers"))
 
 
 class DirectoryWorkerStore(SQLBaseStore):
-    @defer.inlineCallbacks
-    def get_association_from_room_alias(self, room_alias):
-        """ Get's the room_id and server list for a given room_alias
+    async def get_association_from_room_alias(
+        self, room_alias: RoomAlias
+    ) -> Optional[RoomAliasMapping]:
+        """Gets the room_id and server list for a given room_alias
 
         Args:
-            room_alias (RoomAlias)
+            room_alias: The alias to translate to an ID.
 
         Returns:
-            Deferred: results in namedtuple with keys "room_id" and
-            "servers" or None if no association can be found
+            The room alias mapping or None if no association can be found.
         """
-        room_id = yield self.db_pool.simple_select_one_onecol(
+        room_id = await self.db_pool.simple_select_one_onecol(
             "room_aliases",
             {"room_alias": room_alias.to_string()},
             "room_id",
@@ -48,7 +47,7 @@ class DirectoryWorkerStore(SQLBaseStore):
         if not room_id:
             return None
 
-        servers = yield self.db_pool.simple_select_onecol(
+        servers = await self.db_pool.simple_select_onecol(
             "room_alias_servers",
             {"room_alias": room_alias.to_string()},
             "server",
@@ -79,18 +78,20 @@ class DirectoryWorkerStore(SQLBaseStore):
 
 
 class DirectoryStore(DirectoryWorkerStore):
-    @defer.inlineCallbacks
-    def create_room_alias_association(self, room_alias, room_id, servers, creator=None):
+    async def create_room_alias_association(
+        self,
+        room_alias: RoomAlias,
+        room_id: str,
+        servers: Iterable[str],
+        creator: Optional[str] = None,
+    ) -> None:
         """ Creates an association between a room alias and room_id/servers
 
         Args:
-            room_alias (RoomAlias)
-            room_id (str)
-            servers (list)
-            creator (str): Optional user_id of creator.
-
-        Returns:
-            Deferred
+            room_alias: The alias to create.
+            room_id: The target of the alias.
+            servers: A list of servers through which it may be possible to join the room
+            creator: Optional user_id of creator.
         """
 
         def alias_txn(txn):
@@ -118,24 +119,22 @@ class DirectoryStore(DirectoryWorkerStore):
             )
 
         try:
-            ret = yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "create_room_alias_association", alias_txn
             )
         except self.database_engine.module.IntegrityError:
             raise SynapseError(
                 409, "Room alias %s already exists" % room_alias.to_string()
             )
-        return ret
 
-    @defer.inlineCallbacks
-    def delete_room_alias(self, room_alias):
-        room_id = yield self.db_pool.runInteraction(
+    async def delete_room_alias(self, room_alias: RoomAlias) -> str:
+        room_id = await self.db_pool.runInteraction(
             "delete_room_alias", self._delete_room_alias_txn, room_alias
         )
 
         return room_id
 
-    def _delete_room_alias_txn(self, txn, room_alias):
+    def _delete_room_alias_txn(self, txn, room_alias: RoomAlias) -> str:
         txn.execute(
             "SELECT room_id FROM room_aliases WHERE room_alias = ?",
             (room_alias.to_string(),),
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index c4aaec3993..2eeb9f97dc 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -14,8 +14,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
-
 from synapse.api.errors import StoreError
 from synapse.logging.opentracing import log_kv, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
@@ -23,8 +21,9 @@ from synapse.util import json_encoder
 
 
 class EndToEndRoomKeyStore(SQLBaseStore):
-    @defer.inlineCallbacks
-    def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
+    async def update_e2e_room_key(
+        self, user_id, version, room_id, session_id, room_key
+    ):
         """Replaces the encrypted E2E room key for a given session in a given backup
 
         Args:
@@ -37,7 +36,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             StoreError
         """
 
-        yield self.db_pool.simple_update_one(
+        await self.db_pool.simple_update_one(
             table="e2e_room_keys",
             keyvalues={
                 "user_id": user_id,
@@ -54,8 +53,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             desc="update_e2e_room_key",
         )
 
-    @defer.inlineCallbacks
-    def add_e2e_room_keys(self, user_id, version, room_keys):
+    async def add_e2e_room_keys(self, user_id, version, room_keys):
         """Bulk add room keys to a given backup.
 
         Args:
@@ -88,13 +86,12 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                 }
             )
 
-        yield self.db_pool.simple_insert_many(
+        await self.db_pool.simple_insert_many(
             table="e2e_room_keys", values=values, desc="add_e2e_room_keys"
         )
 
     @trace
-    @defer.inlineCallbacks
-    def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
+    async def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
         """Bulk get the E2E room keys for a given backup, optionally filtered to a given
         room, or a given session.
 
@@ -109,7 +106,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                 the backup (or for the specified room)
 
         Returns:
-            A deferred list of dicts giving the session_data and message metadata for
+            A list of dicts giving the session_data and message metadata for
             these room keys.
         """
 
@@ -124,7 +121,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             if session_id:
                 keyvalues["session_id"] = session_id
 
-        rows = yield self.db_pool.simple_select_list(
+        rows = await self.db_pool.simple_select_list(
             table="e2e_room_keys",
             keyvalues=keyvalues,
             retcols=(
@@ -242,8 +239,9 @@ class EndToEndRoomKeyStore(SQLBaseStore):
         )
 
     @trace
-    @defer.inlineCallbacks
-    def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
+    async def delete_e2e_room_keys(
+        self, user_id, version, room_id=None, session_id=None
+    ):
         """Bulk delete the E2E room keys for a given backup, optionally filtered to a given
         room or a given session.
 
@@ -258,7 +256,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                 the backup (or for the specified room)
 
         Returns:
-            A deferred of the deletion transaction
+            The deletion transaction
         """
 
         keyvalues = {"user_id": user_id, "version": int(version)}
@@ -267,7 +265,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             if session_id:
                 keyvalues["session_id"] = session_id
 
-        yield self.db_pool.simple_delete(
+        await self.db_pool.simple_delete(
             table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys"
         )
 
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 6126376a6f..f93e0d320d 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -14,12 +14,11 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Dict, List, Tuple
+from typing import Dict, Iterable, List, Optional, Tuple
 
 from canonicaljson import encode_canonical_json
 
 from twisted.enterprise.adbapi import Connection
-from twisted.internet import defer
 
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
@@ -31,8 +30,7 @@ from synapse.util.iterutils import batch_iter
 
 class EndToEndKeyWorkerStore(SQLBaseStore):
     @trace
-    @defer.inlineCallbacks
-    def get_e2e_device_keys(
+    async def get_e2e_device_keys(
         self, query_list, include_all_devices=False, include_deleted_devices=False
     ):
         """Fetch a list of device keys.
@@ -52,7 +50,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         if not query_list:
             return {}
 
-        results = yield self.db_pool.runInteraction(
+        results = await self.db_pool.runInteraction(
             "get_e2e_device_keys",
             self._get_e2e_device_keys_txn,
             query_list,
@@ -175,8 +173,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         log_kv(result)
         return result
 
-    @defer.inlineCallbacks
-    def get_e2e_one_time_keys(self, user_id, device_id, key_ids):
+    async def get_e2e_one_time_keys(
+        self, user_id: str, device_id: str, key_ids: List[str]
+    ) -> Dict[Tuple[str, str], str]:
         """Retrieve a number of one-time keys for a user
 
         Args:
@@ -186,11 +185,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
                 retrieve
 
         Returns:
-            deferred resolving to Dict[(str, str), str]: map from (algorithm,
-            key_id) to json string for key
+            A map from (algorithm, key_id) to json string for key
         """
 
-        rows = yield self.db_pool.simple_select_many_batch(
+        rows = await self.db_pool.simple_select_many_batch(
             table="e2e_one_time_keys_json",
             column="key_id",
             iterable=key_ids,
@@ -202,17 +200,21 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         log_kv({"message": "Fetched one time keys for user", "one_time_keys": result})
         return result
 
-    @defer.inlineCallbacks
-    def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
+    async def add_e2e_one_time_keys(
+        self,
+        user_id: str,
+        device_id: str,
+        time_now: int,
+        new_keys: Iterable[Tuple[str, str, str]],
+    ) -> None:
         """Insert some new one time keys for a device. Errors if any of the
         keys already exist.
 
         Args:
-            user_id(str): id of user to get keys for
-            device_id(str): id of device to get keys for
-            time_now(long): insertion time to record (ms since epoch)
-            new_keys(iterable[(str, str, str)]: keys to add - each a tuple of
-                (algorithm, key_id, key json)
+            user_id: id of user to get keys for
+            device_id: id of device to get keys for
+            time_now: insertion time to record (ms since epoch)
+            new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
         """
 
         def _add_e2e_one_time_keys(txn):
@@ -242,7 +244,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
                 txn, self.count_e2e_one_time_keys, (user_id, device_id)
             )
 
-        yield self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
         )
 
@@ -269,22 +271,23 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             "count_e2e_one_time_keys", _count_e2e_one_time_keys
         )
 
-    @defer.inlineCallbacks
-    def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None):
+    async def get_e2e_cross_signing_key(
+        self, user_id: str, key_type: str, from_user_id: Optional[str] = None
+    ) -> Optional[dict]:
         """Returns a user's cross-signing key.
 
         Args:
-            user_id (str): the user whose key is being requested
-            key_type (str): the type of key that is being requested: either 'master'
+            user_id: the user whose key is being requested
+            key_type: the type of key that is being requested: either 'master'
                 for a master key, 'self_signing' for a self-signing key, or
                 'user_signing' for a user-signing key
-            from_user_id (str): if specified, signatures made by this user on
+            from_user_id: if specified, signatures made by this user on
                 the self-signing key will be included in the result
 
         Returns:
             dict of the key data or None if not found
         """
-        res = yield self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id)
+        res = await self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id)
         user_keys = res.get(user_id)
         if not user_keys:
             return None
@@ -450,28 +453,26 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
 
         return keys
 
-    @defer.inlineCallbacks
-    def get_e2e_cross_signing_keys_bulk(
-        self, user_ids: List[str], from_user_id: str = None
-    ) -> defer.Deferred:
+    async def get_e2e_cross_signing_keys_bulk(
+        self, user_ids: List[str], from_user_id: Optional[str] = None
+    ) -> Dict[str, Dict[str, dict]]:
         """Returns the cross-signing keys for a set of users.
 
         Args:
-            user_ids (list[str]): the users whose keys are being requested
-            from_user_id (str): if specified, signatures made by this user on
+            user_ids: the users whose keys are being requested
+            from_user_id: if specified, signatures made by this user on
                 the self-signing keys will be included in the result
 
         Returns:
-            Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to
-                key data.  If a user's cross-signing keys were not found, either
-                their user ID will not be in the dict, or their user ID will map
-                to None.
+            A map of user ID to key type to key data.  If a user's cross-signing
+            keys were not found, either their user ID will not be in the dict,
+            or their user ID will map to None.
         """
 
-        result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
+        result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
 
         if from_user_id:
-            result = yield self.db_pool.runInteraction(
+            result = await self.db_pool.runInteraction(
                 "get_e2e_cross_signing_signatures",
                 self._get_e2e_cross_signing_signatures_txn,
                 result,
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 02b01d9619..e71cdd2cb4 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -15,8 +15,6 @@
 import logging
 from typing import List
 
-from twisted.internet import defer
-
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool, make_in_list_sql_clause
 from synapse.util.caches.descriptors import cached
@@ -252,16 +250,12 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
             "reap_monthly_active_users", _reap_users, reserved_users
         )
 
-    @defer.inlineCallbacks
-    def upsert_monthly_active_user(self, user_id):
+    async def upsert_monthly_active_user(self, user_id: str) -> None:
         """Updates or inserts the user into the monthly active user table, which
         is used to track the current MAU usage of the server
 
         Args:
-            user_id (str): user to add/update
-
-        Returns:
-            Deferred
+            user_id: user to add/update
         """
         # Support user never to be included in MAU stats. Note I can't easily call this
         # from upsert_monthly_active_user_txn because then I need a _txn form of
@@ -271,11 +265,11 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
         # _initialise_reserved_users reasoning that it would be very strange to
         #  include a support user in this context.
 
-        is_support = yield self.is_support_user(user_id)
+        is_support = await self.is_support_user(user_id)
         if is_support:
             return
 
-        yield self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
         )
 
@@ -322,8 +316,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
 
         return is_insert
 
-    @defer.inlineCallbacks
-    def populate_monthly_active_users(self, user_id):
+    async def populate_monthly_active_users(self, user_id):
         """Checks on the state of monthly active user limits and optionally
         add the user to the monthly active tables
 
@@ -332,14 +325,14 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
         """
         if self._limit_usage_by_mau or self._mau_stats_only:
             # Trial users and guests should not be included as part of MAU group
-            is_guest = yield self.is_guest(user_id)
+            is_guest = await self.is_guest(user_id)
             if is_guest:
                 return
-            is_trial = yield self.is_trial_user(user_id)
+            is_trial = await self.is_trial_user(user_id)
             if is_trial:
                 return
 
-            last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id)
+            last_seen_timestamp = await self.user_last_seen_monthly_active(user_id)
             now = self.hs.get_clock().time_msec()
 
             # We want to reduce to the total number of db writes, and are happy
@@ -352,10 +345,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
                 # False, there is no point in checking get_monthly_active_count - it
                 # adds no value and will break the logic if max_mau_value is exceeded.
                 if not self._limit_usage_by_mau:
-                    yield self.upsert_monthly_active_user(user_id)
+                    await self.upsert_monthly_active_user(user_id)
                 else:
-                    count = yield self.get_monthly_active_count()
+                    count = await self.get_monthly_active_count()
                     if count < self._max_mau_value:
-                        yield self.upsert_monthly_active_user(user_id)
+                        await self.upsert_monthly_active_user(user_id)
             elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY:
-                yield self.upsert_monthly_active_user(user_id)
+                await self.upsert_monthly_active_user(user_id)
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 628f7d8db0..2a0b7c1b56 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -120,7 +120,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
 
         self.mock_as_api.query_alias.return_value = make_awaitable(True)
         self.mock_store.get_app_services.return_value = services
-        self.mock_store.get_association_from_room_alias.return_value = defer.succeed(
+        self.mock_store.get_association_from_room_alias.return_value = make_awaitable(
             Mock(room_id=room_id, servers=servers)
         )
 
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index 4e128e1047..daac947cb2 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -34,8 +34,10 @@ class DirectoryStoreTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_room_to_alias(self):
-        yield self.store.create_room_alias_association(
-            room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+        yield defer.ensureDeferred(
+            self.store.create_room_alias_association(
+                room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+            )
         )
 
         self.assertEquals(
@@ -45,24 +47,36 @@ class DirectoryStoreTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_alias_to_room(self):
-        yield self.store.create_room_alias_association(
-            room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+        yield defer.ensureDeferred(
+            self.store.create_room_alias_association(
+                room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+            )
         )
 
         self.assertObjectHasAttributes(
             {"room_id": self.room.to_string(), "servers": ["test"]},
-            (yield self.store.get_association_from_room_alias(self.alias)),
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_association_from_room_alias(self.alias)
+                )
+            ),
         )
 
     @defer.inlineCallbacks
     def test_delete_alias(self):
-        yield self.store.create_room_alias_association(
-            room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+        yield defer.ensureDeferred(
+            self.store.create_room_alias_association(
+                room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+            )
         )
 
-        room_id = yield self.store.delete_room_alias(self.alias)
+        room_id = yield defer.ensureDeferred(self.store.delete_room_alias(self.alias))
         self.assertEqual(self.room.to_string(), room_id)
 
         self.assertIsNone(
-            (yield self.store.get_association_from_room_alias(self.alias))
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_association_from_room_alias(self.alias)
+                )
+            )
         )
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 398d546280..9f8d30373b 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -34,7 +34,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
 
         yield self.store.set_e2e_device_keys("user", "device", now, json)
 
-        res = yield self.store.get_e2e_device_keys((("user", "device"),))
+        res = yield defer.ensureDeferred(
+            self.store.get_e2e_device_keys((("user", "device"),))
+        )
         self.assertIn("user", res)
         self.assertIn("device", res["user"])
         dev = res["user"]["device"]
@@ -63,7 +65,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
         yield self.store.set_e2e_device_keys("user", "device", now, json)
         yield self.store.store_device("user", "device", "display_name")
 
-        res = yield self.store.get_e2e_device_keys((("user", "device"),))
+        res = yield defer.ensureDeferred(
+            self.store.get_e2e_device_keys((("user", "device"),))
+        )
         self.assertIn("user", res)
         self.assertIn("device", res["user"])
         dev = res["user"]["device"]
@@ -85,8 +89,8 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
         yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
         yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
 
-        res = yield self.store.get_e2e_device_keys(
-            (("user1", "device1"), ("user2", "device2"))
+        res = yield defer.ensureDeferred(
+            self.store.get_e2e_device_keys((("user1", "device1"), ("user2", "device2")))
         )
         self.assertIn("user1", res)
         self.assertIn("device1", res["user1"])
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 259f2215f1..e793781a26 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
 from synapse.api.constants import UserTypes
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 from tests.unittest import default_config, override_config
 
 FORTY_DAYS = 40 * 24 * 60 * 60
@@ -230,7 +231,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         )
         self.get_success(d)
 
-        self.store.upsert_monthly_active_user = Mock()
+        self.store.upsert_monthly_active_user = Mock(
+            side_effect=lambda user_id: make_awaitable(None)
+        )
 
         d = self.store.populate_monthly_active_users(user_id)
         self.get_success(d)
@@ -238,7 +241,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         self.store.upsert_monthly_active_user.assert_not_called()
 
     def test_populate_monthly_users_should_update(self):
-        self.store.upsert_monthly_active_user = Mock()
+        self.store.upsert_monthly_active_user = Mock(
+            side_effect=lambda user_id: make_awaitable(None)
+        )
 
         self.store.is_trial_user = Mock(return_value=defer.succeed(False))
 
@@ -251,7 +256,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         self.store.upsert_monthly_active_user.assert_called_once()
 
     def test_populate_monthly_users_should_not_update(self):
-        self.store.upsert_monthly_active_user = Mock()
+        self.store.upsert_monthly_active_user = Mock(
+            side_effect=lambda user_id: make_awaitable(None)
+        )
 
         self.store.is_trial_user = Mock(return_value=defer.succeed(False))
         self.store.user_last_seen_monthly_active = Mock(
@@ -333,7 +340,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
 
     @override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
     def test_no_users_when_not_tracking(self):
-        self.store.upsert_monthly_active_user = Mock()
+        self.store.upsert_monthly_active_user = Mock(
+            side_effect=lambda user_id: make_awaitable(None)
+        )
 
         self.get_success(self.store.populate_monthly_active_users("@user:sever"))