summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/http/site.py11
-rw-r--r--synapse/module_api/__init__.py3
-rw-r--r--synapse/module_api/callbacks/spamchecker_callbacks.py80
-rw-r--r--synapse/rest/client/login.py52
4 files changed, 142 insertions, 4 deletions
diff --git a/synapse/http/site.py b/synapse/http/site.py
index c530966ef3..5b5a7c1e59 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -521,6 +521,11 @@ class SynapseRequest(Request):
         else:
             return self.getClientAddress().host
 
+    def request_info(self) -> "RequestInfo":
+        h = self.getHeader(b"User-Agent")
+        user_agent = h.decode("ascii", "replace") if h else None
+        return RequestInfo(user_agent=user_agent, ip=self.get_client_ip_if_available())
+
 
 class XForwardedForRequest(SynapseRequest):
     """Request object which honours proxy headers
@@ -661,3 +666,9 @@ class SynapseSite(Site):
 
     def log(self, request: SynapseRequest) -> None:
         pass
+
+
+@attr.s(auto_attribs=True, frozen=True, slots=True)
+class RequestInfo:
+    user_agent: Optional[str]
+    ip: str
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 84b2aef620..95f7800111 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -80,6 +80,7 @@ from synapse.module_api.callbacks.account_validity_callbacks import (
 )
 from synapse.module_api.callbacks.spamchecker_callbacks import (
     CHECK_EVENT_FOR_SPAM_CALLBACK,
+    CHECK_LOGIN_FOR_SPAM_CALLBACK,
     CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK,
     CHECK_REGISTRATION_FOR_SPAM_CALLBACK,
     CHECK_USERNAME_FOR_SPAM_CALLBACK,
@@ -302,6 +303,7 @@ class ModuleApi:
             CHECK_REGISTRATION_FOR_SPAM_CALLBACK
         ] = None,
         check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
+        check_login_for_spam: Optional[CHECK_LOGIN_FOR_SPAM_CALLBACK] = None,
     ) -> None:
         """Registers callbacks for spam checking capabilities.
 
@@ -319,6 +321,7 @@ class ModuleApi:
             check_username_for_spam=check_username_for_spam,
             check_registration_for_spam=check_registration_for_spam,
             check_media_file_for_spam=check_media_file_for_spam,
+            check_login_for_spam=check_login_for_spam,
         )
 
     def register_account_validity_callbacks(
diff --git a/synapse/module_api/callbacks/spamchecker_callbacks.py b/synapse/module_api/callbacks/spamchecker_callbacks.py
index 4456d1b81e..7cee442145 100644
--- a/synapse/module_api/callbacks/spamchecker_callbacks.py
+++ b/synapse/module_api/callbacks/spamchecker_callbacks.py
@@ -196,6 +196,26 @@ CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK = Callable[
         ]
     ],
 ]
+CHECK_LOGIN_FOR_SPAM_CALLBACK = Callable[
+    [
+        str,
+        Optional[str],
+        Optional[str],
+        Collection[Tuple[Optional[str], str]],
+        Optional[str],
+    ],
+    Awaitable[
+        Union[
+            Literal["NOT_SPAM"],
+            Codes,
+            # Highly experimental, not officially part of the spamchecker API, may
+            # disappear without warning depending on the results of ongoing
+            # experiments.
+            # Use this to return additional information as part of an error.
+            Tuple[Codes, JsonDict],
+        ]
+    ],
+]
 
 
 def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None:
@@ -315,6 +335,7 @@ class SpamCheckerModuleApiCallbacks:
         self._check_media_file_for_spam_callbacks: List[
             CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK
         ] = []
+        self._check_login_for_spam_callbacks: List[CHECK_LOGIN_FOR_SPAM_CALLBACK] = []
 
     def register_callbacks(
         self,
@@ -335,6 +356,7 @@ class SpamCheckerModuleApiCallbacks:
             CHECK_REGISTRATION_FOR_SPAM_CALLBACK
         ] = None,
         check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
+        check_login_for_spam: Optional[CHECK_LOGIN_FOR_SPAM_CALLBACK] = None,
     ) -> None:
         """Register callbacks from module for each hook."""
         if check_event_for_spam is not None:
