diff --git a/changelog.d/10142.feature b/changelog.d/10142.feature
new file mode 100644
index 0000000000..5353f6269d
--- /dev/null
+++ b/changelog.d/10142.feature
@@ -0,0 +1 @@
+Add support for [MSC3231 - Token authenticated registration](https://github.com/matrix-org/matrix-doc/pull/3231). Users can be required to submit a token during registration to authenticate themselves. Contributed by Callum Brown.
diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md
index 634bc833ab..4fcd2b7852 100644
--- a/docs/SUMMARY.md
+++ b/docs/SUMMARY.md
@@ -53,6 +53,7 @@
- [Media](admin_api/media_admin_api.md)
- [Purge History](admin_api/purge_history_api.md)
- [Register Users](admin_api/register_api.md)
+ - [Registration Tokens](usage/administration/admin_api/registration_tokens.md)
- [Manipulate Room Membership](admin_api/room_membership.md)
- [Rooms](admin_api/rooms.md)
- [Server Notices](admin_api/server_notices.md)
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 2b0c453242..935841dbfa 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -793,6 +793,8 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# is using
# - one for registration that ratelimits registration requests based on the
# client's IP address.
+# - one for checking the validity of registration tokens that ratelimits
+# requests based on the client's IP address.
# - one for login that ratelimits login requests based on the client's IP
# address.
# - one for login that ratelimits login requests based on the account the
@@ -821,6 +823,10 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# per_second: 0.17
# burst_count: 3
#
+#rc_registration_token_validity:
+# per_second: 0.1
+# burst_count: 5
+#
#rc_login:
# address:
# per_second: 0.17
@@ -1169,6 +1175,15 @@ url_preview_accept_language:
#
#enable_3pid_lookup: true
+# Require users to submit a token during registration.
+# Tokens can be managed using the admin API:
+# https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/registration_tokens.html
+# Note that `enable_registration` must be set to `true`.
+# Disabling this option will not delete any tokens previously generated.
+# Defaults to false. Uncomment the following to require tokens:
+#
+#registration_requires_token: true
+
# If set, allows registration of standard or admin accounts by anyone who
# has the shared secret, even if registration is otherwise disabled.
#
diff --git a/docs/usage/administration/admin_api/registration_tokens.md b/docs/usage/administration/admin_api/registration_tokens.md
new file mode 100644
index 0000000000..828c0277d6
--- /dev/null
+++ b/docs/usage/administration/admin_api/registration_tokens.md
@@ -0,0 +1,295 @@
+# Registration Tokens
+
+This API allows you to manage tokens which can be used to authenticate
+registration requests, as proposed in [MSC3231](https://github.com/govynnus/matrix-doc/blob/token-registration/proposals/3231-token-authenticated-registration.md).
+To use it, you will need to enable the `registration_requires_token` config
+option, and authenticate by providing an `access_token` for a server admin:
+see [Admin API](../../usage/administration/admin_api).
+Note that this API is still experimental; not all clients may support it yet.
+
+
+## Registration token objects
+
+Most endpoints make use of JSON objects that contain details about tokens.
+These objects have the following fields:
+- `token`: The token which can be used to authenticate registration.
+- `uses_allowed`: The number of times the token can be used to complete a
+ registration before it becomes invalid.
+- `pending`: The number of pending uses the token has. When someone uses
+ the token to authenticate themselves, the pending counter is incremented
+ so that the token is not used more than the permitted number of times.
+ When the person completes registration the pending counter is decremented,
+ and the completed counter is incremented.
+- `completed`: The number of times the token has been used to successfully
+ complete a registration.
+- `expiry_time`: The latest time the token is valid. Given as the number of
+ milliseconds since 1970-01-01 00:00:00 UTC (the start of the Unix epoch).
+ To convert this into a human-readable form you can remove the milliseconds
+ and use the `date` command. For example, `date -d '@1625394937'`.
+
+
+## List all tokens
+
+Lists all tokens and details about them. If the request is successful, the top
+level JSON object will have a `registration_tokens` key which is an array of
+registration token objects.
+
+```
+GET /_synapse/admin/v1/registration_tokens
+```
+
+Optional query parameters:
+- `valid`: `true` or `false`. If `true`, only valid tokens are returned.
+ If `false`, only tokens that have expired or have had all uses exhausted are
+ returned. If omitted, all tokens are returned regardless of validity.
+
+Example:
+
+```
+GET /_synapse/admin/v1/registration_tokens
+```
+```
+200 OK
+
+{
+ "registration_tokens": [
+ {
+ "token": "abcd",
+ "uses_allowed": 3,
+ "pending": 0,
+ "completed": 1,
+ "expiry_time": null
+ },
+ {
+ "token": "pqrs",
+ "uses_allowed": 2,
+ "pending": 1,
+ "completed": 1,
+ "expiry_time": null
+ },
+ {
+ "token": "wxyz",
+ "uses_allowed": null,
+ "pending": 0,
+ "completed": 9,
+ "expiry_time": 1625394937000 // 2021-07-04 10:35:37 UTC
+ }
+ ]
+}
+```
+
+Example using the `valid` query parameter:
+
+```
+GET /_synapse/admin/v1/registration_tokens?valid=false
+```
+```
+200 OK
+
+{
+ "registration_tokens": [
+ {
+ "token": "pqrs",
+ "uses_allowed": 2,
+ "pending": 1,
+ "completed": 1,
+ "expiry_time": null
+ },
+ {
+ "token": "wxyz",
+ "uses_allowed": null,
+ "pending": 0,
+ "completed": 9,
+ "expiry_time": 1625394937000 // 2021-07-04 10:35:37 UTC
+ }
+ ]
+}
+```
+
+
+## Get one token
+
+Get details about a single token. If the request is successful, the response
+body will be a registration token object.
+
+```
+GET /_synapse/admin/v1/registration_tokens/<token>
+```
+
+Path parameters:
+- `token`: The registration token to return details of.
+
+Example:
+
+```
+GET /_synapse/admin/v1/registration_tokens/abcd
+```
+```
+200 OK
+
+{
+ "token": "abcd",
+ "uses_allowed": 3,
+ "pending": 0,
+ "completed": 1,
+ "expiry_time": null
+}
+```
+
+
+## Create token
+
+Create a new registration token. If the request is successful, the newly created
+token will be returned as a registration token object in the response body.
+
+```
+POST /_synapse/admin/v1/registration_tokens/new
+```
+
+The request body must be a JSON object and can contain the following fields:
+- `token`: The registration token. A string of no more than 64 characters that
+ consists only of characters matched by the regex `[A-Za-z0-9-_]`.
+ Default: randomly generated.
+- `uses_allowed`: The integer number of times the token can be used to complete
+ a registration before it becomes invalid.
+ Default: `null` (unlimited uses).
+- `expiry_time`: The latest time the token is valid. Given as the number of
+ milliseconds since 1970-01-01 00:00:00 UTC (the start of the Unix epoch).
+ You could use, for example, `date '+%s000' -d 'tomorrow'`.
+ Default: `null` (token does not expire).
+- `length`: The length of the token randomly generated if `token` is not
+ specified. Must be between 1 and 64 inclusive. Default: `16`.
+
+If a field is omitted the default is used.
+
+Example using defaults:
+
+```
+POST /_synapse/admin/v1/registration_tokens/new
+
+{}
+```
+```
+200 OK
+
+{
+ "token": "0M-9jbkf2t_Tgiw1",
+ "uses_allowed": null,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": null
+}
+```
+
+Example specifying some fields:
+
+```
+POST /_synapse/admin/v1/registration_tokens/new
+
+{
+ "token": "defg",
+ "uses_allowed": 1
+}
+```
+```
+200 OK
+
+{
+ "token": "defg",
+ "uses_allowed": 1,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": null
+}
+```
+
+
+## Update token
+
+Update the number of allowed uses or expiry time of a token. If the request is
+successful, the updated token will be returned as a registration token object
+in the response body.
+
+```
+PUT /_synapse/admin/v1/registration_tokens/<token>
+```
+
+Path parameters:
+- `token`: The registration token to update.
+
+The request body must be a JSON object and can contain the following fields:
+- `uses_allowed`: The integer number of times the token can be used to complete
+ a registration before it becomes invalid. By setting `uses_allowed` to `0`
+ the token can be easily made invalid without deleting it.
+ If `null` the token will have an unlimited number of uses.
+- `expiry_time`: The latest time the token is valid. Given as the number of
+ milliseconds since 1970-01-01 00:00:00 UTC (the start of the Unix epoch).
+ If `null` the token will not expire.
+
+If a field is omitted its value is not modified.
+
+Example:
+
+```
+PUT /_synapse/admin/v1/registration_tokens/defg
+
+{
+ "expiry_time": 4781243146000 // 2121-07-06 11:05:46 UTC
+}
+```
+```
+200 OK
+
+{
+ "token": "defg",
+ "uses_allowed": 1,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": 4781243146000
+}
+```
+
+
+## Delete token
+
+Delete a registration token. If the request is successful, the response body
+will be an empty JSON object.
+
+```
+DELETE /_synapse/admin/v1/registration_tokens/<token>
+```
+
+Path parameters:
+- `token`: The registration token to delete.
+
+Example:
+
+```
+DELETE /_synapse/admin/v1/registration_tokens/wxyz
+```
+```
+200 OK
+
+{}
+```
+
+
+## Errors
+
+If a request fails a "standard error response" will be returned as defined in
+the [Matrix Client-Server API specification](https://matrix.org/docs/spec/client_server/r0.6.1#api-standards).
+
+For example, if the token specified in a path parameter does not exist a
+`404 Not Found` error will be returned.
+
+```
+GET /_synapse/admin/v1/registration_tokens/1234
+```
+```
+404 Not Found
+
+{
+ "errcode": "M_NOT_FOUND",
+ "error": "No such registration token: 1234"
+}
+```
diff --git a/docs/workers.md b/docs/workers.md
index 2e63f03452..3121241894 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -236,6 +236,7 @@ expressions:
# Registration/login requests
^/_matrix/client/(api/v1|r0|unstable)/login$
^/_matrix/client/(r0|unstable)/register$
+ ^/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity$
# Event sending requests
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/redact
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index e0e24fddac..829061c870 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -79,6 +79,7 @@ class LoginType:
TERMS = "m.login.terms"
SSO = "m.login.sso"
DUMMY = "m.login.dummy"
+ REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token"
# This is used in the `type` parameter for /register when called by
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 845e6a8220..fd2626dbe1 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -95,7 +95,10 @@ from synapse.rest.client.profile import (
ProfileRestServlet,
)
from synapse.rest.client.push_rule import PushRuleRestServlet
-from synapse.rest.client.register import RegisterRestServlet
+from synapse.rest.client.register import (
+ RegisterRestServlet,
+ RegistrationTokenValidityRestServlet,
+)
from synapse.rest.client.sendtodevice import SendToDeviceRestServlet
from synapse.rest.client.versions import VersionsRestServlet
from synapse.rest.client.voip import VoipRestServlet
@@ -279,6 +282,7 @@ class GenericWorkerServer(HomeServer):
resource = JsonResource(self, canonical_json=False)
RegisterRestServlet(self).register(resource)
+ RegistrationTokenValidityRestServlet(self).register(resource)
login.register_servlets(self, resource)
ThreepidRestServlet(self).register(resource)
DevicesRestServlet(self).register(resource)
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 7a8d5851c4..f856327bd8 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -79,6 +79,11 @@ class RatelimitConfig(Config):
self.rc_registration = RateLimitConfig(config.get("rc_registration", {}))
+ self.rc_registration_token_validity = RateLimitConfig(
+ config.get("rc_registration_token_validity", {}),
+ defaults={"per_second": 0.1, "burst_count": 5},
+ )
+
rc_login_config = config.get("rc_login", {})
self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {}))
self.rc_login_account = RateLimitConfig(rc_login_config.get("account", {}))
@@ -143,6 +148,8 @@ class RatelimitConfig(Config):
# is using
# - one for registration that ratelimits registration requests based on the
# client's IP address.
+ # - one for checking the validity of registration tokens that ratelimits
+ # requests based on the client's IP address.
# - one for login that ratelimits login requests based on the client's IP
# address.
# - one for login that ratelimits login requests based on the account the
@@ -171,6 +178,10 @@ class RatelimitConfig(Config):
# per_second: 0.17
# burst_count: 3
#
+ #rc_registration_token_validity:
+ # per_second: 0.1
+ # burst_count: 5
+ #
#rc_login:
# address:
# per_second: 0.17
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 0ad919b139..7cffdacfa5 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -33,6 +33,9 @@ class RegistrationConfig(Config):
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
self.enable_3pid_lookup = config.get("enable_3pid_lookup", True)
+ self.registration_requires_token = config.get(
+ "registration_requires_token", False
+ )
self.registration_shared_secret = config.get("registration_shared_secret")
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
@@ -140,6 +143,9 @@ class RegistrationConfig(Config):
"mechanism by removing the `access_token_lifetime` option."
)
+ # The fallback template used for authenticating using a registration token
+ self.registration_token_template = self.read_template("registration_token.html")
+
# The success template used during fallback auth.
self.fallback_success_template = self.read_template("auth_success.html")
@@ -199,6 +205,15 @@ class RegistrationConfig(Config):
#
#enable_3pid_lookup: true
+ # Require users to submit a token during registration.
+ # Tokens can be managed using the admin API:
+ # https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/registration_tokens.html
+ # Note that `enable_registration` must be set to `true`.
+ # Disabling this option will not delete any tokens previously generated.
+ # Defaults to false. Uncomment the following to require tokens:
+ #
+ #registration_requires_token: true
+
# If set, allows registration of standard or admin accounts by anyone who
# has the shared secret, even if registration is otherwise disabled.
#
diff --git a/synapse/handlers/ui_auth/__init__.py b/synapse/handlers/ui_auth/__init__.py
index 4c3b669fae..13b0c61d2e 100644
--- a/synapse/handlers/ui_auth/__init__.py
+++ b/synapse/handlers/ui_auth/__init__.py
@@ -34,3 +34,8 @@ class UIAuthSessionDataConstants:
# used by validate_user_via_ui_auth to store the mxid of the user we are validating
# for.
REQUEST_USER_ID = "request_user_id"
+
+ # used during registration to store the registration token used (if required) so that:
+ # - we can prevent a token being used twice by one session
+ # - we can 'use up' the token after registration has successfully completed
+ REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token"
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 270541cc76..d3828dec6b 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -241,11 +241,76 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
return await self._check_threepid("msisdn", authdict)
+class RegistrationTokenAuthChecker(UserInteractiveAuthChecker):
+ AUTH_TYPE = LoginType.REGISTRATION_TOKEN
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+ self.hs = hs
+ self._enabled = bool(hs.config.registration_requires_token)
+ self.store = hs.get_datastore()
+
+ def is_enabled(self) -> bool:
+ return self._enabled
+
+ async def check_auth(self, authdict: dict, clientip: str) -> Any:
+ if "token" not in authdict:
+ raise LoginError(400, "Missing registration token", Codes.MISSING_PARAM)
+ if not isinstance(authdict["token"], str):
+ raise LoginError(
+ 400, "Registration token must be a string", Codes.INVALID_PARAM
+ )
+ if "session" not in authdict:
+ raise LoginError(400, "Missing UIA session", Codes.MISSING_PARAM)
+
+ # Get these here to avoid cyclic dependencies
+ from synapse.handlers.ui_auth import UIAuthSessionDataConstants
+
+ auth_handler = self.hs.get_auth_handler()
+
+ session = authdict["session"]
+ token = authdict["token"]
+
+ # If the LoginType.REGISTRATION_TOKEN stage has already been completed,
+ # return early to avoid incrementing `pending` again.
+ stored_token = await auth_handler.get_session_data(
+ session, UIAuthSessionDataConstants.REGISTRATION_TOKEN
+ )
+ if stored_token:
+ if token != stored_token:
+ raise LoginError(
+ 400, "Registration token has changed", Codes.INVALID_PARAM
+ )
+ else:
+ return token
+
+ if await self.store.registration_token_is_valid(token):
+ # Increment pending counter, so that if token has limited uses it
+ # can't be used up by someone else in the meantime.
+ await self.store.set_registration_token_pending(token)
+ # Store the token in the UIA session, so that once registration
+ # is complete `completed` can be incremented.
+ await auth_handler.set_session_data(
+ session,
+ UIAuthSessionDataConstants.REGISTRATION_TOKEN,
+ token,
+ )
+ # The token will be stored as the result of the authentication stage
+ # in ui_auth_sessions_credentials. This allows the pending counter
+ # for tokens to be decremented when expired sessions are deleted.
+ return token
+ else:
+ raise LoginError(
+ 401, "Invalid registration token", errcode=Codes.UNAUTHORIZED
+ )
+
+
INTERACTIVE_AUTH_CHECKERS = [
DummyAuthChecker,
TermsAuthChecker,
RecaptchaAuthChecker,
EmailIdentityAuthChecker,
MsisdnAuthChecker,
+ RegistrationTokenAuthChecker,
]
"""A list of UserInteractiveAuthChecker classes"""
diff --git a/synapse/res/templates/registration_token.html b/synapse/res/templates/registration_token.html
new file mode 100644
index 0000000000..4577ce1702
--- /dev/null
+++ b/synapse/res/templates/registration_token.html
@@ -0,0 +1,23 @@
+<html>
+<head>
+<title>Authentication</title>
+<meta name='viewport' content='width=device-width, initial-scale=1,
+ user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
+<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
+</head>
+<body>
+<form id="registrationForm" method="post" action="{{ myurl }}">
+ <div>
+ {% if error is defined %}
+ <p class="error"><strong>Error: {{ error }}</strong></p>
+ {% endif %}
+ <p>
+ Please enter a registration token.
+ </p>
+ <input type="hidden" name="session" value="{{ session }}" />
+ <input type="text" name="token" />
+ <input type="submit" value="Authenticate" />
+ </div>
+</form>
+</body>
+</html>
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 7f3051aef1..6e1c8736e1 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -36,6 +36,11 @@ from synapse.rest.admin.event_reports import (
)
from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
+from synapse.rest.admin.registration_tokens import (
+ ListRegistrationTokensRestServlet,
+ NewRegistrationTokenRestServlet,
+ RegistrationTokenRestServlet,
+)
from synapse.rest.admin.rooms import (
DeleteRoomRestServlet,
ForwardExtremitiesRestServlet,
@@ -238,6 +243,9 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RoomEventContextServlet(hs).register(http_server)
RateLimitRestServlet(hs).register(http_server)
UsernameAvailableRestServlet(hs).register(http_server)
+ ListRegistrationTokensRestServlet(hs).register(http_server)
+ NewRegistrationTokenRestServlet(hs).register(http_server)
+ RegistrationTokenRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(
diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py
new file mode 100644
index 0000000000..5a1c929d85
--- /dev/null
+++ b/synapse/rest/admin/registration_tokens.py
@@ -0,0 +1,321 @@
+# Copyright 2021 Callum Brown
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import string
+from typing import TYPE_CHECKING, Tuple
+
+from synapse.api.errors import Codes, NotFoundError, SynapseError
+from synapse.http.servlet import (
+ RestServlet,
+ parse_boolean,
+ parse_json_object_from_request,
+)
+from synapse.http.site import SynapseRequest
+from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class ListRegistrationTokensRestServlet(RestServlet):
+ """List registration tokens.
+
+ To list all tokens:
+
+ GET /_synapse/admin/v1/registration_tokens
+
+ 200 OK
+
+ {
+ "registration_tokens": [
+ {
+ "token": "abcd",
+ "uses_allowed": 3,
+ "pending": 0,
+ "completed": 1,
+ "expiry_time": null
+ },
+ {
+ "token": "wxyz",
+ "uses_allowed": null,
+ "pending": 0,
+ "completed": 9,
+ "expiry_time": 1625394937000
+ }
+ ]
+ }
+
+ The optional query parameter `valid` can be used to filter the response.
+ If it is `true`, only valid tokens are returned. If it is `false`, only
+ tokens that have expired or have had all uses exhausted are returned.
+ If it is omitted, all tokens are returned regardless of validity.
+ """
+
+ PATTERNS = admin_patterns("/registration_tokens$")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ await assert_requester_is_admin(self.auth, request)
+ valid = parse_boolean(request, "valid")
+ token_list = await self.store.get_registration_tokens(valid)
+ return 200, {"registration_tokens": token_list}
+
+
+class NewRegistrationTokenRestServlet(RestServlet):
+ """Create a new registration token.
+
+ For example, to create a token specifying some fields:
+
+ POST /_synapse/admin/v1/registration_tokens/new
+
+ {
+ "token": "defg",
+ "uses_allowed": 1
+ }
+
+ 200 OK
+
+ {
+ "token": "defg",
+ "uses_allowed": 1,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": null
+ }
+
+ Defaults are used for any fields not specified.
+ """
+
+ PATTERNS = admin_patterns("/registration_tokens/new$")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ # A string of all the characters allowed to be in a registration_token
+ self.allowed_chars = string.ascii_letters + string.digits + "-_"
+ self.allowed_chars_set = set(self.allowed_chars)
+
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ await assert_requester_is_admin(self.auth, request)
+ body = parse_json_object_from_request(request)
+
+ if "token" in body:
+ token = body["token"]
+ if not isinstance(token, str):
+ raise SynapseError(400, "token must be a string", Codes.INVALID_PARAM)
+ if not (0 < len(token) <= 64):
+ raise SynapseError(
+ 400,
+ "token must not be empty and must not be longer than 64 characters",
+ Codes.INVALID_PARAM,
+ )
+ if not set(token).issubset(self.allowed_chars_set):
+ raise SynapseError(
+ 400,
+ "token must consist only of characters matched by the regex [A-Za-z0-9-_]",
+ Codes.INVALID_PARAM,
+ )
+
+ else:
+ # Get length of token to generate (default is 16)
+ length = body.get("length", 16)
+ if not isinstance(length, int):
+ raise SynapseError(
+ 400, "length must be an integer", Codes.INVALID_PARAM
+ )
+ if not (0 < length <= 64):
+ raise SynapseError(
+ 400,
+ "length must be greater than zero and not greater than 64",
+ Codes.INVALID_PARAM,
+ )
+
+ # Generate token
+ token = await self.store.generate_registration_token(
+ length, self.allowed_chars
+ )
+
+ uses_allowed = body.get("uses_allowed", None)
+ if not (
+ uses_allowed is None
+ or (isinstance(uses_allowed, int) and uses_allowed >= 0)
+ ):
+ raise SynapseError(
+ 400,
+ "uses_allowed must be a non-negative integer or null",
+ Codes.INVALID_PARAM,
+ )
+
+ expiry_time = body.get("expiry_time", None)
+ if not isinstance(expiry_time, (int, type(None))):
+ raise SynapseError(
+ 400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
+ )
+ if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
+ raise SynapseError(
+ 400, "expiry_time must not be in the past", Codes.INVALID_PARAM
+ )
+
+ created = await self.store.create_registration_token(
+ token, uses_allowed, expiry_time
+ )
+ if not created:
+ raise SynapseError(
+ 400, f"Token already exists: {token}", Codes.INVALID_PARAM
+ )
+
+ resp = {
+ "token": token,
+ "uses_allowed": uses_allowed,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": expiry_time,
+ }
+ return 200, resp
+
+
+class RegistrationTokenRestServlet(RestServlet):
+ """Retrieve, update, or delete the given token.
+
+ For example,
+
+ to retrieve a token:
+
+ GET /_synapse/admin/v1/registration_tokens/abcd
+
+ 200 OK
+
+ {
+ "token": "abcd",
+ "uses_allowed": 3,
+ "pending": 0,
+ "completed": 1,
+ "expiry_time": null
+ }
+
+
+ to update a token:
+
+ PUT /_synapse/admin/v1/registration_tokens/defg
+
+ {
+ "uses_allowed": 5,
+ "expiry_time": 4781243146000
+ }
+
+ 200 OK
+
+ {
+ "token": "defg",
+ "uses_allowed": 5,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": 4781243146000
+ }
+
+
+ to delete a token:
+
+ DELETE /_synapse/admin/v1/registration_tokens/wxyz
+
+ 200 OK
+
+ {}
+ """
+
+ PATTERNS = admin_patterns("/registration_tokens/(?P<token>[^/]*)$")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.clock = hs.get_clock()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+
+ async def on_GET(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]:
+ """Retrieve a registration token."""
+ await assert_requester_is_admin(self.auth, request)
+ token_info = await self.store.get_one_registration_token(token)
+
+ # If no result return a 404
+ if token_info is None:
+ raise NotFoundError(f"No such registration token: {token}")
+
+ return 200, token_info
+
+ async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]:
+ """Update a registration token."""
+ await assert_requester_is_admin(self.auth, request)
+ body = parse_json_object_from_request(request)
+ new_attributes = {}
+
+ # Only add uses_allowed to new_attributes if it is present and valid
+ if "uses_allowed" in body:
+ uses_allowed = body["uses_allowed"]
+ if not (
+ uses_allowed is None
+ or (isinstance(uses_allowed, int) and uses_allowed >= 0)
+ ):
+ raise SynapseError(
+ 400,
+ "uses_allowed must be a non-negative integer or null",
+ Codes.INVALID_PARAM,
+ )
+ new_attributes["uses_allowed"] = uses_allowed
+
+ if "expiry_time" in body:
+ expiry_time = body["expiry_time"]
+ if not isinstance(expiry_time, (int, type(None))):
+ raise SynapseError(
+ 400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
+ )
+ if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
+ raise SynapseError(
+ 400, "expiry_time must not be in the past", Codes.INVALID_PARAM
+ )
+ new_attributes["expiry_time"] = expiry_time
+
+ if len(new_attributes) == 0:
+ # Nothing to update, get token info to return
+ token_info = await self.store.get_one_registration_token(token)
+ else:
+ token_info = await self.store.update_registration_token(
+ token, new_attributes
+ )
+
+ # If no result return a 404
+ if token_info is None:
+ raise NotFoundError(f"No such registration token: {token}")
+
+ return 200, token_info
+
+ async def on_DELETE(
+ self, request: SynapseRequest, token: str
+ ) -> Tuple[int, JsonDict]:
+ """Delete a registration token."""
+ await assert_requester_is_admin(self.auth, request)
+
+ if await self.store.delete_registration_token(token):
+ return 200, {}
+
+ raise NotFoundError(f"No such registration token: {token}")
diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py
index 73284e48ec..91800c0278 100644
--- a/synapse/rest/client/auth.py
+++ b/synapse/rest/client/auth.py
@@ -46,6 +46,7 @@ class AuthRestServlet(RestServlet):
self.registration_handler = hs.get_registration_handler()
self.recaptcha_template = hs.config.recaptcha_template
self.terms_template = hs.config.terms_template
+ self.registration_token_template = hs.config.registration_token_template
self.success_template = hs.config.fallback_success_template
async def on_GET(self, request, stagetype):
@@ -74,6 +75,12 @@ class AuthRestServlet(RestServlet):
# re-authenticate with their SSO provider.
html = await self.auth_handler.start_sso_ui_auth(request, session)
+ elif stagetype == LoginType.REGISTRATION_TOKEN:
+ html = self.registration_token_template.render(
+ session=session,
+ myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web",
+ )
+
else:
raise SynapseError(404, "Unknown auth stage type")
@@ -140,6 +147,23 @@ class AuthRestServlet(RestServlet):
# The SSO fallback workflow should not post here,
raise SynapseError(404, "Fallback SSO auth does not support POST requests.")
+ elif stagetype == LoginType.REGISTRATION_TOKEN:
+ token = parse_string(request, "token", required=True)
+ authdict = {"session": session, "token": token}
+
+ try:
+ await self.auth_handler.add_oob_auth(
+ LoginType.REGISTRATION_TOKEN, authdict, request.getClientIP()
+ )
+ except LoginError as e:
+ html = self.registration_token_template.render(
+ session=session,
+ myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web",
+ error=e.msg,
+ )
+ else:
+ html = self.success_template.render()
+
else:
raise SynapseError(404, "Unknown auth stage type")
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 58b8e8f261..2781a0ea96 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -28,6 +28,7 @@ from synapse.api.errors import (
ThreepidValidationError,
UnrecognizedRequestError,
)
+from synapse.api.ratelimiting import Ratelimiter
from synapse.config import ConfigError
from synapse.config.captcha import CaptchaConfig
from synapse.config.consent import ConsentConfig
@@ -379,6 +380,55 @@ class UsernameAvailabilityRestServlet(RestServlet):
return 200, {"available": True}
+class RegistrationTokenValidityRestServlet(RestServlet):
+ """Check the validity of a registration token.
+
+ Example:
+
+ GET /_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity?token=abcd
+
+ 200 OK
+
+ {
+ "valid": true
+ }
+ """
+
+ PATTERNS = client_patterns(
+ f"/org.matrix.msc3231/register/{LoginType.REGISTRATION_TOKEN}/validity",
+ releases=(),
+ unstable=True,
+ )
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super().__init__()
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.ratelimiter = Ratelimiter(
+ store=self.store,
+ clock=hs.get_clock(),
+ rate_hz=hs.config.ratelimiting.rc_registration_token_validity.per_second,
+ burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count,
+ )
+
+ async def on_GET(self, request):
+ await self.ratelimiter.ratelimit(None, (request.getClientIP(),))
+
+ if not self.hs.config.enable_registration:
+ raise SynapseError(
+ 403, "Registration has been disabled", errcode=Codes.FORBIDDEN
+ )
+
+ token = parse_string(request, "token", required=True)
+ valid = await self.store.registration_token_is_valid(token)
+
+ return 200, {"valid": valid}
+
+
class RegisterRestServlet(RestServlet):
PATTERNS = client_patterns("/register$")
@@ -686,6 +736,22 @@ class RegisterRestServlet(RestServlet):
)
if registered:
+ # Check if a token was used to authenticate registration
+ registration_token = await self.auth_handler.get_session_data(
+ session_id,
+ UIAuthSessionDataConstants.REGISTRATION_TOKEN,
+ )
+ if registration_token:
+ # Increment the `completed` counter for the token
+ await self.store.use_registration_token(registration_token)
+ # Indicate that the token has been successfully used so that
+ # pending is not decremented again when expiring old UIA sessions.
+ await self.store.mark_ui_auth_stage_complete(
+ session_id,
+ LoginType.REGISTRATION_TOKEN,
+ True,
+ )
+
await self.registration_handler.post_registration_actions(
user_id=registered_user_id,
auth_result=auth_result,
@@ -868,6 +934,11 @@ def _calculate_registration_flows(
for flow in flows:
flow.insert(0, LoginType.RECAPTCHA)
+ # Prepend registration token to all flows if we're requiring a token
+ if config.registration_requires_token:
+ for flow in flows:
+ flow.insert(0, LoginType.REGISTRATION_TOKEN)
+
return flows
@@ -876,4 +947,5 @@ def register_servlets(hs, http_server):
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
UsernameAvailabilityRestServlet(hs).register(http_server)
RegistrationSubmitTokenServlet(hs).register(http_server)
+ RegistrationTokenValidityRestServlet(hs).register(http_server)
RegisterRestServlet(hs).register(http_server)
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 469dd53e0c..a6517962f6 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1168,6 +1168,322 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="update_access_token_last_validated",
)
+ async def registration_token_is_valid(self, token: str) -> bool:
+ """Checks if a token can be used to authenticate a registration.
+
+ Args:
+ token: The registration token to be checked
+ Returns:
+ True if the token is valid, False otherwise.
+ """
+ res = await self.db_pool.simple_select_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["uses_allowed", "pending", "completed", "expiry_time"],
+ allow_none=True,
+ )
+
+ # Check if the token exists
+ if res is None:
+ return False
+
+ # Check if the token has expired
+ now = self._clock.time_msec()
+ if res["expiry_time"] and res["expiry_time"] < now:
+ return False
+
+ # Check if the token has been used up
+ if (
+ res["uses_allowed"]
+ and res["pending"] + res["completed"] >= res["uses_allowed"]
+ ):
+ return False
+
+ # Otherwise, the token is valid
+ return True
+
+ async def set_registration_token_pending(self, token: str) -> None:
+ """Increment the pending registrations counter for a token.
+
+ Args:
+ token: The registration token pending use
+ """
+
+ def _set_registration_token_pending_txn(txn):
+ pending = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcol="pending",
+ )
+ self.db_pool.simple_update_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ updatevalues={"pending": pending + 1},
+ )
+
+ return await self.db_pool.runInteraction(
+ "set_registration_token_pending", _set_registration_token_pending_txn
+ )
+
+ async def use_registration_token(self, token: str) -> None:
+ """Complete a use of the given registration token.
+
+ The `pending` counter will be decremented, and the `completed`
+ counter will be incremented.
+
+ Args:
+ token: The registration token to be 'used'
+ """
+
+ def _use_registration_token_txn(txn):
+ # Normally, res is Optional[Dict[str, Any]].
+ # Override type because the return type is only optional if
+ # allow_none is True, and we don't want mypy throwing errors
+ # about None not being indexable.
+ res: Dict[str, Any] = self.db_pool.simple_select_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["pending", "completed"],
+ ) # type: ignore
+
+ # Decrement pending and increment completed
+ self.db_pool.simple_update_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ updatevalues={
+ "completed": res["completed"] + 1,
+ "pending": res["pending"] - 1,
+ },
+ )
+
+ return await self.db_pool.runInteraction(
+ "use_registration_token", _use_registration_token_txn
+ )
+
+ async def get_registration_tokens(
+ self, valid: Optional[bool] = None
+ ) -> List[Dict[str, Any]]:
+ """List all registration tokens. Used by the admin API.
+
+ Args:
+ valid: If True, only valid tokens are returned.
+ If False, only invalid tokens are returned.
+ Default is None: return all tokens regardless of validity.
+
+ Returns:
+ A list of dicts, each containing details of a token.
+ """
+
+ def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]):
+ if valid is None:
+ # Return all tokens regardless of validity
+ txn.execute("SELECT * FROM registration_tokens")
+
+ elif valid:
+ # Select valid tokens only
+ sql = (
+ "SELECT * FROM registration_tokens WHERE "
+ "(uses_allowed > pending + completed OR uses_allowed IS NULL) "
+ "AND (expiry_time > ? OR expiry_time IS NULL)"
+ )
+ txn.execute(sql, [now])
+
+ else:
+ # Select invalid tokens only
+ sql = (
+ "SELECT * FROM registration_tokens WHERE "
+ "uses_allowed <= pending + completed OR expiry_time <= ?"
+ )
+ txn.execute(sql, [now])
+
+ return self.db_pool.cursor_to_dict(txn)
+
+ return await self.db_pool.runInteraction(
+ "select_registration_tokens",
+ select_registration_tokens_txn,
+ self._clock.time_msec(),
+ valid,
+ )
+
+ async def get_one_registration_token(self, token: str) -> Optional[Dict[str, Any]]:
+ """Get info about the given registration token. Used by the admin API.
+
+ Args:
+ token: The token to retrieve information about.
+
+ Returns:
+ A dict, or None if token doesn't exist.
+ """
+ return await self.db_pool.simple_select_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"],
+ allow_none=True,
+ desc="get_one_registration_token",
+ )
+
+ async def generate_registration_token(
+ self, length: int, chars: str
+ ) -> Optional[str]:
+ """Generate a random registration token. Used by the admin API.
+
+ Args:
+ length: The length of the token to generate.
+ chars: A string of the characters allowed in the generated token.
+
+ Returns:
+ The generated token.
+
+ Raises:
+ SynapseError if a unique registration token could still not be
+ generated after a few tries.
+ """
+ # Make a few attempts at generating a unique token of the required
+ # length before failing.
+ for _i in range(3):
+ # Generate token
+ token = "".join(random.choices(chars, k=length))
+
+ # Check if the token already exists
+ existing_token = await self.db_pool.simple_select_one_onecol(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcol="token",
+ allow_none=True,
+ desc="check_if_registration_token_exists",
+ )
+
+ if existing_token is None:
+ # The generated token doesn't exist yet, return it
+ return token
+
+ raise SynapseError(
+ 500,
+ "Unable to generate a unique registration token. Try again with a greater length",
+ Codes.UNKNOWN,
+ )
+
+ async def create_registration_token(
+ self, token: str, uses_allowed: Optional[int], expiry_time: Optional[int]
+ ) -> bool:
+ """Create a new registration token. Used by the admin API.
+
+ Args:
+ token: The token to create.
+ uses_allowed: The number of times the token can be used to complete
+ a registration before it becomes invalid. A value of None indicates
+ unlimited uses.
+ expiry_time: The latest time the token is valid. Given as the
+ number of milliseconds since 1970-01-01 00:00:00 UTC. A value of
+ None indicates that the token does not expire.
+
+ Returns:
+ Whether the row was inserted or not.
+ """
+
+ def _create_registration_token_txn(txn):
+ row = self.db_pool.simple_select_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["token"],
+ allow_none=True,
+ )
+
+ if row is not None:
+ # Token already exists
+ return False
+
+ self.db_pool.simple_insert_txn(
+ txn,
+ "registration_tokens",
+ values={
+ "token": token,
+ "uses_allowed": uses_allowed,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": expiry_time,
+ },
+ )
+
+ return True
+
+ return await self.db_pool.runInteraction(
+ "create_registration_token", _create_registration_token_txn
+ )
+
+ async def update_registration_token(
+ self, token: str, updatevalues: Dict[str, Optional[int]]
+ ) -> Optional[Dict[str, Any]]:
+ """Update a registration token. Used by the admin API.
+
+ Args:
+ token: The token to update.
+ updatevalues: A dict with the fields to update. E.g.:
+ `{"uses_allowed": 3}` to update just uses_allowed, or
+ `{"uses_allowed": 3, "expiry_time": None}` to update both.
+ This is passed straight to simple_update_one.
+
+ Returns:
+ A dict with all info about the token, or None if token doesn't exist.
+ """
+
+ def _update_registration_token_txn(txn):
+ try:
+ self.db_pool.simple_update_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ updatevalues=updatevalues,
+ )
+ except StoreError:
+ # Update failed because token does not exist
+ return None
+
+ # Get all info about the token so it can be sent in the response
+ return self.db_pool.simple_select_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=[
+ "token",
+ "uses_allowed",
+ "pending",
+ "completed",
+ "expiry_time",
+ ],
+ allow_none=True,
+ )
+
+ return await self.db_pool.runInteraction(
+ "update_registration_token", _update_registration_token_txn
+ )
+
+ async def delete_registration_token(self, token: str) -> bool:
+ """Delete a registration token. Used by the admin API.
+
+ Args:
+ token: The token to delete.
+
+ Returns:
+ Whether the token was successfully deleted or not.
+ """
+ try:
+ await self.db_pool.simple_delete_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ desc="delete_registration_token",
+ )
+ except StoreError:
+ # Deletion failed because token does not exist
+ return False
+
+ return True
+
@cached()
async def mark_access_token_as_used(self, token_id: int) -> None:
"""
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 38bfdf5dad..4d6bbc94c7 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import attr
+from synapse.api.constants import LoginType
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction
@@ -329,6 +330,48 @@ class UIAuthWorkerStore(SQLBaseStore):
keyvalues={},
)
+ # If a registration token was used, decrement the pending counter
+ # before deleting the session.
+ rows = self.db_pool.simple_select_many_txn(
+ txn,
+ table="ui_auth_sessions_credentials",
+ column="session_id",
+ iterable=session_ids,
+ keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
+ retcols=["result"],
+ )
+
+ # Get the tokens used and how much pending needs to be decremented by.
+ token_counts: Dict[str, int] = {}
+ for r in rows:
+ # If registration was successfully completed, the result of the
+ # registration token stage for that session will be True.
+ # If a token was used to authenticate, but registration was
+ # never completed, the result will be the token used.
+ token = db_to_json(r["result"])
+ if isinstance(token, str):
+ token_counts[token] = token_counts.get(token, 0) + 1
+
+ # Update the `pending` counters.
+ if len(token_counts) > 0:
+ token_rows = self.db_pool.simple_select_many_txn(
+ txn,
+ table="registration_tokens",
+ column="token",
+ iterable=list(token_counts.keys()),
+ keyvalues={},
+ retcols=["token", "pending"],
+ )
+ for token_row in token_rows:
+ token = token_row["token"]
+ new_pending = token_row["pending"] - token_counts[token]
+ self.db_pool.simple_update_one_txn(
+ txn,
+ table="registration_tokens",
+ keyvalues={"token": token},
+ updatevalues={"pending": new_pending},
+ )
+
# Delete the corresponding completed credentials.
self.db_pool.simple_delete_many_txn(
txn,
diff --git a/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql b/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql
new file mode 100644
index 0000000000..ee6cf958f4
--- /dev/null
+++ b/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql
@@ -0,0 +1,23 @@
+/* Copyright 2021 Callum Brown
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS registration_tokens(
+ token TEXT NOT NULL, -- The token that can be used for authentication.
+ uses_allowed INT, -- The total number of times this token can be used. NULL if no limit.
+ pending INT NOT NULL, -- The number of in progress registrations using this token.
+ completed INT NOT NULL, -- The number of times this token has been used to complete a registration.
+ expiry_time BIGINT, -- The latest time this token will be valid (epoch time in milliseconds). NULL if token doesn't expire.
+ UNIQUE (token)
+);
diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py
new file mode 100644
index 0000000000..4927321e5a
--- /dev/null
+++ b/tests/rest/admin/test_registration_tokens.py
@@ -0,0 +1,710 @@
+# Copyright 2021 Callum Brown
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+import string
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client import login
+
+from tests import unittest
+
+
+class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+
+ self.url = "/_synapse/admin/v1/registration_tokens"
+
+ def _new_token(self, **kwargs):
+ """Helper function to create a token."""
+ token = kwargs.get(
+ "token",
+ "".join(random.choices(string.ascii_letters, k=8)),
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": kwargs.get("uses_allowed", None),
+ "pending": kwargs.get("pending", 0),
+ "completed": kwargs.get("completed", 0),
+ "expiry_time": kwargs.get("expiry_time", None),
+ },
+ )
+ )
+ return token
+
+ # CREATION
+
+ def test_create_no_auth(self):
+ """Try to create a token without authentication."""
+ channel = self.make_request("POST", self.url + "/new", {})
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_create_requester_not_admin(self):
+ """Try to create a token while not an admin."""
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_create_using_defaults(self):
+ """Create a token using all the defaults."""
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["token"]), 16)
+ self.assertIsNone(channel.json_body["uses_allowed"])
+ self.assertIsNone(channel.json_body["expiry_time"])
+ self.assertEqual(channel.json_body["pending"], 0)
+ self.assertEqual(channel.json_body["completed"], 0)
+
+ def test_create_specifying_fields(self):
+ """Create a token specifying the value of all fields."""
+ data = {
+ "token": "abcd",
+ "uses_allowed": 1,
+ "expiry_time": self.clock.time_msec() + 1000000,
+ }
+
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["token"], "abcd")
+ self.assertEqual(channel.json_body["uses_allowed"], 1)
+ self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"])
+ self.assertEqual(channel.json_body["pending"], 0)
+ self.assertEqual(channel.json_body["completed"], 0)
+
+ def test_create_with_null_value(self):
+ """Create a token specifying unlimited uses and no expiry."""
+ data = {
+ "uses_allowed": None,
+ "expiry_time": None,
+ }
+
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["token"]), 16)
+ self.assertIsNone(channel.json_body["uses_allowed"])
+ self.assertIsNone(channel.json_body["expiry_time"])
+ self.assertEqual(channel.json_body["pending"], 0)
+ self.assertEqual(channel.json_body["completed"], 0)
+
+ def test_create_token_too_long(self):
+ """Check token longer than 64 chars is invalid."""
+ data = {"token": "a" * 65}
+
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_create_token_invalid_chars(self):
+ """Check you can't create token with invalid characters."""
+ data = {
+ "token": "abc/def",
+ }
+
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_create_token_already_exists(self):
+ """Check you can't create token that already exists."""
+ data = {
+ "token": "abcd",
+ }
+
+ channel1 = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel1.result["code"]), msg=channel1.result["body"])
+
+ channel2 = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel2.result["code"]), msg=channel2.result["body"])
+ self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_create_unable_to_generate_token(self):
+ """Check right error is raised when server can't generate unique token."""
+ # Create all possible single character tokens
+ tokens = []
+ for c in string.ascii_letters + string.digits + "-_":
+ tokens.append(
+ {
+ "token": c,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ }
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert_many(
+ "registration_tokens",
+ tokens,
+ "create_all_registration_tokens",
+ )
+ )
+
+ # Check creating a single character token fails with a 500 status code
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": 1},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(500, int(channel.result["code"]), msg=channel.result["body"])
+
+ def test_create_uses_allowed(self):
+ """Check you can only create a token with good values for uses_allowed."""
+ # Should work with 0 (token is invalid from the start)
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"uses_allowed": 0},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["uses_allowed"], 0)
+
+ # Should fail with negative integer
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"uses_allowed": -5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with float
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"uses_allowed": 1.5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_create_expiry_time(self):
+ """Check you can't create a token with an invalid expiry_time."""
+ # Should fail with a time in the past
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"expiry_time": self.clock.time_msec() - 10000},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with float
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"expiry_time": self.clock.time_msec() + 1000000.5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_create_length(self):
+ """Check you can only generate a token with a valid length."""
+ # Should work with 64
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": 64},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["token"]), 64)
+
+ # Should fail with 0
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": 0},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with a negative integer
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": -5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with a float
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": 8.5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with 65
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": 65},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # UPDATING
+
+ def test_update_no_auth(self):
+ """Try to update a token without authentication."""
+ channel = self.make_request(
+ "PUT",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_update_requester_not_admin(self):
+ """Try to update a token while not an admin."""
+ channel = self.make_request(
+ "PUT",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_update_non_existent(self):
+ """Try to update a token that doesn't exist."""
+ channel = self.make_request(
+ "PUT",
+ self.url + "/1234",
+ {"uses_allowed": 1},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_update_uses_allowed(self):
+ """Test updating just uses_allowed."""
+ # Create new token using default values
+ token = self._new_token()
+
+ # Should succeed with 1
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"uses_allowed": 1},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["uses_allowed"], 1)
+ self.assertIsNone(channel.json_body["expiry_time"])
+
+ # Should succeed with 0 (makes token invalid)
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"uses_allowed": 0},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["uses_allowed"], 0)
+ self.assertIsNone(channel.json_body["expiry_time"])
+
+ # Should succeed with null
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"uses_allowed": None},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertIsNone(channel.json_body["uses_allowed"])
+ self.assertIsNone(channel.json_body["expiry_time"])
+
+ # Should fail with a float
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"uses_allowed": 1.5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with a negative integer
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"uses_allowed": -5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_update_expiry_time(self):
+ """Test updating just expiry_time."""
+ # Create new token using default values
+ token = self._new_token()
+ new_expiry_time = self.clock.time_msec() + 1000000
+
+ # Should succeed with a time in the future
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"expiry_time": new_expiry_time},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
+ self.assertIsNone(channel.json_body["uses_allowed"])
+
+ # Should succeed with null
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"expiry_time": None},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertIsNone(channel.json_body["expiry_time"])
+ self.assertIsNone(channel.json_body["uses_allowed"])
+
+ # Should fail with a time in the past
+ past_time = self.clock.time_msec() - 10000
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"expiry_time": past_time},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail a float
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"expiry_time": new_expiry_time + 0.5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_update_both(self):
+ """Test updating both uses_allowed and expiry_time."""
+ # Create new token using default values
+ token = self._new_token()
+ new_expiry_time = self.clock.time_msec() + 1000000
+
+ data = {
+ "uses_allowed": 1,
+ "expiry_time": new_expiry_time,
+ }
+
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["uses_allowed"], 1)
+ self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
+
+ def test_update_invalid_type(self):
+ """Test using invalid types doesn't work."""
+ # Create new token using default values
+ token = self._new_token()
+
+ data = {
+ "uses_allowed": False,
+ "expiry_time": "1626430124000",
+ }
+
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # DELETING
+
+ def test_delete_no_auth(self):
+ """Try to delete a token without authentication."""
+ channel = self.make_request(
+ "DELETE",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_delete_requester_not_admin(self):
+ """Try to delete a token while not an admin."""
+ channel = self.make_request(
+ "DELETE",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_delete_non_existent(self):
+ """Try to delete a token that doesn't exist."""
+ channel = self.make_request(
+ "DELETE",
+ self.url + "/1234",
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_delete(self):
+ """Test deleting a token."""
+ # Create new token using default values
+ token = self._new_token()
+
+ channel = self.make_request(
+ "DELETE",
+ self.url + "/" + token,
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # GETTING ONE
+
+ def test_get_no_auth(self):
+ """Try to get a token without authentication."""
+ channel = self.make_request(
+ "GET",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_get_requester_not_admin(self):
+ """Try to get a token while not an admin."""
+ channel = self.make_request(
+ "GET",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_get_non_existent(self):
+ """Try to get a token that doesn't exist."""
+ channel = self.make_request(
+ "GET",
+ self.url + "/1234",
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_get(self):
+ """Test getting a token."""
+ # Create new token using default values
+ token = self._new_token()
+
+ channel = self.make_request(
+ "GET",
+ self.url + "/" + token,
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["token"], token)
+ self.assertIsNone(channel.json_body["uses_allowed"])
+ self.assertIsNone(channel.json_body["expiry_time"])
+ self.assertEqual(channel.json_body["pending"], 0)
+ self.assertEqual(channel.json_body["completed"], 0)
+
+ # LISTING
+
+ def test_list_no_auth(self):
+ """Try to list tokens without authentication."""
+ channel = self.make_request("GET", self.url, {})
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_list_requester_not_admin(self):
+ """Try to list tokens while not an admin."""
+ channel = self.make_request(
+ "GET",
+ self.url,
+ {},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_list_all(self):
+ """Test listing all tokens."""
+ # Create new token using default values
+ token = self._new_token()
+
+ channel = self.make_request(
+ "GET",
+ self.url,
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["registration_tokens"]), 1)
+ token_info = channel.json_body["registration_tokens"][0]
+ self.assertEqual(token_info["token"], token)
+ self.assertIsNone(token_info["uses_allowed"])
+ self.assertIsNone(token_info["expiry_time"])
+ self.assertEqual(token_info["pending"], 0)
+ self.assertEqual(token_info["completed"], 0)
+
+ def test_list_invalid_query_parameter(self):
+ """Test with `valid` query parameter not `true` or `false`."""
+ channel = self.make_request(
+ "GET",
+ self.url + "?valid=x",
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+
+ def _test_list_query_parameter(self, valid: str):
+ """Helper used to test both valid=true and valid=false."""
+ # Create 2 valid and 2 invalid tokens.
+ now = self.hs.get_clock().time_msec()
+ # Create always valid token
+ valid1 = self._new_token()
+ # Create token that hasn't been used up
+ valid2 = self._new_token(uses_allowed=1)
+ # Create token that has expired
+ invalid1 = self._new_token(expiry_time=now - 10000)
+ # Create token that has been used up but hasn't expired
+ invalid2 = self._new_token(
+ uses_allowed=2,
+ pending=1,
+ completed=1,
+ expiry_time=now + 1000000,
+ )
+
+ if valid == "true":
+ tokens = [valid1, valid2]
+ else:
+ tokens = [invalid1, invalid2]
+
+ channel = self.make_request(
+ "GET",
+ self.url + "?valid=" + valid,
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["registration_tokens"]), 2)
+ token_info_1 = channel.json_body["registration_tokens"][0]
+ token_info_2 = channel.json_body["registration_tokens"][1]
+ self.assertIn(token_info_1["token"], tokens)
+ self.assertIn(token_info_2["token"], tokens)
+
+ def test_list_valid(self):
+ """Test listing just valid tokens."""
+ self._test_list_query_parameter(valid="true")
+
+ def test_list_invalid(self):
+ """Test listing just invalid tokens."""
+ self._test_list_query_parameter(valid="false")
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index fecda037a5..9f3ab2c985 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -24,6 +24,7 @@ from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
from synapse.rest.client import account, account_validity, login, logout, register, sync
+from synapse.storage._base import db_to_json
from tests import unittest
from tests.unittest import override_config
@@ -204,6 +205,371 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_requires_token(self):
+ username = "kermit"
+ device_id = "frogfone"
+ token = "abcd"
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ },
+ )
+ )
+ params = {
+ "username": username,
+ "password": "monkey",
+ "device_id": device_id,
+ }
+
+ # Request without auth to get flows and session
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ flows = channel.json_body["flows"]
+ # Synapse adds a dummy stage to differentiate flows where otherwise one
+ # flow would be a subset of another flow.
+ self.assertCountEqual(
+ [[LoginType.REGISTRATION_TOKEN, LoginType.DUMMY]],
+ (f["stages"] for f in flows),
+ )
+ session = channel.json_body["session"]
+
+ # Do the registration token stage and check it has completed
+ params["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session,
+ }
+ request_data = json.dumps(params)
+ channel = self.make_request(b"POST", self.url, request_data)
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ completed = channel.json_body["completed"]
+ self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
+
+ # Do the m.login.dummy stage and check registration was successful
+ params["auth"] = {
+ "type": LoginType.DUMMY,
+ "session": session,
+ }
+ request_data = json.dumps(params)
+ channel = self.make_request(b"POST", self.url, request_data)
+ det_data = {
+ "user_id": f"@{username}:{self.hs.hostname}",
+ "home_server": self.hs.hostname,
+ "device_id": device_id,
+ }
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ self.assertDictContainsSubset(det_data, channel.json_body)
+
+ # Check the `completed` counter has been incremented and pending is 0
+ res = self.get_success(
+ store.db_pool.simple_select_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["pending", "completed"],
+ )
+ )
+ self.assertEquals(res["completed"], 1)
+ self.assertEquals(res["pending"], 0)
+
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_token_invalid(self):
+ params = {
+ "username": "kermit",
+ "password": "monkey",
+ }
+ # Request without auth to get session
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ session = channel.json_body["session"]
+
+ # Test with token param missing (invalid)
+ params["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "session": session,
+ }
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.MISSING_PARAM)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ # Test with non-string (invalid)
+ params["auth"]["token"] = 1234
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.INVALID_PARAM)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ # Test with unknown token (invalid)
+ params["auth"]["token"] = "1234"
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_token_limit_uses(self):
+ token = "abcd"
+ store = self.hs.get_datastore()
+ # Create token that can be used once
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": 1,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ },
+ )
+ )
+ params1 = {"username": "bert", "password": "monkey"}
+ params2 = {"username": "ernie", "password": "monkey"}
+ # Do 2 requests without auth to get two session IDs
+ channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
+ session1 = channel1.json_body["session"]
+ channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
+ session2 = channel2.json_body["session"]
+
+ # Use token with session1 and check `pending` is 1
+ params1["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session1,
+ }
+ self.make_request(b"POST", self.url, json.dumps(params1))
+ # Repeat request to make sure pending isn't increased again
+ self.make_request(b"POST", self.url, json.dumps(params1))
+ pending = self.get_success(
+ store.db_pool.simple_select_one_onecol(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcol="pending",
+ )
+ )
+ self.assertEquals(pending, 1)
+
+ # Check auth fails when using token with session2
+ params2["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session2,
+ }
+ channel = self.make_request(b"POST", self.url, json.dumps(params2))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ # Complete registration with session1
+ params1["auth"]["type"] = LoginType.DUMMY
+ self.make_request(b"POST", self.url, json.dumps(params1))
+ # Check pending=0 and completed=1
+ res = self.get_success(
+ store.db_pool.simple_select_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["pending", "completed"],
+ )
+ )
+ self.assertEquals(res["pending"], 0)
+ self.assertEquals(res["completed"], 1)
+
+ # Check auth still fails when using token with session2
+ channel = self.make_request(b"POST", self.url, json.dumps(params2))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_token_expiry(self):
+ token = "abcd"
+ now = self.hs.get_clock().time_msec()
+ store = self.hs.get_datastore()
+ # Create token that expired yesterday
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": now - 24 * 60 * 60 * 1000,
+ },
+ )
+ )
+ params = {"username": "kermit", "password": "monkey"}
+ # Request without auth to get session
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ session = channel.json_body["session"]
+
+ # Check authentication fails with expired token
+ params["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session,
+ }
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ # Update token so it expires tomorrow
+ self.get_success(
+ store.db_pool.simple_update_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ updatevalues={"expiry_time": now + 24 * 60 * 60 * 1000},
+ )
+ )
+
+ # Check authentication succeeds
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ completed = channel.json_body["completed"]
+ self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
+
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_token_session_expiry(self):
+ """Test `pending` is decremented when an uncompleted session expires."""
+ token = "abcd"
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ },
+ )
+ )
+
+ # Do 2 requests without auth to get two session IDs
+ params1 = {"username": "bert", "password": "monkey"}
+ params2 = {"username": "ernie", "password": "monkey"}
+ channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
+ session1 = channel1.json_body["session"]
+ channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
+ session2 = channel2.json_body["session"]
+
+ # Use token with both sessions
+ params1["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session1,
+ }
+ self.make_request(b"POST", self.url, json.dumps(params1))
+
+ params2["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session2,
+ }
+ self.make_request(b"POST", self.url, json.dumps(params2))
+
+ # Complete registration with session1
+ params1["auth"]["type"] = LoginType.DUMMY
+ self.make_request(b"POST", self.url, json.dumps(params1))
+
+ # Check `result` of registration token stage for session1 is `True`
+ result1 = self.get_success(
+ store.db_pool.simple_select_one_onecol(
+ "ui_auth_sessions_credentials",
+ keyvalues={
+ "session_id": session1,
+ "stage_type": LoginType.REGISTRATION_TOKEN,
+ },
+ retcol="result",
+ )
+ )
+ self.assertTrue(db_to_json(result1))
+
+ # Check `result` for session2 is the token used
+ result2 = self.get_success(
+ store.db_pool.simple_select_one_onecol(
+ "ui_auth_sessions_credentials",
+ keyvalues={
+ "session_id": session2,
+ "stage_type": LoginType.REGISTRATION_TOKEN,
+ },
+ retcol="result",
+ )
+ )
+ self.assertEquals(db_to_json(result2), token)
+
+ # Delete both sessions (mimics expiry)
+ self.get_success(
+ store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
+ )
+
+ # Check pending is now 0
+ pending = self.get_success(
+ store.db_pool.simple_select_one_onecol(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcol="pending",
+ )
+ )
+ self.assertEquals(pending, 0)
+
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_token_session_expiry_deleted_token(self):
+ """Test session expiry doesn't break when the token is deleted.
+
+ 1. Start but don't complete UIA with a registration token
+ 2. Delete the token from the database
+ 3. Expire the session
+ """
+ token = "abcd"
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ },
+ )
+ )
+
+ # Do request without auth to get a session ID
+ params = {"username": "kermit", "password": "monkey"}
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ session = channel.json_body["session"]
+
+ # Use token
+ params["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session,
+ }
+ self.make_request(b"POST", self.url, json.dumps(params))
+
+ # Delete token
+ self.get_success(
+ store.db_pool.simple_delete_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ )
+ )
+
+ # Delete session (mimics expiry)
+ self.get_success(
+ store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
+ )
+
def test_advertised_flows(self):
channel = self.make_request(b"POST", self.url, b"{}")
self.assertEquals(channel.result["code"], b"401", channel.result)
@@ -744,3 +1110,71 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
self.assertLessEqual(res, now_ms + self.validity_period)
+
+
+class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
+ servlets = [register.register_servlets]
+ url = "/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity"
+
+ def default_config(self):
+ config = super().default_config()
+ config["registration_requires_token"] = True
+ return config
+
+ def test_GET_token_valid(self):
+ token = "abcd"
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ },
+ )
+ )
+
+ channel = self.make_request(
+ b"GET",
+ f"{self.url}?token={token}",
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ self.assertEquals(channel.json_body["valid"], True)
+
+ def test_GET_token_invalid(self):
+ token = "1234"
+ channel = self.make_request(
+ b"GET",
+ f"{self.url}?token={token}",
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ self.assertEquals(channel.json_body["valid"], False)
+
+ @override_config(
+ {"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}}
+ )
+ def test_GET_ratelimiting(self):
+ token = "1234"
+
+ for i in range(0, 6):
+ channel = self.make_request(
+ b"GET",
+ f"{self.url}?token={token}",
+ )
+
+ if i == 5:
+ self.assertEquals(channel.result["code"], b"429", channel.result)
+ retry_after_ms = int(channel.json_body["retry_after_ms"])
+ else:
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
+
+ channel = self.make_request(
+ b"GET",
+ f"{self.url}?token={token}",
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
|