summary refs log tree commit diff
path: root/synapse/http
diff options
context:
space:
mode:
authorShay <hillerys@element.io>2024-07-02 06:07:04 -0700
committerGitHub <noreply@github.com>2024-07-02 14:07:04 +0100
commit8f890447b0f8b6cbe369b162670185e8c746b2f2 (patch)
treec8c290661a59b06257ce7e2fda19e799d83825eb /synapse/http
parentFix sync waiting for an invalid token from the "future" (#17386) (diff)
downloadsynapse-8f890447b0f8b6cbe369b162670185e8c746b2f2.tar.xz
Support MSC3916 by adding `_matrix/client/v1/media/download` endpoint (#17365)
Diffstat (limited to 'synapse/http')
-rw-r--r--synapse/http/client.py152
-rw-r--r--synapse/http/matrixfederationclient.py192
2 files changed, 344 insertions, 0 deletions
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 4718517c97..56ad28eabf 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -35,6 +35,8 @@ from typing import (
     Union,
 )
 
+import attr
+import multipart
 import treq
 from canonicaljson import encode_canonical_json
 from netaddr import AddrFormatError, IPAddress, IPSet
@@ -1006,6 +1008,130 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
         self._maybe_fail()
 
 
+@attr.s(auto_attribs=True, slots=True)
+class MultipartResponse:
+    """
+    A small class to hold parsed values of a multipart response.
+    """
+
+    json: bytes = b"{}"
+    length: Optional[int] = None
+    content_type: Optional[bytes] = None
+    disposition: Optional[bytes] = None
+    url: Optional[bytes] = None
+
+
+class _MultipartParserProtocol(protocol.Protocol):
+    """
+    Protocol to read and parse a MSC3916 multipart/mixed response
+    """
+
+    transport: Optional[ITCPTransport] = None
+
+    def __init__(
+        self,
+        stream: ByteWriteable,
+        deferred: defer.Deferred,
+        boundary: str,
+        max_length: Optional[int],
+    ) -> None:
+        self.stream = stream
+        self.deferred = deferred
+        self.boundary = boundary
+        self.max_length = max_length
+        self.parser = None
+        self.multipart_response = MultipartResponse()
+        self.has_redirect = False
+        self.in_json = False
+        self.json_done = False
+        self.file_length = 0
+        self.total_length = 0
+        self.in_disposition = False
+        self.in_content_type = False
+
+    def dataReceived(self, incoming_data: bytes) -> None:
+        if self.deferred.called:
+            return
+
+        # we don't have a parser yet, instantiate it
+        if not self.parser:
+
+            def on_header_field(data: bytes, start: int, end: int) -> None:
+                if data[start:end] == b"Location":
+                    self.has_redirect = True
+                if data[start:end] == b"Content-Disposition":
+                    self.in_disposition = True
+                if data[start:end] == b"Content-Type":
+                    self.in_content_type = True
+
+            def on_header_value(data: bytes, start: int, end: int) -> None:
+                # the first header should be content-type for application/json
+                if not self.in_json and not self.json_done:
+                    assert data[start:end] == b"application/json"
+                    self.in_json = True
+                elif self.has_redirect:
+                    self.multipart_response.url = data[start:end]
+                elif self.in_content_type:
+                    self.multipart_response.content_type = data[start:end]
+                    self.in_content_type = False
+                elif self.in_disposition:
+                    self.multipart_response.disposition = data[start:end]
+                    self.in_disposition = False
+
+            def on_part_data(data: bytes, start: int, end: int) -> None:
+                # we've seen json header but haven't written the json data
+                if self.in_json and not self.json_done:
+                    self.multipart_response.json = data[start:end]
+                    self.json_done = True
+                # we have a redirect header rather than a file, and have already captured it
+                elif self.has_redirect:
+                    return
+                # otherwise we are in the file part
+                else:
+                    logger.info("Writing multipart file data to stream")
+                    try:
+                        self.stream.write(data[start:end])
+                    except Exception as e:
+                        logger.warning(
+                            f"Exception encountered writing file data to stream: {e}"
+                        )
+                        self.deferred.errback()
+                    self.file_length += end - start
+
+            callbacks = {
+                "on_header_field": on_header_field,
+                "on_header_value": on_header_value,
+                "on_part_data": on_part_data,
+            }
+            self.parser = multipart.MultipartParser(self.boundary, callbacks)
+
+        self.total_length += len(incoming_data)
+        if self.max_length is not None and self.total_length >= self.max_length:
+            self.deferred.errback(BodyExceededMaxSize())
+            # Close the connection (forcefully) since all the data will get
+            # discarded anyway.
+            assert self.transport is not None
+            self.transport.abortConnection()
+
+        try:
+            self.parser.write(incoming_data)  # type: ignore[attr-defined]
+        except Exception as e:
+            logger.warning(f"Exception writing to multipart parser: {e}")
+            self.deferred.errback()
+            return
+
+    def connectionLost(self, reason: Failure = connectionDone) -> None:
+        # If the maximum size was already exceeded, there's nothing to do.
+        if self.deferred.called:
+            return
+
+        if reason.check(ResponseDone):
+            self.multipart_response.length = self.file_length
+            self.deferred.callback(self.multipart_response)
+        else:
+            self.deferred.errback(reason)
+
+
 class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
     """A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
 
@@ -1091,6 +1217,32 @@ def read_body_with_max_size(
     return d
 
 
+def read_multipart_response(
+    response: IResponse, stream: ByteWriteable, boundary: str, max_length: Optional[int]
+) -> "defer.Deferred[MultipartResponse]":
+    """
+    Reads a MSC3916 multipart/mixed response and parses it, reading the file part (if it contains one) into
+    the stream passed in and returning a deferred resolving to a MultipartResponse
+
+    Args:
+        response: The HTTP response to read from.
+        stream: The file-object to write to.
+        boundary: the multipart/mixed boundary string
+        max_length: maximum allowable length of the response
+    """
+    d: defer.Deferred[MultipartResponse] = defer.Deferred()
+
+    # If the Content-Length header gives a size larger than the maximum allowed
+    # size, do not bother downloading the body.
+    if max_length is not None and response.length != UNKNOWN_LENGTH:
+        if response.length > max_length:
+            response.deliverBody(_DiscardBodyWithMaxSizeProtocol(d))
+            return d
+
+    response.deliverBody(_MultipartParserProtocol(stream, d, boundary, max_length))
+    return d
+
+
 def encode_query_args(args: Optional[QueryParams]) -> bytes:
     """
     Encodes a map of query arguments to bytes which can be appended to a URL.
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 104b803b0f..749b01dd0e 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -75,9 +75,11 @@ from synapse.http.client import (
     BlocklistingAgentWrapper,
     BodyExceededMaxSize,
     ByteWriteable,
+    SimpleHttpClient,
     _make_scheduler,
     encode_query_args,
     read_body_with_max_size,
+    read_multipart_response,
 )
 from synapse.http.connectproxyclient import BearerProxyCredentials
 from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
@@ -466,6 +468,13 @@ class MatrixFederationHttpClient:
 
         self._sleeper = AwakenableSleeper(self.reactor)
 
+        self._simple_http_client = SimpleHttpClient(
+            hs,
+            ip_blocklist=hs.config.server.federation_ip_range_blocklist,
+            ip_allowlist=hs.config.server.federation_ip_range_allowlist,
+            use_proxy=True,
+        )
+
     def wake_destination(self, destination: str) -> None:
         """Called when the remote server may have come back online."""
 
@@ -1553,6 +1562,189 @@ class MatrixFederationHttpClient:
         )
         return length, headers
 
+    async def federation_get_file(
+        self,
+        destination: str,
+        path: str,
+        output_stream: BinaryIO,
+        download_ratelimiter: Ratelimiter,
+        ip_address: str,
+        max_size: int,
+        args: Optional[QueryParams] = None,
+        retry_on_dns_fail: bool = True,
+        ignore_backoff: bool = False,
+    ) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
+        """GETs a file from a given homeserver over the federation /download endpoint
+        Args:
+            destination: The remote server to send the HTTP request to.
+            path: The HTTP path to GET.
+            output_stream: File to write the response body to.
+            download_ratelimiter: a ratelimiter to limit remote media downloads, keyed to
+                requester IP
+            ip_address: IP address of the requester
+            max_size: maximum allowable size in bytes of the file
+            args: Optional dictionary used to create the query string.
+            ignore_backoff: true to ignore the historical backoff data
+                and try the request anyway.
+
+        Returns:
+            Resolves to an (int, dict, bytes) tuple of
+            the file length, a dict of the response headers, and the file json
+
+        Raises:
+            HttpResponseException: If we get an HTTP response code >= 300
+                (except 429).
+            NotRetryingDestination: If we are not yet ready to retry this
+                server.
+            FederationDeniedError: If this destination is not on our
+                federation whitelist
+            RequestSendFailed: If there were problems connecting to the
+                remote, due to e.g. DNS failures, connection timeouts etc.
+            SynapseError: If the requested file exceeds ratelimits or the response from the
+            remote server is not a multipart response
+            AssertionError: if the resolved multipart response's length is None
+        """
+        request = MatrixFederationRequest(
+            method="GET", destination=destination, path=path, query=args
+        )
+
+        # check for a minimum balance of 1MiB in ratelimiter before initiating request
+        send_req, _ = await download_ratelimiter.can_do_action(
+            requester=None, key=ip_address, n_actions=1048576, update=False
+        )
+
+        if not send_req:
+            msg = "Requested file size exceeds ratelimits"
+            logger.warning(
+                "{%s} [%s] %s",
+                request.txn_id,
+                request.destination,
+                msg,
+            )
+            raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
+
+        response = await self._send_request(
+            request,
+            retry_on_dns_fail=retry_on_dns_fail,
+            ignore_backoff=ignore_backoff,
+        )
+
+        headers = dict(response.headers.getAllRawHeaders())
+
+        expected_size = response.length
+        # if we don't get an expected length then use the max length
+        if expected_size == UNKNOWN_LENGTH:
+            expected_size = max_size
+            logger.debug(
+                f"File size unknown, assuming file is max allowable size: {max_size}"
+            )
+
+        read_body, _ = await download_ratelimiter.can_do_action(
+            requester=None,
+            key=ip_address,
+            n_actions=expected_size,
+        )
+        if not read_body:
+            msg = "Requested file size exceeds ratelimits"
+            logger.warning(
+                "{%s} [%s] %s",
+                request.txn_id,
+                request.destination,
+                msg,
+            )
+            raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
+
+        # this should be a multipart/mixed response with the boundary string in the header
+        try:
+            raw_content_type = headers.get(b"Content-Type")
+            assert raw_content_type is not None
+            content_type = raw_content_type[0].decode("UTF-8")
+            content_type_parts = content_type.split("boundary=")
+            boundary = content_type_parts[1]
+        except Exception:
+            msg = "Remote response is malformed: expected Content-Type of multipart/mixed with a boundary present."
+            logger.warning(
+                "{%s} [%s] %s",
+                request.txn_id,
+                request.destination,
+                msg,
+            )
+            raise SynapseError(HTTPStatus.BAD_GATEWAY, msg)
+
+        try:
+            # add a byte of headroom to max size as `_MultipartParserProtocol.dataReceived` errs at >=
+            deferred = read_multipart_response(
+                response, output_stream, boundary, expected_size + 1
+            )
+            deferred.addTimeout(self.default_timeout_seconds, self.reactor)
+        except BodyExceededMaxSize:
+            msg = "Requested file is too large > %r bytes" % (expected_size,)
+            logger.warning(
+                "{%s} [%s] %s",
+                request.txn_id,
+                request.destination,
+                msg,
+            )
+            raise SynapseError(HTTPStatus.BAD_GATEWAY, msg, Codes.TOO_LARGE)
+        except defer.TimeoutError as e:
+            logger.warning(
+                "{%s} [%s] Timed out reading response - %s %s",
+                request.txn_id,
+                request.destination,
+                request.method,
+                request.uri.decode("ascii"),
+            )
+            raise RequestSendFailed(e, can_retry=True) from e
+        except ResponseFailed as e:
+            logger.warning(
+                "{%s} [%s] Failed to read response - %s %s",
+                request.txn_id,
+                request.destination,
+                request.method,
+                request.uri.decode("ascii"),
+            )
+            raise RequestSendFailed(e, can_retry=True) from e
+        except Exception as e:
+            logger.warning(
+                "{%s} [%s] Error reading response: %s",
+                request.txn_id,
+                request.destination,
+                e,
+            )
+            raise
+
+        multipart_response = await make_deferred_yieldable(deferred)
+        if not multipart_response.url:
+            assert multipart_response.length is not None
+            length = multipart_response.length
+            headers[b"Content-Type"] = [multipart_response.content_type]
+            headers[b"Content-Disposition"] = [multipart_response.disposition]
+
+        # the response contained a redirect url to download the file from
+        else:
+            str_url = multipart_response.url.decode("utf-8")
+            logger.info(
+                "{%s} [%s] File download redirected, now downloading from: %s",
+                request.txn_id,
+                request.destination,
+                str_url,
+            )
+            length, headers, _, _ = await self._simple_http_client.get_file(
+                str_url, output_stream, expected_size
+            )
+
+        logger.info(
+            "{%s} [%s] Completed: %d %s [%d bytes] %s %s",
+            request.txn_id,
+            request.destination,
+            response.code,
+            response.phrase.decode("ascii", errors="replace"),
+            length,
+            request.method,
+            request.uri.decode("ascii"),
+        )
+        return length, headers, multipart_response.json
+
 
 def _flatten_response_never_received(e: BaseException) -> str:
     if hasattr(e, "reasons"):