summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-08-28 11:34:50 -0400
committerGitHub <noreply@github.com>2020-08-28 11:34:50 -0400
commitd58fda99ffd29c15254788476fb8775370eebec5 (patch)
tree8fae2ed28132aaea20c03dc25d07f6e9bcd7ee06 /synapse/storage/databases/main
parentOnly return devices with keys from `/federation/v1/user/devices/` (#8198) (diff)
downloadsynapse-d58fda99ffd29c15254788476fb8775370eebec5.tar.xz
Convert `event_push_actions`, `registration`, and `roommember` datastores to async (#8197)
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/event_push_actions.py38
-rw-r--r--synapse/storage/databases/main/registration.py238
-rw-r--r--synapse/storage/databases/main/roommember.py52
3 files changed, 168 insertions, 160 deletions
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index e8834b2162..384c39f4d7 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import List
+from typing import Dict, List, Union
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
@@ -383,19 +383,20 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         # Now return the first `limit`
         return notifs[:limit]
 
-    def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering):
+    async def get_if_maybe_push_in_range_for_user(
+        self, user_id: str, min_stream_ordering: int
+    ) -> bool:
         """A fast check to see if there might be something to push for the
         user since the given stream ordering. May return false positives.
 
         Useful to know whether to bother starting a pusher on start up or not.
 
         Args:
-            user_id (str)
-            min_stream_ordering (int)
+            user_id
+            min_stream_ordering
 
         Returns:
-            Deferred[bool]: True if there may be push to process, False if
-            there definitely isn't.
+            True if there may be push to process, False if there definitely isn't.
         """
 
         def _get_if_maybe_push_in_range_for_user_txn(txn):
@@ -408,22 +409,20 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             txn.execute(sql, (user_id, min_stream_ordering))
             return bool(txn.fetchone())
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_if_maybe_push_in_range_for_user",
             _get_if_maybe_push_in_range_for_user_txn,
         )
 
-    async def add_push_actions_to_staging(self, event_id, user_id_actions):
+    async def add_push_actions_to_staging(
+        self, event_id: str, user_id_actions: Dict[str, List[Union[dict, str]]]
+    ) -> None:
         """Add the push actions for the event to the push action staging area.
 
         Args:
-            event_id (str)
-            user_id_actions (dict[str, list[dict|str])]): A dictionary mapping
-                user_id to list of push actions, where an action can either be
-                a string or dict.
-
-        Returns:
-            Deferred
+            event_id
+            user_id_actions: A mapping of user_id to list of push actions, where
+                an action can either be a string or dict.
         """
 
         if not user_id_actions:
@@ -507,7 +506,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             "Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago
         )
 
-    def find_first_stream_ordering_after_ts(self, ts):
+    async def find_first_stream_ordering_after_ts(self, ts: int) -> int:
         """Gets the stream ordering corresponding to a given timestamp.
 
         Specifically, finds the stream_ordering of the first event that was
@@ -516,13 +515,12 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         relatively slow.
 
         Args:
-            ts (int): timestamp in millis
+            ts: timestamp in millis
 
         Returns:
-            Deferred[int]: stream ordering of the first event received on/after
-                the timestamp
+            stream ordering of the first event received on/after the timestamp
         """
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "_find_first_stream_ordering_after_ts_txn",
             self._find_first_stream_ordering_after_ts_txn,
             ts,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 12689f4308..000b8ccff4 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -17,7 +17,7 @@
 
 import logging
 import re
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple
 
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@@ -84,17 +84,17 @@ class RegistrationWorkerStore(SQLBaseStore):
         return is_trial
 
     @cached()
-    def get_user_by_access_token(self, token):
+    async def get_user_by_access_token(self, token: str) -> Optional[dict]:
         """Get a user from the given access token.
 
         Args:
-            token (str): The access token of a user.
+            token: The access token of a user.
         Returns:
-            defer.Deferred: None, if the token did not match, otherwise dict
-                including the keys `name`, `is_guest`, `device_id`, `token_id`,
-                `valid_until_ms`.
+            None, if the token did not match, otherwise dict
+            including the keys `name`, `is_guest`, `device_id`, `token_id`,
+            `valid_until_ms`.
         """
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_user_by_access_token", self._query_for_auth, token
         )
 
@@ -281,13 +281,12 @@ class RegistrationWorkerStore(SQLBaseStore):
 
         return bool(res) if res else False
 
