summary refs log tree commit diff
diff options
context:
space:
mode:
authorAndrew Morgan <andrew@amorgan.xyz>2023-03-08 16:22:49 +0000
committerAndrew Morgan <andrew@amorgan.xyz>2023-03-08 16:28:32 +0000
commit7441f1ba5fb7f638445cbf8d60a967702994246c (patch)
treea564fbe2360217ca11fa61a358dc8b8fb998e470
parentMove Spam Checker callbacks to a dedicated file (diff)
downloadsynapse-7441f1ba5fb7f638445cbf8d60a967702994246c.tar.xz
Update calling code for the Spam Checker
It's questionable whether `NOT_SPAM` should be defined in
"SpamCheckerModuleApiCallbacks", but putting just that in a separate
class feels a bit silly.
-rw-r--r--synapse/app/_base.py2
-rw-r--r--synapse/federation/federation_base.py6
-rw-r--r--synapse/federation/federation_server.py8
-rw-r--r--synapse/handlers/directory.py14
-rw-r--r--synapse/handlers/federation.py4
-rw-r--r--synapse/handlers/message.py10
-rw-r--r--synapse/handlers/register.py4
-rw-r--r--synapse/handlers/room.py10
-rw-r--r--synapse/handlers/room_member.py22
-rw-r--r--synapse/handlers/user_directory.py6
-rw-r--r--synapse/rest/media/v1/media_storage.py7
-rw-r--r--tests/handlers/test_user_directory.py2
-rw-r--r--tests/rest/client/test_rooms.py26
-rw-r--r--tests/rest/media/v1/test_media_storage.py2
-rw-r--r--tests/server.py2
15 files changed, 78 insertions, 47 deletions
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 28062dd69d..7b4637e968 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -59,7 +59,6 @@ from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import ListenerConfig, ManholeConfig
 from synapse.crypto import context_factory
 from synapse.events.presence_router import load_legacy_presence_router
-from synapse.events.spamcheck import load_legacy_spam_checkers
 from synapse.events.third_party_rules import load_legacy_third_party_event_rules
 from synapse.handlers.auth import load_legacy_password_auth_providers
 from synapse.http.site import SynapseSite
@@ -68,6 +67,7 @@ from synapse.logging.opentracing import init_tracer
 from synapse.metrics import install_gc_manager, register_threadpool
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.metrics.jemalloc import setup_jemalloc_stats
+from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_checkers
 from synapse.types import ISynapseReactor
 from synapse.util import SYNAPSE_VERSION
 from synapse.util.caches.lrucache import setup_expire_lru_cache_entries
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 29fae716f5..3df975958d 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -51,7 +51,7 @@ class FederationBase:
 
         self.server_name = hs.hostname
         self.keyring = hs.get_keyring()
-        self.spam_checker = hs.get_spam_checker()
+        self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
         self.store = hs.get_datastores().main
         self._clock = hs.get_clock()
         self._storage_controllers = hs.get_storage_controllers()
@@ -137,9 +137,9 @@ class FederationBase:
                     )
             return redacted_event
 
-        spam_check = await self.spam_checker.check_event_for_spam(pdu)
+        spam_check = await self._spam_checker_module_callbacks.check_event_for_spam(pdu)
 
