summary refs log tree commit diff
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2021-02-01 17:28:37 +0000
committerRichard van der Hoff <richard@matrix.org>2021-02-01 17:28:37 +0000
commit18ab35284a2270efe01815911885e45b0f743453 (patch)
treea0b6c0322233525e91b84e2b4c08800cfa9d7c18
parentAdd phone home stats for encrypted messages. (#9283) (diff)
parentReplace username picker with a template (#9275) (diff)
downloadsynapse-18ab35284a2270efe01815911885e45b0f743453.tar.xz
Merge branch 'social_login' into develop
-rw-r--r--changelog.d/9262.feature1
-rw-r--r--changelog.d/9271.bugfix1
-rw-r--r--changelog.d/9272.feature1
-rw-r--r--changelog.d/9275.feature1
-rw-r--r--docs/sample_config.yaml46
-rw-r--r--docs/workers.md18
-rw-r--r--synapse/app/generic_worker.py11
-rw-r--r--synapse/app/homeserver.py16
-rw-r--r--synapse/config/_base.py39
-rw-r--r--synapse/config/oidc_config.py3
-rw-r--r--synapse/config/sso.py47
-rw-r--r--synapse/handlers/auth.py24
-rw-r--r--synapse/handlers/sso.py93
-rw-r--r--synapse/http/server.py7
-rw-r--r--synapse/module_api/__init__.py10
-rw-r--r--synapse/res/templates/sso.css83
-rw-r--r--synapse/res/templates/sso_auth_account_details.html115
-rw-r--r--synapse/res/templates/sso_auth_account_details.js76
-rw-r--r--synapse/res/templates/sso_redirect_confirm.html32
-rw-r--r--synapse/res/username_picker/index.html19
-rw-r--r--synapse/res/username_picker/script.js95
-rw-r--r--synapse/res/username_picker/style.css27
-rw-r--r--synapse/rest/consent/consent_resource.py1
-rw-r--r--synapse/rest/synapse/client/__init__.py49
-rw-r--r--synapse/rest/synapse/client/pick_username.py91
-rw-r--r--synapse/rest/synapse/client/sso_register.py50
-rw-r--r--synapse/storage/databases/main/registration.py40
-rw-r--r--synapse/util/templates.py106
-rw-r--r--tests/handlers/test_cas.py8
-rw-r--r--tests/handlers/test_oidc.py24
-rw-r--r--tests/handlers/test_saml.py8
-rw-r--r--tests/rest/client/v1/test_login.py30
-rw-r--r--tests/rest/client/v2_alpha/test_auth.py6
33 files changed, 860 insertions, 318 deletions
diff --git a/changelog.d/9262.feature b/changelog.d/9262.feature
new file mode 100644
index 0000000000..c21b197ca1
--- /dev/null
+++ b/changelog.d/9262.feature
@@ -0,0 +1 @@
+Improve the user experience of setting up an account via single-sign on.
diff --git a/changelog.d/9271.bugfix b/changelog.d/9271.bugfix
new file mode 100644
index 0000000000..ef30c6570f
--- /dev/null
+++ b/changelog.d/9271.bugfix
@@ -0,0 +1 @@
+Fix single-sign-on when the endpoints are routed to synapse workers.
diff --git a/changelog.d/9272.feature b/changelog.d/9272.feature
new file mode 100644
index 0000000000..c21b197ca1
--- /dev/null
+++ b/changelog.d/9272.feature
@@ -0,0 +1 @@
+Improve the user experience of setting up an account via single-sign on.
diff --git a/changelog.d/9275.feature b/changelog.d/9275.feature
new file mode 100644
index 0000000000..c21b197ca1
--- /dev/null
+++ b/changelog.d/9275.feature
@@ -0,0 +1 @@
+Improve the user experience of setting up an account via single-sign on.
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 7fd35516dc..a669a241da 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -1815,7 +1815,8 @@ saml2_config:
 #
 #             localpart_template: Jinja2 template for the localpart of the MXID.
 #                 If this is not set, the user will be prompted to choose their
-#                 own username.
+#                 own username (see 'sso_auth_account_details.html' in the 'sso'
+#                 section of this file).
 #
 #             display_name_template: Jinja2 template for the display name to set
 #                 on first login. If unset, no displayname will be set.
@@ -1978,10 +1979,40 @@ sso:
     #
     #     * idp: the 'idp_id' of the chosen IDP.
     #
+    # * HTML page to prompt new users to enter a userid and confirm other
+    #   details: 'sso_auth_account_details.html'. This is only shown if the
+    #   SSO implementation (with any user_mapping_provider) does not return
+    #   a localpart.
+    #
+    #   When rendering, this template is given the following variables:
+    #
+    #     * server_name: the homeserver's name.
+    #
+    #     * idp: details of the SSO Identity Provider that the user logged in
+    #       with: an object with the following attributes:
+    #
+    #         * idp_id: unique identifier for the IdP
+    #         * idp_name: user-facing name for the IdP
+    #         * idp_icon: if specified in the IdP config, an MXC URI for an icon
+    #              for the IdP
+    #         * idp_brand: if specified in the IdP config, a textual identifier
+    #              for the brand of the IdP
+    #
+    #     * user_attributes: an object containing details about the user that
+    #       we received from the IdP. May have the following attributes:
+    #
+    #         * display_name: the user's display_name
+    #         * emails: a list of email addresses
+    #
+    #   The template should render a form which submits the following fields:
+    #
+    #     * username: the localpart of the user's chosen user id
+    #
     # * HTML page for a confirmation step before redirecting back to the client
     #   with the login token: 'sso_redirect_confirm.html'.
     #
-    #   When rendering, this template is given three variables:
+    #   When rendering, this template is given the following variables:
+    #
     #     * redirect_url: the URL the user is about to be redirected to. Needs
     #                     manual escaping (see
     #                     https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
@@ -1994,6 +2025,17 @@ sso:
     #
     #     * server_name: the homeserver's name.
     #
+    #     * new_user: a boolean indicating whether this is the user's first time
+    #          logging in.
+    #
+    #     * user_id: the user's matrix ID.
+    #
+    #     * user_profile.avatar_url: an MXC URI for the user's avatar, if any.
+    #           None if the user has not set an avatar.
+    #
+    #     * user_profile.display_name: the user's display name. None if the user
+    #           has not set a display name.
+    #
     # * HTML page which notifies the user that they are authenticating to confirm
     #   an operation on their account during the user interactive authentication
     #   process: 'sso_auth_confirm.html'.
diff --git a/docs/workers.md b/docs/workers.md
index d2927d95a6..bd8c9f95cb 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -228,7 +228,6 @@ expressions:
     ^/_matrix/client/(api/v1|r0|unstable)/joined_groups$
     ^/_matrix/client/(api/v1|r0|unstable)/publicised_groups$
     ^/_matrix/client/(api/v1|r0|unstable)/publicised_groups/
-    ^/_synapse/client/password_reset/email/submit_token$
 
     # Registration/login requests
     ^/_matrix/client/(api/v1|r0|unstable)/login$
@@ -259,25 +258,28 @@ Additionally, the following endpoints should be included if Synapse is configure
 to use SSO (you only need to include the ones for whichever SSO provider you're
 using):
 
+    # for all SSO providers
+    ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect
+    ^/_synapse/client/pick_idp$
+    ^/_synapse/client/pick_username
+    ^/_synapse/client/sso_register$
+
     # OpenID Connect requests.
-    ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect$
     ^/_synapse/oidc/callback$
 
     # SAML requests.
-    ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect$
     ^/_matrix/saml2/authn_response$
 
     # CAS requests.
-    ^/_matrix/client/(api/v1|r0|unstable)/login/(cas|sso)/redirect$
     ^/_matrix/client/(api/v1|r0|unstable)/login/cas/ticket$
 
-Note that a HTTP listener with `client` and `federation` resources must be
-configured in the `worker_listeners` option in the worker config.
-
-Ensure that all SSO logins go to a single process (usually the main process).
+Ensure that all SSO logins go to a single process.
 For multiple workers not handling the SSO endpoints properly, see
 [#7530](https://github.com/matrix-org/synapse/issues/7530).
 
+Note that a HTTP listener with `client` and `federation` resources must be
+configured in the `worker_listeners` option in the worker config.
+
 #### Load balancing
 
 It is possible to run multiple instances of this worker app, with incoming requests
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index e60988fa4a..516f2464b4 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -22,6 +22,7 @@ from typing import Dict, Iterable, Optional, Set
 from typing_extensions import ContextManager
 
 from twisted.internet import address
+from twisted.web.resource import IResource
 
 import synapse
 import synapse.events
@@ -90,9 +91,8 @@ from synapse.replication.tcp.streams import (
     ToDeviceStream,
 )
 from synapse.rest.admin import register_servlets_for_media_repo
-from synapse.rest.client.v1 import events, room
+from synapse.rest.client.v1 import events, login, room
 from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
-from synapse.rest.client.v1.login import LoginRestServlet
 from synapse.rest.client.v1.profile import (
     ProfileAvatarURLRestServlet,
     ProfileDisplaynameRestServlet,
@@ -127,6 +127,7 @@ from synapse.rest.client.v2_alpha.sendtodevice import SendToDeviceRestServlet
 from synapse.rest.client.versions import VersionsRestServlet
 from synapse.rest.health import HealthResource
 from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
 from synapse.server import HomeServer, cache_in_self
 from synapse.storage.databases.main.censor_events import CensorEventsStore
 from synapse.storage.databases.main.client_ips import ClientIpWorkerStore
@@ -507,7 +508,7 @@ class GenericWorkerServer(HomeServer):
             site_tag = port
 
         # We always include a health resource.
-        resources = {"/health": HealthResource()}
+        resources = {"/health": HealthResource()}  # type: Dict[str, IResource]
 
         for res in listener_config.http_options.resources:
             for name in res.names:
@@ -517,7 +518,7 @@ class GenericWorkerServer(HomeServer):
                     resource = JsonResource(self, canonical_json=False)
 
                     RegisterRestServlet(self).register(resource)
-                    LoginRestServlet(self).register(resource)
+                    login.register_servlets(self, resource)
                     ThreepidRestServlet(self).register(resource)
                     DevicesRestServlet(self).register(resource)
                     KeyQueryServlet(self).register(resource)
@@ -557,6 +558,8 @@ class GenericWorkerServer(HomeServer):
                     groups.register_servlets(self, resource)
 
                     resources.update({CLIENT_API_PREFIX: resource})
+
+                    resources.update(build_synapse_client_resource_tree(self))
                 elif name == "federation":
                     resources.update({FEDERATION_PREFIX: TransportLayerServer(self)})
                 elif name == "media":
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 57a2f5237c..244657cb88 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -60,8 +60,7 @@ from synapse.rest import ClientRestResource
 from synapse.rest.admin import AdminRestResource
 from synapse.rest.health import HealthResource
 from synapse.rest.key.v2 import KeyApiV2Resource
-from synapse.rest.synapse.client.pick_idp import PickIdpResource
-from synapse.rest.synapse.client.pick_username import pick_username_resource
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
 from synapse.rest.well_known import WellKnownResource
 from synapse.server import HomeServer
 from synapse.storage import DataStore
@@ -190,21 +189,10 @@ class SynapseHomeServer(HomeServer):
                     "/_matrix/client/versions": client_resource,
                     "/.well-known/matrix/client": WellKnownResource(self),
                     "/_synapse/admin": AdminRestResource(self),
-                    "/_synapse/client/pick_username": pick_username_resource(self),
-                    "/_synapse/client/pick_idp": PickIdpResource(self),
+                    **build_synapse_client_resource_tree(self),
                 }
             )
 
-            if self.get_config().oidc_enabled:
-                from synapse.rest.oidc import OIDCResource
-
-                resources["/_synapse/oidc"] = OIDCResource(self)
-
-            if self.get_config().saml2_enabled:
-                from synapse.rest.saml2 import SAML2Resource
-
-                resources["/_matrix/saml2"] = SAML2Resource(self)
-
             if self.get_config().threepid_behaviour_email == ThreepidBehaviour.LOCAL:
                 from synapse.rest.synapse.client.password_reset import (
                     PasswordResetSubmitTokenResource,
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 6a0768ce00..a851f8801d 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -18,18 +18,18 @@
 import argparse
 import errno
 import os
-import time
-import urllib.parse
 from collections import OrderedDict
 from hashlib import sha256
 from textwrap import dedent
-from typing import Any, Callable, Iterable, List, MutableMapping, Optional
+from typing import Any, Iterable, List, MutableMapping, Optional
 
 import attr
 import jinja2
 import pkg_resources
 import yaml
 
+from synapse.util.templates import _create_mxc_to_http_filter, _format_ts_filter
+
 
 class ConfigError(Exception):
     """Represents a problem parsing the configuration
@@ -262,6 +262,7 @@ class Config:
             # Search the custom template directory as well
             search_directories.insert(0, custom_template_directory)
 
+        # TODO: switch to synapse.util.templates.build_jinja_env
         loader = jinja2.FileSystemLoader(search_directories)
         env = jinja2.Environment(loader=loader, autoescape=jinja2.select_autoescape(),)
 
@@ -277,38 +278,6 @@ class Config:
         return [env.get_template(filename) for filename in filenames]
 
 
-def _format_ts_filter(value: int, format: str):
-    return time.strftime(format, time.localtime(value / 1000))
-
-
-def _create_mxc_to_http_filter(public_baseurl: str) -> Callable:
-    """Create and return a jinja2 filter that converts MXC urls to HTTP
-
-    Args:
-        public_baseurl: The public, accessible base URL of the homeserver
-    """
-
-    def mxc_to_http_filter(value, width, height, resize_method="crop"):
-        if value[0:6] != "mxc://":
-            return ""
-
-        server_and_media_id = value[6:]
-        fragment = None
-        if "#" in server_and_media_id:
-            server_and_media_id, fragment = server_and_media_id.split("#", 1)
-            fragment = "#" + fragment
-
-        params = {"width": width, "height": height, "method": resize_method}
-        return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
-            public_baseurl,
-            server_and_media_id,
-            urllib.parse.urlencode(params),
-            fragment or "",
-        )
-
-    return mxc_to_http_filter
-
-
 class RootConfig:
     """
     Holder of an application's configuration.
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index b71aae0b53..bb122ef182 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -151,7 +151,8 @@ class OIDCConfig(Config):
         #
         #             localpart_template: Jinja2 template for the localpart of the MXID.
         #                 If this is not set, the user will be prompted to choose their
-        #                 own username.
+        #                 own username (see 'sso_auth_account_details.html' in the 'sso'
+        #                 section of this file).
         #
         #             display_name_template: Jinja2 template for the display name to set
         #                 on first login. If unset, no displayname will be set.
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index 59be825532..e308fc9333 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -27,7 +27,7 @@ class SSOConfig(Config):
         sso_config = config.get("sso") or {}  # type: Dict[str, Any]
 
         # The sso-specific template_dir
-        template_dir = sso_config.get("template_dir")
+        self.sso_template_dir = sso_config.get("template_dir")
 
         # Read templates from disk
         (
@@ -48,7 +48,7 @@ class SSOConfig(Config):
                 "sso_auth_success.html",
                 "sso_auth_bad_user.html",
             ],
-            template_dir,
+            self.sso_template_dir,
         )
 
         # These templates have no placeholders, so render them here
@@ -124,10 +124,40 @@ class SSOConfig(Config):
             #
             #     * idp: the 'idp_id' of the chosen IDP.
             #
+            # * HTML page to prompt new users to enter a userid and confirm other
+            #   details: 'sso_auth_account_details.html'. This is only shown if the
+            #   SSO implementation (with any user_mapping_provider) does not return
+            #   a localpart.
+            #
+            #   When rendering, this template is given the following variables:
+            #
+            #     * server_name: the homeserver's name.
+            #
+            #     * idp: details of the SSO Identity Provider that the user logged in
+            #       with: an object with the following attributes:
+            #
+            #         * idp_id: unique identifier for the IdP
+            #         * idp_name: user-facing name for the IdP
+            #         * idp_icon: if specified in the IdP config, an MXC URI for an icon
+            #              for the IdP
+            #         * idp_brand: if specified in the IdP config, a textual identifier
+            #              for the brand of the IdP
+            #
+            #     * user_attributes: an object containing details about the user that
+            #       we received from the IdP. May have the following attributes:
+            #
+            #         * display_name: the user's display_name
+            #         * emails: a list of email addresses
+            #
+            #   The template should render a form which submits the following fields:
+            #
+            #     * username: the localpart of the user's chosen user id
+            #
             # * HTML page for a confirmation step before redirecting back to the client
             #   with the login token: 'sso_redirect_confirm.html'.
             #
-            #   When rendering, this template is given three variables:
+            #   When rendering, this template is given the following variables:
+            #
             #     * redirect_url: the URL the user is about to be redirected to. Needs
             #                     manual escaping (see
             #                     https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
@@ -140,6 +170,17 @@ class SSOConfig(Config):
             #
             #     * server_name: the homeserver's name.
             #
+            #     * new_user: a boolean indicating whether this is the user's first time
+            #          logging in.
+            #
+            #     * user_id: the user's matrix ID.
+            #
+            #     * user_profile.avatar_url: an MXC URI for the user's avatar, if any.
+            #           None if the user has not set an avatar.
+            #
+            #     * user_profile.display_name: the user's display name. None if the user
+            #           has not set a display name.
+            #
             # * HTML page which notifies the user that they are authenticating to confirm
             #   an operation on their account during the user interactive authentication
             #   process: 'sso_auth_confirm.html'.
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 3127357964..143494ae99 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -61,6 +61,7 @@ from synapse.http.site import SynapseRequest
 from synapse.logging.context import defer_to_thread
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.module_api import ModuleApi
+from synapse.storage.roommember import ProfileInfo
 from synapse.types import JsonDict, Requester, UserID
 from synapse.util import stringutils as stringutils
 from synapse.util.async_helpers import maybe_awaitable
@@ -1386,6 +1387,7 @@ class AuthHandler(BaseHandler):
         request: Request,
         client_redirect_url: str,
         extra_attributes: Optional[JsonDict] = None,
+        new_user: bool = False,
     ):
         """Having figured out a mxid for this user, complete the HTTP request
 
@@ -1396,6 +1398,8 @@ class AuthHandler(BaseHandler):
                 process.
             extra_attributes: Extra attributes which will be passed to the client
                 during successful login. Must be JSON serializable.
+            new_user: True if we should use wording appropriate to a user who has just
+                registered.
         """
         # If the account has been deactivated, do not proceed with the login
         # flow.
@@ -1404,8 +1408,17 @@ class AuthHandler(BaseHandler):
             respond_with_html(request, 403, self._sso_account_deactivated_template)
             return
 
+        profile = await self.store.get_profileinfo(
+            UserID.from_string(registered_user_id).localpart
+        )
+
         self._complete_sso_login(
-            registered_user_id, request, client_redirect_url, extra_attributes
+            registered_user_id,
+            request,
+            client_redirect_url,
+            extra_attributes,
+            new_user=new_user,
+            user_profile_data=profile,
         )
 
     def _complete_sso_login(
@@ -1414,12 +1427,18 @@ class AuthHandler(BaseHandler):
         request: Request,
         client_redirect_url: str,
         extra_attributes: Optional[JsonDict] = None,
+        new_user: bool = False,
+        user_profile_data: Optional[ProfileInfo] = None,
     ):
         """
         The synchronous portion of complete_sso_login.
 
         This exists purely for backwards compatibility of synapse.module_api.ModuleApi.
         """
+
+        if user_profile_data is None:
+            user_profile_data = ProfileInfo(None, None)
+
         # Store any extra attributes which will be passed in the login response.
         # Note that this is per-user so it may overwrite a previous value, this
         # is considered OK since the newest SSO attributes should be most valid.
@@ -1457,6 +1476,9 @@ class AuthHandler(BaseHandler):
             display_url=redirect_url_no_params,
             redirect_url=redirect_url,
             server_name=self._server_name,
+            new_user=new_user,
+            user_id=registered_user_id,
+            user_profile=user_profile_data,
         )
         respond_with_html(request, 200, html)
 
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 3308b037d2..ff4750999a 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -21,12 +21,13 @@ import attr
 from typing_extensions import NoReturn, Protocol
 
 from twisted.web.http import Request