-    def set_server_admin(self, user, admin):
+    async def set_server_admin(self, user: UserID, admin: bool) -> None:
         """Sets whether a user is an admin of this homeserver.
 
         Args:
-            user (UserID): user ID of the user to test
-            admin (bool): true iff the user is to be a server admin,
-                false otherwise.
+            user: user ID of the user to test
+            admin: true iff the user is to be a server admin, false otherwise.
         """
 
         def set_server_admin_txn(txn):
@@ -298,7 +297,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 txn, self.get_user_by_id, (user.to_string(),)
             )
 
-        return self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
+        await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
 
     def _query_for_auth(self, txn, token):
         sql = (
@@ -364,9 +363,11 @@ class RegistrationWorkerStore(SQLBaseStore):
         )
         return True if res == UserTypes.SUPPORT else False
 
-    def get_users_by_id_case_insensitive(self, user_id):
+    async def get_users_by_id_case_insensitive(self, user_id: str) -> Dict[str, str]:
         """Gets users that match user_id case insensitively.
-        Returns a mapping of user_id -> password_hash.
+
+        Returns:
+             A mapping of user_id -> password_hash.
         """
 
         def f(txn):
@@ -374,7 +375,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             txn.execute(sql, (user_id,))
             return dict(txn)
 
-        return self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
+        return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
 
     async def get_user_by_external_id(
         self, auth_provider: str, external_id: str
@@ -408,7 +409,7 @@ class RegistrationWorkerStore(SQLBaseStore):
 
         return await self.db_pool.runInteraction("count_users", _count_users)
 
-    def count_daily_user_type(self):
+    async def count_daily_user_type(self) -> Dict[str, int]:
         """
         Counts 1) native non guest users
                2) native guests users
@@ -437,7 +438,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 results[row[0]] = row[1]
             return results
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "count_daily_user_type", _count_daily_user_type
         )
 
@@ -663,24 +664,29 @@ class RegistrationWorkerStore(SQLBaseStore):
         # Convert the integer into a boolean.
         return res == 1
 
-    def get_threepid_validation_session(
-        self, medium, client_secret, address=None, sid=None, validated=True
-    ):
+    async def get_threepid_validation_session(
+        self,
+        medium: Optional[str],
+        client_secret: str,
+        address: Optional[str] = None,
+        sid: Optional[str] = None,
+        validated: Optional[bool] = True,
+    ) -> Optional[Dict[str, Any]]:
         """Gets a session_id and last_send_attempt (if available) for a
         combination of validation metadata
 
         Args:
-            medium (str|None): The medium of the 3PID
-            address (str|None): The address of the 3PID
-            sid (str|None): The ID of the validation session
-            client_secret (str): A unique string provided by the client to help identify this
+            medium: The medium of the 3PID
+            client_secret: A unique string provided by the client to help identify this
                 validation attempt
-            validated (bool|None): Whether sessions should be filtered by
+            address: The address of the 3PID
+            sid: The ID of the validation session
+            validated: Whether sessions should be filtered by
                 whether they have been validated already or not. None to
                 perform no filtering
 
         Returns:
-            Deferred[dict|None]: A dict containing the following:
+            A dict containing the following:
                 * address - address of the 3pid
                 * medium - medium of the 3pid
                 * client_secret - a secret provided by the client for this validation session
@@ -726,17 +732,17 @@ class RegistrationWorkerStore(SQLBaseStore):
 
             return rows[0]
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_threepid_validation_session", get_threepid_validation_session_txn
         )
 
-    def delete_threepid_session(self, session_id):
+    async def delete_threepid_session(self, session_id: str) -> None:
         """Removes a threepid validation session from the database. This can
         be done after validation has been performed and whatever action was
         waiting on it has been carried out
 
         Args:
-            session_id (str): The ID of the session to delete
+            session_id: The ID of the session to delete
         """
 
         def delete_threepid_session_txn(txn):
@@ -751,7 +757,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 keyvalues={"session_id": session_id},
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "delete_threepid_session", delete_threepid_session_txn
         )
 
@@ -941,43 +947,40 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             desc="add_access_token_to_user",
         )
 
