summary refs log tree commit diff
path: root/synapse/storage/databases/main/registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/registration.py')
-rw-r--r--synapse/storage/databases/main/registration.py156
1 files changed, 155 insertions, 1 deletions
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 2996d6bb4d..0255295317 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -21,7 +21,13 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
 import attr
 
 from synapse.api.constants import UserTypes
-from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
+from synapse.api.errors import (
+    Codes,
+    NotFoundError,
+    StoreError,
+    SynapseError,
+    ThreepidValidationError,
+)
 from synapse.config.homeserver import HomeServerConfig
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage.database import (
@@ -50,6 +56,14 @@ class ExternalIDReuseException(Exception):
     because this external id is given to an other user."""
 
 
+class LoginTokenExpired(Exception):
+    """Exception if the login token sent expired"""
+
+
+class LoginTokenReused(Exception):
+    """Exception if the login token sent was already used"""
+
+
 @attr.s(frozen=True, slots=True, auto_attribs=True)
 class TokenLookupResult:
     """Result of looking up an access token.
@@ -115,6 +129,20 @@ class RefreshTokenLookupResult:
     If None, the session can be refreshed indefinitely."""
 
 
+@attr.s(auto_attribs=True, frozen=True, slots=True)
+class LoginTokenLookupResult:
+    """Result of looking up a login token."""
+
+    user_id: str
+    """The user this token belongs to."""
+
+    auth_provider_id: Optional[str]
+    """The SSO Identity Provider that the user authenticated with, to get this token."""
+
+    auth_provider_session_id: Optional[str]
+    """The session ID advertised by the SSO Identity Provider."""
+
+
 class RegistrationWorkerStore(CacheInvalidationWorkerStore):
     def __init__(
         self,
@@ -1789,6 +1817,109 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             "replace_refresh_token", _replace_refresh_token_txn
         )
 
+    async def add_login_token_to_user(
+        self,
+        user_id: str,
+        token: str,
+        expiry_ts: int,
+        auth_provider_id: Optional[str],
+        auth_provider_session_id: Optional[str],
+    ) -> None:
+        """Adds a short-term login token for the given user.
+
+        Args:
+            user_id: The user ID.
+            token: The new login token to add.
+            expiry_ts (milliseconds since the epoch): Time after which the login token
+                cannot be used.
+            auth_provider_id: The SSO Identity Provider that the user authenticated with
+                to get this token, if any
+            auth_provider_session_id: The session ID advertised by the SSO Identity
+                Provider, if any.
+        """
+        await self.db_pool.simple_insert(
+            "login_tokens",
+            {
+                "token": token,
+                "user_id": user_id,
+                "expiry_ts": expiry_ts,
+                "auth_provider_id": auth_provider_id,
+                "auth_provider_session_id": auth_provider_session_id,
+            },
+            desc="add_login_token_to_user",
+        )
+
+    def _consume_login_token(
+        self,
+        txn: LoggingTransaction,
+        token: str,
+        ts: int,
+    ) -> LoginTokenLookupResult:
+        values = self.db_pool.simple_select_one_txn(
+            txn,
+            "login_tokens",
+            keyvalues={"token": token},
+            retcols=(
+                "user_id",
+                "expiry_ts",
+                "used_ts",
+                "auth_provider_id",
+                "auth_provider_session_id",
+            ),
+            allow_none=True,
+        )
+
+        if values is None:
+            raise NotFoundError()
+
+        self.db_pool.simple_update_one_txn(
+            txn,
+            "login_tokens",
+            keyvalues={"token": token},
+            updatevalues={"used_ts": ts},
+        )
+        user_id = values["user_id"]
+        expiry_ts = values["expiry_ts"]
+        used_ts = values["used_ts"]
+        auth_provider_id = values["auth_provider_id"]
+        auth_provider_session_id = values["auth_provider_session_id"]
+
+        # Token was already used
+        if used_ts is not None:
+            raise LoginTokenReused()
+
+        # Token expired
+        if ts > int(expiry_ts):
+            raise LoginTokenExpired()
+
+        return LoginTokenLookupResult(
+            user_id=user_id,
+            auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
+        )
+
+    async def consume_login_token(self, token: str) -> LoginTokenLookupResult:
+        """Lookup a login token and consume it.
+
+        Args:
+            token: The login token.
+
+        Returns:
+            The data stored with that token, including the `user_id`. Returns `None` if
+            the token does not exist or if it expired.
+
+        Raises:
+            NotFound if the login token was not found in database
+            LoginTokenExpired if the login token expired
+            LoginTokenReused if the login token was already used
+        """
+        return await self.db_pool.runInteraction(
+            "consume_login_token",
+            self._consume_login_token,
+            token,
+            self._clock.time_msec(),
+        )
+
     @cached()
     async def is_guest(self, user_id: str) -> bool:
         res = await self.db_pool.simple_select_one_onecol(
@@ -2019,6 +2150,12 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
             and hs.config.experimental.msc3866.require_approval_for_new_accounts
         )
 
+        # Create a background job for removing expired login tokens
+        if hs.config.worker.run_background_tasks:
+            self._clock.looping_call(
+                self._delete_expired_login_tokens, THIRTY_MINUTES_IN_MS
+            )
+
     async def add_access_token_to_user(
         self,
         user_id: str,
@@ -2617,6 +2754,23 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
             approved,
         )
 
+    @wrap_as_background_process("delete_expired_login_tokens")
+    async def _delete_expired_login_tokens(self) -> None:
+        """Remove login tokens with expiry dates that have passed."""
+
+        def _delete_expired_login_tokens_txn(txn: LoggingTransaction, ts: int) -> None:
+            sql = "DELETE FROM login_tokens WHERE expiry_ts <= ?"
+            txn.execute(sql, (ts,))
+
+        # We keep the expired tokens for an extra 5 minutes so we can measure how many
+        # times a token is being used after its expiry
+        now = self._clock.time_msec()
+        await self.db_pool.runInteraction(
+            "delete_expired_login_tokens",
+            _delete_expired_login_tokens_txn,
+            now - (5 * 60 * 1000),
+        )
+
 
 def find_max_generated_user_id_localpart(cur: Cursor) -> int:
     """