summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/auth.py41
-rw-r--r--synapse/federation/federation_server.py10
-rw-r--r--synapse/federation/sender/transaction_manager.py11
-rw-r--r--synapse/handlers/acme.py4
-rw-r--r--synapse/handlers/auth.py68
-rw-r--r--synapse/handlers/oidc_handler.py65
-rw-r--r--synapse/handlers/register.py35
-rw-r--r--synapse/handlers/sso.py3
-rw-r--r--synapse/http/client.py5
-rw-r--r--synapse/http/federation/matrix_federation_agent.py3
-rw-r--r--synapse/http/matrixfederationclient.py8
-rw-r--r--synapse/module_api/__init__.py31
-rw-r--r--synapse/replication/tcp/redis.py2
-rw-r--r--synapse/rest/client/v1/login.py14
-rw-r--r--synapse/server.py5
-rw-r--r--synapse/types.py16
-rw-r--r--synapse/util/macaroons.py89
17 files changed, 282 insertions, 128 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 89e62b0e36..968cf6f174 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -39,6 +39,7 @@ from synapse.logging import opentracing as opentracing
 from synapse.storage.databases.main.registration import TokenLookupResult
 from synapse.types import StateMap, UserID
 from synapse.util.caches.lrucache import LruCache
+from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
 from synapse.util.metrics import Measure
 
 logger = logging.getLogger(__name__)
@@ -408,7 +409,7 @@ class Auth:
             raise _InvalidMacaroonException()
 
         try:
-            user_id = self.get_user_id_from_macaroon(macaroon)
+            user_id = get_value_from_macaroon(macaroon, "user_id")
 
             guest = False
             for caveat in macaroon.caveats:
@@ -416,7 +417,12 @@ class Auth:
                     guest = True
 
             self.validate_macaroon(macaroon, rights, user_id=user_id)
-        except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
+        except (
+            pymacaroons.exceptions.MacaroonException,
+            KeyError,
+            TypeError,
+            ValueError,
+        ):
             raise InvalidClientTokenError("Invalid macaroon passed.")
 
         if rights == "access":
@@ -424,27 +430,6 @@ class Auth:
 
         return user_id, guest
 
