summary refs log tree commit diff
path: root/tests/rest
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest')
-rw-r--r--tests/rest/client/test_login.py108
1 files changed, 106 insertions, 2 deletions
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index f3c3bc69a9..ffbc13bb8d 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -13,11 +13,12 @@
 # limitations under the License.
 import time
 import urllib.parse
-from typing import Any, Dict, List, Optional
+from typing import Any, Collection, Dict, List, Optional, Tuple, Union
 from unittest.mock import Mock
 from urllib.parse import urlencode
 
 import pymacaroons
+from typing_extensions import Literal
 
 from twisted.test.proto_helpers import MemoryReactor
 from twisted.web.resource import Resource
@@ -26,11 +27,12 @@ import synapse.rest.admin
 from synapse.api.constants import ApprovalNoticeMedium, LoginType
 from synapse.api.errors import Codes
 from synapse.appservice import ApplicationService
+from synapse.module_api import ModuleApi
 from synapse.rest.client import devices, login, logout, register
 from synapse.rest.client.account import WhoamiRestServlet
 from synapse.rest.synapse.client import build_synapse_client_resource_tree
 from synapse.server import HomeServer
-from synapse.types import create_requester
+from synapse.types import JsonDict, create_requester
 from synapse.util import Clock
 
 from tests import unittest
@@ -88,6 +90,56 @@ ADDITIONAL_LOGIN_FLOWS = [
 ]
 
 
+class TestSpamChecker:
+    def __init__(self, config: None, api: ModuleApi):
+        api.register_spam_checker_callbacks(
+            check_login_for_spam=self.check_login_for_spam,
+        )
+
+    @staticmethod
+    def parse_config(config: JsonDict) -> None:
+        return None
+
+    async def check_login_for_spam(
+        self,
+        user_id: str,
+        device_id: Optional[str],
+        initial_display_name: Optional[str],
+        request_info: Collection[Tuple[Optional[str], str]],
+        auth_provider_id: Optional[str] = None,
+    ) -> Union[
+        Literal["NOT_SPAM"],
+        Tuple["synapse.module_api.errors.Codes", JsonDict],
+    ]:
+        return "NOT_SPAM"
+
+
+class DenyAllSpamChecker:
+    def __init__(self, config: None, api: ModuleApi):
+        api.register_spam_checker_callbacks(
+            check_login_for_spam=self.check_login_for_spam,
+        )
+
+    @staticmethod
+    def parse_config(config: JsonDict) -> None:
+        return None
+
+    async def check_login_for_spam(
+        self,
+        user_id: str,
+        device_id: Optional[str],
+        initial_display_name: Optional[str],
+        request_info: Collection[Tuple[Optional[str], str]],
+        auth_provider_id: Optional[str] = None,
+    ) -> Union[
+        Literal["NOT_SPAM"],
+        Tuple["synapse.module_api.errors.Codes", JsonDict],
+    ]:
+        # Return an odd set of values to ensure that they get correctly passed
+        # to the client.
+        return Codes.LIMIT_EXCEEDED, {"extra": "value"}
+
+
 class LoginRestServletTestCase(unittest.HomeserverTestCase):
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -469,6 +521,58 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             ],
         )
 
+    @override_config(
+        {
+            "modules": [
+                {
+                    "module": TestSpamChecker.__module__
+                    + "."
+                    + TestSpamChecker.__qualname__
+                }
+            ]
+        }
+    )
+    def test_spam_checker_allow(self) -> None:
+        """Check that that adding a spam checker doesn't break login."""
+        self.register_user("kermit", "monkey")
+
+        body = {"type": "m.login.password", "user": "kermit", "password": "monkey"}
+
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/r0/login",
+            body,
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+
+    @override_config(
+        {
+            "modules": [
+                {
+                    "module": DenyAllSpamChecker.__module__
+                    + "."
+                    + DenyAllSpamChecker.__qualname__
+                }
+            ]
+        }
+    )
+    def test_spam_checker_deny(self) -> None:
+        """Check that login"""
+
+        self.register_user("kermit", "monkey")
+
+        body = {"type": "m.login.password", "user": "kermit", "password": "monkey"}
+
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/r0/login",
+            body,
+        )
+        self.assertEqual(channel.code, 403, channel.result)
+        self.assertDictContainsSubset(
+            {"errcode": Codes.LIMIT_EXCEEDED, "extra": "value"}, channel.json_body
+        )
+
 
 @skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
 class MultiSSOTestCase(unittest.HomeserverTestCase):