summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8162.misc1
-rw-r--r--synapse/storage/database.py36
-rw-r--r--synapse/storage/databases/main/devices.py14
-rw-r--r--synapse/storage/databases/main/directory.py4
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py8
-rw-r--r--synapse/storage/databases/main/events_worker.py10
-rw-r--r--synapse/storage/databases/main/group_server.py18
-rw-r--r--synapse/storage/databases/main/media_repository.py13
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py15
-rw-r--r--synapse/storage/databases/main/profile.py17
-rw-r--r--synapse/storage/databases/main/receipts.py6
-rw-r--r--synapse/storage/databases/main/registration.py10
-rw-r--r--synapse/storage/databases/main/rejections.py5
-rw-r--r--synapse/storage/databases/main/room.py10
-rw-r--r--synapse/storage/databases/main/state.py4
-rw-r--r--synapse/storage/databases/main/stats.py10
-rw-r--r--synapse/storage/databases/main/user_directory.py9
-rw-r--r--tests/handlers/test_profile.py56
-rw-r--r--tests/handlers/test_typing.py4
-rw-r--r--tests/module_api/test_api.py2
-rw-r--r--tests/storage/test_base.py28
-rw-r--r--tests/storage/test_devices.py8
-rw-r--r--tests/storage/test_profile.py23
-rw-r--r--tests/storage/test_registration.py2
-rw-r--r--tests/storage/test_room.py20
25 files changed, 220 insertions, 113 deletions
diff --git a/changelog.d/8162.misc b/changelog.d/8162.misc
new file mode 100644
index 0000000000..e26764dea1
--- /dev/null
+++ b/changelog.d/8162.misc
@@ -0,0 +1 @@
+ Convert various parts of the codebase to async/await.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index bc327e344e..181c3ec249 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -29,9 +29,11 @@ from typing import (
     Tuple,
     TypeVar,
     Union,
+    overload,
 )
 
 from prometheus_client import Histogram
+from typing_extensions import Literal
 
 from twisted.enterprise import adbapi
 from twisted.internet import defer
@@ -1020,14 +1022,36 @@ class DatabasePool(object):
 
         return txn.execute_batch(sql, args)
 
-    def simple_select_one(
+    @overload
+    async def simple_select_one(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcols: Iterable[str],
+        allow_none: Literal[False] = False,
+        desc: str = "simple_select_one",
+    ) -> Dict[str, Any]:
+        ...
+
+    @overload
+    async def simple_select_one(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcols: Iterable[str],
+        allow_none: Literal[True] = True,
+        desc: str = "simple_select_one",
+    ) -> Optional[Dict[str, Any]]:
+        ...
+
+    async def simple_select_one(
         self,
         table: str,
         keyvalues: Dict[str, Any],
         retcols: Iterable[str],
         allow_none: bool = False,
         desc: str = "simple_select_one",
-    ) -> defer.Deferred:
+    ) -> Optional[Dict[str, Any]]:
         """Executes a SELECT query on the named table, which is expected to
         return a single row, returning multiple columns from it.
 
@@ -1038,18 +1062,18 @@ class DatabasePool(object):
             allow_none: If true, return None instead of failing if the SELECT
                 statement returns no rows
         """
