summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/9510.feature1
-rw-r--r--mypy.ini1
-rw-r--r--synapse/api/auth.py41
-rw-r--r--synapse/handlers/auth.py68
-rw-r--r--synapse/handlers/oidc_handler.py65
-rw-r--r--synapse/handlers/sso.py2
-rw-r--r--synapse/module_api/__init__.py31
-rw-r--r--synapse/rest/client/v1/login.py6
-rw-r--r--synapse/util/macaroons.py89
-rw-r--r--tests/handlers/test_auth.py49
-rw-r--r--tests/handlers/test_cas.py10
-rw-r--r--tests/handlers/test_oidc.py36
-rw-r--r--tests/handlers/test_saml.py10
13 files changed, 258 insertions, 151 deletions
diff --git a/changelog.d/9510.feature b/changelog.d/9510.feature
new file mode 100644
index 0000000000..5214b50d41
--- /dev/null
+++ b/changelog.d/9510.feature
@@ -0,0 +1 @@
+Add prometheus metrics for number of users successfully registering and logging in.
diff --git a/mypy.ini b/mypy.ini
index 64ed45dac2..f31cd432e6 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -69,6 +69,7 @@ files =
   synapse/util/async_helpers.py,
   synapse/util/caches,
   synapse/util/metrics.py,
+  synapse/util/macaroons.py,
   synapse/util/stringutils.py,
   tests/replication,
   tests/test_utils,
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 89e62b0e36..968cf6f174 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -39,6 +39,7 @@ from synapse.logging import opentracing as opentracing
 from synapse.storage.databases.main.registration import TokenLookupResult
 from synapse.types import StateMap, UserID
 from synapse.util.caches.lrucache import LruCache
+from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
 from synapse.util.metrics import Measure
 
 logger = logging.getLogger(__name__)
@@ -408,7 +409,7 @@ class Auth:
             raise _InvalidMacaroonException()
 
         try:
-            user_id = self.get_user_id_from_macaroon(macaroon)
+            user_id = get_value_from_macaroon(macaroon, "user_id")
 
             guest = False
             for caveat in macaroon.caveats:
@@ -416,7 +417,12 @@ class Auth:
                     guest = True
 
             self.validate_macaroon(macaroon, rights, user_id=user_id)
-        except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
+        except (
+            pymacaroons.exceptions.MacaroonException,
+            KeyError,
+            TypeError,
+            ValueError,
+        ):
             raise InvalidClientTokenError("Invalid macaroon passed.")
 
         if rights == "access":
@@ -424,27 +430,6 @@ class Auth:
 
         return user_id, guest
 