+from twisted.web.iweb import IRequest
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
 from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 from synapse.http import get_request_user_agent
-from synapse.http.server import respond_with_html
+from synapse.http.server import respond_with_html, respond_with_redirect
 from synapse.http.site import SynapseRequest
 from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
 from synapse.util.async_helpers import Linearizer
@@ -141,6 +142,9 @@ class UsernameMappingSession:
     # expiry time for the session, in milliseconds
     expiry_time_ms = attr.ib(type=int)
 
+    # choices made by the user
+    chosen_localpart = attr.ib(type=Optional[str], default=None)
+
 
 # the HTTP cookie used to track the mapping session id
 USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session"
@@ -387,6 +391,8 @@ class SsoHandler:
                 to an additional page. (e.g. to prompt for more information)
 
         """
+        new_user = False
+
         # grab a lock while we try to find a mapping for this user. This seems...
         # optimistic, especially for implementations that end up redirecting to
         # interstitial pages.
@@ -427,9 +433,14 @@ class SsoHandler:
                     get_request_user_agent(request),
                     request.getClientIP(),
                 )
+                new_user = True
 
         await self._auth_handler.complete_sso_login(
-            user_id, request, client_redirect_url, extra_login_attributes
+            user_id,
+            request,
+            client_redirect_url,
+            extra_login_attributes,
+            new_user=new_user,
         )
 
     async def _call_attribute_mapper(
@@ -519,7 +530,7 @@ class SsoHandler:
         logger.info("Recorded registration session id %s", session_id)
 
         # Set the cookie and redirect to the username picker
-        e = RedirectException(b"/_synapse/client/pick_username")
+        e = RedirectException(b"/_synapse/client/pick_username/account_details")
         e.cookies.append(
             b"%s=%s; path=/"
             % (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii"))
@@ -647,6 +658,25 @@ class SsoHandler:
         )
         respond_with_html(request, 200, html)
 
+    def get_mapping_session(self, session_id: str) -> UsernameMappingSession:
+        """Look up the given username mapping session
+
+        If it is not found, raises a SynapseError with an http code of 400
+
+        Args:
+            session_id: session to look up
+        Returns:
+            active mapping session
+        Raises:
+            SynapseError if the session is not found/has expired
+        """
+        self._expire_old_sessions()
+        session = self._username_mapping_sessions.get(session_id)
+        if session:
+            return session
+        logger.info("Couldn't find session id %s", session_id)
+        raise SynapseError(400, "unknown session")
+
     async def check_username_availability(
         self, localpart: str, session_id: str,
     ) -> bool:
@@ -663,12 +693,7 @@ class SsoHandler:
 
         # make sure that there is a valid mapping session, to stop people dictionary-
         # scanning for accounts
-
-        self._expire_old_sessions()
-        session = self._username_mapping_sessions.get(session_id)
-        if not session:
-            logger.info("Couldn't find session id %s", session_id)
-            raise SynapseError(400, "unknown session")
+        self.get_mapping_session(session_id)
 
         logger.info(
             "[session %s] Checking for availability of username %s",
@@ -696,16 +721,33 @@ class SsoHandler:
             localpart: localpart requested by the user
             session_id: ID of the username mapping session, extracted from a cookie
         """
