summary refs log tree commit diff
path: root/tests/handlers/test_register.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers/test_register.py')
-rw-r--r--tests/handlers/test_register.py120
1 files changed, 84 insertions, 36 deletions
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index c51763f41a..a9fd3036dc 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -27,6 +27,58 @@ from tests.utils import mock_getRawHeaders
 from .. import unittest
 
 
+class TestSpamChecker:
+    def __init__(self, config, api):
+        api.register_spam_checker_callbacks(
+            check_registration_for_spam=self.check_registration_for_spam,
+        )
+
+    @staticmethod
+    def parse_config(config):
+        return config
+
+    async def check_registration_for_spam(
+        self,
+        email_threepid,
+        username,
+        request_info,
+        auth_provider_id,
+    ):
+        pass
+
+
+class DenyAll(TestSpamChecker):
+    async def check_registration_for_spam(
+        self,
+        email_threepid,
+        username,
+        request_info,
+        auth_provider_id,
+    ):
+        return RegistrationBehaviour.DENY
+
+
+class BanAll(TestSpamChecker):
+    async def check_registration_for_spam(
+        self,
+        email_threepid,
+        username,
+        request_info,
+        auth_provider_id,
+    ):
+        return RegistrationBehaviour.SHADOW_BAN
+
+
+class BanBadIdPUser(TestSpamChecker):
+    async def check_registration_for_spam(
+        self, email_threepid, username, request_info, auth_provider_id=None
+    ):
+        # Reject any user coming from CAS and whose username contains profanity
+        if auth_provider_id == "cas" and "flimflob" in username:
+            return RegistrationBehaviour.DENY
+        return RegistrationBehaviour.ALLOW
+
+
 class RegistrationTestCase(unittest.HomeserverTestCase):
     """Tests the RegistrationHandler."""
 
@@ -42,6 +94,11 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         hs_config["limit_usage_by_mau"] = True
 
         hs = self.setup_test_homeserver(config=hs_config)
+
+        module_api = hs.get_module_api()
+        for module, config in hs.config.modules.loaded_modules:
+            module(config=config, api=module_api)
+
         return hs
 
     def prepare(self, reactor, clock, hs):
@@ -465,34 +522,30 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
             self.handler.register_user(localpart=invalid_user_id), SynapseError
         )
 
+    @override_config(
+        {
+            "modules": [
+                {
+                    "module": TestSpamChecker.__module__ + ".DenyAll",
+                }
+            ]
+        }
+    )
     def test_spam_checker_deny(self):
         """A spam checker can deny registration, which results in an error."""
-
-        class DenyAll:
-            def check_registration_for_spam(
-                self, email_threepid, username, request_info
-            ):
-                return RegistrationBehaviour.DENY
-
-        # Configure a spam checker that denies all users.
-        spam_checker = self.hs.get_spam_checker()
-        spam_checker.spam_checkers = [DenyAll()]
-
         self.get_failure(self.handler.register_user(localpart="user"), SynapseError)
 
+    @override_config(
+        {
+            "modules": [
+                {
+                    "module": TestSpamChecker.__module__ + ".BanAll",
+                }
+            ]
+        }
+    )
     def test_spam_checker_shadow_ban(self):
         """A spam checker can choose to shadow-ban a user, which allows registration to succeed."""
-
-        class BanAll:
-            def check_registration_for_spam(
-                self, email_threepid, username, request_info
-            ):
-                return RegistrationBehaviour.SHADOW_BAN
-
-        # Configure a spam checker that denies all users.
-        spam_checker = self.hs.get_spam_checker()
-        spam_checker.spam_checkers = [BanAll()]
-
         user_id = self.get_success(self.handler.register_user(localpart="user"))
 
         # Get an access token.
@@ -512,22 +565,17 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
 
         self.assertTrue(requester.shadow_banned)
 
+    @override_config(
+        {
+            "modules": [
+                {
+                    "module": TestSpamChecker.__module__ + ".BanBadIdPUser",
+                }
+            ]
+        }
+    )
     def test_spam_checker_receives_sso_type(self):
         """Test rejecting registration based on SSO type"""
-
-        class BanBadIdPUser:
-            def check_registration_for_spam(
-                self, email_threepid, username, request_info, auth_provider_id=None
-            ):
-                # Reject any user coming from CAS and whose username contains profanity
-                if auth_provider_id == "cas" and "flimflob" in username:
-                    return RegistrationBehaviour.DENY
-                return RegistrationBehaviour.ALLOW
-
-        # Configure a spam checker that denies a certain user on a specific IdP
-        spam_checker = self.hs.get_spam_checker()
-        spam_checker.spam_checkers = [BanBadIdPUser()]
-
         f = self.get_failure(
             self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"),
             SynapseError,