-    def get_user_id_from_macaroon(self, macaroon):
-        """Retrieve the user_id given by the caveats on the macaroon.
-
-        Does *not* validate the macaroon.
-
-        Args:
-            macaroon (pymacaroons.Macaroon): The macaroon to validate
-
-        Returns:
-            (str) user id
-
-        Raises:
-            InvalidClientCredentialsError if there is no user_id caveat in the
-                macaroon
-        """
-        user_prefix = "user_id = "
-        for caveat in macaroon.caveats:
-            if caveat.caveat_id.startswith(user_prefix):
-                return caveat.caveat_id[len(user_prefix) :]
-        raise InvalidClientTokenError("No user caveat in macaroon")
-
     def validate_macaroon(self, macaroon, type_string, user_id):
         """
         validate that a Macaroon is understood by and was signed by this server.
@@ -465,21 +450,13 @@ class Auth:
         v.satisfy_exact("type = " + type_string)
         v.satisfy_exact("user_id = %s" % user_id)
         v.satisfy_exact("guest = true")
-        v.satisfy_general(self._verify_expiry)
+        satisfy_expiry(v, self.clock.time_msec)
 
         # access_tokens include a nonce for uniqueness: any value is acceptable
         v.satisfy_general(lambda c: c.startswith("nonce = "))
 
         v.verify(macaroon, self._macaroon_secret_key)
 
-    def _verify_expiry(self, caveat):
-        prefix = "time < "
-        if not caveat.startswith(prefix):
-            return False
-        expiry = int(caveat[len(prefix) :])
-        now = self.hs.get_clock().time_msec()
-        return now < expiry
-
     def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService:
         token = self.get_access_token_from_request(request)
         service = self.store.get_app_service_by_token(token)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 3978e41518..bec0c615d4 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -65,6 +65,7 @@ from synapse.storage.roommember import ProfileInfo
 from synapse.types import JsonDict, Requester, UserID
 from synapse.util import stringutils as stringutils
 from synapse.util.async_helpers import maybe_awaitable
+from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
 from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.threepids import canonicalise_email
 
@@ -170,6 +171,16 @@ class SsoLoginExtraAttributes:
     extra_attributes = attr.ib(type=JsonDict)
 
 
+@attr.s(slots=True, frozen=True)
+class LoginTokenAttributes:
+    """Data we store in a short-term login token"""
+
+    user_id = attr.ib(type=str)
+
+    # the SSO Identity Provider that the user authenticated with, to get this token
+    auth_provider_id = attr.ib(type=str)
+
+
 class AuthHandler(BaseHandler):
     SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
 
@@ -1164,18 +1175,16 @@ class AuthHandler(BaseHandler):
             return None
         return user_id
 
-    async def validate_short_term_login_token_and_get_user_id(self, login_token: str):
-        auth_api = self.hs.get_auth()
-        user_id = None
+    async def validate_short_term_login_token(
+        self, login_token: str
+    ) -> LoginTokenAttributes:
         try:
-            macaroon = pymacaroons.Macaroon.deserialize(login_token)
-            user_id = auth_api.get_user_id_from_macaroon(macaroon)
-            auth_api.validate_macaroon(macaroon, "login", user_id)
+            res = self.macaroon_gen.verify_short_term_login_token(login_token)
         except Exception:
             raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
 
-        await self.auth.check_auth_blocking(user_id)
-        return user_id
+        await self.auth.check_auth_blocking(res.user_id)
+        return res
 
     async def delete_access_token(self, access_token: str):
         """Invalidate a single access token
