summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorCallum Brown <callum@calcuode.com>2021-08-21 22:14:43 +0100
committerGitHub <noreply@github.com>2021-08-21 22:14:43 +0100
commit947dbbdfd1e0029da66f956d277b7c089928e1e7 (patch)
tree57cf53bcbb1f02e75e114cde5d0aa77662163038 /synapse
parentFlatten tests/rest/client/{v1,v2_alpha} too (#10667) (diff)
downloadsynapse-947dbbdfd1e0029da66f956d277b7c089928e1e7.tar.xz
Implement MSC3231: Token authenticated registration (#10142)
Signed-off-by: Callum Brown <callum@calcuode.com>

This is part of my GSoC project implementing [MSC3231](https://github.com/matrix-org/matrix-doc/pull/3231).
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/constants.py1
-rw-r--r--synapse/app/generic_worker.py6
-rw-r--r--synapse/config/ratelimiting.py11
-rw-r--r--synapse/config/registration.py15
-rw-r--r--synapse/handlers/ui_auth/__init__.py5
-rw-r--r--synapse/handlers/ui_auth/checkers.py65
-rw-r--r--synapse/res/templates/registration_token.html23
-rw-r--r--synapse/rest/admin/__init__.py8
-rw-r--r--synapse/rest/admin/registration_tokens.py321
-rw-r--r--synapse/rest/client/auth.py24
-rw-r--r--synapse/rest/client/register.py72
-rw-r--r--synapse/storage/databases/main/registration.py316
-rw-r--r--synapse/storage/databases/main/ui_auth.py43
-rw-r--r--synapse/storage/schema/main/delta/63/01create_registration_tokens.sql23
14 files changed, 932 insertions, 1 deletions
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index e0e24fddac..829061c870 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -79,6 +79,7 @@ class LoginType:
     TERMS = "m.login.terms"
     SSO = "m.login.sso"
     DUMMY = "m.login.dummy"
+    REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token"
 
 
 # This is used in the `type` parameter for /register when called by
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 845e6a8220..fd2626dbe1 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -95,7 +95,10 @@ from synapse.rest.client.profile import (
     ProfileRestServlet,
 )
 from synapse.rest.client.push_rule import PushRuleRestServlet
-from synapse.rest.client.register import RegisterRestServlet
+from synapse.rest.client.register import (
+    RegisterRestServlet,
+    RegistrationTokenValidityRestServlet,
+)
 from synapse.rest.client.sendtodevice import SendToDeviceRestServlet
 from synapse.rest.client.versions import VersionsRestServlet
 from synapse.rest.client.voip import VoipRestServlet
@@ -279,6 +282,7 @@ class GenericWorkerServer(HomeServer):
                     resource = JsonResource(self, canonical_json=False)
 
                     RegisterRestServlet(self).register(resource)
+                    RegistrationTokenValidityRestServlet(self).register(resource)
                     login.register_servlets(self, resource)
                     ThreepidRestServlet(self).register(resource)
                     DevicesRestServlet(self).register(resource)
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 7a8d5851c4..f856327bd8 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -79,6 +79,11 @@ class RatelimitConfig(Config):
 
         self.rc_registration = RateLimitConfig(config.get("rc_registration", {}))
 
+        self.rc_registration_token_validity = RateLimitConfig(
+            config.get("rc_registration_token_validity", {}),
+            defaults={"per_second": 0.1, "burst_count": 5},
+        )
+
         rc_login_config = config.get("rc_login", {})
         self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {}))
         self.rc_login_account = RateLimitConfig(rc_login_config.get("account", {}))
@@ -143,6 +148,8 @@ class RatelimitConfig(Config):
         #     is using
         #   - one for registration that ratelimits registration requests based on the
         #     client's IP address.
+        #   - one for checking the validity of registration tokens that ratelimits
+        #     requests based on the client's IP address.
         #   - one for login that ratelimits login requests based on the client's IP
         #     address.
         #   - one for login that ratelimits login requests based on the account the
@@ -171,6 +178,10 @@ class RatelimitConfig(Config):
         #  per_second: 0.17
         #  burst_count: 3
         #
+        #rc_registration_token_validity:
+        #  per_second: 0.1
+        #  burst_count: 5
+        #
         #rc_login:
         #  address:
         #    per_second: 0.17
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 0ad919b139..7cffdacfa5 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -33,6 +33,9 @@ class RegistrationConfig(Config):
         self.registrations_require_3pid = config.get("registrations_require_3pid", [])
         self.allowed_local_3pids = config.get("allowed_local_3pids", [])
         self.enable_3pid_lookup = config.get("enable_3pid_lookup", True)
+        self.registration_requires_token = config.get(
+            "registration_requires_token", False
+        )
         self.registration_shared_secret = config.get("registration_shared_secret")
 
         self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
@@ -140,6 +143,9 @@ class RegistrationConfig(Config):
                 "mechanism by removing the `access_token_lifetime` option."
             )
 
