diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 068ad22b30..48bda66f3e 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -17,7 +17,7 @@
import logging
import re
-from typing import Awaitable, Dict, List, Optional
+from typing import Any, Awaitable, Dict, List, Optional
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@@ -46,8 +46,8 @@ class RegistrationWorkerStore(SQLBaseStore):
)
@cached()
- def get_user_by_id(self, user_id):
- return self.db_pool.simple_select_one(
+ async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
+ return await self.db_pool.simple_select_one(
table="users",
keyvalues={"name": user_id},
retcols=[
@@ -578,20 +578,20 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="add_user_bound_threepid",
)
- def user_get_bound_threepids(self, user_id):
+ async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]:
"""Get the threepids that a user has bound to an identity server through the homeserver
The homeserver remembers where binds to an identity server occurred. Using this
method can retrieve those threepids.
Args:
- user_id (str): The ID of the user to retrieve threepids for
+ user_id: The ID of the user to retrieve threepids for
Returns:
- Deferred[list[dict]]: List of dictionaries containing the following:
+ List of dictionaries containing the following keys:
medium (str): The medium of the threepid (e.g "email")
address (str): The address of the threepid (e.g "bob@example.com")
"""
- return self.db_pool.simple_select_list(
+ return await self.db_pool.simple_select_list(
table="user_threepid_id_server",
keyvalues={"user_id": user_id},
retcols=["medium", "address"],
@@ -623,19 +623,21 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="remove_user_bound_threepid",
)
- def get_id_servers_user_bound(self, user_id, medium, address):
+ async def get_id_servers_user_bound(
+ self, user_id: str, medium: str, address: str
+ ) -> List[str]:
"""Get the list of identity servers that the server proxied bind
requests to for given user and threepid
Args:
- user_id (str)
- medium (str)
- address (str)
+ user_id: The user to query for identity servers.
+ medium: The medium to query for identity servers.
+ address: The address to query for identity servers.
Returns:
- Deferred[list[str]]: Resolves to a list of identity servers
+ A list of identity servers
"""
- return self.db_pool.simple_select_onecol(
+ return await self.db_pool.simple_select_onecol(
table="user_threepid_id_server",
keyvalues={"user_id": user_id, "medium": medium, "address": address},
retcol="id_server",
@@ -889,6 +891,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
super(RegistrationStore, self).__init__(database, db_conn, hs)
self._account_validity = hs.config.account_validity
+ self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
if self._account_validity.enabled:
self._clock.call_later(
@@ -1258,12 +1261,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="del_user_pending_deactivation",
)
- def get_user_pending_deactivation(self):
+ async def get_user_pending_deactivation(self) -> Optional[str]:
"""
Gets one user from the table of users waiting to be parted from all the rooms
they're in.
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
"users_pending_deactivation",
keyvalues={},
retcol="user_id",
@@ -1302,15 +1305,22 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
if not row:
- raise ThreepidValidationError(400, "Unknown session_id")
+ if self._ignore_unknown_session_error:
+ # If we need to inhibit the error caused by an incorrect session ID,
+ # use None as placeholder values for the client secret and the
+ # validation timestamp.
+ # It shouldn't be an issue because they're both only checked after
+ # the token check, which should fail. And if it doesn't for some
+ # reason, the next check is on the client secret, which is NOT NULL,
+ # so we don't have to worry about the client secret matching by
+ # accident.
+ row = {"client_secret": None, "validated_at": None}
+ else:
+ raise ThreepidValidationError(400, "Unknown session_id")
+
retrieved_client_secret = row["client_secret"]
validated_at = row["validated_at"]
- if retrieved_client_secret != client_secret:
- raise ThreepidValidationError(
- 400, "This client_secret does not match the provided session_id"
- )
-
row = self.db_pool.simple_select_one_txn(
txn,
table="threepid_validation_token",
@@ -1326,6 +1336,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
expires = row["expires"]
next_link = row["next_link"]
+ if retrieved_client_secret != client_secret:
+ raise ThreepidValidationError(
+ 400, "This client_secret does not match the provided session_id"
+ )
+
# If the session is already validated, no need to revalidate
if validated_at:
return next_link
|