summary refs log tree commit diff
path: root/synapse/storage/databases/main/registration.py
diff options
context:
space:
mode:
authorDirk Klimpel <5740567+dklimpel@users.noreply.github.com>2022-04-12 12:54:00 +0200
committerGitHub <noreply@github.com>2022-04-12 11:54:00 +0100
commit1783156dbcf4164692e66275d1c29857c434995b (patch)
treed2d18c66112307a447be9febed9669c55276cea0 /synapse/storage/databases/main/registration.py
parentPoetry: select olddeps using `poetry` (#12407) (diff)
downloadsynapse-1783156dbcf4164692e66275d1c29857c434995b.tar.xz
Add some type hints to datastore (#12423)
* Add some type hints to datastore

* newsfile

* change `Collection` to `List`

* refactor return type of `select_users_txn`

* correct type hint in `stream.py`

* Remove `Optional` in `select_users_txn`

* remove not needed return type in `__init__`

* Revert change in `get_stream_id_for_event_txn`

* Remove import from `Literal`
Diffstat (limited to 'synapse/storage/databases/main/registration.py')
-rw-r--r--synapse/storage/databases/main/registration.py133
1 files changed, 79 insertions, 54 deletions
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index c7634c92fd..d43163c27c 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -34,7 +34,7 @@ from synapse.storage.databases.main.stats import StatsStore
 from synapse.storage.types import Cursor
 from synapse.storage.util.id_generators import IdGenerator
 from synapse.storage.util.sequence import build_sequence_generator
-from synapse.types import UserID, UserInfo
+from synapse.types import JsonDict, UserID, UserInfo
 from synapse.util.caches.descriptors import cached
 
 if TYPE_CHECKING:
@@ -79,7 +79,7 @@ class TokenLookupResult:
 
     # Make the token owner default to the user ID, which is the common case.
     @token_owner.default
-    def _default_token_owner(self):
+    def _default_token_owner(self) -> str:
         return self.user_id
 
 
@@ -299,7 +299,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 the account.
         """
 
-        def set_account_validity_for_user_txn(txn):
+        def set_account_validity_for_user_txn(txn: LoggingTransaction) -> None:
             self.db_pool.simple_update_txn(
                 txn=txn,
                 table="account_validity",
@@ -385,23 +385,25 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             desc="get_renewal_token_for_user",
         )
 
-    async def get_users_expiring_soon(self) -> List[Dict[str, Any]]:
+    async def get_users_expiring_soon(self) -> List[Tuple[str, int]]:
         """Selects users whose account will expire in the [now, now + renew_at] time
         window (see configuration for account_validity for information on what renew_at
         refers to).
 
         Returns:
-            A list of dictionaries, each with a user ID and expiration time (in milliseconds).
+            A list of tuples, each with a user ID and expiration time (in milliseconds).
         """
 
-        def select_users_txn(txn, now_ms, renew_at):
+        def select_users_txn(
+            txn: LoggingTransaction, now_ms: int, renew_at: int
+        ) -> List[Tuple[str, int]]:
             sql = (
                 "SELECT user_id, expiration_ts_ms FROM account_validity"
                 " WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?"
             )
             values = [False, now_ms, renew_at]
             txn.execute(sql, values)
-            return self.db_pool.cursor_to_dict(txn)
+            return cast(List[Tuple[str, int]], txn.fetchall())
 
         return await self.db_pool.runInteraction(
             "get_users_expiring_soon",
@@ -466,7 +468,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             admin: true iff the user is to be a server admin, false otherwise.
         """
 
-        def set_server_admin_txn(txn):
+        def set_server_admin_txn(txn: LoggingTransaction) -> None:
             self.db_pool.simple_update_one_txn(
                 txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0}
             )
@@ -515,7 +517,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             user_type: type of the user or None for a user without a type.
         """
 
-        def set_user_type_txn(txn):
+        def set_user_type_txn(txn: LoggingTransaction) -> None:
             self.db_pool.simple_update_one_txn(
                 txn, "users", {"name": user.to_string()}, {"user_type": user_type}
             )
@@ -525,7 +527,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
         await self.db_pool.runInteraction("set_user_type", set_user_type_txn)
 
-    def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
+    def _query_for_auth(
+        self, txn: LoggingTransaction, token: str
+    ) -> Optional[TokenLookupResult]:
         sql = """
             SELECT users.name as user_id,
                 users.is_guest,
@@ -582,7 +586,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             "is_support_user", self.is_support_user_txn, user_id
         )
 