@@ -378,6 +400,9 @@ class SpamCheckerModuleApiCallbacks:
         if check_media_file_for_spam is not None:
             self._check_media_file_for_spam_callbacks.append(check_media_file_for_spam)
 
+        if check_login_for_spam is not None:
+            self._check_login_for_spam_callbacks.append(check_login_for_spam)
+
     @trace
     async def check_event_for_spam(
         self, event: "synapse.events.EventBase"
@@ -819,3 +844,58 @@ class SpamCheckerModuleApiCallbacks:
                     return synapse.api.errors.Codes.FORBIDDEN, {}
 
         return self.NOT_SPAM
+
+    async def check_login_for_spam(
+        self,
+        user_id: str,
+        device_id: Optional[str],
+        initial_display_name: Optional[str],
+        request_info: Collection[Tuple[Optional[str], str]],
+        auth_provider_id: Optional[str] = None,
+    ) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]:
+        """Checks if we should allow the given registration request.
+
+        Args:
+            user_id: The request user ID
+            request_info: List of tuples of user agent and IP that
+                were used during the registration process.
+            auth_provider_id: The SSO IdP the user used, e.g "oidc", "saml",
+                "cas". If any. Note this does not include users registered
+                via a password provider.
+
+        Returns:
+            Enum for how the request should be handled
+        """
+
+        for callback in self._check_login_for_spam_callbacks:
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                res = await delay_cancellation(
+                    callback(
+                        user_id,
+                        device_id,
+                        initial_display_name,
+                        request_info,
+                        auth_provider_id,
+                    )
+                )
+                # Normalize return values to `Codes` or `"NOT_SPAM"`.
+                if res is self.NOT_SPAM:
+                    continue
+                elif isinstance(res, synapse.api.errors.Codes):
+                    return res, {}
+                elif (
+                    isinstance(res, tuple)
+                    and len(res) == 2
+                    and isinstance(res[0], synapse.api.errors.Codes)
+                    and isinstance(res[1], dict)
+                ):
+                    return res
+                else:
+                    logger.warning(
+                        "Module returned invalid value, rejecting login as spam"
+                    )
+                    return synapse.api.errors.Codes.FORBIDDEN, {}
+
+        return self.NOT_SPAM
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index 6493b00bb8..d724c68920 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -50,7 +50,7 @@ from synapse.http.servlet import (
     parse_json_object_from_request,
     parse_string,
 )
-from synapse.http.site import SynapseRequest
+from synapse.http.site import RequestInfo, SynapseRequest
 from synapse.rest.client._base import client_patterns
 from synapse.rest.well_known import WellKnownBuilder
 from synapse.types import JsonDict, UserID
@@ -114,6 +114,7 @@ class LoginRestServlet(RestServlet):
         self.auth_handler = self.hs.get_auth_handler()
         self.registration_handler = hs.get_registration_handler()
         self._sso_handler = hs.get_sso_handler()
+        self._spam_checker = hs.get_module_api_callbacks().spam_checker
 
         self._well_known_builder = WellKnownBuilder(hs)
         self._address_ratelimiter = Ratelimiter(
@@ -197,6 +198,8 @@ class LoginRestServlet(RestServlet):
             self._refresh_tokens_enabled and client_requested_refresh_token
         )
 
+        request_info = request.request_info()
+
         try:
             if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
                 requester = await self.auth.get_user_by_req(request)
@@ -216,6 +219,7 @@ class LoginRestServlet(RestServlet):
                     login_submission,
                     appservice,
                     should_issue_refresh_token=should_issue_refresh_token,
