summary refs log tree commit diff
path: root/synapse/handlers/oidc_handler.py
diff options
context:
space:
mode:
authorBen Banfield-Zanin <benbz@matrix.org>2021-02-16 13:33:20 +0000
committerBen Banfield-Zanin <benbz@matrix.org>2021-02-16 13:33:20 +0000
commitdcf1b9c276e22bb6f5200fc029301c4d40e87a1f (patch)
tree1f5badce24645d99534133a7a989069906088fff /synapse/handlers/oidc_handler.py
parentMerge remote-tracking branch 'origin/release-v1.24.0' into bbz/info-mainline-... (diff)
parentFixup CHANGES (diff)
downloadsynapse-dcf1b9c276e22bb6f5200fc029301c4d40e87a1f.tar.xz
Merge remote-tracking branch 'origin/release-v1.27.0' into bbz/info-mainline-1.27.0 github/bbz/info-mainline-1.27.0 bbz/info-mainline-1.27.0
Diffstat (limited to 'synapse/handlers/oidc_handler.py')
-rw-r--r--synapse/handlers/oidc_handler.py745
1 files changed, 434 insertions, 311 deletions
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index c605f7082a..71008ec50d 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import inspect
 import logging
-from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
+from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar
 from urllib.parse import urlencode
 
 import attr
@@ -35,7 +35,7 @@ from typing_extensions import TypedDict
 from twisted.web.client import readBody
 
 from synapse.config import ConfigError
-from synapse.handlers._base import BaseHandler
+from synapse.config.oidc_config import OidcProviderConfig
 from synapse.handlers.sso import MappingException, UserAttributes
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
@@ -71,6 +71,144 @@ JWK = Dict[str, str]
 JWKS = TypedDict("JWKS", {"keys": List[JWK]})
 
 
+class OidcHandler:
+    """Handles requests related to the OpenID Connect login flow.
+    """
+
+    def __init__(self, hs: "HomeServer"):
+        self._sso_handler = hs.get_sso_handler()
+
+        provider_confs = hs.config.oidc.oidc_providers
+        # we should not have been instantiated if there is no configured provider.
+        assert provider_confs
+
+        self._token_generator = OidcSessionTokenGenerator(hs)
+        self._providers = {
+            p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
+        }  # type: Dict[str, OidcProvider]
+
+    async def load_metadata(self) -> None:
+        """Validate the config and load the metadata from the remote endpoint.
+
+        Called at startup to ensure we have everything we need.
+        """
+        for idp_id, p in self._providers.items():
+            try:
+                await p.load_metadata()
+                await p.load_jwks()
+            except Exception as e:
+                raise Exception(
+                    "Error while initialising OIDC provider %r" % (idp_id,)
+                ) from e
+
+    async def handle_oidc_callback(self, request: SynapseRequest) -> None:
+        """Handle an incoming request to /_synapse/client/oidc/callback
+
+        Since we might want to display OIDC-related errors in a user-friendly
+        way, we don't raise SynapseError from here. Instead, we call
+        ``self._sso_handler.render_error`` which displays an HTML page for the error.
+
+        Most of the OpenID Connect logic happens here:
+
+          - first, we check if there was any error returned by the provider and
+            display it
+          - then we fetch the session cookie, decode and verify it
+          - the ``state`` query parameter should match with the one stored in the
+            session cookie
+
+        Once we know the session is legit, we then delegate to the OIDC Provider
+        implementation, which will exchange the code with the provider and complete the
+        login/authentication.
+
+        Args:
+            request: the incoming request from the browser.
+        """
+
+        # The provider might redirect with an error.
+        # In that case, just display it as-is.
+        if b"error" in request.args:
+            # error response from the auth server. see:
+            #  https://tools.ietf.org/html/rfc6749#section-4.1.2.1
+            #  https://openid.net/specs/openid-connect-core-1_0.html#AuthError
+            error = request.args[b"error"][0].decode()
+            description = request.args.get(b"error_description", [b""])[0].decode()
+
+            # Most of the errors returned by the provider could be due by
+            # either the provider misbehaving or Synapse being misconfigured.
+            # The only exception of that is "access_denied", where the user
+            # probably cancelled the login flow. In other cases, log those errors.
+            if error != "access_denied":
+                logger.error("Error from the OIDC provider: %s %s", error, description)
+
+            self._sso_handler.render_error(request, error, description)
+            return
+
+        # otherwise, it is presumably a successful response. see:
+        #   https://tools.ietf.org/html/rfc6749#section-4.1.2
+
+        # Fetch the session cookie
+        session = request.getCookie(SESSION_COOKIE_NAME)  # type: Optional[bytes]
+        if session is None:
+            logger.info("No session cookie found")
+            self._sso_handler.render_error(
+                request, "missing_session", "No session cookie found"
+            )
+            return
+
+        # Remove the cookie. There is a good chance that if the callback failed
+        # once, it will fail next time and the code will already be exchanged.
+        # Removing it early avoids spamming the provider with token requests.
+        request.addCookie(
+            SESSION_COOKIE_NAME,
+            b"",
+            path="/_synapse/oidc",
+            expires="Thu, Jan 01 1970 00:00:00 UTC",
+            httpOnly=True,
+            sameSite="lax",
+        )
+
+        # Check for the state query parameter
+        if b"state" not in request.args:
+            logger.info("State parameter is missing")
+            self._sso_handler.render_error(
+                request, "invalid_request", "State parameter is missing"
+            )
+            return
+
+        state = request.args[b"state"][0].decode()
+
+        # Deserialize the session token and verify it.
+        try:
+            session_data = self._token_generator.verify_oidc_session_token(
+                session, state
+            )
+        except (MacaroonDeserializationException, ValueError) as e:
+            logger.exception("Invalid session")
+            self._sso_handler.render_error(request, "invalid_session", str(e))
+            return
+        except MacaroonInvalidSignatureException as e:
+            logger.exception("Could not verify session")
+            self._sso_handler.render_error(request, "mismatching_session", str(e))
+            return
+
+        oidc_provider = self._providers.get(session_data.idp_id)
+        if not oidc_provider:
+            logger.error("OIDC session uses unknown IdP %r", oidc_provider)
+            self._sso_handler.render_error(request, "unknown_idp", "Unknown IdP")
+            return
+
+        if b"code" not in request.args:
+            logger.info("Code parameter is missing")
+            self._sso_handler.render_error(
+                request, "invalid_request", "Code parameter is missing"
+            )
+            return
+
+        code = request.args[b"code"][0].decode()
+
+        await oidc_provider.handle_oidc_callback(request, session_data, code)
+
+
 class OidcError(Exception):
     """Used to catch errors when calling the token_endpoint
     """
