summary refs log tree commit diff
path: root/synapse/handlers/sso.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/sso.py')
-rw-r--r--synapse/handlers/sso.py736
1 files changed, 692 insertions, 44 deletions
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 47ad96f97e..b450668f1c 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -12,15 +12,35 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import abc
 import logging
-from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
+from typing import (
+    TYPE_CHECKING,
+    Awaitable,
+    Callable,
+    Dict,
+    Iterable,
+    Mapping,
+    Optional,
+    Set,
+)
+from urllib.parse import urlencode
 
 import attr
+from typing_extensions import NoReturn, Protocol
 
-from synapse.api.errors import RedirectException
-from synapse.handlers._base import BaseHandler
-from synapse.http.server import respond_with_html
-from synapse.types import UserID, contains_invalid_mxid_characters
+from twisted.web.http import Request
+from twisted.web.iweb import IRequest
+
+from synapse.api.constants import LoginType
+from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
+from synapse.handlers.ui_auth import UIAuthSessionDataConstants
+from synapse.http import get_request_user_agent
+from synapse.http.server import respond_with_html, respond_with_redirect
+from synapse.http.site import SynapseRequest
+from synapse.types import Collection, JsonDict, UserID, contains_invalid_mxid_characters
+from synapse.util.async_helpers import Linearizer
+from synapse.util.stringutils import random_string
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -35,24 +55,190 @@ class MappingException(Exception):
     """
 
 
+class SsoIdentityProvider(Protocol):
+    """Abstract base class to be implemented by SSO Identity Providers
+
+    An Identity Provider, or IdP, is an external HTTP service which authenticates a user
+    to say whether they should be allowed to log in, or perform a given action.
+
+    Synapse supports various implementations of IdPs, including OpenID Connect, SAML,
+    and CAS.
+
+    The main entry point is `handle_redirect_request`, which should return a URI to
+    redirect the user's browser to the IdP's authentication page.
+
+    Each IdP should be registered with the SsoHandler via
+    `hs.get_sso_handler().register_identity_provider()`, so that requests to
+    `/_matrix/client/r0/login/sso/redirect` can be correctly dispatched.
+    """
+
+    @property
+    @abc.abstractmethod
+    def idp_id(self) -> str:
+        """A unique identifier for this SSO provider
+
+        Eg, "saml", "cas", "github"
+        """
+
+    @property
+    @abc.abstractmethod
+    def idp_name(self) -> str:
+        """User-facing name for this provider"""
+
+    @property
+    def idp_icon(self) -> Optional[str]:
+        """Optional MXC URI for user-facing icon"""
+        return None
+
+    @property
+    def idp_brand(self) -> Optional[str]:
+        """Optional branding identifier"""
+        return None
+
+    @abc.abstractmethod
+    async def handle_redirect_request(
+        self,
+        request: SynapseRequest,
+        client_redirect_url: Optional[bytes],
+        ui_auth_session_id: Optional[str] = None,
+    ) -> str:
+        """Handle an incoming request to /login/sso/redirect
+
+        Args:
+            request: the incoming HTTP request
+            client_redirect_url: the URL that we should redirect the
+                client to after login (or None for UI Auth).
+            ui_auth_session_id: The session ID of the ongoing UI Auth (or
+                None if this is a login).
+
+        Returns:
+            URL to redirect to
+        """
+        raise NotImplementedError()
+
+
 @attr.s
 class UserAttributes:
-    localpart = attr.ib(type=str)
+    # the localpart of the mxid that the mapper has assigned to the user.
+    # if `None`, the mapper has not picked a userid, and the user should be prompted to
+    # enter one.
+    localpart = attr.ib(type=Optional[str])
     display_name = attr.ib(type=Optional[str], default=None)
-    emails = attr.ib(type=List[str], default=attr.Factory(list))
+    emails = attr.ib(type=Collection[str], default=attr.Factory(list))
+
+
+@attr.s(slots=True)
+class UsernameMappingSession:
+    """Data we track about SSO sessions"""
+
+    # A unique identifier for this SSO provider, e.g.  "oidc" or "saml".
+    auth_provider_id = attr.ib(type=str)
+
+    # user ID on the IdP server
+    remote_user_id = attr.ib(type=str)
 
+    # attributes returned by the ID mapper
+    display_name = attr.ib(type=Optional[str])
+    emails = attr.ib(type=Collection[str])
 
-class SsoHandler(BaseHandler):
+    # An optional dictionary of extra attributes to be provided to the client in the
+    # login response.
+    extra_login_attributes = attr.ib(type=Optional[JsonDict])
+
+    # where to redirect the client back to
+    client_redirect_url = attr.ib(type=str)
+
+    # expiry time for the session, in milliseconds
+    expiry_time_ms = attr.ib(type=int)
+
+    # choices made by the user
+    chosen_localpart = attr.ib(type=Optional[str], default=None)
+    use_display_name = attr.ib(type=bool, default=True)
+    emails_to_use = attr.ib(type=Collection[str], default=())
+    terms_accepted_version = attr.ib(type=Optional[str], default=None)
+
+
+# the HTTP cookie used to track the mapping session id
+USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session"
+
+
+class SsoHandler:
     # The number of attempts to ask the mapping provider for when generating an MXID.
     _MAP_USERNAME_RETRIES = 1000
 
+    # the time a UsernameMappingSession remains valid for
+    _MAPPING_SESSION_VALIDITY_PERIOD_MS = 15 * 60 * 1000
+
     def __init__(self, hs: "HomeServer"):
-        super().__init__(hs)
+        self._clock = hs.get_clock()
+        self._store = hs.get_datastore()
+        self._server_name = hs.hostname
         self._registration_handler = hs.get_registration_handler()
+        self._auth_handler = hs.get_auth_handler()
         self._error_template = hs.config.sso_error_template
+        self._bad_user_template = hs.config.sso_auth_bad_user_template
+
+        # The following template is shown after a successful user interactive
+        # authentication session. It tells the user they can close the window.
+        self._sso_auth_success_template = hs.config.sso_auth_success_template
+
+        # a lock on the mappings
+        self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
+
+        # a map from session id to session data
+        self._username_mapping_sessions = {}  # type: Dict[str, UsernameMappingSession]
+
+        # map from idp_id to SsoIdentityProvider
+        self._identity_providers = {}  # type: Dict[str, SsoIdentityProvider]
+
+        self._consent_at_registration = hs.config.consent.user_consent_at_registration
+
+    def register_identity_provider(self, p: SsoIdentityProvider):
+        p_id = p.idp_id
+        assert p_id not in self._identity_providers
+        self._identity_providers[p_id] = p
+
+    def get_identity_providers(self) -> Mapping[str, SsoIdentityProvider]:
+        """Get the configured identity providers"""
+        return self._identity_providers
+
+    async def get_identity_providers_for_user(
+        self, user_id: str
+    ) -> Mapping[str, SsoIdentityProvider]:
+        """Get the SsoIdentityProviders which a user has used
+
+        Given a user id, get the identity providers that that user has used to log in
+        with in the past (and thus could use to re-identify themselves for UI Auth).
+
+        Args:
+            user_id: MXID of user to look up
+
+        Raises:
+            a map of idp_id to SsoIdentityProvider
+        """
+        external_ids = await self._store.get_external_ids_by_user(user_id)
+
+        valid_idps = {}
+        for idp_id, _ in external_ids:
+            idp = self._identity_providers.get(idp_id)
+            if not idp:
+                logger.warning(
+                    "User %r has an SSO mapping for IdP %r, but this is no longer "
+                    "configured.",
+                    user_id,
+                    idp_id,
+                )
+            else:
+                valid_idps[idp_id] = idp
+
+        return valid_idps
 
     def render_error(
-        self, request, error: str, error_description: Optional[str] = None
+        self,
+        request: Request,
+        error: str,
+        error_description: Optional[str] = None,
+        code: int = 400,
     ) -> None:
         """Renders the error template and responds with it.
 