-        self._expire_old_sessions()
-        session = self._username_mapping_sessions.get(session_id)
-        if not session:
-            logger.info("Couldn't find session id %s", session_id)
-            raise SynapseError(400, "unknown session")
+        session = self.get_mapping_session(session_id)
+
+        # update the session with the user's choices
+        session.chosen_localpart = localpart
+
+        # we're done; now we can register the user
+        respond_with_redirect(request, b"/_synapse/client/sso_register")
 
-        logger.info("[session %s] Registering localpart %s", session_id, localpart)
+    async def register_sso_user(self, request: Request, session_id: str) -> None:
+        """Called once we have all the info we need to register a new user.
+
+        Does so and serves an HTTP response
+
+        Args:
+            request: HTTP request
+            session_id: ID of the username mapping session, extracted from a cookie
+        """
+        session = self.get_mapping_session(session_id)
+
+        logger.info(
+            "[session %s] Registering localpart %s",
+            session_id,
+            session.chosen_localpart,
+        )
 
         attributes = UserAttributes(
-            localpart=localpart,
+            localpart=session.chosen_localpart,
             display_name=session.display_name,
             emails=session.emails,
         )
@@ -720,7 +762,12 @@ class SsoHandler:
             request.getClientIP(),
         )
 
-        logger.info("[session %s] Registered userid %s", session_id, user_id)
+        logger.info(
+            "[session %s] Registered userid %s with attributes %s",
+            session_id,
+            user_id,
+            attributes,
+        )
 
         # delete the mapping session and the cookie
         del self._username_mapping_sessions[session_id]