@@ -85,46 +223,64 @@ class OidcError(Exception):
         return self.error
 
 
-class OidcHandler(BaseHandler):
-    """Handles requests related to the OpenID Connect login flow.
+class OidcProvider:
+    """Wraps the config for a single OIDC IdentityProvider
+
+    Provides methods for handling redirect requests and callbacks via that particular
+    IdP.
     """
 
-    def __init__(self, hs: "HomeServer"):
-        super().__init__(hs)
+    def __init__(
+        self,
+        hs: "HomeServer",
+        token_generator: "OidcSessionTokenGenerator",
+        provider: OidcProviderConfig,
+    ):
+        self._store = hs.get_datastore()
+
+        self._token_generator = token_generator
+
         self._callback_url = hs.config.oidc_callback_url  # type: str
-        self._scopes = hs.config.oidc_scopes  # type: List[str]
-        self._user_profile_method = hs.config.oidc_user_profile_method  # type: str
+
+        self._scopes = provider.scopes
+        self._user_profile_method = provider.user_profile_method
         self._client_auth = ClientAuth(
-            hs.config.oidc_client_id,
-            hs.config.oidc_client_secret,
-            hs.config.oidc_client_auth_method,
+            provider.client_id, provider.client_secret, provider.client_auth_method,
         )  # type: ClientAuth
-        self._client_auth_method = hs.config.oidc_client_auth_method  # type: str
+        self._client_auth_method = provider.client_auth_method
         self._provider_metadata = OpenIDProviderMetadata(
-            issuer=hs.config.oidc_issuer,
-            authorization_endpoint=hs.config.oidc_authorization_endpoint,
-            token_endpoint=hs.config.oidc_token_endpoint,
-            userinfo_endpoint=hs.config.oidc_userinfo_endpoint,
-            jwks_uri=hs.config.oidc_jwks_uri,
+            issuer=provider.issuer,
+            authorization_endpoint=provider.authorization_endpoint,
+            token_endpoint=provider.token_endpoint,
+            userinfo_endpoint=provider.userinfo_endpoint,
+            jwks_uri=provider.jwks_uri,
         )  # type: OpenIDProviderMetadata
-        self._provider_needs_discovery = hs.config.oidc_discover  # type: bool
-        self._user_mapping_provider = hs.config.oidc_user_mapping_provider_class(
-            hs.config.oidc_user_mapping_provider_config
-        )  # type: OidcMappingProvider
-        self._skip_verification = hs.config.oidc_skip_verification  # type: bool
-        self._allow_existing_users = hs.config.oidc_allow_existing_users  # type: bool
+        self._provider_needs_discovery = provider.discover
+        self._user_mapping_provider = provider.user_mapping_provider_class(
+            provider.user_mapping_provider_config
+        )
+        self._skip_verification = provider.skip_verification
+        self._allow_existing_users = provider.allow_existing_users
 
         self._http_client = hs.get_proxied_http_client()
-        self._auth_handler = hs.get_auth_handler()
-        self._registration_handler = hs.get_registration_handler()
         self._server_name = hs.config.server_name  # type: str
-        self._macaroon_secret_key = hs.config.macaroon_secret_key
 
         # identifier for the external_ids table
-        self._auth_provider_id = "oidc"
+        self.idp_id = provider.idp_id
+
+        # user-facing name of this auth provider
+        self.idp_name = provider.idp_name
+
+        # MXC URI for icon for this auth provider
+        self.idp_icon = provider.idp_icon
+
+        # optional brand identifier for this auth provider
+        self.idp_brand = provider.idp_brand
 
         self._sso_handler = hs.get_sso_handler()
 
+        self._sso_handler.register_identity_provider(self)
+
     def _validate_metadata(self):
         """Verifies the provider metadata.
 
