diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index a99a9e09fc..26b8711851 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -130,7 +130,8 @@ class Ratelimiter:
Overrides the value set during instantiation if set.
burst_count: How many actions that can be performed before being limited.
Overrides the value set during instantiation if set.
- update: Whether to count this check as performing the action
+ update: Whether to count this check as performing the action. If the action
+ cannot be performed, the user's action count is not incremented at all.
n_actions: The number of times the user wants to do this action. If the user
cannot do all of the actions, the user's action count is not incremented
at all.
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index f0f5a37a57..7d80ff6998 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -1871,6 +1871,52 @@ class FederationClient(FederationBase):
return filtered_statuses, filtered_failures
+ async def federation_download_media(
+ self,
+ destination: str,
+ media_id: str,
+ output_stream: BinaryIO,
+ max_size: int,
+ max_timeout_ms: int,
+ download_ratelimiter: Ratelimiter,
+ ip_address: str,
+ ) -> Union[
+ Tuple[int, Dict[bytes, List[bytes]], bytes],
+ Tuple[int, Dict[bytes, List[bytes]]],
+ ]:
+ try:
+ return await self.transport_layer.federation_download_media(
+ destination,
+ media_id,
+ output_stream=output_stream,
+ max_size=max_size,
+ max_timeout_ms=max_timeout_ms,
+ download_ratelimiter=download_ratelimiter,
+ ip_address=ip_address,
+ )
+ except HttpResponseException as e:
+ # If an error is received that is due to an unrecognised endpoint,
+ # fallback to the _matrix/media/v3/download endpoint. Otherwise, consider it a legitimate error
+ # and raise.
+ if not is_unknown_endpoint(e):
+ raise
+
+ logger.debug(
+ "Couldn't download media %s/%s over _matrix/federation/v1/media/download, falling back to _matrix/media/v3/download path",
+ destination,
+ media_id,
+ )
+
+ return await self.transport_layer.download_media_v3(
+ destination,
+ media_id,
+ output_stream=output_stream,
+ max_size=max_size,
+ max_timeout_ms=max_timeout_ms,
+ download_ratelimiter=download_ratelimiter,
+ ip_address=ip_address,
+ )
+
async def download_media(
self,
destination: str,
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index af1336fe5f..206e91ed14 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -824,7 +824,6 @@ class TransportLayerClient:
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
path = f"/_matrix/media/r0/download/{destination}/{media_id}"
-
return await self.client.get_file(
destination,
path,
@@ -852,7 +851,6 @@ class TransportLayerClient:
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
path = f"/_matrix/media/v3/download/{destination}/{media_id}"
-
return await self.client.get_file(
destination,
path,
@@ -873,6 +871,29 @@ class TransportLayerClient:
ip_address=ip_address,
)
+ async def federation_download_media(
+ self,
+ destination: str,
+ media_id: str,
+ output_stream: BinaryIO,
+ max_size: int,
+ max_timeout_ms: int,
+ download_ratelimiter: Ratelimiter,
+ ip_address: str,
+ ) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
+ path = f"/_matrix/federation/v1/media/download/{media_id}"
+ return await self.client.federation_get_file(
+ destination,
+ path,
+ output_stream=output_stream,
+ max_size=max_size,
+ args={
+ "timeout_ms": str(max_timeout_ms),
+ },
+ download_ratelimiter=download_ratelimiter,
+ ip_address=ip_address,
+ )
+
def _create_path(federation_prefix: str, path: str, *args: str) -> str:
"""
diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py
index edaf0196d6..c44e5daa47 100644
--- a/synapse/federation/transport/server/__init__.py
+++ b/synapse/federation/transport/server/__init__.py
@@ -32,8 +32,8 @@ from synapse.federation.transport.server._base import (
from synapse.federation.transport.server.federation import (
FEDERATION_SERVLET_CLASSES,
FederationAccountStatusServlet,
+ FederationMediaDownloadServlet,
FederationUnstableClientKeysClaimServlet,
- FederationUnstableMediaDownloadServlet,
)
from synapse.http.server import HttpServer, JsonResource
from synapse.http.servlet import (
@@ -316,11 +316,8 @@ def register_servlets(
):
continue
- if servletclass == FederationUnstableMediaDownloadServlet:
- if (
- not hs.config.server.enable_media_repo
- or not hs.config.experimental.msc3916_authenticated_media_enabled
- ):
+ if servletclass == FederationMediaDownloadServlet:
+ if not hs.config.server.enable_media_repo:
continue
servletclass(
diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py
index 4e2717b565..e124481474 100644
--- a/synapse/federation/transport/server/_base.py
+++ b/synapse/federation/transport/server/_base.py
@@ -362,7 +362,7 @@ class BaseFederationServlet:
return None
if (
func.__self__.__class__.__name__ # type: ignore
- == "FederationUnstableMediaDownloadServlet"
+ == "FederationMediaDownloadServlet"
):
response = await func(
origin, content, request, *args, **kwargs
@@ -374,7 +374,7 @@ class BaseFederationServlet:
else:
if (
func.__self__.__class__.__name__ # type: ignore
- == "FederationUnstableMediaDownloadServlet"
+ == "FederationMediaDownloadServlet"
):
response = await func(
origin, content, request, *args, **kwargs
diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index 67bb907050..ec957768d4 100644
--- a/synapse/federation/transport/server/federation.py
+++ b/synapse/federation/transport/server/federation.py
@@ -790,7 +790,7 @@ class FederationAccountStatusServlet(BaseFederationServerServlet):
return 200, {"account_statuses": statuses, "failures": failures}
-class FederationUnstableMediaDownloadServlet(BaseFederationServerServlet):
+class FederationMediaDownloadServlet(BaseFederationServerServlet):
"""
Implementation of new federation media `/download` endpoint outlined in MSC3916. Returns
a multipart/mixed response consisting of a JSON object and the requested media
@@ -798,7 +798,6 @@ class FederationUnstableMediaDownloadServlet(BaseFederationServerServlet):
"""
PATH = "/media/download/(?P<media_id>[^/]*)"
- PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3916"
RATELIMIT = True
def __init__(
@@ -858,5 +857,5 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
FederationV1SendKnockServlet,
FederationMakeKnockServlet,
FederationAccountStatusServlet,
- FederationUnstableMediaDownloadServlet,
+ FederationMediaDownloadServlet,
)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 4718517c97..56ad28eabf 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -35,6 +35,8 @@ from typing import (
Union,
)
+import attr
+import multipart
import treq
from canonicaljson import encode_canonical_json
from netaddr import AddrFormatError, IPAddress, IPSet
@@ -1006,6 +1008,130 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
self._maybe_fail()
+@attr.s(auto_attribs=True, slots=True)
+class MultipartResponse:
+ """
+ A small class to hold parsed values of a multipart response.
+ """
+
+ json: bytes = b"{}"
+ length: Optional[int] = None
+ content_type: Optional[bytes] = None
+ disposition: Optional[bytes] = None
+ url: Optional[bytes] = None
+
+
+class _MultipartParserProtocol(protocol.Protocol):
+ """
+ Protocol to read and parse a MSC3916 multipart/mixed response
+ """
+
+ transport: Optional[ITCPTransport] = None
+
+ def __init__(
+ self,
+ stream: ByteWriteable,
+ deferred: defer.Deferred,
+ boundary: str,
+ max_length: Optional[int],
+ ) -> None:
+ self.stream = stream
+ self.deferred = deferred
+ self.boundary = boundary
+ self.max_length = max_length
+ self.parser = None
+ self.multipart_response = MultipartResponse()
+ self.has_redirect = False
+ self.in_json = False
+ self.json_done = False
+ self.file_length = 0
+ self.total_length = 0
+ self.in_disposition = False
+ self.in_content_type = False
+
+ def dataReceived(self, incoming_data: bytes) -> None:
+ if self.deferred.called:
+ return
+
+ # we don't have a parser yet, instantiate it
+ if not self.parser:
+
+ def on_header_field(data: bytes, start: int, end: int) -> None:
+ if data[start:end] == b"Location":
+ self.has_redirect = True
+ if data[start:end] == b"Content-Disposition":
+ self.in_disposition = True
+ if data[start:end] == b"Content-Type":
+ self.in_content_type = True
+
+ def on_header_value(data: bytes, start: int, end: int) -> None:
+ # the first header should be content-type for application/json
+ if not self.in_json and not self.json_done:
+ assert data[start:end] == b"application/json"
+ self.in_json = True
+ elif self.has_redirect:
+ self.multipart_response.url = data[start:end]
+ elif self.in_content_type:
+ self.multipart_response.content_type = data[start:end]
+ self.in_content_type = False
+ elif self.in_disposition:
+ self.multipart_response.disposition = data[start:end]
+ self.in_disposition = False
+
+ def on_part_data(data: bytes, start: int, end: int) -> None:
+ # we've seen json header but haven't written the json data
+ if self.in_json and not self.json_done:
+ self.multipart_response.json = data[start:end]
+ self.json_done = True
+ # we have a redirect header rather than a file, and have already captured it
+ elif self.has_redirect:
+ return
+ # otherwise we are in the file part
+ else:
+ logger.info("Writing multipart file data to stream")
+ try:
+ self.stream.write(data[start:end])
+ except Exception as e:
+ logger.warning(
+ f"Exception encountered writing file data to stream: {e}"
+ )
+ self.deferred.errback()
+ self.file_length += end - start
+
+ callbacks = {
+ "on_header_field": on_header_field,
+ "on_header_value": on_header_value,
+ "on_part_data": on_part_data,
+ }
+ self.parser = multipart.MultipartParser(self.boundary, callbacks)
+
+ self.total_length += len(incoming_data)
+ if self.max_length is not None and self.total_length >= self.max_length:
+ self.deferred.errback(BodyExceededMaxSize())
+ # Close the connection (forcefully) since all the data will get
+ # discarded anyway.
+ assert self.transport is not None
+ self.transport.abortConnection()
+
+ try:
+ self.parser.write(incoming_data) # type: ignore[attr-defined]
+ except Exception as e:
+ logger.warning(f"Exception writing to multipart parser: {e}")
+ self.deferred.errback()
+ return
+
+ def connectionLost(self, reason: Failure = connectionDone) -> None:
+ # If the maximum size was already exceeded, there's nothing to do.
+ if self.deferred.called:
+ return
+
+ if reason.check(ResponseDone):
+ self.multipart_response.length = self.file_length
+ self.deferred.callback(self.multipart_response)
+ else:
+ self.deferred.errback(reason)
+
+
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
@@ -1091,6 +1217,32 @@ def read_body_with_max_size(
return d
+def read_multipart_response(
+ response: IResponse, stream: ByteWriteable, boundary: str, max_length: Optional[int]
+) -> "defer.Deferred[MultipartResponse]":
+ """
+ Reads a MSC3916 multipart/mixed response and parses it, reading the file part (if it contains one) into
+ the stream passed in and returning a deferred resolving to a MultipartResponse
+
+ Args:
+ response: The HTTP response to read from.
+ stream: The file-object to write to.
+ boundary: the multipart/mixed boundary string
+ max_length: maximum allowable length of the response
+ """
+ d: defer.Deferred[MultipartResponse] = defer.Deferred()
+
+ # If the Content-Length header gives a size larger than the maximum allowed
+ # size, do not bother downloading the body.
+ if max_length is not None and response.length != UNKNOWN_LENGTH:
+ if response.length > max_length:
+ response.deliverBody(_DiscardBodyWithMaxSizeProtocol(d))
+ return d
+
+ response.deliverBody(_MultipartParserProtocol(stream, d, boundary, max_length))
+ return d
+
+
def encode_query_args(args: Optional[QueryParams]) -> bytes:
"""
Encodes a map of query arguments to bytes which can be appended to a URL.
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 104b803b0f..749b01dd0e 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -75,9 +75,11 @@ from synapse.http.client import (
BlocklistingAgentWrapper,
BodyExceededMaxSize,
ByteWriteable,
+ SimpleHttpClient,
_make_scheduler,
encode_query_args,
read_body_with_max_size,
+ read_multipart_response,
)
from synapse.http.connectproxyclient import BearerProxyCredentials
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
@@ -466,6 +468,13 @@ class MatrixFederationHttpClient:
self._sleeper = AwakenableSleeper(self.reactor)
+ self._simple_http_client = SimpleHttpClient(
+ hs,
+ ip_blocklist=hs.config.server.federation_ip_range_blocklist,
+ ip_allowlist=hs.config.server.federation_ip_range_allowlist,
+ use_proxy=True,
+ )
+
def wake_destination(self, destination: str) -> None:
"""Called when the remote server may have come back online."""
@@ -1553,6 +1562,189 @@ class MatrixFederationHttpClient:
)
return length, headers
+ async def federation_get_file(
+ self,
+ destination: str,
+ path: str,
+ output_stream: BinaryIO,
+ download_ratelimiter: Ratelimiter,
+ ip_address: str,
+ max_size: int,
+ args: Optional[QueryParams] = None,
+ retry_on_dns_fail: bool = True,
+ ignore_backoff: bool = False,
+ ) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
+ """GETs a file from a given homeserver over the federation /download endpoint
+ Args:
+ destination: The remote server to send the HTTP request to.
+ path: The HTTP path to GET.
+ output_stream: File to write the response body to.
+ download_ratelimiter: a ratelimiter to limit remote media downloads, keyed to
+ requester IP
+ ip_address: IP address of the requester
+ max_size: maximum allowable size in bytes of the file
+ args: Optional dictionary used to create the query string.
+ ignore_backoff: true to ignore the historical backoff data
+ and try the request anyway.
+
+ Returns:
+ Resolves to an (int, dict, bytes) tuple of
+ the file length, a dict of the response headers, and the file json
+
+ Raises:
+ HttpResponseException: If we get an HTTP response code >= 300
+ (except 429).
+ NotRetryingDestination: If we are not yet ready to retry this
+ server.
+ FederationDeniedError: If this destination is not on our
+ federation whitelist
+ RequestSendFailed: If there were problems connecting to the
+ remote, due to e.g. DNS failures, connection timeouts etc.
+ SynapseError: If the requested file exceeds ratelimits or the response from the
+ remote server is not a multipart response
+ AssertionError: if the resolved multipart response's length is None
+ """
+ request = MatrixFederationRequest(
+ method="GET", destination=destination, path=path, query=args
+ )
+
+ # check for a minimum balance of 1MiB in ratelimiter before initiating request
+ send_req, _ = await download_ratelimiter.can_do_action(
+ requester=None, key=ip_address, n_actions=1048576, update=False
+ )
+
+ if not send_req:
+ msg = "Requested file size exceeds ratelimits"
+ logger.warning(
+ "{%s} [%s] %s",
+ request.txn_id,
+ request.destination,
+ msg,
+ )
+ raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
+
+ response = await self._send_request(
+ request,
+ retry_on_dns_fail=retry_on_dns_fail,
+ ignore_backoff=ignore_backoff,
+ )
+
+ headers = dict(response.headers.getAllRawHeaders())
+
+ expected_size = response.length
+ # if we don't get an expected length then use the max length
+ if expected_size == UNKNOWN_LENGTH:
+ expected_size = max_size
+ logger.debug(
+ f"File size unknown, assuming file is max allowable size: {max_size}"
+ )
+
+ read_body, _ = await download_ratelimiter.can_do_action(
+ requester=None,
+ key=ip_address,
+ n_actions=expected_size,
+ )
+ if not read_body:
+ msg = "Requested file size exceeds ratelimits"
+ logger.warning(
+ "{%s} [%s] %s",
+ request.txn_id,
+ request.destination,
+ msg,
+ )
+ raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
+
+ # this should be a multipart/mixed response with the boundary string in the header
+ try:
+ raw_content_type = headers.get(b"Content-Type")
+ assert raw_content_type is not None
+ content_type = raw_content_type[0].decode("UTF-8")
+ content_type_parts = content_type.split("boundary=")
+ boundary = content_type_parts[1]
+ except Exception:
+ msg = "Remote response is malformed: expected Content-Type of multipart/mixed with a boundary present."
+ logger.warning(
+ "{%s} [%s] %s",
+ request.txn_id,
+ request.destination,
+ msg,
+ )
+ raise SynapseError(HTTPStatus.BAD_GATEWAY, msg)
+
+ try:
+ # add a byte of headroom to max size as `_MultipartParserProtocol.dataReceived` errs at >=
+ deferred = read_multipart_response(
+ response, output_stream, boundary, expected_size + 1
+ )
+ deferred.addTimeout(self.default_timeout_seconds, self.reactor)
+ except BodyExceededMaxSize:
+ msg = "Requested file is too large > %r bytes" % (expected_size,)
+ logger.warning(
+ "{%s} [%s] %s",
+ request.txn_id,
+ request.destination,
+ msg,
+ )
+ raise SynapseError(HTTPStatus.BAD_GATEWAY, msg, Codes.TOO_LARGE)
+ except defer.TimeoutError as e:
+ logger.warning(
+ "{%s} [%s] Timed out reading response - %s %s",
+ request.txn_id,
+ request.destination,
+ request.method,
+ request.uri.decode("ascii"),
+ )
+ raise RequestSendFailed(e, can_retry=True) from e
+ except ResponseFailed as e:
+ logger.warning(
+ "{%s} [%s] Failed to read response - %s %s",
+ request.txn_id,
+ request.destination,
+ request.method,
+ request.uri.decode("ascii"),
+ )
+ raise RequestSendFailed(e, can_retry=True) from e
+ except Exception as e:
+ logger.warning(
+ "{%s} [%s] Error reading response: %s",
+ request.txn_id,
+ request.destination,
+ e,
+ )
+ raise
+
+ multipart_response = await make_deferred_yieldable(deferred)
+ if not multipart_response.url:
+ assert multipart_response.length is not None
+ length = multipart_response.length
+ headers[b"Content-Type"] = [multipart_response.content_type]
+ headers[b"Content-Disposition"] = [multipart_response.disposition]
+
+ # the response contained a redirect url to download the file from
+ else:
+ str_url = multipart_response.url.decode("utf-8")
+ logger.info(
+ "{%s} [%s] File download redirected, now downloading from: %s",
+ request.txn_id,
+ request.destination,
+ str_url,
+ )
+ length, headers, _, _ = await self._simple_http_client.get_file(
+ str_url, output_stream, expected_size
+ )
+
+ logger.info(
+ "{%s} [%s] Completed: %d %s [%d bytes] %s %s",
+ request.txn_id,
+ request.destination,
+ response.code,
+ response.phrase.decode("ascii", errors="replace"),
+ length,
+ request.method,
+ request.uri.decode("ascii"),
+ )
+ return length, headers, multipart_response.json
+
def _flatten_response_never_received(e: BaseException) -> str:
if hasattr(e, "reasons"):
diff --git a/synapse/media/_base.py b/synapse/media/_base.py
index 7ad0b7c3cf..1b268ce4d4 100644
--- a/synapse/media/_base.py
+++ b/synapse/media/_base.py
@@ -221,6 +221,7 @@ def add_file_headers(
# select private. don't bother setting Expires as all our
# clients are smart enough to be happy with Cache-Control
request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
+
if file_size is not None:
request.setHeader(b"Content-Length", b"%d" % (file_size,))
@@ -302,12 +303,37 @@ async def respond_with_multipart_responder(
)
return
+ if media_info.media_type.lower().split(";", 1)[0] in INLINE_CONTENT_TYPES:
+ disposition = "inline"
+ else:
+ disposition = "attachment"
+
+ def _quote(x: str) -> str:
+ return urllib.parse.quote(x.encode("utf-8"))
+
+ if media_info.upload_name:
+ if _can_encode_filename_as_token(media_info.upload_name):
+ disposition = "%s; filename=%s" % (
+ disposition,
+ media_info.upload_name,
+ )
+ else:
+ disposition = "%s; filename*=utf-8''%s" % (
+ disposition,
+ _quote(media_info.upload_name),
+ )
+
from synapse.media.media_storage import MultipartFileConsumer
# note that currently the json_object is just {}, this will change when linked media
# is implemented
multipart_consumer = MultipartFileConsumer(
- clock, request, media_info.media_type, {}, media_info.media_length
+ clock,
+ request,
+ media_info.media_type,
+ {},
+ disposition,
+ media_info.media_length,
)
logger.debug("Responding to media request with responder %s", responder)
diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 1436329fad..542642b900 100644
--- a/synapse/media/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -480,6 +480,7 @@ class MediaRepository:
name: Optional[str],
max_timeout_ms: int,
ip_address: str,
+ use_federation_endpoint: bool,
) -> None:
"""Respond to requests for remote media.
@@ -492,6 +493,8 @@ class MediaRepository:
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
ip_address: the IP address of the requester
+ use_federation_endpoint: whether to request the remote media over the new
+ federation `/download` endpoint
Returns:
Resolves once a response has successfully been written to request
@@ -522,6 +525,7 @@ class MediaRepository:
max_timeout_ms,
self.download_ratelimiter,
ip_address,
+ use_federation_endpoint,
)
# We deliberately stream the file outside the lock
@@ -569,6 +573,7 @@ class MediaRepository:
max_timeout_ms,
self.download_ratelimiter,
ip_address,
+ False,
)
# Ensure we actually use the responder so that it releases resources
@@ -585,6 +590,7 @@ class MediaRepository:
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
+ use_federation_endpoint: bool,
) -> Tuple[Optional[Responder], RemoteMedia]:
"""Looks for media in local cache, if not there then attempt to
download from remote server.
@@ -598,6 +604,8 @@ class MediaRepository:
download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
requester IP.
ip_address: the IP address of the requester
+ use_federation_endpoint: whether to request the remote media over the new federation
+ /download endpoint
Returns:
A tuple of responder and the media info of the file.
@@ -629,9 +637,23 @@ class MediaRepository:
# Failed to find the file anywhere, lets download it.
try:
- media_info = await self._download_remote_file(
- server_name, media_id, max_timeout_ms, download_ratelimiter, ip_address
- )
+ if not use_federation_endpoint:
+ media_info = await self._download_remote_file(
+ server_name,
+ media_id,
+ max_timeout_ms,
+ download_ratelimiter,
+ ip_address,
+ )
+ else:
+ media_info = await self._federation_download_remote_file(
+ server_name,
+ media_id,
+ max_timeout_ms,
+ download_ratelimiter,
+ ip_address,
+ )
+
except SynapseError:
raise
except Exception as e:
@@ -775,6 +797,129 @@ class MediaRepository:
quarantined_by=None,
)
+ async def _federation_download_remote_file(
+ self,
+ server_name: str,
+ media_id: str,
+ max_timeout_ms: int,
+ download_ratelimiter: Ratelimiter,
+ ip_address: str,
+ ) -> RemoteMedia:
+ """Attempt to download the remote file from the given server name.
+ Uses the given file_id as the local id and downloads the file over the federation
+ v1 download endpoint
+
+ Args:
+ server_name: Originating server
+ media_id: The media ID of the content (as defined by the
+ remote server). This is different than the file_id, which is
+ locally generated.
+ max_timeout_ms: the maximum number of milliseconds to wait for the
+ media to be uploaded.
+ download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
+ requester IP
+ ip_address: the IP address of the requester
+
+ Returns:
+ The media info of the file.
+ """
+
+ file_id = random_string(24)
+
+ file_info = FileInfo(server_name=server_name, file_id=file_id)
+
+ async with self.media_storage.store_into_file(file_info) as (f, fname):
+ try:
+ res = await self.client.federation_download_media(
+ server_name,
+ media_id,
+ output_stream=f,
+ max_size=self.max_upload_size,
+ max_timeout_ms=max_timeout_ms,
+ download_ratelimiter=download_ratelimiter,
+ ip_address=ip_address,
+ )
+ # if we had to fall back to the _matrix/media endpoint it will only return
+ # the headers and length, check the length of the tuple before unpacking
+ if len(res) == 3:
+ length, headers, json = res
+ else:
+ length, headers = res
+ except RequestSendFailed as e:
+ logger.warning(
+ "Request failed fetching remote media %s/%s: %r",
+ server_name,
+ media_id,
+ e,
+ )
+ raise SynapseError(502, "Failed to fetch remote media")
+
+ except HttpResponseException as e:
+ logger.warning(
+ "HTTP error fetching remote media %s/%s: %s",
+ server_name,
+ media_id,
+ e.response,
+ )
+ if e.code == twisted.web.http.NOT_FOUND:
+ raise e.to_synapse_error()
+ raise SynapseError(502, "Failed to fetch remote media")
+
+ except SynapseError:
+ logger.warning(
+ "Failed to fetch remote media %s/%s", server_name, media_id
+ )
+ raise
+ except NotRetryingDestination:
+ logger.warning("Not retrying destination %r", server_name)
+ raise SynapseError(502, "Failed to fetch remote media")
+ except Exception:
+ logger.exception(
+ "Failed to fetch remote media %s/%s", server_name, media_id
+ )
+ raise SynapseError(502, "Failed to fetch remote media")
+
+ if b"Content-Type" in headers:
+ media_type = headers[b"Content-Type"][0].decode("ascii")
+ else:
+ media_type = "application/octet-stream"
+ upload_name = get_filename_from_headers(headers)
+ time_now_ms = self.clock.time_msec()
+
+ # Multiple remote media download requests can race (when using
+ # multiple media repos), so this may throw a violation constraint
+ # exception. If it does we'll delete the newly downloaded file from
+ # disk (as we're in the ctx manager).
+ #
+ # However: we've already called `finish()` so we may have also
+ # written to the storage providers. This is preferable to the
+ # alternative where we call `finish()` *after* this, where we could
+ # end up having an entry in the DB but fail to write the files to
+ # the storage providers.
+ await self.store.store_cached_remote_media(
+ origin=server_name,
+ media_id=media_id,
+ media_type=media_type,
+ time_now_ms=time_now_ms,
+ upload_name=upload_name,
+ media_length=length,
+ filesystem_id=file_id,
+ )
+
+ logger.debug("Stored remote media in file %r", fname)
+
+ return RemoteMedia(
+ media_origin=server_name,
+ media_id=media_id,
+ media_type=media_type,
+ media_length=length,
+ upload_name=upload_name,
+ created_ts=time_now_ms,
+ filesystem_id=file_id,
+ last_access_ts=time_now_ms,
+ quarantined_by=None,
+ )
+
def _get_thumbnail_requirements(
self, media_type: str
) -> Tuple[ThumbnailRequirement, ...]:
diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py
index 1be2c9b5f5..2a106bb0eb 100644
--- a/synapse/media/media_storage.py
+++ b/synapse/media/media_storage.py
@@ -401,13 +401,14 @@ class MultipartFileConsumer:
wrapped_consumer: interfaces.IConsumer,
file_content_type: str,
json_object: JsonDict,
- content_length: Optional[int] = None,
+ disposition: str,
+ content_length: Optional[int],
) -> None:
self.clock = clock
self.wrapped_consumer = wrapped_consumer
self.json_field = json_object
self.json_field_written = False
- self.content_type_written = False
+ self.file_headers_written = False
self.file_content_type = file_content_type
self.boundary = uuid4().hex.encode("ascii")
@@ -420,6 +421,7 @@ class MultipartFileConsumer:
self.paused = False
self.length = content_length
+ self.disposition = disposition
### IConsumer APIs ###
@@ -488,11 +490,13 @@ class MultipartFileConsumer:
self.json_field_written = True
# if we haven't written the content type yet, do so
- if not self.content_type_written:
+ if not self.file_headers_written:
type = self.file_content_type.encode("utf-8")
content_type = Header(b"Content-Type", type)
- self.wrapped_consumer.write(bytes(content_type) + CRLF + CRLF)
- self.content_type_written = True
+ self.wrapped_consumer.write(bytes(content_type) + CRLF)
+ disp_header = Header(b"Content-Disposition", self.disposition)
+ self.wrapped_consumer.write(bytes(disp_header) + CRLF + CRLF)
+ self.file_headers_written = True
self.wrapped_consumer.write(data)
@@ -506,7 +510,6 @@ class MultipartFileConsumer:
producing data for good.
"""
assert self.producer is not None
-
self.paused = True
self.producer.stopProducing()
@@ -518,7 +521,6 @@ class MultipartFileConsumer:
the time being, and to stop until C{resumeProducing()} is called.
"""
assert self.producer is not None
-
self.paused = True
if self.streaming:
@@ -549,7 +551,7 @@ class MultipartFileConsumer:
"""
if not self.length:
return None
- # calculate length of json field and content-type header
+ # calculate length of json field and content-type, disposition headers
json_field = json.dumps(self.json_field)
json_bytes = json_field.encode("utf-8")
json_length = len(json_bytes)
@@ -558,9 +560,13 @@ class MultipartFileConsumer:
content_type = Header(b"Content-Type", type)
type_length = len(bytes(content_type))
- # 154 is the length of the elements that aren't variable, ie
+ disp = self.disposition.encode("utf-8")
+ disp_header = Header(b"Content-Disposition", disp)
+ disp_length = len(bytes(disp_header))
+
+ # 156 is the length of the elements that aren't variable, ie
# CRLFs and boundary strings, etc
- self.length += json_length + type_length + 154
+ self.length += json_length + type_length + disp_length + 156
return self.length
@@ -569,7 +575,6 @@ class MultipartFileConsumer:
async def _resumeProducingRepeatedly(self) -> None:
assert self.producer is not None
assert not self.streaming
-
producer = cast("interfaces.IPullProducer", self.producer)
self.paused = False
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 0024ccf708..c94d454a28 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -145,6 +145,10 @@ class ClientRestResource(JsonResource):
password_policy.register_servlets(hs, client_resource)
knock.register_servlets(hs, client_resource)
appservice_ping.register_servlets(hs, client_resource)
+ if hs.config.server.enable_media_repo:
+ from synapse.rest.client import media
+
+ media.register_servlets(hs, client_resource)
# moving to /_synapse/admin
if is_main_process:
diff --git a/synapse/rest/client/media.py b/synapse/rest/client/media.py
index 0c089163c1..c0ae5dd66f 100644
--- a/synapse/rest/client/media.py
+++ b/synapse/rest/client/media.py
@@ -22,6 +22,7 @@
import logging
import re
+from typing import Optional
from synapse.http.server import (
HttpServer,
@@ -194,14 +195,76 @@ class UnstableThumbnailResource(RestServlet):
self.media_repo.mark_recently_accessed(server_name, media_id)
+class DownloadResource(RestServlet):
+ PATTERNS = [
+ re.compile(
+ "/_matrix/client/v1/media/download/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)(/(?P<file_name>[^/]*))?$"
+ )
+ ]
+
+ def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
+ super().__init__()
+ self.media_repo = media_repo
+ self._is_mine_server_name = hs.is_mine_server_name
+ self.auth = hs.get_auth()
+
+ async def on_GET(
+ self,
+ request: SynapseRequest,
+ server_name: str,
+ media_id: str,
+ file_name: Optional[str] = None,
+ ) -> None:
+ # Validate the server name, raising if invalid
+ parse_and_validate_server_name(server_name)
+
+ await self.auth.get_user_by_req(request)
+
+ set_cors_headers(request)
+ set_corp_headers(request)
+ request.setHeader(
+ b"Content-Security-Policy",
+ b"sandbox;"
+ b" default-src 'none';"
+ b" script-src 'none';"
+ b" plugin-types application/pdf;"
+ b" style-src 'unsafe-inline';"
+ b" media-src 'self';"
+ b" object-src 'self';",
+ )
+ # Limited non-standard form of CSP for IE11
+ request.setHeader(b"X-Content-Security-Policy", b"sandbox;")
+ request.setHeader(b"Referrer-Policy", b"no-referrer")
+ max_timeout_ms = parse_integer(
+ request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS
+ )
+ max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS)
+
+ if self._is_mine_server_name(server_name):
+ await self.media_repo.get_local_media(
+ request, media_id, file_name, max_timeout_ms
+ )
+ else:
+ ip_address = request.getClientAddress().host
+ await self.media_repo.get_remote_media(
+ request,
+ server_name,
+ media_id,
+ file_name,
+ max_timeout_ms,
+ ip_address,
+ True,
+ )
+
+
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
- if hs.config.experimental.msc3916_authenticated_media_enabled:
- media_repo = hs.get_media_repository()
- if hs.config.media.url_preview_enabled:
- UnstablePreviewURLServlet(
- hs, media_repo, media_repo.media_storage
- ).register(http_server)
- UnstableMediaConfigResource(hs).register(http_server)
- UnstableThumbnailResource(hs, media_repo, media_repo.media_storage).register(
+ media_repo = hs.get_media_repository()
+ if hs.config.media.url_preview_enabled:
+ UnstablePreviewURLServlet(hs, media_repo, media_repo.media_storage).register(
http_server
)
+ UnstableMediaConfigResource(hs).register(http_server)
+ UnstableThumbnailResource(hs, media_repo, media_repo.media_storage).register(
+ http_server
+ )
+ DownloadResource(hs, media_repo).register(http_server)
diff --git a/synapse/rest/media/download_resource.py b/synapse/rest/media/download_resource.py
index 1628d58926..c32c626905 100644
--- a/synapse/rest/media/download_resource.py
+++ b/synapse/rest/media/download_resource.py
@@ -105,4 +105,5 @@ class DownloadResource(RestServlet):
file_name,
max_timeout_ms,
ip_address,
+ False,
)
|