-    def register_user(
+    async def register_user(
         self,
-        user_id,
-        password_hash=None,
-        was_guest=False,
-        make_guest=False,
-        appservice_id=None,
-        create_profile_with_displayname=None,
-        admin=False,
-        user_type=None,
-        shadow_banned=False,
-    ):
+        user_id: str,
+        password_hash: Optional[str] = None,
+        was_guest: bool = False,
+        make_guest: bool = False,
+        appservice_id: Optional[str] = None,
+        create_profile_with_displayname: Optional[str] = None,
+        admin: bool = False,
+        user_type: Optional[str] = None,
+        shadow_banned: bool = False,
+    ) -> None:
         """Attempts to register an account.
 
         Args:
-            user_id (str): The desired user ID to register.
-            password_hash (str|None): Optional. The password hash for this user.
-            was_guest (bool): Optional. Whether this is a guest account being
-                upgraded to a non-guest account.
-            make_guest (boolean): True if the the new user should be guest,
-                false to add a regular user account.
-            appservice_id (str): The ID of the appservice registering the user.
-            create_profile_with_displayname (unicode): Optionally create a profile for
+            user_id: The desired user ID to register.
+            password_hash: Optional. The password hash for this user.
+            was_guest: Whether this is a guest account being upgraded to a
+                non-guest account.
+            make_guest: True if the the new user should be guest, false to add a
+                regular user account.
+            appservice_id: The ID of the appservice registering the user.
+            create_profile_with_displayname: Optionally create a profile for
                 the user, setting their displayname to the given value
-            admin (boolean): is an admin user?
-            user_type (str|None): type of user. One of the values from
-                api.constants.UserTypes, or None for a normal user.
-            shadow_banned (bool): Whether the user is shadow-banned,
-                i.e. they may be told their requests succeeded but we ignore them.
+            admin: is an admin user?
+            user_type: type of user. One of the values from api.constants.UserTypes,
+                or None for a normal user.
+            shadow_banned: Whether the user is shadow-banned, i.e. they may be
+                told their requests succeeded but we ignore them.
 
         Raises:
             StoreError if the user_id could not be registered.
-
-        Returns:
-            Deferred
         """
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "register_user",
             self._register_user,
             user_id,
@@ -1101,7 +1104,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             desc="record_user_external_id",
         )
 
-    def user_set_password_hash(self, user_id, password_hash):
+    async def user_set_password_hash(self, user_id: str, password_hash: str) -> None:
         """
         NB. This does *not* evict any cache because the one use for this
             removes most of the entries subsequently anyway so it would be
@@ -1114,17 +1117,18 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             )
             self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "user_set_password_hash", user_set_password_hash_txn
         )
 
-    def user_set_consent_version(self, user_id, consent_version):
+    async def user_set_consent_version(
+        self, user_id: str, consent_version: str
+    ) -> None:
         """Updates the user table to record privacy policy consent
 
         Args:
-            user_id (str): full mxid of the user to update
-            consent_version (str): version of the policy the user has consented
-                to
+            user_id: full mxid of the user to update
+            consent_version: version of the policy the user has consented to
 
         Raises:
             StoreError(404) if user not found
@@ -1139,16 +1143,17 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             )
             self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
-        return self.db_pool.runInteraction("user_set_consent_version", f)
+        await self.db_pool.runInteraction("user_set_consent_version", f)
 
-    def user_set_consent_server_notice_sent(self, user_id, consent_version):
+    async def user_set_consent_server_notice_sent(
+        self, user_id: str, consent_version: str
+    ) -> None:
         """Updates the user table to record that we have sent the user a server
         notice about privacy policy consent
 
         Args:
-            user_id (str): full mxid of the user to update
-            consent_version (str): version of the policy we have notified the
-                user about
+            user_id: full mxid of the user to update
+            consent_version: version of the policy we have notified the user about
 
         Raises:
             StoreError(404) if user not found
@@ -1163,22 +1168,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             )
             self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
-        return self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
+        await self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
 
-    def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
+    async def user_delete_access_tokens(
+        self,
+        user_id: str,
+        except_token_id: Optional[str] = None,
+        device_id: Optional[str] = None,
+    ) -> List[Tuple[str, int, Optional[str]]]:
         """
         Invalidate access tokens belonging to a user
 
         Args:
-            user_id (str):  ID of user the tokens belong to
-            except_token_id (str): list of access_tokens IDs which should
-                *not* be deleted
-            device_id (str|None):  ID of device the tokens are associated with.
+            user_id: ID of user the tokens belong to
+            except_token_id: access_tokens ID which should *not* be deleted
+            device_id: ID of device the tokens are associated with.
                 If None, tokens associated with any device (or no device) will
                 be deleted
         Returns:
