summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorShay <hillerys@element.io>2024-06-25 07:35:37 -0700
committerGitHub <noreply@github.com>2024-06-25 14:35:37 +0000
commita023538822c8e241cdd3180c9cbbcb0f4eb84844 (patch)
treea683081ca5898833895ec4cc70a4f0959a48df06 /synapse
parentFix refreshable_access_token_lifetime typo (#17357) (diff)
downloadsynapse-a023538822c8e241cdd3180c9cbbcb0f4eb84844.tar.xz
Re-introduce federation /download endpoint (#17350)
Diffstat (limited to 'synapse')
-rw-r--r--synapse/federation/transport/server/__init__.py8
-rw-r--r--synapse/federation/transport/server/_base.py24
-rw-r--r--synapse/federation/transport/server/federation.py41
-rw-r--r--synapse/media/_base.py78
-rw-r--r--synapse/media/media_repository.py14
-rw-r--r--synapse/media/media_storage.py259
6 files changed, 413 insertions, 11 deletions
diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py
index bac569e977..edaf0196d6 100644
--- a/synapse/federation/transport/server/__init__.py
+++ b/synapse/federation/transport/server/__init__.py
@@ -33,6 +33,7 @@ from synapse.federation.transport.server.federation import (
     FEDERATION_SERVLET_CLASSES,
     FederationAccountStatusServlet,
     FederationUnstableClientKeysClaimServlet,
+    FederationUnstableMediaDownloadServlet,
 )
 from synapse.http.server import HttpServer, JsonResource
 from synapse.http.servlet import (
@@ -315,6 +316,13 @@ def register_servlets(
             ):
                 continue
 
+            if servletclass == FederationUnstableMediaDownloadServlet:
+                if (
+                    not hs.config.server.enable_media_repo
+                    or not hs.config.experimental.msc3916_authenticated_media_enabled
+                ):
+                    continue
+
             servletclass(
                 hs=hs,
                 authenticator=authenticator,
diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py
index db0f5076a9..4e2717b565 100644
--- a/synapse/federation/transport/server/_base.py
+++ b/synapse/federation/transport/server/_base.py
@@ -360,13 +360,29 @@ class BaseFederationServlet:
                                     "request"
                                 )
                                 return None
+                            if (
+                                func.__self__.__class__.__name__  # type: ignore
+                                == "FederationUnstableMediaDownloadServlet"
+                            ):
+                                response = await func(
+                                    origin, content, request, *args, **kwargs
+                                )
+                            else:
+                                response = await func(
+                                    origin, content, request.args, *args, **kwargs
+                                )
+                    else:
+                        if (
+                            func.__self__.__class__.__name__  # type: ignore
+                            == "FederationUnstableMediaDownloadServlet"
+                        ):
+                            response = await func(
+                                origin, content, request, *args, **kwargs
+                            )
+                        else:
                             response = await func(
                                 origin, content, request.args, *args, **kwargs
                             )
-                    else:
-                        response = await func(
-                            origin, content, request.args, *args, **kwargs
-                        )
             finally:
                 # if we used the origin's context as the parent, add a new span using
                 # the servlet span as a parent, so that we have a link
diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index a59734785f..67bb907050 100644
--- a/synapse/federation/transport/server/federation.py
+++ b/synapse/federation/transport/server/federation.py
@@ -44,10 +44,13 @@ from synapse.federation.transport.server._base import (
 )
 from synapse.http.servlet import (
     parse_boolean_from_args,
+    parse_integer,
     parse_integer_from_args,
     parse_string_from_args,
     parse_strings_from_args,
 )
+from synapse.http.site import SynapseRequest
+from synapse.media._base import DEFAULT_MAX_TIMEOUT_MS, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS
 from synapse.types import JsonDict
 from synapse.util import SYNAPSE_VERSION
 from synapse.util.ratelimitutils import FederationRateLimiter
@@ -787,6 +790,43 @@ class FederationAccountStatusServlet(BaseFederationServerServlet):
         return 200, {"account_statuses": statuses, "failures": failures}
 
 
+class FederationUnstableMediaDownloadServlet(BaseFederationServerServlet):
+    """
+    Implementation of new federation media `/download` endpoint outlined in MSC3916. Returns
+    a multipart/mixed response consisting of a JSON object and the requested media
+    item. This endpoint only returns local media.
+    """
+
+    PATH = "/media/download/(?P<media_id>[^/]*)"
+    PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3916"
+    RATELIMIT = True
+
+    def __init__(
+        self,
+        hs: "HomeServer",
+        ratelimiter: FederationRateLimiter,
+        authenticator: Authenticator,
+        server_name: str,
+    ):
+        super().__init__(hs, authenticator, ratelimiter, server_name)
+        self.media_repo = self.hs.get_media_repository()
+
+    async def on_GET(
+        self,
+        origin: Optional[str],
+        content: Literal[None],
+        request: SynapseRequest,
+        media_id: str,
+    ) -> None:
+        max_timeout_ms = parse_integer(
+            request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS
+        )
+        max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS)
+        await self.media_repo.get_local_media(
+            request, media_id, None, max_timeout_ms, federation=True
+        )
+
+
 FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
     FederationSendServlet,
     FederationEventServlet,
@@ -818,4 +858,5 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
     FederationV1SendKnockServlet,
     FederationMakeKnockServlet,
     FederationAccountStatusServlet,
+    FederationUnstableMediaDownloadServlet,
 )
diff --git a/synapse/media/_base.py b/synapse/media/_base.py
index 3fbed6062f..7ad0b7c3cf 100644
--- a/synapse/media/_base.py
+++ b/synapse/media/_base.py
@@ -25,7 +25,16 @@ import os
 import urllib
 from abc import ABC, abstractmethod
 from types import TracebackType
-from typing import Awaitable, Dict, Generator, List, Optional, Tuple, Type
+from typing import (
+    TYPE_CHECKING,
+    Awaitable,
+    Dict,
+    Generator,
+    List,
+    Optional,
+    Tuple,
+    Type,
+)
 
 import attr
 
@@ -37,8 +46,13 @@ from synapse.api.errors import Codes, cs_error
 from synapse.http.server import finish_request, respond_with_json
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
+from synapse.util import Clock
 from synapse.util.stringutils import is_ascii
 
+if TYPE_CHECKING:
+    from synapse.storage.databases.main.media_repository import LocalMedia
+
+
 logger = logging.getLogger(__name__)
 
 # list all text content types that will have the charset default to UTF-8 when
@@ -260,6 +274,68 @@ def _can_encode_filename_as_token(x: str) -> bool:
     return True
 
 
+async def respond_with_multipart_responder(
+    clock: Clock,
+    request: SynapseRequest,
+    responder: "Optional[Responder]",
+    media_info: "LocalMedia",
+) -> None:
+    """
+    Responds to requests originating from the federation media `/download` endpoint by
+    streaming a multipart/mixed response
+
+    Args:
+        clock:
+        request: the federation request to respond to
+        responder: the responder which will send the response
+        media_info: metadata about the media item
+    """
+    if not responder:
+        respond_404(request)
+        return
+
+    # If we have a responder we *must* use it as a context manager.
+    with responder:
+        if request._disconnected:
+            logger.warning(
+                "Not sending response to request %s, already disconnected.", request
+            )
+            return
+
+        from synapse.media.media_storage import MultipartFileConsumer
+
+        # note that currently the json_object is just {}, this will change when linked media
+        # is implemented
+        multipart_consumer = MultipartFileConsumer(
+            clock, request, media_info.media_type, {}, media_info.media_length
+        )
+
+        logger.debug("Responding to media request with responder %s", responder)
+        if media_info.media_length is not None:
+            content_length = multipart_consumer.content_length()
+            assert content_length is not None
+            request.setHeader(b"Content-Length", b"%d" % (content_length,))
+
+        request.setHeader(
+            b"Content-Type",
+            b"multipart/mixed; boundary=%s" % multipart_consumer.boundary,
+        )
+
+        try:
+            await responder.write_to_consumer(multipart_consumer)
+        except Exception as e:
+            # The majority of the time this will be due to the client having gone
+            # away. Unfortunately, Twisted simply throws a generic exception at us
+            # in that case.
+            logger.warning("Failed to write to consumer: %s %s", type(e), e)
+
+            # Unregister the producer, if it has one, so Twisted doesn't complain
+            if request.producer:
+                request.unregisterProducer()
+
+    finish_request(request)
+
+
 async def respond_with_responder(
     request: SynapseRequest,
     responder: "Optional[Responder]",
diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 6ed56099ca..1436329fad 100644
--- a/synapse/media/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -54,6 +54,7 @@ from synapse.media._base import (
     ThumbnailInfo,
     get_filename_from_headers,
     respond_404,
+    respond_with_multipart_responder,
     respond_with_responder,
 )
 from synapse.media.filepath import MediaFilePaths
@@ -429,6 +430,7 @@ class MediaRepository:
         media_id: str,
         name: Optional[str],
         max_timeout_ms: int,
+        federation: bool = False,
     ) -> None:
         """Responds to requests for local media, if exists, or returns 404.
 
@@ -440,6 +442,7 @@ class MediaRepository:
                 the filename in the Content-Disposition header of the response.
             max_timeout_ms: the maximum number of milliseconds to wait for the
                 media to be uploaded.
+            federation: whether the local media being fetched is for a federation request
 
         Returns:
             Resolves once a response has successfully been written to request
@@ -460,9 +463,14 @@ class MediaRepository:
         file_info = FileInfo(None, media_id, url_cache=bool(url_cache))
 
         responder = await self.media_storage.fetch_media(file_info)
-        await respond_with_responder(
-            request, responder, media_type, media_length, upload_name
-        )
+        if federation:
+            await respond_with_multipart_responder(
+                self.clock, request, responder, media_info
+            )
+        else:
+            await respond_with_responder(
+                request, responder, media_type, media_length, upload_name
+            )
 
     async def get_remote_media(
         self,
diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py
index b3cd3fd8f4..1be2c9b5f5 100644
--- a/synapse/media/media_storage.py
+++ b/synapse/media/media_storage.py
@@ -19,9 +19,12 @@
 #
 #
 import contextlib
+import json
 import logging
 import os
 import shutil
+from contextlib import closing
+from io import BytesIO
 from types import TracebackType
 from typing import (
     IO,
@@ -30,24 +33,35 @@ from typing import (
     AsyncIterator,
     BinaryIO,
     Callable,
+    List,
     Optional,
     Sequence,
     Tuple,
     Type,
+    Union,
+    cast,
 )
+from uuid import uuid4
 
 import attr
+from zope.interface import implementer
 
+from twisted.internet import interfaces
 from twisted.internet.defer import Deferred
 from twisted.internet.interfaces import IConsumer
 from twisted.protocols.basic import FileSender
 
 from synapse.api.errors import NotFoundError
-from synapse.logging.context import defer_to_thread, make_deferred_yieldable
+from synapse.logging.context import (
+    defer_to_thread,
+    make_deferred_yieldable,
+    run_in_background,
+)
 from synapse.logging.opentracing import start_active_span, trace, trace_with_opname
 from synapse.util import Clock
 from synapse.util.file_consumer import BackgroundFileConsumer
 
+from ..types import JsonDict
 from ._base import FileInfo, Responder
 from .filepath import MediaFilePaths
 
@@ -57,6 +71,8 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
+CRLF = b"\r\n"
+
 
 class MediaStorage:
     """Responsible for storing/fetching files from local sources.
@@ -174,7 +190,7 @@ class MediaStorage:
         and configured storage providers.
 
         Args:
-            file_info
+            file_info: Metadata about the media file
 
         Returns:
             Returns a Responder if the file was found, otherwise None.
@@ -316,7 +332,7 @@ class FileResponder(Responder):
     """Wraps an open file that can be sent to a request.
 
     Args:
-        open_file: A file like object to be streamed ot the client,
+        open_file: A file like object to be streamed to the client,
             is closed when finished streaming.
     """
 
@@ -370,3 +386,240 @@ class ReadableFileWrapper:
 
                 # We yield to the reactor by sleeping for 0 seconds.
                 await self.clock.sleep(0)
+
+
+@implementer(interfaces.IConsumer)
+@implementer(interfaces.IPushProducer)
+class MultipartFileConsumer:
+    """Wraps a given consumer so that any data that gets written to it gets
+    converted to a multipart format.
+    """
+
+    def __init__(
+        self,
+        clock: Clock,
+        wrapped_consumer: interfaces.IConsumer,
+        file_content_type: str,
+        json_object: JsonDict,
+        content_length: Optional[int] = None,
+    ) -> None:
+        self.clock = clock
+        self.wrapped_consumer = wrapped_consumer
+        self.json_field = json_object
+        self.json_field_written = False
+        self.content_type_written = False
+        self.file_content_type = file_content_type
+        self.boundary = uuid4().hex.encode("ascii")
+
+        # The producer that registered with us, and if it's a push or pull
+        # producer.
+        self.producer: Optional["interfaces.IProducer"] = None
+        self.streaming: Optional[bool] = None
+
+        # Whether the wrapped consumer has asked us to pause.
+        self.paused = False
+
+        self.length = content_length
+
+    ### IConsumer APIs ###
+
+    def registerProducer(
+        self, producer: "interfaces.IProducer", streaming: bool
+    ) -> None:
+        """
+        Register to receive data from a producer.
+
+        This sets self to be a consumer for a producer.  When this object runs
+        out of data (as when a send(2) call on a socket succeeds in moving the
+        last data from a userspace buffer into a kernelspace buffer), it will
+        ask the producer to resumeProducing().
+
+        For L{IPullProducer} providers, C{resumeProducing} will be called once
+        each time data is required.
+
+        For L{IPushProducer} providers, C{pauseProducing} will be called
+        whenever the write buffer fills up and C{resumeProducing} will only be
+        called when it empties.  The consumer will only call C{resumeProducing}
+        to balance a previous C{pauseProducing} call; the producer is assumed
+        to start in an un-paused state.
+
+        @param streaming: C{True} if C{producer} provides L{IPushProducer},
+            C{False} if C{producer} provides L{IPullProducer}.
+
+        @raise RuntimeError: If a producer is already registered.
+        """
+        self.producer = producer
+        self.streaming = streaming
+
+        self.wrapped_consumer.registerProducer(self, True)
+
+        # kick off producing if `self.producer` is not a streaming producer
+        if not streaming:
+            self.resumeProducing()
+
+    def unregisterProducer(self) -> None:
+        """
+        Stop consuming data from a producer, without disconnecting.
+        """
+        self.wrapped_consumer.write(CRLF + b"--" + self.boundary + b"--" + CRLF)
+        self.wrapped_consumer.unregisterProducer()
+        self.paused = True
+
+    def write(self, data: bytes) -> None:
+        """
+        The producer will write data by calling this method.
+
+        The implementation must be non-blocking and perform whatever
+        buffering is necessary.  If the producer has provided enough data
+        for now and it is a L{IPushProducer}, the consumer may call its
+        C{pauseProducing} method.
+        """
+        if not self.json_field_written:
+            self.wrapped_consumer.write(CRLF + b"--" + self.boundary + CRLF)
+
+            content_type = Header(b"Content-Type", b"application/json")
+            self.wrapped_consumer.write(bytes(content_type) + CRLF)
+
+            json_field = json.dumps(self.json_field)
+            json_bytes = json_field.encode("utf-8")
+            self.wrapped_consumer.write(CRLF + json_bytes)
+            self.wrapped_consumer.write(CRLF + b"--" + self.boundary + CRLF)
+
+            self.json_field_written = True
+
+        # if we haven't written the content type yet, do so
+        if not self.content_type_written:
+            type = self.file_content_type.encode("utf-8")
+            content_type = Header(b"Content-Type", type)
+            self.wrapped_consumer.write(bytes(content_type) + CRLF + CRLF)
+            self.content_type_written = True
+
+        self.wrapped_consumer.write(data)
+
+    ### IPushProducer APIs ###
+
+    def stopProducing(self) -> None:
+        """
+        Stop producing data.
+
+        This tells a producer that its consumer has died, so it must stop
+        producing data for good.
+        """
+        assert self.producer is not None
+
+        self.paused = True
+        self.producer.stopProducing()
+
+    def pauseProducing(self) -> None:
+        """
+        Pause producing data.
+
+        Tells a producer that it has produced too much data to process for
+        the time being, and to stop until C{resumeProducing()} is called.
+        """
+        assert self.producer is not None
+
+        self.paused = True
+
+        if self.streaming:
+            cast("interfaces.IPushProducer", self.producer).pauseProducing()
+        else:
+            self.paused = True
+
+    def resumeProducing(self) -> None:
+        """
+        Resume producing data.
+
+        This tells a producer to re-add itself to the main loop and produce
+        more data for its consumer.
+        """
+        assert self.producer is not None
+
+        if self.streaming:
+            cast("interfaces.IPushProducer", self.producer).resumeProducing()
+        else:
+            # If the producer is not a streaming producer we need to start
+            # repeatedly calling  `resumeProducing` in a loop.
+            run_in_background(self._resumeProducingRepeatedly)
+
+    def content_length(self) -> Optional[int]:
+        """
+        Calculate the content length of the multipart response
+        in bytes.
+        """
+        if not self.length:
+            return None
+        # calculate length of json field and content-type header
+        json_field = json.dumps(self.json_field)
+        json_bytes = json_field.encode("utf-8")
+        json_length = len(json_bytes)
+
+        type = self.file_content_type.encode("utf-8")
+        content_type = Header(b"Content-Type", type)
+        type_length = len(bytes(content_type))
+
+        # 154 is the length of the elements that aren't variable, ie
+        # CRLFs and boundary strings, etc
+        self.length += json_length + type_length + 154
+
+        return self.length
+
+    ### Internal APIs. ###
+
+    async def _resumeProducingRepeatedly(self) -> None:
+        assert self.producer is not None
+        assert not self.streaming
+
+        producer = cast("interfaces.IPullProducer", self.producer)
+
+        self.paused = False
+        while not self.paused:
+            producer.resumeProducing()
+            await self.clock.sleep(0)
+
+
+class Header:
+    """
+    `Header` This class is a tiny wrapper that produces
+    request headers. We can't use standard python header
+    class because it encodes unicode fields using =? bla bla ?=
+    encoding, which is correct, but no one in HTTP world expects
+    that, everyone wants utf-8 raw bytes. (stolen from treq.multipart)
+
+    """
+
+    def __init__(
+        self,
+        name: bytes,
+        value: Any,
+        params: Optional[List[Tuple[Any, Any]]] = None,
+    ):
+        self.name = name
+        self.value = value
+        self.params = params or []
+
+    def add_param(self, name: Any, value: Any) -> None:
+        self.params.append((name, value))
+
+    def __bytes__(self) -> bytes:
+        with closing(BytesIO()) as h:
+            h.write(self.name + b": " + escape(self.value).encode("us-ascii"))
+            if self.params:
+                for name, val in self.params:
+                    h.write(b"; ")
+                    h.write(escape(name).encode("us-ascii"))
+                    h.write(b"=")
+                    h.write(b'"' + escape(val).encode("utf-8") + b'"')
+            h.seek(0)
+            return h.read()
+
+
+def escape(value: Union[str, bytes]) -> str:
+    """
+    This function prevents header values from corrupting the request,
+    a newline in the file name parameter makes form-data request unreadable
+    for a majority of parsers. (stolen from treq.multipart)
+    """
+    if isinstance(value, bytes):
+        value = value.decode("utf-8")
+    return value.replace("\r", "").replace("\n", "").replace('"', '\\"')