summary refs log tree commit diff
path: root/synapse/media/media_storage.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/media/media_storage.py')
-rw-r--r--synapse/media/media_storage.py223
1 files changed, 8 insertions, 215 deletions
diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py
index 2f55d12b6b..b3cd3fd8f4 100644
--- a/synapse/media/media_storage.py
+++ b/synapse/media/media_storage.py
@@ -19,12 +19,9 @@
 #
 #
 import contextlib
-import json
 import logging
 import os
 import shutil
-from contextlib import closing
-from io import BytesIO
 from types import TracebackType
 from typing import (
     IO,
@@ -33,19 +30,14 @@ from typing import (
     AsyncIterator,
     BinaryIO,
     Callable,
-    List,
     Optional,
     Sequence,
     Tuple,
     Type,
-    Union,
 )
-from uuid import uuid4
 
 import attr
-from zope.interface import implementer
 
-from twisted.internet import defer, interfaces
 from twisted.internet.defer import Deferred
 from twisted.internet.interfaces import IConsumer
 from twisted.protocols.basic import FileSender
@@ -56,19 +48,15 @@ from synapse.logging.opentracing import start_active_span, trace, trace_with_opn
 from synapse.util import Clock
 from synapse.util.file_consumer import BackgroundFileConsumer
 
-from ..storage.databases.main.media_repository import LocalMedia
-from ..types import JsonDict
 from ._base import FileInfo, Responder
 from .filepath import MediaFilePaths
 
 if TYPE_CHECKING:
-    from synapse.media.storage_provider import StorageProviderWrapper
+    from synapse.media.storage_provider import StorageProvider
     from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
 
-CRLF = b"\r\n"
-
 
 class MediaStorage:
     """Responsible for storing/fetching files from local sources.
@@ -85,7 +73,7 @@ class MediaStorage:
         hs: "HomeServer",
         local_media_directory: str,
         filepaths: MediaFilePaths,
-        storage_providers: Sequence["StorageProviderWrapper"],
+        storage_providers: Sequence["StorageProvider"],
     ):
         self.hs = hs
         self.reactor = hs.get_reactor()
@@ -181,23 +169,15 @@ class MediaStorage:
 
             raise e from None
 
-    async def fetch_media(
-        self,
-        file_info: FileInfo,
-        media_info: Optional[LocalMedia] = None,
-        federation: bool = False,
-    ) -> Optional[Responder]:
+    async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
         """Attempts to fetch media described by file_info from the local cache
         and configured storage providers.
 
         Args:
-            file_info: Metadata about the media file
-            media_info: Metadata about the media item
-            federation: Whether this file is being fetched for a federation request
+            file_info
 
         Returns:
-            If the file was found returns a Responder (a Multipart Responder if the requested
-            file is for the federation /download endpoint), otherwise None.
+            Returns a Responder if the file was found, otherwise None.
         """
         paths = [self._file_info_to_path(file_info)]
 
@@ -217,19 +197,12 @@ class MediaStorage:
             local_path = os.path.join(self.local_media_directory, path)
             if os.path.exists(local_path):
                 logger.debug("responding with local file %s", local_path)
-                if federation:
-                    assert media_info is not None
-                    boundary = uuid4().hex.encode("ascii")
-                    return MultipartResponder(
-                        open(local_path, "rb"), media_info, boundary
-                    )
-                else:
-                    return FileResponder(open(local_path, "rb"))
+                return FileResponder(open(local_path, "rb"))
             logger.debug("local file %s did not exist", local_path)
 
         for provider in self.storage_providers:
             for path in paths:
-                res: Any = await provider.fetch(path, file_info, media_info, federation)
+                res: Any = await provider.fetch(path, file_info)
                 if res:
                     logger.debug("Streaming %s from %s", path, provider)
                     return res
@@ -343,7 +316,7 @@ class FileResponder(Responder):
     """Wraps an open file that can be sent to a request.
 
     Args:
-        open_file: A file like object to be streamed to the client,
+        open_file: A file like object to be streamed ot the client,
             is closed when finished streaming.
     """
 
@@ -364,38 +337,6 @@ class FileResponder(Responder):
         self.open_file.close()
 
 
-class MultipartResponder(Responder):
-    """Wraps an open file, formats the response according to MSC3916 and sends it to a
-    federation request.
-
-    Args:
-        open_file: A file like object to be streamed to the client,
-            is closed when finished streaming.
-        media_info: metadata about the media item
-        boundary: bytes to use for the multipart response boundary
-    """
-
-    def __init__(self, open_file: IO, media_info: LocalMedia, boundary: bytes) -> None:
-        self.open_file = open_file
-        self.media_info = media_info
-        self.boundary = boundary
-
-    def write_to_consumer(self, consumer: IConsumer) -> Deferred:
-        return make_deferred_yieldable(
-            MultipartFileSender().beginFileTransfer(
-                self.open_file, consumer, self.media_info.media_type, {}, self.boundary
-            )
-        )
-
-    def __exit__(
-        self,
-        exc_type: Optional[Type[BaseException]],
-        exc_val: Optional[BaseException],
-        exc_tb: Optional[TracebackType],
-    ) -> None:
-        self.open_file.close()
-
-
 class SpamMediaException(NotFoundError):
     """The media was blocked by a spam checker, so we simply 404 the request (in
     the same way as if it was quarantined).
@@ -429,151 +370,3 @@ class ReadableFileWrapper:
 
                 # We yield to the reactor by sleeping for 0 seconds.
                 await self.clock.sleep(0)
-
-
-@implementer(interfaces.IProducer)
-class MultipartFileSender:
-    """
-    A producer that sends the contents of a file to a federation request in the format
-    outlined in MSC3916 - a multipart/format-data response where the first field is a
-    JSON object and the second is the requested file.
-
-    This is a slight re-writing of twisted.protocols.basic.FileSender to achieve the format
-    outlined above.
-    """
-
-    CHUNK_SIZE = 2**14
-
-    lastSent = ""
-    deferred: Optional[defer.Deferred] = None
-
-    def beginFileTransfer(
-        self,
-        file: IO,
-        consumer: IConsumer,
-        file_content_type: str,
-        json_object: JsonDict,
-        boundary: bytes,
-    ) -> Deferred:
-        """
-        Begin transferring a file
-
-        Args:
-            file: The file object to read data from
-            consumer: The synapse request to write the data to
-            file_content_type: The content-type of the file
-            json_object: The JSON object to write to the first field of the response
-            boundary: bytes to be used as the multipart/form-data boundary
-
-        Returns:  A deferred whose callback will be invoked when the file has
-        been completely written to the consumer. The last byte written to the
-        consumer is passed to the callback.
-        """
-        self.file: Optional[IO] = file
-        self.consumer = consumer
-        self.json_field = json_object
-        self.json_field_written = False
-        self.content_type_written = False
-        self.file_content_type = file_content_type
-        self.boundary = boundary
-        self.deferred: Deferred = defer.Deferred()
-        self.consumer.registerProducer(self, False)
-        # while it's not entirely clear why this assignment is necessary, it mirrors
-        # the behavior in FileSender.beginFileTransfer and thus is preserved here
-        deferred = self.deferred
-        return deferred
-
-    def resumeProducing(self) -> None:
-        # write the first field, which will always be a json field
-        if not self.json_field_written:
-            self.consumer.write(CRLF + b"--" + self.boundary + CRLF)
-
-            content_type = Header(b"Content-Type", b"application/json")
-            self.consumer.write(bytes(content_type) + CRLF)
-
-            json_field = json.dumps(self.json_field)
-            json_bytes = json_field.encode("utf-8")
-            self.consumer.write(json_bytes)
-            self.consumer.write(CRLF + b"--" + self.boundary + CRLF)
-
-            self.json_field_written = True
-
-        chunk: Any = ""
-        if self.file:
-            # if we haven't written the content type yet, do so
-            if not self.content_type_written:
-                type = self.file_content_type.encode("utf-8")
-                content_type = Header(b"Content-Type", type)
-                self.consumer.write(bytes(content_type) + CRLF)
-                self.content_type_written = True
-
-            chunk = self.file.read(self.CHUNK_SIZE)
-
-        if not chunk:
-            # we've reached the end of the file
-            self.consumer.write(CRLF + b"--" + self.boundary + b"--" + CRLF)
-            self.file = None
-            self.consumer.unregisterProducer()
-
-            if self.deferred:
-                self.deferred.callback(self.lastSent)
-                self.deferred = None
-            return
-
-        self.consumer.write(chunk)
-        self.lastSent = chunk[-1:]
-
-    def pauseProducing(self) -> None:
-        pass
-
-    def stopProducing(self) -> None:
-        if self.deferred:
-            self.deferred.errback(Exception("Consumer asked us to stop producing"))
-            self.deferred = None
-
-
-class Header:
-    """
-    `Header` This class is a tiny wrapper that produces
-    request headers. We can't use standard python header
-    class because it encodes unicode fields using =? bla bla ?=
-    encoding, which is correct, but no one in HTTP world expects
-    that, everyone wants utf-8 raw bytes. (stolen from treq.multipart)
-
-    """
-
-    def __init__(
-        self,
-        name: bytes,
-        value: Any,
-        params: Optional[List[Tuple[Any, Any]]] = None,
-    ):
-        self.name = name
-        self.value = value
-        self.params = params or []
-
-    def add_param(self, name: Any, value: Any) -> None:
-        self.params.append((name, value))
-
-    def __bytes__(self) -> bytes:
-        with closing(BytesIO()) as h:
-            h.write(self.name + b": " + escape(self.value).encode("us-ascii"))
-            if self.params:
-                for name, val in self.params:
-                    h.write(b"; ")
-                    h.write(escape(name).encode("us-ascii"))
-                    h.write(b"=")
-                    h.write(b'"' + escape(val).encode("utf-8") + b'"')
-            h.seek(0)
-            return h.read()
-
-
-def escape(value: Union[str, bytes]) -> str:
-    """
-    This function prevents header values from corrupting the request,
-    a newline in the file name parameter makes form-data request unreadable
-    for a majority of parsers. (stolen from treq.multipart)
-    """
-    if isinstance(value, bytes):
-        value = value.decode("utf-8")
-    return value.replace("\r", "").replace("\n", "").replace('"', '\\"')