summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2020-12-18 14:19:46 +0000
committerGitHub <noreply@github.com>2020-12-18 14:19:46 +0000
commit28877fade90a5cfb3457c9e6c70924dbbe8af715 (patch)
tree77e5c2e21144fb58026b98a6ad340b95a7419bb4 /synapse
parentAllow re-using a UI auth validation for a period of time (#8970) (diff)
downloadsynapse-28877fade90a5cfb3457c9e6c70924dbbe8af715.tar.xz
Implement a username picker for synapse (#8942)
The final part (for now) of my work to implement a username picker in synapse itself. The idea is that we allow
`UsernameMappingProvider`s to return `localpart=None`, in which case, rather than redirecting the browser
back to the client, we redirect to a username-picker resource, which allows the user to enter a username.
We *then* complete the SSO flow (including doing the client permission checks).

The static resources for the username picker itself (in 
https://github.com/matrix-org/synapse/tree/rav/username_picker/synapse/res/username_picker)
are essentially lifted wholesale from
https://github.com/matrix-org/matrix-synapse-saml-mozilla/tree/master/matrix_synapse_saml_mozilla/res. 
As the comment says, we might want to think about making them customisable, but that can be a follow-up. 

Fixes #8876.
Diffstat (limited to 'synapse')
-rw-r--r--synapse/app/homeserver.py2
-rw-r--r--synapse/config/oidc_config.py5
-rw-r--r--synapse/handlers/oidc_handler.py59
-rw-r--r--synapse/handlers/sso.py254
-rw-r--r--synapse/res/username_picker/index.html19
-rw-r--r--synapse/res/username_picker/script.js95
-rw-r--r--synapse/res/username_picker/style.css27
-rw-r--r--synapse/rest/synapse/client/pick_username.py88
-rw-r--r--synapse/types.py8
9 files changed, 512 insertions, 45 deletions
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index bbb7407838..8d9b53be53 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -63,6 +63,7 @@ from synapse.rest import ClientRestResource
 from synapse.rest.admin import AdminRestResource
 from synapse.rest.health import HealthResource
 from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.rest.synapse.client.pick_username import pick_username_resource
 from synapse.rest.well_known import WellKnownResource
 from synapse.server import HomeServer
 from synapse.storage import DataStore
@@ -192,6 +193,7 @@ class SynapseHomeServer(HomeServer):
                     "/_matrix/client/versions": client_resource,
                     "/.well-known/matrix/client": WellKnownResource(self),
                     "/_synapse/admin": AdminRestResource(self),
+                    "/_synapse/client/pick_username": pick_username_resource(self),
                 }
             )
 
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index 1abf8ed405..4e3055282d 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -203,9 +203,10 @@ class OIDCConfig(Config):
               #   * user: The claims returned by the UserInfo Endpoint and/or in the ID
               #     Token
               #
-              # This must be configured if using the default mapping provider.
+              # If this is not set, the user will be prompted to choose their
+              # own username.
               #
-              localpart_template: "{{{{ user.preferred_username }}}}"
+              #localpart_template: "{{{{ user.preferred_username }}}}"
 
               # Jinja2 template for the display name to set on first login.
               #
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index cbd11a1382..709f8dfc13 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -947,7 +947,7 @@ class OidcHandler(BaseHandler):
 
 
 UserAttributeDict = TypedDict(
-    "UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
+    "UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]}
 )
 C = TypeVar("C")
 
@@ -1028,10 +1028,10 @@ env = Environment(finalize=jinja_finalize)
 
 @attr.s
 class JinjaOidcMappingConfig:
-    subject_claim = attr.ib()  # type: str
-    localpart_template = attr.ib()  # type: Template
-    display_name_template = attr.ib()  # type: Optional[Template]
-    extra_attributes = attr.ib()  # type: Dict[str, Template]
+    subject_claim = attr.ib(type=str)
+    localpart_template = attr.ib(type=Optional[Template])
+    display_name_template = attr.ib(type=Optional[Template])
+    extra_attributes = attr.ib(type=Dict[str, Template])
 
 
 class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
@@ -1047,18 +1047,14 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
     def parse_config(config: dict) -> JinjaOidcMappingConfig:
         subject_claim = config.get("subject_claim", "sub")
 
-        if "localpart_template" not in config:
-            raise ConfigError(
-                "missing key: oidc_config.user_mapping_provider.config.localpart_template"
-            )
-
-        try:
-            localpart_template = env.from_string(config["localpart_template"])
-        except Exception as e:
-            raise ConfigError(
-                "invalid jinja template for oidc_config.user_mapping_provider.config.localpart_template: %r"
-                % (e,)
-            )
+        localpart_template = None  # type: Optional[Template]
+        if "localpart_template" in config:
+            try:
+                localpart_template = env.from_string(config["localpart_template"])
+            except Exception as e:
+                raise ConfigError(
+                    "invalid jinja template", path=["localpart_template"]
+                ) from e
 
         display_name_template = None  # type: Optional[Template]
         if "display_name_template" in config:
@@ -1066,26 +1062,22 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
                 display_name_template = env.from_string(config["display_name_template"])
             except Exception as e:
                 raise ConfigError(
-                    "invalid jinja template for oidc_config.user_mapping_provider.config.display_name_template: %r"
-                    % (e,)
-                )
+                    "invalid jinja template", path=["display_name_template"]
+                ) from e
 
         extra_attributes = {}  # type Dict[str, Template]
         if "extra_attributes" in config:
             extra_attributes_config = config.get("extra_attributes") or {}
             if not isinstance(extra_attributes_config, dict):