@@ -1397,6 +1406,7 @@ class AuthHandler(BaseHandler):
     async def complete_sso_login(
         self,
         registered_user_id: str,
+        auth_provider_id: str,
         request: Request,
         client_redirect_url: str,
         extra_attributes: Optional[JsonDict] = None,
@@ -1406,6 +1416,9 @@ class AuthHandler(BaseHandler):
 
         Args:
             registered_user_id: The registered user ID to complete SSO login for.
+            auth_provider_id: The id of the SSO Identity provider that was used for
+                login. This will be stored in the login token for future tracking in
+                prometheus metrics.
             request: The request to complete.
             client_redirect_url: The URL to which to redirect the user at the end of the
                 process.
@@ -1427,6 +1440,7 @@ class AuthHandler(BaseHandler):
 
         self._complete_sso_login(
             registered_user_id,
+            auth_provider_id,
             request,
             client_redirect_url,
             extra_attributes,
@@ -1437,6 +1451,7 @@ class AuthHandler(BaseHandler):
     def _complete_sso_login(
         self,
         registered_user_id: str,
+        auth_provider_id: str,
         request: Request,
         client_redirect_url: str,
         extra_attributes: Optional[JsonDict] = None,
@@ -1463,7 +1478,7 @@ class AuthHandler(BaseHandler):
 
         # Create a login token
         login_token = self.macaroon_gen.generate_short_term_login_token(
-            registered_user_id
+            registered_user_id, auth_provider_id=auth_provider_id
         )
 
         # Append the login token to the original redirect URL (i.e. with its query
@@ -1569,15 +1584,48 @@ class MacaroonGenerator:
         return macaroon.serialize()
 
     def generate_short_term_login_token(
-        self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
+        self,
+        user_id: str,
+        auth_provider_id: str,
+        duration_in_ms: int = (2 * 60 * 1000),
     ) -> str:
         macaroon = self._generate_base_macaroon(user_id)
         macaroon.add_first_party_caveat("type = login")
         now = self.hs.get_clock().time_msec()
         expiry = now + duration_in_ms
         macaroon.add_first_party_caveat("time < %d" % (expiry,))
+        macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
         return macaroon.serialize()
 
+    def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
+        """Verify a short-term-login macaroon
+
+        Checks that the given token is a valid, unexpired short-term-login token
+        minted by this server.
+
+        Args:
+            token: the login token to verify
+
+        Returns:
+            the user_id that this token is valid for
+
+        Raises:
+            MacaroonVerificationFailedException if the verification failed
+        """
+        macaroon = pymacaroons.Macaroon.deserialize(token)
+        user_id = get_value_from_macaroon(macaroon, "user_id")
+        auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
+
+        v = pymacaroons.Verifier()
+        v.satisfy_exact("gen = 1")
+        v.satisfy_exact("type = login")
+        v.satisfy_general(lambda c: c.startswith("user_id = "))
+        v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
+        satisfy_expiry(v, self.hs.get_clock().time_msec)
+        v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
+
+        return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
+
     def generate_delete_pusher_token(self, user_id: str) -> str:
         macaroon = self._generate_base_macaroon(user_id)
         macaroon.add_first_party_caveat("type = delete_pusher")
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 07db1e31e4..b4a74390cc 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -42,6 +42,7 @@ from synapse.logging.context import make_deferred_yieldable
 from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
 from synapse.util import json_decoder
 from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
+from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -211,7 +212,7 @@ class OidcHandler:
             session_data = self._token_generator.verify_oidc_session_token(
                 session, state
             )
-        except (MacaroonDeserializationException, ValueError) as e:
+        except (MacaroonDeserializationException, KeyError) as e:
             logger.exception("Invalid session for OIDC callback")
             self._sso_handler.render_error(request, "invalid_session", str(e))
             return
@@ -745,7 +746,7 @@ class OidcProvider:
                 idp_id=self.idp_id,
                 nonce=nonce,
                 client_redirect_url=client_redirect_url.decode(),
-                ui_auth_session_id=ui_auth_session_id,
+                ui_auth_session_id=ui_auth_session_id or "",
             ),
         )
 
@@ -1020,10 +1021,9 @@ class OidcSessionTokenGenerator:
         macaroon.add_first_party_caveat(
             "client_redirect_url = %s" % (session_data.client_redirect_url,)
         )
-        if session_data.ui_auth_session_id:
-            macaroon.add_first_party_caveat(
-                "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
-            )
+        macaroon.add_first_party_caveat(
+            "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
+        )
         now = self._clock.time_msec()
         expiry = now + duration_in_ms
         macaroon.add_first_party_caveat("time < %d" % (expiry,))
@@ -1046,7 +1046,7 @@ class OidcSessionTokenGenerator:
             The data extracted from the session cookie
 
         Raises:
-            ValueError if an expected caveat is missing from the macaroon.
+            KeyError if an expected caveat is missing from the macaroon.
         """
         macaroon = pymacaroons.Macaroon.deserialize(session)
 
@@ -1057,26 +1057,16 @@ class OidcSessionTokenGenerator:
         v.satisfy_general(lambda c: c.startswith("nonce = "))
         v.satisfy_general(lambda c: c.startswith("idp_id = "))
         v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
-        # Sometimes there's a UI auth session ID, it seems to be OK to attempt
-        # to always satisfy this.
         v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
-        v.satisfy_general(self._verify_expiry)
+        satisfy_expiry(v, self._clock.time_msec)
 
         v.verify(macaroon, self._macaroon_secret_key)
 
         # Extract the session data from the token.
-        nonce = self._get_value_from_macaroon(macaroon, "nonce")
-        idp_id = self._get_value_from_macaroon(macaroon, "idp_id")
-        client_redirect_url = self._get_value_from_macaroon(
-            macaroon, "client_redirect_url"
-        )
-        try:
-            ui_auth_session_id = self._get_value_from_macaroon(
-                macaroon, "ui_auth_session_id"
-            )  # type: Optional[str]
-        except ValueError:
-            ui_auth_session_id = None
-
+        nonce = get_value_from_macaroon(macaroon, "nonce")
+        idp_id = get_value_from_macaroon(macaroon, "idp_id")
+        client_redirect_url = get_value_from_macaroon(macaroon, "client_redirect_url")
+        ui_auth_session_id = get_value_from_macaroon(macaroon, "ui_auth_session_id")
         return OidcSessionData(
             nonce=nonce,
             idp_id=idp_id,
@@ -1084,33 +1074,6 @@ class OidcSessionTokenGenerator:
             ui_auth_session_id=ui_auth_session_id,
         )
 
-    def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
-        """Extracts a caveat value from a macaroon token.
-
-        Args:
-            macaroon: the token
-            key: the key of the caveat to extract
-
-        Returns:
-            The extracted value
-
-        Raises:
-            ValueError: if the caveat was not in the macaroon
-        """
-        prefix = key + " = "
-        for caveat in macaroon.caveats:
-            if caveat.caveat_id.startswith(prefix):
-                return caveat.caveat_id[len(prefix) :]
-        raise ValueError("No %s caveat in macaroon" % (key,))
-
-    def _verify_expiry(self, caveat: str) -> bool:
-        prefix = "time < "
-        if not caveat.startswith(prefix):
-            return False
-        expiry = int(caveat[len(prefix) :])
-        now = self._clock.time_msec()
-        return now < expiry
-
 
 @attr.s(frozen=True, slots=True)
 class OidcSessionData:
@@ -1125,8 +1088,8 @@ class OidcSessionData:
     # The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
     client_redirect_url = attr.ib(type=str)
 
-    # The session ID of the ongoing UI Auth (None if this is a login)
-    ui_auth_session_id = attr.ib(type=Optional[str], default=None)
+    # The session ID of the ongoing UI Auth ("" if this is a login)
+    ui_auth_session_id = attr.ib(type=str)
 
 
 UserAttributeDict = TypedDict(
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 80e28bdcbe..8a22dab54a 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -456,6 +456,7 @@ class SsoHandler:
 
         await self._auth_handler.complete_sso_login(
             user_id,
+            auth_provider_id,
             request,
             client_redirect_url,
             extra_login_attributes,
@@ -886,6 +887,7 @@ class SsoHandler:
 
         await self._auth_handler.complete_sso_login(
             user_id,
+            session.auth_provider_id,
             request,
             session.client_redirect_url,
             session.extra_login_attributes,
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index db2d400b7e..781e02fbbb 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -203,11 +203,26 @@ class ModuleApi:
         )
 
     def generate_short_term_login_token(
-        self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
+        self,
+        user_id: str,
+        duration_in_ms: int = (2 * 60 * 1000),
+        auth_provider_id: str = "",
     ) -> str:
-        """Generate a login token suitable for m.login.token authentication"""
+        """Generate a login token suitable for m.login.token authentication
+
+        Args:
+            user_id: gives the ID of the user that the token is for
+
+            duration_in_ms: the time that the token will be valid for
+
+            auth_provider_id: the ID of the SSO IdP that the user used to authenticate
+               to get this token, if any. This is encoded in the token so that
+               /login can report stats on number of successful logins by IdP.
+        """
         return self._hs.get_macaroon_generator().generate_short_term_login_token(
-            user_id, duration_in_ms
+            user_id,
+            auth_provider_id,
+            duration_in_ms,
         )
 
     @defer.inlineCallbacks
@@ -276,6 +291,7 @@ class ModuleApi:
         """
         self._auth_handler._complete_sso_login(
             registered_user_id,
+            "<unknown>",
             request,
             client_redirect_url,
         )
@@ -286,6 +302,7 @@ class ModuleApi:
         request: SynapseRequest,
         client_redirect_url: str,
         new_user: bool = False,
+        auth_provider_id: str = "<unknown>",
     ):
         """Complete a SSO login by redirecting the user to a page to confirm whether they
         want their access token sent to `client_redirect_url`, or redirect them to that
@@ -299,9 +316,15 @@ class ModuleApi:
                 redirect them directly if whitelisted).
             new_user: set to true to use wording for the consent appropriate to a user
                 who has just registered.
+            auth_provider_id: the ID of the SSO IdP which was used to log in. This
+                is used to track counts of sucessful logins by IdP.
         """
         await self._auth_handler.complete_sso_login(
-            registered_user_id, request, client_redirect_url, new_user=new_user
+            registered_user_id,
+            auth_provider_id,
+            request,
+            client_redirect_url,
+            new_user=new_user,
         )
 
     @defer.inlineCallbacks
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 925edfc402..1ec3a47ffb 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -283,12 +283,10 @@ class LoginRestServlet(RestServlet):
         """
         token = login_submission["token"]
         auth_handler = self.auth_handler
-        user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
-            token
-        )
+        res = await auth_handler.validate_short_term_login_token(token)
 
         return await self._complete_login(
-            user_id, login_submission, self.auth_handler._sso_login_callback
+            res.user_id, login_submission, self.auth_handler._sso_login_callback
         )
 
     async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py
new file mode 100644
index 0000000000..12cdd53327
--- /dev/null
+++ b/synapse/util/macaroons.py
@@ -0,0 +1,89 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities for manipulating macaroons"""
+
+from typing import Callable, Optional
+
+import pymacaroons
+from pymacaroons.exceptions import MacaroonVerificationFailedException
+
+
+def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
+    """Extracts a caveat value from a macaroon token.
+
+    Checks that there is exactly one caveat of the form "key = <val>" in the macaroon,
+    and returns the extracted value.
+
+    Args:
+        macaroon: the token
+        key: the key of the caveat to extract
+
+    Returns:
+        The extracted value
+
+    Raises:
+        MacaroonVerificationFailedException: if there are conflicting values for the
+             caveat in the macaroon, or if the caveat was not found in the macaroon.
+    """
+    prefix = key + " = "
+    result = None  # type: Optional[str]
+    for caveat in macaroon.caveats:
+        if not caveat.caveat_id.startswith(prefix):
+            continue
+
+        val = caveat.caveat_id[len(prefix) :]
+
+        if result is None:
+            # first time we found this caveat: record the value
+            result = val
+        elif val != result:
+            # on subsequent occurrences, raise if the value is different.
+            raise MacaroonVerificationFailedException(
+                "Conflicting values for caveat " + key
+            )
+
+    if result is not None:
+        return result
+
+    # If the caveat is not there, we raise a MacaroonVerificationFailedException.
+    # Note that it is insecure to generate a macaroon without all the caveats you
+    # might need (because there is nothing stopping people from adding extra caveats),
+    # so if the caveat isn't there, something odd must be going on.
+    raise MacaroonVerificationFailedException("No %s caveat in macaroon" % (key,))
+
+
+def satisfy_expiry(v: pymacaroons.Verifier, get_time_ms: Callable[[], int]) -> None:
+    """Make a macaroon verifier which accepts 'time' caveats
+
+    Builds a caveat verifier which will accept unexpired 'time' caveats, and adds it to
+    the given macaroon verifier.
+
+    Args:
+        v: the macaroon verifier
+        get_time_ms: a callable which will return the timestamp after which the caveat
+            should be considered expired. Normally the current time.
+    """
+
+    def verify_expiry_caveat(caveat: str):
+        time_msec = get_time_ms()
+        prefix = "time < "
+        if not caveat.startswith(prefix):
+            return False
+        expiry = int(caveat[len(prefix) :])
+        return time_msec < expiry
+
+    v.satisfy_general(verify_expiry_caveat)
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 0e42013bb9..c9f889b511 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -68,38 +68,45 @@ class AuthTestCase(unittest.HomeserverTestCase):
         v.verify(macaroon, self.hs.config.macaroon_secret_key)
 
     def test_short_term_login_token_gives_user_id(self):
