summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--README.rst5
-rw-r--r--UPGRADE.rst7
-rw-r--r--changelog.d/9508.doc1
-rw-r--r--changelog.d/9510.feature1
-rw-r--r--changelog.d/9511.feature1
-rw-r--r--changelog.d/9528.misc1
-rw-r--r--changelog.d/9540.feature1
-rw-r--r--changelog.d/9540.removal1
-rw-r--r--changelog.d/9550.doc1
-rw-r--r--docs/reverse_proxy.md51
-rw-r--r--mypy.ini1
-rw-r--r--synapse/api/auth.py41
-rw-r--r--synapse/federation/federation_server.py10
-rw-r--r--synapse/federation/sender/transaction_manager.py11
-rw-r--r--synapse/handlers/acme.py4
-rw-r--r--synapse/handlers/auth.py68
-rw-r--r--synapse/handlers/oidc_handler.py65
-rw-r--r--synapse/handlers/register.py35
-rw-r--r--synapse/handlers/sso.py3
-rw-r--r--synapse/http/client.py5
-rw-r--r--synapse/http/federation/matrix_federation_agent.py3
-rw-r--r--synapse/http/matrixfederationclient.py8
-rw-r--r--synapse/module_api/__init__.py31
-rw-r--r--synapse/replication/tcp/redis.py2
-rw-r--r--synapse/rest/client/v1/login.py14
-rw-r--r--synapse/server.py5
-rw-r--r--synapse/types.py16
-rw-r--r--synapse/util/macaroons.py89
-rw-r--r--tests/handlers/test_auth.py49
-rw-r--r--tests/handlers/test_cas.py10
-rw-r--r--tests/handlers/test_oidc.py36
-rw-r--r--tests/handlers/test_saml.py10
32 files changed, 404 insertions, 182 deletions
diff --git a/README.rst b/README.rst
index d872b11f57..6a1e713590 100644
--- a/README.rst
+++ b/README.rst
@@ -183,8 +183,9 @@ Using a reverse proxy with Synapse
 It is recommended to put a reverse proxy such as
 `nginx <https://nginx.org/en/docs/http/ngx_http_proxy_module.html>`_,
 `Apache <https://httpd.apache.org/docs/current/mod/mod_proxy_http.html>`_,
-`Caddy <https://caddyserver.com/docs/quick-starts/reverse-proxy>`_ or
-`HAProxy <https://www.haproxy.org/>`_ in front of Synapse. One advantage of
+`Caddy <https://caddyserver.com/docs/quick-starts/reverse-proxy>`_,
+`HAProxy <https://www.haproxy.org/>`_ or
+`relayd <https://man.openbsd.org/relayd.8>`_ in front of Synapse. One advantage of
 doing so is that it means that you can expose the default https port (443) to
 Matrix clients without needing to run Synapse with root privileges.
 
diff --git a/UPGRADE.rst b/UPGRADE.rst
index 031e02bda9..8bc2ff91ab 100644
--- a/UPGRADE.rst
+++ b/UPGRADE.rst
@@ -124,6 +124,13 @@ This version changes the URI used for callbacks from OAuth2 and SAML2 identity p
   need to add ``[synapse public baseurl]/_synapse/client/saml2/authn_response`` as a permitted
   "ACS location" (also known as "allowed callback URLs") at the identity provider.
 
+  The "Issuer" in the "AuthnRequest" to the SAML2 identity provider is also updated to
+  ``[synapse public baseurl]/_synapse/client/saml2/metadata.xml``. If your SAML2 identity
+  provider uses this property to validate or otherwise identify Synapse, its configuration
+  will need to be updated to use the new URL. Alternatively you could create a new, separate
+  "EntityDescriptor" in your SAML2 identity provider with the new URLs and leave the URLs in
+  the existing "EntityDescriptor" as they were.
+
 Changes to HTML templates
 -------------------------
 