@@ -477,7 +633,7 @@ class OidcHandler(BaseHandler):
     async def handle_redirect_request(
         self,
         request: SynapseRequest,
-        client_redirect_url: bytes,
+        client_redirect_url: Optional[bytes],
         ui_auth_session_id: Optional[str] = None,
     ) -> str:
         """Handle an incoming request to /login/sso/redirect
@@ -487,7 +643,7 @@ class OidcHandler(BaseHandler):
 
           - ``client_id``: the client ID set in ``oidc_config.client_id``
           - ``response_type``: ``code``
-          - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/oidc/callback``
+          - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/client/oidc/callback``
           - ``scope``: the list of scopes set in ``oidc_config.scopes``
           - ``state``: a random string
           - ``nonce``: a random string
@@ -501,7 +657,7 @@ class OidcHandler(BaseHandler):
             request: the incoming request from the browser.
                 We'll respond to it with a redirect and a cookie.
             client_redirect_url: the URL that we should redirect the client to
-                when everything is done
+                when everything is done (or None for UI Auth)
             ui_auth_session_id: The session ID of the ongoing UI Auth (or
                 None if this is a login).
 
@@ -513,16 +669,22 @@ class OidcHandler(BaseHandler):
         state = generate_token()
         nonce = generate_token()
 
-        cookie = self._generate_oidc_session_token(
+        if not client_redirect_url:
+            client_redirect_url = b""
+
+        cookie = self._token_generator.generate_oidc_session_token(
             state=state,
-            nonce=nonce,
-            client_redirect_url=client_redirect_url.decode(),
-            ui_auth_session_id=ui_auth_session_id,
+            session_data=OidcSessionData(
+                idp_id=self.idp_id,
+                nonce=nonce,
+                client_redirect_url=client_redirect_url.decode(),
+                ui_auth_session_id=ui_auth_session_id,
+            ),
         )
         request.addCookie(
             SESSION_COOKIE_NAME,
             cookie,
-            path="/_synapse/oidc",
+            path="/_synapse/client/oidc",
             max_age="3600",
             httpOnly=True,
             sameSite="lax",
@@ -540,22 +702,16 @@ class OidcHandler(BaseHandler):
             nonce=nonce,
         )
 
-    async def handle_oidc_callback(self, request: SynapseRequest) -> None:
-        """Handle an incoming request to /_synapse/oidc/callback
+    async def handle_oidc_callback(
+        self, request: SynapseRequest, session_data: "OidcSessionData", code: str
+    ) -> None:
+        """Handle an incoming request to /_synapse/client/oidc/callback
 
-        Since we might want to display OIDC-related errors in a user-friendly
-        way, we don't raise SynapseError from here. Instead, we call
-        ``self._sso_handler.render_error`` which displays an HTML page for the error.
-
-        Most of the OpenID Connect logic happens here:
+        By this time we have already validated the session on the synapse side, and
+        now need to do the provider-specific operations. This includes:
 
-          - first, we check if there was any error returned by the provider and
-            display it
-          - then we fetch the session cookie, decode and verify it
-          - the ``state`` query parameter should match with the one stored in the
-            session cookie
-          - once we known this session is legit, exchange the code with the
-            provider using the ``token_endpoint`` (see ``_exchange_code``)
+          - exchange the code with the provider using the ``token_endpoint`` (see
+            ``_exchange_code``)
           - once we have the token, use it to either extract the UserInfo from
             the ``id_token`` (``_parse_id_token``), or use the ``access_token``
             to fetch UserInfo from the ``userinfo_endpoint``
@@ -565,88 +721,12 @@ class OidcHandler(BaseHandler):
 
         Args:
             request: the incoming request from the browser.
+            session_data: the session data, extracted from our cookie
+            code: The authorization code we got from the callback.
         """