-        if spam_check != self.spam_checker.NOT_SPAM:
+        if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
             logger.warning("Event contains spam, soft-failing %s", pdu.event_id)
             log_kv(
                 {
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 6d99845de5..ea5b8b218b 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -130,7 +130,7 @@ class FederationServer(FederationBase):
         super().__init__(hs)
 
         self.handler = hs.get_federation_handler()
-        self._spam_checker = hs.get_spam_checker()
+        self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
         self._federation_event_handler = hs.get_federation_event_handler()
         self.state = hs.get_state_handler()
         self._event_auth_handler = hs.get_event_auth_handler()
@@ -1129,7 +1129,7 @@ class FederationServer(FederationBase):
             logger.warning("event id %s: %s", pdu.event_id, e)
             raise FederationError("ERROR", 403, str(e), affected=pdu.event_id)
 
-        if await self._spam_checker.should_drop_federated_event(pdu):
+        if await self._spam_checker_module_callbacks.should_drop_federated_event(pdu):
             logger.warning(
                 "Unstaged federated event contains spam, dropping %s", pdu.event_id
             )
@@ -1174,7 +1174,9 @@ class FederationServer(FederationBase):
 
             origin, event = next
 
-            if await self._spam_checker.should_drop_federated_event(event):
+            if await self._spam_checker_module_callbacks.should_drop_federated_event(
+                event
+            ):
                 logger.warning(
                     "Staged federated event contains spam, dropping %s",
                     event.event_id,
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 1fb23cc9bf..5e8316e2e5 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -60,7 +60,7 @@ class DirectoryHandler:
             "directory", self.on_directory_query
         )
 
-        self.spam_checker = hs.get_spam_checker()
+        self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
 
     async def _create_association(
         self,
@@ -145,10 +145,12 @@ class DirectoryHandler:
                         403, "You must be in the room to create an alias for it"
                     )
 
-            spam_check = await self.spam_checker.user_may_create_room_alias(
-                user_id, room_alias
+            spam_check = (
+                await self._spam_checker_module_callbacks.user_may_create_room_alias(
+                    user_id, room_alias
+                )
             )
-            if spam_check != self.spam_checker.NOT_SPAM:
+            if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
                 raise AuthError(
                     403,
                     "This user is not permitted to create this alias",
@@ -444,7 +446,9 @@ class DirectoryHandler:
         """
         user_id = requester.user.to_string()
 
-        spam_check = await self.spam_checker.user_may_publish_room(user_id, room_id)
+        spam_check = await self._spam_checker_module_callbacks.user_may_publish_room(
+            user_id, room_id
+        )
         if spam_check != NOT_SPAM:
             raise AuthError(
                 403,
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 5f2057269d..deb2997bf5 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -141,7 +141,7 @@ class FederationHandler:
         self.server_name = hs.hostname
         self.keyring = hs.get_keyring()
         self.is_mine_id = hs.is_mine_id
-        self.spam_checker = hs.get_spam_checker()
+        self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
         self.event_creation_handler = hs.get_event_creation_handler()
         self.event_builder_factory = hs.get_event_builder_factory()
         self._event_auth_handler = hs.get_event_auth_handler()
@@ -1041,7 +1041,7 @@ class FederationHandler:
         if self.hs.config.server.block_non_admin_invites:
             raise SynapseError(403, "This server does not accept room invites")
 
-        spam_check = await self.spam_checker.user_may_invite(
+        spam_check = await self._spam_checker_module_callbacks.user_may_invite(
             event.sender, event.state_key, event.room_id
         )
         if spam_check != NOT_SPAM:
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index aa90d0000d..98639f3ca3 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -505,7 +505,7 @@ class EventCreationHandler:
 
         self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator()
 
-        self.spam_checker = hs.get_spam_checker()
+        self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
         self.third_party_event_rules: "ThirdPartyEventRules" = (
             self.hs.get_third_party_event_rules()
         )
@@ -1021,8 +1021,12 @@ class EventCreationHandler:
                     event.sender,
                 )
 
-                spam_check_result = await self.spam_checker.check_event_for_spam(event)
-                if spam_check_result != self.spam_checker.NOT_SPAM:
+                spam_check_result = (
+                    await self._spam_checker_module_callbacks.check_event_for_spam(
+                        event
+                    )
+                )
+                if spam_check_result != self._spam_checker_module_callbacks.NOT_SPAM:
                     if isinstance(spam_check_result, tuple):
                         try:
                             [code, dict] = spam_check_result
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index e4e506e62c..3ac4adcd99 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -110,7 +110,7 @@ class RegistrationHandler:
         self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
         self._server_name = hs.hostname
 
-        self.spam_checker = hs.get_spam_checker()
+        self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
 
         if hs.config.worker.worker_app:
             self._register_client = ReplicationRegisterServlet.make_client(hs)
@@ -259,7 +259,7 @@ class RegistrationHandler:
 
         await self.check_registration_ratelimit(address)
 
-        result = await self.spam_checker.check_registration_for_spam(
+        result = await self._spam_checker_module_callbacks.check_registration_for_spam(
             threepid,
             localpart,
             user_agent_ips or [],
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index a26ec02284..131d35155f 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -105,7 +105,7 @@ class RoomCreationHandler:
         self.auth_blocking = hs.get_auth_blocking()
         self.clock = hs.get_clock()
         self.hs = hs
-        self.spam_checker = hs.get_spam_checker()
+        self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
         self.event_creation_handler = hs.get_event_creation_handler()
         self.room_member_handler = hs.get_room_member_handler()
         self._event_auth_handler = hs.get_event_auth_handler()
@@ -445,7 +445,9 @@ class RoomCreationHandler:
         """
         user_id = requester.user.to_string()
 
-        spam_check = await self.spam_checker.user_may_create_room(user_id)
+        spam_check = await self._spam_checker_module_callbacks.user_may_create_room(
+            user_id
+        )
         if spam_check != NOT_SPAM:
             raise SynapseError(
                 403,
@@ -756,7 +758,9 @@ class RoomCreationHandler:
                 )
 
         if not is_requester_admin:
-            spam_check = await self.spam_checker.user_may_create_room(user_id)
+            spam_check = await self._spam_checker_module_callbacks.user_may_create_room(
+                user_id
+            )
             if spam_check != NOT_SPAM:
                 raise SynapseError(
                     403,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index a965c7ec76..6541e645c3 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -96,7 +96,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         self.member_as_limiter = Linearizer(max_count=10, name="member_as_limiter")
 
         self.clock = hs.get_clock()
-        self.spam_checker = hs.get_spam_checker()
+        self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
         self.third_party_event_rules = hs.get_third_party_event_rules()
         self._server_notices_mxid = self.config.servernotices.server_notices_mxid
         self._enable_lookup = hs.config.registration.enable_3pid_lookup
@@ -802,7 +802,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     )
                     block_invite_result = (Codes.FORBIDDEN, {})
 
-                spam_check = await self.spam_checker.user_may_invite(
+                spam_check = await self._spam_checker_module_callbacks.user_may_invite(
                     requester.user.to_string(), target_id, room_id
                 )
                 if spam_check != NOT_SPAM:
@@ -931,8 +931,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 # a room then they're allowed to join it.
                 and not new_room
             ):
-                spam_check = await self.spam_checker.user_may_join_room(
-                    target.to_string(), room_id, is_invited=inviter is not None
+                spam_check = (
+                    await self._spam_checker_module_callbacks.user_may_join_room(
+                        target.to_string(), room_id, is_invited=inviter is not None
+                    )
                 )
                 if spam_check != NOT_SPAM:
                     raise SynapseError(
@@ -1541,11 +1543,13 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             )
         else:
             # Check if the spamchecker(s) allow this invite to go through.
-            spam_check = await self.spam_checker.user_may_send_3pid_invite(
-                inviter_userid=requester.user.to_string(),
-                medium=medium,
-                address=address,
-                room_id=room_id,
+            spam_check = (
+                await self._spam_checker_module_callbacks.user_may_send_3pid_invite(
+                    inviter_userid=requester.user.to_string(),
+                    medium=medium,
+                    address=address,
+                    room_id=room_id,
+                )
             )
             if spam_check != NOT_SPAM:
                 raise SynapseError(
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 3610b6bf78..e150fcb16a 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -63,7 +63,7 @@ class UserDirectoryHandler(StateDeltasHandler):
         self.is_mine_id = hs.is_mine_id
         self.update_user_directory = hs.config.worker.should_update_user_directory
         self.search_all_users = hs.config.userdirectory.user_directory_search_all_users
-        self.spam_checker = hs.get_spam_checker()
+        self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
         # The current position in the current_state_delta stream
         self.pos: Optional[int] = None
 
@@ -101,7 +101,9 @@ class UserDirectoryHandler(StateDeltasHandler):
         # Remove any spammy users from the results.
         non_spammy_users = []
         for user in results["results"]:
-            if not await self.spam_checker.check_username_for_spam(user):
+            if not await self._spam_checker_module_callbacks.check_username_for_spam(
+                user
+            ):
                 non_spammy_users.append(user)
         results["results"] = non_spammy_users
 
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index db25848744..0702733926 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -36,7 +36,6 @@ from twisted.internet.defer import Deferred
 from twisted.internet.interfaces import IConsumer
 from twisted.protocols.basic import FileSender
 
-import synapse
 from synapse.api.errors import NotFoundError
 from synapse.logging.context import defer_to_thread, make_deferred_yieldable
 from synapse.util import Clock
@@ -74,7 +73,7 @@ class MediaStorage:
         self.local_media_directory = local_media_directory
         self.filepaths = filepaths
         self.storage_providers = storage_providers
-        self.spam_checker = hs.get_spam_checker()
+        self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
         self.clock = hs.get_clock()
 
     async def store_file(self, source: IO, file_info: FileInfo) -> str:
@@ -145,10 +144,10 @@ class MediaStorage:
                     f.flush()
                     f.close()
 
-                    spam_check = await self.spam_checker.check_media_file_for_spam(
+                    spam_check = await self._spam_checker_module_callbacks.check_media_file_for_spam(
                         ReadableFileWrapper(self.clock, fname), file_info
                     )
-                    if spam_check != synapse.module_api.NOT_SPAM:
+                    if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
                         logger.info("Blocking media due to spam checker")
                         # Note that we'll delete the stored media, due to the
                         # try/except below. The media also won't be stored in
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index a02c1c6227..1881f3742d 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -791,7 +791,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             return False
 
         # Configure a spam checker that does not filter any users.
-        spam_checker = self.hs.get_spam_checker()
+        spam_checker = self.hs.get_module_api_callbacks().spam_checker
         spam_checker._check_username_for_spam_callbacks = [allow_all]
 
         # The results do not change:
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 4dd763096d..0cb3b5a396 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -814,7 +814,9 @@ class RoomsCreateTestCase(RoomBase):
             return False
 
         join_mock = Mock(side_effect=user_may_join_room)
-        self.hs.get_spam_checker()._user_may_join_room_callbacks.append(join_mock)
+        self.hs.get_module_api_callbacks().spam_checker._user_may_join_room_callbacks.append(
+            join_mock
+        )
 
         channel = self.make_request(
             "POST",
@@ -840,7 +842,9 @@ class RoomsCreateTestCase(RoomBase):
             return Codes.CONSENT_NOT_GIVEN
 
         join_mock = Mock(side_effect=user_may_join_room_codes)
-        self.hs.get_spam_checker()._user_may_join_room_callbacks.append(join_mock)
+        self.hs.get_module_api_callbacks().spam_checker._user_may_join_room_callbacks.append(
+            join_mock
+        )
 
         channel = self.make_request(
             "POST",
@@ -1162,7 +1166,9 @@ class RoomJoinTestCase(RoomBase):
         # `spec` argument is needed for this function mock to have `__qualname__`, which
         # is needed for `Measure` metrics buried in SpamChecker.
         callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None)
-        self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock)
+        self.hs.get_module_api_callbacks().spam_checker._user_may_join_room_callbacks.append(
+            callback_mock
+        )
 
         # Join a first room, without being invited to it.
         self.helper.join(self.room1, self.user2, tok=self.tok2)
@@ -1227,7 +1233,9 @@ class RoomJoinTestCase(RoomBase):
         # `spec` argument is needed for this function mock to have `__qualname__`, which
         # is needed for `Measure` metrics buried in SpamChecker.
         callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None)
-        self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock)
+        self.hs.get_module_api_callbacks().spam_checker._user_may_join_room_callbacks.append(
+            callback_mock
+        )
 
         # Join a first room, without being invited to it.
         self.helper.join(self.room1, self.user2, tok=self.tok2)
@@ -1643,7 +1651,7 @@ class RoomMessagesTestCase(RoomBase):
 
         spam_checker = SpamCheck()
 
-        self.hs.get_spam_checker()._check_event_for_spam_callbacks.append(
+        self.hs.get_module_api_callbacks().spam_checker._check_event_for_spam_callbacks.append(
             spam_checker.check_event_for_spam
         )
 
@@ -3381,7 +3389,9 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
         # `spec` argument is needed for this function mock to have `__qualname__`, which
         # is needed for `Measure` metrics buried in SpamChecker.
         mock = Mock(return_value=make_awaitable(True), spec=lambda *x: None)
-        self.hs.get_spam_checker()._user_may_send_3pid_invite_callbacks.append(mock)
+        self.hs.get_module_api_callbacks().spam_checker._user_may_send_3pid_invite_callbacks.append(
+            mock
+        )
 
         # Send a 3PID invite into the room and check that it succeeded.
         email_to_invite = "teresa@example.com"
@@ -3446,7 +3456,9 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
             return_value=make_awaitable(synapse.module_api.NOT_SPAM),
             spec=lambda *x: None,
         )
-        self.hs.get_spam_checker()._user_may_send_3pid_invite_callbacks.append(mock)
+        self.hs.get_module_api_callbacks().spam_checker._user_may_send_3pid_invite_callbacks.append(
+            mock
+        )
 
         # Send a 3PID invite into the room and check that it succeeded.
         email_to_invite = "teresa@example.com"
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 8ed27179c4..4c58242e95 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -31,10 +31,10 @@ from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.errors import Codes
 from synapse.events import EventBase
-from synapse.events.spamcheck import load_legacy_spam_checkers
 from synapse.http.types import QueryParams
 from synapse.logging.context import make_deferred_yieldable
 from synapse.module_api import ModuleApi
+from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_checkers
 from synapse.rest import admin
 from synapse.rest.client import login
 from synapse.rest.media.v1._base import FileInfo
diff --git a/tests/server.py b/tests/server.py
index 5de9722766..dd1a89014e 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -72,11 +72,11 @@ from twisted.web.server import Request, Site
 from synapse.config.database import DatabaseConnectionConfig
 from synapse.config.homeserver import HomeServerConfig
 from synapse.events.presence_router import load_legacy_presence_router
-from synapse.events.spamcheck import load_legacy_spam_checkers
 from synapse.events.third_party_rules import load_legacy_third_party_event_rules
 from synapse.handlers.auth import load_legacy_password_auth_providers
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import ContextResourceUsage
+from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_checkers
 from synapse.server import HomeServer
 from synapse.storage import DataStore
 from synapse.storage.engines import PostgresEngine, create_engine