summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/9311.feature1
-rw-r--r--docs/spam_checker.md3
-rw-r--r--synapse/events/spamcheck.py47
-rw-r--r--synapse/rest/media/v1/media_storage.py59
-rw-r--r--synapse/rest/media/v1/upload_resource.py12
-rw-r--r--tests/rest/media/v1/test_media_storage.py94
6 files changed, 210 insertions, 6 deletions
diff --git a/changelog.d/9311.feature b/changelog.d/9311.feature
new file mode 100644
index 0000000000..293f2118e5
--- /dev/null
+++ b/changelog.d/9311.feature
@@ -0,0 +1 @@
+Add hook to spam checker modules that allow checking file uploads and remote downloads.
diff --git a/docs/spam_checker.md b/docs/spam_checker.md
index 5b4f6428e6..47a27bf85c 100644
--- a/docs/spam_checker.md
+++ b/docs/spam_checker.md
@@ -61,6 +61,9 @@ class ExampleSpamChecker:
 
     async def check_registration_for_spam(self, email_threepid, username, request_info):
         return RegistrationBehaviour.ALLOW  # allow all registrations
+
+    async def check_media_file_for_spam(self, file_wrapper, file_info):
+        return False  # allow all media
 ```
 
 ## Configuration
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index e7e3a7b9a4..8cfc0bb3cb 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -17,6 +17,8 @@
 import inspect
 from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
 
+from synapse.rest.media.v1._base import FileInfo
+from synapse.rest.media.v1.media_storage import ReadableFileWrapper
 from synapse.spam_checker_api import RegistrationBehaviour
 from synapse.types import Collection
 from synapse.util.async_helpers import maybe_awaitable
@@ -214,3 +216,48 @@ class SpamChecker:
                     return behaviour
 
         return RegistrationBehaviour.ALLOW
+
+    async def check_media_file_for_spam(
+        self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
+    ) -> bool:
+        """Checks if a piece of newly uploaded media should be blocked.
+
+        This will be called for local uploads, downloads of remote media, each
+        thumbnail generated for those, and web pages/images used for URL
+        previews.
+
+        Note that care should be taken to not do blocking IO operations in the
+        main thread. For example, to get the contents of a file a module
+        should do::
+
+            async def check_media_file_for_spam(
+                self, file: ReadableFileWrapper, file_info: FileInfo
+            ) -> bool:
+                buffer = BytesIO()
+                await file.write_chunks_to(buffer.write)
+
+                if buffer.getvalue() == b"Hello World":
+                    return True
+
+                return False
+
+
+        Args:
+            file: An object that allows reading the contents of the media.
+            file_info: Metadata about the file.
+
+        Returns:
+            True if the media should be blocked or False if it should be
+            allowed.
+        """
+
+        for spam_checker in self.spam_checkers:
+            # For backwards compatibility, only run if the method exists on the
+            # spam checker
+            checker = getattr(spam_checker, "check_media_file_for_spam", None)
+            if checker:
+                spam = await maybe_awaitable(checker(file_wrapper, file_info))
+                if spam:
+                    return True
+
+        return False
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 89cdd605aa..aba6d689a8 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -16,13 +16,17 @@ import contextlib
 import logging
 import os
 import shutil
-from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
+from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Sequence
+
+import attr
 
 from twisted.internet.defer import Deferred
 from twisted.internet.interfaces import IConsumer
 from twisted.protocols.basic import FileSender
 
+from synapse.api.errors import NotFoundError
 from synapse.logging.context import defer_to_thread, make_deferred_yieldable
+from synapse.util import Clock
 from synapse.util.file_consumer import BackgroundFileConsumer
 
 from ._base import FileInfo, Responder
@@ -58,6 +62,8 @@ class MediaStorage:
         self.local_media_directory = local_media_directory
         self.filepaths = filepaths
         self.storage_providers = storage_providers
+        self.spam_checker = hs.get_spam_checker()
+        self.clock = hs.get_clock()
 
     async def store_file(self, source: IO, file_info: FileInfo) -> str:
         """Write `source` to the on disk media store, and also any other
@@ -127,18 +133,29 @@ class MediaStorage:
                     f.flush()
                     f.close()
 
+                    spam = await self.spam_checker.check_media_file_for_spam(
+                        ReadableFileWrapper(self.clock, fname), file_info
+                    )
+                    if spam:
+                        logger.info("Blocking media due to spam checker")
+                        # Note that we'll delete the stored media, due to the
+                        # try/except below. The media also won't be stored in
+                        # the DB.
+                        raise SpamMediaException()
+
                     for provider in self.storage_providers:
                         await provider.store_file(path, file_info)
 
                     finished_called[0] = True
 
                 yield f, fname, finish
-        except Exception:
+        except Exception as e:
             try:
                 os.remove(fname)
             except Exception:
                 pass
-            raise
+
+            raise e from None
 
         if not finished_called:
             raise Exception("Finished callback not called")
@@ -302,3 +319,39 @@ class FileResponder(Responder):
 
     def __exit__(self, exc_type, exc_val, exc_tb):
         self.open_file.close()
+
+
+class SpamMediaException(NotFoundError):
+    """The media was blocked by a spam checker, so we simply 404 the request (in
+    the same way as if it was quarantined).
+    """
+
+
+@attr.s(slots=True)
+class ReadableFileWrapper:
+    """Wrapper that allows reading a file in chunks, yielding to the reactor,
+    and writing to a callback.
+
+    This is simplified `FileSender` that takes an IO object rather than an
+    `IConsumer`.
+    """
+
+    CHUNK_SIZE = 2 ** 14
+
+    clock = attr.ib(type=Clock)
+    path = attr.ib(type=str)
+
+    async def write_chunks_to(self, callback: Callable[[bytes], None]):
+        """Reads the file in chunks and calls the callback with each chunk.
+        """
+
+        with open(self.path, "rb") as file:
+            while True:
+                chunk = file.read(self.CHUNK_SIZE)
+                if not chunk:
+                    break
+
+                callback(chunk)
+
+                # We yield to the reactor by sleeping for 0 seconds.
+                await self.clock.sleep(0)
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 6da76ae994..1136277794 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -22,6 +22,7 @@ from twisted.web.http import Request
 from synapse.api.errors import Codes, SynapseError
 from synapse.http.server import DirectServeJsonResource, respond_with_json
 from synapse.http.servlet import parse_string
+from synapse.rest.media.v1.media_storage import SpamMediaException
 
 if TYPE_CHECKING:
     from synapse.app.homeserver import HomeServer
@@ -86,9 +87,14 @@ class UploadResource(DirectServeJsonResource):
         #     disposition = headers.getRawHeaders(b"Content-Disposition")[0]
         # TODO(markjh): parse content-dispostion
 
-        content_uri = await self.media_repo.create_content(
-            media_type, upload_name, request.content, content_length, requester.user
-        )
+        try:
+            content_uri = await self.media_repo.create_content(
+                media_type, upload_name, request.content, content_length, requester.user
+            )
+        except SpamMediaException:
+            # For uploading of media we want to respond with a 400, instead of
+            # the default 404, as that would just be confusing.
+            raise SynapseError(400, "Bad content")
 
         logger.info("Uploaded content with URI %r", content_uri)
 
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index a6c6985173..c279eb49e3 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -30,6 +30,8 @@ from twisted.internet import defer
 from twisted.internet.defer import Deferred
 
 from synapse.logging.context import make_deferred_yieldable
+from synapse.rest import admin
+from synapse.rest.client.v1 import login
 from synapse.rest.media.v1._base import FileInfo
 from synapse.rest.media.v1.filepath import MediaFilePaths
 from synapse.rest.media.v1.media_storage import MediaStorage
@@ -37,6 +39,7 @@ from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
 
 from tests import unittest
 from tests.server import FakeSite, make_request
+from tests.utils import default_config
 
 
 class MediaStorageTests(unittest.HomeserverTestCase):
@@ -398,3 +401,94 @@ class MediaRepoTests(unittest.HomeserverTestCase):
             headers.getRawHeaders(b"X-Robots-Tag"),
             [b"noindex, nofollow, noarchive, noimageindex"],
         )
+
+
+class TestSpamChecker:
+    """A spam checker module that rejects all media that includes the bytes
+    `evil`.
+    """
+
+    def __init__(self, config, api):
+        self.config = config
+        self.api = api
+
+    def parse_config(config):
+        return config
+
+    async def check_event_for_spam(self, foo):
+        return False  # allow all events
+
+    async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
+        return True  # allow all invites
+
+    async def user_may_create_room(self, userid):
+        return True  # allow all room creations
+
+    async def user_may_create_room_alias(self, userid, room_alias):
+        return True  # allow all room aliases
+
+    async def user_may_publish_room(self, userid, room_id):
+        return True  # allow publishing of all rooms
+
+    async def check_media_file_for_spam(self, file_wrapper, file_info) -> bool:
+        buf = BytesIO()
+        await file_wrapper.write_chunks_to(buf.write)
+
+        return b"evil" in buf.getvalue()
+
+
+class SpamCheckerTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        login.register_servlets,
+        admin.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.user = self.register_user("user", "pass")
+        self.tok = self.login("user", "pass")
+
+        # Allow for uploading and downloading to/from the media repo
+        self.media_repo = hs.get_media_repository_resource()
+        self.download_resource = self.media_repo.children[b"download"]
+        self.upload_resource = self.media_repo.children[b"upload"]
+
+    def default_config(self):
+        config = default_config("test")
+
+        config.update(
+            {
+                "spam_checker": [
+                    {
+                        "module": TestSpamChecker.__module__ + ".TestSpamChecker",
+                        "config": {},
+                    }
+                ]
+            }
+        )
+
+        return config
+
+    def test_upload_innocent(self):
+        """Attempt to upload some innocent data that should be allowed.
+        """
+
+        image_data = unhexlify(
+            b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+            b"0000001f15c4890000000a49444154789c63000100000500010d"
+            b"0a2db40000000049454e44ae426082"
+        )
+
+        self.helper.upload_media(
+            self.upload_resource, image_data, tok=self.tok, expect_code=200
+        )
+
+    def test_upload_ban(self):
+        """Attempt to upload some data that includes bytes "evil", which should
+        get rejected by the spam checker.
+        """
+
+        data = b"Some evil data"
+
+        self.helper.upload_media(
+            self.upload_resource, data, tok=self.tok, expect_code=400
+        )