-    def get_user_id_from_macaroon(self, macaroon):
-        """Retrieve the user_id given by the caveats on the macaroon.
-
-        Does *not* validate the macaroon.
-
-        Args:
-            macaroon (pymacaroons.Macaroon): The macaroon to validate
-
-        Returns:
-            (str) user id
-
-        Raises:
-            InvalidClientCredentialsError if there is no user_id caveat in the
-                macaroon
-        """
-        user_prefix = "user_id = "
-        for caveat in macaroon.caveats:
-            if caveat.caveat_id.startswith(user_prefix):
-                return caveat.caveat_id[len(user_prefix) :]
-        raise InvalidClientTokenError("No user caveat in macaroon")
-
     def validate_macaroon(self, macaroon, type_string, user_id):
         """
         validate that a Macaroon is understood by and was signed by this server.
@@ -465,21 +450,13 @@ class Auth:
         v.satisfy_exact("type = " + type_string)
         v.satisfy_exact("user_id = %s" % user_id)
         v.satisfy_exact("guest = true")
-        v.satisfy_general(self._verify_expiry)
+        satisfy_expiry(v, self.clock.time_msec)
 
         # access_tokens include a nonce for uniqueness: any value is acceptable
         v.satisfy_general(lambda c: c.startswith("nonce = "))
 
         v.verify(macaroon, self._macaroon_secret_key)
 
-    def _verify_expiry(self, caveat):
-        prefix = "time < "
-        if not caveat.startswith(prefix):
-            return False
-        expiry = int(caveat[len(prefix) :])
-        now = self.hs.get_clock().time_msec()
-        return now < expiry
-
     def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService:
         token = self.get_access_token_from_request(request)
         service = self.store.get_app_service_by_token(token)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 2f832b47f6..362895bf42 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -90,10 +90,9 @@ pdu_process_time = Histogram(
     "Time taken to process an event",
 )
 
-
-last_pdu_age_metric = Gauge(
-    "synapse_federation_last_received_pdu_age",
-    "The age (in seconds) of the last PDU successfully received from the given domain",
+last_pdu_ts_metric = Gauge(
+    "synapse_federation_last_received_pdu_time",
+    "The timestamp of the last PDU which was successfully received from the given domain",
     labelnames=("server_name",),
 )
 
@@ -369,8 +368,7 @@ class FederationServer(FederationBase):
         )
 
         if newest_pdu_ts and origin in self._federation_metrics_domains:
-            newest_pdu_age = self._clock.time_msec() - newest_pdu_ts
-            last_pdu_age_metric.labels(server_name=origin).set(newest_pdu_age / 1000)
+            last_pdu_ts_metric.labels(server_name=origin).set(newest_pdu_ts / 1000)
 
         return pdu_results
 
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 763aff296c..2a9cd063c4 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -36,9 +36,9 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
-last_pdu_age_metric = Gauge(
-    "synapse_federation_last_sent_pdu_age",
-    "The age (in seconds) of the last PDU successfully sent to the given domain",
+last_pdu_ts_metric = Gauge(
+    "synapse_federation_last_sent_pdu_time",
+    "The timestamp of the last PDU which was successfully sent to the given domain",
     labelnames=("server_name",),
 )
 
@@ -187,9 +187,8 @@ class TransactionManager:
 
             if success and pdus and destination in self._federation_metrics_domains:
                 last_pdu = pdus[-1]
-                last_pdu_age = self.clock.time_msec() - last_pdu.origin_server_ts
-                last_pdu_age_metric.labels(server_name=destination).set(
-                    last_pdu_age / 1000
+                last_pdu_ts_metric.labels(server_name=destination).set(
+                    last_pdu.origin_server_ts / 1000
                 )
 
             set_tag(tags.ERROR, not success)
diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py
index 5ecb2da1ac..132be238dd 100644
--- a/synapse/handlers/acme.py
+++ b/synapse/handlers/acme.py
@@ -73,7 +73,9 @@ class AcmeHandler:
                 "Listening for ACME requests on %s:%i", host, self.hs.config.acme_port
             )
             try:
-                self.reactor.listenTCP(self.hs.config.acme_port, srv, interface=host)
+                self.reactor.listenTCP(
+                    self.hs.config.acme_port, srv, backlog=50, interface=host
+                )
             except twisted.internet.error.CannotListenError as e:
                 check_bind_error(e, host, bind_addresses)
 
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 3978e41518..bec0c615d4 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -65,6 +65,7 @@ from synapse.storage.roommember import ProfileInfo
 from synapse.types import JsonDict, Requester, UserID
 from synapse.util import stringutils as stringutils
 from synapse.util.async_helpers import maybe_awaitable
+from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
 from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.threepids import canonicalise_email
 
@@ -170,6 +171,16 @@ class SsoLoginExtraAttributes:
     extra_attributes = attr.ib(type=JsonDict)
 
 
+@attr.s(slots=True, frozen=True)
+class LoginTokenAttributes:
+    """Data we store in a short-term login token"""
+
+    user_id = attr.ib(type=str)
+
+    # the SSO Identity Provider that the user authenticated with, to get this token
+    auth_provider_id = attr.ib(type=str)
+
+
 class AuthHandler(BaseHandler):
     SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
 
@@ -1164,18 +1175,16 @@ class AuthHandler(BaseHandler):
             return None
         return user_id
 
-    async def validate_short_term_login_token_and_get_user_id(self, login_token: str):
-        auth_api = self.hs.get_auth()
-        user_id = None
+    async def validate_short_term_login_token(
+        self, login_token: str
+    ) -> LoginTokenAttributes:
         try:
-            macaroon = pymacaroons.Macaroon.deserialize(login_token)
-            user_id = auth_api.get_user_id_from_macaroon(macaroon)
-            auth_api.validate_macaroon(macaroon, "login", user_id)
+            res = self.macaroon_gen.verify_short_term_login_token(login_token)
         except Exception:
             raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
 
-        await self.auth.check_auth_blocking(user_id)
-        return user_id
+        await self.auth.check_auth_blocking(res.user_id)
+        return res
 
     async def delete_access_token(self, access_token: str):
         """Invalidate a single access token
