summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-08-27 07:08:38 -0400
committerGitHub <noreply@github.com>2020-08-27 07:08:38 -0400
commit4a739c73b404284253a548f60197e70c6c385645 (patch)
tree6521f9d972250b5bf50748b691a54c0a7e9d0282
parentReduce run-times of tests by advancing the reactor less (#7757) (diff)
downloadsynapse-4a739c73b404284253a548f60197e70c6c385645.tar.xz
Convert simple_update* and simple_select* to async (#8173)
Diffstat (limited to '')
-rw-r--r--changelog.d/8173.misc1
-rw-r--r--synapse/handlers/room.py6
-rw-r--r--synapse/storage/database.py29
-rw-r--r--synapse/storage/databases/main/__init__.py8
-rw-r--r--synapse/storage/databases/main/directory.py6
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py26
-rw-r--r--synapse/storage/databases/main/event_federation.py4
-rw-r--r--synapse/storage/databases/main/group_server.py55
-rw-r--r--synapse/storage/databases/main/media_repository.py16
-rw-r--r--synapse/storage/databases/main/profile.py18
-rw-r--r--synapse/storage/databases/main/receipts.py8
-rw-r--r--synapse/storage/databases/main/registration.py22
-rw-r--r--synapse/storage/databases/main/room.py4
-rw-r--r--synapse/storage/databases/main/user_directory.py4
-rw-r--r--tests/handlers/test_stats.py4
-rw-r--r--tests/storage/test_base.py26
-rw-r--r--tests/storage/test_directory.py6
-rw-r--r--tests/storage/test_main.py4
-rw-r--r--tests/test_federation.py50
19 files changed, 164 insertions, 133 deletions
diff --git a/changelog.d/8173.misc b/changelog.d/8173.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8173.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index e4788ef86b..236a37f777 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -51,7 +51,7 @@ from synapse.types import (
     create_requester,
 )
 from synapse.util import stringutils
-from synapse.util.async_helpers import Linearizer, maybe_awaitable
+from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.response_cache import ResponseCache
 from synapse.visibility import filter_events_for_client
 
@@ -1329,9 +1329,7 @@ class RoomShutdownHandler(object):
                 ratelimit=False,
             )
 
-            aliases_for_room = await maybe_awaitable(
-                self.store.get_aliases_for_room(room_id)
-            )
+            aliases_for_room = await self.store.get_aliases_for_room(room_id)
 
             await self.store.update_aliases_for_room(
                 room_id, new_room_id, requester_user_id
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 181c3ec249..38010af600 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -1132,13 +1132,13 @@ class DatabasePool(object):
 
         return [r[0] for r in txn]
 
-    def simple_select_onecol(
+    async def simple_select_onecol(
         self,
         table: str,
         keyvalues: Optional[Dict[str, Any]],
         retcol: str,
         desc: str = "simple_select_onecol",
-    ) -> defer.Deferred:
+    ) -> List[Any]:
         """Executes a SELECT query on the named table, which returns a list
         comprising of the values of the named column from the selected rows.
 
@@ -1148,19 +1148,19 @@ class DatabasePool(object):
             retcol: column whos value we wish to retrieve.
 
         Returns:
-            Deferred: Results in a list
+            Results in a list
         """
-        return self.runInteraction(
+        return await self.runInteraction(
             desc, self.simple_select_onecol_txn, table, keyvalues, retcol
         )
 
-    def simple_select_list(
+    async def simple_select_list(
         self,
         table: str,
         keyvalues: Optional[Dict[str, Any]],
         retcols: Iterable[str],
         desc: str = "simple_select_list",
-    ) -> defer.Deferred:
+    ) -> List[Dict[str, Any]]:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
@@ -1170,10 +1170,11 @@ class DatabasePool(object):
                 column names and values to select the rows with, or None to not
                 apply a WHERE clause.
             retcols: the names of the columns to return
+
         Returns:
-            defer.Deferred: resolves to list[dict[str, Any]]
+            A list of dictionaries.
         """
-        return self.runInteraction(
+        return await self.runInteraction(
             desc, self.simple_select_list_txn, table, keyvalues, retcols
         )
 
@@ -1299,14 +1300,14 @@ class DatabasePool(object):
         txn.execute(sql, values)
         return cls.cursor_to_dict(txn)
 
-    def simple_update(
+    async def simple_update(
         self,
         table: str,
         keyvalues: Dict[str, Any],
         updatevalues: Dict[str, Any],
         desc: str,
-    ) -> defer.Deferred:
-        return self.runInteraction(
+    ) -> int:
+        return await self.runInteraction(
             desc, self.simple_update_txn, table, keyvalues, updatevalues
         )
 
@@ -1332,13 +1333,13 @@ class DatabasePool(object):
 
         return txn.rowcount
 
-    def simple_update_one(
+    async def simple_update_one(
         self,
         table: str,
         keyvalues: Dict[str, Any],
         updatevalues: Dict[str, Any],
         desc: str = "simple_update_one",
-    ) -> defer.Deferred:
+    ) -> None:
         """Executes an UPDATE query on the named table, setting new values for
         columns in a row matching the key values.
 
@@ -1347,7 +1348,7 @@ class DatabasePool(object):
             keyvalues: dict of column names and values to select the row with
             updatevalues: dict giving column names and values to update
         """
-        return self.runInteraction(
+        await self.runInteraction(
             desc, self.simple_update_one_txn, table, keyvalues, updatevalues
         )
 
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 0934ae276c..8b9b6eb472 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -18,6 +18,7 @@
 import calendar
 import logging
 import time
+from typing import Any, Dict, List
 
 from synapse.api.constants import PresenceState
 from synapse.config.homeserver import HomeServerConfig
@@ -476,14 +477,13 @@ class DataStore(
             "generate_user_daily_visits", _generate_user_daily_visits
         )
 
-    def get_users(self):
+    async def get_users(self) -> List[Dict[str, Any]]:
         """Function to retrieve a list of users in users table.
 
-        Args:
         Returns:
-            defer.Deferred: resolves to list[dict[str, Any]]
+            A list of dictionaries representing users.
         """
-        return self.db_pool.simple_select_list(
+        return await self.db_pool.simple_select_list(
             table="users",
             keyvalues={},
             retcols=[
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index 301d5d845a..405b5eafa5 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 from collections import namedtuple
-from typing import Iterable, Optional
+from typing import Iterable, List, Optional
 
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore
@@ -68,8 +68,8 @@ class DirectoryWorkerStore(SQLBaseStore):
         )
 
     @cached(max_entries=5000)
-    def get_aliases_for_room(self, room_id):
-        return self.db_pool.simple_select_onecol(
+    async def get_aliases_for_room(self, room_id: str) -> List[str]:
+        return await self.db_pool.simple_select_onecol(
             "room_aliases",
             {"room_id": room_id},
             "room_alias",
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index 46c3e33cc6..82f9d870fd 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -14,6 +14,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import Optional
+
 from synapse.api.errors import StoreError
 from synapse.logging.opentracing import log_kv, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
@@ -368,18 +370,22 @@ class EndToEndRoomKeyStore(SQLBaseStore):
         )
 
     @trace
-    def update_e2e_room_keys_version(
-        self, user_id, version, info=None, version_etag=None
-    ):
+    async def update_e2e_room_keys_version(
+        self,
+        user_id: str,
+        version: str,
+        info: Optional[dict] = None,
+        version_etag: Optional[int] = None,
+    ) -> None:
         """Update a given backup version
 
         Args:
-            user_id(str): the user whose backup version we're updating
-            version(str): the version ID of the backup version we're updating
-            info (dict): the new backup version info to store.  If None, then
-                the backup version info is not updated
-            version_etag (Optional[int]): etag of the keys in the backup.  If
-                None, then the etag is not updated
+            user_id: the user whose backup version we're updating
+            version: the version ID of the backup version we're updating
+            info: the new backup version info to store. If None, then the backup
+                version info is not updated.
+            version_etag: etag of the keys in the backup. If None, then the etag
+                is not updated.
         """
         updatevalues = {}
 
@@ -389,7 +395,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             updatevalues["etag"] = version_etag
 
         if updatevalues:
-            return self.db_pool.simple_update(
+            await self.db_pool.simple_update(
                 table="e2e_room_keys_versions",
                 keyvalues={"user_id": user_id, "version": version},
                 updatevalues=updatevalues,
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index e6a97b018c..6e5761c7b7 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -368,8 +368,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         )
 
     @cached(max_entries=5000, iterable=True)
-    def get_latest_event_ids_in_room(self, room_id):
-        return self.db_pool.simple_select_onecol(
+    async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]:
+        return await self.db_pool.simple_select_onecol(
             table="event_forward_extremities",
             keyvalues={"room_id": room_id},
             retcol="event_id",
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index c39864f59f..e3ead71853 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -44,24 +44,26 @@ class GroupServerWorkerStore(SQLBaseStore):
             desc="get_group",
         )
 
-    def get_users_in_group(self, group_id, include_private=False):
+    async def get_users_in_group(
+        self, group_id: str, include_private: bool = False
+    ) -> List[Dict[str, Any]]:
         # TODO: Pagination
 
         keyvalues = {"group_id": group_id}
         if not include_private:
             keyvalues["is_public"] = True
 
-        return self.db_pool.simple_select_list(
+        return await self.db_pool.simple_select_list(
             table="group_users",
             keyvalues=keyvalues,
             retcols=("user_id", "is_public", "is_admin"),
             desc="get_users_in_group",
         )
 
-    def get_invited_users_in_group(self, group_id):
+    async def get_invited_users_in_group(self, group_id: str) -> List[str]:
         # TODO: Pagination
 
-        return self.db_pool.simple_select_onecol(
+        return await self.db_pool.simple_select_onecol(
             table="group_invites",
             keyvalues={"group_id": group_id},
             retcol="user_id",
@@ -265,15 +267,14 @@ class GroupServerWorkerStore(SQLBaseStore):
 
         return role
 
-    def get_local_groups_for_room(self, room_id):
+    async def get_local_groups_for_room(self, room_id: str) -> List[str]:
         """Get all of the local group that contain a given room
         Args:
-            room_id (str): The ID of a room
+            room_id: The ID of a room
         Returns:
-            Deferred[list[str]]: A twisted.Deferred containing a list of group ids
-                containing this room
+            A list of group ids containing this room
         """
-        return self.db_pool.simple_select_onecol(
+        return await self.db_pool.simple_select_onecol(
             table="group_rooms",
             keyvalues={"room_id": room_id},
             retcol="group_id",
@@ -422,10 +423,10 @@ class GroupServerWorkerStore(SQLBaseStore):
             "get_users_membership_info_in_group", _get_users_membership_in_group_txn
         )
 
-    def get_publicised_groups_for_user(self, user_id):
+    async def get_publicised_groups_for_user(self, user_id: str) -> List[str]:
         """Get all groups a user is publicising
         """
-        return self.db_pool.simple_select_onecol(
+        return await self.db_pool.simple_select_onecol(
             table="local_group_membership",
             keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
             retcol="group_id",
@@ -466,8 +467,8 @@ class GroupServerWorkerStore(SQLBaseStore):
 
         return None
 
-    def get_joined_groups(self, user_id):
-        return self.db_pool.simple_select_onecol(
+    async def get_joined_groups(self, user_id: str) -> List[str]:
+        return await self.db_pool.simple_select_onecol(
             table="local_group_membership",
             keyvalues={"user_id": user_id, "membership": "join"},
             retcol="group_id",
@@ -585,14 +586,14 @@ class GroupServerWorkerStore(SQLBaseStore):
 
 
 class GroupServerStore(GroupServerWorkerStore):
-    def set_group_join_policy(self, group_id, join_policy):
+    async def set_group_join_policy(self, group_id: str, join_policy: str) -> None:
         """Set the join policy of a group.
 
         join_policy can be one of:
          * "invite"
          * "open"
         """
-        return self.db_pool.simple_update_one(
+        await self.db_pool.simple_update_one(
             table="groups",
             keyvalues={"group_id": group_id},
             updatevalues={"join_policy": join_policy},
@@ -1050,8 +1051,10 @@ class GroupServerStore(GroupServerWorkerStore):
             desc="add_room_to_group",
         )
 
-    def update_room_in_group_visibility(self, group_id, room_id, is_public):
-        return self.db_pool.simple_update(
+    async def update_room_in_group_visibility(
+        self, group_id: str, room_id: str, is_public: bool
+    ) -> int:
+        return await self.db_pool.simple_update(
             table="group_rooms",
             keyvalues={"group_id": group_id, "room_id": room_id},
             updatevalues={"is_public": is_public},
@@ -1076,10 +1079,12 @@ class GroupServerStore(GroupServerWorkerStore):
             "remove_room_from_group", _remove_room_from_group_txn
         )
 
-    def update_group_publicity(self, group_id, user_id, publicise):
+    async def update_group_publicity(
+        self, group_id: str, user_id: str, publicise: bool
+    ) -> None:
         """Update whether the user is publicising their membership of the group
         """
-        return self.db_pool.simple_update_one(
+        await self.db_pool.simple_update_one(
             table="local_group_membership",
             keyvalues={"group_id": group_id, "user_id": user_id},
             updatevalues={"is_publicised": publicise},
@@ -1218,20 +1223,24 @@ class GroupServerStore(GroupServerWorkerStore):
             desc="update_group_profile",
         )
 
-    def update_attestation_renewal(self, group_id, user_id, attestation):
+    async def update_attestation_renewal(
+        self, group_id: str, user_id: str, attestation: dict
+    ) -> None:
         """Update an attestation that we have renewed
         """
-        return self.db_pool.simple_update_one(
+        await self.db_pool.simple_update_one(
             table="group_attestations_renewals",
             keyvalues={"group_id": group_id, "user_id": user_id},
             updatevalues={"valid_until_ms": attestation["valid_until_ms"]},
             desc="update_attestation_renewal",
         )
 
-    def update_remote_attestion(self, group_id, user_id, attestation):
+    async def update_remote_attestion(
+        self, group_id: str, user_id: str, attestation: dict
+    ) -> None:
         """Update an attestation that a remote has renewed
         """
-        return self.db_pool.simple_update_one(
+        await self.db_pool.simple_update_one(
             table="group_attestations_remote",
             keyvalues={"group_id": group_id, "user_id": user_id},
             updatevalues={
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 4ae255ebd8..fc223f5a2a 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Any, Dict, Optional
+from typing import Any, Dict, List, Optional
 
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
@@ -84,9 +84,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="store_local_media",
         )
 
-    def mark_local_media_as_safe(self, media_id: str):
+    async def mark_local_media_as_safe(self, media_id: str) -> None:
         """Mark a local media as safe from quarantining."""
-        return self.db_pool.simple_update_one(
+        await self.db_pool.simple_update_one(
             table="local_media_repository",
             keyvalues={"media_id": media_id},
             updatevalues={"safe_from_quarantine": True},
@@ -158,8 +158,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="store_url_cache",
         )
 
-    def get_local_media_thumbnails(self, media_id):
-        return self.db_pool.simple_select_list(
+    async def get_local_media_thumbnails(self, media_id: str) -> List[Dict[str, Any]]:
+        return await self.db_pool.simple_select_list(
             "local_media_repository_thumbnails",
             {"media_id": media_id},
             (
@@ -271,8 +271,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             "update_cached_last_access_time", update_cache_txn
         )
 
-    def get_remote_media_thumbnails(self, origin, media_id):
-        return self.db_pool.simple_select_list(
+    async def get_remote_media_thumbnails(
+        self, origin: str, media_id: str
+    ) -> List[Dict[str, Any]]:
+        return await self.db_pool.simple_select_list(
             "remote_media_cache_thumbnails",
             {"media_origin": origin, "media_id": media_id},
             (
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index b8233c4848..858fd92420 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -71,16 +71,20 @@ class ProfileWorkerStore(SQLBaseStore):
             table="profiles", values={"user_id": user_localpart}, desc="create_profile"
         )
 
-    def set_profile_displayname(self, user_localpart, new_displayname):
-        return self.db_pool.simple_update_one(
+    async def set_profile_displayname(
+        self, user_localpart: str, new_displayname: str
+    ) -> None:
+        await self.db_pool.simple_update_one(
             table="profiles",
             keyvalues={"user_id": user_localpart},
             updatevalues={"displayname": new_displayname},
             desc="set_profile_displayname",
         )
 
-    def set_profile_avatar_url(self, user_localpart, new_avatar_url):
-        return self.db_pool.simple_update_one(
+    async def set_profile_avatar_url(
+        self, user_localpart: str, new_avatar_url: str
+    ) -> None:
+        await self.db_pool.simple_update_one(
             table="profiles",
             keyvalues={"user_id": user_localpart},
             updatevalues={"avatar_url": new_avatar_url},
@@ -106,8 +110,10 @@ class ProfileStore(ProfileWorkerStore):
             desc="add_remote_profile_cache",
         )
 
-    def update_remote_profile_cache(self, user_id, displayname, avatar_url):
-        return self.db_pool.simple_update(
+    async def update_remote_profile_cache(
+        self, user_id: str, displayname: str, avatar_url: str
+    ) -> int:
+        return await self.db_pool.simple_update(
             table="remote_profile_cache",
             keyvalues={"user_id": user_id},
             updatevalues={
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index cea5ac9a68..436f22ad2d 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -16,7 +16,7 @@
 
 import abc
 import logging
-from typing import List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
 
 from twisted.internet import defer
 
@@ -62,8 +62,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
         return {r["user_id"] for r in receipts}
 
     @cached(num_args=2)
-    def get_receipts_for_room(self, room_id, receipt_type):
-        return self.db_pool.simple_select_list(
+    async def get_receipts_for_room(
+        self, room_id: str, receipt_type: str
+    ) -> List[Dict[str, Any]]:
+        return await self.db_pool.simple_select_list(
             table="receipts_linearized",
             keyvalues={"room_id": room_id, "receipt_type": receipt_type},
             retcols=("user_id", "event_id"),
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index eced53d470..48bda66f3e 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -578,20 +578,20 @@ class RegistrationWorkerStore(SQLBaseStore):
             desc="add_user_bound_threepid",
         )
 
-    def user_get_bound_threepids(self, user_id):
+    async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]:
         """Get the threepids that a user has bound to an identity server through the homeserver
         The homeserver remembers where binds to an identity server occurred. Using this
         method can retrieve those threepids.
 
         Args:
-            user_id (str): The ID of the user to retrieve threepids for
+            user_id: The ID of the user to retrieve threepids for
 
         Returns:
-            Deferred[list[dict]]: List of dictionaries containing the following:
+            List of dictionaries containing the following keys:
                 medium (str): The medium of the threepid (e.g "email")
                 address (str): The address of the threepid (e.g "bob@example.com")
         """
-        return self.db_pool.simple_select_list(
+        return await self.db_pool.simple_select_list(
             table="user_threepid_id_server",
             keyvalues={"user_id": user_id},
             retcols=["medium", "address"],
@@ -623,19 +623,21 @@ class RegistrationWorkerStore(SQLBaseStore):
             desc="remove_user_bound_threepid",
         )
 
-    def get_id_servers_user_bound(self, user_id, medium, address):
+    async def get_id_servers_user_bound(
+        self, user_id: str, medium: str, address: str
+    ) -> List[str]:
         """Get the list of identity servers that the server proxied bind
         requests to for given user and threepid
 
         Args:
-            user_id (str)
-            medium (str)
-            address (str)
+            user_id: The user to query for identity servers.
+            medium: The medium to query for identity servers.
+            address: The address to query for identity servers.
 
         Returns:
-            Deferred[list[str]]: Resolves to a list of identity servers
+            A list of identity servers
         """
-        return self.db_pool.simple_select_onecol(
+        return await self.db_pool.simple_select_onecol(
             table="user_threepid_id_server",
             keyvalues={"user_id": user_id, "medium": medium, "address": address},
             retcol="id_server",
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 97ecdb16e4..66d7135413 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -125,8 +125,8 @@ class RoomWorkerStore(SQLBaseStore):
             "get_room_with_stats", get_room_with_stats_txn, room_id
         )
 
-    def get_public_room_ids(self):
-        return self.db_pool.simple_select_onecol(
+    async def get_public_room_ids(self) -> List[str]:
+        return await self.db_pool.simple_select_onecol(
             table="rooms",
             keyvalues={"is_public": True},
             retcol="room_id",
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 20cbcd851c..a9f2e93614 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -537,8 +537,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             desc="get_user_in_directory",
         )
 
-    def update_user_directory_stream_pos(self, stream_id):
-        return self.db_pool.simple_update_one(
+    async def update_user_directory_stream_pos(self, stream_id: str) -> None:
+        await self.db_pool.simple_update_one(
             table="user_directory_stream_pos",
             keyvalues={},
             updatevalues={"stream_id": stream_id},
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index 88b05c23a0..a609f148c0 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -81,8 +81,8 @@ class StatsRoomTests(unittest.HomeserverTestCase):
             )
         )
 
-    def get_all_room_state(self):
-        return self.store.db_pool.simple_select_list(
+    async def get_all_room_state(self):
+        return await self.store.db_pool.simple_select_list(
             "room_stats_state", None, retcols=("name", "topic", "canonical_alias")
         )
 
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index bf22540d99..64abe8cc49 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -148,8 +148,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
         self.mock_txn.description = (("colA", None, None, None, None, None, None),)
 
-        ret = yield self.datastore.db_pool.simple_select_list(
-            table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
+        ret = yield defer.ensureDeferred(
+            self.datastore.db_pool.simple_select_list(
+                table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
+            )
         )
 
         self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret)
@@ -161,10 +163,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_update_one_1col(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore.db_pool.simple_update_one(
-            table="tablename",
-            keyvalues={"keycol": "TheKey"},
-            updatevalues={"columnname": "New Value"},
+        yield defer.ensureDeferred(
+            self.datastore.db_pool.simple_update_one(
+                table="tablename",
+                keyvalues={"keycol": "TheKey"},
+                updatevalues={"columnname": "New Value"},
+            )
         )
 
         self.mock_txn.execute.assert_called_with(
@@ -176,10 +180,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_update_one_4cols(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore.db_pool.simple_update_one(
-            table="tablename",
-            keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
-            updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
+        yield defer.ensureDeferred(
+            self.datastore.db_pool.simple_update_one(
+                table="tablename",
+                keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
+                updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
+            )
         )
 
         self.mock_txn.execute.assert_called_with(
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index daac947cb2..da93ca3980 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -42,7 +42,11 @@ class DirectoryStoreTestCase(unittest.TestCase):
 
         self.assertEquals(
             ["#my-room:test"],
-            (yield self.store.get_aliases_for_room(self.room.to_string())),
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_aliases_for_room(self.room.to_string())
+                )
+            ),
         )
 
     @defer.inlineCallbacks
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index fbf8af940a..954338a592 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -36,7 +36,9 @@ class DataStoreTestCase(unittest.TestCase):
     def test_get_users_paginate(self):
         yield self.store.register_user(self.user.to_string(), "pass")
         yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
-        yield self.store.set_profile_displayname(self.user.localpart, self.displayname)
+        yield defer.ensureDeferred(
+            self.store.set_profile_displayname(self.user.localpart, self.displayname)
+        )
 
         users, total = yield self.store.get_users_paginate(
             0, 10, name="bc", guests=False
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 4a4548433f..27a7fc9ed7 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -15,8 +15,9 @@
 
 from mock import Mock
 
-from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed
+from twisted.internet.defer import succeed
 
+from synapse.api.errors import FederationError
 from synapse.events import make_event_from_dict
 from synapse.logging.context import LoggingContext
 from synapse.types import Requester, UserID
@@ -44,22 +45,17 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         user_id = UserID("us", "test")
         our_user = Requester(user_id, None, False, False, None, None)
         room_creator = self.homeserver.get_room_creation_handler()
-        room_deferred = ensureDeferred(
+        self.room_id = self.get_success(
             room_creator.create_room(
                 our_user, room_creator._presets_dict["public_chat"], ratelimit=False
             )
-        )
-        self.reactor.advance(0.1)
-        self.room_id = self.successResultOf(room_deferred)[0]["room_id"]
+        )[0]["room_id"]
 
         self.store = self.homeserver.get_datastore()
 
         # Figure out what the most recent event is
-        most_recent = self.successResultOf(
-            maybeDeferred(
-                self.homeserver.get_datastore().get_latest_event_ids_in_room,
-                self.room_id,
-            )
+        most_recent = self.get_success(
+            self.homeserver.get_datastore().get_latest_event_ids_in_room(self.room_id)
         )[0]
 
         join_event = make_event_from_dict(
@@ -89,19 +85,18 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         )
 
         # Send the join, it should return None (which is not an error)
-        d = ensureDeferred(
-            self.handler.on_receive_pdu(
-                "test.serv", join_event, sent_to_us_directly=True
-            )
+        self.assertEqual(
+            self.get_success(
+                self.handler.on_receive_pdu(
+                    "test.serv", join_event, sent_to_us_directly=True
+                )
+            ),
+            None,
         )
-        self.reactor.advance(1)
-        self.assertEqual(self.successResultOf(d), None)
 
         # Make sure we actually joined the room
         self.assertEqual(
-            self.successResultOf(
-                maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
-            )[0],
+            self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))[0],
             "$join:test.serv",
         )
 
@@ -119,8 +114,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         self.http_client.post_json = post_json
 
         # Figure out what the most recent event is
-        most_recent = self.successResultOf(
-            maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
+        most_recent = self.get_success(
+            self.store.get_latest_event_ids_in_room(self.room_id)
         )[0]
 
         # Now lie about an event
@@ -140,17 +135,14 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         )
 
         with LoggingContext(request="lying_event"):
-            d = ensureDeferred(
+            failure = self.get_failure(
                 self.handler.on_receive_pdu(
                     "test.serv", lying_event, sent_to_us_directly=True
-                )
+                ),
+                FederationError,
             )
 
-            # Step the reactor, so the database fetches come back
-            self.reactor.advance(1)
-
         # on_receive_pdu should throw an error
-        failure = self.failureResultOf(d)
         self.assertEqual(
             failure.value.args[0],
             (
@@ -160,8 +152,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         )
 
         # Make sure the invalid event isn't there
-        extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
-        self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
+        extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
+        self.assertEqual(extrem[0], "$join:test.serv")
 
     def test_retry_device_list_resync(self):
         """Tests that device lists are marked as stale if they couldn't be synced, and