diff --git a/synapse/module_api/callbacks/spamchecker_callbacks.py b/synapse/module_api/callbacks/spamchecker_callbacks.py
index 17079ff781..a2f328cafe 100644
--- a/synapse/module_api/callbacks/spamchecker_callbacks.py
+++ b/synapse/module_api/callbacks/spamchecker_callbacks.py
@@ -31,6 +31,7 @@ from typing import (
Optional,
Tuple,
Union,
+ cast,
)
# `Literal` appears with Python 3.8.
@@ -168,7 +169,10 @@ USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[
]
],
]
-CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[UserProfile], Awaitable[bool]]
+CHECK_USERNAME_FOR_SPAM_CALLBACK = Union[
+ Callable[[UserProfile], Awaitable[bool]],
+ Callable[[UserProfile, str], Awaitable[bool]],
+]
LEGACY_CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[
[
Optional[dict],
@@ -716,7 +720,9 @@ class SpamCheckerModuleApiCallbacks:
return self.NOT_SPAM
- async def check_username_for_spam(self, user_profile: UserProfile) -> bool:
+ async def check_username_for_spam(
+ self, user_profile: UserProfile, requester_id: str
+ ) -> bool:
"""Checks if a user ID or display name are considered "spammy" by this server.
If the server considers a username spammy, then it will not be included in
@@ -727,15 +733,33 @@ class SpamCheckerModuleApiCallbacks:
* user_id
* display_name
* avatar_url
+ requester_id: The user ID of the user making the user directory search request.
Returns:
True if the user is spammy.
"""
for callback in self._check_username_for_spam_callbacks:
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
+ checker_args = inspect.signature(callback)
# Make a copy of the user profile object to ensure the spam checker cannot
# modify it.
- res = await delay_cancellation(callback(user_profile.copy()))
+ # Also ensure backwards compatibility with spam checker callbacks
+ # that don't expect the requester_id argument.
+ if len(checker_args.parameters) == 2:
+ callback_with_requester_id = cast(
+ Callable[[UserProfile, str], Awaitable[bool]], callback
+ )
+ res = await delay_cancellation(
+ callback_with_requester_id(user_profile.copy(), requester_id)
+ )
+ else:
+ callback_without_requester_id = cast(
+ Callable[[UserProfile], Awaitable[bool]], callback
+ )
+ res = await delay_cancellation(
+ callback_without_requester_id(user_profile.copy())
+ )
+
if res:
return True
|