@@ -1397,6 +1406,7 @@ class AuthHandler(BaseHandler):
     async def complete_sso_login(
         self,
         registered_user_id: str,
+        auth_provider_id: str,
         request: Request,
         client_redirect_url: str,
         extra_attributes: Optional[JsonDict] = None,
@@ -1406,6 +1416,9 @@ class AuthHandler(BaseHandler):
 
         Args:
             registered_user_id: The registered user ID to complete SSO login for.
+            auth_provider_id: The id of the SSO Identity provider that was used for
+                login. This will be stored in the login token for future tracking in
+                prometheus metrics.
             request: The request to complete.
             client_redirect_url: The URL to which to redirect the user at the end of the
                 process.
@@ -1427,6 +1440,7 @@ class AuthHandler(BaseHandler):
 
         self._complete_sso_login(
             registered_user_id,
+            auth_provider_id,
             request,
             client_redirect_url,
             extra_attributes,
@@ -1437,6 +1451,7 @@ class AuthHandler(BaseHandler):
     def _complete_sso_login(
         self,
         registered_user_id: str,
+        auth_provider_id: str,
         request: Request,
         client_redirect_url: str,
         extra_attributes: Optional[JsonDict] = None,
@@ -1463,7 +1478,7 @@ class AuthHandler(BaseHandler):
 
         # Create a login token
         login_token = self.macaroon_gen.generate_short_term_login_token(
-            registered_user_id
+            registered_user_id, auth_provider_id=auth_provider_id
         )
 
         # Append the login token to the original redirect URL (i.e. with its query
@@ -1569,15 +1584,48 @@ class MacaroonGenerator:
         return macaroon.serialize()
 
     def generate_short_term_login_token(
-        self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
+        self,
+        user_id: str,
+        auth_provider_id: str,
+        duration_in_ms: int = (2 * 60 * 1000),
     ) -> str:
         macaroon = self._generate_base_macaroon(user_id)
         macaroon.add_first_party_caveat("type = login")
         now = self.hs.get_clock().time_msec()
         expiry = now + duration_in_ms
         macaroon.add_first_party_caveat("time < %d" % (expiry,))
+        macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
         return macaroon.serialize()
 
+    def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
+        """Verify a short-term-login macaroon
+
+        Checks that the given token is a valid, unexpired short-term-login token
+        minted by this server.
+
+        Args:
+            token: the login token to verify
+
+        Returns:
+            the user_id that this token is valid for
+
+        Raises:
+            MacaroonVerificationFailedException if the verification failed
+        """
+        macaroon = pymacaroons.Macaroon.deserialize(token)
+        user_id = get_value_from_macaroon(macaroon, "user_id")
+        auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
+
+        v = pymacaroons.Verifier()
+        v.satisfy_exact("gen = 1")
+        v.satisfy_exact("type = login")
+        v.satisfy_general(lambda c: c.startswith("user_id = "))
+        v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
+        satisfy_expiry(v, self.hs.get_clock().time_msec)
+        v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
+
+        return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
+
     def generate_delete_pusher_token(self, user_id: str) -> str:
         macaroon = self._generate_base_macaroon(user_id)
         macaroon.add_first_party_caveat("type = delete_pusher")
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 07db1e31e4..b4a74390cc 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -42,6 +42,7 @@ from synapse.logging.context import make_deferred_yieldable
 from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
 from synapse.util import json_decoder
 from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
+from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -211,7 +212,7 @@ class OidcHandler:
             session_data = self._token_generator.verify_oidc_session_token(
                 session, state
             )
-        except (MacaroonDeserializationException, ValueError) as e:
+        except (MacaroonDeserializationException, KeyError) as e:
             logger.exception("Invalid session for OIDC callback")
             self._sso_handler.render_error(request, "invalid_session", str(e))
             return
@@ -745,7 +746,7 @@ class OidcProvider:
                 idp_id=self.idp_id,
                 nonce=nonce,
                 client_redirect_url=client_redirect_url.decode(),
-                ui_auth_session_id=ui_auth_session_id,
+                ui_auth_session_id=ui_auth_session_id or "",
             ),
         )
 
@@ -1020,10 +1021,9 @@ class OidcSessionTokenGenerator:
         macaroon.add_first_party_caveat(
             "client_redirect_url = %s" % (session_data.client_redirect_url,)
         )
-        if session_data.ui_auth_session_id:
-            macaroon.add_first_party_caveat(
-                "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
-            )
+        macaroon.add_first_party_caveat(
+            "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
+        )
         now = self._clock.time_msec()
         expiry = now + duration_in_ms
         macaroon.add_first_party_caveat("time < %d" % (expiry,))
@@ -1046,7 +1046,7 @@ class OidcSessionTokenGenerator:
             The data extracted from the session cookie
 
         Raises:
-            ValueError if an expected caveat is missing from the macaroon.
+            KeyError if an expected caveat is missing from the macaroon.
         """
         macaroon = pymacaroons.Macaroon.deserialize(session)
 
@@ -1057,26 +1057,16 @@ class OidcSessionTokenGenerator:
         v.satisfy_general(lambda c: c.startswith("nonce = "))
         v.satisfy_general(lambda c: c.startswith("idp_id = "))
         v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
-        # Sometimes there's a UI auth session ID, it seems to be OK to attempt
-        # to always satisfy this.
         v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
-        v.satisfy_general(self._verify_expiry)
+        satisfy_expiry(v, self._clock.time_msec)
 
         v.verify(macaroon, self._macaroon_secret_key)
 
         # Extract the session data from the token.
-        nonce = self._get_value_from_macaroon(macaroon, "nonce")
-        idp_id = self._get_value_from_macaroon(macaroon, "idp_id")
-        client_redirect_url = self._get_value_from_macaroon(
-            macaroon, "client_redirect_url"
-        )
-        try:
-            ui_auth_session_id = self._get_value_from_macaroon(
-                macaroon, "ui_auth_session_id"
-            )  # type: Optional[str]
-        except ValueError:
-            ui_auth_session_id = None
-
+        nonce = get_value_from_macaroon(macaroon, "nonce")
+        idp_id = get_value_from_macaroon(macaroon, "idp_id")
+        client_redirect_url = get_value_from_macaroon(macaroon, "client_redirect_url")
+        ui_auth_session_id = get_value_from_macaroon(macaroon, "ui_auth_session_id")
         return OidcSessionData(
             nonce=nonce,
             idp_id=idp_id,
@@ -1084,33 +1074,6 @@ class OidcSessionTokenGenerator:
             ui_auth_session_id=ui_auth_session_id,
         )
 
-    def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
-        """Extracts a caveat value from a macaroon token.
-
-        Args:
-            macaroon: the token
-            key: the key of the caveat to extract
-
-        Returns:
-            The extracted value
-
-        Raises:
-            ValueError: if the caveat was not in the macaroon
-        """
-        prefix = key + " = "
-        for caveat in macaroon.caveats:
-            if caveat.caveat_id.startswith(prefix):
-                return caveat.caveat_id[len(prefix) :]
-        raise ValueError("No %s caveat in macaroon" % (key,))
-
-    def _verify_expiry(self, caveat: str) -> bool:
-        prefix = "time < "
-        if not caveat.startswith(prefix):
-            return False
-        expiry = int(caveat[len(prefix) :])
-        now = self._clock.time_msec()
-        return now < expiry
-
 
 @attr.s(frozen=True, slots=True)
 class OidcSessionData:
@@ -1125,8 +1088,8 @@ class OidcSessionData:
     # The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
     client_redirect_url = attr.ib(type=str)
 
-    # The session ID of the ongoing UI Auth (None if this is a login)
-    ui_auth_session_id = attr.ib(type=Optional[str], default=None)
+    # The session ID of the ongoing UI Auth ("" if this is a login)
+    ui_auth_session_id = attr.ib(type=str)
 
 
 UserAttributeDict = TypedDict(
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 3cda89657e..b66f8756b8 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -18,6 +18,8 @@
 import logging
 from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
 
+from prometheus_client import Counter
+
 from synapse import types
 from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
 from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
@@ -41,6 +43,19 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+registration_counter = Counter(
+    "synapse_user_registrations_total",
+    "Number of new users registered (since restart)",
+    ["guest", "shadow_banned", "auth_provider"],
+)
+
+login_counter = Counter(
+    "synapse_user_logins_total",
+    "Number of user logins (since restart)",
+    ["guest", "auth_provider"],
+)
+
+
 class RegistrationHandler(BaseHandler):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
@@ -156,6 +171,7 @@ class RegistrationHandler(BaseHandler):
         bind_emails: Iterable[str] = [],
         by_admin: bool = False,
         user_agent_ips: Optional[List[Tuple[str, str]]] = None,
+        auth_provider_id: Optional[str] = None,
     ) -> str:
         """Registers a new client on the server.
 
@@ -181,8 +197,10 @@ class RegistrationHandler(BaseHandler):
               admin api, otherwise False.
             user_agent_ips: Tuples of IP addresses and user-agents used
                 during the registration process.
+            auth_provider_id: The SSO IdP the user used, if any (just used for the
+                prometheus metrics).
         Returns:
-            The registere user_id.
+            The registered user_id.
         Raises:
             SynapseError if there was a problem registering.
         """
@@ -280,6 +298,12 @@ class RegistrationHandler(BaseHandler):
                     # if user id is taken, just generate another
                     fail_count += 1
 
+        registration_counter.labels(
+            guest=make_guest,
+            shadow_banned=shadow_banned,
+            auth_provider=(auth_provider_id or ""),
+        ).inc()
+
         if not self.hs.config.user_consent_at_registration:
             if not self.hs.config.auto_join_rooms_for_guests and make_guest:
                 logger.info(
@@ -638,6 +662,7 @@ class RegistrationHandler(BaseHandler):
         initial_display_name: Optional[str],
         is_guest: bool = False,
         is_appservice_ghost: bool = False,
+        auth_provider_id: Optional[str] = None,
     ) -> Tuple[str, str]:
         """Register a device for a user and generate an access token.
 
@@ -648,7 +673,8 @@ class RegistrationHandler(BaseHandler):
             device_id: The device ID to check, or None to generate a new one.
             initial_display_name: An optional display name for the device.
             is_guest: Whether this is a guest account
-
+            auth_provider_id: The SSO IdP the user used, if any (just used for the
+                prometheus metrics).
         Returns:
             Tuple of device ID and access token
         """
@@ -687,6 +713,11 @@ class RegistrationHandler(BaseHandler):
                 is_appservice_ghost=is_appservice_ghost,
             )
 
+        login_counter.labels(
+            guest=is_guest,
+            auth_provider=(auth_provider_id or ""),
+        ).inc()
+
         return (registered_device_id, access_token)
 
     async def post_registration_actions(
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 80e28bdcbe..6ef459acff 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -456,6 +456,7 @@ class SsoHandler:
 
         await self._auth_handler.complete_sso_login(
             user_id,
+            auth_provider_id,
             request,
             client_redirect_url,
             extra_login_attributes,
@@ -605,6 +606,7 @@ class SsoHandler:
             default_display_name=attributes.display_name,
             bind_emails=attributes.emails,
             user_agent_ips=[(user_agent, ip_address)],
+            auth_provider_id=auth_provider_id,
         )
 
         await self._store.record_user_external_id(
@@ -886,6 +888,7 @@ class SsoHandler:
 
         await self._auth_handler.complete_sso_login(
             user_id,
+            session.auth_provider_id,
             request,
             session.client_redirect_url,
             session.extra_login_attributes,
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 72901e3f95..af34d583ad 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -63,6 +63,7 @@ from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_u
 from synapse.http.proxyagent import ProxyAgent
 from synapse.logging.context import make_deferred_yieldable
 from synapse.logging.opentracing import set_tag, start_active_span, tags
+from synapse.types import ISynapseReactor
 from synapse.util import json_decoder
 from synapse.util.async_helpers import timeout_deferred
 
@@ -199,7 +200,7 @@ class _IPBlacklistingResolver:
         return r
 
 
-@implementer(IReactorPluggableNameResolver)
+@implementer(ISynapseReactor)
 class BlacklistingReactorWrapper:
     """
     A Reactor wrapper which will prevent DNS resolution to blacklisted IP
@@ -324,7 +325,7 @@ class SimpleHttpClient:
             # filters out blacklisted IP addresses, to prevent DNS rebinding.
             self.reactor = BlacklistingReactorWrapper(
                 hs.get_reactor(), self._ip_whitelist, self._ip_blacklist
-            )
+            )  # type: ISynapseReactor
         else:
             self.reactor = hs.get_reactor()
 
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index b07aa59c08..5935a125fd 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -35,6 +35,7 @@ from synapse.http.client import BlacklistingAgentWrapper
 from synapse.http.federation.srv_resolver import Server, SrvResolver
 from synapse.http.federation.well_known_resolver import WellKnownResolver
 from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.types import ISynapseReactor
 from synapse.util import Clock
 
 logger = logging.getLogger(__name__)