-    def is_real_user_txn(self, txn, user_id):
+    def is_real_user_txn(self, txn: LoggingTransaction, user_id: str) -> bool:
         res = self.db_pool.simple_select_one_onecol_txn(
             txn=txn,
             table="users",
@@ -592,7 +596,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         )
         return res is None
 
-    def is_support_user_txn(self, txn, user_id):
+    def is_support_user_txn(self, txn: LoggingTransaction, user_id: str) -> bool:
         res = self.db_pool.simple_select_one_onecol_txn(
             txn=txn,
             table="users",
@@ -609,10 +613,11 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
              A mapping of user_id -> password_hash.
         """
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> Dict[str, str]:
             sql = "SELECT name, password_hash FROM users WHERE lower(name) = lower(?)"
             txn.execute(sql, (user_id,))
-            return dict(txn)
+            result = cast(List[Tuple[str, str]], txn.fetchall())
+            return dict(result)
 
         return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
 
@@ -734,7 +739,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
         def _replace_user_external_id_txn(
             txn: LoggingTransaction,
-        ):
+        ) -> None:
             _remove_user_external_ids_txn(txn, user_id)
 
             for auth_provider, external_id in record_external_ids:
@@ -790,10 +795,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         )
         return [(r["auth_provider"], r["external_id"]) for r in res]
 
-    async def count_all_users(self):
+    async def count_all_users(self) -> int:
         """Counts all users registered on the homeserver."""
 
-        def _count_users(txn):
+        def _count_users(txn: LoggingTransaction) -> int:
             txn.execute("SELECT COUNT(*) AS users FROM users")
             rows = self.db_pool.cursor_to_dict(txn)
             if rows:
@@ -810,7 +815,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         who registered on the homeserver in the past 24 hours
         """
 
-        def _count_daily_user_type(txn):
+        def _count_daily_user_type(txn: LoggingTransaction) -> Dict[str, int]:
             yesterday = int(self._clock.time()) - (60 * 60 * 24)
 
             sql = """
@@ -835,23 +840,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             "count_daily_user_type", _count_daily_user_type
         )
 
-    async def count_nonbridged_users(self):
-        def _count_users(txn):
+    async def count_nonbridged_users(self) -> int:
+        def _count_users(txn: LoggingTransaction) -> int:
             txn.execute(
                 """
                 SELECT COUNT(*) FROM users
                 WHERE appservice_id IS NULL
             """
             )
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             return count
 
         return await self.db_pool.runInteraction("count_users", _count_users)
 
-    async def count_real_users(self):
+    async def count_real_users(self) -> int:
         """Counts all users without a special user_type registered on the homeserver."""
 
-        def _count_users(txn):
+        def _count_users(txn: LoggingTransaction) -> int:
             txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
             rows = self.db_pool.cursor_to_dict(txn)
             if rows:
@@ -888,7 +893,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         return user_id
 
     def get_user_id_by_threepid_txn(
-        self, txn, medium: str, address: str
+        self, txn: LoggingTransaction, medium: str, address: str
     ) -> Optional[str]:
         """Returns user id from threepid
 
@@ -925,7 +930,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             {"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
         )
 
-    async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]:
+    async def user_get_threepids(self, user_id: str) -> List[Dict[str, Any]]:
         return await self.db_pool.simple_select_list(
             "user_threepids",
             {"user_id": user_id},
@@ -957,7 +962,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
     async def add_user_bound_threepid(
         self, user_id: str, medium: str, address: str, id_server: str
-    ):
+    ) -> None:
         """The server proxied a bind request to the given identity server on
         behalf of the given user. We need to remember this in case the user
         asks us to unbind the threepid.
@@ -1116,7 +1121,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
         assert address or sid
 
-        def get_threepid_validation_session_txn(txn):
+        def get_threepid_validation_session_txn(
+            txn: LoggingTransaction,
+        ) -> Optional[Dict[str, Any]]:
             sql = """
                 SELECT address, session_id, medium, client_secret,
                 last_send_attempt, validated_at