diff --git a/changelog.d/9508.doc b/changelog.d/9508.doc
new file mode 100644
index 0000000000..a17a8faecf
--- /dev/null
+++ b/changelog.d/9508.doc
@@ -0,0 +1 @@
+Add relayd entry to reverse proxy example configurations.
diff --git a/changelog.d/9510.feature b/changelog.d/9510.feature
new file mode 100644
index 0000000000..5214b50d41
--- /dev/null
+++ b/changelog.d/9510.feature
@@ -0,0 +1 @@
+Add prometheus metrics for number of users successfully registering and logging in.
diff --git a/changelog.d/9511.feature b/changelog.d/9511.feature
new file mode 100644
index 0000000000..5214b50d41
--- /dev/null
+++ b/changelog.d/9511.feature
@@ -0,0 +1 @@
+Add prometheus metrics for number of users successfully registering and logging in.
diff --git a/changelog.d/9528.misc b/changelog.d/9528.misc
new file mode 100644
index 0000000000..14c7b78dd9
--- /dev/null
+++ b/changelog.d/9528.misc
@@ -0,0 +1 @@
+Fix incorrect type hints.
diff --git a/changelog.d/9540.feature b/changelog.d/9540.feature
new file mode 100644
index 0000000000..5417e51b93
--- /dev/null
+++ b/changelog.d/9540.feature
@@ -0,0 +1 @@
+Add `synapse_federation_last_sent_pdu_time` and `synapse_federation_last_received_pdu_time` prometheus metrics, which monitor federation delays by reporting the timestamps of messages sent and received to a set of remote servers.
diff --git a/changelog.d/9540.removal b/changelog.d/9540.removal
new file mode 100644
index 0000000000..d54f553cb9
--- /dev/null
+++ b/changelog.d/9540.removal
@@ -0,0 +1 @@
+The `synapse_federation_last_sent_pdu_age` and `synapse_federation_last_received_pdu_age` prometheus metrics have been removed. They are replaced by `synapse_federation_last_sent_pdu_time` and `synapse_federation_last_received_pdu_time`.
diff --git a/changelog.d/9550.doc b/changelog.d/9550.doc
new file mode 100644
index 0000000000..adbbeb0ae4
--- /dev/null
+++ b/changelog.d/9550.doc
@@ -0,0 +1 @@
+Improve the SAML2 upgrade notes for 1.27.0.
diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md
index 81e5a68a36..860afd5a04 100644
--- a/docs/reverse_proxy.md
+++ b/docs/reverse_proxy.md
@@ -3,8 +3,9 @@
 It is recommended to put a reverse proxy such as
 [nginx](https://nginx.org/en/docs/http/ngx_http_proxy_module.html),
 [Apache](https://httpd.apache.org/docs/current/mod/mod_proxy_http.html),
-[Caddy](https://caddyserver.com/docs/quick-starts/reverse-proxy) or
-[HAProxy](https://www.haproxy.org/) in front of Synapse. One advantage
+[Caddy](https://caddyserver.com/docs/quick-starts/reverse-proxy),
+[HAProxy](https://www.haproxy.org/) or
+[relayd](https://man.openbsd.org/relayd.8) in front of Synapse. One advantage
 of doing so is that it means that you can expose the default https port
 (443) to Matrix clients without needing to run Synapse with root
 privileges.
@@ -162,6 +163,52 @@ backend matrix
   server matrix 127.0.0.1:8008
 ```
 
+### Relayd
+
+```
+table <webserver>    { 127.0.0.1 }
+table <matrixserver> { 127.0.0.1 }
+
+http protocol "https" {
+    tls { no tlsv1.0, ciphers "HIGH" }
+    tls keypair "example.com"
+    match header set "X-Forwarded-For"   value "$REMOTE_ADDR"
+    match header set "X-Forwarded-Proto" value "https"
+
+    # set CORS header for .well-known/matrix/server, .well-known/matrix/client
+    # httpd does not support setting headers, so do it here
+    match request path "/.well-known/matrix/*" tag "matrix-cors"
+    match response tagged "matrix-cors" header set "Access-Control-Allow-Origin" value "*"
+
+    pass quick path "/_matrix/*"         forward to <matrixserver>
+    pass quick path "/_synapse/client/*" forward to <matrixserver>
+
+    # pass on non-matrix traffic to webserver
+    pass                                 forward to <webserver>
+}
+
+relay "https_traffic" {
+    listen on egress port 443 tls
+    protocol "https"
+    forward to <matrixserver> port 8008 check tcp
+    forward to <webserver>    port 8080 check tcp
+}
+
+http protocol "matrix" {
+    tls { no tlsv1.0, ciphers "HIGH" }
+    tls keypair "example.com"
+    block
+    pass quick path "/_matrix/*"         forward to <matrixserver>
+    pass quick path "/_synapse/client/*" forward to <matrixserver>
+}
+
+relay "matrix_federation" {
+    listen on egress port 8448 tls
+    protocol "matrix"
+    forward to <matrixserver> port 8008 check tcp
+}
+```
+
 ## Homeserver Configuration
 
 You will also want to set `bind_addresses: ['127.0.0.1']` and
diff --git a/mypy.ini b/mypy.ini
index 64ed45dac2..f31cd432e6 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -69,6 +69,7 @@ files =
   synapse/util/async_helpers.py,
   synapse/util/caches,
   synapse/util/metrics.py,
+  synapse/util/macaroons.py,
   synapse/util/stringutils.py,
   tests/replication,
   tests/test_utils,
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 89e62b0e36..968cf6f174 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -39,6 +39,7 @@ from synapse.logging import opentracing as opentracing
 from synapse.storage.databases.main.registration import TokenLookupResult
 from synapse.types import StateMap, UserID
 from synapse.util.caches.lrucache import LruCache
+from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
 from synapse.util.metrics import Measure
 
 logger = logging.getLogger(__name__)
@@ -408,7 +409,7 @@ class Auth:
             raise _InvalidMacaroonException()
 
         try:
-            user_id = self.get_user_id_from_macaroon(macaroon)
+            user_id = get_value_from_macaroon(macaroon, "user_id")
 
             guest = False
             for caveat in macaroon.caveats:
@@ -416,7 +417,12 @@ class Auth:
                     guest = True
 
             self.validate_macaroon(macaroon, rights, user_id=user_id)
-        except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
+        except (
+            pymacaroons.exceptions.MacaroonException,
+            KeyError,
+            TypeError,
+            ValueError,
+        ):
             raise InvalidClientTokenError("Invalid macaroon passed.")
 
         if rights == "access":
@@ -424,27 +430,6 @@ class Auth:
 
         return user_id, guest
 
-    def get_user_id_from_macaroon(self, macaroon):
-        """Retrieve the user_id given by the caveats on the macaroon.
-
-        Does *not* validate the macaroon.
-
-        Args:
-            macaroon (pymacaroons.Macaroon): The macaroon to validate
-
-        Returns:
-            (str) user id
-
-        Raises:
-            InvalidClientCredentialsError if there is no user_id caveat in the
-                macaroon
-        """
-        user_prefix = "user_id = "
-        for caveat in macaroon.caveats:
-            if caveat.caveat_id.startswith(user_prefix):
-                return caveat.caveat_id[len(user_prefix) :]
-        raise InvalidClientTokenError("No user caveat in macaroon")
-
     def validate_macaroon(self, macaroon, type_string, user_id):
         """
         validate that a Macaroon is understood by and was signed by this server.
@@ -465,21 +450,13 @@ class Auth:
         v.satisfy_exact("type = " + type_string)
         v.satisfy_exact("user_id = %s" % user_id)
         v.satisfy_exact("guest = true")
-        v.satisfy_general(self._verify_expiry)
+        satisfy_expiry(v, self.clock.time_msec)
 
         # access_tokens include a nonce for uniqueness: any value is acceptable
         v.satisfy_general(lambda c: c.startswith("nonce = "))
 
         v.verify(macaroon, self._macaroon_secret_key)
 
-    def _verify_expiry(self, caveat):
-        prefix = "time < "
-        if not caveat.startswith(prefix):
-            return False
-        expiry = int(caveat[len(prefix) :])
-        now = self.hs.get_clock().time_msec()
-        return now < expiry
-
     def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService:
         token = self.get_access_token_from_request(request)
         service = self.store.get_app_service_by_token(token)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 2f832b47f6..362895bf42 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -90,10 +90,9 @@ pdu_process_time = Histogram(
     "Time taken to process an event",
 )
 
-
-last_pdu_age_metric = Gauge(
-    "synapse_federation_last_received_pdu_age",
-    "The age (in seconds) of the last PDU successfully received from the given domain",
+last_pdu_ts_metric = Gauge(
+    "synapse_federation_last_received_pdu_time",
+    "The timestamp of the last PDU which was successfully received from the given domain",
     labelnames=("server_name",),
 )
 
@@ -369,8 +368,7 @@ class FederationServer(FederationBase):
         )
 
         if newest_pdu_ts and origin in self._federation_metrics_domains:
-            newest_pdu_age = self._clock.time_msec() - newest_pdu_ts
-            last_pdu_age_metric.labels(server_name=origin).set(newest_pdu_age / 1000)
+            last_pdu_ts_metric.labels(server_name=origin).set(newest_pdu_ts / 1000)
 
         return pdu_results
 
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 763aff296c..2a9cd063c4 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -36,9 +36,9 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
-last_pdu_age_metric = Gauge(
-    "synapse_federation_last_sent_pdu_age",
-    "The age (in seconds) of the last PDU successfully sent to the given domain",
+last_pdu_ts_metric = Gauge(
+    "synapse_federation_last_sent_pdu_time",
+    "The timestamp of the last PDU which was successfully sent to the given domain",
     labelnames=("server_name",),
 )
 
@@ -187,9 +187,8 @@ class TransactionManager:
 
             if success and pdus and destination in self._federation_metrics_domains:
                 last_pdu = pdus[-1]
-                last_pdu_age = self.clock.time_msec() - last_pdu.origin_server_ts
-                last_pdu_age_metric.labels(server_name=destination).set(
-                    last_pdu_age / 1000
+                last_pdu_ts_metric.labels(server_name=destination).set(
+                    last_pdu.origin_server_ts / 1000
                 )
 
             set_tag(tags.ERROR, not success)
diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py
index 5ecb2da1ac..132be238dd 100644
--- a/synapse/handlers/acme.py
+++ b/synapse/handlers/acme.py
@@ -73,7 +73,9 @@ class AcmeHandler:
                 "Listening for ACME requests on %s:%i", host, self.hs.config.acme_port
             )
             try:
-                self.reactor.listenTCP(self.hs.config.acme_port, srv, interface=host)
+                self.reactor.listenTCP(
+                    self.hs.config.acme_port, srv, backlog=50, interface=host
+                )
             except twisted.internet.error.CannotListenError as e:
                 check_bind_error(e, host, bind_addresses)
 
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 3978e41518..bec0c615d4 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -65,6 +65,7 @@ from synapse.storage.roommember import ProfileInfo
 from synapse.types import JsonDict, Requester, UserID
 from synapse.util import stringutils as stringutils
 from synapse.util.async_helpers import maybe_awaitable
+from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
 from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.threepids import canonicalise_email
 
@@ -170,6 +171,16 @@ class SsoLoginExtraAttributes:
     extra_attributes = attr.ib(type=JsonDict)
 
 
+@attr.s(slots=True, frozen=True)
+class LoginTokenAttributes:
+    """Data we store in a short-term login token"""
+
+    user_id = attr.ib(type=str)
+
+    # the SSO Identity Provider that the user authenticated with, to get this token
+    auth_provider_id = attr.ib(type=str)
+
+
 class AuthHandler(BaseHandler):
     SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
 
@@ -1164,18 +1175,16 @@ class AuthHandler(BaseHandler):
             return None
         return user_id
 
-    async def validate_short_term_login_token_and_get_user_id(self, login_token: str):
-        auth_api = self.hs.get_auth()
-        user_id = None
+    async def validate_short_term_login_token(
+        self, login_token: str
+    ) -> LoginTokenAttributes:
         try:
-            macaroon = pymacaroons.Macaroon.deserialize(login_token)
-            user_id = auth_api.get_user_id_from_macaroon(macaroon)
-            auth_api.validate_macaroon(macaroon, "login", user_id)
+            res = self.macaroon_gen.verify_short_term_login_token(login_token)
         except Exception:
             raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
 
-        await self.auth.check_auth_blocking(user_id)
-        return user_id
+        await self.auth.check_auth_blocking(res.user_id)
+        return res
 
     async def delete_access_token(self, access_token: str):
         """Invalidate a single access token
@@ -1397,6 +1406,7 @@ class AuthHandler(BaseHandler):
     async def complete_sso_login(
         self,
         registered_user_id: str,
+        auth_provider_id: str,
         request: Request,
         client_redirect_url: str,
         extra_attributes: Optional[JsonDict] = None,
@@ -1406,6 +1416,9 @@ class AuthHandler(BaseHandler):
 
         Args:
             registered_user_id: The registered user ID to complete SSO login for.
+            auth_provider_id: The id of the SSO Identity provider that was used for
+                login. This will be stored in the login token for future tracking in
+                prometheus metrics.
             request: The request to complete.
             client_redirect_url: The URL to which to redirect the user at the end of the
                 process.
@@ -1427,6 +1440,7 @@ class AuthHandler(BaseHandler):
 
         self._complete_sso_login(
             registered_user_id,
+            auth_provider_id,
             request,
             client_redirect_url,
             extra_attributes,
@@ -1437,6 +1451,7 @@ class AuthHandler(BaseHandler):
     def _complete_sso_login(
         self,
         registered_user_id: str,
+        auth_provider_id: str,
         request: Request,
         client_redirect_url: str,
         extra_attributes: Optional[JsonDict] = None,
@@ -1463,7 +1478,7 @@ class AuthHandler(BaseHandler):
 
         # Create a login token
         login_token = self.macaroon_gen.generate_short_term_login_token(
-            registered_user_id
+            registered_user_id, auth_provider_id=auth_provider_id
         )
 
         # Append the login token to the original redirect URL (i.e. with its query
@@ -1569,15 +1584,48 @@ class MacaroonGenerator:
         return macaroon.serialize()
 
     def generate_short_term_login_token(
-        self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
+        self,
+        user_id: str,
+        auth_provider_id: str,
+        duration_in_ms: int = (2 * 60 * 1000),
     ) -> str:
         macaroon = self._generate_base_macaroon(user_id)
         macaroon.add_first_party_caveat("type = login")
         now = self.hs.get_clock().time_msec()
         expiry = now + duration_in_ms
         macaroon.add_first_party_caveat("time < %d" % (expiry,))
+        macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
         return macaroon.serialize()
 
+    def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
+        """Verify a short-term-login macaroon
+
+        Checks that the given token is a valid, unexpired short-term-login token
+        minted by this server.
+
+        Args:
+            token: the login token to verify
+
+        Returns:
+            the user_id that this token is valid for
+
+        Raises:
+            MacaroonVerificationFailedException if the verification failed
+        """
+        macaroon = pymacaroons.Macaroon.deserialize(token)
+        user_id = get_value_from_macaroon(macaroon, "user_id")
+        auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
+
+        v = pymacaroons.Verifier()
+        v.satisfy_exact("gen = 1")
+        v.satisfy_exact("type = login")
+        v.satisfy_general(lambda c: c.startswith("user_id = "))
+        v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
+        satisfy_expiry(v, self.hs.get_clock().time_msec)
+        v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
+
+        return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
+
     def generate_delete_pusher_token(self, user_id: str) -> str:
         macaroon = self._generate_base_macaroon(user_id)
         macaroon.add_first_party_caveat("type = delete_pusher")
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 07db1e31e4..b4a74390cc 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -42,6 +42,7 @@ from synapse.logging.context import make_deferred_yieldable
 from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
 from synapse.util import json_decoder
 from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
+from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -211,7 +212,7 @@ class OidcHandler:
             session_data = self._token_generator.verify_oidc_session_token(
                 session, state
             )
-        except (MacaroonDeserializationException, ValueError) as e:
+        except (MacaroonDeserializationException, KeyError) as e:
             logger.exception("Invalid session for OIDC callback")
             self._sso_handler.render_error(request, "invalid_session", str(e))
             return
@@ -745,7 +746,7 @@ class OidcProvider:
                 idp_id=self.idp_id,
                 nonce=nonce,
                 client_redirect_url=client_redirect_url.decode(),
-                ui_auth_session_id=ui_auth_session_id,
+                ui_auth_session_id=ui_auth_session_id or "",
             ),
         )
 
@@ -1020,10 +1021,9 @@ class OidcSessionTokenGenerator:
         macaroon.add_first_party_caveat(
             "client_redirect_url = %s" % (session_data.client_redirect_url,)
         )
-        if session_data.ui_auth_session_id:
-            macaroon.add_first_party_caveat(
-                "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
-            )
+        macaroon.add_first_party_caveat(
+            "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
+        )
         now = self._clock.time_msec()
         expiry = now + duration_in_ms
         macaroon.add_first_party_caveat("time < %d" % (expiry,))
@@ -1046,7 +1046,7 @@ class OidcSessionTokenGenerator:
             The data extracted from the session cookie
 
         Raises:
-            ValueError if an expected caveat is missing from the macaroon.
+            KeyError if an expected caveat is missing from the macaroon.
         """
         macaroon = pymacaroons.Macaroon.deserialize(session)
 
@@ -1057,26 +1057,16 @@ class OidcSessionTokenGenerator:
         v.satisfy_general(lambda c: c.startswith("nonce = "))
         v.satisfy_general(lambda c: c.startswith("idp_id = "))
         v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
-        # Sometimes there's a UI auth session ID, it seems to be OK to attempt
-        # to always satisfy this.
         v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
-        v.satisfy_general(self._verify_expiry)
+        satisfy_expiry(v, self._clock.time_msec)
 
         v.verify(macaroon, self._macaroon_secret_key)
 
         # Extract the session data from the token.
-        nonce = self._get_value_from_macaroon(macaroon, "nonce")
-        idp_id = self._get_value_from_macaroon(macaroon, "idp_id")
-        client_redirect_url = self._get_value_from_macaroon(
-            macaroon, "client_redirect_url"
-        )
-        try:
-            ui_auth_session_id = self._get_value_from_macaroon(
-                macaroon, "ui_auth_session_id"
-            )  # type: Optional[str]
-        except ValueError:
-            ui_auth_session_id = None
-
+        nonce = get_value_from_macaroon(macaroon, "nonce")
+        idp_id = get_value_from_macaroon(macaroon, "idp_id")
+        client_redirect_url = get_value_from_macaroon(macaroon, "client_redirect_url")
+        ui_auth_session_id = get_value_from_macaroon(macaroon, "ui_auth_session_id")
         return OidcSessionData(
             nonce=nonce,
             idp_id=idp_id,
@@ -1084,33 +1074,6 @@ class OidcSessionTokenGenerator:
             ui_auth_session_id=ui_auth_session_id,
         )
 
-    def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
-        """Extracts a caveat value from a macaroon token.
-
-        Args:
-            macaroon: the token
-            key: the key of the caveat to extract
-
-        Returns:
-            The extracted value
-
-        Raises:
-            ValueError: if the caveat was not in the macaroon
-        """
-        prefix = key + " = "
-        for caveat in macaroon.caveats:
-            if caveat.caveat_id.startswith(prefix):
-                return caveat.caveat_id[len(prefix) :]
-        raise ValueError("No %s caveat in macaroon" % (key,))
-
-    def _verify_expiry(self, caveat: str) -> bool:
-        prefix = "time < "
-        if not caveat.startswith(prefix):
-            return False
-        expiry = int(caveat[len(prefix) :])
-        now = self._clock.time_msec()
-        return now < expiry
-
 
 @attr.s(frozen=True, slots=True)
 class OidcSessionData:
@@ -1125,8 +1088,8 @@ class OidcSessionData:
     # The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
     client_redirect_url = attr.ib(type=str)
 
-    # The session ID of the ongoing UI Auth (None if this is a login)
-    ui_auth_session_id = attr.ib(type=Optional[str], default=None)
+    # The session ID of the ongoing UI Auth ("" if this is a login)
+    ui_auth_session_id = attr.ib(type=str)
 
 
 UserAttributeDict = TypedDict(
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 3cda89657e..b66f8756b8 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -18,6 +18,8 @@
 import logging
 from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
 
+from prometheus_client import Counter
+
 from synapse import types
 from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
 from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
@@ -41,6 +43,19 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+registration_counter = Counter(
+    "synapse_user_registrations_total",
+    "Number of new users registered (since restart)",
+    ["guest", "shadow_banned", "auth_provider"],
+)
+
+login_counter = Counter(
+    "synapse_user_logins_total",
+    "Number of user logins (since restart)",
+    ["guest", "auth_provider"],
+)
+
+
 class RegistrationHandler(BaseHandler):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
@@ -156,6 +171,7 @@ class RegistrationHandler(BaseHandler):
         bind_emails: Iterable[str] = [],
         by_admin: bool = False,
         user_agent_ips: Optional[List[Tuple[str, str]]] = None,
+        auth_provider_id: Optional[str] = None,
     ) -> str:
         """Registers a new client on the server.
 
@@ -181,8 +197,10 @@ class RegistrationHandler(BaseHandler):
               admin api, otherwise False.
             user_agent_ips: Tuples of IP addresses and user-agents used
                 during the registration process.
+            auth_provider_id: The SSO IdP the user used, if any (just used for the
+                prometheus metrics).
         Returns:
-            The registere user_id.
+            The registered user_id.
         Raises:
             SynapseError if there was a problem registering.
         """
@@ -280,6 +298,12 @@ class RegistrationHandler(BaseHandler):
                     # if user id is taken, just generate another
                     fail_count += 1
 
+        registration_counter.labels(
+            guest=make_guest,
+            shadow_banned=shadow_banned,
+            auth_provider=(auth_provider_id or ""),
+        ).inc()
+
         if not self.hs.config.user_consent_at_registration:
             if not self.hs.config.auto_join_rooms_for_guests and make_guest:
                 logger.info(
@@ -638,6 +662,7 @@ class RegistrationHandler(BaseHandler):
         initial_display_name: Optional[str],
         is_guest: bool = False,
         is_appservice_ghost: bool = False,
+        auth_provider_id: Optional[str] = None,
     ) -> Tuple[str, str]:
         """Register a device for a user and generate an access token.
 
@@ -648,7 +673,8 @@ class RegistrationHandler(BaseHandler):
             device_id: The device ID to check, or None to generate a new one.
             initial_display_name: An optional display name for the device.
             is_guest: Whether this is a guest account
-
+            auth_provider_id: The SSO IdP the user used, if any (just used for the
+                prometheus metrics).
         Returns:
             Tuple of device ID and access token
         """
@@ -687,6 +713,11 @@ class RegistrationHandler(BaseHandler):
                 is_appservice_ghost=is_appservice_ghost,
             )
 
+        login_counter.labels(
+            guest=is_guest,
+            auth_provider=(auth_provider_id or ""),
+        ).inc()
+
         return (registered_device_id, access_token)
 
     async def post_registration_actions(
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 80e28bdcbe..6ef459acff 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -456,6 +456,7 @@ class SsoHandler:
 
         await self._auth_handler.complete_sso_login(
             user_id,
+            auth_provider_id,
             request,
             client_redirect_url,
             extra_login_attributes,
@@ -605,6 +606,7 @@ class SsoHandler:
             default_display_name=attributes.display_name,
             bind_emails=attributes.emails,
             user_agent_ips=[(user_agent, ip_address)],
+            auth_provider_id=auth_provider_id,
         )
 
         await self._store.record_user_external_id(
@@ -886,6 +888,7 @@ class SsoHandler:
 
         await self._auth_handler.complete_sso_login(
             user_id,
+            session.auth_provider_id,
             request,
             session.client_redirect_url,
             session.extra_login_attributes,
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 72901e3f95..af34d583ad 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -63,6 +63,7 @@ from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_u
 from synapse.http.proxyagent import ProxyAgent
 from synapse.logging.context import make_deferred_yieldable
 from synapse.logging.opentracing import set_tag, start_active_span, tags
+from synapse.types import ISynapseReactor
 from synapse.util import json_decoder
 from synapse.util.async_helpers import timeout_deferred
 
@@ -199,7 +200,7 @@ class _IPBlacklistingResolver:
         return r
 
 
-@implementer(IReactorPluggableNameResolver)
+@implementer(ISynapseReactor)
 class BlacklistingReactorWrapper:
     """
     A Reactor wrapper which will prevent DNS resolution to blacklisted IP
@@ -324,7 +325,7 @@ class SimpleHttpClient:
             # filters out blacklisted IP addresses, to prevent DNS rebinding.
             self.reactor = BlacklistingReactorWrapper(
                 hs.get_reactor(), self._ip_whitelist, self._ip_blacklist
-            )
+            )  # type: ISynapseReactor
         else:
             self.reactor = hs.get_reactor()
 
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index b07aa59c08..5935a125fd 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -35,6 +35,7 @@ from synapse.http.client import BlacklistingAgentWrapper
 from synapse.http.federation.srv_resolver import Server, SrvResolver
 from synapse.http.federation.well_known_resolver import WellKnownResolver
 from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.types import ISynapseReactor
 from synapse.util import Clock
 
 logger = logging.getLogger(__name__)
@@ -68,7 +69,7 @@ class MatrixFederationAgent:
 
     def __init__(
         self,
-        reactor: IReactorCore,
+        reactor: ISynapseReactor,
         tls_client_options_factory: Optional[FederationPolicyForHTTPS],
         user_agent: bytes,
         ip_blacklist: IPSet,
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 0f107714ea..da6866addf 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -59,7 +59,7 @@ from synapse.logging.opentracing import (
     start_active_span,
     tags,
 )
-from synapse.types import JsonDict
+from synapse.types import ISynapseReactor, JsonDict
 from synapse.util import json_decoder
 from synapse.util.async_helpers import timeout_deferred
 from synapse.util.metrics import Measure
@@ -237,14 +237,14 @@ class MatrixFederationHttpClient:
         # addresses, to prevent DNS rebinding.
         self.reactor = BlacklistingReactorWrapper(
             hs.get_reactor(), None, hs.config.federation_ip_range_blacklist
-        )
+        )  # type: ISynapseReactor
 
         user_agent = hs.version_string
         if hs.config.user_agent_suffix:
             user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix)
         user_agent = user_agent.encode("ascii")
 
-        self.agent = MatrixFederationAgent(
+        federation_agent = MatrixFederationAgent(
             self.reactor,
             tls_client_options_factory,
             user_agent,
@@ -254,7 +254,7 @@ class MatrixFederationHttpClient:
         # Use a BlacklistingAgentWrapper to prevent circumventing the IP
         # blacklist via IP literals in server names
         self.agent = BlacklistingAgentWrapper(
-            self.agent,
+            federation_agent,
             ip_blacklist=hs.config.federation_ip_range_blacklist,
         )
 
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index db2d400b7e..781e02fbbb 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -203,11 +203,26 @@ class ModuleApi:
         )
 
     def generate_short_term_login_token(
-        self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
+        self,
+        user_id: str,
+        duration_in_ms: int = (2 * 60 * 1000),
+        auth_provider_id: str = "",
     ) -> str:
-        """Generate a login token suitable for m.login.token authentication"""
+        """Generate a login token suitable for m.login.token authentication
+
+        Args:
+            user_id: gives the ID of the user that the token is for
+
+            duration_in_ms: the time that the token will be valid for
+
+            auth_provider_id: the ID of the SSO IdP that the user used to authenticate
+               to get this token, if any. This is encoded in the token so that
+               /login can report stats on number of successful logins by IdP.
+        """
         return self._hs.get_macaroon_generator().generate_short_term_login_token(
-            user_id, duration_in_ms
+            user_id,
+            auth_provider_id,
+            duration_in_ms,
         )
 
     @defer.inlineCallbacks
@@ -276,6 +291,7 @@ class ModuleApi:
         """
         self._auth_handler._complete_sso_login(
             registered_user_id,
+            "<unknown>",
             request,
             client_redirect_url,
         )
@@ -286,6 +302,7 @@ class ModuleApi:
         request: SynapseRequest,
         client_redirect_url: str,
         new_user: bool = False,
+        auth_provider_id: str = "<unknown>",
     ):
         """Complete a SSO login by redirecting the user to a page to confirm whether they
         want their access token sent to `client_redirect_url`, or redirect them to that
@@ -299,9 +316,15 @@ class ModuleApi:
                 redirect them directly if whitelisted).
             new_user: set to true to use wording for the consent appropriate to a user
                 who has just registered.
+            auth_provider_id: the ID of the SSO IdP which was used to log in. This
+                is used to track counts of sucessful logins by IdP.
         """
         await self._auth_handler.complete_sso_login(
-            registered_user_id, request, client_redirect_url, new_user=new_user
+            registered_user_id,
+            auth_provider_id,
+            request,
+            client_redirect_url,
+            new_user=new_user,
         )
 
     @defer.inlineCallbacks
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 0e6155cf53..7560706b4b 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -328,6 +328,6 @@ def lazyConnection(
     factory.continueTrying = reconnect
 
     reactor = hs.get_reactor()
-    reactor.connectTCP(host, port, factory, 30)
+    reactor.connectTCP(host, port, factory, timeout=30, bindAddress=None)
 
     return factory.handler
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 925edfc402..34bc1bd49b 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -219,6 +219,7 @@ class LoginRestServlet(RestServlet):
         callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
         create_non_existent_users: bool = False,
         ratelimit: bool = True,
+        auth_provider_id: Optional[str] = None,
     ) -> Dict[str, str]:
         """Called when we've successfully authed the user and now need to
         actually login them in (e.g. create devices). This gets called on
@@ -234,6 +235,8 @@ class LoginRestServlet(RestServlet):
             create_non_existent_users: Whether to create the user if they don't
                 exist. Defaults to False.
             ratelimit: Whether to ratelimit the login request.
+            auth_provider_id: The SSO IdP the user used, if any (just used for the
+                prometheus metrics).
 
         Returns:
             result: Dictionary of account information after successful login.
@@ -256,7 +259,7 @@ class LoginRestServlet(RestServlet):
         device_id = login_submission.get("device_id")
         initial_display_name = login_submission.get("initial_device_display_name")
         device_id, access_token = await self.registration_handler.register_device(
-            user_id, device_id, initial_display_name
+            user_id, device_id, initial_display_name, auth_provider_id=auth_provider_id
         )
 
         result = {
@@ -283,12 +286,13 @@ class LoginRestServlet(RestServlet):
         """
         token = login_submission["token"]
         auth_handler = self.auth_handler
-        user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
-            token
-        )
+        res = await auth_handler.validate_short_term_login_token(token)
 
         return await self._complete_login(
-            user_id, login_submission, self.auth_handler._sso_login_callback
+            res.user_id,
+            login_submission,
+            self.auth_handler._sso_login_callback,
+            auth_provider_id=res.auth_provider_id,
         )
 
     async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
diff --git a/synapse/server.py b/synapse/server.py
index afd7cd72e7..369cc88026 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -36,7 +36,6 @@ from typing import (
     cast,
 )
 
-import twisted.internet.base
 import twisted.internet.tcp
 from twisted.internet import defer
 from twisted.mail.smtp import sendmail
@@ -130,7 +129,7 @@ from synapse.server_notices.worker_server_notices_sender import (
 from synapse.state import StateHandler, StateResolutionHandler
 from synapse.storage import Databases, DataStore, Storage
 from synapse.streams.events import EventSources
-from synapse.types import DomainSpecificString
+from synapse.types import DomainSpecificString, ISynapseReactor
 from synapse.util import Clock
 from synapse.util.distributor import Distributor
 from synapse.util.ratelimitutils import FederationRateLimiter
@@ -291,7 +290,7 @@ class HomeServer(metaclass=abc.ABCMeta):
         for i in self.REQUIRED_ON_BACKGROUND_TASK_STARTUP:
             getattr(self, "get_" + i + "_handler")()
 
-    def get_reactor(self) -> twisted.internet.base.ReactorBase:
+    def get_reactor(self) -> ISynapseReactor:
         """
         Fetch the Twisted reactor in use by this HomeServer.
         """
diff --git a/synapse/types.py b/synapse/types.py
index 721343f0b5..0216d213c7 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -35,6 +35,14 @@ from typing import (
 import attr
 from signedjson.key import decode_verify_key_bytes
 from unpaddedbase64 import decode_base64
+from zope.interface import Interface
+
+from twisted.internet.interfaces import (
+    IReactorCore,
+    IReactorPluggableNameResolver,
+    IReactorTCP,
+    IReactorTime,
+)
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.util.stringutils import parse_and_validate_server_name
@@ -67,6 +75,14 @@ MutableStateMap = MutableMapping[StateKey, T]
 JsonDict = Dict[str, Any]
 
 
+# Note that this seems to require inheriting *directly* from Interface in order
+# for mypy-zope to realize it is an interface.
+class ISynapseReactor(
+    IReactorTCP, IReactorPluggableNameResolver, IReactorTime, IReactorCore, Interface
+):
+    """The interfaces necessary for Synapse to function."""
+
+
 class Requester(
     namedtuple(
         "Requester",
diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py
new file mode 100644
index 0000000000..12cdd53327
--- /dev/null
+++ b/synapse/util/macaroons.py
@@ -0,0 +1,89 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities for manipulating macaroons"""
+
+from typing import Callable, Optional
+
+import pymacaroons
+from pymacaroons.exceptions import MacaroonVerificationFailedException
+
+
+def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
+    """Extracts a caveat value from a macaroon token.
+
+    Checks that there is exactly one caveat of the form "key = <val>" in the macaroon,
+    and returns the extracted value.
+
+    Args:
+        macaroon: the token
+        key: the key of the caveat to extract
+
+    Returns:
+        The extracted value
+
+    Raises:
+        MacaroonVerificationFailedException: if there are conflicting values for the
+             caveat in the macaroon, or if the caveat was not found in the macaroon.
+    """
+    prefix = key + " = "
+    result = None  # type: Optional[str]
+    for caveat in macaroon.caveats:
+        if not caveat.caveat_id.startswith(prefix):
+            continue
+
+        val = caveat.caveat_id[len(prefix) :]
+
+        if result is None:
+            # first time we found this caveat: record the value
+            result = val
+        elif val != result:
+            # on subsequent occurrences, raise if the value is different.
+            raise MacaroonVerificationFailedException(
+                "Conflicting values for caveat " + key
+            )
+
+    if result is not None:
+        return result
+
+    # If the caveat is not there, we raise a MacaroonVerificationFailedException.
+    # Note that it is insecure to generate a macaroon without all the caveats you
+    # might need (because there is nothing stopping people from adding extra caveats),
+    # so if the caveat isn't there, something odd must be going on.
+    raise MacaroonVerificationFailedException("No %s caveat in macaroon" % (key,))
+
+
+def satisfy_expiry(v: pymacaroons.Verifier, get_time_ms: Callable[[], int]) -> None:
+    """Make a macaroon verifier which accepts 'time' caveats
+
+    Builds a caveat verifier which will accept unexpired 'time' caveats, and adds it to
+    the given macaroon verifier.
+
+    Args:
+        v: the macaroon verifier
+        get_time_ms: a callable which will return the timestamp after which the caveat
+            should be considered expired. Normally the current time.
+    """
+
+    def verify_expiry_caveat(caveat: str):
+        time_msec = get_time_ms()
+        prefix = "time < "
+        if not caveat.startswith(prefix):
+            return False
+        expiry = int(caveat[len(prefix) :])
+        return time_msec < expiry
+
+    v.satisfy_general(verify_expiry_caveat)
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 0e42013bb9..c9f889b511 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -68,38 +68,45 @@ class AuthTestCase(unittest.HomeserverTestCase):
         v.verify(macaroon, self.hs.config.macaroon_secret_key)
 
     def test_short_term_login_token_gives_user_id(self):
-        token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
-        user_id = self.get_success(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
+        token = self.macaroon_generator.generate_short_term_login_token(
+            "a_user", "", 5000
         )
-        self.assertEqual("a_user", user_id)
+        res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
+        self.assertEqual("a_user", res.user_id)
+        self.assertEqual("", res.auth_provider_id)
 
         # when we advance the clock, the token should be rejected
         self.reactor.advance(6)
         self.get_failure(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(token),
+            self.auth_handler.validate_short_term_login_token(token),
             AuthError,
         )
 
+    def test_short_term_login_token_gives_auth_provider(self):
+        token = self.macaroon_generator.generate_short_term_login_token(
+            "a_user", auth_provider_id="my_idp"
+        )
+        res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
+        self.assertEqual("a_user", res.user_id)
+        self.assertEqual("my_idp", res.auth_provider_id)
+
     def test_short_term_login_token_cannot_replace_user_id(self):
-        token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
+        token = self.macaroon_generator.generate_short_term_login_token(
+            "a_user", "", 5000
+        )
         macaroon = pymacaroons.Macaroon.deserialize(token)
 
-        user_id = self.get_success(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
-                macaroon.serialize()
-            )
+        res = self.get_success(
+            self.auth_handler.validate_short_term_login_token(macaroon.serialize())
         )
-        self.assertEqual("a_user", user_id)
+        self.assertEqual("a_user", res.user_id)
 
         # add another "user_id" caveat, which might allow us to override the
         # user_id.
         macaroon.add_first_party_caveat("user_id = b_user")
 
         self.get_failure(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
-                macaroon.serialize()
-            ),
+            self.auth_handler.validate_short_term_login_token(macaroon.serialize()),
             AuthError,
         )
 
@@ -113,7 +120,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         )
 
         self.get_success(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            self.auth_handler.validate_short_term_login_token(
                 self._get_macaroon().serialize()
             )
         )
@@ -135,7 +142,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             return_value=make_awaitable(self.large_number_of_users)
         )
         self.get_failure(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            self.auth_handler.validate_short_term_login_token(
                 self._get_macaroon().serialize()
             ),
             ResourceLimitError,
@@ -159,7 +166,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             ResourceLimitError,
         )
         self.get_failure(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            self.auth_handler.validate_short_term_login_token(
                 self._get_macaroon().serialize()
             ),
             ResourceLimitError,
@@ -175,7 +182,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            self.auth_handler.validate_short_term_login_token(
                 self._get_macaroon().serialize()
             )
         )
@@ -197,11 +204,13 @@ class AuthTestCase(unittest.HomeserverTestCase):
             return_value=make_awaitable(self.small_number_of_users)
         )
         self.get_success(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            self.auth_handler.validate_short_term_login_token(
                 self._get_macaroon().serialize()
             )
         )
 
     def _get_macaroon(self):
-        token = self.macaroon_generator.generate_short_term_login_token("user_a", 5000)
+        token = self.macaroon_generator.generate_short_term_login_token(
+            "user_a", "", 5000
+        )
         return pymacaroons.Macaroon.deserialize(token)
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 6f992291b8..7975af243c 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -66,7 +66,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None, new_user=True
+            "@test_user:test", "cas", request, "redirect_uri", None, new_user=True
         )
 
     def test_map_cas_user_to_existing_user(self):
@@ -89,7 +89,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None, new_user=False
+            "@test_user:test", "cas", request, "redirect_uri", None, new_user=False
         )
 
         # Subsequent calls should map to the same mxid.
@@ -98,7 +98,7 @@ class CasHandlerTestCase(HomeserverTestCase):
             self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
         )
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None, new_user=False
+            "@test_user:test", "cas", request, "redirect_uri", None, new_user=False
         )
 
     def test_map_cas_user_to_invalid_localpart(self):
@@ -116,7 +116,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
+            "@f=c3=b6=c3=b6:test", "cas", request, "redirect_uri", None, new_user=True
         )
 
     @override_config(
@@ -160,7 +160,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None, new_user=True
+            "@test_user:test", "cas", request, "redirect_uri", None, new_user=True
         )
 
 
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index cf1de28fa9..02d4b2de0d 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,7 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import json
-from typing import Optional
 from urllib.parse import parse_qs, urlparse
 
 from mock import ANY, Mock, patch
@@ -23,6 +22,7 @@ import pymacaroons
 from synapse.handlers.sso import MappingException
 from synapse.server import HomeServer
 from synapse.types import UserID
+from synapse.util.macaroons import get_value_from_macaroon
 
 from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
@@ -360,15 +360,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(name, b"oidc_session")
 
         macaroon = pymacaroons.Macaroon.deserialize(cookie)
-        state = self.handler._token_generator._get_value_from_macaroon(
-            macaroon, "state"
-        )
-        nonce = self.handler._token_generator._get_value_from_macaroon(
-            macaroon, "nonce"
-        )
-        redirect = self.handler._token_generator._get_value_from_macaroon(
-            macaroon, "client_redirect_url"
-        )
+        state = get_value_from_macaroon(macaroon, "state")
+        nonce = get_value_from_macaroon(macaroon, "nonce")
+        redirect = get_value_from_macaroon(macaroon, "client_redirect_url")
 
         self.assertEqual(params["state"], [state])
         self.assertEqual(params["nonce"], [nonce])
@@ -434,7 +428,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.handler.handle_oidc_callback(request))
 
         auth_handler.complete_sso_login.assert_called_once_with(
-            expected_user_id, request, client_redirect_url, None, new_user=True
+            expected_user_id, "oidc", request, client_redirect_url, None, new_user=True
         )
         self.provider._exchange_code.assert_called_once_with(code)
         self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
@@ -465,7 +459,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.handler.handle_oidc_callback(request))
 
         auth_handler.complete_sso_login.assert_called_once_with(
-            expected_user_id, request, client_redirect_url, None, new_user=False
+            expected_user_id, "oidc", request, client_redirect_url, None, new_user=False
         )
         self.provider._exchange_code.assert_called_once_with(code)
         self.provider._parse_id_token.assert_not_called()
@@ -651,6 +645,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         auth_handler.complete_sso_login.assert_called_once_with(
             "@foo:test",
+            "oidc",
             request,
             client_redirect_url,
             {"phone": "1234567"},
@@ -668,7 +663,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", ANY, ANY, None, new_user=True
+            "@test_user:test", "oidc", ANY, ANY, None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -679,7 +674,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user_2:test", ANY, ANY, None, new_user=True
+            "@test_user_2:test", "oidc", ANY, ANY, None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -716,14 +711,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            user.to_string(), ANY, ANY, None, new_user=False
+            user.to_string(), "oidc", ANY, ANY, None, new_user=False
         )
         auth_handler.complete_sso_login.reset_mock()
 
         # Subsequent calls should map to the same mxid.
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            user.to_string(), ANY, ANY, None, new_user=False
+            user.to_string(), "oidc", ANY, ANY, None, new_user=False
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -738,7 +733,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            user.to_string(), ANY, ANY, None, new_user=False
+            user.to_string(), "oidc", ANY, ANY, None, new_user=False
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -774,7 +769,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@TEST_USER_2:test", ANY, ANY, None, new_user=False
+            "@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False
         )
 
     def test_map_userinfo_to_invalid_localpart(self):
@@ -810,7 +805,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         # test_user is already taken, so test_user1 gets registered instead.
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user1:test", ANY, ANY, None, new_user=True
+            "@test_user1:test", "oidc", ANY, ANY, None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -866,7 +861,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         state: str,
         nonce: str,
         client_redirect_url: str,
-        ui_auth_session_id: Optional[str] = None,
+        ui_auth_session_id: str = "",
     ) -> str:
         from synapse.handlers.oidc_handler import OidcSessionData
 
@@ -909,6 +904,7 @@ async def _make_callback_with_userinfo(
             idp_id="oidc",
             nonce="nonce",
             client_redirect_url=client_redirect_url,
+            ui_auth_session_id="",
         ),
     )
     request = _build_callback_request("code", state, session)
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 029af2853e..30efd43b40 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -131,7 +131,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None, new_user=True
+            "@test_user:test", "saml", request, "redirect_uri", None, new_user=True
         )
 
     @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
@@ -157,7 +157,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "", None, new_user=False
+            "@test_user:test", "saml", request, "", None, new_user=False
         )
 
         # Subsequent calls should map to the same mxid.
@@ -166,7 +166,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
             self.handler._handle_authn_response(request, saml_response, "")
         )
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "", None, new_user=False
+            "@test_user:test", "saml", request, "", None, new_user=False
         )
 
     def test_map_saml_response_to_invalid_localpart(self):
@@ -214,7 +214,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # test_user is already taken, so test_user1 gets registered instead.
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user1:test", request, "", None, new_user=True
+            "@test_user1:test", "saml", request, "", None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -310,7 +310,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None, new_user=True
+            "@test_user:test", "saml", request, "redirect_uri", None, new_user=True
         )