summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/config/homeserver.py4
-rw-r--r--synapse/events/spamcheck.py37
-rw-r--r--synapse/federation/federation_base.py5
-rw-r--r--synapse/handlers/message.py5
-rw-r--r--synapse/server.py5
5 files changed, 33 insertions, 23 deletions
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index b22cacf8dc..3f9d9d5f8b 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -34,6 +34,7 @@ from .password_auth_providers import PasswordAuthProviderConfig
 from .emailconfig import EmailConfig
 from .workers import WorkerConfig
 from .push import PushConfig
+from .spam_checker import SpamCheckerConfig
 
 
 class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
@@ -41,7 +42,8 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
                        VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
                        AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
                        JWTConfig, PasswordConfig, EmailConfig,
-                       WorkerConfig, PasswordAuthProviderConfig, PushConfig,):
+                       WorkerConfig, PasswordAuthProviderConfig, PushConfig,
+                       SpamCheckerConfig,):
     pass
 
 
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 56fa9e556e..7b22b3413a 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -13,26 +13,29 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+class SpamChecker(object):
+    def __init__(self, hs):
+        self.spam_checker = None
 
-def check_event_for_spam(event):
-    """Checks if a given event is considered "spammy" by this server.
+        if hs.config.spam_checker is not None:
+            module, config = hs.config.spam_checker
+            print("cfg %r", config)
+            self.spam_checker = module(config=config)
 
-    If the server considers an event spammy, then it will be rejected if
-    sent by a local user. If it is sent by a user on another server, then
-    users receive a blank event.
+    def check_event_for_spam(self, event):
+        """Checks if a given event is considered "spammy" by this server.
 
-    Args:
-        event (synapse.events.EventBase): the event to be checked
+        If the server considers an event spammy, then it will be rejected if
+        sent by a local user. If it is sent by a user on another server, then
+        users receive a blank event.
 
-    Returns:
-        bool: True if the event is spammy.
-    """
-    if not hasattr(event, "content") or "body" not in event.content:
-        return False
+        Args:
+            event (synapse.events.EventBase): the event to be checked
 
-    # for example:
-    #
-    # if "the third flower is green" in event.content["body"]:
-    #    return True
+        Returns:
+            bool: True if the event is spammy.
+        """
+        if self.spam_checker is None:
+            return False
 
-    return False
+        return self.spam_checker.check_event_for_spam(event)
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index babd9ea078..a0f5d40eb3 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -16,7 +16,6 @@ import logging
 
 from synapse.api.errors import SynapseError
 from synapse.crypto.event_signing import check_event_content_hash
-from synapse.events import spamcheck
 from synapse.events.utils import prune_event
 from synapse.util import unwrapFirstError, logcontext
 from twisted.internet import defer
@@ -26,7 +25,7 @@ logger = logging.getLogger(__name__)
 
 class FederationBase(object):
     def __init__(self, hs):
-        pass
+        self.spam_checker = hs.get_spam_checker()
 
     @defer.inlineCallbacks
     def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
@@ -144,7 +143,7 @@ class FederationBase(object):
                     )
                     return redacted
 
-                if spamcheck.check_event_for_spam(pdu):
+                if self.spam_checker.check_event_for_spam(pdu):
                     logger.warn(
                         "Event contains spam, redacting %s: %s",
                         pdu.event_id, pdu.get_pdu_json()
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index da18bf23db..37f0a2772a 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -12,7 +12,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from synapse.events import spamcheck
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, Membership
@@ -58,6 +57,8 @@ class MessageHandler(BaseHandler):
 
         self.action_generator = hs.get_action_generator()
 
+        self.spam_checker = hs.get_spam_checker()
+
     @defer.inlineCallbacks
     def purge_history(self, room_id, event_id):
         event = yield self.store.get_event(event_id)
@@ -322,7 +323,7 @@ class MessageHandler(BaseHandler):
             txn_id=txn_id
         )
 
-        if spamcheck.check_event_for_spam(event):
+        if self.spam_checker.check_event_for_spam(event):
             raise SynapseError(
                 403, "Spam is not permitted here", Codes.FORBIDDEN
             )
diff --git a/synapse/server.py b/synapse/server.py
index a38e5179e0..4d44af745e 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -31,6 +31,7 @@ from synapse.appservice.api import ApplicationServiceApi
 from synapse.appservice.scheduler import ApplicationServiceScheduler
 from synapse.crypto.keyring import Keyring
 from synapse.events.builder import EventBuilderFactory
+from synapse.events.spamcheck import SpamChecker
 from synapse.federation import initialize_http_replication
 from synapse.federation.send_queue import FederationRemoteSendQueue
 from synapse.federation.transport.client import TransportLayerClient
@@ -139,6 +140,7 @@ class HomeServer(object):
         'read_marker_handler',
         'action_generator',
         'user_directory_handler',
+        'spam_checker',
     ]
 
     def __init__(self, hostname, **kwargs):
@@ -309,6 +311,9 @@ class HomeServer(object):
     def build_user_directory_handler(self):
         return UserDirectoyHandler(self)
 
+    def build_spam_checker(self):
+        return SpamChecker(self)
+
     def remove_pusher(self, app_id, push_key, user_id):
         return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)