-        token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
-        user_id = self.get_success(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
+        token = self.macaroon_generator.generate_short_term_login_token(
+            "a_user", "", 5000
         )
-        self.assertEqual("a_user", user_id)
+        res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
+        self.assertEqual("a_user", res.user_id)
+        self.assertEqual("", res.auth_provider_id)
 
         # when we advance the clock, the token should be rejected
         self.reactor.advance(6)
         self.get_failure(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(token),
+            self.auth_handler.validate_short_term_login_token(token),
             AuthError,
         )
 
+    def test_short_term_login_token_gives_auth_provider(self):
+        token = self.macaroon_generator.generate_short_term_login_token(
+            "a_user", auth_provider_id="my_idp"
+        )
+        res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
+        self.assertEqual("a_user", res.user_id)
+        self.assertEqual("my_idp", res.auth_provider_id)
+
     def test_short_term_login_token_cannot_replace_user_id(self):
-        token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
+        token = self.macaroon_generator.generate_short_term_login_token(
+            "a_user", "", 5000
+        )
         macaroon = pymacaroons.Macaroon.deserialize(token)
 
-        user_id = self.get_success(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
-                macaroon.serialize()
-            )
+        res = self.get_success(
+            self.auth_handler.validate_short_term_login_token(macaroon.serialize())
         )
-        self.assertEqual("a_user", user_id)
+        self.assertEqual("a_user", res.user_id)
 
         # add another "user_id" caveat, which might allow us to override the
         # user_id.
         macaroon.add_first_party_caveat("user_id = b_user")
 
         self.get_failure(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
-                macaroon.serialize()
-            ),
+            self.auth_handler.validate_short_term_login_token(macaroon.serialize()),
             AuthError,
         )
 