+                    request_info=request_info,
                 )
             elif (
                 self.jwt_enabled
@@ -227,6 +231,7 @@ class LoginRestServlet(RestServlet):
                 result = await self._do_jwt_login(
                     login_submission,
                     should_issue_refresh_token=should_issue_refresh_token,
+                    request_info=request_info,
                 )
             elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
                 await self._address_ratelimiter.ratelimit(
@@ -235,6 +240,7 @@ class LoginRestServlet(RestServlet):
                 result = await self._do_token_login(
                     login_submission,
                     should_issue_refresh_token=should_issue_refresh_token,
+                    request_info=request_info,
                 )
             else:
                 await self._address_ratelimiter.ratelimit(
@@ -243,6 +249,7 @@ class LoginRestServlet(RestServlet):
                 result = await self._do_other_login(
                     login_submission,
                     should_issue_refresh_token=should_issue_refresh_token,
+                    request_info=request_info,
                 )
         except KeyError:
             raise SynapseError(400, "Missing JSON keys.")
@@ -265,6 +272,8 @@ class LoginRestServlet(RestServlet):
         login_submission: JsonDict,
         appservice: ApplicationService,
         should_issue_refresh_token: bool = False,
+        *,
+        request_info: RequestInfo,
     ) -> LoginResponse:
         identifier = login_submission.get("identifier")
         logger.info("Got appservice login request with identifier: %r", identifier)
@@ -300,10 +309,15 @@ class LoginRestServlet(RestServlet):
             # The user represented by an appservice's configured sender_localpart
             # is not actually created in Synapse.
             should_check_deactivated=qualified_user_id != appservice.sender,
+            request_info=request_info,
         )
 
     async def _do_other_login(
-        self, login_submission: JsonDict, should_issue_refresh_token: bool = False
+        self,
+        login_submission: JsonDict,
+        should_issue_refresh_token: bool = False,
+        *,
+        request_info: RequestInfo,
     ) -> LoginResponse:
         """Handle non-token/saml/jwt logins
 
@@ -333,6 +347,7 @@ class LoginRestServlet(RestServlet):
             login_submission,
             callback,
             should_issue_refresh_token=should_issue_refresh_token,
+            request_info=request_info,
         )
         return result
 
@@ -347,6 +362,8 @@ class LoginRestServlet(RestServlet):
         should_issue_refresh_token: bool = False,
         auth_provider_session_id: Optional[str] = None,
         should_check_deactivated: bool = True,
+        *,
+        request_info: RequestInfo,
     ) -> LoginResponse:
         """Called when we've successfully authed the user and now need to
         actually login them in (e.g. create devices). This gets called on
@@ -371,6 +388,7 @@ class LoginRestServlet(RestServlet):
 
                 This exists purely for appservice's configured sender_localpart
                 which doesn't have an associated user in the database.
+            request_info: The user agent/IP address of the user.
 
         Returns:
             Dictionary of account information after successful login.
@@ -417,6 +435,22 @@ class LoginRestServlet(RestServlet):
                 )
 
         initial_display_name = login_submission.get("initial_device_display_name")
+        spam_check = await self._spam_checker.check_login_for_spam(
+            user_id,
+            device_id=device_id,
+            initial_display_name=initial_display_name,
+            request_info=[(request_info.user_agent, request_info.ip)],
+            auth_provider_id=auth_provider_id,
+        )
+        if spam_check != self._spam_checker.NOT_SPAM:
+            logger.info("Blocking login due to spam checker")
+            raise SynapseError(
+                403,
+                msg="Login was blocked by the server",
+                errcode=spam_check[0],
+                additional_fields=spam_check[1],
+            )
+
         (
             device_id,
             access_token,
@@ -451,7 +485,11 @@ class LoginRestServlet(RestServlet):
         return result
 
     async def _do_token_login(
-        self, login_submission: JsonDict, should_issue_refresh_token: bool = False
+        self,
+        login_submission: JsonDict,
+        should_issue_refresh_token: bool = False,
+        *,
+        request_info: RequestInfo,
     ) -> LoginResponse:
         """
         Handle token login.
@@ -474,10 +512,15 @@ class LoginRestServlet(RestServlet):
             auth_provider_id=res.auth_provider_id,
             should_issue_refresh_token=should_issue_refresh_token,
             auth_provider_session_id=res.auth_provider_session_id,
+            request_info=request_info,
         )
 
     async def _do_jwt_login(
-        self, login_submission: JsonDict, should_issue_refresh_token: bool = False
+        self,
+        login_submission: JsonDict,
+        should_issue_refresh_token: bool = False,
+        *,
+        request_info: RequestInfo,
     ) -> LoginResponse:
         """
         Handle the custom JWT login.
@@ -496,6 +539,7 @@ class LoginRestServlet(RestServlet):
             login_submission,
             create_non_existent_users=True,
             should_issue_refresh_token=should_issue_refresh_token,
+            request_info=request_info,
         )