-                raise ConfigError(
-                    "oidc_config.user_mapping_provider.config.extra_attributes must be a dict"
-                )
+                raise ConfigError("must be a dict", path=["extra_attributes"])
 
             for key, value in extra_attributes_config.items():
                 try:
                     extra_attributes[key] = env.from_string(value)
                 except Exception as e:
                     raise ConfigError(
-                        "invalid jinja template for oidc_config.user_mapping_provider.config.extra_attributes.%s: %r"
-                        % (key, e)
-                    )
+                        "invalid jinja template", path=["extra_attributes", key]
+                    ) from e
 
         return JinjaOidcMappingConfig(
             subject_claim=subject_claim,
@@ -1100,14 +1092,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
     async def map_user_attributes(
         self, userinfo: UserInfo, token: Token, failures: int
     ) -> UserAttributeDict:
-        localpart = self._config.localpart_template.render(user=userinfo).strip()
+        localpart = None
+
+        if self._config.localpart_template:
+            localpart = self._config.localpart_template.render(user=userinfo).strip()
 
-        # Ensure only valid characters are included in the MXID.
-        localpart = map_username_to_mxid_localpart(localpart)
+            # Ensure only valid characters are included in the MXID.
+            localpart = map_username_to_mxid_localpart(localpart)
 
-        # Append suffix integer if last call to this function failed to produce
-        # a usable mxid.
-        localpart += str(failures) if failures else ""
+            # Append suffix integer if last call to this function failed to produce
+            # a usable mxid.
+            localpart += str(failures) if failures else ""
 
         display_name = None  # type: Optional[str]
         if self._config.display_name_template is not None:
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index f054b66a53..548b02211b 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -13,17 +13,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional
 
 import attr
+from typing_extensions import NoReturn
 
 from twisted.web.http import Request
 
-from synapse.api.errors import RedirectException
+from synapse.api.errors import RedirectException, SynapseError
 from synapse.http.server import respond_with_html
 from synapse.http.site import SynapseRequest
 from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
 from synapse.util.async_helpers import Linearizer
+from synapse.util.stringutils import random_string
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -40,16 +42,52 @@ class MappingException(Exception):
 
 @attr.s
 class UserAttributes:
-    localpart = attr.ib(type=str)
+    # the localpart of the mxid that the mapper has assigned to the user.
+    # if `None`, the mapper has not picked a userid, and the user should be prompted to
+    # enter one.
+    localpart = attr.ib(type=Optional[str])
     display_name = attr.ib(type=Optional[str], default=None)
     emails = attr.ib(type=List[str], default=attr.Factory(list))
 
 
+@attr.s(slots=True)
+class UsernameMappingSession:
+    """Data we track about SSO sessions"""
+
+    # A unique identifier for this SSO provider, e.g.  "oidc" or "saml".
+    auth_provider_id = attr.ib(type=str)
+
+    # user ID on the IdP server
+    remote_user_id = attr.ib(type=str)
+
+    # attributes returned by the ID mapper
+    display_name = attr.ib(type=Optional[str])
+    emails = attr.ib(type=List[str])
+
+    # An optional dictionary of extra attributes to be provided to the client in the
+    # login response.
+    extra_login_attributes = attr.ib(type=Optional[JsonDict])
+
+    # where to redirect the client back to
+    client_redirect_url = attr.ib(type=str)
+
+    # expiry time for the session, in milliseconds
+    expiry_time_ms = attr.ib(type=int)
+
+
+# the HTTP cookie used to track the mapping session id
+USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session"
+
+
 class SsoHandler:
     # The number of attempts to ask the mapping provider for when generating an MXID.
     _MAP_USERNAME_RETRIES = 1000
 
+    # the time a UsernameMappingSession remains valid for
+    _MAPPING_SESSION_VALIDITY_PERIOD_MS = 15 * 60 * 1000
+
     def __init__(self, hs: "HomeServer"):
+        self._clock = hs.get_clock()
         self._store = hs.get_datastore()
         self._server_name = hs.hostname
         self._registration_handler = hs.get_registration_handler()
@@ -59,6 +97,9 @@ class SsoHandler:
         # a lock on the mappings
         self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
 
+        # a map from session id to session data
+        self._username_mapping_sessions = {}  # type: Dict[str, UsernameMappingSession]
+
     def render_error(
         self, request, error: str, error_description: Optional[str] = None
     ) -> None:
@@ -206,6 +247,18 @@ class SsoHandler:
             # Otherwise, generate a new user.
             if not user_id:
                 attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
+
+                if attributes.localpart is None:
+                    # the mapper doesn't return a username. bail out with a redirect to
+                    # the username picker.
+                    await self._redirect_to_username_picker(
+                        auth_provider_id,
+                        remote_user_id,
+                        attributes,
+                        client_redirect_url,
+                        extra_login_attributes,
+                    )
+
                 user_id = await self._register_mapped_user(
                     attributes,
                     auth_provider_id,
@@ -243,10 +296,8 @@ class SsoHandler:
             )
 
             if not attributes.localpart:
-                raise MappingException(
-                    "Error parsing SSO response: SSO mapping provider plugin "
-                    "did not return a localpart value"
-                )
+                # the mapper has not picked a localpart
+                return attributes
 
             # Check if this mxid already exists
             user_id = UserID(attributes.localpart, self._server_name).to_string()
@@ -261,6 +312,59 @@ class SsoHandler:
             )
         return attributes
 