@@ -738,6 +785,7 @@ class SsoHandler:
             request,
             session.client_redirect_url,
             session.extra_login_attributes,
+            new_user=True,
         )
 
     def _expire_old_sessions(self):
@@ -751,3 +799,14 @@ class SsoHandler:
         for session_id in to_expire:
             logger.info("Expiring mapping session %s", session_id)
             del self._username_mapping_sessions[session_id]
+
+
+def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
+    """Extract the session ID from the cookie
+
+    Raises a SynapseError if the cookie isn't found
+    """
+    session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
+    if not session_id:
+        raise SynapseError(code=400, msg="missing session_id")
+    return session_id.decode("ascii", errors="replace")
diff --git a/synapse/http/server.py b/synapse/http/server.py
index d69d579b3a..8249732b27 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -761,6 +761,13 @@ def set_clickjacking_protection_headers(request: Request):
     request.setHeader(b"Content-Security-Policy", b"frame-ancestors 'none';")
 
 
+def respond_with_redirect(request: Request, url: bytes) -> None:
+    """Write a 302 response to the request, if it is still alive."""
+    logger.debug("Redirect to %s", url.decode("utf-8"))
+    request.redirect(url)
+    finish_request(request)
+
+
 def finish_request(request: Request):
     """ Finish writing the response to the request.
 
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 72ab5750cc..401d577293 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -279,7 +279,11 @@ class ModuleApi:
         )
 
     async def complete_sso_login_async(
-        self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
+        self,
+        registered_user_id: str,
+        request: SynapseRequest,
+        client_redirect_url: str,
+        new_user: bool = False,
     ):
         """Complete a SSO login by redirecting the user to a page to confirm whether they
         want their access token sent to `client_redirect_url`, or redirect them to that
@@ -291,9 +295,11 @@ class ModuleApi:
             request: The request to respond to.
             client_redirect_url: The URL to which to offer to redirect the user (or to
                 redirect them directly if whitelisted).
