summary refs log tree commit diff
path: root/synapse/media/media_storage.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/media/media_storage.py')
-rw-r--r--synapse/media/media_storage.py223
1 files changed, 215 insertions, 8 deletions
diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py
index b3cd3fd8f4..2f55d12b6b 100644
--- a/synapse/media/media_storage.py
+++ b/synapse/media/media_storage.py
@@ -19,9 +19,12 @@
 #
 #
 import contextlib
+import json
 import logging
 import os
 import shutil
+from contextlib import closing
+from io import BytesIO
 from types import TracebackType
 from typing import (
     IO,
@@ -30,14 +33,19 @@ from typing import (
     AsyncIterator,
     BinaryIO,
     Callable,
+    List,
     Optional,
     Sequence,
     Tuple,
     Type,
+    Union,
 )
+from uuid import uuid4
 
 import attr
+from zope.interface import implementer
 
+from twisted.internet import defer, interfaces
 from twisted.internet.defer import Deferred
 from twisted.internet.interfaces import IConsumer
 from twisted.protocols.basic import FileSender
@@ -48,15 +56,19 @@ from synapse.logging.opentracing import start_active_span, trace, trace_with_opn
 from synapse.util import Clock
 from synapse.util.file_consumer import BackgroundFileConsumer
 
+from ..storage.databases.main.media_repository import LocalMedia
+from ..types import JsonDict
 from ._base import FileInfo, Responder
 from .filepath import MediaFilePaths
 
 if TYPE_CHECKING:
-    from synapse.media.storage_provider import StorageProvider
+    from synapse.media.storage_provider import StorageProviderWrapper
     from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
 
+CRLF = b"\r\n"
+
 
 class MediaStorage:
     """Responsible for storing/fetching files from local sources.
@@ -73,7 +85,7 @@ class MediaStorage:
         hs: "HomeServer",
         local_media_directory: str,
         filepaths: MediaFilePaths,
-        storage_providers: Sequence["StorageProvider"],
+        storage_providers: Sequence["StorageProviderWrapper"],
     ):
         self.hs = hs
         self.reactor = hs.get_reactor()
@@ -169,15 +181,23 @@ class MediaStorage:
 
             raise e from None
 
-    async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
+    async def fetch_media(
+        self,
+        file_info: FileInfo,
+        media_info: Optional[LocalMedia] = None,
+        federation: bool = False,
+    ) -> Optional[Responder]:
         """Attempts to fetch media described by file_info from the local cache
         and configured storage providers.
 
         Args:
-            file_info
+            file_info: Metadata about the media file
+            media_info: Metadata about the media item
+            federation: Whether this file is being fetched for a federation request
 
         Returns:
-            Returns a Responder if the file was found, otherwise None.
+            If the file was found returns a Responder (a Multipart Responder if the requested
+            file is for the federation /download endpoint), otherwise None.
         """
         paths = [self._file_info_to_path(file_info)]
 
@@ -197,12 +217,19 @@ class MediaStorage:
             local_path = os.path.join(self.local_media_directory, path)
             if os.path.exists(local_path):
                 logger.debug("responding with local file %s", local_path)
-                return FileResponder(open(local_path, "rb"))
+                if federation:
+                    assert media_info is not None
+                    boundary = uuid4().hex.encode("ascii")
+                    return MultipartResponder(
+                        open(local_path, "rb"), media_info, boundary
+                    )
+                else:
+                    return FileResponder(open(local_path, "rb"))
             logger.debug("local file %s did not exist", local_path)
 
         for provider in self.storage_providers:
             for path in paths:
-                res: Any = await provider.fetch(path, file_info)
+                res: Any = await provider.fetch(path, file_info, media_info, federation)
                 if res:
                     logger.debug("Streaming %s from %s", path, provider)
                     return res
@@ -316,7 +343,7 @@ class FileResponder(Responder):
     """Wraps an open file that can be sent to a request.
 
     Args:
-        open_file: A file like object to be streamed ot the client,
+        open_file: A file like object to be streamed to the client,
             is closed when finished streaming.
     """
 
@@ -337,6 +364,38 @@ class FileResponder(Responder):
         self.open_file.close()
 
 
+class MultipartResponder(Responder):
+    """Wraps an open file, formats the response according to MSC3916 and sends it to a
+    federation request.
+
+    Args:
+        open_file: A file like object to be streamed to the client,
+            is closed when finished streaming.
+        media_info: metadata about the media item
+        boundary: bytes to use for the multipart response boundary
+    """
+
+    def __init__(self, open_file: IO, media_info: LocalMedia, boundary: bytes) -> None:
+        self.open_file = open_file
+        self.media_info = media_info
+        self.boundary = boundary
+
+    def write_to_consumer(self, consumer: IConsumer) -> Deferred:
+        return make_deferred_yieldable(
+            MultipartFileSender().beginFileTransfer(
+                self.open_file, consumer, self.media_info.media_type, {}, self.boundary
+            )
+        )
+
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> None:
+        self.open_file.close()
+
+
 class SpamMediaException(NotFoundError):
     """The media was blocked by a spam checker, so we simply 404 the request (in
     the same way as if it was quarantined).
@@ -370,3 +429,151 @@ class ReadableFileWrapper:
 
                 # We yield to the reactor by sleeping for 0 seconds.
                 await self.clock.sleep(0)
+
+
+@implementer(interfaces.IProducer)
+class MultipartFileSender:
+    """
+    A producer that sends the contents of a file to a federation request in the format
+    outlined in MSC3916 - a multipart/format-data response where the first field is a
+    JSON object and the second is the requested file.
+
+    This is a slight re-writing of twisted.protocols.basic.FileSender to achieve the format
+    outlined above.
+    """
+
+    CHUNK_SIZE = 2**14
+
+    lastSent = ""
+    deferred: Optional[defer.Deferred] = None
+
+    def beginFileTransfer(
+        self,
+        file: IO,
+        consumer: IConsumer,
+        file_content_type: str,
+        json_object: JsonDict,
+        boundary: bytes,
+    ) -> Deferred:
+        """
+        Begin transferring a file
+
+        Args:
+            file: The file object to read data from
+            consumer: The synapse request to write the data to
+            file_content_type: The content-type of the file
+            json_object: The JSON object to write to the first field of the response
+            boundary: bytes to be used as the multipart/form-data boundary
+
+        Returns:  A deferred whose callback will be invoked when the file has
+        been completely written to the consumer. The last byte written to the
+        consumer is passed to the callback.
+        """
+        self.file: Optional[IO] = file
+        self.consumer = consumer
+        self.json_field = json_object
+        self.json_field_written = False
+        self.content_type_written = False
+        self.file_content_type = file_content_type
+        self.boundary = boundary
+        self.deferred: Deferred = defer.Deferred()
+        self.consumer.registerProducer(self, False)
+        # while it's not entirely clear why this assignment is necessary, it mirrors
+        # the behavior in FileSender.beginFileTransfer and thus is preserved here
+        deferred = self.deferred
+        return deferred
+
+    def resumeProducing(self) -> None:
+        # write the first field, which will always be a json field
+        if not self.json_field_written:
+            self.consumer.write(CRLF + b"--" + self.boundary + CRLF)
+
+            content_type = Header(b"Content-Type", b"application/json")
+            self.consumer.write(bytes(content_type) + CRLF)
+
+            json_field = json.dumps(self.json_field)
+            json_bytes = json_field.encode("utf-8")
+            self.consumer.write(json_bytes)
+            self.consumer.write(CRLF + b"--" + self.boundary + CRLF)
+
+            self.json_field_written = True
+
+        chunk: Any = ""
+        if self.file:
+            # if we haven't written the content type yet, do so
+            if not self.content_type_written:
+                type = self.file_content_type.encode("utf-8")
+                content_type = Header(b"Content-Type", type)
+                self.consumer.write(bytes(content_type) + CRLF)
+                self.content_type_written = True
+
+            chunk = self.file.read(self.CHUNK_SIZE)
+
+        if not chunk:
+            # we've reached the end of the file
+            self.consumer.write(CRLF + b"--" + self.boundary + b"--" + CRLF)
+            self.file = None
+            self.consumer.unregisterProducer()
+
+            if self.deferred:
+                self.deferred.callback(self.lastSent)
+                self.deferred = None
+            return
+
+        self.consumer.write(chunk)
+        self.lastSent = chunk[-1:]
+
+    def pauseProducing(self) -> None:
+        pass
+
+    def stopProducing(self) -> None:
+        if self.deferred:
+            self.deferred.errback(Exception("Consumer asked us to stop producing"))
+            self.deferred = None
+
+
+class Header:
+    """
+    `Header` This class is a tiny wrapper that produces
+    request headers. We can't use standard python header
+    class because it encodes unicode fields using =? bla bla ?=
+    encoding, which is correct, but no one in HTTP world expects
+    that, everyone wants utf-8 raw bytes. (stolen from treq.multipart)
+
+    """
+
+    def __init__(
+        self,
+        name: bytes,
+        value: Any,
+        params: Optional[List[Tuple[Any, Any]]] = None,
+    ):
+        self.name = name
+        self.value = value
+        self.params = params or []
+
+    def add_param(self, name: Any, value: Any) -> None:
+        self.params.append((name, value))
+
+    def __bytes__(self) -> bytes:
+        with closing(BytesIO()) as h:
+            h.write(self.name + b": " + escape(self.value).encode("us-ascii"))
+            if self.params:
+                for name, val in self.params:
+                    h.write(b"; ")
+                    h.write(escape(name).encode("us-ascii"))
+                    h.write(b"=")
+                    h.write(b'"' + escape(val).encode("utf-8") + b'"')
+            h.seek(0)
+            return h.read()
+
+
+def escape(value: Union[str, bytes]) -> str:
+    """
+    This function prevents header values from corrupting the request,
+    a newline in the file name parameter makes form-data request unreadable
+    for a majority of parsers. (stolen from treq.multipart)
+    """
+    if isinstance(value, bytes):
+        value = value.decode("utf-8")
+    return value.replace("\r", "").replace("\n", "").replace('"', '\\"')