summary refs log tree commit diff
path: root/synapse/events
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/events/builder.py4
-rw-r--r--synapse/events/snapshot.py3
-rw-r--r--synapse/events/spamcheck.py47
-rw-r--r--synapse/events/third_party_rules.py3
-rw-r--r--synapse/events/utils.py2
5 files changed, 54 insertions, 5 deletions
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 07df258e6e..c1c0426f6e 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -98,7 +98,9 @@ class EventBuilder:
         return self._state_key is not None
 
     async def build(
-        self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]],
+        self,
+        prev_event_ids: List[str],
+        auth_event_ids: Optional[List[str]],
     ) -> EventBase:
         """Transform into a fully signed and hashed event
 
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index afecafe15c..7295df74fe 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -341,8 +341,7 @@ def _encode_state_dict(state_dict):
 
 
 def _decode_state_dict(input):
-    """Decodes a state dict encoded using `_encode_state_dict` above
-    """
+    """Decodes a state dict encoded using `_encode_state_dict` above"""
     if input is None:
         return None
 
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index e7e3a7b9a4..8cfc0bb3cb 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -17,6 +17,8 @@
 import inspect
 from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
 
+from synapse.rest.media.v1._base import FileInfo
+from synapse.rest.media.v1.media_storage import ReadableFileWrapper
 from synapse.spam_checker_api import RegistrationBehaviour
 from synapse.types import Collection
 from synapse.util.async_helpers import maybe_awaitable
@@ -214,3 +216,48 @@ class SpamChecker:
                     return behaviour
 
         return RegistrationBehaviour.ALLOW
+
+    async def check_media_file_for_spam(
+        self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
+    ) -> bool:
+        """Checks if a piece of newly uploaded media should be blocked.
+
+        This will be called for local uploads, downloads of remote media, each
+        thumbnail generated for those, and web pages/images used for URL
+        previews.
+
+        Note that care should be taken to not do blocking IO operations in the
+        main thread. For example, to get the contents of a file a module
+        should do::
+
+            async def check_media_file_for_spam(
+                self, file: ReadableFileWrapper, file_info: FileInfo
+            ) -> bool:
+                buffer = BytesIO()
+                await file.write_chunks_to(buffer.write)
+
+                if buffer.getvalue() == b"Hello World":
+                    return True
+
+                return False
+
+
+        Args:
+            file: An object that allows reading the contents of the media.
+            file_info: Metadata about the file.
+
+        Returns:
+            True if the media should be blocked or False if it should be
+            allowed.
+        """
+
+        for spam_checker in self.spam_checkers:
+            # For backwards compatibility, only run if the method exists on the
+            # spam checker
+            checker = getattr(spam_checker, "check_media_file_for_spam", None)
+            if checker:
+                spam = await maybe_awaitable(checker(file_wrapper, file_info))
+                if spam:
+                    return True
+
+        return False
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 77fbd3f68a..02bce8b5c9 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -40,7 +40,8 @@ class ThirdPartyEventRules:
 
         if module is not None:
             self.third_party_rules = module(
-                config=config, module_api=hs.get_module_api(),
+                config=config,
+                module_api=hs.get_module_api(),
             )
 
     async def check_event_allowed(
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 9c22e33813..7ca5c9940a 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -34,7 +34,7 @@ SPLIT_FIELD_REGEX = re.compile(r"(?<!\\)\.")
 
 
 def prune_event(event: EventBase) -> EventBase:
-    """ Returns a pruned version of the given event, which removes all keys we
+    """Returns a pruned version of the given event, which removes all keys we
     don't know about or think could potentially be dodgy.
 
     This is used when we "redact" an event. We want to remove all fields that