@@ -113,7 +120,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         )
 
         self.get_success(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            self.auth_handler.validate_short_term_login_token(
                 self._get_macaroon().serialize()
             )
         )
@@ -135,7 +142,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             return_value=make_awaitable(self.large_number_of_users)
         )
         self.get_failure(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            self.auth_handler.validate_short_term_login_token(
                 self._get_macaroon().serialize()
             ),
             ResourceLimitError,
@@ -159,7 +166,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             ResourceLimitError,
         )
         self.get_failure(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            self.auth_handler.validate_short_term_login_token(
                 self._get_macaroon().serialize()
             ),
             ResourceLimitError,
@@ -175,7 +182,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            self.auth_handler.validate_short_term_login_token(
                 self._get_macaroon().serialize()
             )
         )
@@ -197,11 +204,13 @@ class AuthTestCase(unittest.HomeserverTestCase):
             return_value=make_awaitable(self.small_number_of_users)
         )
         self.get_success(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            self.auth_handler.validate_short_term_login_token(
                 self._get_macaroon().serialize()
             )
         )
 
     def _get_macaroon(self):
-        token = self.macaroon_generator.generate_short_term_login_token("user_a", 5000)
+        token = self.macaroon_generator.generate_short_term_login_token(
+            "user_a", "", 5000
+        )
         return pymacaroons.Macaroon.deserialize(token)
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 6f992291b8..7975af243c 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -66,7 +66,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None, new_user=True
+            "@test_user:test", "cas", request, "redirect_uri", None, new_user=True
         )
 
     def test_map_cas_user_to_existing_user(self):
