summary refs log tree commit diff
path: root/synapse/events/spamcheck.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/events/spamcheck.py')
-rw-r--r--synapse/events/spamcheck.py306
1 files changed, 218 insertions, 88 deletions
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index d5fa195094..45ec96dfc1 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -15,7 +15,18 @@
 
 import inspect
 import logging
-from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Awaitable,
+    Callable,
+    Collection,
+    Dict,
+    List,
+    Optional,
+    Tuple,
+    Union,
+)
 
 from synapse.rest.media.v1._base import FileInfo
 from synapse.rest.media.v1.media_storage import ReadableFileWrapper
@@ -29,20 +40,186 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
+CHECK_EVENT_FOR_SPAM_CALLBACK = Callable[
+    ["synapse.events.EventBase"],
+    Awaitable[Union[bool, str]],
+]
+USER_MAY_INVITE_CALLBACK = Callable[[str, str, str], Awaitable[bool]]
+USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]]
+USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[[str, RoomAlias], Awaitable[bool]]
+USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]]
+CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[Dict[str, str]], Awaitable[bool]]
+LEGACY_CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[
+    [
+        Optional[dict],
+        Optional[str],
+        Collection[Tuple[str, str]],
+    ],
+    Awaitable[RegistrationBehaviour],
+]
+CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[
+    [
+        Optional[dict],
+        Optional[str],
+        Collection[Tuple[str, str]],
+        Optional[str],
+    ],
+    Awaitable[RegistrationBehaviour],
+]
+CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK = Callable[
+    [ReadableFileWrapper, FileInfo],
+    Awaitable[bool],
+]
+
+
+def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
+    """Wrapper that loads spam checkers configured using the old configuration, and
+    registers the spam checker hooks they implement.
+    """
+    spam_checkers = []  # type: List[Any]
+    api = hs.get_module_api()
+    for module, config in hs.config.spam_checkers:
+        # Older spam checkers don't accept the `api` argument, so we
+        # try and detect support.
+        spam_args = inspect.getfullargspec(module)
+        if "api" in spam_args.args:
+            spam_checkers.append(module(config=config, api=api))
+        else:
+            spam_checkers.append(module(config=config))
+
+    # The known spam checker hooks. If a spam checker module implements a method
+    # which name appears in this set, we'll want to register it.
+    spam_checker_methods = {
+        "check_event_for_spam",
+        "user_may_invite",
+        "user_may_create_room",
+        "user_may_create_room_alias",
+        "user_may_publish_room",
+        "check_username_for_spam",
+        "check_registration_for_spam",
+        "check_media_file_for_spam",
+    }
+
+    for spam_checker in spam_checkers:
+        # Methods on legacy spam checkers might not be async, so we wrap them around a
+        # wrapper that will call maybe_awaitable on the result.
+        def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
+            # f might be None if the callback isn't implemented by the module. In this
+            # case we don't want to register a callback at all so we return None.
+            if f is None:
+                return None
+
+            if f.__name__ == "check_registration_for_spam":
+                checker_args = inspect.signature(f)
+                if len(checker_args.parameters) == 3:
+                    # Backwards compatibility; some modules might implement a hook that
+                    # doesn't expect a 4th argument. In this case, wrap it in a function
+                    # that gives it only 3 arguments and drops the auth_provider_id on
+                    # the floor.
+                    def wrapper(
+                        email_threepid: Optional[dict],
+                        username: Optional[str],
+                        request_info: Collection[Tuple[str, str]],
+                        auth_provider_id: Optional[str],
+                    ) -> Union[Awaitable[RegistrationBehaviour], RegistrationBehaviour]:
+                        # We've already made sure f is not None above, but mypy doesn't
+                        # do well across function boundaries so we need to tell it f is
+                        # definitely not None.
+                        assert f is not None
+
+                        return f(
+                            email_threepid,
+                            username,
+                            request_info,
+                        )
+
+                    f = wrapper
+                elif len(checker_args.parameters) != 4:
+                    raise RuntimeError(
+                        "Bad signature for callback check_registration_for_spam",
+                    )
+
+            def run(*args, **kwargs):
+                # We've already made sure f is not None above, but mypy doesn't do well
+                # across function boundaries so we need to tell it f is definitely not
+                # None.
+                assert f is not None
+
+                return maybe_awaitable(f(*args, **kwargs))
+
+            return run
+
+        # Register the hooks through the module API.
+        hooks = {
+            hook: async_wrapper(getattr(spam_checker, hook, None))
+            for hook in spam_checker_methods
+        }
+
+        api.register_spam_checker_callbacks(**hooks)
+
 
 class SpamChecker:
-    def __init__(self, hs: "synapse.server.HomeServer"):
-        self.spam_checkers = []  # type: List[Any]
-        api = hs.get_module_api()
-
-        for module, config in hs.config.spam_checkers:
-            # Older spam checkers don't accept the `api` argument, so we
-            # try and detect support.
-            spam_args = inspect.getfullargspec(module)
-            if "api" in spam_args.args:
-                self.spam_checkers.append(module(config=config, api=api))
-            else:
-                self.spam_checkers.append(module(config=config))
+    def __init__(self):
+        self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = []
+        self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = []
+        self._user_may_create_room_callbacks: List[USER_MAY_CREATE_ROOM_CALLBACK] = []
+        self._user_may_create_room_alias_callbacks: List[
+            USER_MAY_CREATE_ROOM_ALIAS_CALLBACK
+        ] = []
+        self._user_may_publish_room_callbacks: List[USER_MAY_PUBLISH_ROOM_CALLBACK] = []
+        self._check_username_for_spam_callbacks: List[
+            CHECK_USERNAME_FOR_SPAM_CALLBACK
+        ] = []
+        self._check_registration_for_spam_callbacks: List[
+            CHECK_REGISTRATION_FOR_SPAM_CALLBACK
+        ] = []
+        self._check_media_file_for_spam_callbacks: List[
+            CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK
+        ] = []
+
+    def register_callbacks(
+        self,
+        check_event_for_spam: Optional[CHECK_EVENT_FOR_SPAM_CALLBACK] = None,
+        user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None,
+        user_may_create_room: Optional[USER_MAY_CREATE_ROOM_CALLBACK] = None,
+        user_may_create_room_alias: Optional[
+            USER_MAY_CREATE_ROOM_ALIAS_CALLBACK
+        ] = None,
+        user_may_publish_room: Optional[USER_MAY_PUBLISH_ROOM_CALLBACK] = None,
+        check_username_for_spam: Optional[CHECK_USERNAME_FOR_SPAM_CALLBACK] = None,
+        check_registration_for_spam: Optional[
+            CHECK_REGISTRATION_FOR_SPAM_CALLBACK
+        ] = None,
+        check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
+    ):
+        """Register callbacks from module for each hook."""
+        if check_event_for_spam is not None:
+            self._check_event_for_spam_callbacks.append(check_event_for_spam)
+
+        if user_may_invite is not None:
+            self._user_may_invite_callbacks.append(user_may_invite)
+
+        if user_may_create_room is not None:
+            self._user_may_create_room_callbacks.append(user_may_create_room)
+
+        if user_may_create_room_alias is not None:
+            self._user_may_create_room_alias_callbacks.append(
+                user_may_create_room_alias,
+            )
+
+        if user_may_publish_room is not None:
+            self._user_may_publish_room_callbacks.append(user_may_publish_room)
+
+        if check_username_for_spam is not None:
+            self._check_username_for_spam_callbacks.append(check_username_for_spam)
+
+        if check_registration_for_spam is not None:
+            self._check_registration_for_spam_callbacks.append(
+                check_registration_for_spam,
+            )
+
+        if check_media_file_for_spam is not None:
+            self._check_media_file_for_spam_callbacks.append(check_media_file_for_spam)
 
     async def check_event_for_spam(
         self, event: "synapse.events.EventBase"
@@ -60,9 +237,10 @@ class SpamChecker:
             True or a string if the event is spammy. If a string is returned it
             will be used as the error message returned to the user.
         """
-        for spam_checker in self.spam_checkers:
-            if await maybe_awaitable(spam_checker.check_event_for_spam(event)):
-                return True
+        for callback in self._check_event_for_spam_callbacks:
+            res = await callback(event)  # type: Union[bool, str]
+            if res:
+                return res
 
         return False
 
@@ -81,15 +259,8 @@ class SpamChecker:
         Returns:
             True if the user may send an invite, otherwise False
         """
-        for spam_checker in self.spam_checkers:
-            if (
-                await maybe_awaitable(
-                    spam_checker.user_may_invite(
-                        inviter_userid, invitee_userid, room_id
-                    )
-                )
-                is False
-            ):
+        for callback in self._user_may_invite_callbacks:
+            if await callback(inviter_userid, invitee_userid, room_id) is False:
                 return False
 
         return True
@@ -105,11 +276,8 @@ class SpamChecker:
         Returns:
             True if the user may create a room, otherwise False
         """
-        for spam_checker in self.spam_checkers:
-            if (
-                await maybe_awaitable(spam_checker.user_may_create_room(userid))
-                is False
-            ):
+        for callback in self._user_may_create_room_callbacks:
+            if await callback(userid) is False:
                 return False
 
         return True
@@ -128,13 +296,8 @@ class SpamChecker:
         Returns:
             True if the user may create a room alias, otherwise False
         """
-        for spam_checker in self.spam_checkers:
-            if (
-                await maybe_awaitable(
-                    spam_checker.user_may_create_room_alias(userid, room_alias)
-                )
-                is False
-            ):
+        for callback in self._user_may_create_room_alias_callbacks:
+            if await callback(userid, room_alias) is False:
                 return False
 
         return True
@@ -151,13 +314,8 @@ class SpamChecker:
         Returns:
             True if the user may publish the room, otherwise False
         """
-        for spam_checker in self.spam_checkers:
-            if (
-                await maybe_awaitable(
-                    spam_checker.user_may_publish_room(userid, room_id)
-                )
-                is False
-            ):
+        for callback in self._user_may_publish_room_callbacks:
+            if await callback(userid, room_id) is False:
                 return False
 
         return True
@@ -177,15 +335,11 @@ class SpamChecker:
         Returns:
             True if the user is spammy.
         """
-        for spam_checker in self.spam_checkers:
-            # For backwards compatibility, only run if the method exists on the
-            # spam checker
-            checker = getattr(spam_checker, "check_username_for_spam", None)
-            if checker:
-                # Make a copy of the user profile object to ensure the spam checker
-                # cannot modify it.
-                if await maybe_awaitable(checker(user_profile.copy())):
-                    return True
+        for callback in self._check_username_for_spam_callbacks:
+            # Make a copy of the user profile object to ensure the spam checker cannot
+            # modify it.
+            if await callback(user_profile.copy()):
+                return True
 
         return False
 
@@ -211,33 +365,13 @@ class SpamChecker:
             Enum for how the request should be handled
         """
 
-        for spam_checker in self.spam_checkers:
-            # For backwards compatibility, only run if the method exists on the
-            # spam checker
-            checker = getattr(spam_checker, "check_registration_for_spam", None)
-            if checker:
-                # Provide auth_provider_id if the function supports it
-                checker_args = inspect.signature(checker)
-                if len(checker_args.parameters) == 4:
-                    d = checker(
-                        email_threepid,
-                        username,
-                        request_info,
-                        auth_provider_id,
-                    )
-                elif len(checker_args.parameters) == 3:
-                    d = checker(email_threepid, username, request_info)
-                else:
-                    logger.error(
-                        "Invalid signature for %s.check_registration_for_spam. Denying registration",
-                        spam_checker.__module__,
-                    )
-                    return RegistrationBehaviour.DENY
-
-                behaviour = await maybe_awaitable(d)
-                assert isinstance(behaviour, RegistrationBehaviour)
-                if behaviour != RegistrationBehaviour.ALLOW:
-                    return behaviour
+        for callback in self._check_registration_for_spam_callbacks:
+            behaviour = await (
+                callback(email_threepid, username, request_info, auth_provider_id)
+            )
+            assert isinstance(behaviour, RegistrationBehaviour)
+            if behaviour != RegistrationBehaviour.ALLOW:
+                return behaviour
 
         return RegistrationBehaviour.ALLOW
 
@@ -275,13 +409,9 @@ class SpamChecker:
             allowed.
         """
 
-        for spam_checker in self.spam_checkers:
-            # For backwards compatibility, only run if the method exists on the
-            # spam checker
-            checker = getattr(spam_checker, "check_media_file_for_spam", None)
-            if checker:
-                spam = await maybe_awaitable(checker(file_wrapper, file_info))
-                if spam:
-                    return True
+        for callback in self._check_media_file_for_spam_callbacks:
+            spam = await callback(file_wrapper, file_info)
+            if spam:
+                return True
 
         return False