diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index c51763f41a..a9fd3036dc 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -27,6 +27,58 @@ from tests.utils import mock_getRawHeaders
from .. import unittest
+class TestSpamChecker:
+ def __init__(self, config, api):
+ api.register_spam_checker_callbacks(
+ check_registration_for_spam=self.check_registration_for_spam,
+ )
+
+ @staticmethod
+ def parse_config(config):
+ return config
+
+ async def check_registration_for_spam(
+ self,
+ email_threepid,
+ username,
+ request_info,
+ auth_provider_id,
+ ):
+ pass
+
+
+class DenyAll(TestSpamChecker):
+ async def check_registration_for_spam(
+ self,
+ email_threepid,
+ username,
+ request_info,
+ auth_provider_id,
+ ):
+ return RegistrationBehaviour.DENY
+
+
+class BanAll(TestSpamChecker):
+ async def check_registration_for_spam(
+ self,
+ email_threepid,
+ username,
+ request_info,
+ auth_provider_id,
+ ):
+ return RegistrationBehaviour.SHADOW_BAN
+
+
+class BanBadIdPUser(TestSpamChecker):
+ async def check_registration_for_spam(
+ self, email_threepid, username, request_info, auth_provider_id=None
+ ):
+ # Reject any user coming from CAS and whose username contains profanity
+ if auth_provider_id == "cas" and "flimflob" in username:
+ return RegistrationBehaviour.DENY
+ return RegistrationBehaviour.ALLOW
+
+
class RegistrationTestCase(unittest.HomeserverTestCase):
"""Tests the RegistrationHandler."""
@@ -42,6 +94,11 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
hs_config["limit_usage_by_mau"] = True
hs = self.setup_test_homeserver(config=hs_config)
+
+ module_api = hs.get_module_api()
+ for module, config in hs.config.modules.loaded_modules:
+ module(config=config, api=module_api)
+
return hs
def prepare(self, reactor, clock, hs):
@@ -465,34 +522,30 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart=invalid_user_id), SynapseError
)
+ @override_config(
+ {
+ "modules": [
+ {
+ "module": TestSpamChecker.__module__ + ".DenyAll",
+ }
+ ]
+ }
+ )
def test_spam_checker_deny(self):
"""A spam checker can deny registration, which results in an error."""
-
- class DenyAll:
- def check_registration_for_spam(
- self, email_threepid, username, request_info
- ):
- return RegistrationBehaviour.DENY
-
- # Configure a spam checker that denies all users.
- spam_checker = self.hs.get_spam_checker()
- spam_checker.spam_checkers = [DenyAll()]
-
self.get_failure(self.handler.register_user(localpart="user"), SynapseError)
+ @override_config(
+ {
+ "modules": [
+ {
+ "module": TestSpamChecker.__module__ + ".BanAll",
+ }
+ ]
+ }
+ )
def test_spam_checker_shadow_ban(self):
"""A spam checker can choose to shadow-ban a user, which allows registration to succeed."""
-
- class BanAll:
- def check_registration_for_spam(
- self, email_threepid, username, request_info
- ):
- return RegistrationBehaviour.SHADOW_BAN
-
- # Configure a spam checker that denies all users.
- spam_checker = self.hs.get_spam_checker()
- spam_checker.spam_checkers = [BanAll()]
-
user_id = self.get_success(self.handler.register_user(localpart="user"))
# Get an access token.
@@ -512,22 +565,17 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertTrue(requester.shadow_banned)
+ @override_config(
+ {
+ "modules": [
+ {
+ "module": TestSpamChecker.__module__ + ".BanBadIdPUser",
+ }
+ ]
+ }
+ )
def test_spam_checker_receives_sso_type(self):
"""Test rejecting registration based on SSO type"""
-
- class BanBadIdPUser:
- def check_registration_for_spam(
- self, email_threepid, username, request_info, auth_provider_id=None
- ):
- # Reject any user coming from CAS and whose username contains profanity
- if auth_provider_id == "cas" and "flimflob" in username:
- return RegistrationBehaviour.DENY
- return RegistrationBehaviour.ALLOW
-
- # Configure a spam checker that denies a certain user on a specific IdP
- spam_checker = self.hs.get_spam_checker()
- spam_checker.spam_checkers = [BanBadIdPUser()]
-
f = self.get_failure(
self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"),
SynapseError,
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index daac37abd8..549876dc85 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -312,15 +312,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1)
+ async def allow_all(user_profile):
+ # Allow all users.
+ return False
+
# Configure a spam checker that does not filter any users.
spam_checker = self.hs.get_spam_checker()
-
- class AllowAll:
- async def check_username_for_spam(self, user_profile):
- # Allow all users.
- return False
-
- spam_checker.spam_checkers = [AllowAll()]
+ spam_checker._check_username_for_spam_callbacks = [allow_all]
# The results do not change:
# We get one search result when searching for user2 by user1.
@@ -328,12 +326,11 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(s["results"]), 1)
# Configure a spam checker that filters all users.
- class BlockAll:
- async def check_username_for_spam(self, user_profile):
- # All users are spammy.
- return True
+ async def block_all(user_profile):
+ # All users are spammy.
+ return True
- spam_checker.spam_checkers = [BlockAll()]
+ spam_checker._check_username_for_spam_callbacks = [block_all]
# User1 now gets no search results for any of the other users.
s = self.get_success(self.handler.search_users(u1, "user2", 10))
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 4a213d13dd..95e7075841 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -27,6 +27,7 @@ from PIL import Image as Image
from twisted.internet import defer
from twisted.internet.defer import Deferred
+from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.logging.context import make_deferred_yieldable
from synapse.rest import admin
from synapse.rest.client.v1 import login
@@ -535,6 +536,8 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
self.download_resource = self.media_repo.children[b"download"]
self.upload_resource = self.media_repo.children[b"upload"]
+ load_legacy_spam_checkers(hs)
+
def default_config(self):
config = default_config("test")
|