@@ -64,11 +250,53 @@ class SsoHandler(BaseHandler):
                 We'll respond with an HTML page describing the error.
             error: A technical identifier for this error.
             error_description: A human-readable description of the error.
+            code: The integer error code (an HTTP response code)
         """
         html = self._error_template.render(
             error=error, error_description=error_description
         )
-        respond_with_html(request, 400, html)
+        respond_with_html(request, code, html)
+
+    async def handle_redirect_request(
+        self,
+        request: SynapseRequest,
+        client_redirect_url: bytes,
+        idp_id: Optional[str],
+    ) -> str:
+        """Handle a request to /login/sso/redirect
+
+        Args:
+            request: incoming HTTP request
+            client_redirect_url: the URL that we should redirect the
+                client to after login.
+            idp_id: optional identity provider chosen by the client
+
+        Returns:
+             the URI to redirect to
+        """
+        if not self._identity_providers:
+            raise SynapseError(
+                400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED
+            )
+
+        # if the client chose an IdP, use that
+        idp = None  # type: Optional[SsoIdentityProvider]
+        if idp_id:
+            idp = self._identity_providers.get(idp_id)
+            if not idp:
+                raise NotFoundError("Unknown identity provider")
+
+        # if we only have one auth provider, redirect to it directly
+        elif len(self._identity_providers) == 1:
+            idp = next(iter(self._identity_providers.values()))
+
+        if idp:
+            return await idp.handle_redirect_request(request, client_redirect_url)
+
+        # otherwise, redirect to the IDP picker
+        return "/_synapse/client/pick_idp?" + urlencode(
+            (("redirectUrl", client_redirect_url),)
+        )
 
     async def get_sso_user_by_remote_user_id(
         self, auth_provider_id: str, remote_user_id: str
@@ -95,7 +323,7 @@ class SsoHandler(BaseHandler):
         )
 
         # Check if we already have a mapping for this user.
-        previously_registered_user_id = await self.store.get_user_by_external_id(
+        previously_registered_user_id = await self._store.get_user_by_external_id(
             auth_provider_id, remote_user_id,
         )
 
@@ -112,15 +340,16 @@ class SsoHandler(BaseHandler):
         # No match.
         return None
 
-    async def get_mxid_from_sso(
+    async def complete_sso_login_request(
         self,
         auth_provider_id: str,
         remote_user_id: str,
-        user_agent: str,
-        ip_address: str,
+        request: SynapseRequest,
+        client_redirect_url: str,
         sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
-        grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]],
-    ) -> str:
+        grandfather_existing_users: Callable[[], Awaitable[Optional[str]]],
+        extra_login_attributes: Optional[JsonDict] = None,
+    ) -> None:
         """
         Given an SSO ID, retrieve the user ID for it and possibly register the user.
 
