diff options
Diffstat (limited to 'synapse/media/media_storage.py')
-rw-r--r-- | synapse/media/media_storage.py | 259 |
1 files changed, 256 insertions, 3 deletions
diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py index b3cd3fd8f4..1be2c9b5f5 100644 --- a/synapse/media/media_storage.py +++ b/synapse/media/media_storage.py @@ -19,9 +19,12 @@ # # import contextlib +import json import logging import os import shutil +from contextlib import closing +from io import BytesIO from types import TracebackType from typing import ( IO, @@ -30,24 +33,35 @@ from typing import ( AsyncIterator, BinaryIO, Callable, + List, Optional, Sequence, Tuple, Type, + Union, + cast, ) +from uuid import uuid4 import attr +from zope.interface import implementer +from twisted.internet import interfaces from twisted.internet.defer import Deferred from twisted.internet.interfaces import IConsumer from twisted.protocols.basic import FileSender from synapse.api.errors import NotFoundError -from synapse.logging.context import defer_to_thread, make_deferred_yieldable +from synapse.logging.context import ( + defer_to_thread, + make_deferred_yieldable, + run_in_background, +) from synapse.logging.opentracing import start_active_span, trace, trace_with_opname from synapse.util import Clock from synapse.util.file_consumer import BackgroundFileConsumer +from ..types import JsonDict from ._base import FileInfo, Responder from .filepath import MediaFilePaths @@ -57,6 +71,8 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +CRLF = b"\r\n" + class MediaStorage: """Responsible for storing/fetching files from local sources. @@ -174,7 +190,7 @@ class MediaStorage: and configured storage providers. Args: - file_info + file_info: Metadata about the media file Returns: Returns a Responder if the file was found, otherwise None. @@ -316,7 +332,7 @@ class FileResponder(Responder): """Wraps an open file that can be sent to a request. Args: - open_file: A file like object to be streamed ot the client, + open_file: A file like object to be streamed to the client, is closed when finished streaming. """ @@ -370,3 +386,240 @@ class ReadableFileWrapper: # We yield to the reactor by sleeping for 0 seconds. await self.clock.sleep(0) + + +@implementer(interfaces.IConsumer) +@implementer(interfaces.IPushProducer) +class MultipartFileConsumer: + """Wraps a given consumer so that any data that gets written to it gets + converted to a multipart format. + """ + + def __init__( + self, + clock: Clock, + wrapped_consumer: interfaces.IConsumer, + file_content_type: str, + json_object: JsonDict, + content_length: Optional[int] = None, + ) -> None: + self.clock = clock + self.wrapped_consumer = wrapped_consumer + self.json_field = json_object + self.json_field_written = False + self.content_type_written = False + self.file_content_type = file_content_type + self.boundary = uuid4().hex.encode("ascii") + + # The producer that registered with us, and if it's a push or pull + # producer. + self.producer: Optional["interfaces.IProducer"] = None + self.streaming: Optional[bool] = None + + # Whether the wrapped consumer has asked us to pause. + self.paused = False + + self.length = content_length + + ### IConsumer APIs ### + + def registerProducer( + self, producer: "interfaces.IProducer", streaming: bool + ) -> None: + """ + Register to receive data from a producer. + + This sets self to be a consumer for a producer. When this object runs + out of data (as when a send(2) call on a socket succeeds in moving the + last data from a userspace buffer into a kernelspace buffer), it will + ask the producer to resumeProducing(). + + For L{IPullProducer} providers, C{resumeProducing} will be called once + each time data is required. + + For L{IPushProducer} providers, C{pauseProducing} will be called + whenever the write buffer fills up and C{resumeProducing} will only be + called when it empties. The consumer will only call C{resumeProducing} + to balance a previous C{pauseProducing} call; the producer is assumed + to start in an un-paused state. + + @param streaming: C{True} if C{producer} provides L{IPushProducer}, + C{False} if C{producer} provides L{IPullProducer}. + + @raise RuntimeError: If a producer is already registered. + """ + self.producer = producer + self.streaming = streaming + + self.wrapped_consumer.registerProducer(self, True) + + # kick off producing if `self.producer` is not a streaming producer + if not streaming: + self.resumeProducing() + + def unregisterProducer(self) -> None: + """ + Stop consuming data from a producer, without disconnecting. + """ + self.wrapped_consumer.write(CRLF + b"--" + self.boundary + b"--" + CRLF) + self.wrapped_consumer.unregisterProducer() + self.paused = True + + def write(self, data: bytes) -> None: + """ + The producer will write data by calling this method. + + The implementation must be non-blocking and perform whatever + buffering is necessary. If the producer has provided enough data + for now and it is a L{IPushProducer}, the consumer may call its + C{pauseProducing} method. + """ + if not self.json_field_written: + self.wrapped_consumer.write(CRLF + b"--" + self.boundary + CRLF) + + content_type = Header(b"Content-Type", b"application/json") + self.wrapped_consumer.write(bytes(content_type) + CRLF) + + json_field = json.dumps(self.json_field) + json_bytes = json_field.encode("utf-8") + self.wrapped_consumer.write(CRLF + json_bytes) + self.wrapped_consumer.write(CRLF + b"--" + self.boundary + CRLF) + + self.json_field_written = True + + # if we haven't written the content type yet, do so + if not self.content_type_written: + type = self.file_content_type.encode("utf-8") + content_type = Header(b"Content-Type", type) + self.wrapped_consumer.write(bytes(content_type) + CRLF + CRLF) + self.content_type_written = True + + self.wrapped_consumer.write(data) + + ### IPushProducer APIs ### + + def stopProducing(self) -> None: + """ + Stop producing data. + + This tells a producer that its consumer has died, so it must stop + producing data for good. + """ + assert self.producer is not None + + self.paused = True + self.producer.stopProducing() + + def pauseProducing(self) -> None: + """ + Pause producing data. + + Tells a producer that it has produced too much data to process for + the time being, and to stop until C{resumeProducing()} is called. + """ + assert self.producer is not None + + self.paused = True + + if self.streaming: + cast("interfaces.IPushProducer", self.producer).pauseProducing() + else: + self.paused = True + + def resumeProducing(self) -> None: + """ + Resume producing data. + + This tells a producer to re-add itself to the main loop and produce + more data for its consumer. + """ + assert self.producer is not None + + if self.streaming: + cast("interfaces.IPushProducer", self.producer).resumeProducing() + else: + # If the producer is not a streaming producer we need to start + # repeatedly calling `resumeProducing` in a loop. + run_in_background(self._resumeProducingRepeatedly) + + def content_length(self) -> Optional[int]: + """ + Calculate the content length of the multipart response + in bytes. + """ + if not self.length: + return None + # calculate length of json field and content-type header + json_field = json.dumps(self.json_field) + json_bytes = json_field.encode("utf-8") + json_length = len(json_bytes) + + type = self.file_content_type.encode("utf-8") + content_type = Header(b"Content-Type", type) + type_length = len(bytes(content_type)) + + # 154 is the length of the elements that aren't variable, ie + # CRLFs and boundary strings, etc + self.length += json_length + type_length + 154 + + return self.length + + ### Internal APIs. ### + + async def _resumeProducingRepeatedly(self) -> None: + assert self.producer is not None + assert not self.streaming + + producer = cast("interfaces.IPullProducer", self.producer) + + self.paused = False + while not self.paused: + producer.resumeProducing() + await self.clock.sleep(0) + + +class Header: + """ + `Header` This class is a tiny wrapper that produces + request headers. We can't use standard python header + class because it encodes unicode fields using =? bla bla ?= + encoding, which is correct, but no one in HTTP world expects + that, everyone wants utf-8 raw bytes. (stolen from treq.multipart) + + """ + + def __init__( + self, + name: bytes, + value: Any, + params: Optional[List[Tuple[Any, Any]]] = None, + ): + self.name = name + self.value = value + self.params = params or [] + + def add_param(self, name: Any, value: Any) -> None: + self.params.append((name, value)) + + def __bytes__(self) -> bytes: + with closing(BytesIO()) as h: + h.write(self.name + b": " + escape(self.value).encode("us-ascii")) + if self.params: + for name, val in self.params: + h.write(b"; ") + h.write(escape(name).encode("us-ascii")) + h.write(b"=") + h.write(b'"' + escape(val).encode("utf-8") + b'"') + h.seek(0) + return h.read() + + +def escape(value: Union[str, bytes]) -> str: + """ + This function prevents header values from corrupting the request, + a newline in the file name parameter makes form-data request unreadable + for a majority of parsers. (stolen from treq.multipart) + """ + if isinstance(value, bytes): + value = value.decode("utf-8") + return value.replace("\r", "").replace("\n", "").replace('"', '\\"') |