-        return self.runInteraction(
+        return await self.runInteraction(
             desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
         )
 
-    def simple_select_one_onecol(
+    async def simple_select_one_onecol(
         self,
         table: str,
         keyvalues: Dict[str, Any],
         retcol: Iterable[str],
         allow_none: bool = False,
         desc: str = "simple_select_one_onecol",
-    ) -> defer.Deferred:
+    ) -> Optional[Any]:
         """Executes a SELECT query on the named table, which is expected to
         return a single row, returning a single column from it.
 
@@ -1061,7 +1085,7 @@ class DatabasePool(object):
                 statement returns no rows
             desc: description of the transaction, for logging and metrics
         """
-        return self.runInteraction(
+        return await self.runInteraction(
             desc,
             self.simple_select_one_onecol_txn,
             table,
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 03b45dbc4d..a811a39eb5 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -15,7 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Dict, Iterable, List, Optional, Set, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
 
 from synapse.api.errors import Codes, StoreError
 from synapse.logging.opentracing import (
@@ -47,7 +47,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
 
 
 class DeviceWorkerStore(SQLBaseStore):
-    def get_device(self, user_id: str, device_id: str):
+    async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
         """Retrieve a device. Only returns devices that are not marked as
         hidden.
 
@@ -55,11 +55,11 @@ class DeviceWorkerStore(SQLBaseStore):
             user_id: The ID of the user which owns the device
             device_id: The ID of the device to retrieve
         Returns:
-            defer.Deferred for a dict containing the device information
+            A dict containing the device information
         Raises:
             StoreError: if the device is not found
         """
-        return self.db_pool.simple_select_one(
+        return await self.db_pool.simple_select_one(
             table="devices",
             keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
             retcols=("user_id", "device_id", "display_name"),
@@ -656,11 +656,13 @@ class DeviceWorkerStore(SQLBaseStore):
         )
 
     @cached(max_entries=10000)
-    def get_device_list_last_stream_id_for_remote(self, user_id: str):
+    async def get_device_list_last_stream_id_for_remote(
+        self, user_id: str
+    ) -> Optional[Any]:
         """Get the last stream_id we got for a user. May be None if we haven't
         got any information for them.
         """
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="device_lists_remote_extremeties",
             keyvalues={"user_id": user_id},
             retcol="stream_id",
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index 037e02603c..301d5d845a 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -59,8 +59,8 @@ class DirectoryWorkerStore(SQLBaseStore):
 
         return RoomAliasMapping(room_id, room_alias.to_string(), servers)
 
-    def get_room_alias_creator(self, room_alias):
-        return self.db_pool.simple_select_one_onecol(
+    async def get_room_alias_creator(self, room_alias: str) -> str:
+        return await self.db_pool.simple_select_one_onecol(
             table="room_aliases",
             keyvalues={"room_alias": room_alias},
             retcol="creator",
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index 2eeb9f97dc..46c3e33cc6 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -223,15 +223,15 @@ class EndToEndRoomKeyStore(SQLBaseStore):
 
         return ret
 
-    def count_e2e_room_keys(self, user_id, version):
+    async def count_e2e_room_keys(self, user_id: str, version: str) -> int:
         """Get the number of keys in a backup version.
 
         Args:
-            user_id (str): the user whose backup we're querying
-            version (str): the version ID of the backup we're querying about
+            user_id: the user whose backup we're querying
+            version: the version ID of the backup we're querying about
         """
 
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="e2e_room_keys",
             keyvalues={"user_id": user_id, "version": version},
             retcol="COUNT(*)",
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index e1241a724b..d59d73938a 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -119,19 +119,19 @@ class EventsWorkerStore(SQLBaseStore):
 
         super().process_replication_rows(stream_name, instance_name, token, rows)
 
-    def get_received_ts(self, event_id):
+    async def get_received_ts(self, event_id: str) -> Optional[int]:
         """Get received_ts (when it was persisted) for the event.
 
         Raises an exception for unknown events.
 
         Args:
-            event_id (str)
+            event_id: The event ID to query.
 
         Returns:
-            Deferred[int|None]: Timestamp in milliseconds, or None for events
-            that were persisted before received_ts was implemented.
+            Timestamp in milliseconds, or None for events that were persisted
+            before received_ts was implemented.
         """
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="events",
             keyvalues={"event_id": event_id},
             retcol="received_ts",
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index a488e0924b..c39864f59f 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
 
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore, db_to_json
@@ -28,8 +28,8 @@ _DEFAULT_ROLE_ID = ""
 
 
 class GroupServerWorkerStore(SQLBaseStore):
-    def get_group(self, group_id):
-        return self.db_pool.simple_select_one(
+    async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
+        return await self.db_pool.simple_select_one(
             table="groups",
             keyvalues={"group_id": group_id},
             retcols=(
@@ -351,8 +351,10 @@ class GroupServerWorkerStore(SQLBaseStore):
         )
         return bool(result)
 
-    def is_user_admin_in_group(self, group_id, user_id):
-        return self.db_pool.simple_select_one_onecol(
+    async def is_user_admin_in_group(
+        self, group_id: str, user_id: str
+    ) -> Optional[bool]:
+        return await self.db_pool.simple_select_one_onecol(
             table="group_users",
             keyvalues={"group_id": group_id, "user_id": user_id},
             retcol="is_admin",
@@ -360,10 +362,12 @@ class GroupServerWorkerStore(SQLBaseStore):
             desc="is_user_admin_in_group",
         )
 
-    def is_user_invited_to_local_group(self, group_id, user_id):
+    async def is_user_invited_to_local_group(
+        self, group_id: str, user_id: str
+    ) -> Optional[bool]:
         """Has the group server invited a user?
         """
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="group_invites",
             keyvalues={"group_id": group_id, "user_id": user_id},
             retcol="user_id",
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 80fc1cd009..4ae255ebd8 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -12,6 +12,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Any, Dict, Optional
+
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
 
@@ -37,12 +39,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
 
-    def get_local_media(self, media_id):
+    async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
         """Get the metadata for a local piece of media
+
         Returns:
             None if the media_id doesn't exist.
         """
-        return self.db_pool.simple_select_one(
+        return await self.db_pool.simple_select_one(
             "local_media_repository",
             {"media_id": media_id},
             (
@@ -191,8 +194,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="store_local_thumbnail",
         )
 
-    def get_cached_remote_media(self, origin, media_id):
-        return self.db_pool.simple_select_one(
+    async def get_cached_remote_media(
+        self, origin, media_id: str
+    ) -> Optional[Dict[str, Any]]:
+        return await self.db_pool.simple_select_one(
             "remote_media_cache",
             {"media_origin": origin, "media_id": media_id},
             (
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index e71cdd2cb4..fe30552c08 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -99,17 +99,18 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
         return users
 
     @cached(num_args=1)
-    def user_last_seen_monthly_active(self, user_id):
+    async def user_last_seen_monthly_active(self, user_id: str) -> int:
         """
-            Checks if a given user is part of the monthly active user group
-            Arguments:
-                user_id (str): user to add/update
-            Return:
-                Deferred[int] : timestamp since last seen, None if never seen
+        Checks if a given user is part of the monthly active user group
 
+        Arguments:
+            user_id: user to add/update
+
+        Return:
+            Timestamp since last seen, None if never seen
         """
 
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="monthly_active_users",
             keyvalues={"user_id": user_id},
             retcol="timestamp",
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index b8261357d4..b8233c4848 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Any, Dict, Optional
 
 from synapse.api.errors import StoreError
 from synapse.storage._base import SQLBaseStore
@@ -19,7 +20,7 @@ from synapse.storage.databases.main.roommember import ProfileInfo
 
 
 class ProfileWorkerStore(SQLBaseStore):
-    async def get_profileinfo(self, user_localpart):
+    async def get_profileinfo(self, user_localpart: str) -> ProfileInfo:
         try:
             profile = await self.db_pool.simple_select_one(
                 table="profiles",
@@ -38,24 +39,26 @@ class ProfileWorkerStore(SQLBaseStore):
             avatar_url=profile["avatar_url"], display_name=profile["displayname"]
         )
 
-    def get_profile_displayname(self, user_localpart):
-        return self.db_pool.simple_select_one_onecol(
+    async def get_profile_displayname(self, user_localpart: str) -> str:
+        return await self.db_pool.simple_select_one_onecol(
             table="profiles",
             keyvalues={"user_id": user_localpart},
             retcol="displayname",
             desc="get_profile_displayname",
         )
 
-    def get_profile_avatar_url(self, user_localpart):
-        return self.db_pool.simple_select_one_onecol(
+    async def get_profile_avatar_url(self, user_localpart: str) -> str:
+        return await self.db_pool.simple_select_one_onecol(
             table="profiles",
             keyvalues={"user_id": user_localpart},
             retcol="avatar_url",
             desc="get_profile_avatar_url",
         )
 
-    def get_from_remote_profile_cache(self, user_id):
-        return self.db_pool.simple_select_one(
+    async def get_from_remote_profile_cache(
+        self, user_id: str
+    ) -> Optional[Dict[str, Any]]:
+        return await self.db_pool.simple_select_one(
             table="remote_profile_cache",
             keyvalues={"user_id": user_id},
             retcols=("displayname", "avatar_url"),
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 6821476ee0..cea5ac9a68 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -71,8 +71,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
         )
 
     @cached(num_args=3)
-    def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
-        return self.db_pool.simple_select_one_onecol(
+    async def get_last_receipt_event_id_for_user(
+        self, user_id: str, room_id: str, receipt_type: str
+    ) -> Optional[str]:
+        return await self.db_pool.simple_select_one_onecol(
             table="receipts_linearized",
             keyvalues={
                 "room_id": room_id,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 321a51cc6a..eced53d470 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -17,7 +17,7 @@
 
 import logging
 import re
-from typing import Awaitable, Dict, List, Optional
+from typing import Any, Awaitable, Dict, List, Optional
 
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@@ -46,8 +46,8 @@ class RegistrationWorkerStore(SQLBaseStore):
         )
 
     @cached()
-    def get_user_by_id(self, user_id):
-        return self.db_pool.simple_select_one(
+    async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
+        return await self.db_pool.simple_select_one(
             table="users",
             keyvalues={"name": user_id},
             retcols=[
@@ -1259,12 +1259,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             desc="del_user_pending_deactivation",
         )
 
-    def get_user_pending_deactivation(self):
+    async def get_user_pending_deactivation(self) -> Optional[str]:
         """
         Gets one user from the table of users waiting to be parted from all the rooms
         they're in.
         """
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             "users_pending_deactivation",
             keyvalues={},
             retcol="user_id",
diff --git a/synapse/storage/databases/main/rejections.py b/synapse/storage/databases/main/rejections.py
index cf9ba51205..1e361aaa9a 100644
--- a/synapse/storage/databases/main/rejections.py
+++ b/synapse/storage/databases/main/rejections.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import Optional
 
 from synapse.storage._base import SQLBaseStore
 
@@ -21,8 +22,8 @@ logger = logging.getLogger(__name__)
 
 
 class RejectionsStore(SQLBaseStore):
-    def get_rejection_reason(self, event_id):
-        return self.db_pool.simple_select_one_onecol(
+    async def get_rejection_reason(self, event_id: str) -> Optional[str]:
+        return await self.db_pool.simple_select_one_onecol(
             table="rejections",
             retcol="reason",
             keyvalues={"event_id": event_id},
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index b3772be2b2..97ecdb16e4 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -73,15 +73,15 @@ class RoomWorkerStore(SQLBaseStore):
 
         self.config = hs.config
 
-    def get_room(self, room_id):
+    async def get_room(self, room_id: str) -> dict:
         """Retrieve a room.
 
         Args:
-            room_id (str): The ID of the room to retrieve.
+            room_id: The ID of the room to retrieve.
         Returns:
             A dict containing the room information, or None if the room is unknown.
         """
-        return self.db_pool.simple_select_one(
+        return await self.db_pool.simple_select_one(
             table="rooms",
             keyvalues={"room_id": room_id},
             retcols=("room_id", "is_public", "creator"),
@@ -330,8 +330,8 @@ class RoomWorkerStore(SQLBaseStore):
         return ret_val
 
     @cached(max_entries=10000)
-    def is_room_blocked(self, room_id):
-        return self.db_pool.simple_select_one_onecol(
+    async def is_room_blocked(self, room_id: str) -> Optional[bool]:
+        return await self.db_pool.simple_select_one_onecol(
             table="blocked_rooms",
             keyvalues={"room_id": room_id},
             retcol="1",
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 991233a9bc..458f169617 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -260,8 +260,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         return event.content.get("canonical_alias")
 
     @cached(max_entries=50000)
-    def _get_state_group_for_event(self, event_id):
-        return self.db_pool.simple_select_one_onecol(
+    async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
+        return await self.db_pool.simple_select_one_onecol(
             table="event_to_state_groups",
             keyvalues={"event_id": event_id},
             retcol="state_group",
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 802c9019b9..9fe97af56a 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -211,11 +211,11 @@ class StatsStore(StateDeltasStore):
 
         return len(rooms_to_work_on)
 
-    def get_stats_positions(self):
+    async def get_stats_positions(self) -> int:
         """
         Returns the stats processor positions.
         """
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="stats_incremental_position",
             keyvalues={},
             retcol="stream_id",
@@ -300,7 +300,7 @@ class StatsStore(StateDeltasStore):
         return slice_list
 
     @cached()
-    def get_earliest_token_for_stats(self, stats_type, id):
+    async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int:
         """
         Fetch the "earliest token". This is used by the room stats delta
         processor to ignore deltas that have been processed between the
@@ -308,11 +308,11 @@ class StatsStore(StateDeltasStore):
         being calculated.
 
         Returns:
-            Deferred[int]
+            The earliest token.
         """
         table, id_col = TYPE_TO_TABLE[stats_type]
 
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             "%s_current" % (table,),
             keyvalues={id_col: id},
             retcol="completed_delta_stream_id",
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index af21fe457a..20cbcd851c 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -15,6 +15,7 @@
 
 import logging
 import re
+from typing import Any, Dict, Optional
 
 from synapse.api.constants import EventTypes, JoinRules
 from synapse.storage.database import DatabasePool
@@ -527,8 +528,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         )
 
     @cached()
-    def get_user_in_directory(self, user_id):
-        return self.db_pool.simple_select_one(
+    async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]:
+        return await self.db_pool.simple_select_one(
             table="user_directory",
             keyvalues={"user_id": user_id},
             retcols=("display_name", "avatar_url"),
@@ -663,8 +664,8 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
         users.update(rows)
         return list(users)
 
-    def get_user_directory_stream_pos(self):
-        return self.db_pool.simple_select_one_onecol(
+    async def get_user_directory_stream_pos(self) -> int:
+        return await self.db_pool.simple_select_one_onecol(
             table="user_directory_stream_pos",
             keyvalues={},
             retcol="stream_id",
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index b609b30d4a..60ebc95f3e 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -71,7 +71,9 @@ class ProfileTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_my_name(self):
-        yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
+        yield defer.ensureDeferred(
+            self.store.set_profile_displayname(self.frank.localpart, "Frank")
+        )
 
         displayname = yield defer.ensureDeferred(
             self.handler.get_displayname(self.frank)
@@ -104,7 +106,12 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         self.assertEquals(
-            (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_displayname(self.frank.localpart)
+                )
+            ),
+            "Frank",
         )
 
     @defer.inlineCallbacks
@@ -112,10 +119,17 @@ class ProfileTestCase(unittest.TestCase):
         self.hs.config.enable_set_displayname = False
 
         # Setting displayname for the first time is allowed
-        yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
+        yield defer.ensureDeferred(
+            self.store.set_profile_displayname(self.frank.localpart, "Frank")
+        )
 
         self.assertEquals(
-            (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_displayname(self.frank.localpart)
+                )
+            ),
+            "Frank",
         )
 
         # Setting displayname a second time is forbidden
@@ -158,7 +172,9 @@ class ProfileTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_incoming_fed_query(self):
         yield defer.ensureDeferred(self.store.create_profile("caroline"))
-        yield self.store.set_profile_displayname("caroline", "Caroline")
+        yield defer.ensureDeferred(
+            self.store.set_profile_displayname("caroline", "Caroline")
+        )
 
         response = yield defer.ensureDeferred(
             self.query_handlers["profile"](
@@ -170,8 +186,10 @@ class ProfileTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_my_avatar(self):
-        yield self.store.set_profile_avatar_url(
-            self.frank.localpart, "http://my.server/me.png"
+        yield defer.ensureDeferred(
+            self.store.set_profile_avatar_url(
+                self.frank.localpart, "http://my.server/me.png"
+            )
         )
         avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
 
@@ -188,7 +206,11 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         self.assertEquals(
-            (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_avatar_url(self.frank.localpart)
+                )
+            ),
             "http://my.server/pic.gif",
         )
 
@@ -202,7 +224,11 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         self.assertEquals(
-            (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_avatar_url(self.frank.localpart)
+                )
+            ),
             "http://my.server/me.png",
         )
 
@@ -211,12 +237,18 @@ class ProfileTestCase(unittest.TestCase):
         self.hs.config.enable_set_avatar_url = False
 
         # Setting displayname for the first time is allowed
-        yield self.store.set_profile_avatar_url(
-            self.frank.localpart, "http://my.server/me.png"
+        yield defer.ensureDeferred(
+            self.store.set_profile_avatar_url(
+                self.frank.localpart, "http://my.server/me.png"
+            )
         )
 
         self.assertEquals(
-            (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_avatar_url(self.frank.localpart)
+                )
+            ),
             "http://my.server/me.png",
         )
 
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index e01de158e5..834b4a0af6 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -144,9 +144,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.datastore.get_users_in_room = get_users_in_room
 
-        self.datastore.get_user_directory_stream_pos.return_value = (
+        self.datastore.get_user_directory_stream_pos.side_effect = (
             # we deliberately return a non-None stream pos to avoid doing an initial_spam
-            defer.succeed(1)
+            lambda: make_awaitable(1)
         )
 
         self.datastore.get_current_state_deltas.return_value = (0, None)
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 807cd65dd6..04de0b9dbe 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -35,7 +35,7 @@ class ModuleApiTestCase(HomeserverTestCase):
         # Check that the new user exists with all provided attributes
         self.assertEqual(user_id, "@bob:test")
         self.assertTrue(access_token)
-        self.assertTrue(self.store.get_user_by_id(user_id))
+        self.assertTrue(self.get_success(self.store.get_user_by_id(user_id)))
 
         # Check that the email was assigned
         emails = self.get_success(self.store.user_get_threepids(user_id))
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 13bcac743a..bf22540d99 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -97,8 +97,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.rowcount = 1
         self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
 
-        value = yield self.datastore.db_pool.simple_select_one_onecol(
-            table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
+        value = yield defer.ensureDeferred(
+            self.datastore.db_pool.simple_select_one_onecol(
+                table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
+            )
         )
 
         self.assertEquals("Value", value)
@@ -111,10 +113,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.rowcount = 1
         self.mock_txn.fetchone.return_value = (1, 2, 3)
 
-        ret = yield self.datastore.db_pool.simple_select_one(
-            table="tablename",
-            keyvalues={"keycol": "TheKey"},
-            retcols=["colA", "colB", "colC"],
+        ret = yield defer.ensureDeferred(
+            self.datastore.db_pool.simple_select_one(
+                table="tablename",
+                keyvalues={"keycol": "TheKey"},
+                retcols=["colA", "colB", "colC"],
+            )
         )
 
         self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret)
@@ -127,11 +131,13 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.rowcount = 0
         self.mock_txn.fetchone.return_value = None
 
-        ret = yield self.datastore.db_pool.simple_select_one(
-            table="tablename",
-            keyvalues={"keycol": "Not here"},
-            retcols=["colA"],
-            allow_none=True,
+        ret = yield defer.ensureDeferred(
+            self.datastore.db_pool.simple_select_one(
+                table="tablename",
+                keyvalues={"keycol": "Not here"},
+                retcols=["colA"],
+                allow_none=True,
+            )
         )
 
         self.assertFalse(ret)
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 87ed8f8cd1..34ae8c9da7 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -38,7 +38,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
             self.store.store_device("user_id", "device_id", "display_name")
         )
 
-        res = yield self.store.get_device("user_id", "device_id")
+        res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
         self.assertDictContainsSubset(
             {
                 "user_id": "user_id",
@@ -111,12 +111,12 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
             self.store.store_device("user_id", "device_id", "display_name 1")
         )
 
-        res = yield self.store.get_device("user_id", "device_id")
+        res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
         self.assertEqual("display_name 1", res["display_name"])
 
         # do a no-op first
         yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
-        res = yield self.store.get_device("user_id", "device_id")
+        res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
         self.assertEqual("display_name 1", res["display_name"])
 
         # do the update
@@ -127,7 +127,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
         )
 
         # check it worked
-        res = yield self.store.get_device("user_id", "device_id")
+        res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
         self.assertEqual("display_name 2", res["display_name"])
 
     @defer.inlineCallbacks
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 9d5b8aa47d..3fd0a38cf5 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -35,21 +35,34 @@ class ProfileStoreTestCase(unittest.TestCase):
     def test_displayname(self):
         yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
 
-        yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
+        yield defer.ensureDeferred(
+            self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
+        )
 
         self.assertEquals(
-            "Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart))
+            "Frank",
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_displayname(self.u_frank.localpart)
+                )
+            ),
         )
 
     @defer.inlineCallbacks
     def test_avatar_url(self):
         yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
 
-        yield self.store.set_profile_avatar_url(
-            self.u_frank.localpart, "http://my.site/here"
+        yield defer.ensureDeferred(
+            self.store.set_profile_avatar_url(
+                self.u_frank.localpart, "http://my.site/here"
+            )
         )
 
         self.assertEquals(
             "http://my.site/here",
-            (yield self.store.get_profile_avatar_url(self.u_frank.localpart)),
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_avatar_url(self.u_frank.localpart)
+                )
+            ),
         )
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 58f827d8d3..70c55cd650 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -53,7 +53,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
                 "user_type": None,
                 "deactivated": 0,
             },
-            (yield self.store.get_user_by_id(self.user_id)),
+            (yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))),
         )
 
     @defer.inlineCallbacks
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index d07b985a8e..bc8400f240 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -54,12 +54,14 @@ class RoomStoreTestCase(unittest.TestCase):
                 "creator": self.u_creator.to_string(),
                 "is_public": True,
             },
-            (yield self.store.get_room(self.room.to_string())),
+            (yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))),
         )
 
     @defer.inlineCallbacks
     def test_get_room_unknown_room(self):
-        self.assertIsNone((yield self.store.get_room("!uknown:test")),)
+        self.assertIsNone(
+            (yield defer.ensureDeferred(self.store.get_room("!uknown:test")))
+        )
 
     @defer.inlineCallbacks
     def test_get_room_with_stats(self):
@@ -69,12 +71,22 @@ class RoomStoreTestCase(unittest.TestCase):
                 "creator": self.u_creator.to_string(),
                 "public": True,
             },
-            (yield self.store.get_room_with_stats(self.room.to_string())),
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_room_with_stats(self.room.to_string())
+                )
+            ),
         )
 
     @defer.inlineCallbacks
     def test_get_room_with_stats_unknown_room(self):
-        self.assertIsNone((yield self.store.get_room_with_stats("!uknown:test")),)
+        self.assertIsNone(
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_room_with_stats("!uknown:test")
+                )
+            ),
+        )
 
 
 class RoomEventsStoreTestCase(unittest.TestCase):