summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-08-11 17:21:13 -0400
committerGitHub <noreply@github.com>2020-08-11 17:21:13 -0400
commita0acdfa9e93ae63a3adee264d5420fdd1d38d76e (patch)
treee39f391b56dcbb25ebc381e15a635cab3abc2d21 /synapse/storage/databases/main
parentAuto set logging filter (#8051) (diff)
downloadsynapse-a0acdfa9e93ae63a3adee264d5420fdd1d38d76e.tar.xz
Converts event_federation and registration databases to async/await (#8061)
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/event_federation.py38
-rw-r--r--synapse/storage/databases/main/registration.py233
2 files changed, 118 insertions, 153 deletions
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index eddb32b4d3..484875f989 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -15,9 +15,7 @@
 import itertools
 import logging
 from queue import Empty, PriorityQueue
-from typing import Dict, List, Optional, Set, Tuple
-
-from twisted.internet import defer
+from typing import Dict, Iterable, List, Optional, Set, Tuple
 
 from synapse.api.errors import StoreError
 from synapse.metrics.background_process_metrics import run_as_background_process
@@ -286,17 +284,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
         return dict(txn)
 
-    @defer.inlineCallbacks
-    def get_max_depth_of(self, event_ids):
+    async def get_max_depth_of(self, event_ids: List[str]) -> int:
         """Returns the max depth of a set of event IDs
 
         Args:
-            event_ids (list[str])
-
-        Returns
-            Deferred[int]
+            event_ids: The event IDs to calculate the max depth of.
         """
-        rows = yield self.db_pool.simple_select_many_batch(
+        rows = await self.db_pool.simple_select_many_batch(
             table="events",
             column="event_id",
             iterable=event_ids,
@@ -550,9 +544,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
         return event_results
 
-    @defer.inlineCallbacks
-    def get_missing_events(self, room_id, earliest_events, latest_events, limit):
-        ids = yield self.db_pool.runInteraction(
+    async def get_missing_events(self, room_id, earliest_events, latest_events, limit):
+        ids = await self.db_pool.runInteraction(
             "get_missing_events",
             self._get_missing_events,
             room_id,
@@ -560,7 +553,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             latest_events,
             limit,
         )
-        events = yield self.get_events_as_list(ids)
+        events = await self.get_events_as_list(ids)
         return events
 
     def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
@@ -595,17 +588,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         event_results.reverse()
         return event_results
 
-    @defer.inlineCallbacks
-    def get_successor_events(self, event_ids):
+    async def get_successor_events(self, event_ids: Iterable[str]) -> List[str]:
         """Fetch all events that have the given events as a prev event
 
         Args:
-            event_ids (iterable[str])
-
-        Returns:
-            Deferred[list[str]]
+            event_ids: The events to use as the previous events.
         """
-        rows = yield self.db_pool.simple_select_many_batch(
+        rows = await self.db_pool.simple_select_many_batch(
             table="event_edges",
             column="prev_event_id",
             iterable=event_ids,
@@ -674,8 +663,7 @@ class EventFederationStore(EventFederationWorkerStore):
         txn.execute(query, (room_id,))
         txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
 
-    @defer.inlineCallbacks
-    def _background_delete_non_state_event_auth(self, progress, batch_size):
+    async def _background_delete_non_state_event_auth(self, progress, batch_size):
         def delete_event_auth(txn):
             target_min_stream_id = progress.get("target_min_stream_id_inclusive")
             max_stream_id = progress.get("max_stream_id_exclusive")
@@ -714,12 +702,12 @@ class EventFederationStore(EventFederationWorkerStore):
 
             return min_stream_id >= target_min_stream_id
 
-        result = yield self.db_pool.runInteraction(
+        result = await self.db_pool.runInteraction(
             self.EVENT_AUTH_STATE_ONLY, delete_event_auth
         )
 
         if not result:
-            yield self.db_pool.updates._end_background_update(
+            await self.db_pool.updates._end_background_update(
                 self.EVENT_AUTH_STATE_ONLY
             )
 
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index f618629e09..402ae25571 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -17,9 +17,8 @@
 
 import logging
 import re
-from typing import Optional
+from typing import Dict, List, Optional
 
-from twisted.internet import defer
 from twisted.internet.defer import Deferred
 
 from synapse.api.constants import UserTypes
@@ -30,7 +29,7 @@ from synapse.storage.database import DatabasePool
 from synapse.storage.types import Cursor
 from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import UserID
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached
 
 THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
 
@@ -69,19 +68,15 @@ class RegistrationWorkerStore(SQLBaseStore):
             desc="get_user_by_id",
         )
 
-    @defer.inlineCallbacks
-    def is_trial_user(self, user_id):
+    async def is_trial_user(self, user_id: str) -> bool:
         """Checks if user is in the "trial" period, i.e. within the first
         N days of registration defined by `mau_trial_days` config
 
         Args:
-            user_id (str)
-
-        Returns:
-            Deferred[bool]
+            user_id: The user to check for trial status.
         """
 
-        info = yield self.get_user_by_id(user_id)
+        info = await self.get_user_by_id(user_id)
         if not info:
             return False
 
@@ -105,41 +100,42 @@ class RegistrationWorkerStore(SQLBaseStore):
             "get_user_by_access_token", self._query_for_auth, token
         )
 
-    @cachedInlineCallbacks()
-    def get_expiration_ts_for_user(self, user_id):
+    @cached()
+    async def get_expiration_ts_for_user(self, user_id: str) -> Optional[None]:
         """Get the expiration timestamp for the account bearing a given user ID.
 
         Args:
-            user_id (str): The ID of the user.
+            user_id: The ID of the user.
         Returns:
-            defer.Deferred: None, if the account has no expiration timestamp,
-                otherwise int representation of the timestamp (as a number of
-                milliseconds since epoch).
+            None, if the account has no expiration timestamp, otherwise int
+            representation of the timestamp (as a number of milliseconds since epoch).
         """
-        res = yield self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="account_validity",
             keyvalues={"user_id": user_id},
             retcol="expiration_ts_ms",
             allow_none=True,
             desc="get_expiration_ts_for_user",
         )
-        return res
 
-    @defer.inlineCallbacks
-    def set_account_validity_for_user(
-        self, user_id, expiration_ts, email_sent, renewal_token=None
-    ):
+    async def set_account_validity_for_user(
+        self,
+        user_id: str,
+        expiration_ts: int,
+        email_sent: bool,
+        renewal_token: Optional[str] = None,
+    ) -> None:
         """Updates the account validity properties of the given account, with the
         given values.
 
         Args:
-            user_id (str): ID of the account to update properties for.
-            expiration_ts (int): New expiration date, as a timestamp in milliseconds
+            user_id: ID of the account to update properties for.
+            expiration_ts: New expiration date, as a timestamp in milliseconds
                 since epoch.
-            email_sent (bool): True means a renewal email has been sent for this
-                account and there's no need to send another one for the current validity
+            email_sent: True means a renewal email has been sent for this account
+                and there's no need to send another one for the current validity
                 period.
-            renewal_token (str): Renewal token the user can use to extend the validity
+            renewal_token: Renewal token the user can use to extend the validity
                 of their account. Defaults to no token.
         """
 
@@ -158,75 +154,69 @@ class RegistrationWorkerStore(SQLBaseStore):
                 txn, self.get_expiration_ts_for_user, (user_id,)
             )
 
-        yield self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "set_account_validity_for_user", set_account_validity_for_user_txn
         )
 
-    @defer.inlineCallbacks
-    def set_renewal_token_for_user(self, user_id, renewal_token):
+    async def set_renewal_token_for_user(
+        self, user_id: str, renewal_token: str
+    ) -> None:
         """Defines a renewal token for a given user.
 
         Args:
-            user_id (str): ID of the user to set the renewal token for.
-            renewal_token (str): Random unique string that will be used to renew the
+            user_id: ID of the user to set the renewal token for.
+            renewal_token: Random unique string that will be used to renew the
                 user's account.
 
         Raises:
             StoreError: The provided token is already set for another user.
         """
-        yield self.db_pool.simple_update_one(
+        await self.db_pool.simple_update_one(
             table="account_validity",
             keyvalues={"user_id": user_id},
             updatevalues={"renewal_token": renewal_token},
             desc="set_renewal_token_for_user",
         )
 
-    @defer.inlineCallbacks
-    def get_user_from_renewal_token(self, renewal_token):
+    async def get_user_from_renewal_token(self, renewal_token: str) -> str:
         """Get a user ID from a renewal token.
 
         Args:
-            renewal_token (str): The renewal token to perform the lookup with.
+            renewal_token: The renewal token to perform the lookup with.
 
         Returns:
-            defer.Deferred[str]: The ID of the user to which the token belongs.
+            The ID of the user to which the token belongs.
         """
-        res = yield self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="account_validity",
             keyvalues={"renewal_token": renewal_token},
             retcol="user_id",
             desc="get_user_from_renewal_token",
         )
 
-        return res
-
-    @defer.inlineCallbacks
-    def get_renewal_token_for_user(self, user_id):
+    async def get_renewal_token_for_user(self, user_id: str) -> str:
         """Get the renewal token associated with a given user ID.
 
         Args:
-            user_id (str): The user ID to lookup a token for.
+            user_id: The user ID to lookup a token for.
 
         Returns:
-            defer.Deferred[str]: The renewal token associated with this user ID.
+            The renewal token associated with this user ID.
         """
-        res = yield self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="account_validity",
             keyvalues={"user_id": user_id},
             retcol="renewal_token",
             desc="get_renewal_token_for_user",
         )
 
-        return res
-
-    @defer.inlineCallbacks
-    def get_users_expiring_soon(self):
+    async def get_users_expiring_soon(self) -> List[Dict[str, int]]:
         """Selects users whose account will expire in the [now, now + renew_at] time
         window (see configuration for account_validity for information on what renew_at
         refers to).
 
         Returns:
-            Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]]
+            A list of dictionaries mapping user ID to expiration time (in milliseconds).
         """
 
         def select_users_txn(txn, now_ms, renew_at):
@@ -238,53 +228,49 @@ class RegistrationWorkerStore(SQLBaseStore):
             txn.execute(sql, values)
             return self.db_pool.cursor_to_dict(txn)
 
-        res = yield self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_users_expiring_soon",
             select_users_txn,
             self.clock.time_msec(),
             self.config.account_validity.renew_at,
         )
 
-        return res
-
-    @defer.inlineCallbacks
-    def set_renewal_mail_status(self, user_id, email_sent):
+    async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None:
         """Sets or unsets the flag that indicates whether a renewal email has been sent
         to the user (and the user hasn't renewed their account yet).
 
         Args:
-            user_id (str): ID of the user to set/unset the flag for.
-            email_sent (bool): Flag which indicates whether a renewal email has been sent
+            user_id: ID of the user to set/unset the flag for.
+            email_sent: Flag which indicates whether a renewal email has been sent
                 to this user.
         """
-        yield self.db_pool.simple_update_one(
+        await self.db_pool.simple_update_one(
             table="account_validity",
             keyvalues={"user_id": user_id},
             updatevalues={"email_sent": email_sent},
             desc="set_renewal_mail_status",
         )
 
-    @defer.inlineCallbacks
-    def delete_account_validity_for_user(self, user_id):
+    async def delete_account_validity_for_user(self, user_id: str) -> None:
         """Deletes the entry for the given user in the account validity table, removing
         their expiration date and renewal token.
 
         Args:
-            user_id (str): ID of the user to remove from the account validity table.
+            user_id: ID of the user to remove from the account validity table.
         """
-        yield self.db_pool.simple_delete_one(
+        await self.db_pool.simple_delete_one(
             table="account_validity",
             keyvalues={"user_id": user_id},
             desc="delete_account_validity_for_user",
         )
 
-    async def is_server_admin(self, user):
+    async def is_server_admin(self, user: UserID) -> bool:
         """Determines if a user is an admin of this homeserver.
 
         Args:
-            user (UserID): user ID of the user to test
+            user: user ID of the user to test
 
-        Returns (bool):
+        Returns:
             true iff the user is a server admin, false otherwise.
         """
         res = await self.db_pool.simple_select_one_onecol(
@@ -332,32 +318,31 @@ class RegistrationWorkerStore(SQLBaseStore):
 
         return None
 
-    @cachedInlineCallbacks()
-    def is_real_user(self, user_id):
+    @cached()
+    async def is_real_user(self, user_id: str) -> bool:
         """Determines if the user is a real user, ie does not have a 'user_type'.
 
         Args:
-            user_id (str): user id to test
+            user_id: user id to test
 
         Returns:
-            Deferred[bool]: True if user 'user_type' is null or empty string
+            True if user 'user_type' is null or empty string
         """
-        res = yield self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "is_real_user", self.is_real_user_txn, user_id
         )
-        return res
 
     @cached()
-    def is_support_user(self, user_id):
+    async def is_support_user(self, user_id: str) -> bool:
         """Determines if the user is of type UserTypes.SUPPORT
 
         Args:
-            user_id (str): user id to test
+            user_id: user id to test
 
         Returns:
-            Deferred[bool]: True if user is of type UserTypes.SUPPORT
+            True if user is of type UserTypes.SUPPORT
         """
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "is_support_user", self.is_support_user_txn, user_id
         )
 
@@ -413,8 +398,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             desc="get_user_by_external_id",
         )
 
-    @defer.inlineCallbacks
-    def count_all_users(self):
+    async def count_all_users(self):
         """Counts all users registered on the homeserver."""
 
         def _count_users(txn):
@@ -424,8 +408,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 return rows[0]["users"]
             return 0
 
-        ret = yield self.db_pool.runInteraction("count_users", _count_users)
-        return ret
+        return await self.db_pool.runInteraction("count_users", _count_users)
 
     def count_daily_user_type(self):
         """
@@ -460,8 +443,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             "count_daily_user_type", _count_daily_user_type
         )
 
-    @defer.inlineCallbacks
-    def count_nonbridged_users(self):
+    async def count_nonbridged_users(self):
         def _count_users(txn):
             txn.execute(
                 """
@@ -472,11 +454,9 @@ class RegistrationWorkerStore(SQLBaseStore):
             (count,) = txn.fetchone()
             return count
 
-        ret = yield self.db_pool.runInteraction("count_users", _count_users)
-        return ret
+        return await self.db_pool.runInteraction("count_users", _count_users)
 
-    @defer.inlineCallbacks
-    def count_real_users(self):
+    async def count_real_users(self):
         """Counts all users without a special user_type registered on the homeserver."""
 
         def _count_users(txn):
@@ -486,8 +466,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 return rows[0]["users"]
             return 0
 
-        ret = yield self.db_pool.runInteraction("count_real_users", _count_users)
-        return ret
+        return await self.db_pool.runInteraction("count_real_users", _count_users)
 
     async def generate_user_id(self) -> str:
         """Generate a suitable localpart for a guest user
@@ -537,23 +516,20 @@ class RegistrationWorkerStore(SQLBaseStore):
             return ret["user_id"]
         return None
 
-    @defer.inlineCallbacks
-    def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
-        yield self.db_pool.simple_upsert(
+    async def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
+        await self.db_pool.simple_upsert(
             "user_threepids",
             {"medium": medium, "address": address},
             {"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
         )
 
-    @defer.inlineCallbacks
-    def user_get_threepids(self, user_id):
-        ret = yield self.db_pool.simple_select_list(
+    async def user_get_threepids(self, user_id):
+        return await self.db_pool.simple_select_list(
             "user_threepids",
             {"user_id": user_id},
             ["medium", "address", "validated_at", "added_at"],
             "user_get_threepids",
         )
-        return ret
 
     def user_delete_threepid(self, user_id, medium, address):
         return self.db_pool.simple_delete(
@@ -668,18 +644,18 @@ class RegistrationWorkerStore(SQLBaseStore):
             desc="get_id_servers_user_bound",
         )
 
-    @cachedInlineCallbacks()
-    def get_user_deactivated_status(self, user_id):
+    @cached()
+    async def get_user_deactivated_status(self, user_id: str) -> bool:
         """Retrieve the value for the `deactivated` property for the provided user.
 
         Args:
-            user_id (str): The ID of the user to retrieve the status for.
+            user_id: The ID of the user to retrieve the status for.
 
         Returns:
-            defer.Deferred(bool): The requested value.
+            True if the user was deactivated, false if the user is still active.
         """
 
-        res = yield self.db_pool.simple_select_one_onecol(
+        res = await self.db_pool.simple_select_one_onecol(
             table="users",
             keyvalues={"name": user_id},
             retcol="deactivated",
@@ -818,8 +794,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
             "users_set_deactivated_flag", self._background_update_set_deactivated_flag
         )
 
-    @defer.inlineCallbacks
-    def _background_update_set_deactivated_flag(self, progress, batch_size):
+    async def _background_update_set_deactivated_flag(self, progress, batch_size):
         """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
         for each of them.
         """
@@ -870,19 +845,18 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
             else:
                 return False, len(rows)
 
-        end, nb_processed = yield self.db_pool.runInteraction(
+        end, nb_processed = await self.db_pool.runInteraction(
             "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn
         )
 
         if end:
-            yield self.db_pool.updates._end_background_update(
+            await self.db_pool.updates._end_background_update(
                 "users_set_deactivated_flag"
             )
 
         return nb_processed
 
-    @defer.inlineCallbacks
-    def _bg_user_threepids_grandfather(self, progress, batch_size):
+    async def _bg_user_threepids_grandfather(self, progress, batch_size):
         """We now track which identity servers a user binds their 3PID to, so
         we need to handle the case of existing bindings where we didn't track
         this.
@@ -903,11 +877,11 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
             txn.executemany(sql, [(id_server,) for id_server in id_servers])
 
         if id_servers:
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
             )
 
-        yield self.db_pool.updates._end_background_update("user_threepids_grandfather")
+        await self.db_pool.updates._end_background_update("user_threepids_grandfather")
 
         return 1
 
@@ -937,23 +911,26 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
 
         hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
 
-    @defer.inlineCallbacks
-    def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
+    async def add_access_token_to_user(
+        self,
+        user_id: str,
+        token: str,
+        device_id: Optional[str],
+        valid_until_ms: Optional[int],
+    ) -> None:
         """Adds an access token for the given user.
 
         Args:
-            user_id (str): The user ID.
-            token (str): The new access token to add.
-            device_id (str): ID of the device to associate with the access
-                token
-            valid_until_ms (int|None): when the token is valid until. None for
-                no expiry.
+            user_id: The user ID.
+            token: The new access token to add.
+            device_id: ID of the device to associate with the access token
+            valid_until_ms: when the token is valid until. None for no expiry.
         Raises:
             StoreError if there was a problem adding this.
         """
         next_id = self._access_tokens_id_gen.get_next()
 
-        yield self.db_pool.simple_insert(
+        await self.db_pool.simple_insert(
             "access_tokens",
             {
                 "id": next_id,
@@ -1097,7 +1074,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             )
 
         self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
-        txn.call_after(self.is_guest.invalidate, (user_id,))
 
     def record_user_external_id(
         self, auth_provider: str, external_id: str, user_id: str
@@ -1241,9 +1217,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
 
         return self.db_pool.runInteraction("delete_access_token", f)
 
-    @cachedInlineCallbacks()
-    def is_guest(self, user_id):
-        res = yield self.db_pool.simple_select_one_onecol(
+    @cached()
+    async def is_guest(self, user_id: str) -> bool:
+        res = await self.db_pool.simple_select_one_onecol(
             table="users",
             keyvalues={"name": user_id},
             retcol="is_guest",
@@ -1481,16 +1457,17 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             self.clock.time_msec(),
         )
 
-    @defer.inlineCallbacks
-    def set_user_deactivated_status(self, user_id, deactivated):
+    async def set_user_deactivated_status(
+        self, user_id: str, deactivated: bool
+    ) -> None:
         """Set the `deactivated` property for the provided user to the provided value.
 
         Args:
-            user_id (str): The ID of the user to set the status for.
-            deactivated (bool): The value to set for `deactivated`.
+            user_id: The ID of the user to set the status for.
+            deactivated: The value to set for `deactivated`.
         """
 
-        yield self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "set_user_deactivated_status",
             self.set_user_deactivated_status_txn,
             user_id,
@@ -1507,9 +1484,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         self._invalidate_cache_and_stream(
             txn, self.get_user_deactivated_status, (user_id,)
         )
+        txn.call_after(self.is_guest.invalidate, (user_id,))
 
-    @defer.inlineCallbacks
-    def _set_expiration_date_when_missing(self):
+    async def _set_expiration_date_when_missing(self):
         """
         Retrieves the list of registered users that don't have an expiration date, and
         adds an expiration date for each of them.
@@ -1533,7 +1510,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                         txn, user["name"], use_delta=True
                     )
 
-        yield self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "get_users_with_no_expiration_date",
             select_users_with_no_expiration_date_txn,
         )