@@ -89,7 +89,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None, new_user=False
+            "@test_user:test", "cas", request, "redirect_uri", None, new_user=False
         )
 
         # Subsequent calls should map to the same mxid.
@@ -98,7 +98,7 @@ class CasHandlerTestCase(HomeserverTestCase):
             self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
         )
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None, new_user=False
+            "@test_user:test", "cas", request, "redirect_uri", None, new_user=False
         )
 
     def test_map_cas_user_to_invalid_localpart(self):
@@ -116,7 +116,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
+            "@f=c3=b6=c3=b6:test", "cas", request, "redirect_uri", None, new_user=True
         )
 
     @override_config(
@@ -160,7 +160,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None, new_user=True
+            "@test_user:test", "cas", request, "redirect_uri", None, new_user=True
         )
 
 
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index cf1de28fa9..02d4b2de0d 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,7 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import json
-from typing import Optional
 from urllib.parse import parse_qs, urlparse
 
 from mock import ANY, Mock, patch
@@ -23,6 +22,7 @@ import pymacaroons
 from synapse.handlers.sso import MappingException
 from synapse.server import HomeServer
 from synapse.types import UserID
+from synapse.util.macaroons import get_value_from_macaroon
 
 from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
@@ -360,15 +360,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(name, b"oidc_session")
 
         macaroon = pymacaroons.Macaroon.deserialize(cookie)
