diff options
Diffstat (limited to 'synapse/media/media_storage.py')
-rw-r--r-- | synapse/media/media_storage.py | 223 |
1 files changed, 215 insertions, 8 deletions
diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py index b3cd3fd8f4..2f55d12b6b 100644 --- a/synapse/media/media_storage.py +++ b/synapse/media/media_storage.py @@ -19,9 +19,12 @@ # # import contextlib +import json import logging import os import shutil +from contextlib import closing +from io import BytesIO from types import TracebackType from typing import ( IO, @@ -30,14 +33,19 @@ from typing import ( AsyncIterator, BinaryIO, Callable, + List, Optional, Sequence, Tuple, Type, + Union, ) +from uuid import uuid4 import attr +from zope.interface import implementer +from twisted.internet import defer, interfaces from twisted.internet.defer import Deferred from twisted.internet.interfaces import IConsumer from twisted.protocols.basic import FileSender @@ -48,15 +56,19 @@ from synapse.logging.opentracing import start_active_span, trace, trace_with_opn from synapse.util import Clock from synapse.util.file_consumer import BackgroundFileConsumer +from ..storage.databases.main.media_repository import LocalMedia +from ..types import JsonDict from ._base import FileInfo, Responder from .filepath import MediaFilePaths if TYPE_CHECKING: - from synapse.media.storage_provider import StorageProvider + from synapse.media.storage_provider import StorageProviderWrapper from synapse.server import HomeServer logger = logging.getLogger(__name__) +CRLF = b"\r\n" + class MediaStorage: """Responsible for storing/fetching files from local sources. @@ -73,7 +85,7 @@ class MediaStorage: hs: "HomeServer", local_media_directory: str, filepaths: MediaFilePaths, - storage_providers: Sequence["StorageProvider"], + storage_providers: Sequence["StorageProviderWrapper"], ): self.hs = hs self.reactor = hs.get_reactor() @@ -169,15 +181,23 @@ class MediaStorage: raise e from None - async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]: + async def fetch_media( + self, + file_info: FileInfo, + media_info: Optional[LocalMedia] = None, + federation: bool = False, + ) -> Optional[Responder]: """Attempts to fetch media described by file_info from the local cache and configured storage providers. Args: - file_info + file_info: Metadata about the media file + media_info: Metadata about the media item + federation: Whether this file is being fetched for a federation request Returns: - Returns a Responder if the file was found, otherwise None. + If the file was found returns a Responder (a Multipart Responder if the requested + file is for the federation /download endpoint), otherwise None. """ paths = [self._file_info_to_path(file_info)] @@ -197,12 +217,19 @@ class MediaStorage: local_path = os.path.join(self.local_media_directory, path) if os.path.exists(local_path): logger.debug("responding with local file %s", local_path) - return FileResponder(open(local_path, "rb")) + if federation: + assert media_info is not None + boundary = uuid4().hex.encode("ascii") + return MultipartResponder( + open(local_path, "rb"), media_info, boundary + ) + else: + return FileResponder(open(local_path, "rb")) logger.debug("local file %s did not exist", local_path) for provider in self.storage_providers: for path in paths: - res: Any = await provider.fetch(path, file_info) + res: Any = await provider.fetch(path, file_info, media_info, federation) if res: logger.debug("Streaming %s from %s", path, provider) return res @@ -316,7 +343,7 @@ class FileResponder(Responder): """Wraps an open file that can be sent to a request. Args: - open_file: A file like object to be streamed ot the client, + open_file: A file like object to be streamed to the client, is closed when finished streaming. """ @@ -337,6 +364,38 @@ class FileResponder(Responder): self.open_file.close() +class MultipartResponder(Responder): + """Wraps an open file, formats the response according to MSC3916 and sends it to a + federation request. + + Args: + open_file: A file like object to be streamed to the client, + is closed when finished streaming. + media_info: metadata about the media item + boundary: bytes to use for the multipart response boundary + """ + + def __init__(self, open_file: IO, media_info: LocalMedia, boundary: bytes) -> None: + self.open_file = open_file + self.media_info = media_info + self.boundary = boundary + + def write_to_consumer(self, consumer: IConsumer) -> Deferred: + return make_deferred_yieldable( + MultipartFileSender().beginFileTransfer( + self.open_file, consumer, self.media_info.media_type, {}, self.boundary + ) + ) + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.open_file.close() + + class SpamMediaException(NotFoundError): """The media was blocked by a spam checker, so we simply 404 the request (in the same way as if it was quarantined). @@ -370,3 +429,151 @@ class ReadableFileWrapper: # We yield to the reactor by sleeping for 0 seconds. await self.clock.sleep(0) + + +@implementer(interfaces.IProducer) +class MultipartFileSender: + """ + A producer that sends the contents of a file to a federation request in the format + outlined in MSC3916 - a multipart/format-data response where the first field is a + JSON object and the second is the requested file. + + This is a slight re-writing of twisted.protocols.basic.FileSender to achieve the format + outlined above. + """ + + CHUNK_SIZE = 2**14 + + lastSent = "" + deferred: Optional[defer.Deferred] = None + + def beginFileTransfer( + self, + file: IO, + consumer: IConsumer, + file_content_type: str, + json_object: JsonDict, + boundary: bytes, + ) -> Deferred: + """ + Begin transferring a file + + Args: + file: The file object to read data from + consumer: The synapse request to write the data to + file_content_type: The content-type of the file + json_object: The JSON object to write to the first field of the response + boundary: bytes to be used as the multipart/form-data boundary + + Returns: A deferred whose callback will be invoked when the file has + been completely written to the consumer. The last byte written to the + consumer is passed to the callback. + """ + self.file: Optional[IO] = file + self.consumer = consumer + self.json_field = json_object + self.json_field_written = False + self.content_type_written = False + self.file_content_type = file_content_type + self.boundary = boundary + self.deferred: Deferred = defer.Deferred() + self.consumer.registerProducer(self, False) + # while it's not entirely clear why this assignment is necessary, it mirrors + # the behavior in FileSender.beginFileTransfer and thus is preserved here + deferred = self.deferred + return deferred + + def resumeProducing(self) -> None: + # write the first field, which will always be a json field + if not self.json_field_written: + self.consumer.write(CRLF + b"--" + self.boundary + CRLF) + + content_type = Header(b"Content-Type", b"application/json") + self.consumer.write(bytes(content_type) + CRLF) + + json_field = json.dumps(self.json_field) + json_bytes = json_field.encode("utf-8") + self.consumer.write(json_bytes) + self.consumer.write(CRLF + b"--" + self.boundary + CRLF) + + self.json_field_written = True + + chunk: Any = "" + if self.file: + # if we haven't written the content type yet, do so + if not self.content_type_written: + type = self.file_content_type.encode("utf-8") + content_type = Header(b"Content-Type", type) + self.consumer.write(bytes(content_type) + CRLF) + self.content_type_written = True + + chunk = self.file.read(self.CHUNK_SIZE) + + if not chunk: + # we've reached the end of the file + self.consumer.write(CRLF + b"--" + self.boundary + b"--" + CRLF) + self.file = None + self.consumer.unregisterProducer() + + if self.deferred: + self.deferred.callback(self.lastSent) + self.deferred = None + return + + self.consumer.write(chunk) + self.lastSent = chunk[-1:] + + def pauseProducing(self) -> None: + pass + + def stopProducing(self) -> None: + if self.deferred: + self.deferred.errback(Exception("Consumer asked us to stop producing")) + self.deferred = None + + +class Header: + """ + `Header` This class is a tiny wrapper that produces + request headers. We can't use standard python header + class because it encodes unicode fields using =? bla bla ?= + encoding, which is correct, but no one in HTTP world expects + that, everyone wants utf-8 raw bytes. (stolen from treq.multipart) + + """ + + def __init__( + self, + name: bytes, + value: Any, + params: Optional[List[Tuple[Any, Any]]] = None, + ): + self.name = name + self.value = value + self.params = params or [] + + def add_param(self, name: Any, value: Any) -> None: + self.params.append((name, value)) + + def __bytes__(self) -> bytes: + with closing(BytesIO()) as h: + h.write(self.name + b": " + escape(self.value).encode("us-ascii")) + if self.params: + for name, val in self.params: + h.write(b"; ") + h.write(escape(name).encode("us-ascii")) + h.write(b"=") + h.write(b'"' + escape(val).encode("utf-8") + b'"') + h.seek(0) + return h.read() + + +def escape(value: Union[str, bytes]) -> str: + """ + This function prevents header values from corrupting the request, + a newline in the file name parameter makes form-data request unreadable + for a majority of parsers. (stolen from treq.multipart) + """ + if isinstance(value, bytes): + value = value.decode("utf-8") + return value.replace("\r", "").replace("\n", "").replace('"', '\\"') |