-
-        # The provider might redirect with an error.
-        # In that case, just display it as-is.
-        if b"error" in request.args:
-            # error response from the auth server. see:
-            #  https://tools.ietf.org/html/rfc6749#section-4.1.2.1
-            #  https://openid.net/specs/openid-connect-core-1_0.html#AuthError
-            error = request.args[b"error"][0].decode()
-            description = request.args.get(b"error_description", [b""])[0].decode()
-
-            # Most of the errors returned by the provider could be due by
-            # either the provider misbehaving or Synapse being misconfigured.
-            # The only exception of that is "access_denied", where the user
-            # probably cancelled the login flow. In other cases, log those errors.
-            if error != "access_denied":
-                logger.error("Error from the OIDC provider: %s %s", error, description)
-
-            self._sso_handler.render_error(request, error, description)
-            return
-
-        # otherwise, it is presumably a successful response. see:
-        #   https://tools.ietf.org/html/rfc6749#section-4.1.2
-
-        # Fetch the session cookie
-        session = request.getCookie(SESSION_COOKIE_NAME)  # type: Optional[bytes]
-        if session is None:
-            logger.info("No session cookie found")
-            self._sso_handler.render_error(
-                request, "missing_session", "No session cookie found"
-            )
-            return
-
-        # Remove the cookie. There is a good chance that if the callback failed
-        # once, it will fail next time and the code will already be exchanged.
-        # Removing it early avoids spamming the provider with token requests.
-        request.addCookie(
-            SESSION_COOKIE_NAME,
-            b"",
-            path="/_synapse/oidc",
-            expires="Thu, Jan 01 1970 00:00:00 UTC",
-            httpOnly=True,
-            sameSite="lax",
-        )
-
-        # Check for the state query parameter
-        if b"state" not in request.args:
-            logger.info("State parameter is missing")
-            self._sso_handler.render_error(
-                request, "invalid_request", "State parameter is missing"
-            )
-            return
-
-        state = request.args[b"state"][0].decode()
-
-        # Deserialize the session token and verify it.
-        try:
-            (
-                nonce,
-                client_redirect_url,
-                ui_auth_session_id,
-            ) = self._verify_oidc_session_token(session, state)
-        except MacaroonDeserializationException as e:
-            logger.exception("Invalid session")
-            self._sso_handler.render_error(request, "invalid_session", str(e))
-            return
-        except MacaroonInvalidSignatureException as e:
-            logger.exception("Could not verify session")
-            self._sso_handler.render_error(request, "mismatching_session", str(e))
-            return
-
         # Exchange the code with the provider
-        if b"code" not in request.args:
-            logger.info("Code parameter is missing")
-            self._sso_handler.render_error(
-                request, "invalid_request", "Code parameter is missing"
-            )
-            return
-
-        logger.debug("Exchanging code")
-        code = request.args[b"code"][0].decode()
         try:
+            logger.debug("Exchanging code")
             token = await self._exchange_code(code)
         except OidcError as e:
             logger.exception("Could not exchange code")
@@ -668,25 +748,131 @@ class OidcHandler(BaseHandler):
         else:
             logger.debug("Extracting userinfo from id_token")
             try:
-                userinfo = await self._parse_id_token(token, nonce=nonce)
+                userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
             except Exception as e:
                 logger.exception("Invalid id_token")
                 self._sso_handler.render_error(request, "invalid_token", str(e))
                 return
 
-        # Pull out the user-agent and IP from the request.
-        user_agent = request.get_user_agent("")
-        ip_address = self.hs.get_ip_from_request(request)
+        # first check if we're doing a UIA
+        if session_data.ui_auth_session_id:
+            try:
+                remote_user_id = self._remote_id_from_userinfo(userinfo)
+            except Exception as e:
+                logger.exception("Could not extract remote user id")
+                self._sso_handler.render_error(request, "mapping_error", str(e))
+                return
+
+            return await self._sso_handler.complete_sso_ui_auth_request(
+                self.idp_id, remote_user_id, session_data.ui_auth_session_id, request
+            )
+
+        # otherwise, it's a login
 
         # Call the mapper to register/login the user
         try:
