summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2024-06-17 17:58:43 +0100
committerErik Johnston <erik@matrix.org>2024-06-17 17:58:43 +0100
commit1a3a6b63ef823043ecd23d828b89ab9a79b7ac8b (patch)
tree1f6d342326f176b3f825816f1e45cf5cc5412660
parentMerge branch 'release-v1.109' into develop (diff)
downloadsynapse-1a3a6b63ef823043ecd23d828b89ab9a79b7ac8b.tar.xz
We need to support 3rd party storage providers
So we need to wrap the responders rather than changing them
-rw-r--r--synapse/media/_base.py16
-rw-r--r--synapse/media/media_repository.py7
-rw-r--r--synapse/media/media_storage.py254
-rw-r--r--synapse/media/storage_provider.py9
4 files changed, 151 insertions, 135 deletions
diff --git a/synapse/media/_base.py b/synapse/media/_base.py
index 19bca94170..12fa1425b2 100644
--- a/synapse/media/_base.py
+++ b/synapse/media/_base.py
@@ -46,10 +46,10 @@ from synapse.api.errors import Codes, cs_error
 from synapse.http.server import finish_request, respond_with_json
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
+from synapse.util import Clock
 from synapse.util.stringutils import is_ascii
 
 if TYPE_CHECKING:
-    from synapse.media.media_storage import MultipartResponder
     from synapse.storage.databases.main.media_repository import LocalMedia
 
 