+            new_user: set to true to use wording for the consent appropriate to a user
+                who has just registered.
         """
         await self._auth_handler.complete_sso_login(
-            registered_user_id, request, client_redirect_url,
+            registered_user_id, request, client_redirect_url, new_user=new_user
         )
 
     @defer.inlineCallbacks
diff --git a/synapse/res/templates/sso.css b/synapse/res/templates/sso.css
new file mode 100644
index 0000000000..ff9dc94032
--- /dev/null
+++ b/synapse/res/templates/sso.css
@@ -0,0 +1,83 @@
+body {
+  font-family: "Inter", "Helvetica", "Arial", sans-serif;
+  font-size: 14px;
+  color: #17191C;
+}
+
+header {
+  max-width: 480px;
+  width: 100%;
+  margin: 24px auto;
+  text-align: center;
+}
+
+header p {
+  color: #737D8C;
+  line-height: 24px;
+}
+
+h1 {
+  font-size: 24px;
+}
+
+h2 {
+  font-size: 14px;
+}
+
+h2 img {
+  vertical-align: middle;
+  margin-right: 8px;
+  width: 24px;
+  height: 24px;
+}
+
+label {
+  cursor: pointer;
+}
+
+main {
+  max-width: 360px;
+  width: 100%;
+  margin: 24px auto;
+}
+
+.primary-button {
+  border: none;
+  text-decoration: none;
+  padding: 12px;
+  color: white;
+  background-color: #418DED;
+  font-weight: bold;
+  display: block;
+  border-radius: 12px;
+  width: 100%;
+  margin: 16px 0;
+  cursor: pointer;
+  text-align: center;
+}
+
+.profile {
+  display: flex;
+  justify-content: center;
+  margin: 24px 0;
+}
+
+.profile .avatar {
+  width: 36px;
+  height: 36px;
+  border-radius: 100%;
+  display: block;
+  margin-right: 8px;
+}
+
+.profile .display-name {
+  font-weight: bold;
+  margin-bottom: 4px;
+}
+.profile .user-id {
+  color: #737D8C;
+}
+
+.profile .display-name, .profile .user-id {
+  line-height: 18px;
+}
\ No newline at end of file
diff --git a/synapse/res/templates/sso_auth_account_details.html b/synapse/res/templates/sso_auth_account_details.html
new file mode 100644
index 0000000000..f22b09aec1
--- /dev/null
+++ b/synapse/res/templates/sso_auth_account_details.html
@@ -0,0 +1,115 @@
+<!DOCTYPE html>
+<html lang="en">
+  <head>
+    <title>Synapse Login</title>
+    <meta charset="utf-8">
+    <meta name="viewport" content="width=device-width, user-scalable=no">
+    <style type="text/css">
+      {% include "sso.css" without context %}
+
+      .username_input {
+        display: flex;
+        border: 2px solid #418DED;
+        border-radius: 8px;
+        padding: 12px;
+        position: relative;
+        margin: 16px 0;
+        align-items: center;
+        font-size: 12px;
+      }
+
+      .username_input label {
+        position: absolute;
+        top: -8px;
+        left: 14px;
+        font-size: 80%;
+        background: white;
+        padding: 2px;
+      }
+
+      .username_input input {
+        flex: 1;
+        display: block;
+        min-width: 0;
+        border: none;
+      }
+
+      .username_input div {
+        color: #8D99A5;
+      }
+
+      .idp-pick-details {
+        border: 1px solid #E9ECF1;
+        border-radius: 8px;
+        margin: 24px 0;
+      }
+
+      .idp-pick-details h2 {
+        margin: 0;
+        padding: 8px 12px;
+      }
+
+      .idp-pick-details .idp-detail {
+        border-top: 1px solid #E9ECF1;
+        padding: 12px;
+      }
+
+      .idp-pick-details .use, .idp-pick-details .idp-value {
+        color: #737D8C;
+      }
+
+      .idp-pick-details .idp-value {
+        margin: 0;
+        margin-top: 8px;
+      }
+
+      .idp-pick-details .avatar {
+        width: 53px;
+        height: 53px;
+        border-radius: 100%;
+        display: block;
+        margin-top: 8px;
+      }
+    </style>
+  </head>
+  <body>
+    <header>
+      <h1>Your account is nearly ready</h1>
+      <p>Check your details before creating an account on {{ server_name }}</p>
+    </header>
+    <main>
+      <form method="post" class="form__input" id="form">
+        <div class="username_input">
+          <label for="field-username">Username</label>
+          <div class="prefix">@</div>
+          <input type="text" name="username" id="field-username" autofocus required pattern="[a-z0-9\-=_\/\.]+">
+          <div class="postfix">:{{ server_name }}</div>
+        </div>
+        <input type="submit" value="Continue" class="primary-button">
+        {% if user_attributes %}
+        <section class="idp-pick-details">
+          <h2><img src="{{ idp.idp_icon | mxc_to_http(24, 24) }}"/>Information from {{ idp.idp_name }}</h2>
+          {% if user_attributes.avatar_url %}
+          <div class="idp-detail idp-avatar">
+            <img src="{{ user_attributes.avatar_url }}" class="avatar" />
+          </div>
+          {% endif %}
+          {% if user_attributes.display_name %}
+          <div class="idp-detail">
+            <p class="idp-value">{{ user_attributes.display_name }}</p>
+          </div>
+          {% endif %}
+          {% for email in user_attributes.emails %}
+          <div class="idp-detail">
+            <p class="idp-value">{{ email }}</p>
+          </div>
+          {% endfor %}
+        </section>
+        {% endif %}
+      </form>
+    </main>
+    <script type="text/javascript">
+      {% include "sso_auth_account_details.js" without context %}
+    </script>
+  </body>
+</html>
diff --git a/synapse/res/templates/sso_auth_account_details.js b/synapse/res/templates/sso_auth_account_details.js
new file mode 100644
index 0000000000..deef419bb6
--- /dev/null
+++ b/synapse/res/templates/sso_auth_account_details.js
@@ -0,0 +1,76 @@
+const usernameField = document.getElementById("field-username");
+
+function throttle(fn, wait) {
+    let timeout;
+    return function() {
+        const args = Array.from(arguments);
+        if (timeout) {
+            clearTimeout(timeout);
+        }
+        timeout = setTimeout(fn.bind.apply(fn, [null].concat(args)), wait);
+    }
+}
+
+function checkUsernameAvailable(username) {
+    let check_uri = 'check?username=' + encodeURIComponent(username);
+    return fetch(check_uri, {
+        // include the cookie
+        "credentials": "same-origin",
+    }).then((response) => {
+        if(!response.ok) {
+            // for non-200 responses, raise the body of the response as an exception
+            return response.text().then((text) => { throw new Error(text); });
+        } else {
+            return response.json();
+        }
+    }).then((json) => {
+        if(json.error) {
+            return {message: json.error};
+        } else if(json.available) {
+            return {available: true};
+        } else {
+            return {message: username + " is not available, please choose another."};
+        }
+    });
+}
+
+function validateUsername(username) {
+    usernameField.setCustomValidity("");
+    if (usernameField.validity.valueMissing) {
+        usernameField.setCustomValidity("Please provide a username");
+        return;
+    }
+    if (usernameField.validity.patternMismatch) {
+        usernameField.setCustomValidity("Invalid username, please only use " + allowedCharactersString);
+        return;
+    }
+    usernameField.setCustomValidity("Checking if username is available …");
+    throttledCheckUsernameAvailable(username);
+}
+
+const throttledCheckUsernameAvailable = throttle(function(username) {
+    const handleError =  function(err) {
+        // don't prevent form submission on error
+        usernameField.setCustomValidity("");
+        console.log(err.message);
+    };
+    try {
+        checkUsernameAvailable(username).then(function(result) {
+            if (!result.available) {
+                usernameField.setCustomValidity(result.message);
+                usernameField.reportValidity();
+            } else {
+                usernameField.setCustomValidity("");
+            }
+        }, handleError);
+    } catch (err) {
+        handleError(err);
+    }
+}, 500);
+
+usernameField.addEventListener("input", function(evt) {
+    validateUsername(usernameField.value);
+});
+usernameField.addEventListener("change", function(evt) {
+    validateUsername(usernameField.value);
+});
diff --git a/synapse/res/templates/sso_redirect_confirm.html b/synapse/res/templates/sso_redirect_confirm.html
index a45a50916c..d1328a6969 100644
--- a/synapse/res/templates/sso_redirect_confirm.html
+++ b/synapse/res/templates/sso_redirect_confirm.html
@@ -3,12 +3,34 @@
 <head>
     <meta charset="UTF-8">
     <title>SSO redirect confirmation</title>
+    <meta name="viewport" content="width=device-width, user-scalable=no">
+    <style type="text/css">
+      {% include "sso.css" without context %}
+    </style>
 </head>
     <body>
-        <p>The application at <span style="font-weight:bold">{{ display_url }}</span> is requesting full access to your <span style="font-weight:bold">{{ server_name }}</span> Matrix account.</p>
-        <p>If you don't recognise this address, you should ignore this and close this tab.</p>
-        <p>
-            <a href="{{ redirect_url }}">I trust this address</a>
-        </p>
+        <header>
+            {% if new_user %}
+            <h1>Your account is now ready</h1>
+            <p>You've made your account on {{ server_name }}.</p>
+            {% else %}
+            <h1>Log in</h1>
+            {% endif %}
+            <p>Continue to confirm you trust <strong>{{ display_url }}</strong>.</p>
+        </header>
+        <main>
+            {% if user_profile.avatar_url %}
+            <div class="profile">
+                <img src="{{ user_profile.avatar_url | mxc_to_http(64, 64) }}" class="avatar" />
+                <div class="profile-details">
+                    {% if user_profile.display_name %}
+                    <div class="display-name">{{ user_profile.display_name }}</div>
+                    {% endif %}
+                    <div class="user-id">{{ user_id }}</div>
+                </div>
+            </div>
+            {% endif %}
+            <a href="{{ redirect_url }}" class="primary-button">Continue</a>
+        </main>
     </body>
 </html>
diff --git a/synapse/res/username_picker/index.html b/synapse/res/username_picker/index.html
deleted file mode 100644
index 37ea8bb6d8..0000000000
--- a/synapse/res/username_picker/index.html
+++ /dev/null
@@ -1,19 +0,0 @@
-<!DOCTYPE html>
-<html lang="en">
-  <head>
-    <title>Synapse Login</title>
-    <link rel="stylesheet" href="style.css" type="text/css" />
-  </head>
-  <body>
-    <div class="card">
-      <form method="post" class="form__input" id="form" action="submit">
-        <label for="field-username">Please pick your username:</label>
-        <input type="text" name="username" id="field-username" autofocus="">
-        <input type="submit" class="button button--full-width" id="button-submit" value="Submit">
-      </form>
-      <!-- this is used for feedback -->
-      <div role=alert class="tooltip hidden" id="message"></div>
-      <script src="script.js"></script>
-    </div>
-  </body>
-</html>
diff --git a/synapse/res/username_picker/script.js b/synapse/res/username_picker/script.js
deleted file mode 100644
index 416a7c6f41..0000000000
--- a/synapse/res/username_picker/script.js
+++ /dev/null
@@ -1,95 +0,0 @@
-let inputField = document.getElementById("field-username");
-let inputForm = document.getElementById("form");
-let submitButton = document.getElementById("button-submit");
-let message = document.getElementById("message");
-
-// Submit username and receive response
-function showMessage(messageText) {
-    // Unhide the message text
-    message.classList.remove("hidden");
-
-    message.textContent = messageText;
-};
-
-function doSubmit() {
-    showMessage("Success. Please wait a moment for your browser to redirect.");
-
-    // remove the event handler before re-submitting the form.
-    delete inputForm.onsubmit;
-    inputForm.submit();
-}
-
-function onResponse(response) {
-    // Display message
-    showMessage(response);
-
-    // Enable submit button and input field
-    submitButton.classList.remove('button--disabled');
-    submitButton.value = "Submit";
-};
-
-let allowedUsernameCharacters = RegExp("[^a-z0-9\\.\\_\\=\\-\\/]");
-function usernameIsValid(username) {
-    return !allowedUsernameCharacters.test(username);
-}
-let allowedCharactersString = "lowercase letters, digits, ., _, -, /, =";
-
-function buildQueryString(params) {
-    return Object.keys(params)
-        .map(k => encodeURIComponent(k) + '=' + encodeURIComponent(params[k]))
-        .join('&');
-}
-
-function submitUsername(username) {
-    if(username.length == 0) {
-        onResponse("Please enter a username.");
-        return;
-    }
-    if(!usernameIsValid(username)) {
-        onResponse("Invalid username. Only the following characters are allowed: " + allowedCharactersString);
-        return;
-    }
-
-    // if this browser doesn't support fetch, skip the availability check.
-    if(!window.fetch) {
-        doSubmit();
-        return;
-    }
-
-    let check_uri = 'check?' + buildQueryString({"username": username});
-    fetch(check_uri, {
-        // include the cookie
-        "credentials": "same-origin",
-    }).then((response) => {
-        if(!response.ok) {
-            // for non-200 responses, raise the body of the response as an exception
-            return response.text().then((text) => { throw text; });
-        } else {
-            return response.json();
-        }
-    }).then((json) => {
-        if(json.error) {
-            throw json.error;
-        } else if(json.available) {
-            doSubmit();
-        } else {
-            onResponse("This username is not available, please choose another.");
-        }
-    }).catch((err) => {
-        onResponse("Error checking username availability: " + err);
-    });
-}
-
-function clickSubmit() {
-    event.preventDefault();
-    if(submitButton.classList.contains('button--disabled')) { return; }
-
-    // Disable submit button and input field
-    submitButton.classList.add('button--disabled');
-
-    // Submit username
-    submitButton.value = "Checking...";
-    submitUsername(inputField.value);
-};
-
-inputForm.onsubmit = clickSubmit;
diff --git a/synapse/res/username_picker/style.css b/synapse/res/username_picker/style.css
deleted file mode 100644
index 745bd4c684..0000000000
--- a/synapse/res/username_picker/style.css
+++ /dev/null
@@ -1,27 +0,0 @@
-input[type="text"] {
-  font-size: 100%;
-  background-color: #ededf0;
-  border: 1px solid #fff;
-  border-radius: .2em;
-  padding: .5em .9em;
-  display: block;
-  width: 26em;
-}
-
-.button--disabled {
-  border-color: #fff;
-  background-color: transparent;
-  color: #000;
-  text-transform: none;
-}
-
-.hidden {
-  display: none;
-}
-
-.tooltip {
-  background-color: #f9f9fa;
-  padding: 1em;
-  margin: 1em 0;
-}
-
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index b3e4d5612e..8b9ef26cf2 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -100,6 +100,7 @@ class ConsentResource(DirectServeHtmlResource):
 
         consent_template_directory = hs.config.user_consent_template_dir
 
+        # TODO: switch to synapse.util.templates.build_jinja_env
         loader = jinja2.FileSystemLoader(consent_template_directory)
         self._jinja_env = jinja2.Environment(
             loader=loader, autoescape=jinja2.select_autoescape(["html", "htm", "xml"])
diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py
index c0b733488b..6acbc03d73 100644
--- a/synapse/rest/synapse/client/__init__.py
+++ b/synapse/rest/synapse/client/__init__.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2020 The Matrix.org Foundation C.I.C.
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,3 +12,50 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
+from typing import TYPE_CHECKING, Mapping
+
+from twisted.web.resource import Resource
+
+from synapse.rest.synapse.client.pick_idp import PickIdpResource
+from synapse.rest.synapse.client.pick_username import pick_username_resource
+from synapse.rest.synapse.client.sso_register import SsoRegisterResource
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+
+def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resource]:
+    """Builds a resource tree to include synapse-specific client resources
+
+    These are resources which should be loaded on all workers which expose a C-S API:
+    ie, the main process, and any generic workers so configured.
+
+    Returns:
+         map from path to Resource.
+    """
+    resources = {
+        # SSO bits. These are always loaded, whether or not SSO login is actually
+        # enabled (they just won't work very well if it's not)
+        "/_synapse/client/pick_idp": PickIdpResource(hs),
+        "/_synapse/client/pick_username": pick_username_resource(hs),
+        "/_synapse/client/sso_register": SsoRegisterResource(hs),
+    }
+
+    # provider-specific SSO bits. Only load these if they are enabled, since they
+    # rely on optional dependencies.
+    if hs.config.oidc_enabled:
+        from synapse.rest.oidc import OIDCResource
+
+        resources["/_synapse/oidc"] = OIDCResource(hs)
+
+    if hs.config.saml2_enabled:
+        from synapse.rest.saml2 import SAML2Resource
+
+        # This is mounted under '/_matrix' for backwards-compatibility.
+        resources["/_matrix/saml2"] = SAML2Resource(hs)
+
+    return resources
+
+
+__all__ = ["build_synapse_client_resource_tree"]
diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py
index d3b6803e65..27540d3bbe 100644
--- a/synapse/rest/synapse/client/pick_username.py
+++ b/synapse/rest/synapse/client/pick_username.py
@@ -12,42 +12,42 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import TYPE_CHECKING
 
-import pkg_resources
+import logging
+from typing import TYPE_CHECKING
 
 from twisted.web.http import Request
 from twisted.web.resource import Resource
-from twisted.web.static import File
 
 from synapse.api.errors import SynapseError
-from synapse.handlers.sso import USERNAME_MAPPING_SESSION_COOKIE_NAME
-from synapse.http.server import DirectServeHtmlResource, DirectServeJsonResource
+from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
+from synapse.http.server import (
+    DirectServeHtmlResource,
+    DirectServeJsonResource,
+    respond_with_html,
+)
 from synapse.http.servlet import parse_string
 from synapse.http.site import SynapseRequest
+from synapse.util.templates import build_jinja_env
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
+logger = logging.getLogger(__name__)
+
 
 def pick_username_resource(hs: "HomeServer") -> Resource:
     """Factory method to generate the username picker resource.
 
