summary refs log tree commit diff
path: root/synapse/rest
diff options
context:
space:
mode:
authorBrendan Abolivier <babolivier@matrix.org>2022-09-29 14:23:24 +0100
committerGitHub <noreply@github.com>2022-09-29 15:23:24 +0200
commitbe76cd8200b18f3c68b895f85ac7ef5b0ddc2466 (patch)
treec32c1b1e7a835b9f970fc4ae28ded0dd858ff841 /synapse/rest
parentExplicit cast to enforce type hints. (#13939) (diff)
downloadsynapse-be76cd8200b18f3c68b895f85ac7ef5b0ddc2466.tar.xz
Allow admins to require a manual approval process before new accounts can be used (using MSC3866) (#13556)
Diffstat (limited to 'synapse/rest')
-rw-r--r--synapse/rest/admin/users.py43
-rw-r--r--synapse/rest/client/login.py37
-rw-r--r--synapse/rest/client/register.py22
3 files changed, 96 insertions, 6 deletions
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 1274773d7e..15ac2059aa 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -69,6 +69,7 @@ class UsersRestServletV2(RestServlet):
         self.store = hs.get_datastores().main
         self.auth = hs.get_auth()
         self.admin_handler = hs.get_admin_handler()
+        self._msc3866_enabled = hs.config.experimental.msc3866.enabled
 
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
@@ -95,6 +96,13 @@ class UsersRestServletV2(RestServlet):
         guests = parse_boolean(request, "guests", default=True)
         deactivated = parse_boolean(request, "deactivated", default=False)
 
+        # If support for MSC3866 is not enabled, apply no filtering based on the
+        # `approved` column.
+        if self._msc3866_enabled:
+            approved = parse_boolean(request, "approved", default=True)
+        else:
+            approved = True
+
         order_by = parse_string(
             request,
             "order_by",
@@ -115,8 +123,22 @@ class UsersRestServletV2(RestServlet):
         direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
 
         users, total = await self.store.get_users_paginate(
-            start, limit, user_id, name, guests, deactivated, order_by, direction
+            start,
+            limit,
+            user_id,
+            name,
+            guests,
+            deactivated,
+            order_by,
+            direction,
+            approved,
         )
+
+        # If support for MSC3866 is not enabled, don't show the approval flag.
+        if not self._msc3866_enabled:
+            for user in users:
+                del user["approved"]
+
         ret = {"users": users, "total": total}
         if (start + limit) < total:
             ret["next_token"] = str(start + len(users))
@@ -163,6 +185,7 @@ class UserRestServletV2(RestServlet):
         self.deactivate_account_handler = hs.get_deactivate_account_handler()
         self.registration_handler = hs.get_registration_handler()
         self.pusher_pool = hs.get_pusherpool()
+        self._msc3866_enabled = hs.config.experimental.msc3866.enabled
 
     async def on_GET(
         self, request: SynapseRequest, user_id: str
@@ -239,6 +262,15 @@ class UserRestServletV2(RestServlet):
                 HTTPStatus.BAD_REQUEST, "'deactivated' parameter is not of type boolean"
             )
 
+        approved: Optional[bool] = None
+        if "approved" in body and self._msc3866_enabled:
+            approved = body["approved"]
+            if not isinstance(approved, bool):
+                raise SynapseError(
+                    HTTPStatus.BAD_REQUEST,
+                    "'approved' parameter is not of type boolean",
+                )
+
         # convert List[Dict[str, str]] into List[Tuple[str, str]]
         if external_ids is not None:
             new_external_ids = [
@@ -343,6 +375,9 @@ class UserRestServletV2(RestServlet):
             if "user_type" in body:
                 await self.store.set_user_type(target_user, user_type)
 
+            if approved is not None:
+                await self.store.update_user_approval_status(target_user, approved)
+
             user = await self.admin_handler.get_user(target_user)
             assert user is not None
 
@@ -355,6 +390,10 @@ class UserRestServletV2(RestServlet):
             if password is not None:
                 password_hash = await self.auth_handler.hash(password)
 
+            new_user_approved = True
+            if self._msc3866_enabled and approved is not None:
+                new_user_approved = approved
+
             user_id = await self.registration_handler.register_user(
                 localpart=target_user.localpart,
                 password_hash=password_hash,
@@ -362,6 +401,7 @@ class UserRestServletV2(RestServlet):
                 default_display_name=displayname,
                 user_type=user_type,
                 by_admin=True,
+                approved=new_user_approved,
             )
 
             if threepids is not None:
@@ -550,6 +590,7 @@ class UserRegisterServlet(RestServlet):
             user_type=user_type,
             default_display_name=displayname,
             by_admin=True,
+            approved=True,
         )
 
         result = await register._create_registration_details(user_id, body)
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index 0437c87d8d..f554586ac3 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -28,7 +28,14 @@ from typing import (
 
 from typing_extensions import TypedDict
 
-from synapse.api.errors import Codes, InvalidClientTokenError, LoginError, SynapseError
+from synapse.api.constants import ApprovalNoticeMedium
+from synapse.api.errors import (
+    Codes,
+    InvalidClientTokenError,
+    LoginError,
+    NotApprovedError,
+    SynapseError,
+)
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.api.urls import CLIENT_API_PREFIX
 from synapse.appservice import ApplicationService
@@ -55,11 +62,11 @@ logger = logging.getLogger(__name__)
 
 class LoginResponse(TypedDict, total=False):
     user_id: str
-    access_token: str
+    access_token: Optional[str]
     home_server: str
     expires_in_ms: Optional[int]
     refresh_token: Optional[str]
-    device_id: str
+    device_id: Optional[str]
     well_known: Optional[Dict[str, Any]]
 
 
@@ -92,6 +99,12 @@ class LoginRestServlet(RestServlet):
             hs.config.registration.refreshable_access_token_lifetime is not None
         )
 
+        # Whether we need to check if the user has been approved or not.
+        self._require_approval = (
+            hs.config.experimental.msc3866.enabled
+            and hs.config.experimental.msc3866.require_approval_for_new_accounts
+        )
+
         self.auth = hs.get_auth()
 
         self.clock = hs.get_clock()
@@ -220,6 +233,14 @@ class LoginRestServlet(RestServlet):
         except KeyError:
             raise SynapseError(400, "Missing JSON keys.")
 
+        if self._require_approval:
+            approved = await self.auth_handler.is_user_approved(result["user_id"])
+            if not approved:
+                raise NotApprovedError(
+                    msg="This account is pending approval by a server administrator.",
+                    approval_notice_medium=ApprovalNoticeMedium.NONE,
+                )
+
         well_known_data = self._well_known_builder.get_well_known()
         if well_known_data:
             result["well_known"] = well_known_data
@@ -356,6 +377,16 @@ class LoginRestServlet(RestServlet):
                 errcode=Codes.INVALID_PARAM,
             )
 
+        if self._require_approval:
+            approved = await self.auth_handler.is_user_approved(user_id)
+            if not approved:
+                # If the user isn't approved (and needs to be) we won't allow them to
+                # actually log in, so we don't want to create a device/access token.
+                return LoginResponse(
+                    user_id=user_id,
+                    home_server=self.hs.hostname,
+                )
+
         initial_display_name = login_submission.get("initial_device_display_name")
         (
             device_id,
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 20bab20c8f..de810ae3ec 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -21,10 +21,15 @@ from twisted.web.server import Request
 import synapse
 import synapse.api.auth
 import synapse.types
-from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
+from synapse.api.constants import (
+    APP_SERVICE_REGISTRATION_TYPE,
+    ApprovalNoticeMedium,
+    LoginType,
+)
 from synapse.api.errors import (
     Codes,
     InteractiveAuthIncompleteError,
+    NotApprovedError,
     SynapseError,
     ThreepidValidationError,
     UnrecognizedRequestError,
@@ -414,6 +419,11 @@ class RegisterRestServlet(RestServlet):
             hs.config.registration.inhibit_user_in_use_error
         )
 
+        self._require_approval = (
+            hs.config.experimental.msc3866.enabled
+            and hs.config.experimental.msc3866.require_approval_for_new_accounts
+        )
+
         self._registration_flows = _calculate_registration_flows(
             hs.config, self.auth_handler
         )
@@ -734,6 +744,12 @@ class RegisterRestServlet(RestServlet):
                 access_token=return_dict.get("access_token"),
             )
 
+            if self._require_approval:
+                raise NotApprovedError(
+                    msg="This account needs to be approved by an administrator before it can be used.",
+                    approval_notice_medium=ApprovalNoticeMedium.NONE,
+                )
+
         return 200, return_dict
 
     async def _do_appservice_registration(
@@ -778,7 +794,9 @@ class RegisterRestServlet(RestServlet):
             "user_id": user_id,
             "home_server": self.hs.hostname,
         }
-        if not params.get("inhibit_login", False):
+        # We don't want to log the user in if we're going to deny them access because
+        # they need to be approved first.
+        if not params.get("inhibit_login", False) and not self._require_approval:
             device_id = params.get("device_id")
             initial_display_name = params.get("initial_device_display_name")
             (