summary refs log tree commit diff
path: root/synapse/rest
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest')
-rw-r--r--synapse/rest/admin/__init__.py8
-rw-r--r--synapse/rest/admin/registration_tokens.py321
-rw-r--r--synapse/rest/client/auth.py24
-rw-r--r--synapse/rest/client/register.py72
4 files changed, 425 insertions, 0 deletions
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 7f3051aef1..6e1c8736e1 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -36,6 +36,11 @@ from synapse.rest.admin.event_reports import (
 )
 from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
 from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
+from synapse.rest.admin.registration_tokens import (
+    ListRegistrationTokensRestServlet,
+    NewRegistrationTokenRestServlet,
+    RegistrationTokenRestServlet,
+)
 from synapse.rest.admin.rooms import (
     DeleteRoomRestServlet,
     ForwardExtremitiesRestServlet,
@@ -238,6 +243,9 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     RoomEventContextServlet(hs).register(http_server)
     RateLimitRestServlet(hs).register(http_server)
     UsernameAvailableRestServlet(hs).register(http_server)
+    ListRegistrationTokensRestServlet(hs).register(http_server)
+    NewRegistrationTokenRestServlet(hs).register(http_server)
+    RegistrationTokenRestServlet(hs).register(http_server)
 
 
 def register_servlets_for_client_rest_resource(
diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py
new file mode 100644
index 0000000000..5a1c929d85
--- /dev/null
+++ b/synapse/rest/admin/registration_tokens.py
@@ -0,0 +1,321 @@
+# Copyright 2021 Callum Brown
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import string
+from typing import TYPE_CHECKING, Tuple
+
+from synapse.api.errors import Codes, NotFoundError, SynapseError
+from synapse.http.servlet import (
+    RestServlet,
+    parse_boolean,
+    parse_json_object_from_request,
+)
+from synapse.http.site import SynapseRequest
+from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class ListRegistrationTokensRestServlet(RestServlet):
+    """List registration tokens.
+
+    To list all tokens:
+
+        GET /_synapse/admin/v1/registration_tokens
+
+        200 OK
+
+        {
+            "registration_tokens": [
+                {
+                    "token": "abcd",
+                    "uses_allowed": 3,
+                    "pending": 0,
+                    "completed": 1,
+                    "expiry_time": null
+                },
+                {
+                    "token": "wxyz",
+                    "uses_allowed": null,
+                    "pending": 0,
+                    "completed": 9,
+                    "expiry_time": 1625394937000
+                }
+            ]
+        }
+
+    The optional query parameter `valid` can be used to filter the response.
+    If it is `true`, only valid tokens are returned. If it is `false`, only
+    tokens that have expired or have had all uses exhausted are returned.
+    If it is omitted, all tokens are returned regardless of validity.
+    """
+
+    PATTERNS = admin_patterns("/registration_tokens$")
+
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
+
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        await assert_requester_is_admin(self.auth, request)
+        valid = parse_boolean(request, "valid")
+        token_list = await self.store.get_registration_tokens(valid)
+        return 200, {"registration_tokens": token_list}
+
+
+class NewRegistrationTokenRestServlet(RestServlet):
+    """Create a new registration token.
+
+    For example, to create a token specifying some fields:
+
+        POST /_synapse/admin/v1/registration_tokens/new
+
+        {
+            "token": "defg",
+            "uses_allowed": 1
+        }
+
+        200 OK
+
+        {
+            "token": "defg",
+            "uses_allowed": 1,
+            "pending": 0,
+            "completed": 0,
+            "expiry_time": null
+        }
+
+    Defaults are used for any fields not specified.
+    """
+
+    PATTERNS = admin_patterns("/registration_tokens/new$")
+
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
+        self.clock = hs.get_clock()
+        # A string of all the characters allowed to be in a registration_token
+        self.allowed_chars = string.ascii_letters + string.digits + "-_"
+        self.allowed_chars_set = set(self.allowed_chars)
+
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        await assert_requester_is_admin(self.auth, request)
+        body = parse_json_object_from_request(request)
+
+        if "token" in body:
+            token = body["token"]
+            if not isinstance(token, str):
+                raise SynapseError(400, "token must be a string", Codes.INVALID_PARAM)
+            if not (0 < len(token) <= 64):
+                raise SynapseError(
+                    400,
+                    "token must not be empty and must not be longer than 64 characters",
+                    Codes.INVALID_PARAM,
+                )
+            if not set(token).issubset(self.allowed_chars_set):
+                raise SynapseError(
+                    400,
+                    "token must consist only of characters matched by the regex [A-Za-z0-9-_]",
+                    Codes.INVALID_PARAM,
+                )
+
+        else:
+            # Get length of token to generate (default is 16)
+            length = body.get("length", 16)
+            if not isinstance(length, int):
+                raise SynapseError(
+                    400, "length must be an integer", Codes.INVALID_PARAM
+                )
+            if not (0 < length <= 64):
+                raise SynapseError(
+                    400,
+                    "length must be greater than zero and not greater than 64",
+                    Codes.INVALID_PARAM,
+                )
+
+            # Generate token
+            token = await self.store.generate_registration_token(
+                length, self.allowed_chars
+            )
+
+        uses_allowed = body.get("uses_allowed", None)
+        if not (
+            uses_allowed is None
+            or (isinstance(uses_allowed, int) and uses_allowed >= 0)
+        ):
+            raise SynapseError(
+                400,
+                "uses_allowed must be a non-negative integer or null",
+                Codes.INVALID_PARAM,
+            )
+
+        expiry_time = body.get("expiry_time", None)
+        if not isinstance(expiry_time, (int, type(None))):
+            raise SynapseError(
+                400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
+            )
+        if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
+            raise SynapseError(
+                400, "expiry_time must not be in the past", Codes.INVALID_PARAM
+            )
+
+        created = await self.store.create_registration_token(
+            token, uses_allowed, expiry_time
+        )
+        if not created:
+            raise SynapseError(
+                400, f"Token already exists: {token}", Codes.INVALID_PARAM
+            )
+
+        resp = {
+            "token": token,
+            "uses_allowed": uses_allowed,
+            "pending": 0,
+            "completed": 0,
+            "expiry_time": expiry_time,
+        }
+        return 200, resp
+
+
+class RegistrationTokenRestServlet(RestServlet):
+    """Retrieve, update, or delete the given token.
+
+    For example,
+
+    to retrieve a token:
+
+        GET /_synapse/admin/v1/registration_tokens/abcd
+
+        200 OK
+
+        {
+            "token": "abcd",
+            "uses_allowed": 3,
+            "pending": 0,
+            "completed": 1,
+            "expiry_time": null
+        }
+
+
+    to update a token:
+
+        PUT /_synapse/admin/v1/registration_tokens/defg
+
+        {
+            "uses_allowed": 5,
+            "expiry_time": 4781243146000
+        }
+
+        200 OK
+
+        {
+            "token": "defg",
+            "uses_allowed": 5,
+            "pending": 0,
+            "completed": 0,
+            "expiry_time": 4781243146000
+        }
+
+
+    to delete a token:
+
+        DELETE /_synapse/admin/v1/registration_tokens/wxyz
+
+        200 OK
+
+        {}
+    """
+
+    PATTERNS = admin_patterns("/registration_tokens/(?P<token>[^/]*)$")
+
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+        self.clock = hs.get_clock()
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
+
+    async def on_GET(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]:
+        """Retrieve a registration token."""
+        await assert_requester_is_admin(self.auth, request)
+        token_info = await self.store.get_one_registration_token(token)
+
+        # If no result return a 404
+        if token_info is None:
+            raise NotFoundError(f"No such registration token: {token}")
+
+        return 200, token_info
+
+    async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]:
+        """Update a registration token."""
+        await assert_requester_is_admin(self.auth, request)
+        body = parse_json_object_from_request(request)
+        new_attributes = {}
+
+        # Only add uses_allowed to new_attributes if it is present and valid
+        if "uses_allowed" in body:
+            uses_allowed = body["uses_allowed"]
+            if not (
+                uses_allowed is None
+                or (isinstance(uses_allowed, int) and uses_allowed >= 0)
+            ):
+                raise SynapseError(
+                    400,
+                    "uses_allowed must be a non-negative integer or null",
+                    Codes.INVALID_PARAM,
+                )
+            new_attributes["uses_allowed"] = uses_allowed
+
+        if "expiry_time" in body:
+            expiry_time = body["expiry_time"]
+            if not isinstance(expiry_time, (int, type(None))):
+                raise SynapseError(
+                    400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
+                )
+            if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
+                raise SynapseError(
+                    400, "expiry_time must not be in the past", Codes.INVALID_PARAM
+                )
+            new_attributes["expiry_time"] = expiry_time
+
+        if len(new_attributes) == 0:
+            # Nothing to update, get token info to return
+            token_info = await self.store.get_one_registration_token(token)
+        else:
+            token_info = await self.store.update_registration_token(
+                token, new_attributes
+            )
+
+        # If no result return a 404
+        if token_info is None:
+            raise NotFoundError(f"No such registration token: {token}")
+
+        return 200, token_info
+
+    async def on_DELETE(
+        self, request: SynapseRequest, token: str
+    ) -> Tuple[int, JsonDict]:
+        """Delete a registration token."""
+        await assert_requester_is_admin(self.auth, request)
+
+        if await self.store.delete_registration_token(token):
+            return 200, {}
+
+        raise NotFoundError(f"No such registration token: {token}")
diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py
index 73284e48ec..91800c0278 100644
--- a/synapse/rest/client/auth.py
+++ b/synapse/rest/client/auth.py
@@ -46,6 +46,7 @@ class AuthRestServlet(RestServlet):
         self.registration_handler = hs.get_registration_handler()
         self.recaptcha_template = hs.config.recaptcha_template
         self.terms_template = hs.config.terms_template
+        self.registration_token_template = hs.config.registration_token_template
         self.success_template = hs.config.fallback_success_template
 
     async def on_GET(self, request, stagetype):
@@ -74,6 +75,12 @@ class AuthRestServlet(RestServlet):
             # re-authenticate with their SSO provider.
             html = await self.auth_handler.start_sso_ui_auth(request, session)
 
+        elif stagetype == LoginType.REGISTRATION_TOKEN:
+            html = self.registration_token_template.render(
+                session=session,
+                myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web",
+            )
+
         else:
             raise SynapseError(404, "Unknown auth stage type")
 
@@ -140,6 +147,23 @@ class AuthRestServlet(RestServlet):
             # The SSO fallback workflow should not post here,
             raise SynapseError(404, "Fallback SSO auth does not support POST requests.")
 
+        elif stagetype == LoginType.REGISTRATION_TOKEN:
+            token = parse_string(request, "token", required=True)
+            authdict = {"session": session, "token": token}
+
+            try:
+                await self.auth_handler.add_oob_auth(
+                    LoginType.REGISTRATION_TOKEN, authdict, request.getClientIP()
+                )
+            except LoginError as e:
+                html = self.registration_token_template.render(
+                    session=session,
+                    myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web",
+                    error=e.msg,
+                )
+            else:
+                html = self.success_template.render()
+
         else:
             raise SynapseError(404, "Unknown auth stage type")
 
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 58b8e8f261..2781a0ea96 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -28,6 +28,7 @@ from synapse.api.errors import (
     ThreepidValidationError,
     UnrecognizedRequestError,
 )
+from synapse.api.ratelimiting import Ratelimiter
 from synapse.config import ConfigError
 from synapse.config.captcha import CaptchaConfig
 from synapse.config.consent import ConsentConfig
@@ -379,6 +380,55 @@ class UsernameAvailabilityRestServlet(RestServlet):
             return 200, {"available": True}
 
 
+class RegistrationTokenValidityRestServlet(RestServlet):
+    """Check the validity of a registration token.
+
+    Example:
+
+        GET /_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity?token=abcd
+
+        200 OK
+
+        {
+            "valid": true
+        }
+    """
+
+    PATTERNS = client_patterns(
+        f"/org.matrix.msc3231/register/{LoginType.REGISTRATION_TOKEN}/validity",
+        releases=(),
+        unstable=True,
+    )
+
+    def __init__(self, hs):
+        """
+        Args:
+            hs (synapse.server.HomeServer): server
+        """
+        super().__init__()
+        self.hs = hs
+        self.store = hs.get_datastore()
+        self.ratelimiter = Ratelimiter(
+            store=self.store,
+            clock=hs.get_clock(),
+            rate_hz=hs.config.ratelimiting.rc_registration_token_validity.per_second,
+            burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count,
+        )
+
+    async def on_GET(self, request):
+        await self.ratelimiter.ratelimit(None, (request.getClientIP(),))
+
+        if not self.hs.config.enable_registration:
+            raise SynapseError(
+                403, "Registration has been disabled", errcode=Codes.FORBIDDEN
+            )
+
+        token = parse_string(request, "token", required=True)
+        valid = await self.store.registration_token_is_valid(token)
+
+        return 200, {"valid": valid}
+
+
 class RegisterRestServlet(RestServlet):
     PATTERNS = client_patterns("/register$")
 
@@ -686,6 +736,22 @@ class RegisterRestServlet(RestServlet):
         )
 
         if registered:
