diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py
new file mode 100644
index 0000000000..a7e22a91e1
--- /dev/null
+++ b/synapse/media/media_storage.py
@@ -0,0 +1,374 @@
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import contextlib
+import logging
+import os
+import shutil
+from types import TracebackType
+from typing import (
+ IO,
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ BinaryIO,
+ Callable,
+ Generator,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+)
+
+import attr
+
+from twisted.internet.defer import Deferred
+from twisted.internet.interfaces import IConsumer
+from twisted.protocols.basic import FileSender
+
+import synapse
+from synapse.api.errors import NotFoundError
+from synapse.logging.context import defer_to_thread, make_deferred_yieldable
+from synapse.util import Clock
+from synapse.util.file_consumer import BackgroundFileConsumer
+
+from ._base import FileInfo, Responder
+from .filepath import MediaFilePaths
+
+if TYPE_CHECKING:
+ from synapse.media.storage_provider import StorageProvider
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class MediaStorage:
+ """Responsible for storing/fetching files from local sources.
+
+ Args:
+ hs
+ local_media_directory: Base path where we store media on disk
+ filepaths
+ storage_providers: List of StorageProvider that are used to fetch and store files.
+ """
+
+ def __init__(
+ self,
+ hs: "HomeServer",
+ local_media_directory: str,
+ filepaths: MediaFilePaths,
+ storage_providers: Sequence["StorageProvider"],
+ ):
+ self.hs = hs
+ self.reactor = hs.get_reactor()
+ self.local_media_directory = local_media_directory
+ self.filepaths = filepaths
+ self.storage_providers = storage_providers
+ self.spam_checker = hs.get_spam_checker()
+ self.clock = hs.get_clock()
+
+ async def store_file(self, source: IO, file_info: FileInfo) -> str:
+ """Write `source` to the on disk media store, and also any other
+ configured storage providers
+
+ Args:
+ source: A file like object that should be written
+ file_info: Info about the file to store
+
+ Returns:
+ the file path written to in the primary media store
+ """
+
+ with self.store_into_file(file_info) as (f, fname, finish_cb):
+ # Write to the main repository
+ await self.write_to_file(source, f)
+ await finish_cb()
+
+ return fname
+
+ async def write_to_file(self, source: IO, output: IO) -> None:
+ """Asynchronously write the `source` to `output`."""
+ await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
+
+ @contextlib.contextmanager
+ def store_into_file(
+ self, file_info: FileInfo
+ ) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]:
+ """Context manager used to get a file like object to write into, as
+ described by file_info.
+
+ Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
+ like object that can be written to, fname is the absolute path of file
+ on disk, and finish_cb is a function that returns an awaitable.
+
+ fname can be used to read the contents from after upload, e.g. to
+ generate thumbnails.
+
+ finish_cb must be called and waited on after the file has been
+ successfully been written to. Should not be called if there was an
+ error.
+
+ Args:
+ file_info: Info about the file to store
+
+ Example:
+
+ with media_storage.store_into_file(info) as (f, fname, finish_cb):
+ # .. write into f ...
+ await finish_cb()
+ """
+
+ path = self._file_info_to_path(file_info)
+ fname = os.path.join(self.local_media_directory, path)
+
+ dirname = os.path.dirname(fname)
+ os.makedirs(dirname, exist_ok=True)
+
+ finished_called = [False]
+
+ try:
+ with open(fname, "wb") as f:
+
+ async def finish() -> None:
+ # Ensure that all writes have been flushed and close the
+ # file.
+ f.flush()
+ f.close()
+
+ spam_check = await self.spam_checker.check_media_file_for_spam(
+ ReadableFileWrapper(self.clock, fname), file_info
+ )
+ if spam_check != synapse.module_api.NOT_SPAM:
+ logger.info("Blocking media due to spam checker")
+ # Note that we'll delete the stored media, due to the
+ # try/except below. The media also won't be stored in
+ # the DB.
+ # We currently ignore any additional field returned by
+ # the spam-check API.
+ raise SpamMediaException(errcode=spam_check[0])
+
+ for provider in self.storage_providers:
+ await provider.store_file(path, file_info)
+
+ finished_called[0] = True
+
+ yield f, fname, finish
+ except Exception as e:
+ try:
+ os.remove(fname)
+ except Exception:
+ pass
+
+ raise e from None
+
+ if not finished_called:
+ raise Exception("Finished callback not called")
+
+ async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
+ """Attempts to fetch media described by file_info from the local cache
+ and configured storage providers.
+
+ Args:
+ file_info
+
+ Returns:
+ Returns a Responder if the file was found, otherwise None.
+ """
+ paths = [self._file_info_to_path(file_info)]
+
+ # fallback for remote thumbnails with no method in the filename
+ if file_info.thumbnail and file_info.server_name:
+ paths.append(
+ self.filepaths.remote_media_thumbnail_rel_legacy(
+ server_name=file_info.server_name,
+ file_id=file_info.file_id,
+ width=file_info.thumbnail.width,
+ height=file_info.thumbnail.height,
+ content_type=file_info.thumbnail.type,
+ )
+ )
+
+ for path in paths:
+ local_path = os.path.join(self.local_media_directory, path)
+ if os.path.exists(local_path):
+ logger.debug("responding with local file %s", local_path)
+ return FileResponder(open(local_path, "rb"))
+ logger.debug("local file %s did not exist", local_path)
+
+ for provider in self.storage_providers:
+ for path in paths:
+ res: Any = await provider.fetch(path, file_info)
+ if res:
+ logger.debug("Streaming %s from %s", path, provider)
+ return res
+ logger.debug("%s not found on %s", path, provider)
+
+ return None
+
+ async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str:
+ """Ensures that the given file is in the local cache. Attempts to
+ download it from storage providers if it isn't.
+
+ Args:
+ file_info
+
+ Returns:
+ Full path to local file
+ """
+ path = self._file_info_to_path(file_info)
+ local_path = os.path.join(self.local_media_directory, path)
+ if os.path.exists(local_path):
+ return local_path
+
+ # Fallback for paths without method names
+ # Should be removed in the future
+ if file_info.thumbnail and file_info.server_name:
+ legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy(
+ server_name=file_info.server_name,
+ file_id=file_info.file_id,
+ width=file_info.thumbnail.width,
+ height=file_info.thumbnail.height,
+ content_type=file_info.thumbnail.type,
+ )
+ legacy_local_path = os.path.join(self.local_media_directory, legacy_path)
+ if os.path.exists(legacy_local_path):
+ return legacy_local_path
+
+ dirname = os.path.dirname(local_path)
+ os.makedirs(dirname, exist_ok=True)
+
+ for provider in self.storage_providers:
+ res: Any = await provider.fetch(path, file_info)
+ if res:
+ with res:
+ consumer = BackgroundFileConsumer(
+ open(local_path, "wb"), self.reactor
+ )
+ await res.write_to_consumer(consumer)
+ await consumer.wait()
+ return local_path
+
+ raise NotFoundError()
+
+ def _file_info_to_path(self, file_info: FileInfo) -> str:
+ """Converts file_info into a relative path.
+
+ The path is suitable for storing files under a directory, e.g. used to
+ store files on local FS under the base media repository directory.
+ """
+ if file_info.url_cache:
+ if file_info.thumbnail:
+ return self.filepaths.url_cache_thumbnail_rel(
+ media_id=file_info.file_id,
+ width=file_info.thumbnail.width,
+ height=file_info.thumbnail.height,
+ content_type=file_info.thumbnail.type,
+ method=file_info.thumbnail.method,
+ )
+ return self.filepaths.url_cache_filepath_rel(file_info.file_id)
+
+ if file_info.server_name:
+ if file_info.thumbnail:
+ return self.filepaths.remote_media_thumbnail_rel(
+ server_name=file_info.server_name,
+ file_id=file_info.file_id,
+ width=file_info.thumbnail.width,
+ height=file_info.thumbnail.height,
+ content_type=file_info.thumbnail.type,
+ method=file_info.thumbnail.method,
+ )
+ return self.filepaths.remote_media_filepath_rel(
+ file_info.server_name, file_info.file_id
+ )
+
+ if file_info.thumbnail:
+ return self.filepaths.local_media_thumbnail_rel(
+ media_id=file_info.file_id,
+ width=file_info.thumbnail.width,
+ height=file_info.thumbnail.height,
+ content_type=file_info.thumbnail.type,
+ method=file_info.thumbnail.method,
+ )
+ return self.filepaths.local_media_filepath_rel(file_info.file_id)
+
+
+def _write_file_synchronously(source: IO, dest: IO) -> None:
+ """Write `source` to the file like `dest` synchronously. Should be called
+ from a thread.
+
+ Args:
+ source: A file like object that's to be written
+ dest: A file like object to be written to
+ """
+ source.seek(0) # Ensure we read from the start of the file
+ shutil.copyfileobj(source, dest)
+
+
+class FileResponder(Responder):
+ """Wraps an open file that can be sent to a request.
+
+ Args:
+ open_file: A file like object to be streamed ot the client,
+ is closed when finished streaming.
+ """
+
+ def __init__(self, open_file: IO):
+ self.open_file = open_file
+
+ def write_to_consumer(self, consumer: IConsumer) -> Deferred:
+ return make_deferred_yieldable(
+ FileSender().beginFileTransfer(self.open_file, consumer)
+ )
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
+ self.open_file.close()
+
+
+class SpamMediaException(NotFoundError):
+ """The media was blocked by a spam checker, so we simply 404 the request (in
+ the same way as if it was quarantined).
+ """
+
+
+@attr.s(slots=True, auto_attribs=True)
+class ReadableFileWrapper:
+ """Wrapper that allows reading a file in chunks, yielding to the reactor,
+ and writing to a callback.
+
+ This is simplified `FileSender` that takes an IO object rather than an
+ `IConsumer`.
+ """
+
+ CHUNK_SIZE = 2**14
+
+ clock: Clock
+ path: str
+
+ async def write_chunks_to(self, callback: Callable[[bytes], object]) -> None:
+ """Reads the file in chunks and calls the callback with each chunk."""
+
+ with open(self.path, "rb") as file:
+ while True:
+ chunk = file.read(self.CHUNK_SIZE)
+ if not chunk:
+ break
+
+ callback(chunk)
+
+ # We yield to the reactor by sleeping for 0 seconds.
+ await self.clock.sleep(0)
|