@@ -139,12 +368,18 @@ class SsoHandler(BaseHandler):
         given user-agent and IP address and the SSO ID is linked to this matrix
         ID for subsequent calls.
 
+        Finally, we generate a redirect to the supplied redirect uri, with a login token
+
         Args:
             auth_provider_id: A unique identifier for this SSO provider, e.g.
                 "oidc" or "saml".
+
             remote_user_id: The unique identifier from the SSO provider.
-            user_agent: The user agent of the client making the request.
-            ip_address: The IP address of the client making the request.
+
+            request: The request to respond to
+
+            client_redirect_url: The redirect URL passed in by the client.
+
             sso_to_matrix_id_mapper: A callable to generate the user attributes.
                 The only parameter is an integer which represents the amount of
                 times the returned mxid localpart mapping has failed.
@@ -156,12 +391,13 @@ class SsoHandler(BaseHandler):
                         to the user.
                     RedirectException to redirect to an additional page (e.g.
                         to prompt the user for more information).
+
             grandfather_existing_users: A callable which can return an previously
                 existing matrix ID. The SSO ID is then linked to the returned
                 matrix ID.
 
-        Returns:
-             The user ID associated with the SSO response.
+            extra_login_attributes: An optional dictionary of extra
+                attributes to be provided to the client in the login response.
 
         Raises:
             MappingException if there was a problem mapping the response to a user.
