summary refs log tree commit diff
path: root/synapse/storage/databases/main/registration.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2021-09-08 11:01:19 +0100
committerErik Johnston <erik@matrix.org>2021-09-08 11:01:19 +0100
commitd381eae552b5b0a1f35fe55318517b29187d64db (patch)
treeff02edf7baabcd1e48454ff77ab2feb9ef78aefa /synapse/storage/databases/main/registration.py
parentDefer verification to thread pool (diff)
parentReturn stripped m.space.child events via the space summary. (#10760) (diff)
downloadsynapse-d381eae552b5b0a1f35fe55318517b29187d64db.tar.xz
Merge remote-tracking branch 'origin/develop' into erikj/join_logging
Diffstat (limited to 'synapse/storage/databases/main/registration.py')
-rw-r--r--synapse/storage/databases/main/registration.py393
1 files changed, 385 insertions, 8 deletions
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py

index 6ad1a0cf7f..a6517962f6 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -29,7 +29,7 @@ from synapse.storage.databases.main.stats import StatsStore from synapse.storage.types import Connection, Cursor from synapse.storage.util.id_generators import IdGenerator from synapse.storage.util.sequence import build_sequence_generator -from synapse.types import UserID +from synapse.types import UserID, UserInfo from synapse.util.caches.descriptors import cached if TYPE_CHECKING: @@ -146,6 +146,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): @cached() async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: + """Deprecated: use get_userinfo_by_id instead""" return await self.db_pool.simple_select_one( table="users", keyvalues={"name": user_id}, @@ -166,6 +167,33 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="get_user_by_id", ) + async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]: + """Get a UserInfo object for a user by user ID. + + Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed, + this method should be cached. + + Args: + user_id: The user to fetch user info for. + Returns: + `UserInfo` object if user found, otherwise `None`. + """ + user_data = await self.get_user_by_id(user_id) + if not user_data: + return None + return UserInfo( + appservice_id=user_data["appservice_id"], + consent_server_notice_sent=user_data["consent_server_notice_sent"], + consent_version=user_data["consent_version"], + creation_ts=user_data["creation_ts"], + is_admin=bool(user_data["admin"]), + is_deactivated=bool(user_data["deactivated"]), + is_guest=bool(user_data["is_guest"]), + is_shadow_banned=bool(user_data["shadow_banned"]), + user_id=UserID.from_string(user_data["name"]), + user_type=user_data["user_type"], + ) + async def is_trial_user(self, user_id: str) -> bool: """Checks if user is in the "trial" period, i.e. within the first N days of registration defined by `mau_trial_days` config @@ -571,6 +599,28 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="record_user_external_id", ) + async def remove_user_external_id( + self, auth_provider: str, external_id: str, user_id: str + ) -> None: + """Remove a mapping from an external user id to a mxid + + If the mapping is not found, this method does nothing. + + Args: + auth_provider: identifier for the remote auth provider + external_id: id on that system + user_id: complete mxid that it is mapped to + """ + await self.db_pool.simple_delete( + table="user_external_ids", + keyvalues={ + "auth_provider": auth_provider, + "external_id": external_id, + "user_id": user_id, + }, + desc="remove_user_external_id", + ) + async def get_user_by_external_id( self, auth_provider: str, external_id: str ) -> Optional[str]: @@ -704,16 +754,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): ) return user_id - def get_user_id_by_threepid_txn(self, txn, medium, address): + def get_user_id_by_threepid_txn( + self, txn, medium: str, address: str + ) -> Optional[str]: """Returns user id from threepid Args: txn (cursor): - medium (str): threepid medium e.g. email - address (str): threepid address e.g. me@example.com + medium: threepid medium e.g. email + address: threepid address e.g. me@example.com Returns: - str|None: user id or None if no user id/threepid mapping exists + user id, or None if no user id/threepid mapping exists """ ret = self.db_pool.simple_select_one_txn( txn, @@ -726,14 +778,21 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return ret["user_id"] return None - async def user_add_threepid(self, user_id, medium, address, validated_at, added_at): + async def user_add_threepid( + self, + user_id: str, + medium: str, + address: str, + validated_at: int, + added_at: int, + ) -> None: await self.db_pool.simple_upsert( "user_threepids", {"medium": medium, "address": address}, {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, ) - async def user_get_threepids(self, user_id): + async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]: return await self.db_pool.simple_select_list( "user_threepids", {"user_id": user_id}, @@ -741,7 +800,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "user_get_threepids", ) - async def user_delete_threepid(self, user_id, medium, address) -> None: + async def user_delete_threepid( + self, user_id: str, medium: str, address: str + ) -> None: await self.db_pool.simple_delete( "user_threepids", keyvalues={"user_id": user_id, "medium": medium, "address": address}, @@ -1107,6 +1168,322 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="update_access_token_last_validated", ) + async def registration_token_is_valid(self, token: str) -> bool: + """Checks if a token can be used to authenticate a registration. + + Args: + token: The registration token to be checked + Returns: + True if the token is valid, False otherwise. + """ + res = await self.db_pool.simple_select_one( + "registration_tokens", + keyvalues={"token": token}, + retcols=["uses_allowed", "pending", "completed", "expiry_time"], + allow_none=True, + ) + + # Check if the token exists + if res is None: + return False + + # Check if the token has expired + now = self._clock.time_msec() + if res["expiry_time"] and res["expiry_time"] < now: + return False + + # Check if the token has been used up + if ( + res["uses_allowed"] + and res["pending"] + res["completed"] >= res["uses_allowed"] + ): + return False + + # Otherwise, the token is valid + return True + + async def set_registration_token_pending(self, token: str) -> None: + """Increment the pending registrations counter for a token. + + Args: + token: The registration token pending use + """ + + def _set_registration_token_pending_txn(txn): + pending = self.db_pool.simple_select_one_onecol_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + retcol="pending", + ) + self.db_pool.simple_update_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + updatevalues={"pending": pending + 1}, + ) + + return await self.db_pool.runInteraction( + "set_registration_token_pending", _set_registration_token_pending_txn + ) + + async def use_registration_token(self, token: str) -> None: + """Complete a use of the given registration token. + + The `pending` counter will be decremented, and the `completed` + counter will be incremented. + + Args: + token: The registration token to be 'used' + """ + + def _use_registration_token_txn(txn): + # Normally, res is Optional[Dict[str, Any]]. + # Override type because the return type is only optional if + # allow_none is True, and we don't want mypy throwing errors + # about None not being indexable. + res: Dict[str, Any] = self.db_pool.simple_select_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + retcols=["pending", "completed"], + ) # type: ignore + + # Decrement pending and increment completed + self.db_pool.simple_update_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + updatevalues={ + "completed": res["completed"] + 1, + "pending": res["pending"] - 1, + }, + ) + + return await self.db_pool.runInteraction( + "use_registration_token", _use_registration_token_txn + ) + + async def get_registration_tokens( + self, valid: Optional[bool] = None + ) -> List[Dict[str, Any]]: + """List all registration tokens. Used by the admin API. + + Args: + valid: If True, only valid tokens are returned. + If False, only invalid tokens are returned. + Default is None: return all tokens regardless of validity. + + Returns: + A list of dicts, each containing details of a token. + """ + + def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]): + if valid is None: + # Return all tokens regardless of validity + txn.execute("SELECT * FROM registration_tokens") + + elif valid: + # Select valid tokens only + sql = ( + "SELECT * FROM registration_tokens WHERE " + "(uses_allowed > pending + completed OR uses_allowed IS NULL) " + "AND (expiry_time > ? OR expiry_time IS NULL)" + ) + txn.execute(sql, [now]) + + else: + # Select invalid tokens only + sql = ( + "SELECT * FROM registration_tokens WHERE " + "uses_allowed <= pending + completed OR expiry_time <= ?" + ) + txn.execute(sql, [now]) + + return self.db_pool.cursor_to_dict(txn) + + return await self.db_pool.runInteraction( + "select_registration_tokens", + select_registration_tokens_txn, + self._clock.time_msec(), + valid, + ) + + async def get_one_registration_token(self, token: str) -> Optional[Dict[str, Any]]: + """Get info about the given registration token. Used by the admin API. + + Args: + token: The token to retrieve information about. + + Returns: + A dict, or None if token doesn't exist. + """ + return await self.db_pool.simple_select_one( + "registration_tokens", + keyvalues={"token": token}, + retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"], + allow_none=True, + desc="get_one_registration_token", + ) + + async def generate_registration_token( + self, length: int, chars: str + ) -> Optional[str]: + """Generate a random registration token. Used by the admin API. + + Args: + length: The length of the token to generate. + chars: A string of the characters allowed in the generated token. + + Returns: + The generated token. + + Raises: + SynapseError if a unique registration token could still not be + generated after a few tries. + """ + # Make a few attempts at generating a unique token of the required + # length before failing. + for _i in range(3): + # Generate token + token = "".join(random.choices(chars, k=length)) + + # Check if the token already exists + existing_token = await self.db_pool.simple_select_one_onecol( + "registration_tokens", + keyvalues={"token": token}, + retcol="token", + allow_none=True, + desc="check_if_registration_token_exists", + ) + + if existing_token is None: + # The generated token doesn't exist yet, return it + return token + + raise SynapseError( + 500, + "Unable to generate a unique registration token. Try again with a greater length", + Codes.UNKNOWN, + ) + + async def create_registration_token( + self, token: str, uses_allowed: Optional[int], expiry_time: Optional[int] + ) -> bool: + """Create a new registration token. Used by the admin API. + + Args: + token: The token to create. + uses_allowed: The number of times the token can be used to complete + a registration before it becomes invalid. A value of None indicates + unlimited uses. + expiry_time: The latest time the token is valid. Given as the + number of milliseconds since 1970-01-01 00:00:00 UTC. A value of + None indicates that the token does not expire. + + Returns: + Whether the row was inserted or not. + """ + + def _create_registration_token_txn(txn): + row = self.db_pool.simple_select_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + retcols=["token"], + allow_none=True, + ) + + if row is not None: + # Token already exists + return False + + self.db_pool.simple_insert_txn( + txn, + "registration_tokens", + values={ + "token": token, + "uses_allowed": uses_allowed, + "pending": 0, + "completed": 0, + "expiry_time": expiry_time, + }, + ) + + return True + + return await self.db_pool.runInteraction( + "create_registration_token", _create_registration_token_txn + ) + + async def update_registration_token( + self, token: str, updatevalues: Dict[str, Optional[int]] + ) -> Optional[Dict[str, Any]]: + """Update a registration token. Used by the admin API. + + Args: + token: The token to update. + updatevalues: A dict with the fields to update. E.g.: + `{"uses_allowed": 3}` to update just uses_allowed, or + `{"uses_allowed": 3, "expiry_time": None}` to update both. + This is passed straight to simple_update_one. + + Returns: + A dict with all info about the token, or None if token doesn't exist. + """ + + def _update_registration_token_txn(txn): + try: + self.db_pool.simple_update_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + updatevalues=updatevalues, + ) + except StoreError: + # Update failed because token does not exist + return None + + # Get all info about the token so it can be sent in the response + return self.db_pool.simple_select_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + retcols=[ + "token", + "uses_allowed", + "pending", + "completed", + "expiry_time", + ], + allow_none=True, + ) + + return await self.db_pool.runInteraction( + "update_registration_token", _update_registration_token_txn + ) + + async def delete_registration_token(self, token: str) -> bool: + """Delete a registration token. Used by the admin API. + + Args: + token: The token to delete. + + Returns: + Whether the token was successfully deleted or not. + """ + try: + await self.db_pool.simple_delete_one( + "registration_tokens", + keyvalues={"token": token}, + desc="delete_registration_token", + ) + except StoreError: + # Deletion failed because token does not exist + return False + + return True + @cached() async def mark_access_token_as_used(self, token_id: int) -> None: """