@@ -1150,7 +1157,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             session_id: The ID of the session to delete
         """
 
-        def delete_threepid_session_txn(txn):
+        def delete_threepid_session_txn(txn: LoggingTransaction) -> None:
             self.db_pool.simple_delete_txn(
                 txn,
                 table="threepid_validation_token",
@@ -1170,7 +1177,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
     async def cull_expired_threepid_validation_tokens(self) -> None:
         """Remove threepid validation tokens with expiry dates that have passed"""
 
-        def cull_expired_threepid_validation_tokens_txn(txn, ts):
+        def cull_expired_threepid_validation_tokens_txn(
+            txn: LoggingTransaction, ts: int
+        ) -> None:
             sql = """
             DELETE FROM threepid_validation_token WHERE
             expires < ?
@@ -1184,13 +1193,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         )
 
     @wrap_as_background_process("account_validity_set_expiration_dates")
-    async def _set_expiration_date_when_missing(self):
+    async def _set_expiration_date_when_missing(self) -> None:
         """
         Retrieves the list of registered users that don't have an expiration date, and
         adds an expiration date for each of them.
         """
 
-        def select_users_with_no_expiration_date_txn(txn):
+        def select_users_with_no_expiration_date_txn(txn: LoggingTransaction) -> None:
             """Retrieves the list of registered users with no expiration date from the
             database, filtering out deactivated users.
             """
@@ -1213,7 +1222,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             select_users_with_no_expiration_date_txn,
         )
 
-    def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
+    def set_expiration_date_for_user_txn(
+        self, txn: LoggingTransaction, user_id: str, use_delta: bool = False
+    ) -> None:
         """Sets an expiration date to the account with the given user ID.
 
         Args:
@@ -1344,7 +1355,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             token: The registration token pending use
         """
 
-        def _set_registration_token_pending_txn(txn):
+        def _set_registration_token_pending_txn(txn: LoggingTransaction) -> None:
             pending = self.db_pool.simple_select_one_onecol_txn(
                 txn,
                 "registration_tokens",
@@ -1358,7 +1369,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 updatevalues={"pending": pending + 1},
             )
 
-        return await self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "set_registration_token_pending", _set_registration_token_pending_txn
         )
 
@@ -1372,7 +1383,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             token: The registration token to be 'used'
         """
 
-        def _use_registration_token_txn(txn):
+        def _use_registration_token_txn(txn: LoggingTransaction) -> None:
             # Normally, res is Optional[Dict[str, Any]].
             # Override type because the return type is only optional if
             # allow_none is True, and we don't want mypy throwing errors
@@ -1398,7 +1409,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 },
             )
 
-        return await self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "use_registration_token", _use_registration_token_txn
         )
 
@@ -1416,7 +1427,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             A list of dicts, each containing details of a token.
         """
 
-        def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]):
+        def select_registration_tokens_txn(
+            txn: LoggingTransaction, now: int, valid: Optional[bool]
+        ) -> List[Dict[str, Any]]:
             if valid is None:
                 # Return all tokens regardless of validity
                 txn.execute("SELECT * FROM registration_tokens")
@@ -1523,7 +1536,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             Whether the row was inserted or not.
         """
 
-        def _create_registration_token_txn(txn):
+        def _create_registration_token_txn(txn: LoggingTransaction) -> bool:
             row = self.db_pool.simple_select_one_txn(
                 txn,
                 "registration_tokens",
@@ -1570,7 +1583,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             A dict with all info about the token, or None if token doesn't exist.
         """
 
-        def _update_registration_token_txn(txn):
+        def _update_registration_token_txn(
+            txn: LoggingTransaction,
+        ) -> Optional[Dict[str, Any]]:
             try:
                 self.db_pool.simple_update_one_txn(
                     txn,
@@ -1651,7 +1666,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
     ) -> Optional[RefreshTokenLookupResult]:
         """Lookup a refresh token with hints about its validity."""
 
