summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/event_federation.py38
-rw-r--r--synapse/storage/databases/main/metrics.py20
-rw-r--r--synapse/storage/databases/main/registration.py233
-rw-r--r--synapse/storage/databases/main/tags.py103
-rw-r--r--synapse/storage/databases/state/bg_updates.py18
5 files changed, 184 insertions, 228 deletions
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index eddb32b4d3..484875f989 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -15,9 +15,7 @@
 import itertools
 import logging
 from queue import Empty, PriorityQueue
-from typing import Dict, List, Optional, Set, Tuple
-
-from twisted.internet import defer
+from typing import Dict, Iterable, List, Optional, Set, Tuple
 
 from synapse.api.errors import StoreError
 from synapse.metrics.background_process_metrics import run_as_background_process
@@ -286,17 +284,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
         return dict(txn)
 
-    @defer.inlineCallbacks
-    def get_max_depth_of(self, event_ids):
+    async def get_max_depth_of(self, event_ids: List[str]) -> int:
         """Returns the max depth of a set of event IDs
 
         Args:
-            event_ids (list[str])
-
-        Returns
-            Deferred[int]
+            event_ids: The event IDs to calculate the max depth of.
         """
-        rows = yield self.db_pool.simple_select_many_batch(
+        rows = await self.db_pool.simple_select_many_batch(
             table="events",
             column="event_id",
             iterable=event_ids,
@@ -550,9 +544,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
         return event_results
 
-    @defer.inlineCallbacks
-    def get_missing_events(self, room_id, earliest_events, latest_events, limit):
-        ids = yield self.db_pool.runInteraction(
+    async def get_missing_events(self, room_id, earliest_events, latest_events, limit):
+        ids = await self.db_pool.runInteraction(
             "get_missing_events",
             self._get_missing_events,
             room_id,
@@ -560,7 +553,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             latest_events,
             limit,
         )
-        events = yield self.get_events_as_list(ids)
+        events = await self.get_events_as_list(ids)
         return events
 
     def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
@@ -595,17 +588,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         event_results.reverse()
         return event_results
 
-    @defer.inlineCallbacks
-    def get_successor_events(self, event_ids):
+    async def get_successor_events(self, event_ids: Iterable[str]) -> List[str]:
         """Fetch all events that have the given events as a prev event
 
         Args:
-            event_ids (iterable[str])
-
-        Returns:
-            Deferred[list[str]]
+            event_ids: The events to use as the previous events.
         """
-        rows = yield self.db_pool.simple_select_many_batch(
+        rows = await self.db_pool.simple_select_many_batch(
             table="event_edges",
             column="prev_event_id",
             iterable=event_ids,
@@ -674,8 +663,7 @@ class EventFederationStore(EventFederationWorkerStore):
         txn.execute(query, (room_id,))
         txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
 
-    @defer.inlineCallbacks
-    def _background_delete_non_state_event_auth(self, progress, batch_size):
+    async def _background_delete_non_state_event_auth(self, progress, batch_size):
         def delete_event_auth(txn):
             target_min_stream_id = progress.get("target_min_stream_id_inclusive")
             max_stream_id = progress.get("max_stream_id_exclusive")
@@ -714,12 +702,12 @@ class EventFederationStore(EventFederationWorkerStore):
 
             return min_stream_id >= target_min_stream_id
 
-        result = yield self.db_pool.runInteraction(
+        result = await self.db_pool.runInteraction(
             self.EVENT_AUTH_STATE_ONLY, delete_event_auth
         )
 
         if not result:
-            yield self.db_pool.updates._end_background_update(
+            await self.db_pool.updates._end_background_update(
                 self.EVENT_AUTH_STATE_ONLY
             )
 
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index baa7a5092a..686052bd83 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -15,8 +15,6 @@
 import typing
 from collections import Counter
 
-from twisted.internet import defer
-
 from synapse.metrics import BucketCollector
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore
@@ -69,8 +67,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
         res = await self.db_pool.runInteraction("read_forward_extremities", fetch)
         self._current_forward_extremities_amount = Counter([x[0] for x in res])
 
-    @defer.inlineCallbacks
-    def count_daily_messages(self):
+    async def count_daily_messages(self):
         """
         Returns an estimate of the number of messages sent in the last day.
 
@@ -88,11 +85,9 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             (count,) = txn.fetchone()
             return count
 
-        ret = yield self.db_pool.runInteraction("count_messages", _count_messages)
-        return ret
+        return await self.db_pool.runInteraction("count_messages", _count_messages)
 
-    @defer.inlineCallbacks
-    def count_daily_sent_messages(self):
+    async def count_daily_sent_messages(self):
         def _count_messages(txn):
             # This is good enough as if you have silly characters in your own
             # hostname then thats your own fault.
@@ -109,13 +104,11 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             (count,) = txn.fetchone()
             return count
 
-        ret = yield self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "count_daily_sent_messages", _count_messages
         )
-        return ret
 
-    @defer.inlineCallbacks
-    def count_daily_active_rooms(self):
+    async def count_daily_active_rooms(self):
         def _count(txn):
             sql = """
                 SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
@@ -126,5 +119,4 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             (count,) = txn.fetchone()
             return count
 
-        ret = yield self.db_pool.runInteraction("count_daily_active_rooms", _count)
-        return ret
+        return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index f618629e09..402ae25571 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -17,9 +17,8 @@
 
 import logging
 import re
-from typing import Optional
+from typing import Dict, List, Optional
 
-from twisted.internet import defer
 from twisted.internet.defer import Deferred
 
 from synapse.api.constants import UserTypes
@@ -30,7 +29,7 @@ from synapse.storage.database import DatabasePool
 from synapse.storage.types import Cursor
 from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import UserID
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached
 
 THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
 
@@ -69,19 +68,15 @@ class RegistrationWorkerStore(SQLBaseStore):
             desc="get_user_by_id",
         )
 
-    @defer.inlineCallbacks
-    def is_trial_user(self, user_id):
+    async def is_trial_user(self, user_id: str) -> bool:
         """Checks if user is in the "trial" period, i.e. within the first
         N days of registration defined by `mau_trial_days` config
 
         Args:
-            user_id (str)
-
-        Returns:
-            Deferred[bool]
+            user_id: The user to check for trial status.
         """
 
-        info = yield self.get_user_by_id(user_id)
+        info = await self.get_user_by_id(user_id)
         if not info:
             return False
 
@@ -105,41 +100,42 @@ class RegistrationWorkerStore(SQLBaseStore):
             "get_user_by_access_token", self._query_for_auth, token
         )
 
-    @cachedInlineCallbacks()
-    def get_expiration_ts_for_user(self, user_id):
+    @cached()
+    async def get_expiration_ts_for_user(self, user_id: str) -> Optional[None]:
         """Get the expiration timestamp for the account bearing a given user ID.
 
         Args:
-            user_id (str): The ID of the user.
+            user_id: The ID of the user.
         Returns:
-            defer.Deferred: None, if the account has no expiration timestamp,
-                otherwise int representation of the timestamp (as a number of
-                milliseconds since epoch).
+            None, if the account has no expiration timestamp, otherwise int
+            representation of the timestamp (as a number of milliseconds since epoch).
         """
-        res = yield self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="account_validity",
             keyvalues={"user_id": user_id},
             retcol="expiration_ts_ms",
             allow_none=True,
             desc="get_expiration_ts_for_user",
         )
-        return res
 
-    @defer.inlineCallbacks
-    def set_account_validity_for_user(
-        self, user_id, expiration_ts, email_sent, renewal_token=None
-    ):
+    async def set_account_validity_for_user(
+        self,
+        user_id: str,
+        expiration_ts: int,
+        email_sent: bool,
+        renewal_token: Optional[str] = None,
+    ) -> None:
         """Updates the account validity properties of the given account, with the
         given values.
 
         Args:
-            user_id (str): ID of the account to update properties for.
-            expiration_ts (int): New expiration date, as a timestamp in milliseconds
+            user_id: ID of the account to update properties for.
+            expiration_ts: New expiration date, as a timestamp in milliseconds
                 since epoch.
-            email_sent (bool): True means a renewal email has been sent for this
-                account and there's no need to send another one for the current validity
+            email_sent: True means a renewal email has been sent for this account
+                and there's no need to send another one for the current validity
                 period.
-            renewal_token (str): Renewal token the user can use to extend the validity
+            renewal_token: Renewal token the user can use to extend the validity
                 of their account. Defaults to no token.
         """
 
@@ -158,75 +154,69 @@ class RegistrationWorkerStore(SQLBaseStore):
                 txn, self.get_expiration_ts_for_user, (user_id,)
             )
 
-        yield self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "set_account_validity_for_user", set_account_validity_for_user_txn
         )
 
-    @defer.inlineCallbacks
-    def set_renewal_token_for_user(self, user_id, renewal_token):
+    async def set_renewal_token_for_user(
+        self, user_id: str, renewal_token: str
+    ) -> None:
         """Defines a renewal token for a given user.
 
         Args:
-            user_id (str): ID of the user to set the renewal token for.
-            renewal_token (str): Random unique string that will be used to renew the
+            user_id: ID of the user to set the renewal token for.
+            renewal_token: Random unique string that will be used to renew the
                 user's account.
 
         Raises:
             StoreError: The provided token is already set for another user.
         """
-        yield self.db_pool.simple_update_one(
+        await self.db_pool.simple_update_one(
             table="account_validity",
             keyvalues={"user_id": user_id},
             updatevalues={"renewal_token": renewal_token},
             desc="set_renewal_token_for_user",
         )
 
-    @defer.inlineCallbacks
-    def get_user_from_renewal_token(self, renewal_token):
+    async def get_user_from_renewal_token(self, renewal_token: str) -> str:
         """Get a user ID from a renewal token.
 
         Args:
-            renewal_token (str): The renewal token to perform the lookup with.
+            renewal_token: The renewal token to perform the lookup with.
 
         Returns:
-            defer.Deferred[str]: The ID of the user to which the token belongs.
+            The ID of the user to which the token belongs.
         """
-        res = yield self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="account_validity",
             keyvalues={"renewal_token": renewal_token},
             retcol="user_id",
             desc="get_user_from_renewal_token",
         )
 
-        return res
-
-    @defer.inlineCallbacks
-    def get_renewal_token_for_user(self, user_id):
+    async def get_renewal_token_for_user(self, user_id: str) -> str:
         """Get the renewal token associated with a given user ID.
 
         Args:
-            user_id (str): The user ID to lookup a token for.
+            user_id: The user ID to lookup a token for.
 
         Returns:
-            defer.Deferred[str]: The renewal token associated with this user ID.
+            The renewal token associated with this user ID.
         """
-        res = yield self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="account_validity",
             keyvalues={"user_id": user_id},
             retcol="renewal_token",
             desc="get_renewal_token_for_user",
         )
 
-        return res
-
-    @defer.inlineCallbacks
-    def get_users_expiring_soon(self):
+    async def get_users_expiring_soon(self) -> List[Dict[str, int]]:
         """Selects users whose account will expire in the [now, now + renew_at] time
         window (see configuration for account_validity for information on what renew_at
         refers to).
 
         Returns:
-            Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]]
+            A list of dictionaries mapping user ID to expiration time (in milliseconds).
         """
 
         def select_users_txn(txn, now_ms, renew_at):
@@ -238,53 +228,49 @@ class RegistrationWorkerStore(SQLBaseStore):
             txn.execute(sql, values)
             return self.db_pool.cursor_to_dict(txn)
 
-        res = yield self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_users_expiring_soon",
             select_users_txn,
             self.clock.time_msec(),
             self.config.account_validity.renew_at,
         )
 