-    This resource gets mounted under /_synapse/client/pick_username. The top-level
-    resource is just a File resource which serves up the static files in the resources
-    "res" directory, but it has a couple of children:
-
-    * "submit", which does the mechanics of registering the new user, and redirects the
-      browser back to the client URL
+    This resource gets mounted under /_synapse/client/pick_username and has two
+       children:
 
-    * "check": checks if a userid is free.
+      * "account_details": renders the form and handles the POSTed response
+      * "check": a JSON endpoint which checks if a userid is free.
     """
 
-    # XXX should we make this path customisable so that admins can restyle it?
-    base_path = pkg_resources.resource_filename("synapse", "res/username_picker")
-
-    res = File(base_path)
-    res.putChild(b"submit", SubmitResource(hs))
+    res = Resource()
+    res.putChild(b"account_details", AccountDetailsResource(hs))
     res.putChild(b"check", AvailabilityCheckResource(hs))
 
     return res
@@ -61,28 +61,63 @@ class AvailabilityCheckResource(DirectServeJsonResource):
     async def _async_render_GET(self, request: Request):
         localpart = parse_string(request, "username", required=True)
 
-        session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
-        if not session_id:
-            raise SynapseError(code=400, msg="missing session_id")
+        session_id = get_username_mapping_session_cookie_from_request(request)
 
         is_available = await self._sso_handler.check_username_availability(
-            localpart, session_id.decode("ascii", errors="replace")
+            localpart, session_id
         )
         return 200, {"available": is_available}
 
 
-class SubmitResource(DirectServeHtmlResource):
+class AccountDetailsResource(DirectServeHtmlResource):
     def __init__(self, hs: "HomeServer"):
         super().__init__()
         self._sso_handler = hs.get_sso_handler()
 
-    async def _async_render_POST(self, request: SynapseRequest):
-        localpart = parse_string(request, "username", required=True)
+        def template_search_dirs():
+            if hs.config.sso.sso_template_dir:
+                yield hs.config.sso.sso_template_dir
+            yield hs.config.sso.default_template_dir
+
+        self._jinja_env = build_jinja_env(template_search_dirs(), hs.config)
+
+    async def _async_render_GET(self, request: Request) -> None:
+        try:
+            session_id = get_username_mapping_session_cookie_from_request(request)
+            session = self._sso_handler.get_mapping_session(session_id)
+        except SynapseError as e:
+            logger.warning("Error fetching session: %s", e)
+            self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+            return
+
+        idp_id = session.auth_provider_id
+        template_params = {
+            "idp": self._sso_handler.get_identity_providers()[idp_id],
+            "user_attributes": {
+                "display_name": session.display_name,
+                "emails": session.emails,
+            },
+        }
+
+        template = self._jinja_env.get_template("sso_auth_account_details.html")
+        html = template.render(template_params)
+        respond_with_html(request, 200, html)
 
-        session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
-        if not session_id:
-            raise SynapseError(code=400, msg="missing session_id")
+    async def _async_render_POST(self, request: SynapseRequest):
+        try:
+            session_id = get_username_mapping_session_cookie_from_request(request)
+        except SynapseError as e:
+            logger.warning("Error fetching session cookie: %s", e)
+            self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+            return
+
+        try:
+            localpart = parse_string(request, "username", required=True)
+        except SynapseError as e:
+            logger.warning("[session %s] bad param: %s", session_id, e)
+            self._sso_handler.render_error(request, "bad_param", e.msg, code=e.code)
+            return
 
         await self._sso_handler.handle_submit_username_request(
-            request, localpart, session_id.decode("ascii", errors="replace")
+            request, localpart, session_id
         )
diff --git a/synapse/rest/synapse/client/sso_register.py b/synapse/rest/synapse/client/sso_register.py
new file mode 100644
index 0000000000..dfefeb7796
--- /dev/null
+++ b/synapse/rest/synapse/client/sso_register.py
@@ -0,0 +1,50 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
+
+from synapse.api.errors import SynapseError
+from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
+from synapse.http.server import DirectServeHtmlResource
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class SsoRegisterResource(DirectServeHtmlResource):
+    """A resource which completes SSO registration
+
+    This resource gets mounted at /_synapse/client/sso_register, and is shown
+    after we collect username and/or consent for a new SSO user. It (finally) registers
+    the user, and confirms redirect to the client
+    """
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self._sso_handler = hs.get_sso_handler()
+
+    async def _async_render_GET(self, request: Request) -> None:
+        try:
+            session_id = get_username_mapping_session_cookie_from_request(request)
+        except SynapseError as e:
+            logger.warning("Error fetching session cookie: %s", e)
+            self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+            return
+        await self._sso_handler.register_sso_user(request, session_id)
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 0618b4387a..8405dd460f 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -472,6 +472,26 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
         return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
 
+    async def record_user_external_id(
+        self, auth_provider: str, external_id: str, user_id: str
+    ) -> None:
+        """Record a mapping from an external user id to a mxid
+
+        Args:
+            auth_provider: identifier for the remote auth provider
+            external_id: id on that system
+            user_id: complete mxid that it is mapped to
+        """
+        await self.db_pool.simple_insert(
+            table="user_external_ids",
+            values={
+                "auth_provider": auth_provider,
+                "external_id": external_id,
+                "user_id": user_id,
+            },
+            desc="record_user_external_id",
+        )
+
     async def get_user_by_external_id(
         self, auth_provider: str, external_id: str
     ) -> Optional[str]:
@@ -1400,26 +1420,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
 
         self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
-    async def record_user_external_id(
-        self, auth_provider: str, external_id: str, user_id: str
-    ) -> None:
-        """Record a mapping from an external user id to a mxid
-
-        Args:
-            auth_provider: identifier for the remote auth provider
-            external_id: id on that system
-            user_id: complete mxid that it is mapped to
-        """
-        await self.db_pool.simple_insert(
-            table="user_external_ids",
-            values={
-                "auth_provider": auth_provider,
-                "external_id": external_id,
-                "user_id": user_id,
-            },
-            desc="record_user_external_id",
-        )
-
     async def user_set_password_hash(
         self, user_id: str, password_hash: Optional[str]
     ) -> None:
diff --git a/synapse/util/templates.py b/synapse/util/templates.py
new file mode 100644
index 0000000000..7e5109d206
--- /dev/null
+++ b/synapse/util/templates.py
@@ -0,0 +1,106 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities for dealing with jinja2 templates"""
+
+import time
+import urllib.parse
+from typing import TYPE_CHECKING, Callable, Iterable, Union
+
+import jinja2
+
+if TYPE_CHECKING:
+    from synapse.config.homeserver import HomeServerConfig
+
+
+def build_jinja_env(
+    template_search_directories: Iterable[str],
+    config: "HomeServerConfig",
+    autoescape: Union[bool, Callable[[str], bool], None] = None,
+) -> jinja2.Environment:
+    """Set up a Jinja2 environment to load templates from the given search path
+
+    The returned environment defines the following filters:
+        - format_ts: formats timestamps as strings in the server's local timezone
+             (XXX: why is that useful??)
+        - mxc_to_http: converts mxc: uris to http URIs. Args are:
+             (uri, width, height, resize_method="crop")
+
+    and the following global variables:
+        - server_name: matrix server name
+
+    Args:
+        template_search_directories: directories to search for templates
+
+        config: homeserver config, for things like `server_name` and `public_baseurl`
+
+        autoescape: whether template variables should be autoescaped. bool, or
+           a function mapping from template name to bool. Defaults to escaping templates
+           whose names end in .html, .xml or .htm.
+
+    Returns:
+        jinja environment
+    """
+
+    if autoescape is None:
+        autoescape = jinja2.select_autoescape()
+
+    loader = jinja2.FileSystemLoader(template_search_directories)
+    env = jinja2.Environment(loader=loader, autoescape=autoescape)
+
+    # Update the environment with our custom filters
+    env.filters.update(
+        {
+            "format_ts": _format_ts_filter,
+            "mxc_to_http": _create_mxc_to_http_filter(config.public_baseurl),
+        }
+    )
+
+    # common variables for all templates
+    env.globals.update({"server_name": config.server_name})
+
+    return env
+
+
+def _create_mxc_to_http_filter(public_baseurl: str) -> Callable:
+    """Create and return a jinja2 filter that converts MXC urls to HTTP
+
+    Args:
+        public_baseurl: The public, accessible base URL of the homeserver
+    """
+
+    def mxc_to_http_filter(value, width, height, resize_method="crop"):
+        if value[0:6] != "mxc://":
+            return ""
+
+        server_and_media_id = value[6:]
+        fragment = None
+        if "#" in server_and_media_id:
+            server_and_media_id, fragment = server_and_media_id.split("#", 1)
+            fragment = "#" + fragment
+
+        params = {"width": width, "height": height, "method": resize_method}
+        return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
+            public_baseurl,
+            server_and_media_id,
+            urllib.parse.urlencode(params),
+            fragment or "",
+        )
+
+    return mxc_to_http_filter
+
+
+def _format_ts_filter(value: int, format: str):
+    return time.strftime(format, time.localtime(value / 1000))
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index c37bb6440e..7baf224f7e 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -62,7 +62,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None
+            "@test_user:test", request, "redirect_uri", None, new_user=True
         )
 
     def test_map_cas_user_to_existing_user(self):
