diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 2996d6bb4d..31f0f2bd3d 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -21,7 +21,13 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
import attr
from synapse.api.constants import UserTypes
-from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
+from synapse.api.errors import (
+ Codes,
+ NotFoundError,
+ StoreError,
+ SynapseError,
+ ThreepidValidationError,
+)
from synapse.config.homeserver import HomeServerConfig
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage.database import (
@@ -50,6 +56,14 @@ class ExternalIDReuseException(Exception):
because this external id is given to an other user."""
+class LoginTokenExpired(Exception):
+ """Exception if the login token sent expired"""
+
+
+class LoginTokenReused(Exception):
+ """Exception if the login token sent was already used"""
+
+
@attr.s(frozen=True, slots=True, auto_attribs=True)
class TokenLookupResult:
"""Result of looking up an access token.
@@ -115,6 +129,20 @@ class RefreshTokenLookupResult:
If None, the session can be refreshed indefinitely."""
+@attr.s(auto_attribs=True, frozen=True, slots=True)
+class LoginTokenLookupResult:
+ """Result of looking up a login token."""
+
+ user_id: str
+ """The user this token belongs to."""
+
+ auth_provider_id: Optional[str]
+ """The SSO Identity Provider that the user authenticated with, to get this token."""
+
+ auth_provider_session_id: Optional[str]
+ """The session ID advertised by the SSO Identity Provider."""
+
+
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(
self,
@@ -925,7 +953,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"""Returns user id from threepid
Args:
- txn (cursor):
+ txn:
medium: threepid medium e.g. email
address: threepid address e.g. me@example.com
@@ -1255,8 +1283,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"""Sets an expiration date to the account with the given user ID.
Args:
- user_id (str): User ID to set an expiration date for.
- use_delta (bool): If set to False, the expiration date for the user will be
+ user_id: User ID to set an expiration date for.
+ use_delta: If set to False, the expiration date for the user will be
now + validity period. If set to True, this expiration date will be a
random value in the [now + period - d ; now + period] range, d being a
delta equal to 10% of the validity period.
@@ -1789,6 +1817,130 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"replace_refresh_token", _replace_refresh_token_txn
)
+ async def add_login_token_to_user(
+ self,
+ user_id: str,
+ token: str,
+ expiry_ts: int,
+ auth_provider_id: Optional[str],
+ auth_provider_session_id: Optional[str],
+ ) -> None:
+ """Adds a short-term login token for the given user.
+
+ Args:
+ user_id: The user ID.
+ token: The new login token to add.
+ expiry_ts (milliseconds since the epoch): Time after which the login token
+ cannot be used.
+ auth_provider_id: The SSO Identity Provider that the user authenticated with
+ to get this token, if any
+ auth_provider_session_id: The session ID advertised by the SSO Identity
+ Provider, if any.
+ """
+ await self.db_pool.simple_insert(
+ "login_tokens",
+ {
+ "token": token,
+ "user_id": user_id,
+ "expiry_ts": expiry_ts,
+ "auth_provider_id": auth_provider_id,
+ "auth_provider_session_id": auth_provider_session_id,
+ },
+ desc="add_login_token_to_user",
+ )
+
+ def _consume_login_token(
+ self,
+ txn: LoggingTransaction,
+ token: str,
+ ts: int,
+ ) -> LoginTokenLookupResult:
+ values = self.db_pool.simple_select_one_txn(
+ txn,
+ "login_tokens",
+ keyvalues={"token": token},
+ retcols=(
+ "user_id",
+ "expiry_ts",
+ "used_ts",
+ "auth_provider_id",
+ "auth_provider_session_id",
+ ),
+ allow_none=True,
+ )
+
+ if values is None:
+ raise NotFoundError()
+
+ self.db_pool.simple_update_one_txn(
+ txn,
+ "login_tokens",
+ keyvalues={"token": token},
+ updatevalues={"used_ts": ts},
+ )
+ user_id = values["user_id"]
+ expiry_ts = values["expiry_ts"]
+ used_ts = values["used_ts"]
+ auth_provider_id = values["auth_provider_id"]
+ auth_provider_session_id = values["auth_provider_session_id"]
+
+ # Token was already used
+ if used_ts is not None:
+ raise LoginTokenReused()
+
+ # Token expired
+ if ts > int(expiry_ts):
+ raise LoginTokenExpired()
+
+ return LoginTokenLookupResult(
+ user_id=user_id,
+ auth_provider_id=auth_provider_id,
+ auth_provider_session_id=auth_provider_session_id,
+ )
+
+ async def consume_login_token(self, token: str) -> LoginTokenLookupResult:
+ """Lookup a login token and consume it.
+
+ Args:
+ token: The login token.
+
+ Returns:
+ The data stored with that token, including the `user_id`. Returns `None` if
+ the token does not exist or if it expired.
+
+ Raises:
+ NotFound if the login token was not found in database
+ LoginTokenExpired if the login token expired
+ LoginTokenReused if the login token was already used
+ """
+ return await self.db_pool.runInteraction(
+ "consume_login_token",
+ self._consume_login_token,
+ token,
+ self._clock.time_msec(),
+ )
+
+ async def invalidate_login_tokens_by_session_id(
+ self, auth_provider_id: str, auth_provider_session_id: str
+ ) -> None:
+ """Invalidate login tokens with the given IdP session ID.
+
+ Args:
+ auth_provider_id: The SSO Identity Provider that the user authenticated with
+ to get this token
+ auth_provider_session_id: The session ID advertised by the SSO Identity
+ Provider
+ """
+ await self.db_pool.simple_update(
+ table="login_tokens",
+ keyvalues={
+ "auth_provider_id": auth_provider_id,
+ "auth_provider_session_id": auth_provider_session_id,
+ },
+ updatevalues={"used_ts": self._clock.time_msec()},
+ desc="invalidate_login_tokens_by_session_id",
+ )
+
@cached()
async def is_guest(self, user_id: str) -> bool:
res = await self.db_pool.simple_select_one_onecol(
@@ -2019,6 +2171,12 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
and hs.config.experimental.msc3866.require_approval_for_new_accounts
)
+ # Create a background job for removing expired login tokens
+ if hs.config.worker.run_background_tasks:
+ self._clock.looping_call(
+ self._delete_expired_login_tokens, THIRTY_MINUTES_IN_MS
+ )
+
async def add_access_token_to_user(
self,
user_id: str,
@@ -2617,6 +2775,23 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
approved,
)
+ @wrap_as_background_process("delete_expired_login_tokens")
+ async def _delete_expired_login_tokens(self) -> None:
+ """Remove login tokens with expiry dates that have passed."""
+
+ def _delete_expired_login_tokens_txn(txn: LoggingTransaction, ts: int) -> None:
+ sql = "DELETE FROM login_tokens WHERE expiry_ts <= ?"
+ txn.execute(sql, (ts,))
+
+ # We keep the expired tokens for an extra 5 minutes so we can measure how many
+ # times a token is being used after its expiry
+ now = self._clock.time_msec()
+ await self.db_pool.runInteraction(
+ "delete_expired_login_tokens",
+ _delete_expired_login_tokens_txn,
+ now - (5 * 60 * 1000),
+ )
+
def find_max_generated_user_id_localpart(cur: Cursor) -> int:
"""
|