diff --git a/changelog.d/9311.feature b/changelog.d/9311.feature
new file mode 100644
index 0000000000..293f2118e5
--- /dev/null
+++ b/changelog.d/9311.feature
@@ -0,0 +1 @@
+Add hook to spam checker modules that allow checking file uploads and remote downloads.
diff --git a/docs/spam_checker.md b/docs/spam_checker.md
index 5b4f6428e6..47a27bf85c 100644
--- a/docs/spam_checker.md
+++ b/docs/spam_checker.md
@@ -61,6 +61,9 @@ class ExampleSpamChecker:
async def check_registration_for_spam(self, email_threepid, username, request_info):
return RegistrationBehaviour.ALLOW # allow all registrations
+
+ async def check_media_file_for_spam(self, file_wrapper, file_info):
+ return False # allow all media
```
## Configuration
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index e7e3a7b9a4..8cfc0bb3cb 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -17,6 +17,8 @@
import inspect
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+from synapse.rest.media.v1._base import FileInfo
+from synapse.rest.media.v1.media_storage import ReadableFileWrapper
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import Collection
from synapse.util.async_helpers import maybe_awaitable
@@ -214,3 +216,48 @@ class SpamChecker:
return behaviour
return RegistrationBehaviour.ALLOW
+
+ async def check_media_file_for_spam(
+ self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
+ ) -> bool:
+ """Checks if a piece of newly uploaded media should be blocked.
+
+ This will be called for local uploads, downloads of remote media, each
+ thumbnail generated for those, and web pages/images used for URL
+ previews.
+
+ Note that care should be taken to not do blocking IO operations in the
+ main thread. For example, to get the contents of a file a module
+ should do::
+
+ async def check_media_file_for_spam(
+ self, file: ReadableFileWrapper, file_info: FileInfo
+ ) -> bool:
+ buffer = BytesIO()
+ await file.write_chunks_to(buffer.write)
+
+ if buffer.getvalue() == b"Hello World":
+ return True
+
+ return False
+
+
+ Args:
+ file: An object that allows reading the contents of the media.
+ file_info: Metadata about the file.
+
+ Returns:
+ True if the media should be blocked or False if it should be
+ allowed.
+ """
+
+ for spam_checker in self.spam_checkers:
+ # For backwards compatibility, only run if the method exists on the
+ # spam checker
+ checker = getattr(spam_checker, "check_media_file_for_spam", None)
+ if checker:
+ spam = await maybe_awaitable(checker(file_wrapper, file_info))
+ if spam:
+ return True
+
+ return False
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 89cdd605aa..aba6d689a8 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -16,13 +16,17 @@ import contextlib
import logging
import os
import shutil
-from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
+from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Sequence
+
+import attr
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
+from synapse.api.errors import NotFoundError
from synapse.logging.context import defer_to_thread, make_deferred_yieldable
+from synapse.util import Clock
from synapse.util.file_consumer import BackgroundFileConsumer
from ._base import FileInfo, Responder
@@ -58,6 +62,8 @@ class MediaStorage:
self.local_media_directory = local_media_directory
self.filepaths = filepaths
self.storage_providers = storage_providers
+ self.spam_checker = hs.get_spam_checker()
+ self.clock = hs.get_clock()
async def store_file(self, source: IO, file_info: FileInfo) -> str:
"""Write `source` to the on disk media store, and also any other
@@ -127,18 +133,29 @@ class MediaStorage:
f.flush()
f.close()
+ spam = await self.spam_checker.check_media_file_for_spam(
+ ReadableFileWrapper(self.clock, fname), file_info
+ )
+ if spam:
+ logger.info("Blocking media due to spam checker")
+ # Note that we'll delete the stored media, due to the
+ # try/except below. The media also won't be stored in
+ # the DB.
+ raise SpamMediaException()
+
for provider in self.storage_providers:
await provider.store_file(path, file_info)
finished_called[0] = True
yield f, fname, finish
- except Exception:
+ except Exception as e:
try:
os.remove(fname)
except Exception:
pass
- raise
+
+ raise e from None
if not finished_called:
raise Exception("Finished callback not called")
@@ -302,3 +319,39 @@ class FileResponder(Responder):
def __exit__(self, exc_type, exc_val, exc_tb):
self.open_file.close()
+
+
+class SpamMediaException(NotFoundError):
+ """The media was blocked by a spam checker, so we simply 404 the request (in
+ the same way as if it was quarantined).
+ """
+
+
+@attr.s(slots=True)
+class ReadableFileWrapper:
+ """Wrapper that allows reading a file in chunks, yielding to the reactor,
+ and writing to a callback.
+
+ This is simplified `FileSender` that takes an IO object rather than an
+ `IConsumer`.
+ """
+
+ CHUNK_SIZE = 2 ** 14
+
+ clock = attr.ib(type=Clock)
+ path = attr.ib(type=str)
+
+ async def write_chunks_to(self, callback: Callable[[bytes], None]):
+ """Reads the file in chunks and calls the callback with each chunk.
+ """
+
+ with open(self.path, "rb") as file:
+ while True:
+ chunk = file.read(self.CHUNK_SIZE)
+ if not chunk:
+ break
+
+ callback(chunk)
+
+ # We yield to the reactor by sleeping for 0 seconds.
+ await self.clock.sleep(0)
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 6da76ae994..1136277794 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -22,6 +22,7 @@ from twisted.web.http import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_string
+from synapse.rest.media.v1.media_storage import SpamMediaException
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
@@ -86,9 +87,14 @@ class UploadResource(DirectServeJsonResource):
# disposition = headers.getRawHeaders(b"Content-Disposition")[0]
# TODO(markjh): parse content-dispostion
- content_uri = await self.media_repo.create_content(
- media_type, upload_name, request.content, content_length, requester.user
- )
+ try:
+ content_uri = await self.media_repo.create_content(
+ media_type, upload_name, request.content, content_length, requester.user
+ )
+ except SpamMediaException:
+ # For uploading of media we want to respond with a 400, instead of
+ # the default 404, as that would just be confusing.
+ raise SynapseError(400, "Bad content")
logger.info("Uploaded content with URI %r", content_uri)
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index a6c6985173..c279eb49e3 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -30,6 +30,8 @@ from twisted.internet import defer
from twisted.internet.defer import Deferred
from synapse.logging.context import make_deferred_yieldable
+from synapse.rest import admin
+from synapse.rest.client.v1 import login
from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.rest.media.v1.media_storage import MediaStorage
@@ -37,6 +39,7 @@ from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
from tests import unittest
from tests.server import FakeSite, make_request
+from tests.utils import default_config
class MediaStorageTests(unittest.HomeserverTestCase):
@@ -398,3 +401,94 @@ class MediaRepoTests(unittest.HomeserverTestCase):
headers.getRawHeaders(b"X-Robots-Tag"),
[b"noindex, nofollow, noarchive, noimageindex"],
)
+
+
+class TestSpamChecker:
+ """A spam checker module that rejects all media that includes the bytes
+ `evil`.
+ """
+
+ def __init__(self, config, api):
+ self.config = config
+ self.api = api
+
+ def parse_config(config):
+ return config
+
+ async def check_event_for_spam(self, foo):
+ return False # allow all events
+
+ async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
+ return True # allow all invites
+
+ async def user_may_create_room(self, userid):
+ return True # allow all room creations
+
+ async def user_may_create_room_alias(self, userid, room_alias):
+ return True # allow all room aliases
+
+ async def user_may_publish_room(self, userid, room_id):
+ return True # allow publishing of all rooms
+
+ async def check_media_file_for_spam(self, file_wrapper, file_info) -> bool:
+ buf = BytesIO()
+ await file_wrapper.write_chunks_to(buf.write)
+
+ return b"evil" in buf.getvalue()
+
+
+class SpamCheckerTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ admin.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.user = self.register_user("user", "pass")
+ self.tok = self.login("user", "pass")
+
+ # Allow for uploading and downloading to/from the media repo
+ self.media_repo = hs.get_media_repository_resource()
+ self.download_resource = self.media_repo.children[b"download"]
+ self.upload_resource = self.media_repo.children[b"upload"]
+
+ def default_config(self):
+ config = default_config("test")
+
+ config.update(
+ {
+ "spam_checker": [
+ {
+ "module": TestSpamChecker.__module__ + ".TestSpamChecker",
+ "config": {},
+ }
+ ]
+ }
+ )
+
+ return config
+
+ def test_upload_innocent(self):
+ """Attempt to upload some innocent data that should be allowed.
+ """
+
+ image_data = unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000a49444154789c63000100000500010d"
+ b"0a2db40000000049454e44ae426082"
+ )
+
+ self.helper.upload_media(
+ self.upload_resource, image_data, tok=self.tok, expect_code=200
+ )
+
+ def test_upload_ban(self):
+ """Attempt to upload some data that includes bytes "evil", which should
+ get rejected by the spam checker.
+ """
+
+ data = b"Some evil data"
+
+ self.helper.upload_media(
+ self.upload_resource, data, tok=self.tok, expect_code=400
+ )
|