diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 89cdd605aa..aba6d689a8 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -16,13 +16,17 @@ import contextlib
import logging
import os
import shutil
-from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
+from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Sequence
+
+import attr
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
+from synapse.api.errors import NotFoundError
from synapse.logging.context import defer_to_thread, make_deferred_yieldable
+from synapse.util import Clock
from synapse.util.file_consumer import BackgroundFileConsumer
from ._base import FileInfo, Responder
@@ -58,6 +62,8 @@ class MediaStorage:
self.local_media_directory = local_media_directory
self.filepaths = filepaths
self.storage_providers = storage_providers
+ self.spam_checker = hs.get_spam_checker()
+ self.clock = hs.get_clock()
async def store_file(self, source: IO, file_info: FileInfo) -> str:
"""Write `source` to the on disk media store, and also any other
@@ -127,18 +133,29 @@ class MediaStorage:
f.flush()
f.close()
+ spam = await self.spam_checker.check_media_file_for_spam(
+ ReadableFileWrapper(self.clock, fname), file_info
+ )
+ if spam:
+ logger.info("Blocking media due to spam checker")
+ # Note that we'll delete the stored media, due to the
+ # try/except below. The media also won't be stored in
+ # the DB.
+ raise SpamMediaException()
+
for provider in self.storage_providers:
await provider.store_file(path, file_info)
finished_called[0] = True
yield f, fname, finish
- except Exception:
+ except Exception as e:
try:
os.remove(fname)
except Exception:
pass
- raise
+
+ raise e from None
if not finished_called:
raise Exception("Finished callback not called")
@@ -302,3 +319,39 @@ class FileResponder(Responder):
def __exit__(self, exc_type, exc_val, exc_tb):
self.open_file.close()
+
+
+class SpamMediaException(NotFoundError):
+ """The media was blocked by a spam checker, so we simply 404 the request (in
+ the same way as if it was quarantined).
+ """
+
+
+@attr.s(slots=True)
+class ReadableFileWrapper:
+ """Wrapper that allows reading a file in chunks, yielding to the reactor,
+ and writing to a callback.
+
+ This is simplified `FileSender` that takes an IO object rather than an
+ `IConsumer`.
+ """
+
+ CHUNK_SIZE = 2 ** 14
+
+ clock = attr.ib(type=Clock)
+ path = attr.ib(type=str)
+
+ async def write_chunks_to(self, callback: Callable[[bytes], None]):
+ """Reads the file in chunks and calls the callback with each chunk.
+ """
+
+ with open(self.path, "rb") as file:
+ while True:
+ chunk = file.read(self.CHUNK_SIZE)
+ if not chunk:
+ break
+
+ callback(chunk)
+
+ # We yield to the reactor by sleeping for 0 seconds.
+ await self.clock.sleep(0)
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 6da76ae994..1136277794 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -22,6 +22,7 @@ from twisted.web.http import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_string
+from synapse.rest.media.v1.media_storage import SpamMediaException
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
@@ -86,9 +87,14 @@ class UploadResource(DirectServeJsonResource):
# disposition = headers.getRawHeaders(b"Content-Disposition")[0]
# TODO(markjh): parse content-dispostion
- content_uri = await self.media_repo.create_content(
- media_type, upload_name, request.content, content_length, requester.user
- )
+ try:
+ content_uri = await self.media_repo.create_content(
+ media_type, upload_name, request.content, content_length, requester.user
+ )
+ except SpamMediaException:
+ # For uploading of media we want to respond with a 400, instead of
+ # the default 404, as that would just be confusing.
+ raise SynapseError(400, "Bad content")
logger.info("Uploaded content with URI %r", content_uri)
|