-            user_id = await self._map_userinfo_to_user(
-                userinfo, token, user_agent, ip_address
+            await self._complete_oidc_login(
+                userinfo, token, request, session_data.client_redirect_url
             )
         except MappingException as e:
             logger.exception("Could not map user")
             self._sso_handler.render_error(request, "mapping_error", str(e))
-            return
+
+    async def _complete_oidc_login(
+        self,
+        userinfo: UserInfo,
+        token: Token,
+        request: SynapseRequest,
+        client_redirect_url: str,
+    ) -> None:
+        """Given a UserInfo response, complete the login flow
+
+        UserInfo should have a claim that uniquely identifies users. This claim
+        is usually `sub`, but can be configured with `oidc_config.subject_claim`.
+        It is then used as an `external_id`.
+
+        If we don't find the user that way, we should register the user,
+        mapping the localpart and the display name from the UserInfo.
+
+        If a user already exists with the mxid we've mapped and allow_existing_users
+        is disabled, raise an exception.
+
+        Otherwise, render a redirect back to the client_redirect_url with a loginToken.
+
+        Args:
+            userinfo: an object representing the user
+            token: a dict with the tokens obtained from the provider
+            request: The request to respond to
+            client_redirect_url: The redirect URL passed in by the client.
+
+        Raises:
+            MappingException: if there was an error while mapping some properties
+        """
+        try:
+            remote_user_id = self._remote_id_from_userinfo(userinfo)
+        except Exception as e:
+            raise MappingException(
+                "Failed to extract subject from OIDC response: %s" % (e,)
+            )
+
+        # Older mapping providers don't accept the `failures` argument, so we
+        # try and detect support.
+        mapper_signature = inspect.signature(
+            self._user_mapping_provider.map_user_attributes
+        )
+        supports_failures = "failures" in mapper_signature.parameters
+
+        async def oidc_response_to_user_attributes(failures: int) -> UserAttributes:
+            """
+            Call the mapping provider to map the OIDC userinfo and token to user attributes.
+
+            This is backwards compatibility for abstraction for the SSO handler.
+            """
+            if supports_failures:
+                attributes = await self._user_mapping_provider.map_user_attributes(
+                    userinfo, token, failures
+                )
+            else:
+                # If the mapping provider does not support processing failures,
+                # do not continually generate the same Matrix ID since it will
+                # continue to already be in use. Note that the error raised is
+                # arbitrary and will get turned into a MappingException.
+                if failures:
+                    raise MappingException(
+                        "Mapping provider does not support de-duplicating Matrix IDs"
+                    )
+
+                attributes = await self._user_mapping_provider.map_user_attributes(  # type: ignore
+                    userinfo, token
+                )
+
+            return UserAttributes(**attributes)
+
+        async def grandfather_existing_users() -> Optional[str]:
+            if self._allow_existing_users:
+                # If allowing existing users we want to generate a single localpart
+                # and attempt to match it.
+                attributes = await oidc_response_to_user_attributes(failures=0)
+
+                user_id = UserID(attributes.localpart, self._server_name).to_string()
+                users = await self._store.get_users_by_id_case_insensitive(user_id)
+                if users:
+                    # If an existing matrix ID is returned, then use it.
+                    if len(users) == 1:
+                        previously_registered_user_id = next(iter(users))
+                    elif user_id in users:
+                        previously_registered_user_id = user_id
+                    else:
+                        # Do not attempt to continue generating Matrix IDs.
+                        raise MappingException(
+                            "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
+                                user_id, users
+                            )
+                        )
+
+                    return previously_registered_user_id
+
+            return None
 
         # Mapping providers might not have get_extra_attributes: only call this
         # method if it exists.
@@ -697,22 +883,42 @@ class OidcHandler(BaseHandler):
         if get_extra_attributes:
             extra_attributes = await get_extra_attributes(userinfo, token)
 
-        # and finally complete the login
-        if ui_auth_session_id:
-            await self._auth_handler.complete_sso_ui_auth(
-                user_id, ui_auth_session_id, request
-            )
-        else:
-            await self._auth_handler.complete_sso_login(
-                user_id, request, client_redirect_url, extra_attributes
-            )
+        await self._sso_handler.complete_sso_login_request(
+            self.idp_id,
+            remote_user_id,
+            request,
+            client_redirect_url,
+            oidc_response_to_user_attributes,
+            grandfather_existing_users,
+            extra_attributes,
+        )
+
+    def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
+        """Extract the unique remote id from an OIDC UserInfo block
+
+        Args:
+            userinfo: An object representing the user given by the OIDC provider
+        Returns:
+            remote user id
+        """
+        remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
+        # Some OIDC providers use integer IDs, but Synapse expects external IDs
+        # to be strings.
+        return str(remote_user_id)
+
+
+class OidcSessionTokenGenerator:
+    """Methods for generating and checking OIDC Session cookies."""
+
+    def __init__(self, hs: "HomeServer"):
+        self._clock = hs.get_clock()
+        self._server_name = hs.hostname
+        self._macaroon_secret_key = hs.config.key.macaroon_secret_key
 