-            defer.Deferred[list[str, int, str|None, int]]: a list of
-                (token, token id, device id) for each of the deleted tokens
+            A tuple of (token, token id, device id) for each of the deleted tokens
         """
 
         def f(txn):
@@ -1209,9 +1217,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
 
             return tokens_and_devices
 
-        return self.db_pool.runInteraction("user_delete_access_tokens", f)
+        return await self.db_pool.runInteraction("user_delete_access_tokens", f)
 
-    def delete_access_token(self, access_token):
+    async def delete_access_token(self, access_token: str) -> None:
         def f(txn):
             self.db_pool.simple_delete_one_txn(
                 txn, table="access_tokens", keyvalues={"token": access_token}
@@ -1221,7 +1229,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                 txn, self.get_user_by_access_token, (access_token,)
             )
 
-        return self.db_pool.runInteraction("delete_access_token", f)
+        await self.db_pool.runInteraction("delete_access_token", f)
 
     @cached()
     async def is_guest(self, user_id: str) -> bool:
@@ -1272,24 +1280,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             desc="get_users_pending_deactivation",
         )
 
-    def validate_threepid_session(self, session_id, client_secret, token, current_ts):
+    async def validate_threepid_session(
+        self, session_id: str, client_secret: str, token: str, current_ts: int
+    ) -> Optional[str]:
         """Attempt to validate a threepid session using a token
 
         Args:
-            session_id (str): The id of a validation session
-            client_secret (str): A unique string provided by the client to
-                help identify this validation attempt
-            token (str): A validation token
-            current_ts (int): The current unix time in milliseconds. Used for
-                checking token expiry status
+            session_id: The id of a validation session
+            client_secret: A unique string provided by the client to help identify
+                this validation attempt
+            token: A validation token
+            current_ts: The current unix time in milliseconds. Used for checking
+                token expiry status
 
         Raises:
             ThreepidValidationError: if a matching validation token was not found or has
                 expired
 
         Returns:
-            deferred str|None: A str representing a link to redirect the user
-            to if there is one.
+            A str representing a link to redirect the user to if there is one.
         """
 
         # Insert everything into a transaction in order to run atomically
@@ -1359,36 +1368,35 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             return next_link
 
         # Return next_link if it exists
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "validate_threepid_session_txn", validate_threepid_session_txn
         )
 
-    def start_or_continue_validation_session(
+    async def start_or_continue_validation_session(
         self,
-        medium,
-        address,
-        session_id,
-        client_secret,
-        send_attempt,
-        next_link,
-        token,
-        token_expires,
-    ):
+        medium: str,
+        address: str,
+        session_id: str,
+        client_secret: str,
+        send_attempt: int,
+        next_link: Optional[str],
+        token: str,
+        token_expires: int,
+    ) -> None:
         """Creates a new threepid validation session if it does not already
         exist and associates a new validation token with it
 
         Args:
-            medium (str): The medium of the 3PID
-            address (str): The address of the 3PID
-            session_id (str): The id of this validation session
-            client_secret (str): A unique string provided by the client to
-                help identify this validation attempt
-            send_attempt (int): The latest send_attempt on this session
-            next_link (str|None): The link to redirect the user to upon
-                successful validation
-            token (str): The validation token
-            token_expires (int): The timestamp for which after the token
-                will no longer be valid
+            medium: The medium of the 3PID
+            address: The address of the 3PID
+            session_id: The id of this validation session
+            client_secret: A unique string provided by the client to help
+                identify this validation attempt
+            send_attempt: The latest send_attempt on this session
+            next_link: The link to redirect the user to upon successful validation
+            token: The validation token
+            token_expires: The timestamp for which after the token will no
+                longer be valid
         """
 
         def start_or_continue_validation_session_txn(txn):
@@ -1417,12 +1425,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                 },
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "start_or_continue_validation_session",
             start_or_continue_validation_session_txn,
         )
 
-    def cull_expired_threepid_validation_tokens(self):
+    async def cull_expired_threepid_validation_tokens(self) -> None:
         """Remove threepid validation tokens with expiry dates that have passed"""
 
         def cull_expired_threepid_validation_tokens_txn(txn, ts):
@@ -1430,9 +1438,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             DELETE FROM threepid_validation_token WHERE
             expires < ?
             """