+        # The fallback template used for authenticating using a registration token
+        self.registration_token_template = self.read_template("registration_token.html")
+
         # The success template used during fallback auth.
         self.fallback_success_template = self.read_template("auth_success.html")
 
@@ -199,6 +205,15 @@ class RegistrationConfig(Config):
         #
         #enable_3pid_lookup: true
 
+        # Require users to submit a token during registration.
+        # Tokens can be managed using the admin API:
+        # https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/registration_tokens.html
+        # Note that `enable_registration` must be set to `true`.
+        # Disabling this option will not delete any tokens previously generated.
+        # Defaults to false. Uncomment the following to require tokens:
+        #
+        #registration_requires_token: true
+
         # If set, allows registration of standard or admin accounts by anyone who
         # has the shared secret, even if registration is otherwise disabled.
         #
diff --git a/synapse/handlers/ui_auth/__init__.py b/synapse/handlers/ui_auth/__init__.py
index 4c3b669fae..13b0c61d2e 100644
--- a/synapse/handlers/ui_auth/__init__.py
+++ b/synapse/handlers/ui_auth/__init__.py
@@ -34,3 +34,8 @@ class UIAuthSessionDataConstants:
     # used by validate_user_via_ui_auth to store the mxid of the user we are validating
     # for.
     REQUEST_USER_ID = "request_user_id"
+
+    # used during registration to store the registration token used (if required) so that:
+    # - we can prevent a token being used twice by one session
+    # - we can 'use up' the token after registration has successfully completed
+    REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token"
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 270541cc76..d3828dec6b 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -241,11 +241,76 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
         return await self._check_threepid("msisdn", authdict)
 
 
+class RegistrationTokenAuthChecker(UserInteractiveAuthChecker):
+    AUTH_TYPE = LoginType.REGISTRATION_TOKEN
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__(hs)
+        self.hs = hs
+        self._enabled = bool(hs.config.registration_requires_token)
+        self.store = hs.get_datastore()
+
+    def is_enabled(self) -> bool:
+        return self._enabled
+
+    async def check_auth(self, authdict: dict, clientip: str) -> Any:
+        if "token" not in authdict:
+            raise LoginError(400, "Missing registration token", Codes.MISSING_PARAM)
+        if not isinstance(authdict["token"], str):
+            raise LoginError(
+                400, "Registration token must be a string", Codes.INVALID_PARAM
+            )
+        if "session" not in authdict:
+            raise LoginError(400, "Missing UIA session", Codes.MISSING_PARAM)
+
+        # Get these here to avoid cyclic dependencies
+        from synapse.handlers.ui_auth import UIAuthSessionDataConstants
+
+        auth_handler = self.hs.get_auth_handler()
+
+        session = authdict["session"]
+        token = authdict["token"]
+
+        # If the LoginType.REGISTRATION_TOKEN stage has already been completed,
+        # return early to avoid incrementing `pending` again.
+        stored_token = await auth_handler.get_session_data(
+            session, UIAuthSessionDataConstants.REGISTRATION_TOKEN
+        )
+        if stored_token:
+            if token != stored_token:
+                raise LoginError(
+                    400, "Registration token has changed", Codes.INVALID_PARAM
+                )
+            else:
+                return token
+
+        if await self.store.registration_token_is_valid(token):
+            # Increment pending counter, so that if token has limited uses it
+            # can't be used up by someone else in the meantime.
+            await self.store.set_registration_token_pending(token)
+            # Store the token in the UIA session, so that once registration
+            # is complete `completed` can be incremented.
+            await auth_handler.set_session_data(
+                session,
+                UIAuthSessionDataConstants.REGISTRATION_TOKEN,
+                token,
+            )
+            # The token will be stored as the result of the authentication stage
+            # in ui_auth_sessions_credentials. This allows the pending counter
+            # for tokens to be decremented when expired sessions are deleted.
+            return token
+        else:
+            raise LoginError(
+                401, "Invalid registration token", errcode=Codes.UNAUTHORIZED
+            )
+
+
 INTERACTIVE_AUTH_CHECKERS = [
     DummyAuthChecker,
     TermsAuthChecker,
     RecaptchaAuthChecker,
     EmailIdentityAuthChecker,
     MsisdnAuthChecker,
+    RegistrationTokenAuthChecker,
 ]
 """A list of UserInteractiveAuthChecker classes"""