-        return res
-
-    @defer.inlineCallbacks
-    def set_renewal_mail_status(self, user_id, email_sent):
+    async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None:
         """Sets or unsets the flag that indicates whether a renewal email has been sent
         to the user (and the user hasn't renewed their account yet).
 
         Args:
-            user_id (str): ID of the user to set/unset the flag for.
-            email_sent (bool): Flag which indicates whether a renewal email has been sent
+            user_id: ID of the user to set/unset the flag for.
+            email_sent: Flag which indicates whether a renewal email has been sent
                 to this user.
         """
-        yield self.db_pool.simple_update_one(
+        await self.db_pool.simple_update_one(
             table="account_validity",
             keyvalues={"user_id": user_id},
             updatevalues={"email_sent": email_sent},
             desc="set_renewal_mail_status",
         )
 
-    @defer.inlineCallbacks
-    def delete_account_validity_for_user(self, user_id):
+    async def delete_account_validity_for_user(self, user_id: str) -> None:
         """Deletes the entry for the given user in the account validity table, removing
         their expiration date and renewal token.
 
         Args:
-            user_id (str): ID of the user to remove from the account validity table.
+            user_id: ID of the user to remove from the account validity table.
         """
-        yield self.db_pool.simple_delete_one(
+        await self.db_pool.simple_delete_one(
             table="account_validity",
             keyvalues={"user_id": user_id},
             desc="delete_account_validity_for_user",
         )
 
-    async def is_server_admin(self, user):
+    async def is_server_admin(self, user: UserID) -> bool:
         """Determines if a user is an admin of this homeserver.
 
         Args:
-            user (UserID): user ID of the user to test
+            user: user ID of the user to test
 
-        Returns (bool):
+        Returns:
             true iff the user is a server admin, false otherwise.
         """
         res = await self.db_pool.simple_select_one_onecol(
@@ -332,32 +318,31 @@ class RegistrationWorkerStore(SQLBaseStore):
 
         return None
 
-    @cachedInlineCallbacks()
-    def is_real_user(self, user_id):
+    @cached()
+    async def is_real_user(self, user_id: str) -> bool:
         """Determines if the user is a real user, ie does not have a 'user_type'.
 
         Args:
-            user_id (str): user id to test
+            user_id: user id to test
 
         Returns:
-            Deferred[bool]: True if user 'user_type' is null or empty string
+            True if user 'user_type' is null or empty string
         """
-        res = yield self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "is_real_user", self.is_real_user_txn, user_id
         )
-        return res
 
     @cached()
-    def is_support_user(self, user_id):
+    async def is_support_user(self, user_id: str) -> bool:
         """Determines if the user is of type UserTypes.SUPPORT
 
         Args:
-            user_id (str): user id to test
+            user_id: user id to test
 
         Returns:
-            Deferred[bool]: True if user is of type UserTypes.SUPPORT
+            True if user is of type UserTypes.SUPPORT
         """
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "is_support_user", self.is_support_user_txn, user_id
         )
 
@@ -413,8 +398,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             desc="get_user_by_external_id",
         )
 