@@ -169,24 +405,62 @@ class SsoHandler(BaseHandler):
                 to an additional page. (e.g. to prompt for more information)
 
         """
-        # first of all, check if we already have a mapping for this user
-        previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
-            auth_provider_id, remote_user_id,
-        )
-        if previously_registered_user_id:
-            return previously_registered_user_id
+        new_user = False
+
+        # grab a lock while we try to find a mapping for this user. This seems...
+        # optimistic, especially for implementations that end up redirecting to
+        # interstitial pages.
+        with await self._mapping_lock.queue(auth_provider_id):
+            # first of all, check if we already have a mapping for this user
+            user_id = await self.get_sso_user_by_remote_user_id(
+                auth_provider_id, remote_user_id,
+            )
 
-        # Check for grandfathering of users.
-        if grandfather_existing_users:
-            previously_registered_user_id = await grandfather_existing_users()
-            if previously_registered_user_id:
-                # Future logins should also match this user ID.
-                await self.store.record_user_external_id(
-                    auth_provider_id, remote_user_id, previously_registered_user_id
+            # Check for grandfathering of users.
+            if not user_id:
+                user_id = await grandfather_existing_users()
+                if user_id:
+                    # Future logins should also match this user ID.
+                    await self._store.record_user_external_id(
+                        auth_provider_id, remote_user_id, user_id
+                    )
+
+            # Otherwise, generate a new user.
+            if not user_id:
+                attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
+
+                if attributes.localpart is None:
+                    # the mapper doesn't return a username. bail out with a redirect to
+                    # the username picker.
+                    await self._redirect_to_username_picker(
+                        auth_provider_id,
+                        remote_user_id,
+                        attributes,
+                        client_redirect_url,
+                        extra_login_attributes,
+                    )
+
+                user_id = await self._register_mapped_user(
+                    attributes,
+                    auth_provider_id,
+                    remote_user_id,
+                    get_request_user_agent(request),
+                    request.getClientIP(),
                 )
-                return previously_registered_user_id
+                new_user = True
+
+        await self._auth_handler.complete_sso_login(
+            user_id,
+            request,
+            client_redirect_url,
+            extra_login_attributes,
+            new_user=new_user,
+        )
 
-        # Otherwise, generate a new user.
+    async def _call_attribute_mapper(
+        self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
+    ) -> UserAttributes:
+        """Call the attribute mapper function in a loop, until we get a unique userid"""
         for i in range(self._MAP_USERNAME_RETRIES):
             try:
                 attributes = await sso_to_matrix_id_mapper(i)
@@ -208,14 +482,12 @@ class SsoHandler(BaseHandler):
             )
 
             if not attributes.localpart:
-                raise MappingException(
-                    "Error parsing SSO response: SSO mapping provider plugin "
-                    "did not return a localpart value"
-                )
+                # the mapper has not picked a localpart
+                return attributes
 
             # Check if this mxid already exists
-            user_id = UserID(attributes.localpart, self.server_name).to_string()
-            if not await self.store.get_users_by_id_case_insensitive(user_id):
+            user_id = UserID(attributes.localpart, self._server_name).to_string()
+            if not await self._store.get_users_by_id_case_insensitive(user_id):
                 # This mxid is free
                 break
         else:
@@ -224,10 +496,101 @@ class SsoHandler(BaseHandler):
             raise MappingException(
                 "Unable to generate a Matrix ID from the SSO response"
             )
+        return attributes
+
+    async def _redirect_to_username_picker(
+        self,
+        auth_provider_id: str,
+        remote_user_id: str,
+        attributes: UserAttributes,
+        client_redirect_url: str,
+        extra_login_attributes: Optional[JsonDict],
+    ) -> NoReturn:
+        """Creates a UsernameMappingSession and redirects the browser
+
+        Called if the user mapping provider doesn't return a localpart for a new user.
+        Raises a RedirectException which redirects the browser to the username picker.
+
+        Args:
+            auth_provider_id: A unique identifier for this SSO provider, e.g.
+                "oidc" or "saml".
+
+            remote_user_id: The unique identifier from the SSO provider.
+
+            attributes: the user attributes returned by the user mapping provider.
+
+            client_redirect_url: The redirect URL passed in by the client, which we
+                will eventually redirect back to.
+
+            extra_login_attributes: An optional dictionary of extra
+                attributes to be provided to the client in the login response.
+
+        Raises:
+            RedirectException
+        """
+        session_id = random_string(16)
+        now = self._clock.time_msec()
+        session = UsernameMappingSession(
+            auth_provider_id=auth_provider_id,
+            remote_user_id=remote_user_id,
+            display_name=attributes.display_name,
+            emails=attributes.emails,
+            client_redirect_url=client_redirect_url,
+            expiry_time_ms=now + self._MAPPING_SESSION_VALIDITY_PERIOD_MS,
+            extra_login_attributes=extra_login_attributes,
+        )
+
+        self._username_mapping_sessions[session_id] = session
+        logger.info("Recorded registration session id %s", session_id)
+
+        # Set the cookie and redirect to the username picker
+        e = RedirectException(b"/_synapse/client/pick_username/account_details")
+        e.cookies.append(
+            b"%s=%s; path=/"
+            % (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii"))
+        )
+        raise e
+
+    async def _register_mapped_user(
+        self,
+        attributes: UserAttributes,
+        auth_provider_id: str,
+        remote_user_id: str,
+        user_agent: str,
+        ip_address: str,
+    ) -> str:
+        """Register a new SSO user.
+
+        This is called once we have successfully mapped the remote user id onto a local
+        user id, one way or another.
+
+        Args:
+             attributes: user attributes returned by the user mapping provider,
+                including a non-empty localpart.
+
+            auth_provider_id: A unique identifier for this SSO provider, e.g.
+                "oidc" or "saml".
+
+            remote_user_id: The unique identifier from the SSO provider.
+
+            user_agent: The user-agent in the HTTP request (used for potential
+                shadow-banning.)
+
+            ip_address: The IP address of the requester (used for potential
+                shadow-banning.)
+
+        Raises:
+            a MappingException if the localpart is invalid.
+
+            a SynapseError with code 400 and errcode Codes.USER_IN_USE if the localpart
+            is already taken.
+        """
 
         # Since the localpart is provided via a potentially untrusted module,
         # ensure the MXID is valid before registering.
-        if contains_invalid_mxid_characters(attributes.localpart):
+        if not attributes.localpart or contains_invalid_mxid_characters(
+            attributes.localpart
+        ):
             raise MappingException("localpart is invalid: %s" % (attributes.localpart,))
 
         logger.debug("Mapped SSO user to local part %s", attributes.localpart)
@@ -238,7 +601,292 @@ class SsoHandler(BaseHandler):
             user_agent_ips=[(user_agent, ip_address)],
         )
 
-        await self.store.record_user_external_id(
+        await self._store.record_user_external_id(
             auth_provider_id, remote_user_id, registered_user_id
         )
         return registered_user_id
+
+    async def complete_sso_ui_auth_request(
+        self,
+        auth_provider_id: str,
+        remote_user_id: str,
+        ui_auth_session_id: str,
+        request: Request,
+    ) -> None:
+        """
+        Given an SSO ID, retrieve the user ID for it and complete UIA.
+
+        Note that this requires that the user is mapped in the "user_external_ids"
+        table. This will be the case if they have ever logged in via SAML or OIDC in
+        recentish synapse versions, but may not be for older users.
+
+        Args:
+            auth_provider_id: A unique identifier for this SSO provider, e.g.
+                "oidc" or "saml".
+            remote_user_id: The unique identifier from the SSO provider.
+            ui_auth_session_id: The ID of the user-interactive auth session.
+            request: The request to complete.
+        """
+
+        user_id = await self.get_sso_user_by_remote_user_id(
+            auth_provider_id, remote_user_id,
+        )
+
+        user_id_to_verify = await self._auth_handler.get_session_data(
+            ui_auth_session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
+        )  # type: str
+
+        if not user_id:
+            logger.warning(
+                "Remote user %s/%s has not previously logged in here: UIA will fail",
+                auth_provider_id,
+                remote_user_id,
+            )
+        elif user_id != user_id_to_verify:
+            logger.warning(
+                "Remote user %s/%s mapped onto incorrect user %s: UIA will fail",
+                auth_provider_id,
+                remote_user_id,
+                user_id,
+            )
+        else:
+            # success!
+            # Mark the stage of the authentication as successful.
+            await self._store.mark_ui_auth_stage_complete(
+                ui_auth_session_id, LoginType.SSO, user_id
+            )
+
+            # Render the HTML confirmation page and return.
+            html = self._sso_auth_success_template
+            respond_with_html(request, 200, html)
+            return
+
+        # the user_id didn't match: mark the stage of the authentication as unsuccessful
+        await self._store.mark_ui_auth_stage_complete(
+            ui_auth_session_id, LoginType.SSO, ""
+        )
+
+        # render an error page.
+        html = self._bad_user_template.render(
+            server_name=self._server_name, user_id_to_verify=user_id_to_verify,
+        )
+        respond_with_html(request, 200, html)
+
+    def get_mapping_session(self, session_id: str) -> UsernameMappingSession:
+        """Look up the given username mapping session
+
+        If it is not found, raises a SynapseError with an http code of 400
+
+        Args:
+            session_id: session to look up
+        Returns:
+            active mapping session
+        Raises:
+            SynapseError if the session is not found/has expired
+        """
+        self._expire_old_sessions()
+        session = self._username_mapping_sessions.get(session_id)
+        if session:
+            return session
+        logger.info("Couldn't find session id %s", session_id)
+        raise SynapseError(400, "unknown session")
+
+    async def check_username_availability(
+        self, localpart: str, session_id: str,
+    ) -> bool:
+        """Handle an "is username available" callback check
+
+        Args:
+            localpart: desired localpart
+            session_id: the session id for the username picker
+        Returns:
+            True if the username is available
+        Raises:
+            SynapseError if the localpart is invalid or the session is unknown
+        """
+
+        # make sure that there is a valid mapping session, to stop people dictionary-
+        # scanning for accounts
+        self.get_mapping_session(session_id)
+
+        logger.info(
+            "[session %s] Checking for availability of username %s",
+            session_id,
+            localpart,
+        )
+
+        if contains_invalid_mxid_characters(localpart):
+            raise SynapseError(400, "localpart is invalid: %s" % (localpart,))
+        user_id = UserID(localpart, self._server_name).to_string()
+        user_infos = await self._store.get_users_by_id_case_insensitive(user_id)
+
+        logger.info("[session %s] users: %s", session_id, user_infos)
+        return not user_infos
+
+    async def handle_submit_username_request(
+        self,
+        request: SynapseRequest,
+        session_id: str,
+        localpart: str,
+        use_display_name: bool,
+        emails_to_use: Iterable[str],
+    ) -> None:
+        """Handle a request to the username-picker 'submit' endpoint
+
+        Will serve an HTTP response to the request.
+
+        Args:
+            request: HTTP request
+            localpart: localpart requested by the user
+            session_id: ID of the username mapping session, extracted from a cookie
+            use_display_name: whether the user wants to use the suggested display name
+            emails_to_use: emails that the user would like to use
+        """
+        session = self.get_mapping_session(session_id)
+
+        # update the session with the user's choices
+        session.chosen_localpart = localpart
+        session.use_display_name = use_display_name
+
+        emails_from_idp = set(session.emails)
+        filtered_emails = set()  # type: Set[str]
+
+        # we iterate through the list rather than just building a set conjunction, so
+        # that we can log attempts to use unknown addresses
+        for email in emails_to_use:
+            if email in emails_from_idp:
+                filtered_emails.add(email)
+            else:
+                logger.warning(
+                    "[session %s] ignoring user request to use unknown email address %r",
+                    session_id,
+                    email,
+                )
+        session.emails_to_use = filtered_emails
+
+        # we may now need to collect consent from the user, in which case, redirect
+        # to the consent-extraction-unit
+        if self._consent_at_registration:
+            redirect_url = b"/_synapse/client/new_user_consent"
+
+        # otherwise, redirect to the completion page
+        else:
+            redirect_url = b"/_synapse/client/sso_register"
+
+        respond_with_redirect(request, redirect_url)
+
+    async def handle_terms_accepted(
+        self, request: Request, session_id: str, terms_version: str
+    ):
+        """Handle a request to the new-user 'consent' endpoint
+
+        Will serve an HTTP response to the request.
+
+        Args:
+            request: HTTP request
+            session_id: ID of the username mapping session, extracted from a cookie
+            terms_version: the version of the terms which the user viewed and consented
+                to
+        """
+        logger.info(
+            "[session %s] User consented to terms version %s",
+            session_id,
+            terms_version,
+        )
+        session = self.get_mapping_session(session_id)
+        session.terms_accepted_version = terms_version
+
+        # we're done; now we can register the user
+        respond_with_redirect(request, b"/_synapse/client/sso_register")
+
+    async def register_sso_user(self, request: Request, session_id: str) -> None:
+        """Called once we have all the info we need to register a new user.
+
+        Does so and serves an HTTP response
+
+        Args:
+            request: HTTP request
+            session_id: ID of the username mapping session, extracted from a cookie
+        """
+        session = self.get_mapping_session(session_id)
+
+        logger.info(
+            "[session %s] Registering localpart %s",
+            session_id,
+            session.chosen_localpart,
+        )
+
+        attributes = UserAttributes(
+            localpart=session.chosen_localpart, emails=session.emails_to_use,
+        )
+
+        if session.use_display_name:
+            attributes.display_name = session.display_name
+
+        # the following will raise a 400 error if the username has been taken in the
+        # meantime.
+        user_id = await self._register_mapped_user(
+            attributes,
+            session.auth_provider_id,
+            session.remote_user_id,
+            get_request_user_agent(request),
+            request.getClientIP(),
+        )
+
+        logger.info(
+            "[session %s] Registered userid %s with attributes %s",
+            session_id,
+            user_id,
+            attributes,
+        )
+
+        # delete the mapping session and the cookie
+        del self._username_mapping_sessions[session_id]
+
+        # delete the cookie
+        request.addCookie(
+            USERNAME_MAPPING_SESSION_COOKIE_NAME,
+            b"",
+            expires=b"Thu, 01 Jan 1970 00:00:00 GMT",
+            path=b"/",
+        )
+
+        auth_result = {}
+        if session.terms_accepted_version:
+            # TODO: make this less awful.
+            auth_result[LoginType.TERMS] = True
+
+        await self._registration_handler.post_registration_actions(
+            user_id, auth_result, access_token=None
+        )
+
+        await self._auth_handler.complete_sso_login(
+            user_id,
+            request,
+            session.client_redirect_url,
+            session.extra_login_attributes,
+            new_user=True,
+        )
+
+    def _expire_old_sessions(self):
+        to_expire = []
+        now = int(self._clock.time_msec())
+
+        for session_id, session in self._username_mapping_sessions.items():
+            if session.expiry_time_ms <= now:
+                to_expire.append(session_id)
+
+        for session_id in to_expire:
+            logger.info("Expiring mapping session %s", session_id)
+            del self._username_mapping_sessions[session_id]
+
+
+def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
+    """Extract the session ID from the cookie
+
+    Raises a SynapseError if the cookie isn't found
+    """
+    session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
+    if not session_id:
+        raise SynapseError(code=400, msg="missing session_id")
+    return session_id.decode("ascii", errors="replace")