diff options
Diffstat (limited to 'synapse/rest/media/v1/media_storage.py')
-rw-r--r-- | synapse/rest/media/v1/media_storage.py | 59 |
1 files changed, 56 insertions, 3 deletions
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 89cdd605aa..aba6d689a8 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -16,13 +16,17 @@ import contextlib import logging import os import shutil -from typing import IO, TYPE_CHECKING, Any, Optional, Sequence +from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Sequence + +import attr from twisted.internet.defer import Deferred from twisted.internet.interfaces import IConsumer from twisted.protocols.basic import FileSender +from synapse.api.errors import NotFoundError from synapse.logging.context import defer_to_thread, make_deferred_yieldable +from synapse.util import Clock from synapse.util.file_consumer import BackgroundFileConsumer from ._base import FileInfo, Responder @@ -58,6 +62,8 @@ class MediaStorage: self.local_media_directory = local_media_directory self.filepaths = filepaths self.storage_providers = storage_providers + self.spam_checker = hs.get_spam_checker() + self.clock = hs.get_clock() async def store_file(self, source: IO, file_info: FileInfo) -> str: """Write `source` to the on disk media store, and also any other @@ -127,18 +133,29 @@ class MediaStorage: f.flush() f.close() + spam = await self.spam_checker.check_media_file_for_spam( + ReadableFileWrapper(self.clock, fname), file_info + ) + if spam: + logger.info("Blocking media due to spam checker") + # Note that we'll delete the stored media, due to the + # try/except below. The media also won't be stored in + # the DB. + raise SpamMediaException() + for provider in self.storage_providers: await provider.store_file(path, file_info) finished_called[0] = True yield f, fname, finish - except Exception: + except Exception as e: try: os.remove(fname) except Exception: pass - raise + + raise e from None if not finished_called: raise Exception("Finished callback not called") @@ -302,3 +319,39 @@ class FileResponder(Responder): def __exit__(self, exc_type, exc_val, exc_tb): self.open_file.close() + + +class SpamMediaException(NotFoundError): + """The media was blocked by a spam checker, so we simply 404 the request (in + the same way as if it was quarantined). + """ + + +@attr.s(slots=True) +class ReadableFileWrapper: + """Wrapper that allows reading a file in chunks, yielding to the reactor, + and writing to a callback. + + This is simplified `FileSender` that takes an IO object rather than an + `IConsumer`. + """ + + CHUNK_SIZE = 2 ** 14 + + clock = attr.ib(type=Clock) + path = attr.ib(type=str) + + async def write_chunks_to(self, callback: Callable[[bytes], None]): + """Reads the file in chunks and calls the callback with each chunk. + """ + + with open(self.path, "rb") as file: + while True: + chunk = file.read(self.CHUNK_SIZE) + if not chunk: + break + + callback(chunk) + + # We yield to the reactor by sleeping for 0 seconds. + await self.clock.sleep(0) |