-    @defer.inlineCallbacks
-    def count_all_users(self):
+    async def count_all_users(self):
         """Counts all users registered on the homeserver."""
 
         def _count_users(txn):
@@ -424,8 +408,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 return rows[0]["users"]
             return 0
 
-        ret = yield self.db_pool.runInteraction("count_users", _count_users)
-        return ret
+        return await self.db_pool.runInteraction("count_users", _count_users)
 
     def count_daily_user_type(self):
         """
@@ -460,8 +443,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             "count_daily_user_type", _count_daily_user_type
         )
 
-    @defer.inlineCallbacks
-    def count_nonbridged_users(self):
+    async def count_nonbridged_users(self):
         def _count_users(txn):
             txn.execute(
                 """
@@ -472,11 +454,9 @@ class RegistrationWorkerStore(SQLBaseStore):
             (count,) = txn.fetchone()
             return count
 
-        ret = yield self.db_pool.runInteraction("count_users", _count_users)
-        return ret
+        return await self.db_pool.runInteraction("count_users", _count_users)
 
-    @defer.inlineCallbacks
-    def count_real_users(self):
+    async def count_real_users(self):
         """Counts all users without a special user_type registered on the homeserver."""
 
         def _count_users(txn):
@@ -486,8 +466,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 return rows[0]["users"]
             return 0
 
-        ret = yield self.db_pool.runInteraction("count_real_users", _count_users)
-        return ret
+        return await self.db_pool.runInteraction("count_real_users", _count_users)
 
     async def generate_user_id(self) -> str:
         """Generate a suitable localpart for a guest user
@@ -537,23 +516,20 @@ class RegistrationWorkerStore(SQLBaseStore):
             return ret["user_id"]
         return None
 
-    @defer.inlineCallbacks
-    def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
-        yield self.db_pool.simple_upsert(
+    async def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
+        await self.db_pool.simple_upsert(
             "user_threepids",
             {"medium": medium, "address": address},
             {"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
         )
 
-    @defer.inlineCallbacks
-    def user_get_threepids(self, user_id):
-        ret = yield self.db_pool.simple_select_list(
+    async def user_get_threepids(self, user_id):
+        return await self.db_pool.simple_select_list(
             "user_threepids",
             {"user_id": user_id},
             ["medium", "address", "validated_at", "added_at"],
             "user_get_threepids",
         )
-        return ret
 
     def user_delete_threepid(self, user_id, medium, address):
         return self.db_pool.simple_delete(
@@ -668,18 +644,18 @@ class RegistrationWorkerStore(SQLBaseStore):
             desc="get_id_servers_user_bound",
         )
 
-    @cachedInlineCallbacks()
-    def get_user_deactivated_status(self, user_id):
+    @cached()
+    async def get_user_deactivated_status(self, user_id: str) -> bool:
         """Retrieve the value for the `deactivated` property for the provided user.
 
         Args:
-            user_id (str): The ID of the user to retrieve the status for.
+            user_id: The ID of the user to retrieve the status for.
 
         Returns:
-            defer.Deferred(bool): The requested value.
+            True if the user was deactivated, false if the user is still active.
         """
 
-        res = yield self.db_pool.simple_select_one_onecol(
+        res = await self.db_pool.simple_select_one_onecol(
             table="users",
             keyvalues={"name": user_id},
             retcol="deactivated",
@@ -818,8 +794,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
             "users_set_deactivated_flag", self._background_update_set_deactivated_flag
         )
 
-    @defer.inlineCallbacks
-    def _background_update_set_deactivated_flag(self, progress, batch_size):
+    async def _background_update_set_deactivated_flag(self, progress, batch_size):
         """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
         for each of them.
         """
@@ -870,19 +845,18 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
             else:
                 return False, len(rows)
 
-        end, nb_processed = yield self.db_pool.runInteraction(
+        end, nb_processed = await self.db_pool.runInteraction(
             "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn
         )
 
         if end:
-            yield self.db_pool.updates._end_background_update(
+            await self.db_pool.updates._end_background_update(
                 "users_set_deactivated_flag"
             )
 
         return nb_processed
 
-    @defer.inlineCallbacks
-    def _bg_user_threepids_grandfather(self, progress, batch_size):
+    async def _bg_user_threepids_grandfather(self, progress, batch_size):
         """We now track which identity servers a user binds their 3PID to, so
         we need to handle the case of existing bindings where we didn't track
         this.
@@ -903,11 +877,11 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
             txn.executemany(sql, [(id_server,) for id_server in id_servers])
 
         if id_servers:
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
             )
 