diff --git a/synapse/res/templates/registration_token.html b/synapse/res/templates/registration_token.html
new file mode 100644
index 0000000000..4577ce1702
--- /dev/null
+++ b/synapse/res/templates/registration_token.html
@@ -0,0 +1,23 @@
+<html>
+<head>
+<title>Authentication</title>
+<meta name='viewport' content='width=device-width, initial-scale=1,
+    user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
+<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
+</head>
+<body>
+<form id="registrationForm" method="post" action="{{ myurl }}">
+    <div>
+        {% if error is defined %}
+            <p class="error"><strong>Error: {{ error }}</strong></p>
+        {% endif %}
+        <p>
+           Please enter a registration token.
+        </p>
+        <input type="hidden" name="session" value="{{ session }}" />
+        <input type="text" name="token" />
+        <input type="submit" value="Authenticate" />
+    </div>
+</form>
+</body>
+</html>
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 7f3051aef1..6e1c8736e1 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -36,6 +36,11 @@ from synapse.rest.admin.event_reports import (
 )
 from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
 from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
+from synapse.rest.admin.registration_tokens import (
+    ListRegistrationTokensRestServlet,
+    NewRegistrationTokenRestServlet,
+    RegistrationTokenRestServlet,
+)
 from synapse.rest.admin.rooms import (
     DeleteRoomRestServlet,
     ForwardExtremitiesRestServlet,
@@ -238,6 +243,9 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     RoomEventContextServlet(hs).register(http_server)
     RateLimitRestServlet(hs).register(http_server)
     UsernameAvailableRestServlet(hs).register(http_server)
+    ListRegistrationTokensRestServlet(hs).register(http_server)
+    NewRegistrationTokenRestServlet(hs).register(http_server)
+    RegistrationTokenRestServlet(hs).register(http_server)
 
 
 def register_servlets_for_client_rest_resource(
diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py
new file mode 100644
index 0000000000..5a1c929d85
--- /dev/null
+++ b/synapse/rest/admin/registration_tokens.py
@@ -0,0 +1,321 @@
+# Copyright 2021 Callum Brown
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import string
+from typing import TYPE_CHECKING, Tuple
+
+from synapse.api.errors import Codes, NotFoundError, SynapseError
+from synapse.http.servlet import (
+    RestServlet,
+    parse_boolean,
+    parse_json_object_from_request,
+)
+from synapse.http.site import SynapseRequest
+from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class ListRegistrationTokensRestServlet(RestServlet):
+    """List registration tokens.
+
+    To list all tokens:
+
+        GET /_synapse/admin/v1/registration_tokens
+
+        200 OK
+
+        {
+            "registration_tokens": [
+                {
+                    "token": "abcd",
+                    "uses_allowed": 3,
+                    "pending": 0,
+                    "completed": 1,
+                    "expiry_time": null
+                },
+                {
+                    "token": "wxyz",
+                    "uses_allowed": null,
+                    "pending": 0,
+                    "completed": 9,
+                    "expiry_time": 1625394937000
+                }
+            ]
+        }
+
+    The optional query parameter `valid` can be used to filter the response.
+    If it is `true`, only valid tokens are returned. If it is `false`, only
+    tokens that have expired or have had all uses exhausted are returned.
+    If it is omitted, all tokens are returned regardless of validity.
+    """
+
+    PATTERNS = admin_patterns("/registration_tokens$")
+
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
+
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        await assert_requester_is_admin(self.auth, request)
+        valid = parse_boolean(request, "valid")
+        token_list = await self.store.get_registration_tokens(valid)
+        return 200, {"registration_tokens": token_list}
+
+
+class NewRegistrationTokenRestServlet(RestServlet):
+    """Create a new registration token.
+
+    For example, to create a token specifying some fields:
+
+        POST /_synapse/admin/v1/registration_tokens/new
+
+        {
+            "token": "defg",
+            "uses_allowed": 1
+        }
+
+        200 OK
+
+        {
+            "token": "defg",
+            "uses_allowed": 1,
+            "pending": 0,
+            "completed": 0,
+            "expiry_time": null
+        }
+
+    Defaults are used for any fields not specified.
+    """
+
+    PATTERNS = admin_patterns("/registration_tokens/new$")
+
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
+        self.clock = hs.get_clock()
+        # A string of all the characters allowed to be in a registration_token
+        self.allowed_chars = string.ascii_letters + string.digits + "-_"
+        self.allowed_chars_set = set(self.allowed_chars)
+
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        await assert_requester_is_admin(self.auth, request)
+        body = parse_json_object_from_request(request)
+
+        if "token" in body:
+            token = body["token"]
+            if not isinstance(token, str):
+                raise SynapseError(400, "token must be a string", Codes.INVALID_PARAM)
+            if not (0 < len(token) <= 64):
+                raise SynapseError(
+                    400,
+                    "token must not be empty and must not be longer than 64 characters",
+                    Codes.INVALID_PARAM,
+                )
+            if not set(token).issubset(self.allowed_chars_set):
+                raise SynapseError(
+                    400,
+                    "token must consist only of characters matched by the regex [A-Za-z0-9-_]",
+                    Codes.INVALID_PARAM,
+                )
+
+        else:
+            # Get length of token to generate (default is 16)
+            length = body.get("length", 16)
+            if not isinstance(length, int):
+                raise SynapseError(
+                    400, "length must be an integer", Codes.INVALID_PARAM
+                )
+            if not (0 < length <= 64):
+                raise SynapseError(
+                    400,
+                    "length must be greater than zero and not greater than 64",
+                    Codes.INVALID_PARAM,
+                )
+
+            # Generate token
+            token = await self.store.generate_registration_token(
+                length, self.allowed_chars
+            )
+
+        uses_allowed = body.get("uses_allowed", None)
+        if not (
+            uses_allowed is None
+            or (isinstance(uses_allowed, int) and uses_allowed >= 0)
+        ):
+            raise SynapseError(
+                400,
+                "uses_allowed must be a non-negative integer or null",
+                Codes.INVALID_PARAM,
+            )
+
+        expiry_time = body.get("expiry_time", None)
+        if not isinstance(expiry_time, (int, type(None))):
+            raise SynapseError(
+                400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
+            )
+        if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
+            raise SynapseError(
+                400, "expiry_time must not be in the past", Codes.INVALID_PARAM
+            )
+
+        created = await self.store.create_registration_token(
+            token, uses_allowed, expiry_time
+        )
+        if not created:
+            raise SynapseError(
+                400, f"Token already exists: {token}", Codes.INVALID_PARAM
+            )
+
+        resp = {
+            "token": token,
+            "uses_allowed": uses_allowed,
+            "pending": 0,
+            "completed": 0,
+            "expiry_time": expiry_time,
+        }
+        return 200, resp
+
+
+class RegistrationTokenRestServlet(RestServlet):
+    """Retrieve, update, or delete the given token.
+
+    For example,
+
+    to retrieve a token:
+
+        GET /_synapse/admin/v1/registration_tokens/abcd
+
+        200 OK
+
+        {
+            "token": "abcd",
+            "uses_allowed": 3,
+            "pending": 0,
+            "completed": 1,
+            "expiry_time": null
+        }
+
+
+    to update a token:
+
+        PUT /_synapse/admin/v1/registration_tokens/defg
+
+        {
+            "uses_allowed": 5,
+            "expiry_time": 4781243146000
+        }
+
+        200 OK
+
+        {
+            "token": "defg",
+            "uses_allowed": 5,
+            "pending": 0,
+            "completed": 0,
+            "expiry_time": 4781243146000
+        }
+
+
+    to delete a token:
+
+        DELETE /_synapse/admin/v1/registration_tokens/wxyz
+
+        200 OK
+
+        {}
+    """
+
+    PATTERNS = admin_patterns("/registration_tokens/(?P<token>[^/]*)$")
+
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+        self.clock = hs.get_clock()
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
+
+    async def on_GET(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]:
+        """Retrieve a registration token."""
+        await assert_requester_is_admin(self.auth, request)
+        token_info = await self.store.get_one_registration_token(token)
+
+        # If no result return a 404
+        if token_info is None:
+            raise NotFoundError(f"No such registration token: {token}")
+
+        return 200, token_info
+
+    async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]:
+        """Update a registration token."""
+        await assert_requester_is_admin(self.auth, request)
+        body = parse_json_object_from_request(request)
+        new_attributes = {}
+
+        # Only add uses_allowed to new_attributes if it is present and valid
+        if "uses_allowed" in body:
+            uses_allowed = body["uses_allowed"]
+            if not (
+                uses_allowed is None
+                or (isinstance(uses_allowed, int) and uses_allowed >= 0)
+            ):
+                raise SynapseError(
+                    400,
+                    "uses_allowed must be a non-negative integer or null",
+                    Codes.INVALID_PARAM,
+                )
+            new_attributes["uses_allowed"] = uses_allowed
+
+        if "expiry_time" in body:
+            expiry_time = body["expiry_time"]
+            if not isinstance(expiry_time, (int, type(None))):
+                raise SynapseError(
+                    400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
+                )
+            if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
+                raise SynapseError(
+                    400, "expiry_time must not be in the past", Codes.INVALID_PARAM
+                )
+            new_attributes["expiry_time"] = expiry_time
+
+        if len(new_attributes) == 0:
+            # Nothing to update, get token info to return
+            token_info = await self.store.get_one_registration_token(token)
+        else:
+            token_info = await self.store.update_registration_token(
+                token, new_attributes
+            )
+
+        # If no result return a 404
+        if token_info is None:
+            raise NotFoundError(f"No such registration token: {token}")
+
+        return 200, token_info
+
+    async def on_DELETE(
+        self, request: SynapseRequest, token: str
+    ) -> Tuple[int, JsonDict]:
+        """Delete a registration token."""
+        await assert_requester_is_admin(self.auth, request)
+
+        if await self.store.delete_registration_token(token):
+            return 200, {}
+
+        raise NotFoundError(f"No such registration token: {token}")
diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py
index 73284e48ec..91800c0278 100644
--- a/synapse/rest/client/auth.py
+++ b/synapse/rest/client/auth.py
@@ -46,6 +46,7 @@ class AuthRestServlet(RestServlet):
         self.registration_handler = hs.get_registration_handler()
         self.recaptcha_template = hs.config.recaptcha_template
         self.terms_template = hs.config.terms_template
+        self.registration_token_template = hs.config.registration_token_template
         self.success_template = hs.config.fallback_success_template
 
     async def on_GET(self, request, stagetype):
@@ -74,6 +75,12 @@ class AuthRestServlet(RestServlet):
             # re-authenticate with their SSO provider.
             html = await self.auth_handler.start_sso_ui_auth(request, session)
 
+        elif stagetype == LoginType.REGISTRATION_TOKEN:
+            html = self.registration_token_template.render(
+                session=session,
+                myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web",
+            )
+
         else:
             raise SynapseError(404, "Unknown auth stage type")
 
@@ -140,6 +147,23 @@ class AuthRestServlet(RestServlet):
             # The SSO fallback workflow should not post here,
             raise SynapseError(404, "Fallback SSO auth does not support POST requests.")
 
+        elif stagetype == LoginType.REGISTRATION_TOKEN:
+            token = parse_string(request, "token", required=True)
+            authdict = {"session": session, "token": token}
+
+            try:
+                await self.auth_handler.add_oob_auth(
+                    LoginType.REGISTRATION_TOKEN, authdict, request.getClientIP()
+                )
+            except LoginError as e:
+                html = self.registration_token_template.render(
+                    session=session,
+                    myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web",
+                    error=e.msg,
+                )
+            else:
+                html = self.success_template.render()
+
         else:
             raise SynapseError(404, "Unknown auth stage type")
 
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 58b8e8f261..2781a0ea96 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -28,6 +28,7 @@ from synapse.api.errors import (
     ThreepidValidationError,
     UnrecognizedRequestError,
 )
+from synapse.api.ratelimiting import Ratelimiter
 from synapse.config import ConfigError
 from synapse.config.captcha import CaptchaConfig
 from synapse.config.consent import ConsentConfig
@@ -379,6 +380,55 @@ class UsernameAvailabilityRestServlet(RestServlet):
             return 200, {"available": True}
 
 
+class RegistrationTokenValidityRestServlet(RestServlet):
+    """Check the validity of a registration token.
+
+    Example:
+
+        GET /_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity?token=abcd
+
+        200 OK
+
+        {
+            "valid": true
+        }
+    """
+
+    PATTERNS = client_patterns(
+        f"/org.matrix.msc3231/register/{LoginType.REGISTRATION_TOKEN}/validity",
+        releases=(),
+        unstable=True,
+    )
+
+    def __init__(self, hs):
+        """
+        Args:
+            hs (synapse.server.HomeServer): server
+        """
+        super().__init__()
+        self.hs = hs
+        self.store = hs.get_datastore()
+        self.ratelimiter = Ratelimiter(
+            store=self.store,
+            clock=hs.get_clock(),
+            rate_hz=hs.config.ratelimiting.rc_registration_token_validity.per_second,
+            burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count,
+        )
+
+    async def on_GET(self, request):
+        await self.ratelimiter.ratelimit(None, (request.getClientIP(),))
+
+        if not self.hs.config.enable_registration:
+            raise SynapseError(
+                403, "Registration has been disabled", errcode=Codes.FORBIDDEN
+            )
+
+        token = parse_string(request, "token", required=True)
+        valid = await self.store.registration_token_is_valid(token)
+
+        return 200, {"valid": valid}
+
+
 class RegisterRestServlet(RestServlet):
     PATTERNS = client_patterns("/register$")
 
@@ -686,6 +736,22 @@ class RegisterRestServlet(RestServlet):
         )
 
         if registered:
+            # Check if a token was used to authenticate registration
+            registration_token = await self.auth_handler.get_session_data(
+                session_id,
+                UIAuthSessionDataConstants.REGISTRATION_TOKEN,
+            )
+            if registration_token:
+                # Increment the `completed` counter for the token
+                await self.store.use_registration_token(registration_token)
+                # Indicate that the token has been successfully used so that
+                # pending is not decremented again when expiring old UIA sessions.
+                await self.store.mark_ui_auth_stage_complete(
+                    session_id,
+                    LoginType.REGISTRATION_TOKEN,
+                    True,
+                )
+
             await self.registration_handler.post_registration_actions(
                 user_id=registered_user_id,
                 auth_result=auth_result,
@@ -868,6 +934,11 @@ def _calculate_registration_flows(
         for flow in flows:
             flow.insert(0, LoginType.RECAPTCHA)
 
+    # Prepend registration token to all flows if we're requiring a token
+    if config.registration_requires_token:
+        for flow in flows:
+            flow.insert(0, LoginType.REGISTRATION_TOKEN)
+
     return flows
 
 
@@ -876,4 +947,5 @@ def register_servlets(hs, http_server):
     MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
     UsernameAvailabilityRestServlet(hs).register(http_server)
     RegistrationSubmitTokenServlet(hs).register(http_server)
+    RegistrationTokenValidityRestServlet(hs).register(http_server)
     RegisterRestServlet(hs).register(http_server)
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 469dd53e0c..a6517962f6 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1168,6 +1168,322 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             desc="update_access_token_last_validated",
         )
 
+    async def registration_token_is_valid(self, token: str) -> bool:
+        """Checks if a token can be used to authenticate a registration.
+
+        Args:
+            token: The registration token to be checked
+        Returns:
+            True if the token is valid, False otherwise.
+        """
+        res = await self.db_pool.simple_select_one(
+            "registration_tokens",
+            keyvalues={"token": token},
+            retcols=["uses_allowed", "pending", "completed", "expiry_time"],
+            allow_none=True,
+        )
+
+        # Check if the token exists
+        if res is None:
+            return False
+
+        # Check if the token has expired
+        now = self._clock.time_msec()
+        if res["expiry_time"] and res["expiry_time"] < now:
+            return False
+
+        # Check if the token has been used up
+        if (
+            res["uses_allowed"]
+            and res["pending"] + res["completed"] >= res["uses_allowed"]
+        ):
+            return False
+
+        # Otherwise, the token is valid
+        return True
+
+    async def set_registration_token_pending(self, token: str) -> None:
+        """Increment the pending registrations counter for a token.
+
+        Args:
+            token: The registration token pending use
+        """
+
+        def _set_registration_token_pending_txn(txn):
+            pending = self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcol="pending",
+            )
+            self.db_pool.simple_update_one_txn(
+                txn,
+                "registration_tokens",
+                keyvalues={"token": token},
+                updatevalues={"pending": pending + 1},
+            )
+
+        return await self.db_pool.runInteraction(
+            "set_registration_token_pending", _set_registration_token_pending_txn
+        )
+
+    async def use_registration_token(self, token: str) -> None:
+        """Complete a use of the given registration token.
+
+        The `pending` counter will be decremented, and the `completed`
+        counter will be incremented.
+
+        Args:
+            token: The registration token to be 'used'
+        """
+
+        def _use_registration_token_txn(txn):
+            # Normally, res is Optional[Dict[str, Any]].
+            # Override type because the return type is only optional if
+            # allow_none is True, and we don't want mypy throwing errors
+            # about None not being indexable.
+            res: Dict[str, Any] = self.db_pool.simple_select_one_txn(
+                txn,
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcols=["pending", "completed"],
+            )  # type: ignore
+
+            # Decrement pending and increment completed
+            self.db_pool.simple_update_one_txn(
+                txn,
+                "registration_tokens",
+                keyvalues={"token": token},
+                updatevalues={
+                    "completed": res["completed"] + 1,
+                    "pending": res["pending"] - 1,
+                },
+            )
+
+        return await self.db_pool.runInteraction(
+            "use_registration_token", _use_registration_token_txn
+        )
+
+    async def get_registration_tokens(
+        self, valid: Optional[bool] = None
+    ) -> List[Dict[str, Any]]:
+        """List all registration tokens. Used by the admin API.
+
+        Args:
+            valid: If True, only valid tokens are returned.
+              If False, only invalid tokens are returned.
+              Default is None: return all tokens regardless of validity.
+
+        Returns:
+            A list of dicts, each containing details of a token.
+        """
+
+        def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]):
+            if valid is None:
+                # Return all tokens regardless of validity
+                txn.execute("SELECT * FROM registration_tokens")
+
+            elif valid:
+                # Select valid tokens only
+                sql = (
+                    "SELECT * FROM registration_tokens WHERE "
+                    "(uses_allowed > pending + completed OR uses_allowed IS NULL) "
+                    "AND (expiry_time > ? OR expiry_time IS NULL)"
+                )
+                txn.execute(sql, [now])
+
+            else:
+                # Select invalid tokens only
+                sql = (
+                    "SELECT * FROM registration_tokens WHERE "
+                    "uses_allowed <= pending + completed OR expiry_time <= ?"
+                )
+                txn.execute(sql, [now])
+
+            return self.db_pool.cursor_to_dict(txn)
+
+        return await self.db_pool.runInteraction(
+            "select_registration_tokens",
+            select_registration_tokens_txn,
+            self._clock.time_msec(),
+            valid,
+        )
+
+    async def get_one_registration_token(self, token: str) -> Optional[Dict[str, Any]]:
+        """Get info about the given registration token. Used by the admin API.
+
+        Args:
+            token: The token to retrieve information about.
+
+        Returns:
+            A dict, or None if token doesn't exist.
+        """
+        return await self.db_pool.simple_select_one(
+            "registration_tokens",
+            keyvalues={"token": token},
+            retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"],
+            allow_none=True,
+            desc="get_one_registration_token",
+        )
+
+    async def generate_registration_token(
+        self, length: int, chars: str
+    ) -> Optional[str]:
+        """Generate a random registration token. Used by the admin API.
+
+        Args:
+            length: The length of the token to generate.
+            chars: A string of the characters allowed in the generated token.
+
+        Returns:
+            The generated token.
+
+        Raises:
+            SynapseError if a unique registration token could still not be
+            generated after a few tries.
+        """
+        # Make a few attempts at generating a unique token of the required
+        # length before failing.
+        for _i in range(3):
+            # Generate token
+            token = "".join(random.choices(chars, k=length))
+
+            # Check if the token already exists
+            existing_token = await self.db_pool.simple_select_one_onecol(
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcol="token",
+                allow_none=True,
+                desc="check_if_registration_token_exists",
+            )
+
+            if existing_token is None:
+                # The generated token doesn't exist yet, return it
+                return token
+
+        raise SynapseError(
+            500,
+            "Unable to generate a unique registration token. Try again with a greater length",
+            Codes.UNKNOWN,
+        )
+
+    async def create_registration_token(
+        self, token: str, uses_allowed: Optional[int], expiry_time: Optional[int]
+    ) -> bool:
+        """Create a new registration token. Used by the admin API.
+
+        Args:
+            token: The token to create.
+            uses_allowed: The number of times the token can be used to complete
+              a registration before it becomes invalid. A value of None indicates
+              unlimited uses.
+            expiry_time: The latest time the token is valid. Given as the
+              number of milliseconds since 1970-01-01 00:00:00 UTC. A value of
+              None indicates that the token does not expire.
+
+        Returns:
+            Whether the row was inserted or not.
+        """
+
+        def _create_registration_token_txn(txn):
+            row = self.db_pool.simple_select_one_txn(
+                txn,
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcols=["token"],
+                allow_none=True,
+            )
+
+            if row is not None:
+                # Token already exists
+                return False
+
+            self.db_pool.simple_insert_txn(
+                txn,
+                "registration_tokens",
+                values={
+                    "token": token,
+                    "uses_allowed": uses_allowed,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": expiry_time,
+                },
+            )
+
+            return True
+
+        return await self.db_pool.runInteraction(
+            "create_registration_token", _create_registration_token_txn
+        )
+
+    async def update_registration_token(
+        self, token: str, updatevalues: Dict[str, Optional[int]]
+    ) -> Optional[Dict[str, Any]]:
+        """Update a registration token. Used by the admin API.
+
+        Args:
+            token: The token to update.
+            updatevalues: A dict with the fields to update. E.g.:
+              `{"uses_allowed": 3}` to update just uses_allowed, or
+              `{"uses_allowed": 3, "expiry_time": None}` to update both.
+              This is passed straight to simple_update_one.
+
+        Returns:
+            A dict with all info about the token, or None if token doesn't exist.
+        """
+
+        def _update_registration_token_txn(txn):
+            try:
+                self.db_pool.simple_update_one_txn(
+                    txn,
+                    "registration_tokens",
+                    keyvalues={"token": token},
+                    updatevalues=updatevalues,
+                )
+            except StoreError:
+                # Update failed because token does not exist
+                return None
+
+            # Get all info about the token so it can be sent in the response
+            return self.db_pool.simple_select_one_txn(
+                txn,
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcols=[
+                    "token",
+                    "uses_allowed",
+                    "pending",
+                    "completed",
+                    "expiry_time",
+                ],
+                allow_none=True,
+            )
+
+        return await self.db_pool.runInteraction(
+            "update_registration_token", _update_registration_token_txn
+        )
+
+    async def delete_registration_token(self, token: str) -> bool:
+        """Delete a registration token. Used by the admin API.
+
+        Args:
+            token: The token to delete.
+
+        Returns:
+            Whether the token was successfully deleted or not.
+        """
+        try:
+            await self.db_pool.simple_delete_one(
+                "registration_tokens",
+                keyvalues={"token": token},
+                desc="delete_registration_token",
+            )
+        except StoreError:
+            # Deletion failed because token does not exist
+            return False
+
+        return True
+
     @cached()
     async def mark_access_token_as_used(self, token_id: int) -> None:
         """
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 38bfdf5dad..4d6bbc94c7 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
 
 import attr
 
+from synapse.api.constants import LoginType
 from synapse.api.errors import StoreError
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import LoggingTransaction
@@ -329,6 +330,48 @@ class UIAuthWorkerStore(SQLBaseStore):
             keyvalues={},
         )
 
+        # If a registration token was used, decrement the pending counter
+        # before deleting the session.
+        rows = self.db_pool.simple_select_many_txn(
+            txn,
+            table="ui_auth_sessions_credentials",
+            column="session_id",
+            iterable=session_ids,
+            keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
+            retcols=["result"],
+        )
+
+        # Get the tokens used and how much pending needs to be decremented by.
+        token_counts: Dict[str, int] = {}
+        for r in rows:
+            # If registration was successfully completed, the result of the
+            # registration token stage for that session will be True.
+            # If a token was used to authenticate, but registration was
+            # never completed, the result will be the token used.
+            token = db_to_json(r["result"])
+            if isinstance(token, str):
+                token_counts[token] = token_counts.get(token, 0) + 1
+
+        # Update the `pending` counters.
+        if len(token_counts) > 0:
+            token_rows = self.db_pool.simple_select_many_txn(
+                txn,
+                table="registration_tokens",
+                column="token",
+                iterable=list(token_counts.keys()),
+                keyvalues={},
+                retcols=["token", "pending"],
+            )
+            for token_row in token_rows:
+                token = token_row["token"]
+                new_pending = token_row["pending"] - token_counts[token]
+                self.db_pool.simple_update_one_txn(
+                    txn,
+                    table="registration_tokens",
+                    keyvalues={"token": token},
+                    updatevalues={"pending": new_pending},
+                )
+
         # Delete the corresponding completed credentials.
         self.db_pool.simple_delete_many_txn(
             txn,
diff --git a/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql b/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql
new file mode 100644
index 0000000000..ee6cf958f4
--- /dev/null
+++ b/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql
@@ -0,0 +1,23 @@
+/* Copyright 2021 Callum Brown
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS registration_tokens(
+    token TEXT NOT NULL,  -- The token that can be used for authentication.
+    uses_allowed INT,  -- The total number of times this token can be used. NULL if no limit.
+    pending INT NOT NULL, -- The number of in progress registrations using this token.
+    completed INT NOT NULL, -- The number of times this token has been used to complete a registration.
+    expiry_time BIGINT,  -- The latest time this token will be valid (epoch time in milliseconds). NULL if token doesn't expire.
+    UNIQUE (token)
+);