@@ -275,8 +275,9 @@ def _can_encode_filename_as_token(x: str) -> bool:
 
 
 async def respond_with_multipart_responder(
+    clock: Clock,
     request: SynapseRequest,
-    responder: "Optional[MultipartResponder]",
+    responder: "Optional[Responder]",
     media_info: "LocalMedia",
 ) -> None:
     """
@@ -299,15 +300,22 @@ async def respond_with_multipart_responder(
             )
             return
 
+        from synapse.media.media_storage import MultipartFileConsumer
+
+        multipart_consumer = MultipartFileConsumer(
+            clock, request, media_info.media_type, {}
+        )
+
         logger.debug("Responding to media request with responder %s", responder)
         if media_info.media_length is not None:
             request.setHeader(b"Content-Length", b"%d" % (media_info.media_length,))
         request.setHeader(
-            b"Content-Type", b"multipart/mixed; boundary=%s" % responder.boundary
+            b"Content-Type",
+            b"multipart/mixed; boundary=%s" % multipart_consumer.boundary,
         )
 
         try:
-            await responder.write_to_consumer(request)
+            await responder.write_to_consumer(multipart_consumer)
         except Exception as e:
             # The majority of the time this will be due to the client having gone
             # away. Unfortunately, Twisted simply throws a generic exception at us
diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index c335e518a0..e9725c6b14 100644
--- a/synapse/media/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -58,7 +58,7 @@ from synapse.media._base import (
     respond_with_responder,
 )
 from synapse.media.filepath import MediaFilePaths
-from synapse.media.media_storage import MediaStorage, MultipartResponder
+from synapse.media.media_storage import MediaStorage
 from synapse.media.storage_provider import StorageProviderWrapper
 from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
 from synapse.media.url_previewer import UrlPreviewer
@@ -467,8 +467,9 @@ class MediaRepository:
         )
         if federation:
             # this really should be a Multipart responder but just in case
-            assert isinstance(responder, MultipartResponder)
-            await respond_with_multipart_responder(request, responder, media_info)
+            await respond_with_multipart_responder(
+                self.clock, request, responder, media_info
+            )
         else:
             await respond_with_responder(
                 request, responder, media_type, media_length, upload_name
diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py
index 2f55d12b6b..baf947b873 100644
--- a/synapse/media/media_storage.py
+++ b/synapse/media/media_storage.py
@@ -39,19 +39,24 @@ from typing import (
     Tuple,
     Type,
     Union,
+    cast,
 )
 from uuid import uuid4
 
 import attr
 from zope.interface import implementer
 
-from twisted.internet import defer, interfaces
+from twisted.internet import interfaces
 from twisted.internet.defer import Deferred
 from twisted.internet.interfaces import IConsumer
 from twisted.protocols.basic import FileSender
 
 from synapse.api.errors import NotFoundError
-from synapse.logging.context import defer_to_thread, make_deferred_yieldable
+from synapse.logging.context import (
+    defer_to_thread,
+    make_deferred_yieldable,
+    run_in_background,
+)
 from synapse.logging.opentracing import start_active_span, trace, trace_with_opname
 from synapse.util import Clock
 from synapse.util.file_consumer import BackgroundFileConsumer
@@ -217,14 +222,7 @@ class MediaStorage:
             local_path = os.path.join(self.local_media_directory, path)
             if os.path.exists(local_path):
                 logger.debug("responding with local file %s", local_path)
-                if federation:
-                    assert media_info is not None
-                    boundary = uuid4().hex.encode("ascii")
-                    return MultipartResponder(
-                        open(local_path, "rb"), media_info, boundary
-                    )
-                else:
-                    return FileResponder(open(local_path, "rb"))
+                return FileResponder(open(local_path, "rb"))
             logger.debug("local file %s did not exist", local_path)
 
         for provider in self.storage_providers:
@@ -364,38 +362,6 @@ class FileResponder(Responder):
         self.open_file.close()
 
 
-class MultipartResponder(Responder):
-    """Wraps an open file, formats the response according to MSC3916 and sends it to a
-    federation request.
-
-    Args:
-        open_file: A file like object to be streamed to the client,
-            is closed when finished streaming.
-        media_info: metadata about the media item
-        boundary: bytes to use for the multipart response boundary
-    """
-
-    def __init__(self, open_file: IO, media_info: LocalMedia, boundary: bytes) -> None:
-        self.open_file = open_file
-        self.media_info = media_info
-        self.boundary = boundary
-
-    def write_to_consumer(self, consumer: IConsumer) -> Deferred:
-        return make_deferred_yieldable(
-            MultipartFileSender().beginFileTransfer(
-                self.open_file, consumer, self.media_info.media_type, {}, self.boundary
-            )
-        )
-
-    def __exit__(
-        self,
-        exc_type: Optional[Type[BaseException]],
-        exc_val: Optional[BaseException],
-        exc_tb: Optional[TracebackType],
-    ) -> None:
-        self.open_file.close()
-
-
 class SpamMediaException(NotFoundError):
     """The media was blocked by a spam checker, so we simply 404 the request (in
     the same way as if it was quarantined).
@@ -431,105 +397,153 @@ class ReadableFileWrapper:
                 await self.clock.sleep(0)
 
 
-@implementer(interfaces.IProducer)
-class MultipartFileSender:
-    """
-    A producer that sends the contents of a file to a federation request in the format
-    outlined in MSC3916 - a multipart/format-data response where the first field is a
-    JSON object and the second is the requested file.
-
-    This is a slight re-writing of twisted.protocols.basic.FileSender to achieve the format
-    outlined above.
-    """
-
-    CHUNK_SIZE = 2**14
-
-    lastSent = ""
-    deferred: Optional[defer.Deferred] = None
-
-    def beginFileTransfer(
+@implementer(interfaces.IConsumer)
+@implementer(interfaces.IPushProducer)
+class MultipartFileConsumer:
+    def __init__(
         self,
-        file: IO,
-        consumer: IConsumer,
+        clock: Clock,
+        wrapped_consumer: interfaces.IConsumer,
         file_content_type: str,
         json_object: JsonDict,
-        boundary: bytes,
-    ) -> Deferred:
-        """
-        Begin transferring a file
-
-        Args:
-            file: The file object to read data from
-            consumer: The synapse request to write the data to
-            file_content_type: The content-type of the file
-            json_object: The JSON object to write to the first field of the response
-            boundary: bytes to be used as the multipart/form-data boundary
-
-        Returns:  A deferred whose callback will be invoked when the file has
-        been completely written to the consumer. The last byte written to the
-        consumer is passed to the callback.
-        """
-        self.file: Optional[IO] = file
-        self.consumer = consumer
+    ) -> None:
+        self.clock = clock
+        self.wrapped_consumer = wrapped_consumer
         self.json_field = json_object
         self.json_field_written = False
         self.content_type_written = False
         self.file_content_type = file_content_type