@@ -68,7 +69,7 @@ class MatrixFederationAgent:
 
     def __init__(
         self,
-        reactor: IReactorCore,
+        reactor: ISynapseReactor,
         tls_client_options_factory: Optional[FederationPolicyForHTTPS],
         user_agent: bytes,
         ip_blacklist: IPSet,
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 0f107714ea..da6866addf 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -59,7 +59,7 @@ from synapse.logging.opentracing import (
     start_active_span,
     tags,
 )
-from synapse.types import JsonDict
+from synapse.types import ISynapseReactor, JsonDict
 from synapse.util import json_decoder
 from synapse.util.async_helpers import timeout_deferred
 from synapse.util.metrics import Measure
@@ -237,14 +237,14 @@ class MatrixFederationHttpClient:
         # addresses, to prevent DNS rebinding.
         self.reactor = BlacklistingReactorWrapper(
             hs.get_reactor(), None, hs.config.federation_ip_range_blacklist
-        )
+        )  # type: ISynapseReactor
 
         user_agent = hs.version_string
         if hs.config.user_agent_suffix:
             user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix)
         user_agent = user_agent.encode("ascii")
 
-        self.agent = MatrixFederationAgent(
+        federation_agent = MatrixFederationAgent(
             self.reactor,
             tls_client_options_factory,
             user_agent,
@@ -254,7 +254,7 @@ class MatrixFederationHttpClient:
         # Use a BlacklistingAgentWrapper to prevent circumventing the IP
         # blacklist via IP literals in server names
         self.agent = BlacklistingAgentWrapper(
-            self.agent,
+            federation_agent,
             ip_blacklist=hs.config.federation_ip_range_blacklist,
         )
 
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index db2d400b7e..781e02fbbb 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -203,11 +203,26 @@ class ModuleApi:
         )
 
     def generate_short_term_login_token(
-        self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
+        self,
+        user_id: str,
+        duration_in_ms: int = (2 * 60 * 1000),
+        auth_provider_id: str = "",
     ) -> str:
-        """Generate a login token suitable for m.login.token authentication"""
+        """Generate a login token suitable for m.login.token authentication
+
+        Args:
+            user_id: gives the ID of the user that the token is for
+
+            duration_in_ms: the time that the token will be valid for
+
+            auth_provider_id: the ID of the SSO IdP that the user used to authenticate
+               to get this token, if any. This is encoded in the token so that
+               /login can report stats on number of successful logins by IdP.
+        """
         return self._hs.get_macaroon_generator().generate_short_term_login_token(
-            user_id, duration_in_ms
+            user_id,
+            auth_provider_id,
+            duration_in_ms,
         )
 
     @defer.inlineCallbacks
@@ -276,6 +291,7 @@ class ModuleApi:
         """
         self._auth_handler._complete_sso_login(
             registered_user_id,
+            "<unknown>",
             request,
             client_redirect_url,
         )
@@ -286,6 +302,7 @@ class ModuleApi:
         request: SynapseRequest,
         client_redirect_url: str,
         new_user: bool = False,
+        auth_provider_id: str = "<unknown>",
     ):
         """Complete a SSO login by redirecting the user to a page to confirm whether they
         want their access token sent to `client_redirect_url`, or redirect them to that
@@ -299,9 +316,15 @@ class ModuleApi:
                 redirect them directly if whitelisted).
             new_user: set to true to use wording for the consent appropriate to a user
                 who has just registered.
