diff options
author | Shay <hillerys@element.io> | 2024-07-02 06:07:04 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-07-02 14:07:04 +0100 |
commit | 8f890447b0f8b6cbe369b162670185e8c746b2f2 (patch) | |
tree | c8c290661a59b06257ce7e2fda19e799d83825eb /synapse/http | |
parent | Fix sync waiting for an invalid token from the "future" (#17386) (diff) | |
download | synapse-8f890447b0f8b6cbe369b162670185e8c746b2f2.tar.xz |
Support MSC3916 by adding `_matrix/client/v1/media/download` endpoint (#17365)
Diffstat (limited to 'synapse/http')
-rw-r--r-- | synapse/http/client.py | 152 | ||||
-rw-r--r-- | synapse/http/matrixfederationclient.py | 192 |
2 files changed, 344 insertions, 0 deletions
diff --git a/synapse/http/client.py b/synapse/http/client.py index 4718517c97..56ad28eabf 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -35,6 +35,8 @@ from typing import ( Union, ) +import attr +import multipart import treq from canonicaljson import encode_canonical_json from netaddr import AddrFormatError, IPAddress, IPSet @@ -1006,6 +1008,130 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol): self._maybe_fail() +@attr.s(auto_attribs=True, slots=True) +class MultipartResponse: + """ + A small class to hold parsed values of a multipart response. + """ + + json: bytes = b"{}" + length: Optional[int] = None + content_type: Optional[bytes] = None + disposition: Optional[bytes] = None + url: Optional[bytes] = None + + +class _MultipartParserProtocol(protocol.Protocol): + """ + Protocol to read and parse a MSC3916 multipart/mixed response + """ + + transport: Optional[ITCPTransport] = None + + def __init__( + self, + stream: ByteWriteable, + deferred: defer.Deferred, + boundary: str, + max_length: Optional[int], + ) -> None: + self.stream = stream + self.deferred = deferred + self.boundary = boundary + self.max_length = max_length + self.parser = None + self.multipart_response = MultipartResponse() + self.has_redirect = False + self.in_json = False + self.json_done = False + self.file_length = 0 + self.total_length = 0 + self.in_disposition = False + self.in_content_type = False + + def dataReceived(self, incoming_data: bytes) -> None: + if self.deferred.called: + return + + # we don't have a parser yet, instantiate it + if not self.parser: + + def on_header_field(data: bytes, start: int, end: int) -> None: + if data[start:end] == b"Location": + self.has_redirect = True + if data[start:end] == b"Content-Disposition": + self.in_disposition = True + if data[start:end] == b"Content-Type": + self.in_content_type = True + + def on_header_value(data: bytes, start: int, end: int) -> None: + # the first header should be content-type for application/json + if not self.in_json and not self.json_done: + assert data[start:end] == b"application/json" + self.in_json = True + elif self.has_redirect: + self.multipart_response.url = data[start:end] + elif self.in_content_type: + self.multipart_response.content_type = data[start:end] + self.in_content_type = False + elif self.in_disposition: + self.multipart_response.disposition = data[start:end] + self.in_disposition = False + + def on_part_data(data: bytes, start: int, end: int) -> None: + # we've seen json header but haven't written the json data + if self.in_json and not self.json_done: + self.multipart_response.json = data[start:end] + self.json_done = True + # we have a redirect header rather than a file, and have already captured it + elif self.has_redirect: + return + # otherwise we are in the file part + else: + logger.info("Writing multipart file data to stream") + try: + self.stream.write(data[start:end]) + except Exception as e: + logger.warning( + f"Exception encountered writing file data to stream: {e}" + ) + self.deferred.errback() + self.file_length += end - start + + callbacks = { + "on_header_field": on_header_field, + "on_header_value": on_header_value, + "on_part_data": on_part_data, + } + self.parser = multipart.MultipartParser(self.boundary, callbacks) + + self.total_length += len(incoming_data) + if self.max_length is not None and self.total_length >= self.max_length: + self.deferred.errback(BodyExceededMaxSize()) + # Close the connection (forcefully) since all the data will get + # discarded anyway. + assert self.transport is not None + self.transport.abortConnection() + + try: + self.parser.write(incoming_data) # type: ignore[attr-defined] + except Exception as e: + logger.warning(f"Exception writing to multipart parser: {e}") + self.deferred.errback() + return + + def connectionLost(self, reason: Failure = connectionDone) -> None: + # If the maximum size was already exceeded, there's nothing to do. + if self.deferred.called: + return + + if reason.check(ResponseDone): + self.multipart_response.length = self.file_length + self.deferred.callback(self.multipart_response) + else: + self.deferred.errback(reason) + + class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): """A protocol which reads body to a stream, erroring if the body exceeds a maximum size.""" @@ -1091,6 +1217,32 @@ def read_body_with_max_size( return d +def read_multipart_response( + response: IResponse, stream: ByteWriteable, boundary: str, max_length: Optional[int] +) -> "defer.Deferred[MultipartResponse]": + """ + Reads a MSC3916 multipart/mixed response and parses it, reading the file part (if it contains one) into + the stream passed in and returning a deferred resolving to a MultipartResponse + + Args: + response: The HTTP response to read from. + stream: The file-object to write to. + boundary: the multipart/mixed boundary string + max_length: maximum allowable length of the response + """ + d: defer.Deferred[MultipartResponse] = defer.Deferred() + + # If the Content-Length header gives a size larger than the maximum allowed + # size, do not bother downloading the body. + if max_length is not None and response.length != UNKNOWN_LENGTH: + if response.length > max_length: + response.deliverBody(_DiscardBodyWithMaxSizeProtocol(d)) + return d + + response.deliverBody(_MultipartParserProtocol(stream, d, boundary, max_length)) + return d + + def encode_query_args(args: Optional[QueryParams]) -> bytes: """ Encodes a map of query arguments to bytes which can be appended to a URL. diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 104b803b0f..749b01dd0e 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -75,9 +75,11 @@ from synapse.http.client import ( BlocklistingAgentWrapper, BodyExceededMaxSize, ByteWriteable, + SimpleHttpClient, _make_scheduler, encode_query_args, read_body_with_max_size, + read_multipart_response, ) from synapse.http.connectproxyclient import BearerProxyCredentials from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent @@ -466,6 +468,13 @@ class MatrixFederationHttpClient: self._sleeper = AwakenableSleeper(self.reactor) + self._simple_http_client = SimpleHttpClient( + hs, + ip_blocklist=hs.config.server.federation_ip_range_blocklist, + ip_allowlist=hs.config.server.federation_ip_range_allowlist, + use_proxy=True, + ) + def wake_destination(self, destination: str) -> None: """Called when the remote server may have come back online.""" @@ -1553,6 +1562,189 @@ class MatrixFederationHttpClient: ) return length, headers + async def federation_get_file( + self, + destination: str, + path: str, + output_stream: BinaryIO, + download_ratelimiter: Ratelimiter, + ip_address: str, + max_size: int, + args: Optional[QueryParams] = None, + retry_on_dns_fail: bool = True, + ignore_backoff: bool = False, + ) -> Tuple[int, Dict[bytes, List[bytes]], bytes]: + """GETs a file from a given homeserver over the federation /download endpoint + Args: + destination: The remote server to send the HTTP request to. + path: The HTTP path to GET. + output_stream: File to write the response body to. + download_ratelimiter: a ratelimiter to limit remote media downloads, keyed to + requester IP + ip_address: IP address of the requester + max_size: maximum allowable size in bytes of the file + args: Optional dictionary used to create the query string. + ignore_backoff: true to ignore the historical backoff data + and try the request anyway. + + Returns: + Resolves to an (int, dict, bytes) tuple of + the file length, a dict of the response headers, and the file json + + Raises: + HttpResponseException: If we get an HTTP response code >= 300 + (except 429). + NotRetryingDestination: If we are not yet ready to retry this + server. + FederationDeniedError: If this destination is not on our + federation whitelist + RequestSendFailed: If there were problems connecting to the + remote, due to e.g. DNS failures, connection timeouts etc. + SynapseError: If the requested file exceeds ratelimits or the response from the + remote server is not a multipart response + AssertionError: if the resolved multipart response's length is None + """ + request = MatrixFederationRequest( + method="GET", destination=destination, path=path, query=args + ) + + # check for a minimum balance of 1MiB in ratelimiter before initiating request + send_req, _ = await download_ratelimiter.can_do_action( + requester=None, key=ip_address, n_actions=1048576, update=False + ) + + if not send_req: + msg = "Requested file size exceeds ratelimits" + logger.warning( + "{%s} [%s] %s", + request.txn_id, + request.destination, + msg, + ) + raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED) + + response = await self._send_request( + request, + retry_on_dns_fail=retry_on_dns_fail, + ignore_backoff=ignore_backoff, + ) + + headers = dict(response.headers.getAllRawHeaders()) + + expected_size = response.length + # if we don't get an expected length then use the max length + if expected_size == UNKNOWN_LENGTH: + expected_size = max_size + logger.debug( + f"File size unknown, assuming file is max allowable size: {max_size}" + ) + + read_body, _ = await download_ratelimiter.can_do_action( + requester=None, + key=ip_address, + n_actions=expected_size, + ) + if not read_body: + msg = "Requested file size exceeds ratelimits" + logger.warning( + "{%s} [%s] %s", + request.txn_id, + request.destination, + msg, + ) + raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED) + + # this should be a multipart/mixed response with the boundary string in the header + try: + raw_content_type = headers.get(b"Content-Type") + assert raw_content_type is not None + content_type = raw_content_type[0].decode("UTF-8") + content_type_parts = content_type.split("boundary=") + boundary = content_type_parts[1] + except Exception: + msg = "Remote response is malformed: expected Content-Type of multipart/mixed with a boundary present." + logger.warning( + "{%s} [%s] %s", + request.txn_id, + request.destination, + msg, + ) + raise SynapseError(HTTPStatus.BAD_GATEWAY, msg) + + try: + # add a byte of headroom to max size as `_MultipartParserProtocol.dataReceived` errs at >= + deferred = read_multipart_response( + response, output_stream, boundary, expected_size + 1 + ) + deferred.addTimeout(self.default_timeout_seconds, self.reactor) + except BodyExceededMaxSize: + msg = "Requested file is too large > %r bytes" % (expected_size,) + logger.warning( + "{%s} [%s] %s", + request.txn_id, + request.destination, + msg, + ) + raise SynapseError(HTTPStatus.BAD_GATEWAY, msg, Codes.TOO_LARGE) + except defer.TimeoutError as e: + logger.warning( + "{%s} [%s] Timed out reading response - %s %s", + request.txn_id, + request.destination, + request.method, + request.uri.decode("ascii"), + ) + raise RequestSendFailed(e, can_retry=True) from e + except ResponseFailed as e: + logger.warning( + "{%s} [%s] Failed to read response - %s %s", + request.txn_id, + request.destination, + request.method, + request.uri.decode("ascii"), + ) + raise RequestSendFailed(e, can_retry=True) from e + except Exception as e: + logger.warning( + "{%s} [%s] Error reading response: %s", + request.txn_id, + request.destination, + e, + ) + raise + + multipart_response = await make_deferred_yieldable(deferred) + if not multipart_response.url: + assert multipart_response.length is not None + length = multipart_response.length + headers[b"Content-Type"] = [multipart_response.content_type] + headers[b"Content-Disposition"] = [multipart_response.disposition] + + # the response contained a redirect url to download the file from + else: + str_url = multipart_response.url.decode("utf-8") + logger.info( + "{%s} [%s] File download redirected, now downloading from: %s", + request.txn_id, + request.destination, + str_url, + ) + length, headers, _, _ = await self._simple_http_client.get_file( + str_url, output_stream, expected_size + ) + + logger.info( + "{%s} [%s] Completed: %d %s [%d bytes] %s %s", + request.txn_id, + request.destination, + response.code, + response.phrase.decode("ascii", errors="replace"), + length, + request.method, + request.uri.decode("ascii"), + ) + return length, headers, multipart_response.json + def _flatten_response_never_received(e: BaseException) -> str: if hasattr(e, "reasons"): |