-        self.boundary = boundary
-        self.deferred: Deferred = defer.Deferred()
-        self.consumer.registerProducer(self, False)
-        # while it's not entirely clear why this assignment is necessary, it mirrors
-        # the behavior in FileSender.beginFileTransfer and thus is preserved here
-        deferred = self.deferred
-        return deferred
+        self.boundary = uuid4().hex.encode("ascii")
 
-    def resumeProducing(self) -> None:
-        # write the first field, which will always be a json field
+        self.producer: Optional["interfaces.IProducer"] = None
+        self.streaming = Optional[None]
+
+        self.paused = False
+
+    def registerProducer(
+        self, producer: "interfaces.IProducer", streaming: bool
+    ) -> None:
+        """
+        Register to receive data from a producer.
+
+        This sets self to be a consumer for a producer.  When this object runs
+        out of data (as when a send(2) call on a socket succeeds in moving the
+        last data from a userspace buffer into a kernelspace buffer), it will
+        ask the producer to resumeProducing().
+
+        For L{IPullProducer} providers, C{resumeProducing} will be called once
+        each time data is required.
+
+        For L{IPushProducer} providers, C{pauseProducing} will be called
+        whenever the write buffer fills up and C{resumeProducing} will only be
+        called when it empties.  The consumer will only call C{resumeProducing}
+        to balance a previous C{pauseProducing} call; the producer is assumed
+        to start in an un-paused state.
+
+        @param streaming: C{True} if C{producer} provides L{IPushProducer},
+            C{False} if C{producer} provides L{IPullProducer}.
+
+        @raise RuntimeError: If a producer is already registered.
+        """
+        self.producer = producer
+        self.streaming = streaming
+
+        self.wrapped_consumer.registerProducer(self, True)
+
+        run_in_background(self._resumeProducingRepeatedly)
+
+    def unregisterProducer(self) -> None:
+        """
+        Stop consuming data from a producer, without disconnecting.
+        """
+        self.wrapped_consumer.write(CRLF + b"--" + self.boundary + b"--" + CRLF)
+        self.wrapped_consumer.unregisterProducer()
+        self.paused = True
+
+    def write(self, data: bytes) -> None:
+        """
+        The producer will write data by calling this method.
+
+        The implementation must be non-blocking and perform whatever
+        buffering is necessary.  If the producer has provided enough data
+        for now and it is a L{IPushProducer}, the consumer may call its
+        C{pauseProducing} method.
+        """
         if not self.json_field_written:
-            self.consumer.write(CRLF + b"--" + self.boundary + CRLF)
+            self.wrapped_consumer.write(CRLF + b"--" + self.boundary + CRLF)
 
             content_type = Header(b"Content-Type", b"application/json")
-            self.consumer.write(bytes(content_type) + CRLF)
+            self.wrapped_consumer.write(bytes(content_type) + CRLF)
 
             json_field = json.dumps(self.json_field)
             json_bytes = json_field.encode("utf-8")
-            self.consumer.write(json_bytes)
-            self.consumer.write(CRLF + b"--" + self.boundary + CRLF)
+            self.wrapped_consumer.write(json_bytes)
+            self.wrapped_consumer.write(CRLF + b"--" + self.boundary + CRLF)
 
             self.json_field_written = True
 