+            auth_provider_id: the ID of the SSO IdP which was used to log in. This
+                is used to track counts of sucessful logins by IdP.
         """
         await self._auth_handler.complete_sso_login(
-            registered_user_id, request, client_redirect_url, new_user=new_user
+            registered_user_id,
+            auth_provider_id,
+            request,
+            client_redirect_url,
+            new_user=new_user,
         )
 
     @defer.inlineCallbacks
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 0e6155cf53..7560706b4b 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -328,6 +328,6 @@ def lazyConnection(
     factory.continueTrying = reconnect
 
     reactor = hs.get_reactor()
-    reactor.connectTCP(host, port, factory, 30)
+    reactor.connectTCP(host, port, factory, timeout=30, bindAddress=None)
 
     return factory.handler
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 925edfc402..34bc1bd49b 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -219,6 +219,7 @@ class LoginRestServlet(RestServlet):
         callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
         create_non_existent_users: bool = False,
         ratelimit: bool = True,
+        auth_provider_id: Optional[str] = None,
     ) -> Dict[str, str]:
         """Called when we've successfully authed the user and now need to
         actually login them in (e.g. create devices). This gets called on
@@ -234,6 +235,8 @@ class LoginRestServlet(RestServlet):
             create_non_existent_users: Whether to create the user if they don't
                 exist. Defaults to False.
             ratelimit: Whether to ratelimit the login request.
