summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorBrendan Abolivier <babolivier@matrix.org>2022-09-29 14:23:24 +0100
committerGitHub <noreply@github.com>2022-09-29 15:23:24 +0200
commitbe76cd8200b18f3c68b895f85ac7ef5b0ddc2466 (patch)
treec32c1b1e7a835b9f970fc4ae28ded0dd858ff841 /tests
parentExplicit cast to enforce type hints. (#13939) (diff)
downloadsynapse-be76cd8200b18f3c68b895f85ac7ef5b0ddc2466.tar.xz
Allow admins to require a manual approval process before new accounts can be used (using MSC3866) (#13556)
Diffstat (limited to 'tests')
-rw-r--r--tests/rest/admin/test_user.py186
-rw-r--r--tests/rest/client/test_auth.py33
-rw-r--r--tests/rest/client/test_login.py41
-rw-r--r--tests/rest/client/test_register.py32
-rw-r--r--tests/rest/client/utils.py12
-rw-r--r--tests/storage/test_registration.py102
6 files changed, 398 insertions, 8 deletions
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 1847e6ad6b..4c1ce33463 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -25,10 +25,10 @@ from parameterized import parameterized, parameterized_class
 from twisted.test.proto_helpers import MemoryReactor
 
 import synapse.rest.admin
-from synapse.api.constants import UserTypes
+from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes
 from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
 from synapse.api.room_versions import RoomVersions
-from synapse.rest.client import devices, login, logout, profile, room, sync
+from synapse.rest.client import devices, login, logout, profile, register, room, sync
 from synapse.rest.media.v1.filepath import MediaFilePaths
 from synapse.server import HomeServer
 from synapse.types import JsonDict, UserID
@@ -578,6 +578,16 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         _search_test(None, "foo", "user_id")
         _search_test(None, "bar", "user_id")
 
+    @override_config(
+        {
+            "experimental_features": {
+                "msc3866": {
+                    "enabled": True,
+                    "require_approval_for_new_accounts": True,
+                }
+            }
+        }
+    )
     def test_invalid_parameter(self) -> None:
         """
         If parameters are invalid, an error is returned.
@@ -623,6 +633,16 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
+        # invalid approved
+        channel = self.make_request(
+            "GET",
+            self.url + "?approved=not_bool",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
         # unkown order_by
         channel = self.make_request(
             "GET",
@@ -841,6 +861,69 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         self._order_test([self.admin_user, user1, user2], "creation_ts", "f")
         self._order_test([user2, user1, self.admin_user], "creation_ts", "b")
 
+    @override_config(
+        {
+            "experimental_features": {
+                "msc3866": {
+                    "enabled": True,
+                    "require_approval_for_new_accounts": True,
+                }
+            }
+        }
+    )
+    def test_filter_out_approved(self) -> None:
+        """Tests that the endpoint can filter out approved users."""
+        # Create our users.
+        self._create_users(2)
+
+        # Get the list of users.
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, channel.code, channel.result)
+
+        # Exclude the admin, because we don't want to accidentally un-approve the admin.
+        non_admin_user_ids = [
+            user["name"]
+            for user in channel.json_body["users"]
+            if user["name"] != self.admin_user
+        ]
+
+        self.assertEqual(2, len(non_admin_user_ids), non_admin_user_ids)
+
+        # Select a user and un-approve them. We do this rather than the other way around
+        # because, since these users are created by an admin, we consider them already
+        # approved.
+        not_approved_user = non_admin_user_ids[0]
+
+        channel = self.make_request(
+            "PUT",
+            f"/_synapse/admin/v2/users/{not_approved_user}",
+            {"approved": False},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, channel.code, channel.result)
+
+        # Now get the list of users again, this time filtering out approved users.
+        channel = self.make_request(
+            "GET",
+            self.url + "?approved=false",
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, channel.code, channel.result)
+
+        non_admin_user_ids = [
+            user["name"]
+            for user in channel.json_body["users"]
+            if user["name"] != self.admin_user
+        ]
+
+        # We should only have our unapproved user now.
+        self.assertEqual(1, len(non_admin_user_ids), non_admin_user_ids)
+        self.assertEqual(not_approved_user, non_admin_user_ids[0])
+
     def _order_test(
         self,
         expected_user_list: List[str],
@@ -1272,6 +1355,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         synapse.rest.admin.register_servlets,
         login.register_servlets,
         sync.register_servlets,
+        register.register_servlets,
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -2536,6 +2620,104 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         # Ensure they're still alive
         self.assertEqual(0, channel.json_body["deactivated"])
 
+    @override_config(
+        {
+            "experimental_features": {
+                "msc3866": {
+                    "enabled": True,
+                    "require_approval_for_new_accounts": True,
+                }
+            }
+        }
+    )
+    def test_approve_account(self) -> None:
+        """Tests that approving an account correctly sets the approved flag for the user."""
+        url = self.url_prefix % "@bob:test"
+
+        # Create the user using the client-server API since otherwise the user will be
+        # marked as approved automatically.
+        channel = self.make_request(
+            "POST",
+            "register",
+            {
+                "username": "bob",
+                "password": "test",
+                "auth": {"type": LoginType.DUMMY},
+            },
+        )
+        self.assertEqual(403, channel.code, channel.result)
+        self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
+        self.assertEqual(
+            ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
+        )
+
+        # Get user
+        channel = self.make_request(
+            "GET",
+            url,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertIs(False, channel.json_body["approved"])
+
+        # Approve user
+        channel = self.make_request(
+            "PUT",
+            url,
+            access_token=self.admin_user_tok,
+            content={"approved": True},
+        )
+
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertIs(True, channel.json_body["approved"])
+
+        # Check that the user is now approved
+        channel = self.make_request(
+            "GET",
+            url,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertIs(True, channel.json_body["approved"])
+
+    @override_config(
+        {
+            "experimental_features": {
+                "msc3866": {
+                    "enabled": True,
+                    "require_approval_for_new_accounts": True,
+                }
+            }
+        }
+    )
+    def test_register_approved(self) -> None:
+        url = self.url_prefix % "@bob:test"
+
+        # Create user
+        channel = self.make_request(
+            "PUT",
+            url,
+            access_token=self.admin_user_tok,
+            content={"password": "abc123", "approved": True},
+        )
+
+        self.assertEqual(201, channel.code, msg=channel.json_body)
+        self.assertEqual("@bob:test", channel.json_body["name"])
+        self.assertEqual(1, channel.json_body["approved"])
+
+        # Get user
+        channel = self.make_request(
+            "GET",
+            url,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual("@bob:test", channel.json_body["name"])
+        self.assertEqual(1, channel.json_body["approved"])
+
     def _is_erased(self, user_id: str, expect: bool) -> None:
         """Assert that the user is erased or not"""
         d = self.store.is_user_erased(user_id)
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 05355c7fb6..090cef5216 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -20,7 +20,8 @@ from twisted.test.proto_helpers import MemoryReactor
 from twisted.web.resource import Resource
 
 import synapse.rest.admin
-from synapse.api.constants import LoginType
+from synapse.api.constants import ApprovalNoticeMedium, LoginType
+from synapse.api.errors import Codes
 from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
 from synapse.rest.client import account, auth, devices, login, logout, register
 from synapse.rest.synapse.client import build_synapse_client_resource_tree
@@ -567,6 +568,36 @@ class UIAuthTests(unittest.HomeserverTestCase):
             body={"auth": {"session": session_id}},
         )
 
+    @skip_unless(HAS_OIDC, "requires OIDC")
+    @override_config(
+        {
+            "oidc_config": TEST_OIDC_CONFIG,
+            "experimental_features": {
+                "msc3866": {
+                    "enabled": True,
+                    "require_approval_for_new_accounts": True,
+                }
+            },
+        }
+    )
+    def test_sso_not_approved(self) -> None:
+        """Tests that if we register a user via SSO while requiring approval for new
+        accounts, we still raise the correct error before logging the user in.
+        """
+        login_resp = self.helper.login_via_oidc("username", expected_status=403)
+
+        self.assertEqual(login_resp["errcode"], Codes.USER_AWAITING_APPROVAL)
+        self.assertEqual(
+            ApprovalNoticeMedium.NONE, login_resp["approval_notice_medium"]
+        )
+
+        # Check that we didn't register a device for the user during the login attempt.
+        devices = self.get_success(
+            self.hs.get_datastores().main.get_devices_by_user("@username:test")
+        )
+
+        self.assertEqual(len(devices), 0)
+
 
 class RefreshAuthTests(unittest.HomeserverTestCase):
     servlets = [
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index e2a4d98275..e801ba8c8b 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -23,6 +23,8 @@ from twisted.test.proto_helpers import MemoryReactor
 from twisted.web.resource import Resource
 
 import synapse.rest.admin
+from synapse.api.constants import ApprovalNoticeMedium, LoginType
+from synapse.api.errors import Codes
 from synapse.appservice import ApplicationService
 from synapse.rest.client import devices, login, logout, register
 from synapse.rest.client.account import WhoamiRestServlet
@@ -94,6 +96,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         logout.register_servlets,
         devices.register_servlets,
         lambda hs, http_server: WhoamiRestServlet(hs).register(http_server),
+        register.register_servlets,
     ]
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
@@ -406,6 +409,44 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 400)
         self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM")
 
+    @override_config(
+        {
+            "experimental_features": {
+                "msc3866": {
+                    "enabled": True,
+                    "require_approval_for_new_accounts": True,
+                }
+            }
+        }
+    )
+    def test_require_approval(self) -> None:
+        channel = self.make_request(
+            "POST",
+            "register",
+            {
+                "username": "kermit",
+                "password": "monkey",
+                "auth": {"type": LoginType.DUMMY},
+            },
+        )
+        self.assertEqual(403, channel.code, channel.result)
+        self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
+        self.assertEqual(
+            ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
+        )
+
+        params = {
+            "type": LoginType.PASSWORD,
+            "identifier": {"type": "m.id.user", "user": "kermit"},
+            "password": "monkey",
+        }
+        channel = self.make_request("POST", LOGIN_URL, params)
+        self.assertEqual(403, channel.code, channel.result)
+        self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
+        self.assertEqual(
+            ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
+        )
+
 
 @skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
 class MultiSSOTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index b781875d52..11cf3939d8 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -22,7 +22,11 @@ import pkg_resources
 from twisted.test.proto_helpers import MemoryReactor
 
 import synapse.rest.admin
-from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
+from synapse.api.constants import (
+    APP_SERVICE_REGISTRATION_TYPE,
+    ApprovalNoticeMedium,
+    LoginType,
+)
 from synapse.api.errors import Codes
 from synapse.appservice import ApplicationService
 from synapse.rest.client import account, account_validity, login, logout, register, sync
@@ -765,6 +769,32 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 400, channel.json_body)
         self.assertEqual(channel.json_body["errcode"], Codes.USER_IN_USE)
 
+    @override_config(
+        {
+            "experimental_features": {
+                "msc3866": {
+                    "enabled": True,
+                    "require_approval_for_new_accounts": True,
+                }
+            }
+        }
+    )
+    def test_require_approval(self) -> None:
+        channel = self.make_request(
+            "POST",
+            "register",
+            {
+                "username": "kermit",
+                "password": "monkey",
+                "auth": {"type": LoginType.DUMMY},
+            },
+        )
+        self.assertEqual(403, channel.code, channel.result)
+        self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
+        self.assertEqual(
+            ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
+        )
+
 
 class AccountValidityTestCase(unittest.HomeserverTestCase):
 
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index dd26145bf8..c249a42bb6 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -543,8 +543,12 @@ class RestHelper:
 
         return channel.json_body
 
-    def login_via_oidc(self, remote_user_id: str) -> JsonDict:
-        """Log in (as a new user) via OIDC
+    def login_via_oidc(
+        self,
+        remote_user_id: str,
+        expected_status: int = 200,
+    ) -> JsonDict:
+        """Log in via OIDC
 
         Returns the result of the final token login.
 
@@ -578,7 +582,9 @@ class RestHelper:
             "/login",
             content={"type": "m.login.token", "token": login_token},
         )
-        assert channel.code == HTTPStatus.OK
+        assert (
+            channel.code == expected_status
+        ), f"unexpected status in response: {channel.code}"
         return channel.json_body
 
     def auth_via_oidc(
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 853a93afab..05ea802008 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -16,9 +16,10 @@ from twisted.test.proto_helpers import MemoryReactor
 from synapse.api.constants import UserTypes
 from synapse.api.errors import ThreepidValidationError
 from synapse.server import HomeServer
+from synapse.types import JsonDict, UserID
 from synapse.util import Clock
 
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, override_config
 
 
 class RegistrationStoreTestCase(HomeserverTestCase):
@@ -48,6 +49,7 @@ class RegistrationStoreTestCase(HomeserverTestCase):
                 "user_type": None,
                 "deactivated": 0,
                 "shadow_banned": 0,
+                "approved": 1,
             },
             (self.get_success(self.store.get_user_by_id(self.user_id))),
         )
@@ -166,3 +168,101 @@ class RegistrationStoreTestCase(HomeserverTestCase):
             ThreepidValidationError,
         )
         self.assertEqual(e.value.msg, "Validation token not found or has expired", e)
+
+
+class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
+    def default_config(self) -> JsonDict:
+        config = super().default_config()
+
+        # If there's already some config for this feature in the default config, it
+        # means we're overriding it with @override_config. In this case we don't want
+        # to do anything more with it.
+        msc3866_config = config.get("experimental_features", {}).get("msc3866")
+        if msc3866_config is not None:
+            return config
+
+        # Require approval for all new accounts.
+        config["experimental_features"] = {
+            "msc3866": {
+                "enabled": True,
+                "require_approval_for_new_accounts": True,
+            }
+        }
+        return config
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastores().main
+        self.user_id = "@my-user:test"
+        self.pwhash = "{xx1}123456789"
+
+    @override_config(
+        {
+            "experimental_features": {
+                "msc3866": {
+                    "enabled": True,
+                    "require_approval_for_new_accounts": False,
+                }
+            }
+        }
+    )
+    def test_approval_not_required(self) -> None:
+        """Tests that if we don't require approval for new accounts, newly created
+        accounts are automatically marked as approved.
+        """
+        self.get_success(self.store.register_user(self.user_id, self.pwhash))
+
+        user = self.get_success(self.store.get_user_by_id(self.user_id))
+        assert user is not None
+        self.assertTrue(user["approved"])
+
+        approved = self.get_success(self.store.is_user_approved(self.user_id))
+        self.assertTrue(approved)
+
+    def test_approval_required(self) -> None:
+        """Tests that if we require approval for new accounts, newly created accounts
+        are not automatically marked as approved.
+        """
+        self.get_success(self.store.register_user(self.user_id, self.pwhash))
+
+        user = self.get_success(self.store.get_user_by_id(self.user_id))
+        assert user is not None
+        self.assertFalse(user["approved"])
+
+        approved = self.get_success(self.store.is_user_approved(self.user_id))
+        self.assertFalse(approved)
+
+    def test_override(self) -> None:
+        """Tests that if we require approval for new accounts, but we explicitly say the
+        new user should be considered approved, they're marked as approved.
+        """
+        self.get_success(
+            self.store.register_user(
+                self.user_id,
+                self.pwhash,
+                approved=True,
+            )
+        )
+
+        user = self.get_success(self.store.get_user_by_id(self.user_id))
+        self.assertIsNotNone(user)
+        assert user is not None
+        self.assertEqual(user["approved"], 1)
+
+        approved = self.get_success(self.store.is_user_approved(self.user_id))
+        self.assertTrue(approved)
+
+    def test_approve_user(self) -> None:
+        """Tests that approving the user updates their approval status."""
+        self.get_success(self.store.register_user(self.user_id, self.pwhash))
+
+        approved = self.get_success(self.store.is_user_approved(self.user_id))
+        self.assertFalse(approved)
+
+        self.get_success(
+            self.store.update_user_approval_status(
+                UserID.from_string(self.user_id), True
+            )
+        )
+
+        approved = self.get_success(self.store.is_user_approved(self.user_id))
+        self.assertTrue(approved)