diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 1bb25ab684..f92f3b8c15 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -374,12 +374,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
self.handler._auth_handler.complete_sso_login = simple_async_mock()
- request = Mock(spec=["args", "getCookie", "addCookie"])
+ request = Mock(
+ spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
+ )
code = "code"
state = "state"
nonce = "nonce"
client_redirect_url = "http://client/redirect"
+ user_agent = "Browser"
+ ip_address = "10.0.0.1"
session = self.handler._generate_oidc_session_token(
state=state,
nonce=nonce,
@@ -392,6 +396,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
request.args[b"code"] = [code.encode("utf-8")]
request.args[b"state"] = [state.encode("utf-8")]
+ request.requestHeaders = Mock(spec=["getRawHeaders"])
+ request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")]
+ request.getClientIP.return_value = ip_address
+
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
@@ -399,7 +407,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
- self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+ self.handler._map_userinfo_to_user.assert_called_once_with(
+ userinfo, token, user_agent, ip_address
+ )
self.handler._fetch_userinfo.assert_not_called()
self.handler._render_error.assert_not_called()
@@ -431,7 +441,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_not_called()
- self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+ self.handler._map_userinfo_to_user.assert_called_once_with(
+ userinfo, token, user_agent, ip_address
+ )
self.handler._fetch_userinfo.assert_called_once_with(token)
self.handler._render_error.assert_not_called()
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index e364b1bd62..5c92d0e8c9 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -17,18 +17,21 @@ from mock import Mock
from twisted.internet import defer
+from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, ResourceLimitError, SynapseError
from synapse.handlers.register import RegistrationHandler
+from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, UserID, create_requester
from tests.test_utils import make_awaitable
from tests.unittest import override_config
+from tests.utils import mock_getRawHeaders
from .. import unittest
-class RegistrationHandlers(object):
+class RegistrationHandlers:
def __init__(self, hs):
self.registration_handler = RegistrationHandler(hs)
@@ -475,6 +478,53 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart=invalid_user_id), SynapseError
)
+ def test_spam_checker_deny(self):
+ """A spam checker can deny registration, which results in an error."""
+
+ class DenyAll:
+ def check_registration_for_spam(
+ self, email_threepid, username, request_info
+ ):
+ return RegistrationBehaviour.DENY
+
+ # Configure a spam checker that denies all users.
+ spam_checker = self.hs.get_spam_checker()
+ spam_checker.spam_checkers = [DenyAll()]
+
+ self.get_failure(self.handler.register_user(localpart="user"), SynapseError)
+
+ def test_spam_checker_shadow_ban(self):
+ """A spam checker can choose to shadow-ban a user, which allows registration to succeed."""
+
+ class BanAll:
+ def check_registration_for_spam(
+ self, email_threepid, username, request_info
+ ):
+ return RegistrationBehaviour.SHADOW_BAN
+
+ # Configure a spam checker that denies all users.
+ spam_checker = self.hs.get_spam_checker()
+ spam_checker.spam_checkers = [BanAll()]
+
+ user_id = self.get_success(self.handler.register_user(localpart="user"))
+
+ # Get an access token.
+ token = self.macaroon_generator.generate_access_token(user_id)
+ self.get_success(
+ self.store.add_access_token_to_user(
+ user_id=user_id, token=token, device_id=None, valid_until_ms=None
+ )
+ )
+
+ # Ensure the user was marked as shadow-banned.
+ request = Mock(args={})
+ request.args[b"access_token"] = [token.encode("ascii")]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ auth = Auth(self.hs)
+ requester = self.get_success(auth.get_user_by_req(request))
+
+ self.assertTrue(requester.shadow_banned)
+
async def get_or_create_user(
self, requester, localpart, displayname, password_hash=None
):
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 31ed89a5cd..87be94111f 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -238,7 +238,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def test_spam_checker(self):
"""
- A user which fails to the spam checks will not appear in search results.
+ A user which fails the spam checks will not appear in search results.
"""
u1 = self.register_user("user1", "pass")
u1_token = self.login(u1, "pass")
@@ -269,7 +269,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Configure a spam checker that does not filter any users.
spam_checker = self.hs.get_spam_checker()
- class AllowAll(object):
+ class AllowAll:
def check_username_for_spam(self, user_profile):
# Allow all users.
return False
@@ -282,7 +282,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(s["results"]), 1)
# Configure a spam checker that filters all users.
- class BlockAll(object):
+ class BlockAll:
def check_username_for_spam(self, user_profile):
# All users are spammy.
return True
|