summary refs log tree commit diff
diff options
context:
space:
mode:
authorQuentin Gliech <quenting@element.io>2022-10-26 12:45:41 +0200
committerGitHub <noreply@github.com>2022-10-26 11:45:41 +0100
commit8756d5c87efc5637da55c9e21d2a4eb2369ba693 (patch)
tree38b8f68e61fa285fba1bc345b006fe1a9e3af026
parentUnified search query syntax using the full-text search capabilities of the un... (diff)
downloadsynapse-8756d5c87efc5637da55c9e21d2a4eb2369ba693.tar.xz
Save login tokens in database (#13844)
* Save login tokens in database

Signed-off-by: Quentin Gliech <quenting@element.io>

* Add upgrade notes

* Track login token reuse in a Prometheus metric

Signed-off-by: Quentin Gliech <quenting@element.io>
-rw-r--r--changelog.d/13844.misc1
-rw-r--r--docs/upgrade.md9
-rw-r--r--synapse/handlers/auth.py64
-rw-r--r--synapse/module_api/__init__.py41
-rw-r--r--synapse/rest/client/login.py3
-rw-r--r--synapse/rest/client/login_token_request.py5
-rw-r--r--synapse/storage/databases/main/registration.py156
-rw-r--r--synapse/storage/schema/main/delta/73/10login_tokens.sql35
-rw-r--r--synapse/util/macaroons.py87
-rw-r--r--tests/handlers/test_auth.py135
-rw-r--r--tests/util/test_macaroons.py28
11 files changed, 337 insertions, 227 deletions
diff --git a/changelog.d/13844.misc b/changelog.d/13844.misc
new file mode 100644
index 0000000000..66f4414df7
--- /dev/null
+++ b/changelog.d/13844.misc
@@ -0,0 +1 @@
+Save login tokens in database and prevent login token reuse.
diff --git a/docs/upgrade.md b/docs/upgrade.md
index b81385b191..78c34d0c15 100644
--- a/docs/upgrade.md
+++ b/docs/upgrade.md
@@ -88,6 +88,15 @@ process, for example:
     dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
     ```
 
+# Upgrading to v1.71.0
+
+## Removal of the `generate_short_term_login_token` module API method
+
+As announced with the release of [Synapse 1.69.0](#deprecation-of-the-generate_short_term_login_token-module-api-method), the deprecated `generate_short_term_login_token` module method has been removed.
+
+Modules relying on it can instead use the `create_login_token` method.
+
+
 # Upgrading to v1.69.0
 
 ## Changes to the receipts replication streams
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index f5f0e0e7a7..8b9ef25d29 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -38,6 +38,7 @@ from typing import (
 import attr
 import bcrypt
 import unpaddedbase64
+from prometheus_client import Counter
 
 from twisted.internet.defer import CancelledError
 from twisted.web.server import Request
@@ -48,6 +49,7 @@ from synapse.api.errors import (
     Codes,
     InteractiveAuthIncompleteError,
     LoginError,
+    NotFoundError,
     StoreError,
     SynapseError,
     UserDeactivatedError,
@@ -63,10 +65,14 @@ from synapse.http.server import finish_request, respond_with_html
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import defer_to_thread
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.databases.main.registration import (
+    LoginTokenExpired,
+    LoginTokenLookupResult,
+    LoginTokenReused,
+)
 from synapse.types import JsonDict, Requester, UserID
 from synapse.util import stringutils as stringutils
 from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
-from synapse.util.macaroons import LoginTokenAttributes
 from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.stringutils import base62_encode
 from synapse.util.threepids import canonicalise_email
@@ -80,6 +86,12 @@ logger = logging.getLogger(__name__)
 
 INVALID_USERNAME_OR_PASSWORD = "Invalid username or password"
 
+invalid_login_token_counter = Counter(
+    "synapse_user_login_invalid_login_tokens",
+    "Counts the number of rejected m.login.token on /login",
+    ["reason"],
+)
+
 
 def convert_client_dict_legacy_fields_to_identifier(
     submission: JsonDict,
@@ -883,6 +895,25 @@ class AuthHandler:
 
         return True
 
+    async def create_login_token_for_user_id(
+        self,
+        user_id: str,
+        duration_ms: int = (2 * 60 * 1000),
+        auth_provider_id: Optional[str] = None,
+        auth_provider_session_id: Optional[str] = None,
+    ) -> str:
+        login_token = self.generate_login_token()
+        now = self._clock.time_msec()
+        expiry_ts = now + duration_ms
+        await self.store.add_login_token_to_user(
+            user_id=user_id,
+            token=login_token,
+            expiry_ts=expiry_ts,
+            auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
+        )
+        return login_token
+
     async def create_refresh_token_for_user_id(
         self,
         user_id: str,
@@ -1401,6 +1432,18 @@ class AuthHandler:
             return None
         return user_id
 
+    def generate_login_token(self) -> str:
+        """Generates an opaque string, for use as an short-term login token"""
+
+        # we use the following format for access tokens:
+        #    syl_<random string>_<base62 crc check>
+
+        random_string = stringutils.random_string(20)
+        base = f"syl_{random_string}"
+
+        crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
+        return f"{base}_{crc}"
+
     def generate_access_token(self, for_user: UserID) -> str:
         """Generates an opaque string, for use as an access token"""
 
@@ -1427,16 +1470,17 @@ class AuthHandler:
         crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
         return f"{base}_{crc}"
 
-    async def validate_short_term_login_token(
-        self, login_token: str
-    ) -> LoginTokenAttributes:
+    async def consume_login_token(self, login_token: str) -> LoginTokenLookupResult:
         try:
-            res = self.macaroon_gen.verify_short_term_login_token(login_token)
-        except Exception:
-            raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN)
+            return await self.store.consume_login_token(login_token)
+        except LoginTokenExpired:
+            invalid_login_token_counter.labels("expired").inc()
+        except LoginTokenReused:
+            invalid_login_token_counter.labels("reused").inc()
+        except NotFoundError:
+            invalid_login_token_counter.labels("not found").inc()
 
-        await self.auth_blocking.check_auth_blocking(res.user_id)
-        return res
+        raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN)
 
     async def delete_access_token(self, access_token: str) -> None:
         """Invalidate a single access token
@@ -1711,7 +1755,7 @@ class AuthHandler:
             )
 
         # Create a login token
-        login_token = self.macaroon_gen.generate_short_term_login_token(
+        login_token = await self.create_login_token_for_user_id(
             registered_user_id,
             auth_provider_id=auth_provider_id,
             auth_provider_session_id=auth_provider_session_id,
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 6a6ae208d1..30e689d00d 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -771,50 +771,11 @@ class ModuleApi:
             auth_provider_session_id: The session ID got during login from the SSO IdP,
                 if any.
         """
-        # The deprecated `generate_short_term_login_token` method defaulted to an empty
-        # string for the `auth_provider_id` because of how the underlying macaroon was
-        # generated. This will change to a proper NULL-able field when the tokens get
-        # moved to the database.
-        return self._hs.get_macaroon_generator().generate_short_term_login_token(
+        return await self._hs.get_auth_handler().create_login_token_for_user_id(
             user_id,
-            auth_provider_id or "",
-            auth_provider_session_id,
             duration_in_ms,
-        )
-
-    def generate_short_term_login_token(
-        self,
-        user_id: str,
-        duration_in_ms: int = (2 * 60 * 1000),
-        auth_provider_id: str = "",
-        auth_provider_session_id: Optional[str] = None,
-    ) -> str:
-        """Generate a login token suitable for m.login.token authentication
-
-        Added in Synapse v1.9.0.
-
-        This was deprecated in Synapse v1.69.0 in favor of create_login_token, and will
-        be removed in Synapse 1.71.0.
-
-        Args:
-            user_id: gives the ID of the user that the token is for
-
-            duration_in_ms: the time that the token will be valid for
-
-            auth_provider_id: the ID of the SSO IdP that the user used to authenticate
-               to get this token, if any. This is encoded in the token so that
-               /login can report stats on number of successful logins by IdP.
-        """
-        logger.warn(
-            "A module configured on this server uses ModuleApi.generate_short_term_login_token(), "
-            "which is deprecated in favor of ModuleApi.create_login_token(), and will be removed in "
-            "Synapse 1.71.0",
-        )
-        return self._hs.get_macaroon_generator().generate_short_term_login_token(
-            user_id,
             auth_provider_id,
             auth_provider_session_id,
-            duration_in_ms,
         )
 
     @defer.inlineCallbacks
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index f554586ac3..7774f1967d 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -436,8 +436,7 @@ class LoginRestServlet(RestServlet):
             The body of the JSON response.
         """
         token = login_submission["token"]
-        auth_handler = self.auth_handler
-        res = await auth_handler.validate_short_term_login_token(token)
+        res = await self.auth_handler.consume_login_token(token)
 
         return await self._complete_login(
             res.user_id,
diff --git a/synapse/rest/client/login_token_request.py b/synapse/rest/client/login_token_request.py
index 277b20fb63..43ea21d5e6 100644
--- a/synapse/rest/client/login_token_request.py
+++ b/synapse/rest/client/login_token_request.py
@@ -57,7 +57,6 @@ class LoginTokenRequestServlet(RestServlet):
         self.store = hs.get_datastores().main
         self.clock = hs.get_clock()
         self.server_name = hs.config.server.server_name
-        self.macaroon_gen = hs.get_macaroon_generator()
         self.auth_handler = hs.get_auth_handler()
         self.token_timeout = hs.config.experimental.msc3882_token_timeout
         self.ui_auth = hs.config.experimental.msc3882_ui_auth
@@ -76,10 +75,10 @@ class LoginTokenRequestServlet(RestServlet):
                 can_skip_ui_auth=False,  # Don't allow skipping of UI auth
             )
 
-        login_token = self.macaroon_gen.generate_short_term_login_token(
+        login_token = await self.auth_handler.create_login_token_for_user_id(
             user_id=requester.user.to_string(),
             auth_provider_id="org.matrix.msc3882.login_token_request",
-            duration_in_ms=self.token_timeout,
+            duration_ms=self.token_timeout,
         )
 
         return (
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 2996d6bb4d..0255295317 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -21,7 +21,13 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
 import attr
 
 from synapse.api.constants import UserTypes
-from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
+from synapse.api.errors import (
+    Codes,
+    NotFoundError,
+    StoreError,
+    SynapseError,
+    ThreepidValidationError,
+)
 from synapse.config.homeserver import HomeServerConfig
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage.database import (
@@ -50,6 +56,14 @@ class ExternalIDReuseException(Exception):
     because this external id is given to an other user."""
 
 
+class LoginTokenExpired(Exception):
+    """Exception if the login token sent expired"""
+
+
+class LoginTokenReused(Exception):
+    """Exception if the login token sent was already used"""
+
+
 @attr.s(frozen=True, slots=True, auto_attribs=True)
 class TokenLookupResult:
     """Result of looking up an access token.
@@ -115,6 +129,20 @@ class RefreshTokenLookupResult:
     If None, the session can be refreshed indefinitely."""
 
 
+@attr.s(auto_attribs=True, frozen=True, slots=True)
+class LoginTokenLookupResult:
+    """Result of looking up a login token."""
+
+    user_id: str
+    """The user this token belongs to."""
+
+    auth_provider_id: Optional[str]
+    """The SSO Identity Provider that the user authenticated with, to get this token."""
+
+    auth_provider_session_id: Optional[str]
+    """The session ID advertised by the SSO Identity Provider."""
+
+
 class RegistrationWorkerStore(CacheInvalidationWorkerStore):
     def __init__(
         self,
@@ -1789,6 +1817,109 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             "replace_refresh_token", _replace_refresh_token_txn
         )
 
+    async def add_login_token_to_user(
+        self,
+        user_id: str,
+        token: str,
+        expiry_ts: int,
+        auth_provider_id: Optional[str],
+        auth_provider_session_id: Optional[str],
+    ) -> None:
+        """Adds a short-term login token for the given user.
+
+        Args:
+            user_id: The user ID.
+            token: The new login token to add.
+            expiry_ts (milliseconds since the epoch): Time after which the login token
+                cannot be used.
+            auth_provider_id: The SSO Identity Provider that the user authenticated with
+                to get this token, if any
+            auth_provider_session_id: The session ID advertised by the SSO Identity
+                Provider, if any.
+        """
+        await self.db_pool.simple_insert(
+            "login_tokens",
+            {
+                "token": token,
+                "user_id": user_id,
+                "expiry_ts": expiry_ts,
+                "auth_provider_id": auth_provider_id,
+                "auth_provider_session_id": auth_provider_session_id,
+            },
+            desc="add_login_token_to_user",
+        )
+
+    def _consume_login_token(
+        self,
+        txn: LoggingTransaction,
+        token: str,
+        ts: int,
+    ) -> LoginTokenLookupResult:
+        values = self.db_pool.simple_select_one_txn(
+            txn,
+            "login_tokens",
+            keyvalues={"token": token},
+            retcols=(
+                "user_id",
+                "expiry_ts",
+                "used_ts",
+                "auth_provider_id",
+                "auth_provider_session_id",
+            ),
+            allow_none=True,
+        )
+
+        if values is None:
+            raise NotFoundError()
+
+        self.db_pool.simple_update_one_txn(
+            txn,
+            "login_tokens",
+            keyvalues={"token": token},
+            updatevalues={"used_ts": ts},
+        )
+        user_id = values["user_id"]
+        expiry_ts = values["expiry_ts"]
+        used_ts = values["used_ts"]
+        auth_provider_id = values["auth_provider_id"]
+        auth_provider_session_id = values["auth_provider_session_id"]
+
+        # Token was already used
+        if used_ts is not None:
+            raise LoginTokenReused()
+
+        # Token expired
+        if ts > int(expiry_ts):
+            raise LoginTokenExpired()
+
+        return LoginTokenLookupResult(
+            user_id=user_id,
+            auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
+        )
+
+    async def consume_login_token(self, token: str) -> LoginTokenLookupResult:
+        """Lookup a login token and consume it.
+
+        Args:
+            token: The login token.
+
+        Returns:
+            The data stored with that token, including the `user_id`. Returns `None` if
+            the token does not exist or if it expired.
+
+        Raises:
+            NotFound if the login token was not found in database
+            LoginTokenExpired if the login token expired
+            LoginTokenReused if the login token was already used
+        """
+        return await self.db_pool.runInteraction(
+            "consume_login_token",
+            self._consume_login_token,
+            token,
+            self._clock.time_msec(),
+        )
+
     @cached()
     async def is_guest(self, user_id: str) -> bool:
         res = await self.db_pool.simple_select_one_onecol(
@@ -2019,6 +2150,12 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
             and hs.config.experimental.msc3866.require_approval_for_new_accounts
         )
 
+        # Create a background job for removing expired login tokens
+        if hs.config.worker.run_background_tasks:
+            self._clock.looping_call(
+                self._delete_expired_login_tokens, THIRTY_MINUTES_IN_MS
+            )
+
     async def add_access_token_to_user(
         self,
         user_id: str,
@@ -2617,6 +2754,23 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
             approved,
         )
 
+    @wrap_as_background_process("delete_expired_login_tokens")
+    async def _delete_expired_login_tokens(self) -> None:
+        """Remove login tokens with expiry dates that have passed."""
+
+        def _delete_expired_login_tokens_txn(txn: LoggingTransaction, ts: int) -> None:
+            sql = "DELETE FROM login_tokens WHERE expiry_ts <= ?"
+            txn.execute(sql, (ts,))
+
+        # We keep the expired tokens for an extra 5 minutes so we can measure how many
+        # times a token is being used after its expiry
+        now = self._clock.time_msec()
+        await self.db_pool.runInteraction(
+            "delete_expired_login_tokens",
+            _delete_expired_login_tokens_txn,
+            now - (5 * 60 * 1000),
+        )
+
 
 def find_max_generated_user_id_localpart(cur: Cursor) -> int:
     """
diff --git a/synapse/storage/schema/main/delta/73/10login_tokens.sql b/synapse/storage/schema/main/delta/73/10login_tokens.sql
new file mode 100644
index 0000000000..a39b7bcece
--- /dev/null
+++ b/synapse/storage/schema/main/delta/73/10login_tokens.sql
@@ -0,0 +1,35 @@
+/*
+ * Copyright 2022 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Login tokens are short-lived tokens that are used for the m.login.token
+-- login method, mainly during SSO logins
+CREATE TABLE login_tokens (
+    token TEXT PRIMARY KEY,
+    user_id TEXT NOT NULL, 
+    expiry_ts BIGINT NOT NULL,
+    used_ts BIGINT,
+    auth_provider_id TEXT,
+    auth_provider_session_id TEXT
+);
+
+-- We're sometimes querying them by their session ID we got from their IDP
+CREATE INDEX login_tokens_auth_provider_idx 
+    ON login_tokens (auth_provider_id, auth_provider_session_id);
+
+-- We're deleting them by their expiration time
+CREATE INDEX login_tokens_expiry_time_idx 
+    ON login_tokens (expiry_ts);
+
diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py
index df77edcce2..5df03d3ddc 100644
--- a/synapse/util/macaroons.py
+++ b/synapse/util/macaroons.py
@@ -24,7 +24,7 @@ from typing_extensions import Literal
 
 from synapse.util import Clock, stringutils
 
-MacaroonType = Literal["access", "delete_pusher", "session", "login"]
+MacaroonType = Literal["access", "delete_pusher", "session"]
 
 
 def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
@@ -111,19 +111,6 @@ class OidcSessionData:
     """The session ID of the ongoing UI Auth ("" if this is a login)"""
 
 
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class LoginTokenAttributes:
-    """Data we store in a short-term login token"""
-
-    user_id: str
-
-    auth_provider_id: str
-    """The SSO Identity Provider that the user authenticated with, to get this token."""
-
-    auth_provider_session_id: Optional[str]
-    """The session ID advertised by the SSO Identity Provider."""
-
-
 class MacaroonGenerator:
     def __init__(self, clock: Clock, location: str, secret_key: bytes):
         self._clock = clock
@@ -165,35 +152,6 @@ class MacaroonGenerator:
         macaroon.add_first_party_caveat(f"pushkey = {pushkey}")
         return macaroon.serialize()
 
-    def generate_short_term_login_token(
-        self,
-        user_id: str,
-        auth_provider_id: str,
-        auth_provider_session_id: Optional[str] = None,
-        duration_in_ms: int = (2 * 60 * 1000),
-    ) -> str:
-        """Generate a short-term login token used during SSO logins
-
-        Args:
-            user_id: The user for which the token is valid.
-            auth_provider_id: The SSO IdP the user used.
-            auth_provider_session_id: The session ID got during login from the SSO IdP.
-
-        Returns:
-            A signed token valid for using as a ``m.login.token`` token.
-        """
-        now = self._clock.time_msec()
-        expiry = now + duration_in_ms
-        macaroon = self._generate_base_macaroon("login")
-        macaroon.add_first_party_caveat(f"user_id = {user_id}")
-        macaroon.add_first_party_caveat(f"time < {expiry}")
-        macaroon.add_first_party_caveat(f"auth_provider_id = {auth_provider_id}")
-        if auth_provider_session_id is not None:
-            macaroon.add_first_party_caveat(
-                f"auth_provider_session_id = {auth_provider_session_id}"
-            )
-        return macaroon.serialize()
-
     def generate_oidc_session_token(
         self,
         state: str,
@@ -233,49 +191,6 @@ class MacaroonGenerator:
 
         return macaroon.serialize()
 
-    def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
-        """Verify a short-term-login macaroon
-
-        Checks that the given token is a valid, unexpired short-term-login token
-        minted by this server.
-
-        Args:
-            token: The login token to verify.
-
-        Returns:
-            A set of attributes carried by this token, including the
-            ``user_id`` and informations about the SSO IDP used during that
-            login.
-
-        Raises:
-            MacaroonVerificationFailedException if the verification failed
-        """
-        macaroon = pymacaroons.Macaroon.deserialize(token)
-
-        v = self._base_verifier("login")
-        v.satisfy_general(lambda c: c.startswith("user_id = "))
-        v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
-        v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = "))
-        satisfy_expiry(v, self._clock.time_msec)
-        v.verify(macaroon, self._secret_key)
-
-        user_id = get_value_from_macaroon(macaroon, "user_id")
-        auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
-
-        auth_provider_session_id: Optional[str] = None
-        try:
-            auth_provider_session_id = get_value_from_macaroon(
-                macaroon, "auth_provider_session_id"
-            )
-        except MacaroonVerificationFailedException:
-            pass
-
-        return LoginTokenAttributes(
-            user_id=user_id,
-            auth_provider_id=auth_provider_id,
-            auth_provider_session_id=auth_provider_session_id,
-        )
-
     def verify_guest_token(self, token: str) -> str:
         """Verify a guest access token macaroon
 
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 7106799d44..036dbbc45b 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Optional
 from unittest.mock import Mock
 
 import pymacaroons
@@ -19,6 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.errors import AuthError, ResourceLimitError
 from synapse.rest import admin
+from synapse.rest.client import login
 from synapse.server import HomeServer
 from synapse.util import Clock
 
@@ -29,6 +31,7 @@ from tests.test_utils import make_awaitable
 class AuthTestCase(unittest.HomeserverTestCase):
     servlets = [
         admin.register_servlets,
+        login.register_servlets,
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -46,6 +49,23 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
         self.user1 = self.register_user("a_user", "pass")
 
+    def token_login(self, token: str) -> Optional[str]:
+        body = {
+            "type": "m.login.token",
+            "token": token,
+        }
+
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/v3/login",
+            body,
+        )
+
+        if channel.code == 200:
+            return channel.json_body["user_id"]
+
+        return None
+
     def test_macaroon_caveats(self) -> None:
         token = self.macaroon_generator.generate_guest_access_token("a_user")
         macaroon = pymacaroons.Macaroon.deserialize(token)
@@ -73,49 +93,62 @@ class AuthTestCase(unittest.HomeserverTestCase):
         v.satisfy_general(verify_guest)
         v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
 
-    def test_short_term_login_token_gives_user_id(self) -> None:
-        token = self.macaroon_generator.generate_short_term_login_token(
-            self.user1, "", duration_in_ms=5000
+    def test_login_token_gives_user_id(self) -> None:
+        token = self.get_success(
+            self.auth_handler.create_login_token_for_user_id(
+                self.user1,
+                duration_ms=(5 * 1000),
+            )
         )
-        res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
+
+        res = self.get_success(self.auth_handler.consume_login_token(token))
         self.assertEqual(self.user1, res.user_id)
-        self.assertEqual("", res.auth_provider_id)
+        self.assertEqual(None, res.auth_provider_id)
 
-        # when we advance the clock, the token should be rejected
-        self.reactor.advance(6)
-        self.get_failure(
-            self.auth_handler.validate_short_term_login_token(token),
-            AuthError,
+    def test_login_token_reuse_fails(self) -> None:
+        token = self.get_success(
+            self.auth_handler.create_login_token_for_user_id(
+                self.user1,
+                duration_ms=(5 * 1000),
+            )
         )
 
-    def test_short_term_login_token_gives_auth_provider(self) -> None:
-        token = self.macaroon_generator.generate_short_term_login_token(
-            self.user1, auth_provider_id="my_idp"
-        )
-        res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
-        self.assertEqual(self.user1, res.user_id)
-        self.assertEqual("my_idp", res.auth_provider_id)
+        self.get_success(self.auth_handler.consume_login_token(token))
 
-    def test_short_term_login_token_cannot_replace_user_id(self) -> None:
-        token = self.macaroon_generator.generate_short_term_login_token(
-            self.user1, "", duration_in_ms=5000
+        self.get_failure(
+            self.auth_handler.consume_login_token(token),
+            AuthError,
         )
-        macaroon = pymacaroons.Macaroon.deserialize(token)
 
-        res = self.get_success(
-            self.auth_handler.validate_short_term_login_token(macaroon.serialize())
+    def test_login_token_expires(self) -> None:
+        token = self.get_success(
+            self.auth_handler.create_login_token_for_user_id(
+                self.user1,
+                duration_ms=(5 * 1000),
+            )
         )
-        self.assertEqual(self.user1, res.user_id)
-
-        # add another "user_id" caveat, which might allow us to override the
-        # user_id.
-        macaroon.add_first_party_caveat("user_id = b_user")
 
+        # when we advance the clock, the token should be rejected
+        self.reactor.advance(6)
         self.get_failure(
-            self.auth_handler.validate_short_term_login_token(macaroon.serialize()),
+            self.auth_handler.consume_login_token(token),
             AuthError,
         )
 
+    def test_login_token_gives_auth_provider(self) -> None:
+        token = self.get_success(
+            self.auth_handler.create_login_token_for_user_id(
+                self.user1,
+                auth_provider_id="my_idp",
+                auth_provider_session_id="11-22-33-44",
+                duration_ms=(5 * 1000),
+            )
+        )
+        res = self.get_success(self.auth_handler.consume_login_token(token))
+        self.assertEqual(self.user1, res.user_id)
+        self.assertEqual("my_idp", res.auth_provider_id)
+        self.assertEqual("11-22-33-44", res.auth_provider_session_id)
+
     def test_mau_limits_disabled(self) -> None:
         self.auth_blocking._limit_usage_by_mau = False
         # Ensure does not throw exception
@@ -125,12 +158,12 @@ class AuthTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        self.get_success(
-            self.auth_handler.validate_short_term_login_token(
-                self._get_macaroon().serialize()
-            )
+        token = self.get_success(
+            self.auth_handler.create_login_token_for_user_id(self.user1)
         )
 
+        self.assertIsNotNone(self.token_login(token))
+
     def test_mau_limits_exceeded_large(self) -> None:
         self.auth_blocking._limit_usage_by_mau = True
         self.hs.get_datastores().main.get_monthly_active_count = Mock(
@@ -147,12 +180,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.hs.get_datastores().main.get_monthly_active_count = Mock(
             return_value=make_awaitable(self.large_number_of_users)
         )
-        self.get_failure(
-            self.auth_handler.validate_short_term_login_token(
-                self._get_macaroon().serialize()
-            ),
-            ResourceLimitError,
+        token = self.get_success(
+            self.auth_handler.create_login_token_for_user_id(self.user1)
         )
+        self.assertIsNone(self.token_login(token))
 
     def test_mau_limits_parity(self) -> None:
         # Ensure we're not at the unix epoch.
@@ -171,12 +202,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
             ),
             ResourceLimitError,
         )
-        self.get_failure(
-            self.auth_handler.validate_short_term_login_token(
-                self._get_macaroon().serialize()
-            ),
-            ResourceLimitError,
+        token = self.get_success(
+            self.auth_handler.create_login_token_for_user_id(self.user1)
         )
+        self.assertIsNone(self.token_login(token))
 
         # If in monthly active cohort
         self.hs.get_datastores().main.user_last_seen_monthly_active = Mock(
@@ -187,11 +216,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
                 self.user1, device_id=None, valid_until_ms=None
             )
         )
-        self.get_success(
-            self.auth_handler.validate_short_term_login_token(
-                self._get_macaroon().serialize()
-            )
+        token = self.get_success(
+            self.auth_handler.create_login_token_for_user_id(self.user1)
         )
+        self.assertIsNotNone(self.token_login(token))
 
     def test_mau_limits_not_exceeded(self) -> None:
         self.auth_blocking._limit_usage_by_mau = True
@@ -209,14 +237,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.hs.get_datastores().main.get_monthly_active_count = Mock(
             return_value=make_awaitable(self.small_number_of_users)
         )
-        self.get_success(
-            self.auth_handler.validate_short_term_login_token(
-                self._get_macaroon().serialize()
-            )
-        )
-
-    def _get_macaroon(self) -> pymacaroons.Macaroon:
-        token = self.macaroon_generator.generate_short_term_login_token(
-            self.user1, "", duration_in_ms=5000
+        token = self.get_success(
+            self.auth_handler.create_login_token_for_user_id(self.user1)
         )
-        return pymacaroons.Macaroon.deserialize(token)
+        self.assertIsNotNone(self.token_login(token))
diff --git a/tests/util/test_macaroons.py b/tests/util/test_macaroons.py
index 32125f7bb7..40754a4711 100644
--- a/tests/util/test_macaroons.py
+++ b/tests/util/test_macaroons.py
@@ -84,34 +84,6 @@ class MacaroonGeneratorTestCase(TestCase):
         )
         self.assertEqual(user_id, "@user:tesths")
 
-    def test_short_term_login_token(self):
-        """Test the generation and verification of short-term login tokens"""
-        token = self.macaroon_generator.generate_short_term_login_token(
-            user_id="@user:tesths",
-            auth_provider_id="oidc",
-            auth_provider_session_id="sid",
-            duration_in_ms=2 * 60 * 1000,
-        )
-
-        info = self.macaroon_generator.verify_short_term_login_token(token)
-        self.assertEqual(info.user_id, "@user:tesths")
-        self.assertEqual(info.auth_provider_id, "oidc")
-        self.assertEqual(info.auth_provider_session_id, "sid")
-
-        # Raises with another secret key
-        with self.assertRaises(MacaroonVerificationFailedException):
-            self.other_macaroon_generator.verify_short_term_login_token(token)
-
-        # Wait a minute
-        self.reactor.pump([60])
-        # Shouldn't raise
-        self.macaroon_generator.verify_short_term_login_token(token)
-        # Wait another minute
-        self.reactor.pump([60])
-        # Should raise since it expired
-        with self.assertRaises(MacaroonVerificationFailedException):
-            self.macaroon_generator.verify_short_term_login_token(token)
-
     def test_oidc_session_token(self):
         """Test the generation and verification of OIDC session cookies"""
         state = "arandomstate"