-        yield self.db_pool.updates._end_background_update("user_threepids_grandfather")
+        await self.db_pool.updates._end_background_update("user_threepids_grandfather")
 
         return 1
 
@@ -937,23 +911,26 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
 
         hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
 
-    @defer.inlineCallbacks
-    def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
+    async def add_access_token_to_user(
+        self,
+        user_id: str,
+        token: str,
+        device_id: Optional[str],
+        valid_until_ms: Optional[int],
+    ) -> None:
         """Adds an access token for the given user.
 
         Args:
-            user_id (str): The user ID.
-            token (str): The new access token to add.
-            device_id (str): ID of the device to associate with the access
-                token
-            valid_until_ms (int|None): when the token is valid until. None for
-                no expiry.
+            user_id: The user ID.
+            token: The new access token to add.
+            device_id: ID of the device to associate with the access token
+            valid_until_ms: when the token is valid until. None for no expiry.
         Raises:
             StoreError if there was a problem adding this.
         """
         next_id = self._access_tokens_id_gen.get_next()
 
-        yield self.db_pool.simple_insert(
+        await self.db_pool.simple_insert(
             "access_tokens",
             {
                 "id": next_id,
@@ -1097,7 +1074,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             )
 
         self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
-        txn.call_after(self.is_guest.invalidate, (user_id,))
 
     def record_user_external_id(
         self, auth_provider: str, external_id: str, user_id: str
@@ -1241,9 +1217,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
 
         return self.db_pool.runInteraction("delete_access_token", f)
 
-    @cachedInlineCallbacks()
-    def is_guest(self, user_id):
-        res = yield self.db_pool.simple_select_one_onecol(
+    @cached()
+    async def is_guest(self, user_id: str) -> bool:
+        res = await self.db_pool.simple_select_one_onecol(
             table="users",
             keyvalues={"name": user_id},
             retcol="is_guest",
@@ -1481,16 +1457,17 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             self.clock.time_msec(),
         )
 
-    @defer.inlineCallbacks
-    def set_user_deactivated_status(self, user_id, deactivated):
+    async def set_user_deactivated_status(
+        self, user_id: str, deactivated: bool
+    ) -> None:
         """Set the `deactivated` property for the provided user to the provided value.
 
         Args:
-            user_id (str): The ID of the user to set the status for.
-            deactivated (bool): The value to set for `deactivated`.
+            user_id: The ID of the user to set the status for.
+            deactivated: The value to set for `deactivated`.
         """
 
-        yield self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "set_user_deactivated_status",
             self.set_user_deactivated_status_txn,
             user_id,
@@ -1507,9 +1484,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         self._invalidate_cache_and_stream(
             txn, self.get_user_deactivated_status, (user_id,)
         )
+        txn.call_after(self.is_guest.invalidate, (user_id,))
 
-    @defer.inlineCallbacks
-    def _set_expiration_date_when_missing(self):
+    async def _set_expiration_date_when_missing(self):
         """
         Retrieves the list of registered users that don't have an expiration date, and
         adds an expiration date for each of them.
@@ -1533,7 +1510,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                         txn, user["name"], use_delta=True
                     )
 
-        yield self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "get_users_with_no_expiration_date",
             select_users_with_no_expiration_date_txn,
         )
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index eedd2d96c3..e4e0a0c433 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -15,14 +15,13 @@
 # limitations under the License.
 
 import logging
-from typing import List, Tuple
+from typing import Dict, List, Tuple
 
 from canonicaljson import json
 
-from twisted.internet import defer
-
 from synapse.storage._base import db_to_json
 from synapse.storage.databases.main.account_data import AccountDataWorkerStore
+from synapse.types import JsonDict
 from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
@@ -30,30 +29,26 @@ logger = logging.getLogger(__name__)
 
 class TagsWorkerStore(AccountDataWorkerStore):
     @cached()
-    def get_tags_for_user(self, user_id):
+    async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]]:
         """Get all the tags for a user.
 
 
         Args:
-            user_id(str): The user to get the tags for.
+            user_id: The user to get the tags for.
         Returns:
-            A deferred dict mapping from room_id strings to dicts mapping from
-            tag strings to tag content.
+            A mapping from room_id strings to dicts mapping from tag strings to
+            tag content.
         """
 
-        deferred = self.db_pool.simple_select_list(
+        rows = await self.db_pool.simple_select_list(
             "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
         )
 
-        @deferred.addCallback
-        def tags_by_room(rows):
-            tags_by_room = {}
-            for row in rows:
-                room_tags = tags_by_room.setdefault(row["room_id"], {})
-                room_tags[row["tag"]] = db_to_json(row["content"])
-            return tags_by_room
-
-        return deferred
+        tags_by_room = {}
+        for row in rows:
+            room_tags = tags_by_room.setdefault(row["room_id"], {})
+            room_tags[row["tag"]] = db_to_json(row["content"])
+        return tags_by_room
 
     async def get_all_updated_tags(
         self, instance_name: str, last_id: int, current_id: int, limit: int
@@ -127,17 +122,19 @@ class TagsWorkerStore(AccountDataWorkerStore):
 
         return results, upto_token, limited
 
-    @defer.inlineCallbacks
-    def get_updated_tags(self, user_id, stream_id):
+    async def get_updated_tags(
+        self, user_id: str, stream_id: int
+    ) -> Dict[str, List[str]]:
         """Get all the tags for the rooms where the tags have changed since the
         given version
 
         Args:
             user_id(str): The user to get the tags for.
             stream_id(int): The earliest update to get for the user.
+
         Returns:
-            A deferred dict mapping from room_id strings to lists of tag
-            strings for all the rooms that changed since the stream_id token.
+            A mapping from room_id strings to lists of tag strings for all the
+            rooms that changed since the stream_id token.
         """
 
         def get_updated_tags_txn(txn):
@@ -155,47 +152,53 @@ class TagsWorkerStore(AccountDataWorkerStore):
         if not changed:
             return {}
 
-        room_ids = yield self.db_pool.runInteraction(
+        room_ids = await self.db_pool.runInteraction(
             "get_updated_tags", get_updated_tags_txn
         )
 
         results = {}
         if room_ids:
-            tags_by_room = yield self.get_tags_for_user(user_id)
+            tags_by_room = await self.get_tags_for_user(user_id)
             for room_id in room_ids:
                 results[room_id] = tags_by_room.get(room_id, {})
 
         return results
 
-    def get_tags_for_room(self, user_id, room_id):
+    async def get_tags_for_room(
+        self, user_id: str, room_id: str
+    ) -> Dict[str, JsonDict]:
         """Get all the tags for the given room
+
         Args:
-            user_id(str): The user to get tags for
-            room_id(str): The room to get tags for
+            user_id: The user to get tags for
+            room_id: The room to get tags for
+
         Returns:
-            A deferred list of string tags.
+            A mapping of tags to tag content.
         """
-        return self.db_pool.simple_select_list(
+        rows = await self.db_pool.simple_select_list(
             table="room_tags",
             keyvalues={"user_id": user_id, "room_id": room_id},
             retcols=("tag", "content"),
             desc="get_tags_for_room",
-        ).addCallback(
-            lambda rows: {row["tag"]: db_to_json(row["content"]) for row in rows}
         )
+        return {row["tag"]: db_to_json(row["content"]) for row in rows}
 
 
 class TagsStore(TagsWorkerStore):
-    @defer.inlineCallbacks
-    def add_tag_to_room(self, user_id, room_id, tag, content):
+    async def add_tag_to_room(
+        self, user_id: str, room_id: str, tag: str, content: JsonDict
+    ) -> int:
         """Add a tag to a room for a user.
+
         Args:
-            user_id(str): The user to add a tag for.
-            room_id(str): The room to add a tag for.
-            tag(str): The tag name to add.
-            content(dict): A json object to associate with the tag.
+            user_id: The user to add a tag for.
+            room_id: The room to add a tag for.
+            tag: The tag name to add.
+            content: A json object to associate with the tag.
+
         Returns:
-            A deferred that completes once the tag has been added.
+            The next account data ID.
         """
         content_json = json.dumps(content)
 
@@ -209,18 +212,17 @@ class TagsStore(TagsWorkerStore):
             self._update_revision_txn(txn, user_id, room_id, next_id)
 
         with self._account_data_id_gen.get_next() as next_id:
-            yield self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
+            await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
 
         self.get_tags_for_user.invalidate((user_id,))
 
-        result = self._account_data_id_gen.get_current_token()
-        return result
+        return self._account_data_id_gen.get_current_token()
 
-    @defer.inlineCallbacks
-    def remove_tag_from_room(self, user_id, room_id, tag):
+    async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int:
         """Remove a tag from a room for a user.
+
         Returns:
-            A deferred that completes once the tag has been removed
+            The next account data ID.
         """
 
         def remove_tag_txn(txn, next_id):
@@ -232,21 +234,22 @@ class TagsStore(TagsWorkerStore):
             self._update_revision_txn(txn, user_id, room_id, next_id)
 
         with self._account_data_id_gen.get_next() as next_id:
-            yield self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
+            await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
 
         self.get_tags_for_user.invalidate((user_id,))
 
-        result = self._account_data_id_gen.get_current_token()
-        return result
+        return self._account_data_id_gen.get_current_token()
 
-    def _update_revision_txn(self, txn, user_id, room_id, next_id):
+    def _update_revision_txn(
+        self, txn, user_id: str, room_id: str, next_id: int
+    ) -> None:
         """Update the latest revision of the tags for the given user and room.
 
         Args:
             txn: The database cursor
-            user_id(str): The ID of the user.
-            room_id(str): The ID of the room.
-            next_id(int): The the revision to advance to.
+            user_id: The ID of the user.
+            room_id: The ID of the room.
+            next_id: The the revision to advance to.
         """
 
         txn.call_after(
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index 1e2d584098..139085b672 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -15,8 +15,6 @@
 
 import logging
 
-from twisted.internet import defer
-
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
 from synapse.storage.engines import PostgresEngine
@@ -198,8 +196,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
             columns=["room_id"],
         )
 
-    @defer.inlineCallbacks
-    def _background_deduplicate_state(self, progress, batch_size):
+    async def _background_deduplicate_state(self, progress, batch_size):
         """This background update will slowly deduplicate state by reencoding
         them as deltas.
         """
@@ -212,7 +209,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
         batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
 
         if max_group is None:
-            rows = yield self.db_pool.execute(
+            rows = await self.db_pool.execute(
                 "_background_deduplicate_state",
                 None,
                 "SELECT coalesce(max(id), 0) FROM state_groups",
@@ -330,19 +327,18 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
 
             return False, batch_size
 
-        finished, result = yield self.db_pool.runInteraction(
+        finished, result = await self.db_pool.runInteraction(
             self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn
         )
 
         if finished:
-            yield self.db_pool.updates._end_background_update(
+            await self.db_pool.updates._end_background_update(
                 self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
             )
 
         return result * BATCH_SIZE_SCALE_FACTOR
 
-    @defer.inlineCallbacks
-    def _background_index_state(self, progress, batch_size):
+    async def _background_index_state(self, progress, batch_size):
         def reindex_txn(conn):
             conn.rollback()
             if isinstance(self.database_engine, PostgresEngine):
@@ -365,9 +361,9 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
                 )
                 txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
 
-        yield self.db_pool.runWithConnection(reindex_txn)
+        await self.db_pool.runWithConnection(reindex_txn)
 
-        yield self.db_pool.updates._end_background_update(
+        await self.db_pool.updates._end_background_update(
             self.STATE_GROUP_INDEX_UPDATE_NAME
         )