diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index f3c3bc69a9..ffbc13bb8d 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -13,11 +13,12 @@
# limitations under the License.
import time
import urllib.parse
-from typing import Any, Dict, List, Optional
+from typing import Any, Collection, Dict, List, Optional, Tuple, Union
from unittest.mock import Mock
from urllib.parse import urlencode
import pymacaroons
+from typing_extensions import Literal
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
@@ -26,11 +27,12 @@ import synapse.rest.admin
from synapse.api.constants import ApprovalNoticeMedium, LoginType
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
+from synapse.module_api import ModuleApi
from synapse.rest.client import devices, login, logout, register
from synapse.rest.client.account import WhoamiRestServlet
from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.server import HomeServer
-from synapse.types import create_requester
+from synapse.types import JsonDict, create_requester
from synapse.util import Clock
from tests import unittest
@@ -88,6 +90,56 @@ ADDITIONAL_LOGIN_FLOWS = [
]
+class TestSpamChecker:
+ def __init__(self, config: None, api: ModuleApi):
+ api.register_spam_checker_callbacks(
+ check_login_for_spam=self.check_login_for_spam,
+ )
+
+ @staticmethod
+ def parse_config(config: JsonDict) -> None:
+ return None
+
+ async def check_login_for_spam(
+ self,
+ user_id: str,
+ device_id: Optional[str],
+ initial_display_name: Optional[str],
+ request_info: Collection[Tuple[Optional[str], str]],
+ auth_provider_id: Optional[str] = None,
+ ) -> Union[
+ Literal["NOT_SPAM"],
+ Tuple["synapse.module_api.errors.Codes", JsonDict],
+ ]:
+ return "NOT_SPAM"
+
+
+class DenyAllSpamChecker:
+ def __init__(self, config: None, api: ModuleApi):
+ api.register_spam_checker_callbacks(
+ check_login_for_spam=self.check_login_for_spam,
+ )
+
+ @staticmethod
+ def parse_config(config: JsonDict) -> None:
+ return None
+
+ async def check_login_for_spam(
+ self,
+ user_id: str,
+ device_id: Optional[str],
+ initial_display_name: Optional[str],
+ request_info: Collection[Tuple[Optional[str], str]],
+ auth_provider_id: Optional[str] = None,
+ ) -> Union[
+ Literal["NOT_SPAM"],
+ Tuple["synapse.module_api.errors.Codes", JsonDict],
+ ]:
+ # Return an odd set of values to ensure that they get correctly passed
+ # to the client.
+ return Codes.LIMIT_EXCEEDED, {"extra": "value"}
+
+
class LoginRestServletTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -469,6 +521,58 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
],
)
+ @override_config(
+ {
+ "modules": [
+ {
+ "module": TestSpamChecker.__module__
+ + "."
+ + TestSpamChecker.__qualname__
+ }
+ ]
+ }
+ )
+ def test_spam_checker_allow(self) -> None:
+ """Check that that adding a spam checker doesn't break login."""
+ self.register_user("kermit", "monkey")
+
+ body = {"type": "m.login.password", "user": "kermit", "password": "monkey"}
+
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/login",
+ body,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ @override_config(
+ {
+ "modules": [
+ {
+ "module": DenyAllSpamChecker.__module__
+ + "."
+ + DenyAllSpamChecker.__qualname__
+ }
+ ]
+ }
+ )
+ def test_spam_checker_deny(self) -> None:
+ """Check that login"""
+
+ self.register_user("kermit", "monkey")
+
+ body = {"type": "m.login.password", "user": "kermit", "password": "monkey"}
+
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/login",
+ body,
+ )
+ self.assertEqual(channel.code, 403, channel.result)
+ self.assertDictContainsSubset(
+ {"errcode": Codes.LIMIT_EXCEEDED, "extra": "value"}, channel.json_body
+ )
+
@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
class MultiSSOTestCase(unittest.HomeserverTestCase):
|