summary refs log tree commit diff
path: root/tests/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers')
-rw-r--r--tests/handlers/test_register.py120
-rw-r--r--tests/handlers/test_user_directory.py21
2 files changed, 93 insertions, 48 deletions
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index c51763f41a..a9fd3036dc 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -27,6 +27,58 @@ from tests.utils import mock_getRawHeaders
 from .. import unittest
 
 
+class TestSpamChecker:
+    def __init__(self, config, api):
+        api.register_spam_checker_callbacks(
+            check_registration_for_spam=self.check_registration_for_spam,
+        )
+
+    @staticmethod
+    def parse_config(config):
+        return config
+
+    async def check_registration_for_spam(
+        self,
+        email_threepid,
+        username,
+        request_info,
+        auth_provider_id,
+    ):
+        pass
+
+
+class DenyAll(TestSpamChecker):
+    async def check_registration_for_spam(
+        self,
+        email_threepid,
+        username,
+        request_info,
+        auth_provider_id,
+    ):
+        return RegistrationBehaviour.DENY
+
+
+class BanAll(TestSpamChecker):
+    async def check_registration_for_spam(
+        self,
+        email_threepid,
+        username,
+        request_info,
+        auth_provider_id,
+    ):
+        return RegistrationBehaviour.SHADOW_BAN
+
+
+class BanBadIdPUser(TestSpamChecker):
+    async def check_registration_for_spam(
+        self, email_threepid, username, request_info, auth_provider_id=None
+    ):
+        # Reject any user coming from CAS and whose username contains profanity
+        if auth_provider_id == "cas" and "flimflob" in username:
+            return RegistrationBehaviour.DENY
+        return RegistrationBehaviour.ALLOW
+
+
 class RegistrationTestCase(unittest.HomeserverTestCase):
     """Tests the RegistrationHandler."""
 
@@ -42,6 +94,11 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         hs_config["limit_usage_by_mau"] = True
 
         hs = self.setup_test_homeserver(config=hs_config)
+
+        module_api = hs.get_module_api()
+        for module, config in hs.config.modules.loaded_modules:
+            module(config=config, api=module_api)
+
         return hs
 
     def prepare(self, reactor, clock, hs):
@@ -465,34 +522,30 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
             self.handler.register_user(localpart=invalid_user_id), SynapseError
         )
 
+    @override_config(
+        {
+            "modules": [
+                {
+                    "module": TestSpamChecker.__module__ + ".DenyAll",
+                }
+            ]
+        }
+    )
     def test_spam_checker_deny(self):
         """A spam checker can deny registration, which results in an error."""
-
-        class DenyAll:
-            def check_registration_for_spam(
-                self, email_threepid, username, request_info
-            ):
-                return RegistrationBehaviour.DENY
-
-        # Configure a spam checker that denies all users.
-        spam_checker = self.hs.get_spam_checker()
-        spam_checker.spam_checkers = [DenyAll()]
-
         self.get_failure(self.handler.register_user(localpart="user"), SynapseError)
 
+    @override_config(
+        {
+            "modules": [
+                {
+                    "module": TestSpamChecker.__module__ + ".BanAll",
+                }
+            ]
+        }
+    )
     def test_spam_checker_shadow_ban(self):
         """A spam checker can choose to shadow-ban a user, which allows registration to succeed."""
-
-        class BanAll:
-            def check_registration_for_spam(
-                self, email_threepid, username, request_info
-            ):
-                return RegistrationBehaviour.SHADOW_BAN
-
-        # Configure a spam checker that denies all users.
-        spam_checker = self.hs.get_spam_checker()
-        spam_checker.spam_checkers = [BanAll()]
-
         user_id = self.get_success(self.handler.register_user(localpart="user"))
 
         # Get an access token.
@@ -512,22 +565,17 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
 
         self.assertTrue(requester.shadow_banned)
 
+    @override_config(
+        {
+            "modules": [
+                {
+                    "module": TestSpamChecker.__module__ + ".BanBadIdPUser",
+                }
+            ]
+        }
+    )
     def test_spam_checker_receives_sso_type(self):
         """Test rejecting registration based on SSO type"""
-
-        class BanBadIdPUser:
-            def check_registration_for_spam(
-                self, email_threepid, username, request_info, auth_provider_id=None
-            ):
-                # Reject any user coming from CAS and whose username contains profanity
-                if auth_provider_id == "cas" and "flimflob" in username:
-                    return RegistrationBehaviour.DENY
-                return RegistrationBehaviour.ALLOW
-
-        # Configure a spam checker that denies a certain user on a specific IdP
-        spam_checker = self.hs.get_spam_checker()
-        spam_checker.spam_checkers = [BanBadIdPUser()]
-
         f = self.get_failure(
             self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"),
             SynapseError,
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index daac37abd8..549876dc85 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -312,15 +312,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         s = self.get_success(self.handler.search_users(u1, "user2", 10))
         self.assertEqual(len(s["results"]), 1)
 
+        async def allow_all(user_profile):
+            # Allow all users.
+            return False
+
         # Configure a spam checker that does not filter any users.
         spam_checker = self.hs.get_spam_checker()
-
-        class AllowAll:
-            async def check_username_for_spam(self, user_profile):
-                # Allow all users.
-                return False
-
-        spam_checker.spam_checkers = [AllowAll()]
+        spam_checker._check_username_for_spam_callbacks = [allow_all]
 
         # The results do not change:
         # We get one search result when searching for user2 by user1.
@@ -328,12 +326,11 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         self.assertEqual(len(s["results"]), 1)
 
         # Configure a spam checker that filters all users.
-        class BlockAll:
-            async def check_username_for_spam(self, user_profile):
-                # All users are spammy.
-                return True
+        async def block_all(user_profile):
+            # All users are spammy.
+            return True
 
-        spam_checker.spam_checkers = [BlockAll()]
+        spam_checker._check_username_for_spam_callbacks = [block_all]
 
         # User1 now gets no search results for any of the other users.
         s = self.get_success(self.handler.search_users(u1, "user2", 10))