-    def _generate_oidc_session_token(
+    def generate_oidc_session_token(
         self,
         state: str,
-        nonce: str,
-        client_redirect_url: str,
-        ui_auth_session_id: Optional[str],
+        session_data: "OidcSessionData",
         duration_in_ms: int = (60 * 60 * 1000),
     ) -> str:
         """Generates a signed token storing data about an OIDC session.
@@ -725,11 +931,7 @@ class OidcHandler(BaseHandler):
 
         Args:
             state: The ``state`` parameter passed to the OIDC provider.
-            nonce: The ``nonce`` parameter passed to the OIDC provider.
-            client_redirect_url: The URL the client gave when it initiated the
-                flow.
-            ui_auth_session_id: The session ID of the ongoing UI Auth (or
-                None if this is a login).
+            session_data: data to include in the session token.
             duration_in_ms: An optional duration for the token in milliseconds.
                 Defaults to an hour.
 
@@ -742,23 +944,24 @@ class OidcHandler(BaseHandler):
         macaroon.add_first_party_caveat("gen = 1")
         macaroon.add_first_party_caveat("type = session")
         macaroon.add_first_party_caveat("state = %s" % (state,))
-        macaroon.add_first_party_caveat("nonce = %s" % (nonce,))
+        macaroon.add_first_party_caveat("idp_id = %s" % (session_data.idp_id,))
+        macaroon.add_first_party_caveat("nonce = %s" % (session_data.nonce,))
         macaroon.add_first_party_caveat(
-            "client_redirect_url = %s" % (client_redirect_url,)
+            "client_redirect_url = %s" % (session_data.client_redirect_url,)
         )
-        if ui_auth_session_id:
+        if session_data.ui_auth_session_id:
             macaroon.add_first_party_caveat(
-                "ui_auth_session_id = %s" % (ui_auth_session_id,)
+                "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
             )
-        now = self.clock.time_msec()
+        now = self._clock.time_msec()
         expiry = now + duration_in_ms
         macaroon.add_first_party_caveat("time < %d" % (expiry,))
 
         return macaroon.serialize()
 
-    def _verify_oidc_session_token(
+    def verify_oidc_session_token(
         self, session: bytes, state: str
-    ) -> Tuple[str, str, Optional[str]]:
+    ) -> "OidcSessionData":
         """Verifies and extract an OIDC session token.
 
         This verifies that a given session token was issued by this homeserver
@@ -769,7 +972,10 @@ class OidcHandler(BaseHandler):
             state: The state the OIDC provider gave back
 
         Returns:
-            The nonce, client_redirect_url, and ui_auth_session_id for this session
+            The data extracted from the session cookie
+
+        Raises:
+            ValueError if an expected caveat is missing from the macaroon.
         """
         macaroon = pymacaroons.Macaroon.deserialize(session)
 
@@ -778,6 +984,7 @@ class OidcHandler(BaseHandler):
         v.satisfy_exact("type = session")
         v.satisfy_exact("state = %s" % (state,))
         v.satisfy_general(lambda c: c.startswith("nonce = "))
+        v.satisfy_general(lambda c: c.startswith("idp_id = "))
         v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
         # Sometimes there's a UI auth session ID, it seems to be OK to attempt
         # to always satisfy this.
@@ -786,9 +993,9 @@ class OidcHandler(BaseHandler):
 
         v.verify(macaroon, self._macaroon_secret_key)
 
-        # Extract the `nonce`, `client_redirect_url`, and maybe the
-        # `ui_auth_session_id` from the token.
+        # Extract the session data from the token.
         nonce = self._get_value_from_macaroon(macaroon, "nonce")
+        idp_id = self._get_value_from_macaroon(macaroon, "idp_id")
         client_redirect_url = self._get_value_from_macaroon(
             macaroon, "client_redirect_url"
         )
@@ -799,7 +1006,12 @@ class OidcHandler(BaseHandler):
         except ValueError:
             ui_auth_session_id = None
 
-        return nonce, client_redirect_url, ui_auth_session_id
+        return OidcSessionData(
+            nonce=nonce,
+            idp_id=idp_id,
+            client_redirect_url=client_redirect_url,
+            ui_auth_session_id=ui_auth_session_id,
+        )
 
     def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
         """Extracts a caveat value from a macaroon token.
@@ -812,7 +1024,7 @@ class OidcHandler(BaseHandler):
             The extracted value
 
         Raises:
-            Exception: if the caveat was not in the macaroon
+            ValueError: if the caveat was not in the macaroon
         """
         prefix = key + " = "
         for caveat in macaroon.caveats:
@@ -825,117 +1037,30 @@ class OidcHandler(BaseHandler):
         if not caveat.startswith(prefix):
             return False
         expiry = int(caveat[len(prefix) :])
-        now = self.clock.time_msec()
+        now = self._clock.time_msec()
         return now < expiry
 
-    async def _map_userinfo_to_user(
-        self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str
-    ) -> str:
-        """Maps a UserInfo object to a mxid.
-
-        UserInfo should have a claim that uniquely identifies users. This claim
-        is usually `sub`, but can be configured with `oidc_config.subject_claim`.
-        It is then used as an `external_id`.
-
-        If we don't find the user that way, we should register the user,
-        mapping the localpart and the display name from the UserInfo.
-
-        If a user already exists with the mxid we've mapped and allow_existing_users
-        is disabled, raise an exception.
-
-        Args:
-            userinfo: an object representing the user
-            token: a dict with the tokens obtained from the provider
-            user_agent: The user agent of the client making the request.
-            ip_address: The IP address of the client making the request.
-
-        Raises:
-            MappingException: if there was an error while mapping some properties
-
-        Returns:
-            The mxid of the user
-        """
-        try:
-            remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
-        except Exception as e:
-            raise MappingException(
-                "Failed to extract subject from OIDC response: %s" % (e,)
-            )
-        # Some OIDC providers use integer IDs, but Synapse expects external IDs
-        # to be strings.
-        remote_user_id = str(remote_user_id)
-
-        # Older mapping providers don't accept the `failures` argument, so we
-        # try and detect support.
-        mapper_signature = inspect.signature(
-            self._user_mapping_provider.map_user_attributes
-        )
-        supports_failures = "failures" in mapper_signature.parameters
-
-        async def oidc_response_to_user_attributes(failures: int) -> UserAttributes:
-            """
-            Call the mapping provider to map the OIDC userinfo and token to user attributes.
-
-            This is backwards compatibility for abstraction for the SSO handler.
-            """
-            if supports_failures:
-                attributes = await self._user_mapping_provider.map_user_attributes(
-                    userinfo, token, failures
-                )
-            else:
-                # If the mapping provider does not support processing failures,
-                # do not continually generate the same Matrix ID since it will
-                # continue to already be in use. Note that the error raised is
-                # arbitrary and will get turned into a MappingException.
-                if failures:
-                    raise MappingException(
-                        "Mapping provider does not support de-duplicating Matrix IDs"
-                    )
-
-                attributes = await self._user_mapping_provider.map_user_attributes(  # type: ignore
-                    userinfo, token
-                )
 
-            return UserAttributes(**attributes)
+@attr.s(frozen=True, slots=True)
+class OidcSessionData:
+    """The attributes which are stored in a OIDC session cookie"""
 
-        async def grandfather_existing_users() -> Optional[str]:
-            if self._allow_existing_users:
-                # If allowing existing users we want to generate a single localpart
-                # and attempt to match it.
-                attributes = await oidc_response_to_user_attributes(failures=0)
+    # the Identity Provider being used
+    idp_id = attr.ib(type=str)
 
-                user_id = UserID(attributes.localpart, self.server_name).to_string()
-                users = await self.store.get_users_by_id_case_insensitive(user_id)
-                if users:
-                    # If an existing matrix ID is returned, then use it.
-                    if len(users) == 1:
-                        previously_registered_user_id = next(iter(users))
-                    elif user_id in users:
-                        previously_registered_user_id = user_id
-                    else:
-                        # Do not attempt to continue generating Matrix IDs.
-                        raise MappingException(
-                            "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
-                                user_id, users
-                            )
-                        )
+    # The `nonce` parameter passed to the OIDC provider.
+    nonce = attr.ib(type=str)
 
-                    return previously_registered_user_id
+    # The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
+    client_redirect_url = attr.ib(type=str)
 
-            return None
-
-        return await self._sso_handler.get_mxid_from_sso(
-            self._auth_provider_id,
-            remote_user_id,
-            user_agent,
-            ip_address,
-            oidc_response_to_user_attributes,
-            grandfather_existing_users,
-        )
+    # The session ID of the ongoing UI Auth (None if this is a login)
+    ui_auth_session_id = attr.ib(type=Optional[str], default=None)
 
 
 UserAttributeDict = TypedDict(
-    "UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
+    "UserAttributeDict",
+    {"localpart": Optional[str], "display_name": Optional[str], "emails": List[str]},
 )
 C = TypeVar("C")
 
@@ -1014,12 +1139,13 @@ def jinja_finalize(thing):
 env = Environment(finalize=jinja_finalize)
 
 
-@attr.s
+@attr.s(slots=True, frozen=True)
 class JinjaOidcMappingConfig:
-    subject_claim = attr.ib()  # type: str
-    localpart_template = attr.ib()  # type: Template
-    display_name_template = attr.ib()  # type: Optional[Template]
-    extra_attributes = attr.ib()  # type: Dict[str, Template]
+    subject_claim = attr.ib(type=str)
+    localpart_template = attr.ib(type=Optional[Template])
+    display_name_template = attr.ib(type=Optional[Template])
+    email_template = attr.ib(type=Optional[Template])
+    extra_attributes = attr.ib(type=Dict[str, Template])
 
 
 class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
@@ -1035,50 +1161,37 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
     def parse_config(config: dict) -> JinjaOidcMappingConfig:
         subject_claim = config.get("subject_claim", "sub")
 
-        if "localpart_template" not in config:
-            raise ConfigError(
-                "missing key: oidc_config.user_mapping_provider.config.localpart_template"
-            )
-
-        try:
-            localpart_template = env.from_string(config["localpart_template"])
-        except Exception as e:
-            raise ConfigError(
-                "invalid jinja template for oidc_config.user_mapping_provider.config.localpart_template: %r"
-                % (e,)
-            )
-
-        display_name_template = None  # type: Optional[Template]
-        if "display_name_template" in config:
+        def parse_template_config(option_name: str) -> Optional[Template]:
+            if option_name not in config:
+                return None
             try:
-                display_name_template = env.from_string(config["display_name_template"])
+                return env.from_string(config[option_name])
             except Exception as e:
-                raise ConfigError(
-                    "invalid jinja template for oidc_config.user_mapping_provider.config.display_name_template: %r"
-                    % (e,)
-                )
+                raise ConfigError("invalid jinja template", path=[option_name]) from e
+
+        localpart_template = parse_template_config("localpart_template")
+        display_name_template = parse_template_config("display_name_template")
+        email_template = parse_template_config("email_template")
 
         extra_attributes = {}  # type Dict[str, Template]
         if "extra_attributes" in config:
             extra_attributes_config = config.get("extra_attributes") or {}
             if not isinstance(extra_attributes_config, dict):
-                raise ConfigError(
-                    "oidc_config.user_mapping_provider.config.extra_attributes must be a dict"
-                )
+                raise ConfigError("must be a dict", path=["extra_attributes"])
 
             for key, value in extra_attributes_config.items():
                 try:
                     extra_attributes[key] = env.from_string(value)
                 except Exception as e:
                     raise ConfigError(
-                        "invalid jinja template for oidc_config.user_mapping_provider.config.extra_attributes.%s: %r"
-                        % (key, e)
-                    )
+                        "invalid jinja template", path=["extra_attributes", key]
+                    ) from e
 
         return JinjaOidcMappingConfig(
             subject_claim=subject_claim,
             localpart_template=localpart_template,
             display_name_template=display_name_template,
+            email_template=email_template,
             extra_attributes=extra_attributes,
         )
 
@@ -1088,25 +1201,35 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
     async def map_user_attributes(
         self, userinfo: UserInfo, token: Token, failures: int
     ) -> UserAttributeDict:
-        localpart = self._config.localpart_template.render(user=userinfo).strip()
+        localpart = None
 
-        # Ensure only valid characters are included in the MXID.
-        localpart = map_username_to_mxid_localpart(localpart)
+        if self._config.localpart_template:
+            localpart = self._config.localpart_template.render(user=userinfo).strip()
 
-        # Append suffix integer if last call to this function failed to produce
-        # a usable mxid.
-        localpart += str(failures) if failures else ""
+            # Ensure only valid characters are included in the MXID.
+            localpart = map_username_to_mxid_localpart(localpart)
 
-        display_name = None  # type: Optional[str]
-        if self._config.display_name_template is not None:
-            display_name = self._config.display_name_template.render(
-                user=userinfo
-            ).strip()
+            # Append suffix integer if last call to this function failed to produce
+            # a usable mxid.
+            localpart += str(failures) if failures else ""
 
-            if display_name == "":
-                display_name = None
+        def render_template_field(template: Optional[Template]) -> Optional[str]:
+            if template is None:
+                return None
+            return template.render(user=userinfo).strip()
 
-        return UserAttributeDict(localpart=localpart, display_name=display_name)
+        display_name = render_template_field(self._config.display_name_template)
+        if display_name == "":
+            display_name = None
+
+        emails = []  # type: List[str]
+        email = render_template_field(self._config.email_template)
+        if email:
+            emails.append(email)
+
+        return UserAttributeDict(
+            localpart=localpart, display_name=display_name, emails=emails
+        )
 
     async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
         extras = {}  # type: Dict[str, str]