diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py
index b3cd3fd8f4..1be2c9b5f5 100644
--- a/synapse/media/media_storage.py
+++ b/synapse/media/media_storage.py
@@ -19,9 +19,12 @@
#
#
import contextlib
+import json
import logging
import os
import shutil
+from contextlib import closing
+from io import BytesIO
from types import TracebackType
from typing import (
IO,
@@ -30,24 +33,35 @@ from typing import (
AsyncIterator,
BinaryIO,
Callable,
+ List,
Optional,
Sequence,
Tuple,
Type,
+ Union,
+ cast,
)
+from uuid import uuid4
import attr
+from zope.interface import implementer
+from twisted.internet import interfaces
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
from synapse.api.errors import NotFoundError
-from synapse.logging.context import defer_to_thread, make_deferred_yieldable
+from synapse.logging.context import (
+ defer_to_thread,
+ make_deferred_yieldable,
+ run_in_background,
+)
from synapse.logging.opentracing import start_active_span, trace, trace_with_opname
from synapse.util import Clock
from synapse.util.file_consumer import BackgroundFileConsumer
+from ..types import JsonDict
from ._base import FileInfo, Responder
from .filepath import MediaFilePaths
@@ -57,6 +71,8 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+CRLF = b"\r\n"
+
class MediaStorage:
"""Responsible for storing/fetching files from local sources.
@@ -174,7 +190,7 @@ class MediaStorage:
and configured storage providers.
Args:
- file_info
+ file_info: Metadata about the media file
Returns:
Returns a Responder if the file was found, otherwise None.
@@ -316,7 +332,7 @@ class FileResponder(Responder):
"""Wraps an open file that can be sent to a request.
Args:
- open_file: A file like object to be streamed ot the client,
+ open_file: A file like object to be streamed to the client,
is closed when finished streaming.
"""
@@ -370,3 +386,240 @@ class ReadableFileWrapper:
# We yield to the reactor by sleeping for 0 seconds.
await self.clock.sleep(0)
+
+
+@implementer(interfaces.IConsumer)
+@implementer(interfaces.IPushProducer)
+class MultipartFileConsumer:
+ """Wraps a given consumer so that any data that gets written to it gets
+ converted to a multipart format.
+ """
+
+ def __init__(
+ self,
+ clock: Clock,
+ wrapped_consumer: interfaces.IConsumer,
+ file_content_type: str,
+ json_object: JsonDict,
+ content_length: Optional[int] = None,
+ ) -> None:
+ self.clock = clock
+ self.wrapped_consumer = wrapped_consumer
+ self.json_field = json_object
+ self.json_field_written = False
+ self.content_type_written = False
+ self.file_content_type = file_content_type
+ self.boundary = uuid4().hex.encode("ascii")
+
+ # The producer that registered with us, and if it's a push or pull
+ # producer.
+ self.producer: Optional["interfaces.IProducer"] = None
+ self.streaming: Optional[bool] = None
+
+ # Whether the wrapped consumer has asked us to pause.
+ self.paused = False
+
+ self.length = content_length
+
+ ### IConsumer APIs ###
+
+ def registerProducer(
+ self, producer: "interfaces.IProducer", streaming: bool
+ ) -> None:
+ """
+ Register to receive data from a producer.
+
+ This sets self to be a consumer for a producer. When this object runs
+ out of data (as when a send(2) call on a socket succeeds in moving the
+ last data from a userspace buffer into a kernelspace buffer), it will
+ ask the producer to resumeProducing().
+
+ For L{IPullProducer} providers, C{resumeProducing} will be called once
+ each time data is required.
+
+ For L{IPushProducer} providers, C{pauseProducing} will be called
+ whenever the write buffer fills up and C{resumeProducing} will only be
+ called when it empties. The consumer will only call C{resumeProducing}
+ to balance a previous C{pauseProducing} call; the producer is assumed
+ to start in an un-paused state.
+
+ @param streaming: C{True} if C{producer} provides L{IPushProducer},
+ C{False} if C{producer} provides L{IPullProducer}.
+
+ @raise RuntimeError: If a producer is already registered.
+ """
+ self.producer = producer
+ self.streaming = streaming
+
+ self.wrapped_consumer.registerProducer(self, True)
+
+ # kick off producing if `self.producer` is not a streaming producer
+ if not streaming:
+ self.resumeProducing()
+
+ def unregisterProducer(self) -> None:
+ """
+ Stop consuming data from a producer, without disconnecting.
+ """
+ self.wrapped_consumer.write(CRLF + b"--" + self.boundary + b"--" + CRLF)
+ self.wrapped_consumer.unregisterProducer()
+ self.paused = True
+
+ def write(self, data: bytes) -> None:
+ """
+ The producer will write data by calling this method.
+
+ The implementation must be non-blocking and perform whatever
+ buffering is necessary. If the producer has provided enough data
+ for now and it is a L{IPushProducer}, the consumer may call its
+ C{pauseProducing} method.
+ """
+ if not self.json_field_written:
+ self.wrapped_consumer.write(CRLF + b"--" + self.boundary + CRLF)
+
+ content_type = Header(b"Content-Type", b"application/json")
+ self.wrapped_consumer.write(bytes(content_type) + CRLF)
+
+ json_field = json.dumps(self.json_field)
+ json_bytes = json_field.encode("utf-8")
+ self.wrapped_consumer.write(CRLF + json_bytes)
+ self.wrapped_consumer.write(CRLF + b"--" + self.boundary + CRLF)
+
+ self.json_field_written = True
+
+ # if we haven't written the content type yet, do so
+ if not self.content_type_written:
+ type = self.file_content_type.encode("utf-8")
+ content_type = Header(b"Content-Type", type)
+ self.wrapped_consumer.write(bytes(content_type) + CRLF + CRLF)
+ self.content_type_written = True
+
+ self.wrapped_consumer.write(data)
+
+ ### IPushProducer APIs ###
+
+ def stopProducing(self) -> None:
+ """
+ Stop producing data.
+
+ This tells a producer that its consumer has died, so it must stop
+ producing data for good.
+ """
+ assert self.producer is not None
+
+ self.paused = True
+ self.producer.stopProducing()
+
+ def pauseProducing(self) -> None:
+ """
+ Pause producing data.
+
+ Tells a producer that it has produced too much data to process for
+ the time being, and to stop until C{resumeProducing()} is called.
+ """
+ assert self.producer is not None
+
+ self.paused = True
+
+ if self.streaming:
+ cast("interfaces.IPushProducer", self.producer).pauseProducing()
+ else:
+ self.paused = True
+
+ def resumeProducing(self) -> None:
+ """
+ Resume producing data.
+
+ This tells a producer to re-add itself to the main loop and produce
+ more data for its consumer.
+ """
+ assert self.producer is not None
+
+ if self.streaming:
+ cast("interfaces.IPushProducer", self.producer).resumeProducing()
+ else:
+ # If the producer is not a streaming producer we need to start
+ # repeatedly calling `resumeProducing` in a loop.
+ run_in_background(self._resumeProducingRepeatedly)
+
+ def content_length(self) -> Optional[int]:
+ """
+ Calculate the content length of the multipart response
+ in bytes.
+ """
+ if not self.length:
+ return None
+ # calculate length of json field and content-type header
+ json_field = json.dumps(self.json_field)
+ json_bytes = json_field.encode("utf-8")
+ json_length = len(json_bytes)
+
+ type = self.file_content_type.encode("utf-8")
+ content_type = Header(b"Content-Type", type)
+ type_length = len(bytes(content_type))
+
+ # 154 is the length of the elements that aren't variable, ie
+ # CRLFs and boundary strings, etc
+ self.length += json_length + type_length + 154
+
+ return self.length
+
+ ### Internal APIs. ###
+
+ async def _resumeProducingRepeatedly(self) -> None:
+ assert self.producer is not None
+ assert not self.streaming
+
+ producer = cast("interfaces.IPullProducer", self.producer)
+
+ self.paused = False
+ while not self.paused:
+ producer.resumeProducing()
+ await self.clock.sleep(0)
+
+
+class Header:
+ """
+ `Header` This class is a tiny wrapper that produces
+ request headers. We can't use standard python header
+ class because it encodes unicode fields using =? bla bla ?=
+ encoding, which is correct, but no one in HTTP world expects
+ that, everyone wants utf-8 raw bytes. (stolen from treq.multipart)
+
+ """
+
+ def __init__(
+ self,
+ name: bytes,
+ value: Any,
+ params: Optional[List[Tuple[Any, Any]]] = None,
+ ):
+ self.name = name
+ self.value = value
+ self.params = params or []
+
+ def add_param(self, name: Any, value: Any) -> None:
+ self.params.append((name, value))
+
+ def __bytes__(self) -> bytes:
+ with closing(BytesIO()) as h:
+ h.write(self.name + b": " + escape(self.value).encode("us-ascii"))
+ if self.params:
+ for name, val in self.params:
+ h.write(b"; ")
+ h.write(escape(name).encode("us-ascii"))
+ h.write(b"=")
+ h.write(b'"' + escape(val).encode("utf-8") + b'"')
+ h.seek(0)
+ return h.read()
+
+
+def escape(value: Union[str, bytes]) -> str:
+ """
+ This function prevents header values from corrupting the request,
+ a newline in the file name parameter makes form-data request unreadable
+ for a majority of parsers. (stolen from treq.multipart)
+ """
+ if isinstance(value, bytes):
+ value = value.decode("utf-8")
+ return value.replace("\r", "").replace("\n", "").replace('"', '\\"')
|