diff --git a/synapse/module_api/callbacks/spamchecker_callbacks.py b/synapse/module_api/callbacks/spamchecker_callbacks.py
index 4456d1b81e..7cee442145 100644
--- a/synapse/module_api/callbacks/spamchecker_callbacks.py
+++ b/synapse/module_api/callbacks/spamchecker_callbacks.py
@@ -196,6 +196,26 @@ CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK = Callable[
]
],
]
+CHECK_LOGIN_FOR_SPAM_CALLBACK = Callable[
+ [
+ str,
+ Optional[str],
+ Optional[str],
+ Collection[Tuple[Optional[str], str]],
+ Optional[str],
+ ],
+ Awaitable[
+ Union[
+ Literal["NOT_SPAM"],
+ Codes,
+ # Highly experimental, not officially part of the spamchecker API, may
+ # disappear without warning depending on the results of ongoing
+ # experiments.
+ # Use this to return additional information as part of an error.
+ Tuple[Codes, JsonDict],
+ ]
+ ],
+]
def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None:
@@ -315,6 +335,7 @@ class SpamCheckerModuleApiCallbacks:
self._check_media_file_for_spam_callbacks: List[
CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK
] = []
+ self._check_login_for_spam_callbacks: List[CHECK_LOGIN_FOR_SPAM_CALLBACK] = []
def register_callbacks(
self,
@@ -335,6 +356,7 @@ class SpamCheckerModuleApiCallbacks:
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
] = None,
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
+ check_login_for_spam: Optional[CHECK_LOGIN_FOR_SPAM_CALLBACK] = None,
) -> None:
"""Register callbacks from module for each hook."""
if check_event_for_spam is not None:
@@ -378,6 +400,9 @@ class SpamCheckerModuleApiCallbacks:
if check_media_file_for_spam is not None:
self._check_media_file_for_spam_callbacks.append(check_media_file_for_spam)
+ if check_login_for_spam is not None:
+ self._check_login_for_spam_callbacks.append(check_login_for_spam)
+
@trace
async def check_event_for_spam(
self, event: "synapse.events.EventBase"
@@ -819,3 +844,58 @@ class SpamCheckerModuleApiCallbacks:
return synapse.api.errors.Codes.FORBIDDEN, {}
return self.NOT_SPAM
+
+ async def check_login_for_spam(
+ self,
+ user_id: str,
+ device_id: Optional[str],
+ initial_display_name: Optional[str],
+ request_info: Collection[Tuple[Optional[str], str]],
+ auth_provider_id: Optional[str] = None,
+ ) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]:
+ """Checks if we should allow the given registration request.
+
+ Args:
+ user_id: The request user ID
+ request_info: List of tuples of user agent and IP that
+ were used during the registration process.
+ auth_provider_id: The SSO IdP the user used, e.g "oidc", "saml",
+ "cas". If any. Note this does not include users registered
+ via a password provider.
+
+ Returns:
+ Enum for how the request should be handled
+ """
+
+ for callback in self._check_login_for_spam_callbacks:
+ with Measure(
+ self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+ ):
+ res = await delay_cancellation(
+ callback(
+ user_id,
+ device_id,
+ initial_display_name,
+ request_info,
+ auth_provider_id,
+ )
+ )
+ # Normalize return values to `Codes` or `"NOT_SPAM"`.
+ if res is self.NOT_SPAM:
+ continue
+ elif isinstance(res, synapse.api.errors.Codes):
+ return res, {}
+ elif (
+ isinstance(res, tuple)
+ and len(res) == 2
+ and isinstance(res[0], synapse.api.errors.Codes)
+ and isinstance(res[1], dict)
+ ):
+ return res
+ else:
+ logger.warning(
+ "Module returned invalid value, rejecting login as spam"
+ )
+ return synapse.api.errors.Codes.FORBIDDEN, {}
+
+ return self.NOT_SPAM
|