diff --git a/synapse/http/client.py b/synapse/http/client.py
index 4718517c97..56ad28eabf 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -35,6 +35,8 @@ from typing import (
Union,
)
+import attr
+import multipart
import treq
from canonicaljson import encode_canonical_json
from netaddr import AddrFormatError, IPAddress, IPSet
@@ -1006,6 +1008,130 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
self._maybe_fail()
+@attr.s(auto_attribs=True, slots=True)
+class MultipartResponse:
+ """
+ A small class to hold parsed values of a multipart response.
+ """
+
+ json: bytes = b"{}"
+ length: Optional[int] = None
+ content_type: Optional[bytes] = None
+ disposition: Optional[bytes] = None
+ url: Optional[bytes] = None
+
+
+class _MultipartParserProtocol(protocol.Protocol):
+ """
+ Protocol to read and parse a MSC3916 multipart/mixed response
+ """
+
+ transport: Optional[ITCPTransport] = None
+
+ def __init__(
+ self,
+ stream: ByteWriteable,
+ deferred: defer.Deferred,
+ boundary: str,
+ max_length: Optional[int],
+ ) -> None:
+ self.stream = stream
+ self.deferred = deferred
+ self.boundary = boundary
+ self.max_length = max_length
+ self.parser = None
+ self.multipart_response = MultipartResponse()
+ self.has_redirect = False
+ self.in_json = False
+ self.json_done = False
+ self.file_length = 0
+ self.total_length = 0
+ self.in_disposition = False
+ self.in_content_type = False
+
+ def dataReceived(self, incoming_data: bytes) -> None:
+ if self.deferred.called:
+ return
+
+ # we don't have a parser yet, instantiate it
+ if not self.parser:
+
+ def on_header_field(data: bytes, start: int, end: int) -> None:
+ if data[start:end] == b"Location":
+ self.has_redirect = True
+ if data[start:end] == b"Content-Disposition":
+ self.in_disposition = True
+ if data[start:end] == b"Content-Type":
+ self.in_content_type = True
+
+ def on_header_value(data: bytes, start: int, end: int) -> None:
+ # the first header should be content-type for application/json
+ if not self.in_json and not self.json_done:
+ assert data[start:end] == b"application/json"
+ self.in_json = True
+ elif self.has_redirect:
+ self.multipart_response.url = data[start:end]
+ elif self.in_content_type:
+ self.multipart_response.content_type = data[start:end]
+ self.in_content_type = False
+ elif self.in_disposition:
+ self.multipart_response.disposition = data[start:end]
+ self.in_disposition = False
+
+ def on_part_data(data: bytes, start: int, end: int) -> None:
+ # we've seen json header but haven't written the json data
+ if self.in_json and not self.json_done:
+ self.multipart_response.json = data[start:end]
+ self.json_done = True
+ # we have a redirect header rather than a file, and have already captured it
+ elif self.has_redirect:
+ return
+ # otherwise we are in the file part
+ else:
+ logger.info("Writing multipart file data to stream")
+ try:
+ self.stream.write(data[start:end])
+ except Exception as e:
+ logger.warning(
+ f"Exception encountered writing file data to stream: {e}"
+ )
+ self.deferred.errback()
+ self.file_length += end - start
+
+ callbacks = {
+ "on_header_field": on_header_field,
+ "on_header_value": on_header_value,
+ "on_part_data": on_part_data,
+ }
+ self.parser = multipart.MultipartParser(self.boundary, callbacks)
+
+ self.total_length += len(incoming_data)
+ if self.max_length is not None and self.total_length >= self.max_length:
+ self.deferred.errback(BodyExceededMaxSize())
+ # Close the connection (forcefully) since all the data will get
+ # discarded anyway.
+ assert self.transport is not None
+ self.transport.abortConnection()
+
+ try:
+ self.parser.write(incoming_data) # type: ignore[attr-defined]
+ except Exception as e:
+ logger.warning(f"Exception writing to multipart parser: {e}")
+ self.deferred.errback()
+ return
+
+ def connectionLost(self, reason: Failure = connectionDone) -> None:
+ # If the maximum size was already exceeded, there's nothing to do.
+ if self.deferred.called:
+ return
+
+ if reason.check(ResponseDone):
+ self.multipart_response.length = self.file_length
+ self.deferred.callback(self.multipart_response)
+ else:
+ self.deferred.errback(reason)
+
+
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
@@ -1091,6 +1217,32 @@ def read_body_with_max_size(
return d
+def read_multipart_response(
+ response: IResponse, stream: ByteWriteable, boundary: str, max_length: Optional[int]
+) -> "defer.Deferred[MultipartResponse]":
+ """
+ Reads a MSC3916 multipart/mixed response and parses it, reading the file part (if it contains one) into
+ the stream passed in and returning a deferred resolving to a MultipartResponse
+
+ Args:
+ response: The HTTP response to read from.
+ stream: The file-object to write to.
+ boundary: the multipart/mixed boundary string
+ max_length: maximum allowable length of the response
+ """
+ d: defer.Deferred[MultipartResponse] = defer.Deferred()
+
+ # If the Content-Length header gives a size larger than the maximum allowed
+ # size, do not bother downloading the body.
+ if max_length is not None and response.length != UNKNOWN_LENGTH:
+ if response.length > max_length:
+ response.deliverBody(_DiscardBodyWithMaxSizeProtocol(d))
+ return d
+
+ response.deliverBody(_MultipartParserProtocol(stream, d, boundary, max_length))
+ return d
+
+
def encode_query_args(args: Optional[QueryParams]) -> bytes:
"""
Encodes a map of query arguments to bytes which can be appended to a URL.
|