diff options
Diffstat (limited to 'synapse/media/media_storage.py')
-rw-r--r-- | synapse/media/media_storage.py | 223 |
1 files changed, 8 insertions, 215 deletions
diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py index 2f55d12b6b..b3cd3fd8f4 100644 --- a/synapse/media/media_storage.py +++ b/synapse/media/media_storage.py @@ -19,12 +19,9 @@ # # import contextlib -import json import logging import os import shutil -from contextlib import closing -from io import BytesIO from types import TracebackType from typing import ( IO, @@ -33,19 +30,14 @@ from typing import ( AsyncIterator, BinaryIO, Callable, - List, Optional, Sequence, Tuple, Type, - Union, ) -from uuid import uuid4 import attr -from zope.interface import implementer -from twisted.internet import defer, interfaces from twisted.internet.defer import Deferred from twisted.internet.interfaces import IConsumer from twisted.protocols.basic import FileSender @@ -56,19 +48,15 @@ from synapse.logging.opentracing import start_active_span, trace, trace_with_opn from synapse.util import Clock from synapse.util.file_consumer import BackgroundFileConsumer -from ..storage.databases.main.media_repository import LocalMedia -from ..types import JsonDict from ._base import FileInfo, Responder from .filepath import MediaFilePaths if TYPE_CHECKING: - from synapse.media.storage_provider import StorageProviderWrapper + from synapse.media.storage_provider import StorageProvider from synapse.server import HomeServer logger = logging.getLogger(__name__) -CRLF = b"\r\n" - class MediaStorage: """Responsible for storing/fetching files from local sources. @@ -85,7 +73,7 @@ class MediaStorage: hs: "HomeServer", local_media_directory: str, filepaths: MediaFilePaths, - storage_providers: Sequence["StorageProviderWrapper"], + storage_providers: Sequence["StorageProvider"], ): self.hs = hs self.reactor = hs.get_reactor() @@ -181,23 +169,15 @@ class MediaStorage: raise e from None - async def fetch_media( - self, - file_info: FileInfo, - media_info: Optional[LocalMedia] = None, - federation: bool = False, - ) -> Optional[Responder]: + async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]: """Attempts to fetch media described by file_info from the local cache and configured storage providers. Args: - file_info: Metadata about the media file - media_info: Metadata about the media item - federation: Whether this file is being fetched for a federation request + file_info Returns: - If the file was found returns a Responder (a Multipart Responder if the requested - file is for the federation /download endpoint), otherwise None. + Returns a Responder if the file was found, otherwise None. """ paths = [self._file_info_to_path(file_info)] @@ -217,19 +197,12 @@ class MediaStorage: local_path = os.path.join(self.local_media_directory, path) if os.path.exists(local_path): logger.debug("responding with local file %s", local_path) - if federation: - assert media_info is not None - boundary = uuid4().hex.encode("ascii") - return MultipartResponder( - open(local_path, "rb"), media_info, boundary - ) - else: - return FileResponder(open(local_path, "rb")) + return FileResponder(open(local_path, "rb")) logger.debug("local file %s did not exist", local_path) for provider in self.storage_providers: for path in paths: - res: Any = await provider.fetch(path, file_info, media_info, federation) + res: Any = await provider.fetch(path, file_info) if res: logger.debug("Streaming %s from %s", path, provider) return res @@ -343,7 +316,7 @@ class FileResponder(Responder): """Wraps an open file that can be sent to a request. Args: - open_file: A file like object to be streamed to the client, + open_file: A file like object to be streamed ot the client, is closed when finished streaming. """ @@ -364,38 +337,6 @@ class FileResponder(Responder): self.open_file.close() -class MultipartResponder(Responder): - """Wraps an open file, formats the response according to MSC3916 and sends it to a - federation request. - - Args: - open_file: A file like object to be streamed to the client, - is closed when finished streaming. - media_info: metadata about the media item - boundary: bytes to use for the multipart response boundary - """ - - def __init__(self, open_file: IO, media_info: LocalMedia, boundary: bytes) -> None: - self.open_file = open_file - self.media_info = media_info - self.boundary = boundary - - def write_to_consumer(self, consumer: IConsumer) -> Deferred: - return make_deferred_yieldable( - MultipartFileSender().beginFileTransfer( - self.open_file, consumer, self.media_info.media_type, {}, self.boundary - ) - ) - - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> None: - self.open_file.close() - - class SpamMediaException(NotFoundError): """The media was blocked by a spam checker, so we simply 404 the request (in the same way as if it was quarantined). @@ -429,151 +370,3 @@ class ReadableFileWrapper: # We yield to the reactor by sleeping for 0 seconds. await self.clock.sleep(0) - - -@implementer(interfaces.IProducer) -class MultipartFileSender: - """ - A producer that sends the contents of a file to a federation request in the format - outlined in MSC3916 - a multipart/format-data response where the first field is a - JSON object and the second is the requested file. - - This is a slight re-writing of twisted.protocols.basic.FileSender to achieve the format - outlined above. - """ - - CHUNK_SIZE = 2**14 - - lastSent = "" - deferred: Optional[defer.Deferred] = None - - def beginFileTransfer( - self, - file: IO, - consumer: IConsumer, - file_content_type: str, - json_object: JsonDict, - boundary: bytes, - ) -> Deferred: - """ - Begin transferring a file - - Args: - file: The file object to read data from - consumer: The synapse request to write the data to - file_content_type: The content-type of the file - json_object: The JSON object to write to the first field of the response - boundary: bytes to be used as the multipart/form-data boundary - - Returns: A deferred whose callback will be invoked when the file has - been completely written to the consumer. The last byte written to the - consumer is passed to the callback. - """ - self.file: Optional[IO] = file - self.consumer = consumer - self.json_field = json_object - self.json_field_written = False - self.content_type_written = False - self.file_content_type = file_content_type - self.boundary = boundary - self.deferred: Deferred = defer.Deferred() - self.consumer.registerProducer(self, False) - # while it's not entirely clear why this assignment is necessary, it mirrors - # the behavior in FileSender.beginFileTransfer and thus is preserved here - deferred = self.deferred - return deferred - - def resumeProducing(self) -> None: - # write the first field, which will always be a json field - if not self.json_field_written: - self.consumer.write(CRLF + b"--" + self.boundary + CRLF) - - content_type = Header(b"Content-Type", b"application/json") - self.consumer.write(bytes(content_type) + CRLF) - - json_field = json.dumps(self.json_field) - json_bytes = json_field.encode("utf-8") - self.consumer.write(json_bytes) - self.consumer.write(CRLF + b"--" + self.boundary + CRLF) - - self.json_field_written = True - - chunk: Any = "" - if self.file: - # if we haven't written the content type yet, do so - if not self.content_type_written: - type = self.file_content_type.encode("utf-8") - content_type = Header(b"Content-Type", type) - self.consumer.write(bytes(content_type) + CRLF) - self.content_type_written = True - - chunk = self.file.read(self.CHUNK_SIZE) - - if not chunk: - # we've reached the end of the file - self.consumer.write(CRLF + b"--" + self.boundary + b"--" + CRLF) - self.file = None - self.consumer.unregisterProducer() - - if self.deferred: - self.deferred.callback(self.lastSent) - self.deferred = None - return - - self.consumer.write(chunk) - self.lastSent = chunk[-1:] - - def pauseProducing(self) -> None: - pass - - def stopProducing(self) -> None: - if self.deferred: - self.deferred.errback(Exception("Consumer asked us to stop producing")) - self.deferred = None - - -class Header: - """ - `Header` This class is a tiny wrapper that produces - request headers. We can't use standard python header - class because it encodes unicode fields using =? bla bla ?= - encoding, which is correct, but no one in HTTP world expects - that, everyone wants utf-8 raw bytes. (stolen from treq.multipart) - - """ - - def __init__( - self, - name: bytes, - value: Any, - params: Optional[List[Tuple[Any, Any]]] = None, - ): - self.name = name - self.value = value - self.params = params or [] - - def add_param(self, name: Any, value: Any) -> None: - self.params.append((name, value)) - - def __bytes__(self) -> bytes: - with closing(BytesIO()) as h: - h.write(self.name + b": " + escape(self.value).encode("us-ascii")) - if self.params: - for name, val in self.params: - h.write(b"; ") - h.write(escape(name).encode("us-ascii")) - h.write(b"=") - h.write(b'"' + escape(val).encode("utf-8") + b'"') - h.seek(0) - return h.read() - - -def escape(value: Union[str, bytes]) -> str: - """ - This function prevents header values from corrupting the request, - a newline in the file name parameter makes form-data request unreadable - for a majority of parsers. (stolen from treq.multipart) - """ - if isinstance(value, bytes): - value = value.decode("utf-8") - return value.replace("\r", "").replace("\n", "").replace('"', '\\"') |