-        state = self.handler._token_generator._get_value_from_macaroon(
-            macaroon, "state"
-        )
-        nonce = self.handler._token_generator._get_value_from_macaroon(
-            macaroon, "nonce"
-        )
-        redirect = self.handler._token_generator._get_value_from_macaroon(
-            macaroon, "client_redirect_url"
-        )
+        state = get_value_from_macaroon(macaroon, "state")
+        nonce = get_value_from_macaroon(macaroon, "nonce")
+        redirect = get_value_from_macaroon(macaroon, "client_redirect_url")
 
         self.assertEqual(params["state"], [state])
         self.assertEqual(params["nonce"], [nonce])
@@ -434,7 +428,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.handler.handle_oidc_callback(request))
 
         auth_handler.complete_sso_login.assert_called_once_with(
-            expected_user_id, request, client_redirect_url, None, new_user=True
+            expected_user_id, "oidc", request, client_redirect_url, None, new_user=True
         )
         self.provider._exchange_code.assert_called_once_with(code)
         self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
@@ -465,7 +459,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.handler.handle_oidc_callback(request))
 
         auth_handler.complete_sso_login.assert_called_once_with(
-            expected_user_id, request, client_redirect_url, None, new_user=False
+            expected_user_id, "oidc", request, client_redirect_url, None, new_user=False
         )
         self.provider._exchange_code.assert_called_once_with(code)
         self.provider._parse_id_token.assert_not_called()
@@ -651,6 +645,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         auth_handler.complete_sso_login.assert_called_once_with(
             "@foo:test",
+            "oidc",
             request,
             client_redirect_url,
             {"phone": "1234567"},
@@ -668,7 +663,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", ANY, ANY, None, new_user=True
+            "@test_user:test", "oidc", ANY, ANY, None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -679,7 +674,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user_2:test", ANY, ANY, None, new_user=True
+            "@test_user_2:test", "oidc", ANY, ANY, None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -716,14 +711,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            user.to_string(), ANY, ANY, None, new_user=False
+            user.to_string(), "oidc", ANY, ANY, None, new_user=False
         )
         auth_handler.complete_sso_login.reset_mock()
 
         # Subsequent calls should map to the same mxid.
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            user.to_string(), ANY, ANY, None, new_user=False
+            user.to_string(), "oidc", ANY, ANY, None, new_user=False
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -738,7 +733,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            user.to_string(), ANY, ANY, None, new_user=False
+            user.to_string(), "oidc", ANY, ANY, None, new_user=False
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -774,7 +769,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@TEST_USER_2:test", ANY, ANY, None, new_user=False
+            "@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False
         )
 
     def test_map_userinfo_to_invalid_localpart(self):
@@ -810,7 +805,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         # test_user is already taken, so test_user1 gets registered instead.
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user1:test", ANY, ANY, None, new_user=True
+            "@test_user1:test", "oidc", ANY, ANY, None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -866,7 +861,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         state: str,
         nonce: str,
         client_redirect_url: str,
-        ui_auth_session_id: Optional[str] = None,
+        ui_auth_session_id: str = "",
     ) -> str:
         from synapse.handlers.oidc_handler import OidcSessionData
 
@@ -909,6 +904,7 @@ async def _make_callback_with_userinfo(
             idp_id="oidc",
             nonce="nonce",
             client_redirect_url=client_redirect_url,
+            ui_auth_session_id="",
         ),
     )
     request = _build_callback_request("code", state, session)
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 029af2853e..30efd43b40 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -131,7 +131,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None, new_user=True
+            "@test_user:test", "saml", request, "redirect_uri", None, new_user=True
         )
 
     @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
@@ -157,7 +157,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "", None, new_user=False
+            "@test_user:test", "saml", request, "", None, new_user=False
         )
 
         # Subsequent calls should map to the same mxid.
@@ -166,7 +166,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
             self.handler._handle_authn_response(request, saml_response, "")
         )
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "", None, new_user=False
+            "@test_user:test", "saml", request, "", None, new_user=False
         )
 
     def test_map_saml_response_to_invalid_localpart(self):
@@ -214,7 +214,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # test_user is already taken, so test_user1 gets registered instead.
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user1:test", request, "", None, new_user=True
+            "@test_user1:test", "saml", request, "", None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -310,7 +310,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None, new_user=True
+            "@test_user:test", "saml", request, "redirect_uri", None, new_user=True
         )