diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index c51763f41a..a9fd3036dc 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -27,6 +27,58 @@ from tests.utils import mock_getRawHeaders
from .. import unittest
+class TestSpamChecker:
+ def __init__(self, config, api):
+ api.register_spam_checker_callbacks(
+ check_registration_for_spam=self.check_registration_for_spam,
+ )
+
+ @staticmethod
+ def parse_config(config):
+ return config
+
+ async def check_registration_for_spam(
+ self,
+ email_threepid,
+ username,
+ request_info,
+ auth_provider_id,
+ ):
+ pass
+
+
+class DenyAll(TestSpamChecker):
+ async def check_registration_for_spam(
+ self,
+ email_threepid,
+ username,
+ request_info,
+ auth_provider_id,
+ ):
+ return RegistrationBehaviour.DENY
+
+
+class BanAll(TestSpamChecker):
+ async def check_registration_for_spam(
+ self,
+ email_threepid,
+ username,
+ request_info,
+ auth_provider_id,
+ ):
+ return RegistrationBehaviour.SHADOW_BAN
+
+
+class BanBadIdPUser(TestSpamChecker):
+ async def check_registration_for_spam(
+ self, email_threepid, username, request_info, auth_provider_id=None
+ ):
+ # Reject any user coming from CAS and whose username contains profanity
+ if auth_provider_id == "cas" and "flimflob" in username:
+ return RegistrationBehaviour.DENY
+ return RegistrationBehaviour.ALLOW
+
+
class RegistrationTestCase(unittest.HomeserverTestCase):
"""Tests the RegistrationHandler."""
@@ -42,6 +94,11 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
hs_config["limit_usage_by_mau"] = True
hs = self.setup_test_homeserver(config=hs_config)
+
+ module_api = hs.get_module_api()
+ for module, config in hs.config.modules.loaded_modules:
+ module(config=config, api=module_api)
+
return hs
def prepare(self, reactor, clock, hs):
@@ -465,34 +522,30 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart=invalid_user_id), SynapseError
)
+ @override_config(
+ {
+ "modules": [
+ {
+ "module": TestSpamChecker.__module__ + ".DenyAll",
+ }
+ ]
+ }
+ )
def test_spam_checker_deny(self):
"""A spam checker can deny registration, which results in an error."""
-
- class DenyAll:
- def check_registration_for_spam(
- self, email_threepid, username, request_info
- ):
- return RegistrationBehaviour.DENY
-
- # Configure a spam checker that denies all users.
- spam_checker = self.hs.get_spam_checker()
- spam_checker.spam_checkers = [DenyAll()]
-
self.get_failure(self.handler.register_user(localpart="user"), SynapseError)
+ @override_config(
+ {
+ "modules": [
+ {
+ "module": TestSpamChecker.__module__ + ".BanAll",
+ }
+ ]
+ }
+ )
def test_spam_checker_shadow_ban(self):
"""A spam checker can choose to shadow-ban a user, which allows registration to succeed."""
-
- class BanAll:
- def check_registration_for_spam(
- self, email_threepid, username, request_info
- ):
- return RegistrationBehaviour.SHADOW_BAN
-
- # Configure a spam checker that denies all users.
- spam_checker = self.hs.get_spam_checker()
- spam_checker.spam_checkers = [BanAll()]
-
user_id = self.get_success(self.handler.register_user(localpart="user"))
# Get an access token.
@@ -512,22 +565,17 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertTrue(requester.shadow_banned)
+ @override_config(
+ {
+ "modules": [
+ {
+ "module": TestSpamChecker.__module__ + ".BanBadIdPUser",
+ }
+ ]
+ }
+ )
def test_spam_checker_receives_sso_type(self):
"""Test rejecting registration based on SSO type"""
-
- class BanBadIdPUser:
- def check_registration_for_spam(
- self, email_threepid, username, request_info, auth_provider_id=None
- ):
- # Reject any user coming from CAS and whose username contains profanity
- if auth_provider_id == "cas" and "flimflob" in username:
- return RegistrationBehaviour.DENY
- return RegistrationBehaviour.ALLOW
-
- # Configure a spam checker that denies a certain user on a specific IdP
- spam_checker = self.hs.get_spam_checker()
- spam_checker.spam_checkers = [BanBadIdPUser()]
-
f = self.get_failure(
self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"),
SynapseError,
|