diff --git a/changelog.d/15838.feature b/changelog.d/15838.feature
new file mode 100644
index 0000000000..04c77bd723
--- /dev/null
+++ b/changelog.d/15838.feature
@@ -0,0 +1 @@
+Add spam checker module API for logins.
diff --git a/docs/modules/spam_checker_callbacks.md b/docs/modules/spam_checker_callbacks.md
index 1a0c6ec954..ffdfe6082e 100644
--- a/docs/modules/spam_checker_callbacks.md
+++ b/docs/modules/spam_checker_callbacks.md
@@ -348,6 +348,42 @@ callback returns `False`, Synapse falls through to the next one. The value of th
callback that does not return `False` will be used. If this happens, Synapse will not call
any of the subsequent implementations of this callback.
+
+### `check_login_for_spam`
+
+_First introduced in Synapse v1.87.0_
+
+```python
+async def check_login_for_spam(
+ user_id: str,
+ device_id: Optional[str],
+ initial_display_name: Optional[str],
+ request_info: Collection[Tuple[Optional[str], str]],
+ auth_provider_id: Optional[str] = None,
+) -> Union["synapse.module_api.NOT_SPAM", "synapse.module_api.errors.Codes"]
+```
+
+Called when a user logs in.
+
+The arguments passed to this callback are:
+
+* `user_id`: The user ID the user is logging in with
+* `device_id`: The device ID the user is re-logging into.
+* `initial_display_name`: The device display name, if any.
+* `request_info`: A collection of tuples, which first item is a user agent, and which
+ second item is an IP address. These user agents and IP addresses are the ones that were
+ used during the login process.
+* `auth_provider_id`: The identifier of the SSO authentication provider, if any.
+
+If multiple modules implement this callback, they will be considered in order. If a
+callback returns `synapse.module_api.NOT_SPAM`, Synapse falls through to the next one.
+The value of the first callback that does not return `synapse.module_api.NOT_SPAM` will
+be used. If this happens, Synapse will not call any of the subsequent implementations of
+this callback.
+
+*Note:* This will not be called when a user registers.
+
+
## Example
The example below is a module that implements the spam checker callback
diff --git a/synapse/http/site.py b/synapse/http/site.py
index c530966ef3..5b5a7c1e59 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -521,6 +521,11 @@ class SynapseRequest(Request):
else:
return self.getClientAddress().host
+ def request_info(self) -> "RequestInfo":
+ h = self.getHeader(b"User-Agent")
+ user_agent = h.decode("ascii", "replace") if h else None
+ return RequestInfo(user_agent=user_agent, ip=self.get_client_ip_if_available())
+
class XForwardedForRequest(SynapseRequest):
"""Request object which honours proxy headers
@@ -661,3 +666,9 @@ class SynapseSite(Site):
def log(self, request: SynapseRequest) -> None:
pass
+
+
+@attr.s(auto_attribs=True, frozen=True, slots=True)
+class RequestInfo:
+ user_agent: Optional[str]
+ ip: str
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 84b2aef620..95f7800111 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -80,6 +80,7 @@ from synapse.module_api.callbacks.account_validity_callbacks import (
)
from synapse.module_api.callbacks.spamchecker_callbacks import (
CHECK_EVENT_FOR_SPAM_CALLBACK,
+ CHECK_LOGIN_FOR_SPAM_CALLBACK,
CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK,
CHECK_REGISTRATION_FOR_SPAM_CALLBACK,
CHECK_USERNAME_FOR_SPAM_CALLBACK,
@@ -302,6 +303,7 @@ class ModuleApi:
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
] = None,
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
+ check_login_for_spam: Optional[CHECK_LOGIN_FOR_SPAM_CALLBACK] = None,
) -> None:
"""Registers callbacks for spam checking capabilities.
@@ -319,6 +321,7 @@ class ModuleApi:
check_username_for_spam=check_username_for_spam,
check_registration_for_spam=check_registration_for_spam,
check_media_file_for_spam=check_media_file_for_spam,
+ check_login_for_spam=check_login_for_spam,
)
def register_account_validity_callbacks(
diff --git a/synapse/module_api/callbacks/spamchecker_callbacks.py b/synapse/module_api/callbacks/spamchecker_callbacks.py
index 4456d1b81e..7cee442145 100644
--- a/synapse/module_api/callbacks/spamchecker_callbacks.py
+++ b/synapse/module_api/callbacks/spamchecker_callbacks.py
@@ -196,6 +196,26 @@ CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK = Callable[
]
],
]
+CHECK_LOGIN_FOR_SPAM_CALLBACK = Callable[
+ [
+ str,
+ Optional[str],
+ Optional[str],
+ Collection[Tuple[Optional[str], str]],
+ Optional[str],
+ ],
+ Awaitable[
+ Union[
+ Literal["NOT_SPAM"],
+ Codes,
+ # Highly experimental, not officially part of the spamchecker API, may
+ # disappear without warning depending on the results of ongoing
+ # experiments.
+ # Use this to return additional information as part of an error.
+ Tuple[Codes, JsonDict],
+ ]
+ ],
+]
def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None:
@@ -315,6 +335,7 @@ class SpamCheckerModuleApiCallbacks:
self._check_media_file_for_spam_callbacks: List[
CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK
] = []
+ self._check_login_for_spam_callbacks: List[CHECK_LOGIN_FOR_SPAM_CALLBACK] = []
def register_callbacks(
self,
@@ -335,6 +356,7 @@ class SpamCheckerModuleApiCallbacks:
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
] = None,
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
+ check_login_for_spam: Optional[CHECK_LOGIN_FOR_SPAM_CALLBACK] = None,
) -> None:
"""Register callbacks from module for each hook."""
if check_event_for_spam is not None:
@@ -378,6 +400,9 @@ class SpamCheckerModuleApiCallbacks:
if check_media_file_for_spam is not None:
self._check_media_file_for_spam_callbacks.append(check_media_file_for_spam)
+ if check_login_for_spam is not None:
+ self._check_login_for_spam_callbacks.append(check_login_for_spam)
+
@trace
async def check_event_for_spam(
self, event: "synapse.events.EventBase"
@@ -819,3 +844,58 @@ class SpamCheckerModuleApiCallbacks:
return synapse.api.errors.Codes.FORBIDDEN, {}
return self.NOT_SPAM
+
+ async def check_login_for_spam(
+ self,
+ user_id: str,
+ device_id: Optional[str],
+ initial_display_name: Optional[str],
+ request_info: Collection[Tuple[Optional[str], str]],
+ auth_provider_id: Optional[str] = None,
+ ) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]:
+ """Checks if we should allow the given registration request.
+
+ Args:
+ user_id: The request user ID
+ request_info: List of tuples of user agent and IP that
+ were used during the registration process.
+ auth_provider_id: The SSO IdP the user used, e.g "oidc", "saml",
+ "cas". If any. Note this does not include users registered
+ via a password provider.
+
+ Returns:
+ Enum for how the request should be handled
+ """
+
+ for callback in self._check_login_for_spam_callbacks:
+ with Measure(
+ self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+ ):
+ res = await delay_cancellation(
+ callback(
+ user_id,
+ device_id,
+ initial_display_name,
+ request_info,
+ auth_provider_id,
+ )
+ )
+ # Normalize return values to `Codes` or `"NOT_SPAM"`.
+ if res is self.NOT_SPAM:
+ continue
+ elif isinstance(res, synapse.api.errors.Codes):
+ return res, {}
+ elif (
+ isinstance(res, tuple)
+ and len(res) == 2
+ and isinstance(res[0], synapse.api.errors.Codes)
+ and isinstance(res[1], dict)
+ ):
+ return res
+ else:
+ logger.warning(
+ "Module returned invalid value, rejecting login as spam"
+ )
+ return synapse.api.errors.Codes.FORBIDDEN, {}
+
+ return self.NOT_SPAM
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index 6493b00bb8..d724c68920 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -50,7 +50,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
-from synapse.http.site import SynapseRequest
+from synapse.http.site import RequestInfo, SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import JsonDict, UserID
@@ -114,6 +114,7 @@ class LoginRestServlet(RestServlet):
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self._sso_handler = hs.get_sso_handler()
+ self._spam_checker = hs.get_module_api_callbacks().spam_checker
self._well_known_builder = WellKnownBuilder(hs)
self._address_ratelimiter = Ratelimiter(
@@ -197,6 +198,8 @@ class LoginRestServlet(RestServlet):
self._refresh_tokens_enabled and client_requested_refresh_token
)
+ request_info = request.request_info()
+
try:
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
requester = await self.auth.get_user_by_req(request)
@@ -216,6 +219,7 @@ class LoginRestServlet(RestServlet):
login_submission,
appservice,
should_issue_refresh_token=should_issue_refresh_token,
+ request_info=request_info,
)
elif (
self.jwt_enabled
@@ -227,6 +231,7 @@ class LoginRestServlet(RestServlet):
result = await self._do_jwt_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
+ request_info=request_info,
)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
await self._address_ratelimiter.ratelimit(
@@ -235,6 +240,7 @@ class LoginRestServlet(RestServlet):
result = await self._do_token_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
+ request_info=request_info,
)
else:
await self._address_ratelimiter.ratelimit(
@@ -243,6 +249,7 @@ class LoginRestServlet(RestServlet):
result = await self._do_other_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
+ request_info=request_info,
)
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
@@ -265,6 +272,8 @@ class LoginRestServlet(RestServlet):
login_submission: JsonDict,
appservice: ApplicationService,
should_issue_refresh_token: bool = False,
+ *,
+ request_info: RequestInfo,
) -> LoginResponse:
identifier = login_submission.get("identifier")
logger.info("Got appservice login request with identifier: %r", identifier)
@@ -300,10 +309,15 @@ class LoginRestServlet(RestServlet):
# The user represented by an appservice's configured sender_localpart
# is not actually created in Synapse.
should_check_deactivated=qualified_user_id != appservice.sender,
+ request_info=request_info,
)
async def _do_other_login(
- self, login_submission: JsonDict, should_issue_refresh_token: bool = False
+ self,
+ login_submission: JsonDict,
+ should_issue_refresh_token: bool = False,
+ *,
+ request_info: RequestInfo,
) -> LoginResponse:
"""Handle non-token/saml/jwt logins
@@ -333,6 +347,7 @@ class LoginRestServlet(RestServlet):
login_submission,
callback,
should_issue_refresh_token=should_issue_refresh_token,
+ request_info=request_info,
)
return result
@@ -347,6 +362,8 @@ class LoginRestServlet(RestServlet):
should_issue_refresh_token: bool = False,
auth_provider_session_id: Optional[str] = None,
should_check_deactivated: bool = True,
+ *,
+ request_info: RequestInfo,
) -> LoginResponse:
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
@@ -371,6 +388,7 @@ class LoginRestServlet(RestServlet):
This exists purely for appservice's configured sender_localpart
which doesn't have an associated user in the database.
+ request_info: The user agent/IP address of the user.
Returns:
Dictionary of account information after successful login.
@@ -417,6 +435,22 @@ class LoginRestServlet(RestServlet):
)
initial_display_name = login_submission.get("initial_device_display_name")
+ spam_check = await self._spam_checker.check_login_for_spam(
+ user_id,
+ device_id=device_id,
+ initial_display_name=initial_display_name,
+ request_info=[(request_info.user_agent, request_info.ip)],
+ auth_provider_id=auth_provider_id,
+ )
+ if spam_check != self._spam_checker.NOT_SPAM:
+ logger.info("Blocking login due to spam checker")
+ raise SynapseError(
+ 403,
+ msg="Login was blocked by the server",
+ errcode=spam_check[0],
+ additional_fields=spam_check[1],
+ )
+
(
device_id,
access_token,
@@ -451,7 +485,11 @@ class LoginRestServlet(RestServlet):
return result
async def _do_token_login(
- self, login_submission: JsonDict, should_issue_refresh_token: bool = False
+ self,
+ login_submission: JsonDict,
+ should_issue_refresh_token: bool = False,
+ *,
+ request_info: RequestInfo,
) -> LoginResponse:
"""
Handle token login.
@@ -474,10 +512,15 @@ class LoginRestServlet(RestServlet):
auth_provider_id=res.auth_provider_id,
should_issue_refresh_token=should_issue_refresh_token,
auth_provider_session_id=res.auth_provider_session_id,
+ request_info=request_info,
)
async def _do_jwt_login(
- self, login_submission: JsonDict, should_issue_refresh_token: bool = False
+ self,
+ login_submission: JsonDict,
+ should_issue_refresh_token: bool = False,
+ *,
+ request_info: RequestInfo,
) -> LoginResponse:
"""
Handle the custom JWT login.
@@ -496,6 +539,7 @@ class LoginRestServlet(RestServlet):
login_submission,
create_non_existent_users=True,
should_issue_refresh_token=should_issue_refresh_token,
+ request_info=request_info,
)
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index f3c3bc69a9..ffbc13bb8d 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -13,11 +13,12 @@
# limitations under the License.
import time
import urllib.parse
-from typing import Any, Dict, List, Optional
+from typing import Any, Collection, Dict, List, Optional, Tuple, Union
from unittest.mock import Mock
from urllib.parse import urlencode
import pymacaroons
+from typing_extensions import Literal
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
@@ -26,11 +27,12 @@ import synapse.rest.admin
from synapse.api.constants import ApprovalNoticeMedium, LoginType
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
+from synapse.module_api import ModuleApi
from synapse.rest.client import devices, login, logout, register
from synapse.rest.client.account import WhoamiRestServlet
from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.server import HomeServer
-from synapse.types import create_requester
+from synapse.types import JsonDict, create_requester
from synapse.util import Clock
from tests import unittest
@@ -88,6 +90,56 @@ ADDITIONAL_LOGIN_FLOWS = [
]
+class TestSpamChecker:
+ def __init__(self, config: None, api: ModuleApi):
+ api.register_spam_checker_callbacks(
+ check_login_for_spam=self.check_login_for_spam,
+ )
+
+ @staticmethod
+ def parse_config(config: JsonDict) -> None:
+ return None
+
+ async def check_login_for_spam(
+ self,
+ user_id: str,
+ device_id: Optional[str],
+ initial_display_name: Optional[str],
+ request_info: Collection[Tuple[Optional[str], str]],
+ auth_provider_id: Optional[str] = None,
+ ) -> Union[
+ Literal["NOT_SPAM"],
+ Tuple["synapse.module_api.errors.Codes", JsonDict],
+ ]:
+ return "NOT_SPAM"
+
+
+class DenyAllSpamChecker:
+ def __init__(self, config: None, api: ModuleApi):
+ api.register_spam_checker_callbacks(
+ check_login_for_spam=self.check_login_for_spam,
+ )
+
+ @staticmethod
+ def parse_config(config: JsonDict) -> None:
+ return None
+
+ async def check_login_for_spam(
+ self,
+ user_id: str,
+ device_id: Optional[str],
+ initial_display_name: Optional[str],
+ request_info: Collection[Tuple[Optional[str], str]],
+ auth_provider_id: Optional[str] = None,
+ ) -> Union[
+ Literal["NOT_SPAM"],
+ Tuple["synapse.module_api.errors.Codes", JsonDict],
+ ]:
+ # Return an odd set of values to ensure that they get correctly passed
+ # to the client.
+ return Codes.LIMIT_EXCEEDED, {"extra": "value"}
+
+
class LoginRestServletTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -469,6 +521,58 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
],
)
+ @override_config(
+ {
+ "modules": [
+ {
+ "module": TestSpamChecker.__module__
+ + "."
+ + TestSpamChecker.__qualname__
+ }
+ ]
+ }
+ )
+ def test_spam_checker_allow(self) -> None:
+ """Check that that adding a spam checker doesn't break login."""
+ self.register_user("kermit", "monkey")
+
+ body = {"type": "m.login.password", "user": "kermit", "password": "monkey"}
+
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/login",
+ body,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ @override_config(
+ {
+ "modules": [
+ {
+ "module": DenyAllSpamChecker.__module__
+ + "."
+ + DenyAllSpamChecker.__qualname__
+ }
+ ]
+ }
+ )
+ def test_spam_checker_deny(self) -> None:
+ """Check that login"""
+
+ self.register_user("kermit", "monkey")
+
+ body = {"type": "m.login.password", "user": "kermit", "password": "monkey"}
+
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/login",
+ body,
+ )
+ self.assertEqual(channel.code, 403, channel.result)
+ self.assertDictContainsSubset(
+ {"errcode": Codes.LIMIT_EXCEEDED, "extra": "value"}, channel.json_body
+ )
+
@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
class MultiSSOTestCase(unittest.HomeserverTestCase):
|