@@ -85,7 +85,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None
+            "@test_user:test", request, "redirect_uri", None, new_user=False
         )
 
         # Subsequent calls should map to the same mxid.
@@ -94,7 +94,7 @@ class CasHandlerTestCase(HomeserverTestCase):
             self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
         )
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None
+            "@test_user:test", request, "redirect_uri", None, new_user=False
         )
 
     def test_map_cas_user_to_invalid_localpart(self):
@@ -112,7 +112,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@f=c3=b6=c3=b6:test", request, "redirect_uri", None
+            "@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
         )
 
 
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index b3dfa40d25..d8f90b9a80 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -419,7 +419,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.handler.handle_oidc_callback(request))
 
         auth_handler.complete_sso_login.assert_called_once_with(
-            expected_user_id, request, client_redirect_url, None,
+            expected_user_id, request, client_redirect_url, None, new_user=True
         )
         self.provider._exchange_code.assert_called_once_with(code)
         self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
@@ -450,7 +450,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.handler.handle_oidc_callback(request))
 
         auth_handler.complete_sso_login.assert_called_once_with(
-            expected_user_id, request, client_redirect_url, None,
+            expected_user_id, request, client_redirect_url, None, new_user=False
         )
         self.provider._exchange_code.assert_called_once_with(code)
         self.provider._parse_id_token.assert_not_called()