+            auth_provider_id: The SSO IdP the user used, if any (just used for the
+                prometheus metrics).
 
         Returns:
             result: Dictionary of account information after successful login.
@@ -256,7 +259,7 @@ class LoginRestServlet(RestServlet):
         device_id = login_submission.get("device_id")
         initial_display_name = login_submission.get("initial_device_display_name")
         device_id, access_token = await self.registration_handler.register_device(
-            user_id, device_id, initial_display_name
+            user_id, device_id, initial_display_name, auth_provider_id=auth_provider_id
         )
 
         result = {
@@ -283,12 +286,13 @@ class LoginRestServlet(RestServlet):
         """
         token = login_submission["token"]
         auth_handler = self.auth_handler
-        user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
-            token
-        )
+        res = await auth_handler.validate_short_term_login_token(token)
 
         return await self._complete_login(
-            user_id, login_submission, self.auth_handler._sso_login_callback
+            res.user_id,
+            login_submission,
+            self.auth_handler._sso_login_callback,
+            auth_provider_id=res.auth_provider_id,
         )
 
     async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
diff --git a/synapse/server.py b/synapse/server.py
index afd7cd72e7..369cc88026 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -36,7 +36,6 @@ from typing import (
     cast,
 )
 
-import twisted.internet.base
 import twisted.internet.tcp
 from twisted.internet import defer
 from twisted.mail.smtp import sendmail
@@ -130,7 +129,7 @@ from synapse.server_notices.worker_server_notices_sender import (
 from synapse.state import StateHandler, StateResolutionHandler
 from synapse.storage import Databases, DataStore, Storage
 from synapse.streams.events import EventSources
-from synapse.types import DomainSpecificString
+from synapse.types import DomainSpecificString, ISynapseReactor
 from synapse.util import Clock
 from synapse.util.distributor import Distributor
 from synapse.util.ratelimitutils import FederationRateLimiter
@@ -291,7 +290,7 @@ class HomeServer(metaclass=abc.ABCMeta):
         for i in self.REQUIRED_ON_BACKGROUND_TASK_STARTUP:
             getattr(self, "get_" + i + "_handler")()
 
-    def get_reactor(self) -> twisted.internet.base.ReactorBase:
+    def get_reactor(self) -> ISynapseReactor:
         """
         Fetch the Twisted reactor in use by this HomeServer.
         """
diff --git a/synapse/types.py b/synapse/types.py
index 721343f0b5..0216d213c7 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -35,6 +35,14 @@ from typing import (
 import attr
 from signedjson.key import decode_verify_key_bytes
 from unpaddedbase64 import decode_base64
+from zope.interface import Interface
+
+from twisted.internet.interfaces import (
+    IReactorCore,
+    IReactorPluggableNameResolver,
+    IReactorTCP,
+    IReactorTime,
+)
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.util.stringutils import parse_and_validate_server_name
@@ -67,6 +75,14 @@ MutableStateMap = MutableMapping[StateKey, T]
 JsonDict = Dict[str, Any]
 
 
+# Note that this seems to require inheriting *directly* from Interface in order
+# for mypy-zope to realize it is an interface.
+class ISynapseReactor(
+    IReactorTCP, IReactorPluggableNameResolver, IReactorTime, IReactorCore, Interface
+):
+    """The interfaces necessary for Synapse to function."""
+
+
 class Requester(
     namedtuple(
         "Requester",
diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py
new file mode 100644
index 0000000000..12cdd53327
--- /dev/null
+++ b/synapse/util/macaroons.py
@@ -0,0 +1,89 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities for manipulating macaroons"""
+
+from typing import Callable, Optional
+
+import pymacaroons
+from pymacaroons.exceptions import MacaroonVerificationFailedException
+
+
+def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
+    """Extracts a caveat value from a macaroon token.
+
+    Checks that there is exactly one caveat of the form "key = <val>" in the macaroon,
+    and returns the extracted value.
+
+    Args:
+        macaroon: the token
+        key: the key of the caveat to extract
+
+    Returns:
+        The extracted value
+
+    Raises:
+        MacaroonVerificationFailedException: if there are conflicting values for the
+             caveat in the macaroon, or if the caveat was not found in the macaroon.
+    """
+    prefix = key + " = "
+    result = None  # type: Optional[str]
+    for caveat in macaroon.caveats:
+        if not caveat.caveat_id.startswith(prefix):
+            continue
+
+        val = caveat.caveat_id[len(prefix) :]
+
+        if result is None:
+            # first time we found this caveat: record the value
+            result = val
+        elif val != result:
+            # on subsequent occurrences, raise if the value is different.
+            raise MacaroonVerificationFailedException(
+                "Conflicting values for caveat " + key
+            )
+
+    if result is not None:
+        return result
+
+    # If the caveat is not there, we raise a MacaroonVerificationFailedException.
+    # Note that it is insecure to generate a macaroon without all the caveats you
+    # might need (because there is nothing stopping people from adding extra caveats),
+    # so if the caveat isn't there, something odd must be going on.
+    raise MacaroonVerificationFailedException("No %s caveat in macaroon" % (key,))
+
+
+def satisfy_expiry(v: pymacaroons.Verifier, get_time_ms: Callable[[], int]) -> None:
+    """Make a macaroon verifier which accepts 'time' caveats
+
+    Builds a caveat verifier which will accept unexpired 'time' caveats, and adds it to
+    the given macaroon verifier.
+
+    Args:
+        v: the macaroon verifier
+        get_time_ms: a callable which will return the timestamp after which the caveat
+            should be considered expired. Normally the current time.
+    """
+
+    def verify_expiry_caveat(caveat: str):
+        time_msec = get_time_ms()
+        prefix = "time < "
+        if not caveat.startswith(prefix):
+            return False
+        expiry = int(caveat[len(prefix) :])
+        return time_msec < expiry
+
+    v.satisfy_general(verify_expiry_caveat)