summary refs log tree commit diff
path: root/synapse/storage/databases/main/registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/registration.py')
-rw-r--r--synapse/storage/databases/main/registration.py57
1 files changed, 36 insertions, 21 deletions
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py

index 068ad22b30..48bda66f3e 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -17,7 +17,7 @@ import logging import re -from typing import Awaitable, Dict, List, Optional +from typing import Any, Awaitable, Dict, List, Optional from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError @@ -46,8 +46,8 @@ class RegistrationWorkerStore(SQLBaseStore): ) @cached() - def get_user_by_id(self, user_id): - return self.db_pool.simple_select_one( + async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( table="users", keyvalues={"name": user_id}, retcols=[ @@ -578,20 +578,20 @@ class RegistrationWorkerStore(SQLBaseStore): desc="add_user_bound_threepid", ) - def user_get_bound_threepids(self, user_id): + async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]: """Get the threepids that a user has bound to an identity server through the homeserver The homeserver remembers where binds to an identity server occurred. Using this method can retrieve those threepids. Args: - user_id (str): The ID of the user to retrieve threepids for + user_id: The ID of the user to retrieve threepids for Returns: - Deferred[list[dict]]: List of dictionaries containing the following: + List of dictionaries containing the following keys: medium (str): The medium of the threepid (e.g "email") address (str): The address of the threepid (e.g "bob@example.com") """ - return self.db_pool.simple_select_list( + return await self.db_pool.simple_select_list( table="user_threepid_id_server", keyvalues={"user_id": user_id}, retcols=["medium", "address"], @@ -623,19 +623,21 @@ class RegistrationWorkerStore(SQLBaseStore): desc="remove_user_bound_threepid", ) - def get_id_servers_user_bound(self, user_id, medium, address): + async def get_id_servers_user_bound( + self, user_id: str, medium: str, address: str + ) -> List[str]: """Get the list of identity servers that the server proxied bind requests to for given user and threepid Args: - user_id (str) - medium (str) - address (str) + user_id: The user to query for identity servers. + medium: The medium to query for identity servers. + address: The address to query for identity servers. Returns: - Deferred[list[str]]: Resolves to a list of identity servers + A list of identity servers """ - return self.db_pool.simple_select_onecol( + return await self.db_pool.simple_select_onecol( table="user_threepid_id_server", keyvalues={"user_id": user_id, "medium": medium, "address": address}, retcol="id_server", @@ -889,6 +891,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): super(RegistrationStore, self).__init__(database, db_conn, hs) self._account_validity = hs.config.account_validity + self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors if self._account_validity.enabled: self._clock.call_later( @@ -1258,12 +1261,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): desc="del_user_pending_deactivation", ) - def get_user_pending_deactivation(self): + async def get_user_pending_deactivation(self) -> Optional[str]: """ Gets one user from the table of users waiting to be parted from all the rooms they're in. """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( "users_pending_deactivation", keyvalues={}, retcol="user_id", @@ -1302,15 +1305,22 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) if not row: - raise ThreepidValidationError(400, "Unknown session_id") + if self._ignore_unknown_session_error: + # If we need to inhibit the error caused by an incorrect session ID, + # use None as placeholder values for the client secret and the + # validation timestamp. + # It shouldn't be an issue because they're both only checked after + # the token check, which should fail. And if it doesn't for some + # reason, the next check is on the client secret, which is NOT NULL, + # so we don't have to worry about the client secret matching by + # accident. + row = {"client_secret": None, "validated_at": None} + else: + raise ThreepidValidationError(400, "Unknown session_id") + retrieved_client_secret = row["client_secret"] validated_at = row["validated_at"] - if retrieved_client_secret != client_secret: - raise ThreepidValidationError( - 400, "This client_secret does not match the provided session_id" - ) - row = self.db_pool.simple_select_one_txn( txn, table="threepid_validation_token", @@ -1326,6 +1336,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): expires = row["expires"] next_link = row["next_link"] + if retrieved_client_secret != client_secret: + raise ThreepidValidationError( + 400, "This client_secret does not match the provided session_id" + ) + # If the session is already validated, no need to revalidate if validated_at: return next_link