+    async def _redirect_to_username_picker(
+        self,
+        auth_provider_id: str,
+        remote_user_id: str,
+        attributes: UserAttributes,
+        client_redirect_url: str,
+        extra_login_attributes: Optional[JsonDict],
+    ) -> NoReturn:
+        """Creates a UsernameMappingSession and redirects the browser
+
+        Called if the user mapping provider doesn't return a localpart for a new user.
+        Raises a RedirectException which redirects the browser to the username picker.
+
+        Args:
+            auth_provider_id: A unique identifier for this SSO provider, e.g.
+                "oidc" or "saml".
+
+            remote_user_id: The unique identifier from the SSO provider.
+
+            attributes: the user attributes returned by the user mapping provider.
+
+            client_redirect_url: The redirect URL passed in by the client, which we
+                will eventually redirect back to.
+
+            extra_login_attributes: An optional dictionary of extra
+                attributes to be provided to the client in the login response.
+
+        Raises:
+            RedirectException
+        """
+        session_id = random_string(16)
+        now = self._clock.time_msec()
+        session = UsernameMappingSession(
+            auth_provider_id=auth_provider_id,
+            remote_user_id=remote_user_id,
+            display_name=attributes.display_name,
+            emails=attributes.emails,
+            client_redirect_url=client_redirect_url,
+            expiry_time_ms=now + self._MAPPING_SESSION_VALIDITY_PERIOD_MS,
+            extra_login_attributes=extra_login_attributes,
+        )
+
+        self._username_mapping_sessions[session_id] = session
+        logger.info("Recorded registration session id %s", session_id)
+
+        # Set the cookie and redirect to the username picker
+        e = RedirectException(b"/_synapse/client/pick_username")
+        e.cookies.append(
+            b"%s=%s; path=/"
+            % (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii"))
+        )
+        raise e
+
     async def _register_mapped_user(
         self,
         attributes: UserAttributes,
@@ -269,9 +373,38 @@ class SsoHandler:
         user_agent: str,
         ip_address: str,
     ) -> str:
+        """Register a new SSO user.
+
+        This is called once we have successfully mapped the remote user id onto a local
+        user id, one way or another.
+
+        Args:
+             attributes: user attributes returned by the user mapping provider,
+                including a non-empty localpart.
+
+            auth_provider_id: A unique identifier for this SSO provider, e.g.
+                "oidc" or "saml".
+
+            remote_user_id: The unique identifier from the SSO provider.
+
+            user_agent: The user-agent in the HTTP request (used for potential
+                shadow-banning.)
+
+            ip_address: The IP address of the requester (used for potential
+                shadow-banning.)
+
+        Raises:
+            a MappingException if the localpart is invalid.
+
+            a SynapseError with code 400 and errcode Codes.USER_IN_USE if the localpart
+            is already taken.
+        """
+
         # Since the localpart is provided via a potentially untrusted module,
         # ensure the MXID is valid before registering.
-        if contains_invalid_mxid_characters(attributes.localpart):
+        if not attributes.localpart or contains_invalid_mxid_characters(
+            attributes.localpart
+        ):
             raise MappingException("localpart is invalid: %s" % (attributes.localpart,))
 
         logger.debug("Mapped SSO user to local part %s", attributes.localpart)
@@ -326,3 +459,108 @@ class SsoHandler:
         await self._auth_handler.complete_sso_ui_auth(
             user_id, ui_auth_session_id, request
         )
+
+    async def check_username_availability(
+        self, localpart: str, session_id: str,
+    ) -> bool:
+        """Handle an "is username available" callback check
+
+        Args:
+            localpart: desired localpart
+            session_id: the session id for the username picker
+        Returns:
+            True if the username is available
+        Raises:
+            SynapseError if the localpart is invalid or the session is unknown
+        """
+
+        # make sure that there is a valid mapping session, to stop people dictionary-
+        # scanning for accounts
+
+        self._expire_old_sessions()
+        session = self._username_mapping_sessions.get(session_id)
+        if not session:
+            logger.info("Couldn't find session id %s", session_id)
+            raise SynapseError(400, "unknown session")
+
+        logger.info(
+            "[session %s] Checking for availability of username %s",
+            session_id,
+            localpart,
+        )
+
+        if contains_invalid_mxid_characters(localpart):
+            raise SynapseError(400, "localpart is invalid: %s" % (localpart,))
+        user_id = UserID(localpart, self._server_name).to_string()
+        user_infos = await self._store.get_users_by_id_case_insensitive(user_id)
+
+        logger.info("[session %s] users: %s", session_id, user_infos)
+        return not user_infos
+
+    async def handle_submit_username_request(
+        self, request: SynapseRequest, localpart: str, session_id: str
+    ) -> None:
+        """Handle a request to the username-picker 'submit' endpoint
+
+        Will serve an HTTP response to the request.
+
+        Args:
+            request: HTTP request
+            localpart: localpart requested by the user
+            session_id: ID of the username mapping session, extracted from a cookie
+        """
+        self._expire_old_sessions()
+        session = self._username_mapping_sessions.get(session_id)
+        if not session:
+            logger.info("Couldn't find session id %s", session_id)
+            raise SynapseError(400, "unknown session")
+
+        logger.info("[session %s] Registering localpart %s", session_id, localpart)
+
+        attributes = UserAttributes(
+            localpart=localpart,
+            display_name=session.display_name,
+            emails=session.emails,
+        )
+
+        # the following will raise a 400 error if the username has been taken in the
+        # meantime.
+        user_id = await self._register_mapped_user(
+            attributes,
+            session.auth_provider_id,
+            session.remote_user_id,
+            request.get_user_agent(""),
+            request.getClientIP(),
+        )
+
+        logger.info("[session %s] Registered userid %s", session_id, user_id)
+
+        # delete the mapping session and the cookie
+        del self._username_mapping_sessions[session_id]
+
+        # delete the cookie
+        request.addCookie(
+            USERNAME_MAPPING_SESSION_COOKIE_NAME,
+            b"",
+            expires=b"Thu, 01 Jan 1970 00:00:00 GMT",
+            path=b"/",
+        )
+
+        await self._auth_handler.complete_sso_login(
+            user_id,
+            request,
+            session.client_redirect_url,
+            session.extra_login_attributes,
+        )
+
+    def _expire_old_sessions(self):
+        to_expire = []
+        now = int(self._clock.time_msec())
+
+        for session_id, session in self._username_mapping_sessions.items():
+            if session.expiry_time_ms <= now:
+                to_expire.append(session_id)
+
+        for session_id in to_expire:
+            logger.info("Expiring mapping session %s", session_id)
+            del self._username_mapping_sessions[session_id]
diff --git a/synapse/res/username_picker/index.html b/synapse/res/username_picker/index.html
new file mode 100644
index 0000000000..37ea8bb6d8
--- /dev/null
+++ b/synapse/res/username_picker/index.html
@@ -0,0 +1,19 @@
+<!DOCTYPE html>
+<html lang="en">
+  <head>
+    <title>Synapse Login</title>
+    <link rel="stylesheet" href="style.css" type="text/css" />
+  </head>
+  <body>
+    <div class="card">
+      <form method="post" class="form__input" id="form" action="submit">
+        <label for="field-username">Please pick your username:</label>
+        <input type="text" name="username" id="field-username" autofocus="">
+        <input type="submit" class="button button--full-width" id="button-submit" value="Submit">
+      </form>
+      <!-- this is used for feedback -->
+      <div role=alert class="tooltip hidden" id="message"></div>
+      <script src="script.js"></script>
+    </div>
+  </body>
+</html>
diff --git a/synapse/res/username_picker/script.js b/synapse/res/username_picker/script.js
new file mode 100644
index 0000000000..416a7c6f41
--- /dev/null
+++ b/synapse/res/username_picker/script.js
@@ -0,0 +1,95 @@
+let inputField = document.getElementById("field-username");
+let inputForm = document.getElementById("form");
+let submitButton = document.getElementById("button-submit");
+let message = document.getElementById("message");
+
+// Submit username and receive response
+function showMessage(messageText) {
+    // Unhide the message text
+    message.classList.remove("hidden");
+
+    message.textContent = messageText;
+};
+
+function doSubmit() {
+    showMessage("Success. Please wait a moment for your browser to redirect.");
+
+    // remove the event handler before re-submitting the form.
+    delete inputForm.onsubmit;
+    inputForm.submit();
+}
+
+function onResponse(response) {
+    // Display message
+    showMessage(response);
+
+    // Enable submit button and input field
+    submitButton.classList.remove('button--disabled');
+    submitButton.value = "Submit";
+};
+
+let allowedUsernameCharacters = RegExp("[^a-z0-9\\.\\_\\=\\-\\/]");
+function usernameIsValid(username) {
+    return !allowedUsernameCharacters.test(username);
+}
+let allowedCharactersString = "lowercase letters, digits, ., _, -, /, =";
+
+function buildQueryString(params) {
+    return Object.keys(params)
+        .map(k => encodeURIComponent(k) + '=' + encodeURIComponent(params[k]))
+        .join('&');
+}
+
+function submitUsername(username) {
+    if(username.length == 0) {
+        onResponse("Please enter a username.");
+        return;
+    }
+    if(!usernameIsValid(username)) {
+        onResponse("Invalid username. Only the following characters are allowed: " + allowedCharactersString);
+        return;
+    }
+
+    // if this browser doesn't support fetch, skip the availability check.
+    if(!window.fetch) {
+        doSubmit();
+        return;
+    }
+
+    let check_uri = 'check?' + buildQueryString({"username": username});
+    fetch(check_uri, {
+        // include the cookie
+        "credentials": "same-origin",
+    }).then((response) => {
+        if(!response.ok) {
+            // for non-200 responses, raise the body of the response as an exception
+            return response.text().then((text) => { throw text; });
+        } else {
+            return response.json();
+        }
+    }).then((json) => {
+        if(json.error) {
+            throw json.error;
+        } else if(json.available) {
+            doSubmit();
+        } else {
+            onResponse("This username is not available, please choose another.");
+        }
+    }).catch((err) => {
+        onResponse("Error checking username availability: " + err);
+    });
+}
+
+function clickSubmit() {
+    event.preventDefault();
+    if(submitButton.classList.contains('button--disabled')) { return; }
+
+    // Disable submit button and input field
+    submitButton.classList.add('button--disabled');
+
+    // Submit username
+    submitButton.value = "Checking...";
+    submitUsername(inputField.value);
+};
+
+inputForm.onsubmit = clickSubmit;
diff --git a/synapse/res/username_picker/style.css b/synapse/res/username_picker/style.css
new file mode 100644
index 0000000000..745bd4c684
--- /dev/null
+++ b/synapse/res/username_picker/style.css
@@ -0,0 +1,27 @@
+input[type="text"] {
+  font-size: 100%;
+  background-color: #ededf0;
+  border: 1px solid #fff;
+  border-radius: .2em;
+  padding: .5em .9em;
+  display: block;
+  width: 26em;
+}
+
+.button--disabled {
+  border-color: #fff;
+  background-color: transparent;
+  color: #000;
+  text-transform: none;
+}
+
+.hidden {
+  display: none;
+}
+
+.tooltip {
+  background-color: #f9f9fa;
+  padding: 1em;
+  margin: 1em 0;
+}
+
diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py
new file mode 100644
index 0000000000..d3b6803e65
--- /dev/null
+++ b/synapse/rest/synapse/client/pick_username.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+import pkg_resources
+
+from twisted.web.http import Request
+from twisted.web.resource import Resource
+from twisted.web.static import File
+
+from synapse.api.errors import SynapseError
+from synapse.handlers.sso import USERNAME_MAPPING_SESSION_COOKIE_NAME
+from synapse.http.server import DirectServeHtmlResource, DirectServeJsonResource
+from synapse.http.servlet import parse_string
+from synapse.http.site import SynapseRequest
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+
+def pick_username_resource(hs: "HomeServer") -> Resource:
+    """Factory method to generate the username picker resource.
+
+    This resource gets mounted under /_synapse/client/pick_username. The top-level
+    resource is just a File resource which serves up the static files in the resources
+    "res" directory, but it has a couple of children:
+
+    * "submit", which does the mechanics of registering the new user, and redirects the
+      browser back to the client URL
+
+    * "check": checks if a userid is free.
+    """
+
+    # XXX should we make this path customisable so that admins can restyle it?
+    base_path = pkg_resources.resource_filename("synapse", "res/username_picker")
+
+    res = File(base_path)
+    res.putChild(b"submit", SubmitResource(hs))
+    res.putChild(b"check", AvailabilityCheckResource(hs))
+
+    return res
+
+
+class AvailabilityCheckResource(DirectServeJsonResource):
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self._sso_handler = hs.get_sso_handler()
+
+    async def _async_render_GET(self, request: Request):
+        localpart = parse_string(request, "username", required=True)
+
+        session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
+        if not session_id:
+            raise SynapseError(code=400, msg="missing session_id")
+
+        is_available = await self._sso_handler.check_username_availability(
+            localpart, session_id.decode("ascii", errors="replace")
+        )
+        return 200, {"available": is_available}
+
+
+class SubmitResource(DirectServeHtmlResource):
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self._sso_handler = hs.get_sso_handler()
+
+    async def _async_render_POST(self, request: SynapseRequest):
+        localpart = parse_string(request, "username", required=True)
+
+        session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
+        if not session_id:
+            raise SynapseError(code=400, msg="missing session_id")
+
+        await self._sso_handler.handle_submit_username_request(
+            request, localpart, session_id.decode("ascii", errors="replace")
+        )
diff --git a/synapse/types.py b/synapse/types.py
index 3ab6bdbe06..c7d4e95809 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -349,15 +349,17 @@ NON_MXID_CHARACTER_PATTERN = re.compile(
 )
 
 
-def map_username_to_mxid_localpart(username, case_sensitive=False):
+def map_username_to_mxid_localpart(
+    username: Union[str, bytes], case_sensitive: bool = False
+) -> str:
     """Map a username onto a string suitable for a MXID
 
     This follows the algorithm laid out at
     https://matrix.org/docs/spec/appendices.html#mapping-from-other-character-sets.
 
     Args:
-        username (unicode|bytes): username to be mapped
-        case_sensitive (bool): true if TEST and test should be mapped
+        username: username to be mapped
+        case_sensitive: true if TEST and test should be mapped
             onto different mxids
 
     Returns: