diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index df7f8a43b7..868803e169 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -32,7 +32,6 @@ from synapse.api.errors import (
NotFoundError,
StoreError,
SynapseError,
- ThreepidValidationError,
)
from synapse.config.homeserver import HomeServerConfig
from synapse.metrics.background_process_metrics import wrap_as_background_process
@@ -149,30 +148,6 @@ class LoginTokenLookupResult:
"""The session ID advertised by the SSO Identity Provider."""
-@attr.s(frozen=True, slots=True, auto_attribs=True)
-class ThreepidResult:
- medium: str
- address: str
- validated_at: int
- added_at: int
-
-
-@attr.s(frozen=True, slots=True, auto_attribs=True)
-class ThreepidValidationSession:
- address: str
- """address of the 3pid"""
- medium: str
- """medium of the 3pid"""
- client_secret: str
- """a secret provided by the client for this validation session"""
- session_id: str
- """ID of the validation session"""
- last_send_attempt: int
- """a number serving to dedupe send attempts for this session"""
- validated_at: Optional[int]
- """timestamp of when this session was validated if so"""
-
-
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(
self,
@@ -215,12 +190,6 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
self._set_expiration_date_when_missing,
)
- # Create a background job for culling expired 3PID validity tokens
- if hs.config.worker.run_background_tasks:
- self._clock.looping_call(
- self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
- )
-
@cached()
async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]:
"""Returns info about the user account, if it exists."""
@@ -583,7 +552,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn)
- async def set_user_type(self, user: UserID, user_type: Optional[UserTypes]) -> None:
+ async def set_user_type(
+ self, user: UserID, user_type: Optional[Union[UserTypes, str]]
+ ) -> None:
"""Sets the user type.
Args:
@@ -683,7 +654,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
retcol="user_type",
allow_none=True,
)
- return res is None
+ return res is None or res not in [UserTypes.BOT, UserTypes.SUPPORT]
def is_support_user_txn(self, txn: LoggingTransaction, user_id: str) -> bool:
res = self.db_pool.simple_select_one_onecol_txn(
@@ -759,17 +730,37 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
external_id: id on that system
user_id: complete mxid that it is mapped to
"""
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_by_external_id, (auth_provider, external_id)
+ )
- self.db_pool.simple_insert_txn(
+ # This INSERT ... ON CONFLICT DO NOTHING statement will cause a
+ # 'could not serialize access due to concurrent update'
+ # if the row is added concurrently by another transaction.
+ # This is exactly what we want, as it makes the transaction get retried
+ # in a new snapshot where we can check for a genuine conflict.
+ was_inserted = self.db_pool.simple_upsert_txn(
txn,
table="user_external_ids",
- values={
- "auth_provider": auth_provider,
- "external_id": external_id,
- "user_id": user_id,
- },
+ keyvalues={"auth_provider": auth_provider, "external_id": external_id},
+ values={},
+ insertion_values={"user_id": user_id},
)
+ if not was_inserted:
+ existing_id = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="user_external_ids",
+ keyvalues={"auth_provider": auth_provider, "user_id": user_id},
+ retcol="external_id",
+ allow_none=True,
+ )
+
+ if existing_id != external_id:
+ raise ExternalIDReuseException(
+ f"{user_id!r} has external id {existing_id!r} for {auth_provider} but trying to add {external_id!r}"
+ )
+
async def remove_user_external_id(
self, auth_provider: str, external_id: str, user_id: str
) -> None:
@@ -789,6 +780,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
},
desc="remove_user_external_id",
)
+ await self.invalidate_cache_and_stream(
+ "get_user_by_external_id", (auth_provider, external_id)
+ )
async def replace_user_external_id(
self,
@@ -809,29 +803,20 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
ExternalIDReuseException if the new external_id could not be mapped.
"""
- def _remove_user_external_ids_txn(
+ def _replace_user_external_id_txn(
txn: LoggingTransaction,
- user_id: str,
) -> None:
- """Remove all mappings from external user ids to a mxid
- If these mappings are not found, this method does nothing.
-
- Args:
- user_id: complete mxid that it is mapped to
- """
-
self.db_pool.simple_delete_txn(
txn,
table="user_external_ids",
keyvalues={"user_id": user_id},
)
- def _replace_user_external_id_txn(
- txn: LoggingTransaction,
- ) -> None:
- _remove_user_external_ids_txn(txn, user_id)
-
for auth_provider, external_id in record_external_ids:
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_by_external_id, (auth_provider, external_id)
+ )
+
self._record_user_external_id_txn(
txn,
auth_provider,
@@ -847,6 +832,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
except self.database_engine.module.IntegrityError:
raise ExternalIDReuseException()
+ @cached()
async def get_user_by_external_id(
self, auth_provider: str, external_id: str
) -> Optional[str]:
@@ -944,10 +930,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return await self.db_pool.runInteraction("count_users", _count_users)
async def count_real_users(self) -> int:
- """Counts all users without a special user_type registered on the homeserver."""
+ """Counts all users without the bot or support user_types registered on the homeserver."""
def _count_users(txn: LoggingTransaction) -> int:
- txn.execute("SELECT COUNT(*) FROM users where user_type is null")
+ txn.execute(
+ f"SELECT COUNT(*) FROM users WHERE user_type IS NULL OR user_type NOT IN ('{UserTypes.BOT}', '{UserTypes.SUPPORT}')"
+ )
row = txn.fetchone()
assert row is not None
return row[0]
@@ -965,161 +953,6 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return str(next_id)
- async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[str]:
- """Returns user id from threepid
-
- Args:
- medium: threepid medium e.g. email
- address: threepid address e.g. me@example.com. This must already be
- in canonical form.
-
- Returns:
- The user ID or None if no user id/threepid mapping exists
- """
- user_id = await self.db_pool.runInteraction(
- "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address
- )
- return user_id
-
- def get_user_id_by_threepid_txn(
- self, txn: LoggingTransaction, medium: str, address: str
- ) -> Optional[str]:
- """Returns user id from threepid
-
- Args:
- txn:
- medium: threepid medium e.g. email
- address: threepid address e.g. me@example.com
-
- Returns:
- user id, or None if no user id/threepid mapping exists
- """
- return self.db_pool.simple_select_one_onecol_txn(
- txn,
- "user_threepids",
- {"medium": medium, "address": address},
- "user_id",
- True,
- )
-
- async def user_add_threepid(
- self,
- user_id: str,
- medium: str,
- address: str,
- validated_at: int,
- added_at: int,
- ) -> None:
- await self.db_pool.simple_upsert(
- "user_threepids",
- {"medium": medium, "address": address},
- {"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
- )
-
- async def user_get_threepids(self, user_id: str) -> List[ThreepidResult]:
- results = cast(
- List[Tuple[str, str, int, int]],
- await self.db_pool.simple_select_list(
- "user_threepids",
- keyvalues={"user_id": user_id},
- retcols=["medium", "address", "validated_at", "added_at"],
- desc="user_get_threepids",
- ),
- )
- return [
- ThreepidResult(
- medium=r[0],
- address=r[1],
- validated_at=r[2],
- added_at=r[3],
- )
- for r in results
- ]
-
- async def user_delete_threepid(
- self, user_id: str, medium: str, address: str
- ) -> None:
- await self.db_pool.simple_delete(
- "user_threepids",
- keyvalues={"user_id": user_id, "medium": medium, "address": address},
- desc="user_delete_threepid",
- )
-
- async def add_user_bound_threepid(
- self, user_id: str, medium: str, address: str, id_server: str
- ) -> None:
- """The server proxied a bind request to the given identity server on
- behalf of the given user. We need to remember this in case the user
- asks us to unbind the threepid.
-
- Args:
- user_id
- medium
- address
- id_server
- """
- # We need to use an upsert, in case they user had already bound the
- # threepid
- await self.db_pool.simple_upsert(
- table="user_threepid_id_server",
- keyvalues={
- "user_id": user_id,
- "medium": medium,
- "address": address,
- "id_server": id_server,
- },
- values={},
- insertion_values={},
- desc="add_user_bound_threepid",
- )
-
- async def user_get_bound_threepids(self, user_id: str) -> List[Tuple[str, str]]:
- """Get the threepids that a user has bound to an identity server through the homeserver
- The homeserver remembers where binds to an identity server occurred. Using this
- method can retrieve those threepids.
-
- Args:
- user_id: The ID of the user to retrieve threepids for
-
- Returns:
- List of tuples of two strings:
- medium: The medium of the threepid (e.g "email")
- address: The address of the threepid (e.g "bob@example.com")
- """
- return cast(
- List[Tuple[str, str]],
- await self.db_pool.simple_select_list(
- table="user_threepid_id_server",
- keyvalues={"user_id": user_id},
- retcols=["medium", "address"],
- desc="user_get_bound_threepids",
- ),
- )
-
- async def remove_user_bound_threepid(
- self, user_id: str, medium: str, address: str, id_server: str
- ) -> None:
- """The server proxied an unbind request to the given identity server on
- behalf of the given user, so we remove the mapping of threepid to
- identity server.
-
- Args:
- user_id
- medium
- address
- id_server
- """
- await self.db_pool.simple_delete(
- table="user_threepid_id_server",
- keyvalues={
- "user_id": user_id,
- "medium": medium,
- "address": address,
- "id_server": id_server,
- },
- desc="remove_user_bound_threepid",
- )
-
async def get_id_servers_user_bound(
self, user_id: str, medium: str, address: str
) -> List[str]:
@@ -1204,123 +1037,6 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return bool(res)
- async def get_threepid_validation_session(
- self,
- medium: Optional[str],
- client_secret: str,
- address: Optional[str] = None,
- sid: Optional[str] = None,
- validated: Optional[bool] = True,
- ) -> Optional[ThreepidValidationSession]:
- """Gets a session_id and last_send_attempt (if available) for a
- combination of validation metadata
-
- Args:
- medium: The medium of the 3PID
- client_secret: A unique string provided by the client to help identify this
- validation attempt
- address: The address of the 3PID
- sid: The ID of the validation session
- validated: Whether sessions should be filtered by
- whether they have been validated already or not. None to
- perform no filtering
-
- Returns:
- A ThreepidValidationSession or None if a validation session is not found
- """
- if not client_secret:
- raise SynapseError(
- 400, "Missing parameter: client_secret", errcode=Codes.MISSING_PARAM
- )
-
- keyvalues = {"client_secret": client_secret}
- if medium:
- keyvalues["medium"] = medium
- if address:
- keyvalues["address"] = address
- if sid:
- keyvalues["session_id"] = sid
-
- assert address or sid
-
- def get_threepid_validation_session_txn(
- txn: LoggingTransaction,
- ) -> Optional[ThreepidValidationSession]:
- sql = """
- SELECT address, session_id, medium, client_secret,
- last_send_attempt, validated_at
- FROM threepid_validation_session WHERE %s
- """ % (
- " AND ".join("%s = ?" % k for k in keyvalues.keys()),
- )
-
- if validated is not None:
- sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL")
-
- sql += " LIMIT 1"
-
- txn.execute(sql, list(keyvalues.values()))
- row = txn.fetchone()
- if not row:
- return None
-
- return ThreepidValidationSession(
- address=row[0],
- session_id=row[1],
- medium=row[2],
- client_secret=row[3],
- last_send_attempt=row[4],
- validated_at=row[5],
- )
-
- return await self.db_pool.runInteraction(
- "get_threepid_validation_session", get_threepid_validation_session_txn
- )
-
- async def delete_threepid_session(self, session_id: str) -> None:
- """Removes a threepid validation session from the database. This can
- be done after validation has been performed and whatever action was
- waiting on it has been carried out
-
- Args:
- session_id: The ID of the session to delete
- """
-
- def delete_threepid_session_txn(txn: LoggingTransaction) -> None:
- self.db_pool.simple_delete_txn(
- txn,
- table="threepid_validation_token",
- keyvalues={"session_id": session_id},
- )
- self.db_pool.simple_delete_txn(
- txn,
- table="threepid_validation_session",
- keyvalues={"session_id": session_id},
- )
-
- await self.db_pool.runInteraction(
- "delete_threepid_session", delete_threepid_session_txn
- )
-
- @wrap_as_background_process("cull_expired_threepid_validation_tokens")
- async def cull_expired_threepid_validation_tokens(self) -> None:
- """Remove threepid validation tokens with expiry dates that have passed"""
-
- def cull_expired_threepid_validation_tokens_txn(
- txn: LoggingTransaction, ts: int
- ) -> None:
- sql = """
- DELETE FROM threepid_validation_token WHERE
- expires < ?
- """
- txn.execute(sql, (ts,))
-
- await self.db_pool.runInteraction(
- "cull_expired_threepid_validation_tokens",
- cull_expired_threepid_validation_tokens_txn,
- self._clock.time_msec(),
- )
-
@wrap_as_background_process("account_validity_set_expiration_dates")
async def _set_expiration_date_when_missing(self) -> None:
"""
@@ -1512,15 +1228,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
# Override type because the return type is only optional if
# allow_none is True, and we don't want mypy throwing errors
# about None not being indexable.
- pending, completed = cast(
- Tuple[int, int],
- self.db_pool.simple_select_one_txn(
- txn,
- "registration_tokens",
- keyvalues={"token": token},
- retcols=["pending", "completed"],
- ),
+ row = self.db_pool.simple_select_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=("pending", "completed"),
)
+ pending = int(row[0])
+ completed = int(row[1])
# Decrement pending and increment completed
self.db_pool.simple_update_one_txn(
@@ -2093,6 +1808,136 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
func=is_user_approved_txn,
)
+ async def set_user_deactivated_status(
+ self, user_id: str, deactivated: bool
+ ) -> None:
+ """Set the `deactivated` property for the provided user to the provided value.
+
+ Args:
+ user_id: The ID of the user to set the status for.
+ deactivated: The value to set for `deactivated`.
+ """
+
+ await self.db_pool.runInteraction(
+ "set_user_deactivated_status",
+ self.set_user_deactivated_status_txn,
+ user_id,
+ deactivated,
+ )
+
+ def set_user_deactivated_status_txn(
+ self, txn: LoggingTransaction, user_id: str, deactivated: bool
+ ) -> None:
+ self.db_pool.simple_update_one_txn(
+ txn=txn,
+ table="users",
+ keyvalues={"name": user_id},
+ updatevalues={"deactivated": 1 if deactivated else 0},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_deactivated_status, (user_id,)
+ )
+ self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
+ self._invalidate_cache_and_stream(txn, self.is_guest, (user_id,))
+
+ async def set_user_suspended_status(self, user_id: str, suspended: bool) -> None:
+ """
+ Set whether the user's account is suspended in the `users` table.
+
+ Args:
+ user_id: The user ID of the user in question
+ suspended: True if the user is suspended, false if not
+ """
+ await self.db_pool.runInteraction(
+ "set_user_suspended_status",
+ self.set_user_suspended_status_txn,
+ user_id,
+ suspended,
+ )
+
+ def set_user_suspended_status_txn(
+ self, txn: LoggingTransaction, user_id: str, suspended: bool
+ ) -> None:
+ self.db_pool.simple_update_one_txn(
+ txn=txn,
+ table="users",
+ keyvalues={"name": user_id},
+ updatevalues={"suspended": suspended},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_suspended_status, (user_id,)
+ )
+ self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
+
+ async def set_user_locked_status(self, user_id: str, locked: bool) -> None:
+ """Set the `locked` property for the provided user to the provided value.
+
+ Args:
+ user_id: The ID of the user to set the status for.
+ locked: The value to set for `locked`.
+ """
+
+ await self.db_pool.runInteraction(
+ "set_user_locked_status",
+ self.set_user_locked_status_txn,
+ user_id,
+ locked,
+ )
+
+ def set_user_locked_status_txn(
+ self, txn: LoggingTransaction, user_id: str, locked: bool
+ ) -> None:
+ self.db_pool.simple_update_one_txn(
+ txn=txn,
+ table="users",
+ keyvalues={"name": user_id},
+ updatevalues={"locked": locked},
+ )
+ self._invalidate_cache_and_stream(txn, self.get_user_locked_status, (user_id,))
+ self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
+
+ async def update_user_approval_status(
+ self, user_id: UserID, approved: bool
+ ) -> None:
+ """Set the user's 'approved' flag to the given value.
+
+ The boolean will be turned into an int (in update_user_approval_status_txn)
+ because the column is a smallint.
+
+ Args:
+ user_id: the user to update the flag for.
+ approved: the value to set the flag to.
+ """
+ await self.db_pool.runInteraction(
+ "update_user_approval_status",
+ self.update_user_approval_status_txn,
+ user_id.to_string(),
+ approved,
+ )
+
+ def update_user_approval_status_txn(
+ self, txn: LoggingTransaction, user_id: str, approved: bool
+ ) -> None:
+ """Set the user's 'approved' flag to the given value.
+
+ The boolean is turned into an int because the column is a smallint.
+
+ Args:
+ txn: the current database transaction.
+ user_id: the user to update the flag for.
+ approved: the value to set the flag to.
+ """
+ self.db_pool.simple_update_one_txn(
+ txn=txn,
+ table="users",
+ keyvalues={"name": user_id},
+ updatevalues={"approved": approved},
+ )
+
+ # Invalidate the caches of methods that read the value of the 'approved' flag.
+ self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
+ self._invalidate_cache_and_stream(txn, self.is_user_approved, (user_id,))
+
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(
@@ -2205,117 +2050,6 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
return nb_processed
- async def set_user_deactivated_status(
- self, user_id: str, deactivated: bool
- ) -> None:
- """Set the `deactivated` property for the provided user to the provided value.
-
- Args:
- user_id: The ID of the user to set the status for.
- deactivated: The value to set for `deactivated`.
- """
-
- await self.db_pool.runInteraction(
- "set_user_deactivated_status",
- self.set_user_deactivated_status_txn,
- user_id,
- deactivated,
- )
-
- def set_user_deactivated_status_txn(
- self, txn: LoggingTransaction, user_id: str, deactivated: bool
- ) -> None:
- self.db_pool.simple_update_one_txn(
- txn=txn,
- table="users",
- keyvalues={"name": user_id},
- updatevalues={"deactivated": 1 if deactivated else 0},
- )
- self._invalidate_cache_and_stream(
- txn, self.get_user_deactivated_status, (user_id,)
- )
- self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- txn.call_after(self.is_guest.invalidate, (user_id,))
-
- async def set_user_suspended_status(self, user_id: str, suspended: bool) -> None:
- """
- Set whether the user's account is suspended in the `users` table.
-
- Args:
- user_id: The user ID of the user in question
- suspended: True if the user is suspended, false if not
- """
- await self.db_pool.runInteraction(
- "set_user_suspended_status",
- self.set_user_suspended_status_txn,
- user_id,
- suspended,
- )
-
- def set_user_suspended_status_txn(
- self, txn: LoggingTransaction, user_id: str, suspended: bool
- ) -> None:
- self.db_pool.simple_update_one_txn(
- txn=txn,
- table="users",
- keyvalues={"name": user_id},
- updatevalues={"suspended": suspended},
- )
- self._invalidate_cache_and_stream(
- txn, self.get_user_suspended_status, (user_id,)
- )
- self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
-
- async def set_user_locked_status(self, user_id: str, locked: bool) -> None:
- """Set the `locked` property for the provided user to the provided value.
-
- Args:
- user_id: The ID of the user to set the status for.
- locked: The value to set for `locked`.
- """
-
- await self.db_pool.runInteraction(
- "set_user_locked_status",
- self.set_user_locked_status_txn,
- user_id,
- locked,
- )
-
- def set_user_locked_status_txn(
- self, txn: LoggingTransaction, user_id: str, locked: bool
- ) -> None:
- self.db_pool.simple_update_one_txn(
- txn=txn,
- table="users",
- keyvalues={"name": user_id},
- updatevalues={"locked": locked},
- )
- self._invalidate_cache_and_stream(txn, self.get_user_locked_status, (user_id,))
- self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
-
- def update_user_approval_status_txn(
- self, txn: LoggingTransaction, user_id: str, approved: bool
- ) -> None:
- """Set the user's 'approved' flag to the given value.
-
- The boolean is turned into an int because the column is a smallint.
-
- Args:
- txn: the current database transaction.
- user_id: the user to update the flag for.
- approved: the value to set the flag to.
- """
- self.db_pool.simple_update_one_txn(
- txn=txn,
- table="users",
- keyvalues={"name": user_id},
- updatevalues={"approved": approved},
- )
-
- # Invalidate the caches of methods that read the value of the 'approved' flag.
- self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- self._invalidate_cache_and_stream(txn, self.is_user_approved, (user_id,))
-
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
def __init__(
@@ -2326,9 +2060,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
):
super().__init__(database, db_conn, hs)
- self._ignore_unknown_session_error = (
- hs.config.server.request_token_inhibit_3pid_errors
- )
+ self._ignore_unknown_session_error = False # Used to use whether 3pid errors were suppressed or not... Problem?
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
@@ -2514,7 +2246,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
the user, setting their displayname to the given value
admin: is an admin user?
user_type: type of user. One of the values from api.constants.UserTypes,
- or None for a normal user.
+ a custom value set in the configuration file, or None for a normal
+ user.
shadow_banned: Whether the user is shadow-banned, i.e. they may be
told their requests succeeded but we ignore them.
approved: Whether to consider the user has already been approved by an
@@ -2796,96 +2529,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
desc="add_user_pending_deactivation",
)
- async def validate_threepid_session(
- self, session_id: str, client_secret: str, token: str, current_ts: int
- ) -> Optional[str]:
- """Attempt to validate a threepid session using a token
-
- Args:
- session_id: The id of a validation session
- client_secret: A unique string provided by the client to help identify
- this validation attempt
- token: A validation token
- current_ts: The current unix time in milliseconds. Used for checking
- token expiry status
-
- Raises:
- ThreepidValidationError: if a matching validation token was not found or has
- expired
-
- Returns:
- A str representing a link to redirect the user to if there is one.
- """
-
- # Insert everything into a transaction in order to run atomically
- def validate_threepid_session_txn(txn: LoggingTransaction) -> Optional[str]:
- row = self.db_pool.simple_select_one_txn(
- txn,
- table="threepid_validation_session",
- keyvalues={"session_id": session_id},
- retcols=["client_secret", "validated_at"],
- allow_none=True,
- )
-
- if not row:
- if self._ignore_unknown_session_error:
- # If we need to inhibit the error caused by an incorrect session ID,
- # use None as placeholder values for the client secret and the
- # validation timestamp.
- # It shouldn't be an issue because they're both only checked after
- # the token check, which should fail. And if it doesn't for some
- # reason, the next check is on the client secret, which is NOT NULL,
- # so we don't have to worry about the client secret matching by
- # accident.
- row = None, None
- else:
- raise ThreepidValidationError("Unknown session_id")
-
- retrieved_client_secret, validated_at = row
-
- row = self.db_pool.simple_select_one_txn(
- txn,
- table="threepid_validation_token",
- keyvalues={"session_id": session_id, "token": token},
- retcols=["expires", "next_link"],
- allow_none=True,
- )
-
- if not row:
- raise ThreepidValidationError(
- "Validation token not found or has expired"
- )
- expires, next_link = row
-
- if retrieved_client_secret != client_secret:
- raise ThreepidValidationError(
- "This client_secret does not match the provided session_id"
- )
-
- # If the session is already validated, no need to revalidate
- if validated_at:
- return next_link
-
- if expires <= current_ts:
- raise ThreepidValidationError(
- "This token has expired. Please request a new one"
- )
-
- # Looks good. Validate the session
- self.db_pool.simple_update_txn(
- txn,
- table="threepid_validation_session",
- keyvalues={"session_id": session_id},
- updatevalues={"validated_at": self._clock.time_msec()},
- )
-
- return next_link
-
- # Return next_link if it exists
- return await self.db_pool.runInteraction(
- "validate_threepid_session_txn", validate_threepid_session_txn
- )
-
async def start_or_continue_validation_session(
self,
medium: str,
@@ -2944,25 +2587,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
start_or_continue_validation_session_txn,
)
- async def update_user_approval_status(
- self, user_id: UserID, approved: bool
- ) -> None:
- """Set the user's 'approved' flag to the given value.
-
- The boolean will be turned into an int (in update_user_approval_status_txn)
- because the column is a smallint.
-
- Args:
- user_id: the user to update the flag for.
- approved: the value to set the flag to.
- """
- await self.db_pool.runInteraction(
- "update_user_approval_status",
- self.update_user_approval_status_txn,
- user_id.to_string(),
- approved,
- )
-
@wrap_as_background_process("delete_expired_login_tokens")
async def _delete_expired_login_tokens(self) -> None:
"""Remove login tokens with expiry dates that have passed."""
|