+            # Check if a token was used to authenticate registration
+            registration_token = await self.auth_handler.get_session_data(
+                session_id,
+                UIAuthSessionDataConstants.REGISTRATION_TOKEN,
+            )
+            if registration_token:
+                # Increment the `completed` counter for the token
+                await self.store.use_registration_token(registration_token)
+                # Indicate that the token has been successfully used so that
+                # pending is not decremented again when expiring old UIA sessions.
+                await self.store.mark_ui_auth_stage_complete(
+                    session_id,
+                    LoginType.REGISTRATION_TOKEN,
+                    True,
+                )
+
             await self.registration_handler.post_registration_actions(
                 user_id=registered_user_id,
                 auth_result=auth_result,
@@ -868,6 +934,11 @@ def _calculate_registration_flows(
         for flow in flows:
             flow.insert(0, LoginType.RECAPTCHA)
 
+    # Prepend registration token to all flows if we're requiring a token
+    if config.registration_requires_token:
+        for flow in flows:
+            flow.insert(0, LoginType.REGISTRATION_TOKEN)
+
     return flows
 
 
@@ -876,4 +947,5 @@ def register_servlets(hs, http_server):
     MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
     UsernameAvailabilityRestServlet(hs).register(http_server)
     RegistrationSubmitTokenServlet(hs).register(http_server)
+    RegistrationTokenValidityRestServlet(hs).register(http_server)
     RegisterRestServlet(hs).register(http_server)