-            return txn.execute(sql, (ts,))
+            txn.execute(sql, (ts,))
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "cull_expired_threepid_validation_tokens",
             cull_expired_threepid_validation_tokens_txn,
             self.clock.time_msec(),
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 161edbeccb..c4ce5c17c2 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
+from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
@@ -152,8 +152,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             )
 
     @cached(max_entries=100000, iterable=True)
-    def get_users_in_room(self, room_id: str):
-        return self.db_pool.runInteraction(
+    async def get_users_in_room(self, room_id: str) -> List[str]:
+        return await self.db_pool.runInteraction(
             "get_users_in_room", self.get_users_in_room_txn, room_id
         )
 
@@ -180,14 +180,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return [r[0] for r in txn]
 
     @cached(max_entries=100000)
-    def get_room_summary(self, room_id: str):
+    async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]:
         """ Get the details of a room roughly suitable for use by the room
         summary extension to /sync. Useful when lazy loading room members.
         Args:
             room_id: The room ID to query
         Returns:
-            Deferred[dict[str, MemberSummary]:
-                dict of membership states, pointing to a MemberSummary named tuple.
+            dict of membership states, pointing to a MemberSummary named tuple.
         """
 
         def _get_room_summary_txn(txn):
@@ -261,20 +260,22 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
             return res
 
-        return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn)
+        return await self.db_pool.runInteraction(
+            "get_room_summary", _get_room_summary_txn
+        )
 
     @cached()
-    def get_invited_rooms_for_local_user(self, user_id: str) -> Awaitable[RoomsForUser]:
+    async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser:
         """Get all the rooms the *local* user is invited to.
 
         Args:
             user_id: The user ID.
 
         Returns:
-            A awaitable list of RoomsForUser.
+            A list of RoomsForUser.
         """
 
-        return self.get_rooms_for_local_user_where_membership_is(
+        return await self.get_rooms_for_local_user_where_membership_is(
             user_id, [Membership.INVITE]
         )
 
@@ -357,7 +358,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return results
 
     @cached(max_entries=500000, iterable=True)
-    def get_rooms_for_user_with_stream_ordering(self, user_id: str):
+    async def get_rooms_for_user_with_stream_ordering(
+        self, user_id: str
+    ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
         """Returns a set of room_ids the user is currently joined to.
 
         If a remote user only returns rooms this server is currently
@@ -367,17 +370,18 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             user_id
 
         Returns:
-            Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
-            the rooms the user is in currently, along with the stream ordering
-            of the most recent join for that user and room.
+            Returns the rooms the user is in currently, along with the stream
+            ordering of the most recent join for that user and room.
         """
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_rooms_for_user_with_stream_ordering",
             self._get_rooms_for_user_with_stream_ordering_txn,
             user_id,
         )
 
-    def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id: str):
+    def _get_rooms_for_user_with_stream_ordering_txn(
+        self, txn, user_id: str
+    ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
         # We use `current_state_events` here and not `local_current_membership`
         # as a) this gets called with remote users and b) this only gets called
         # for rooms the server is participating in.
@@ -404,9 +408,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             """
 
         txn.execute(sql, (user_id, Membership.JOIN))
-        results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
-
-        return results
+        return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
 
     async def get_users_server_still_shares_room_with(
         self, user_ids: Collection[str]
@@ -711,14 +713,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return count == 0
 
     @cached()
-    def get_forgotten_rooms_for_user(self, user_id: str):
+    async def get_forgotten_rooms_for_user(self, user_id: str) -> Set[str]:
         """Gets all rooms the user has forgotten.
 
         Args:
-            user_id
+            user_id: The user ID to query the rooms of.
 
         Returns:
-            Deferred[set[str]]
+            The forgotten rooms.
         """
 
         def _get_forgotten_rooms_for_user_txn(txn):
@@ -744,7 +746,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             txn.execute(sql, (user_id,))
             return {row[0] for row in txn if row[1] == 0}
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
         )
 
@@ -973,7 +975,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super(RoomMemberStore, self).__init__(database, db_conn, hs)
 
-    def forget(self, user_id: str, room_id: str):
+    async def forget(self, user_id: str, room_id: str) -> None:
         """Indicate that user_id wishes to discard history for room_id."""
 
         def f(txn):
@@ -994,7 +996,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
                 txn, self.get_forgotten_rooms_for_user, (user_id,)
             )
 
-        return self.db_pool.runInteraction("forget_membership", f)
+        await self.db_pool.runInteraction("forget_membership", f)
 
 
 class _JoinedHostsCache(object):