-        chunk: Any = ""
-        if self.file:
-            # if we haven't written the content type yet, do so
-            if not self.content_type_written:
-                type = self.file_content_type.encode("utf-8")
-                content_type = Header(b"Content-Type", type)
-                self.consumer.write(bytes(content_type) + CRLF)
-                self.content_type_written = True
-
-            chunk = self.file.read(self.CHUNK_SIZE)
-
-        if not chunk:
-            # we've reached the end of the file
-            self.consumer.write(CRLF + b"--" + self.boundary + b"--" + CRLF)
-            self.file = None
-            self.consumer.unregisterProducer()
-
-            if self.deferred:
-                self.deferred.callback(self.lastSent)
-                self.deferred = None
-            return
+        # if we haven't written the content type yet, do so
+        if not self.content_type_written:
+            type = self.file_content_type.encode("utf-8")
+            content_type = Header(b"Content-Type", type)
+            self.wrapped_consumer.write(bytes(content_type) + CRLF)
+            self.content_type_written = True
+
+        self.wrapped_consumer.write(data)
+
+    def stopProducing(self) -> None:
+        """
+        Stop producing data.
 
-        self.consumer.write(chunk)
-        self.lastSent = chunk[-1:]
+        This tells a producer that its consumer has died, so it must stop
+        producing data for good.
+        """
+        assert self.producer is not None
+
+        self.paused = True
+        self.producer.stopProducing()
 
     def pauseProducing(self) -> None:
-        pass
+        """
+        Pause producing data.
 
-    def stopProducing(self) -> None:
-        if self.deferred:
-            self.deferred.errback(Exception("Consumer asked us to stop producing"))
-            self.deferred = None
+        Tells a producer that it has produced too much data to process for
+        the time being, and to stop until C{resumeProducing()} is called.
+        """
+        assert self.producer is not None
+
+        self.paused = True
+
+        if self.streaming:
+            cast("interfaces.IPushProducer", self.producer).pauseProducing()
+        else:
+            self.paused = True
+
+    def resumeProducing(self) -> None:
+        """
+        Resume producing data.
+
+        This tells a producer to re-add itself to the main loop and produce
+        more data for its consumer.
+        """
+        assert self.producer is not None
+
+        if self.streaming:
+            cast("interfaces.IPushProducer", self.producer).resumeProducing()
+            return
+
+        run_in_background(self._resumeProducingRepeatedly)
+
+    async def _resumeProducingRepeatedly(self) -> None:
+        assert self.producer is not None
+        assert not self.streaming
+
+        producer = cast("interfaces.IPullProducer", self.producer)
+
+        self.paused = False
+        while not self.paused:
+            producer.resumeProducing()
+            await self.clock.sleep(0)
 
 
 class Header:
diff --git a/synapse/media/storage_provider.py b/synapse/media/storage_provider.py
index a2d50adf65..a71da3587c 100644
--- a/synapse/media/storage_provider.py
+++ b/synapse/media/storage_provider.py
@@ -24,7 +24,6 @@ import logging
 import os
 import shutil
 from typing import TYPE_CHECKING, Callable, Optional
-from uuid import uuid4
 
 from synapse.config._base import Config
 from synapse.logging.context import defer_to_thread, run_in_background
@@ -33,7 +32,7 @@ from synapse.util.async_helpers import maybe_awaitable
 
 from ..storage.databases.main.media_repository import LocalMedia
 from ._base import FileInfo, Responder
-from .media_storage import FileResponder, MultipartResponder
+from .media_storage import FileResponder
 
 logger = logging.getLogger(__name__)
 
@@ -201,12 +200,6 @@ class FileStorageProviderBackend(StorageProvider):
 
         backup_fname = os.path.join(self.base_directory, path)
         if os.path.isfile(backup_fname):
-            if federation:
-                assert media_info is not None
-                boundary = uuid4().hex.encode("ascii")
-                return MultipartResponder(
-                    open(backup_fname, "rb"), media_info, boundary
-                )
             return FileResponder(open(backup_fname, "rb"))
 
         return None