@@ -623,7 +623,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.handler.handle_oidc_callback(request))
 
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@foo:test", request, client_redirect_url, {"phone": "1234567"},
+            "@foo:test",
+            request,
+            client_redirect_url,
+            {"phone": "1234567"},
+            new_user=True,
         )
 
     def test_map_userinfo_to_user(self):
@@ -637,7 +641,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", ANY, ANY, None,
+            "@test_user:test", ANY, ANY, None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -648,7 +652,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user_2:test", ANY, ANY, None,
+            "@test_user_2:test", ANY, ANY, None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -685,14 +689,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            user.to_string(), ANY, ANY, None,
+            user.to_string(), ANY, ANY, None, new_user=False
         )
         auth_handler.complete_sso_login.reset_mock()
 
         # Subsequent calls should map to the same mxid.
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            user.to_string(), ANY, ANY, None,
+            user.to_string(), ANY, ANY, None, new_user=False
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -707,7 +711,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            user.to_string(), ANY, ANY, None,
+            user.to_string(), ANY, ANY, None, new_user=False
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -743,7 +747,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@TEST_USER_2:test", ANY, ANY, None,
+            "@TEST_USER_2:test", ANY, ANY, None, new_user=False
         )
 
     def test_map_userinfo_to_invalid_localpart(self):
@@ -779,7 +783,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         # test_user is already taken, so test_user1 gets registered instead.
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user1:test", ANY, ANY, None,
+            "@test_user1:test", ANY, ANY, None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 261c7083d1..a8d6c0f617 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -131,7 +131,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None
+            "@test_user:test", request, "redirect_uri", None, new_user=True
         )
 
     @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
@@ -157,7 +157,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "", None
+            "@test_user:test", request, "", None, new_user=False
         )
 
         # Subsequent calls should map to the same mxid.
@@ -166,7 +166,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
             self.handler._handle_authn_response(request, saml_response, "")
         )
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "", None
+            "@test_user:test", request, "", None, new_user=False
         )
 
     def test_map_saml_response_to_invalid_localpart(self):
@@ -214,7 +214,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # test_user is already taken, so test_user1 gets registered instead.
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user1:test", request, "", None
+            "@test_user1:test", request, "", None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index e2bb945453..66dfdaffbc 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -29,8 +29,7 @@ from synapse.appservice import ApplicationService
 from synapse.rest.client.v1 import login, logout
 from synapse.rest.client.v2_alpha import devices, register
 from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
-from synapse.rest.synapse.client.pick_idp import PickIdpResource
-from synapse.rest.synapse.client.pick_username import pick_username_resource
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
 from synapse.types import create_requester
 
 from tests import unittest
@@ -423,11 +422,8 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         return config
 
     def create_resource_dict(self) -> Dict[str, Resource]:
-        from synapse.rest.oidc import OIDCResource
-
         d = super().create_resource_dict()
-        d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs)
-        d["/_synapse/oidc"] = OIDCResource(self.hs)
+        d.update(build_synapse_client_resource_tree(self.hs))
         return d
 
     def test_get_login_flows(self):
@@ -1211,11 +1207,8 @@ class UsernamePickerTestCase(HomeserverTestCase):
         return config
 
     def create_resource_dict(self) -> Dict[str, Resource]:
-        from synapse.rest.oidc import OIDCResource
-
         d = super().create_resource_dict()
-        d["/_synapse/client/pick_username"] = pick_username_resource(self.hs)
-        d["/_synapse/oidc"] = OIDCResource(self.hs)
+        d.update(build_synapse_client_resource_tree(self.hs))
         return d
 
     def test_username_picker(self):
@@ -1229,7 +1222,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
         # that should redirect to the username picker
         self.assertEqual(channel.code, 302, channel.result)
         picker_url = channel.headers.getRawHeaders("Location")[0]
-        self.assertEqual(picker_url, "/_synapse/client/pick_username")
+        self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details")
 
         # ... with a username_mapping_session cookie
         cookies = {}  # type: Dict[str,str]
@@ -1253,12 +1246,11 @@ class UsernamePickerTestCase(HomeserverTestCase):
         self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)
 
         # Now, submit a username to the username picker, which should serve a redirect
-        # back to the client
-        submit_path = picker_url + "/submit"
+        # to the completion page
         content = urlencode({b"username": b"bobby"}).encode("utf8")
         chan = self.make_request(
             "POST",
-            path=submit_path,
+            path=picker_url,
             content=content,
             content_is_form=True,
             custom_headers=[
@@ -1270,6 +1262,16 @@ class UsernamePickerTestCase(HomeserverTestCase):
         )
         self.assertEqual(chan.code, 302, chan.result)
         location_headers = chan.headers.getRawHeaders("Location")
+
+        # send a request to the completion page, which should 302 to the client redirectUrl
+        chan = self.make_request(
+            "GET",
+            path=location_headers[0],
+            custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
+        )
+        self.assertEqual(chan.code, 302, chan.result)
+        location_headers = chan.headers.getRawHeaders("Location")
+
         # ensure that the returned location matches the requested redirect URL
         path, query = location_headers[0].split("?", 1)
         self.assertEqual(path, "https://x")
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index a6488a3d29..3f50c56745 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -22,7 +22,7 @@ from synapse.api.constants import LoginType
 from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
 from synapse.rest.client.v1 import login
 from synapse.rest.client.v2_alpha import auth, devices, register
-from synapse.rest.oidc import OIDCResource
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
 from synapse.types import JsonDict, UserID
 
 from tests import unittest
@@ -173,9 +173,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
     def create_resource_dict(self):
         resource_dict = super().create_resource_dict()
-        if HAS_OIDC:
-            # mount the OIDC resource at /_synapse/oidc
-            resource_dict["/_synapse/oidc"] = OIDCResource(self.hs)
+        resource_dict.update(build_synapse_client_resource_tree(self.hs))
         return resource_dict
 
     def prepare(self, reactor, clock, hs):