-        def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]:
+        def _lookup_refresh_token_txn(
+            txn: LoggingTransaction,
+        ) -> Optional[RefreshTokenLookupResult]:
             txn.execute(
                 """
                 SELECT
@@ -1807,14 +1824,18 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
             unique=False,
         )
 
-    async def _background_update_set_deactivated_flag(self, progress, batch_size):
+    async def _background_update_set_deactivated_flag(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
         for each of them.
         """
 
         last_user = progress.get("user_id", "")
 
-        def _background_update_set_deactivated_flag_txn(txn):
+        def _background_update_set_deactivated_flag_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[bool, int]:
             txn.execute(
                 """
                 SELECT
@@ -1886,7 +1907,9 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
             deactivated,
         )
 
-    def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool):
+    def set_user_deactivated_status_txn(
+        self, txn: LoggingTransaction, user_id: str, deactivated: bool
+    ) -> None:
         self.db_pool.simple_update_one_txn(
             txn=txn,
             table="users",
@@ -2005,7 +2028,9 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
 
         return next_id
 
-    def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
+    def _set_device_for_access_token_txn(
+        self, txn: LoggingTransaction, token: str, device_id: str
+    ) -> str:
         old_device_id = self.db_pool.simple_select_one_onecol_txn(
             txn, "access_tokens", {"token": token}, "device_id"
         )
@@ -2084,7 +2109,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
 
     def _register_user(
         self,
-        txn,
+        txn: LoggingTransaction,
         user_id: str,
         password_hash: Optional[str],
         was_guest: bool,
@@ -2094,7 +2119,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
         admin: bool,
         user_type: Optional[str],
         shadow_banned: bool,
-    ):
+    ) -> None:
         user_id_obj = UserID.from_string(user_id)
 
         now = int(self._clock.time())
@@ -2181,7 +2206,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
             pointless. Use flush_user separately.
         """
 
-        def user_set_password_hash_txn(txn):
+        def user_set_password_hash_txn(txn: LoggingTransaction) -> None:
             self.db_pool.simple_update_one_txn(
                 txn, "users", {"name": user_id}, {"password_hash": password_hash}
             )
@@ -2204,7 +2229,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
             StoreError(404) if user not found
         """
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> None:
             self.db_pool.simple_update_one_txn(
                 txn,
                 table="users",
@@ -2229,7 +2254,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
             StoreError(404) if user not found
         """
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> None:
             self.db_pool.simple_update_one_txn(
                 txn,
                 table="users",
@@ -2259,7 +2284,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
             A tuple of (token, token id, device id) for each of the deleted tokens
         """
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> List[Tuple[str, int, Optional[str]]]:
             keyvalues = {"user_id": user_id}
             if device_id is not None:
                 keyvalues["device_id"] = device_id
@@ -2301,7 +2326,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
         return await self.db_pool.runInteraction("user_delete_access_tokens", f)
 
     async def delete_access_token(self, access_token: str) -> None:
-        def f(txn):
+        def f(txn: LoggingTransaction) -> None:
             self.db_pool.simple_delete_one_txn(
                 txn, table="access_tokens", keyvalues={"token": access_token}
             )
@@ -2313,7 +2338,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
         await self.db_pool.runInteraction("delete_access_token", f)
 
     async def delete_refresh_token(self, refresh_token: str) -> None:
-        def f(txn):
+        def f(txn: LoggingTransaction) -> None:
             self.db_pool.simple_delete_one_txn(
                 txn, table="refresh_tokens", keyvalues={"token": refresh_token}
             )
@@ -2353,7 +2378,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
         """
 
         # Insert everything into a transaction in order to run atomically
-        def validate_threepid_session_txn(txn):
+        def validate_threepid_session_txn(txn: LoggingTransaction) -> Optional[str]:
             row = self.db_pool.simple_select_one_txn(
                 txn,
                 table="threepid_validation_session",
@@ -2450,7 +2475,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
                 longer be valid
         """
 
-        def start_or_continue_validation_session_txn(txn):
+        def start_or_continue_validation_session_txn(txn: LoggingTransaction) -> None:
             # Create or update a validation session
             self.db_pool.simple_upsert_txn(
                 txn,