summary refs log tree commit diff
path: root/tests/storage/test_registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage/test_registration.py')
-rw-r--r--tests/storage/test_registration.py102
1 files changed, 101 insertions, 1 deletions
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 853a93afab..05ea802008 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -16,9 +16,10 @@ from twisted.test.proto_helpers import MemoryReactor
 from synapse.api.constants import UserTypes
 from synapse.api.errors import ThreepidValidationError
 from synapse.server import HomeServer
+from synapse.types import JsonDict, UserID
 from synapse.util import Clock
 
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, override_config
 
 
 class RegistrationStoreTestCase(HomeserverTestCase):
@@ -48,6 +49,7 @@ class RegistrationStoreTestCase(HomeserverTestCase):
                 "user_type": None,
                 "deactivated": 0,
                 "shadow_banned": 0,
+                "approved": 1,
             },
             (self.get_success(self.store.get_user_by_id(self.user_id))),
         )
@@ -166,3 +168,101 @@ class RegistrationStoreTestCase(HomeserverTestCase):
             ThreepidValidationError,
         )
         self.assertEqual(e.value.msg, "Validation token not found or has expired", e)
+
+
+class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
+    def default_config(self) -> JsonDict:
+        config = super().default_config()
+
+        # If there's already some config for this feature in the default config, it
+        # means we're overriding it with @override_config. In this case we don't want
+        # to do anything more with it.
+        msc3866_config = config.get("experimental_features", {}).get("msc3866")
+        if msc3866_config is not None:
+            return config
+
+        # Require approval for all new accounts.
+        config["experimental_features"] = {
+            "msc3866": {
+                "enabled": True,
+                "require_approval_for_new_accounts": True,
+            }
+        }
+        return config
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastores().main
+        self.user_id = "@my-user:test"
+        self.pwhash = "{xx1}123456789"
+
+    @override_config(
+        {
+            "experimental_features": {
+                "msc3866": {
+                    "enabled": True,
+                    "require_approval_for_new_accounts": False,
+                }
+            }
+        }
+    )
+    def test_approval_not_required(self) -> None:
+        """Tests that if we don't require approval for new accounts, newly created
+        accounts are automatically marked as approved.
+        """
+        self.get_success(self.store.register_user(self.user_id, self.pwhash))
+
+        user = self.get_success(self.store.get_user_by_id(self.user_id))
+        assert user is not None
+        self.assertTrue(user["approved"])
+
+        approved = self.get_success(self.store.is_user_approved(self.user_id))
+        self.assertTrue(approved)
+
+    def test_approval_required(self) -> None:
+        """Tests that if we require approval for new accounts, newly created accounts
+        are not automatically marked as approved.
+        """
+        self.get_success(self.store.register_user(self.user_id, self.pwhash))
+
+        user = self.get_success(self.store.get_user_by_id(self.user_id))
+        assert user is not None
+        self.assertFalse(user["approved"])
+
+        approved = self.get_success(self.store.is_user_approved(self.user_id))
+        self.assertFalse(approved)
+
+    def test_override(self) -> None:
+        """Tests that if we require approval for new accounts, but we explicitly say the
+        new user should be considered approved, they're marked as approved.
+        """
+        self.get_success(
+            self.store.register_user(
+                self.user_id,
+                self.pwhash,
+                approved=True,
+            )
+        )
+
+        user = self.get_success(self.store.get_user_by_id(self.user_id))
+        self.assertIsNotNone(user)
+        assert user is not None
+        self.assertEqual(user["approved"], 1)
+
+        approved = self.get_success(self.store.is_user_approved(self.user_id))
+        self.assertTrue(approved)
+
+    def test_approve_user(self) -> None:
+        """Tests that approving the user updates their approval status."""
+        self.get_success(self.store.register_user(self.user_id, self.pwhash))
+
+        approved = self.get_success(self.store.is_user_approved(self.user_id))
+        self.assertFalse(approved)
+
+        self.get_success(
+            self.store.update_user_approval_status(
+                UserID.from_string(self.user_id), True
+            )
+        )
+
+        approved = self.get_success(self.store.is_user_approved(self.user_id))
+        self.assertTrue(approved)