summary refs log tree commit diff
diff options
context:
space:
mode:
authorShay <hillerys@element.io>2024-07-02 06:07:04 -0700
committerGitHub <noreply@github.com>2024-07-02 14:07:04 +0100
commit8f890447b0f8b6cbe369b162670185e8c746b2f2 (patch)
treec8c290661a59b06257ce7e2fda19e799d83825eb
parentFix sync waiting for an invalid token from the "future" (#17386) (diff)
downloadsynapse-8f890447b0f8b6cbe369b162670185e8c746b2f2.tar.xz
Support MSC3916 by adding `_matrix/client/v1/media/download` endpoint (#17365)
-rw-r--r--changelog.d/17365.feature1
-rwxr-xr-xdocker/configure_workers_and_start.py3
-rw-r--r--docs/upgrade.md13
-rw-r--r--docs/workers.md1
-rw-r--r--mypy.ini3
-rw-r--r--poetry.lock18
-rw-r--r--pyproject.toml2
-rw-r--r--synapse/api/ratelimiting.py3
-rw-r--r--synapse/federation/federation_client.py46
-rw-r--r--synapse/federation/transport/client.py25
-rw-r--r--synapse/federation/transport/server/__init__.py9
-rw-r--r--synapse/federation/transport/server/_base.py4
-rw-r--r--synapse/federation/transport/server/federation.py5
-rw-r--r--synapse/http/client.py152
-rw-r--r--synapse/http/matrixfederationclient.py192
-rw-r--r--synapse/media/_base.py28
-rw-r--r--synapse/media/media_repository.py151
-rw-r--r--synapse/media/media_storage.py27
-rw-r--r--synapse/rest/__init__.py4
-rw-r--r--synapse/rest/client/media.py79
-rw-r--r--synapse/rest/media/download_resource.py1
-rw-r--r--tests/federation/test_federation_media.py35
-rw-r--r--tests/http/test_client.py143
-rw-r--r--tests/media/test_media_storage.py14
-rw-r--r--tests/replication/test_multi_media_repo.py234
-rw-r--r--tests/rest/client/test_media.py609
26 files changed, 1718 insertions, 84 deletions
diff --git a/changelog.d/17365.feature b/changelog.d/17365.feature
new file mode 100644
index 0000000000..f90dc84e38
--- /dev/null
+++ b/changelog.d/17365.feature
@@ -0,0 +1 @@
+Support [MSC3916](https://github.com/matrix-org/matrix-spec-proposals/blob/rav/authentication-for-media/proposals/3916-authentication-for-media.md) by adding _matrix/client/v1/media/download endpoint.
\ No newline at end of file
diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py
index 063f3727f9..b6690f3404 100755
--- a/docker/configure_workers_and_start.py
+++ b/docker/configure_workers_and_start.py
@@ -117,7 +117,7 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = {
     },
     "media_repository": {
         "app": "synapse.app.generic_worker",
-        "listener_resources": ["media"],
+        "listener_resources": ["media", "client"],
         "endpoint_patterns": [
             "^/_matrix/media/",
             "^/_synapse/admin/v1/purge_media_cache$",
@@ -125,6 +125,7 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = {
             "^/_synapse/admin/v1/user/.*/media.*$",
             "^/_synapse/admin/v1/media/.*$",
             "^/_synapse/admin/v1/quarantine_media/.*$",
+            "^/_matrix/client/v1/media/.*$",
         ],
         # The first configured media worker will run the media background jobs
         "shared_extra_conf": {
diff --git a/docs/upgrade.md b/docs/upgrade.md
index 99be4122bb..cf53f56b06 100644
--- a/docs/upgrade.md
+++ b/docs/upgrade.md
@@ -117,6 +117,19 @@ each upgrade are complete before moving on to the next upgrade, to avoid
 stacking them up. You can monitor the currently running background updates with
 [the Admin API](usage/administration/admin_api/background_updates.html#status).
 
+# Upgrading to v1.111.0
+
+## New worker endpoints for authenticated client media
+
+[Media repository workers](./workers.md#synapseappmedia_repository) handling
+Media APIs can now handle the following endpoint pattern:
+
+```
+^/_matrix/client/v1/media/.*$
+```
+
+Please update your reverse proxy configuration.
+
 # Upgrading to v1.106.0
 
 ## Minimum supported Rust version
diff --git a/docs/workers.md b/docs/workers.md
index 1f6bfd9e7f..22fde488a9 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -739,6 +739,7 @@ An example for a federation sender instance:
 Handles the media repository. It can handle all endpoints starting with:
 
     /_matrix/media/
+    /_matrix/client/v1/media/
 
 ... and the following regular expressions matching media-specific administration APIs:
 
diff --git a/mypy.ini b/mypy.ini
index 1a2b9ea410..3fca15c01b 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -96,3 +96,6 @@ ignore_missing_imports = True
 # https://github.com/twisted/treq/pull/366
 [mypy-treq.*]
 ignore_missing_imports = True
+
+[mypy-multipart.*]
+ignore_missing_imports = True
diff --git a/poetry.lock b/poetry.lock
index 99c3b62c7d..8142406e3f 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
 
 [[package]]
 name = "annotated-types"
@@ -2040,6 +2040,20 @@ files = [
 six = ">=1.5"
 
 [[package]]
+name = "python-multipart"
+version = "0.0.9"
+description = "A streaming multipart parser for Python"
+optional = false
+python-versions = ">=3.8"
+files = [
+    {file = "python_multipart-0.0.9-py3-none-any.whl", hash = "sha256:97ca7b8ea7b05f977dc3849c3ba99d51689822fab725c3703af7c866a0c2b215"},
+    {file = "python_multipart-0.0.9.tar.gz", hash = "sha256:03f54688c663f1b7977105f021043b0793151e4cb1c1a9d4a11fc13d622c4026"},
+]
+
+[package.extras]
+dev = ["atomicwrites (==1.4.1)", "attrs (==23.2.0)", "coverage (==7.4.1)", "hatch", "invoke (==2.2.0)", "more-itertools (==10.2.0)", "pbr (==6.0.0)", "pluggy (==1.4.0)", "py (==1.11.0)", "pytest (==8.0.0)", "pytest-cov (==4.1.0)", "pytest-timeout (==2.2.0)", "pyyaml (==6.0.1)", "ruff (==0.2.1)"]
+
+[[package]]
 name = "pytz"
 version = "2022.7.1"
 description = "World timezone definitions, modern and historical"
@@ -3187,4 +3201,4 @@ user-search = ["pyicu"]
 [metadata]
 lock-version = "2.0"
 python-versions = "^3.8.0"
-content-hash = "107c8fb5c67360340854fbdba3c085fc5f9c7be24bcb592596a914eea621faea"
+content-hash = "e8d5806e10eb69bc06900fde18ea3df38f38490ab6baa73fe4a563dfb6abacba"
diff --git a/pyproject.toml b/pyproject.toml
index bbf9c78420..0555e67613 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -224,6 +224,8 @@ pydantic = ">=1.7.4, <3"
 # needed.
 setuptools_rust = ">=1.3"
 
+# This is used for parsing multipart responses
+python-multipart = ">=0.0.9"
 
 # Optional Dependencies
 # ---------------------
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index a99a9e09fc..26b8711851 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -130,7 +130,8 @@ class Ratelimiter:
                 Overrides the value set during instantiation if set.
             burst_count: How many actions that can be performed before being limited.
                 Overrides the value set during instantiation if set.
-            update: Whether to count this check as performing the action
+            update: Whether to count this check as performing the action. If the action
+                cannot be performed, the user's action count is not incremented at all.
             n_actions: The number of times the user wants to do this action. If the user
                 cannot do all of the actions, the user's action count is not incremented
                 at all.
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index f0f5a37a57..7d80ff6998 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -1871,6 +1871,52 @@ class FederationClient(FederationBase):
 
         return filtered_statuses, filtered_failures
 
+    async def federation_download_media(
+        self,
+        destination: str,
+        media_id: str,
+        output_stream: BinaryIO,
+        max_size: int,
+        max_timeout_ms: int,
+        download_ratelimiter: Ratelimiter,
+        ip_address: str,
+    ) -> Union[
+        Tuple[int, Dict[bytes, List[bytes]], bytes],
+        Tuple[int, Dict[bytes, List[bytes]]],
+    ]:
+        try:
+            return await self.transport_layer.federation_download_media(
+                destination,
+                media_id,
+                output_stream=output_stream,
+                max_size=max_size,
+                max_timeout_ms=max_timeout_ms,
+                download_ratelimiter=download_ratelimiter,
+                ip_address=ip_address,
+            )
+        except HttpResponseException as e:
+            # If an error is received that is due to an unrecognised endpoint,
+            # fallback to the _matrix/media/v3/download endpoint. Otherwise, consider it a legitimate error
+            # and raise.
+            if not is_unknown_endpoint(e):
+                raise
+
+        logger.debug(
+            "Couldn't download media %s/%s over _matrix/federation/v1/media/download, falling back to _matrix/media/v3/download path",
+            destination,
+            media_id,
+        )
+
+        return await self.transport_layer.download_media_v3(
+            destination,
+            media_id,
+            output_stream=output_stream,
+            max_size=max_size,
+            max_timeout_ms=max_timeout_ms,
+            download_ratelimiter=download_ratelimiter,
+            ip_address=ip_address,
+        )
+
     async def download_media(
         self,
         destination: str,
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index af1336fe5f..206e91ed14 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -824,7 +824,6 @@ class TransportLayerClient:
         ip_address: str,
     ) -> Tuple[int, Dict[bytes, List[bytes]]]:
         path = f"/_matrix/media/r0/download/{destination}/{media_id}"
-
         return await self.client.get_file(
             destination,
             path,
@@ -852,7 +851,6 @@ class TransportLayerClient:
         ip_address: str,
     ) -> Tuple[int, Dict[bytes, List[bytes]]]:
         path = f"/_matrix/media/v3/download/{destination}/{media_id}"
-
         return await self.client.get_file(
             destination,
             path,
@@ -873,6 +871,29 @@ class TransportLayerClient:
             ip_address=ip_address,
         )
 
+    async def federation_download_media(
+        self,
+        destination: str,
+        media_id: str,
+        output_stream: BinaryIO,
+        max_size: int,
+        max_timeout_ms: int,
+        download_ratelimiter: Ratelimiter,
+        ip_address: str,
+    ) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
+        path = f"/_matrix/federation/v1/media/download/{media_id}"
+        return await self.client.federation_get_file(
+            destination,
+            path,
+            output_stream=output_stream,
+            max_size=max_size,
+            args={
+                "timeout_ms": str(max_timeout_ms),
+            },
+            download_ratelimiter=download_ratelimiter,
+            ip_address=ip_address,
+        )
+
 
 def _create_path(federation_prefix: str, path: str, *args: str) -> str:
     """
diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py
index edaf0196d6..c44e5daa47 100644
--- a/synapse/federation/transport/server/__init__.py
+++ b/synapse/federation/transport/server/__init__.py
@@ -32,8 +32,8 @@ from synapse.federation.transport.server._base import (
 from synapse.federation.transport.server.federation import (
     FEDERATION_SERVLET_CLASSES,
     FederationAccountStatusServlet,
+    FederationMediaDownloadServlet,
     FederationUnstableClientKeysClaimServlet,
-    FederationUnstableMediaDownloadServlet,
 )
 from synapse.http.server import HttpServer, JsonResource
 from synapse.http.servlet import (
@@ -316,11 +316,8 @@ def register_servlets(
             ):
                 continue
 
-            if servletclass == FederationUnstableMediaDownloadServlet:
-                if (
-                    not hs.config.server.enable_media_repo
-                    or not hs.config.experimental.msc3916_authenticated_media_enabled
-                ):
+            if servletclass == FederationMediaDownloadServlet:
+                if not hs.config.server.enable_media_repo:
                     continue
 
             servletclass(
diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py
index 4e2717b565..e124481474 100644
--- a/synapse/federation/transport/server/_base.py
+++ b/synapse/federation/transport/server/_base.py
@@ -362,7 +362,7 @@ class BaseFederationServlet:
                                 return None
                             if (
                                 func.__self__.__class__.__name__  # type: ignore
-                                == "FederationUnstableMediaDownloadServlet"
+                                == "FederationMediaDownloadServlet"
                             ):
                                 response = await func(
                                     origin, content, request, *args, **kwargs
@@ -374,7 +374,7 @@ class BaseFederationServlet:
                     else:
                         if (
                             func.__self__.__class__.__name__  # type: ignore
-                            == "FederationUnstableMediaDownloadServlet"
+                            == "FederationMediaDownloadServlet"
                         ):
                             response = await func(
                                 origin, content, request, *args, **kwargs
diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index 67bb907050..ec957768d4 100644
--- a/synapse/federation/transport/server/federation.py
+++ b/synapse/federation/transport/server/federation.py
@@ -790,7 +790,7 @@ class FederationAccountStatusServlet(BaseFederationServerServlet):
         return 200, {"account_statuses": statuses, "failures": failures}
 
 
-class FederationUnstableMediaDownloadServlet(BaseFederationServerServlet):
+class FederationMediaDownloadServlet(BaseFederationServerServlet):
     """
     Implementation of new federation media `/download` endpoint outlined in MSC3916. Returns
     a multipart/mixed response consisting of a JSON object and the requested media
@@ -798,7 +798,6 @@ class FederationUnstableMediaDownloadServlet(BaseFederationServerServlet):
     """
 
     PATH = "/media/download/(?P<media_id>[^/]*)"
-    PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3916"
     RATELIMIT = True
 
     def __init__(
@@ -858,5 +857,5 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
     FederationV1SendKnockServlet,
     FederationMakeKnockServlet,
     FederationAccountStatusServlet,
-    FederationUnstableMediaDownloadServlet,
+    FederationMediaDownloadServlet,
 )
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 4718517c97..56ad28eabf 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -35,6 +35,8 @@ from typing import (
     Union,
 )
 
+import attr
+import multipart
 import treq
 from canonicaljson import encode_canonical_json
 from netaddr import AddrFormatError, IPAddress, IPSet
@@ -1006,6 +1008,130 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
         self._maybe_fail()
 
 
+@attr.s(auto_attribs=True, slots=True)
+class MultipartResponse:
+    """
+    A small class to hold parsed values of a multipart response.
+    """
+
+    json: bytes = b"{}"
+    length: Optional[int] = None
+    content_type: Optional[bytes] = None
+    disposition: Optional[bytes] = None
+    url: Optional[bytes] = None
+
+
+class _MultipartParserProtocol(protocol.Protocol):
+    """
+    Protocol to read and parse a MSC3916 multipart/mixed response
+    """
+
+    transport: Optional[ITCPTransport] = None
+
+    def __init__(
+        self,
+        stream: ByteWriteable,
+        deferred: defer.Deferred,
+        boundary: str,
+        max_length: Optional[int],
+    ) -> None:
+        self.stream = stream
+        self.deferred = deferred
+        self.boundary = boundary
+        self.max_length = max_length
+        self.parser = None
+        self.multipart_response = MultipartResponse()
+        self.has_redirect = False
+        self.in_json = False
+        self.json_done = False
+        self.file_length = 0
+        self.total_length = 0
+        self.in_disposition = False
+        self.in_content_type = False
+
+    def dataReceived(self, incoming_data: bytes) -> None:
+        if self.deferred.called:
+            return
+
+        # we don't have a parser yet, instantiate it
+        if not self.parser:
+
+            def on_header_field(data: bytes, start: int, end: int) -> None:
+                if data[start:end] == b"Location":
+                    self.has_redirect = True
+                if data[start:end] == b"Content-Disposition":
+                    self.in_disposition = True
+                if data[start:end] == b"Content-Type":
+                    self.in_content_type = True
+
+            def on_header_value(data: bytes, start: int, end: int) -> None:
+                # the first header should be content-type for application/json
+                if not self.in_json and not self.json_done:
+                    assert data[start:end] == b"application/json"
+                    self.in_json = True
+                elif self.has_redirect:
+                    self.multipart_response.url = data[start:end]
+                elif self.in_content_type:
+                    self.multipart_response.content_type = data[start:end]
+                    self.in_content_type = False
+                elif self.in_disposition:
+                    self.multipart_response.disposition = data[start:end]
+                    self.in_disposition = False
+
+            def on_part_data(data: bytes, start: int, end: int) -> None:
+                # we've seen json header but haven't written the json data
+                if self.in_json and not self.json_done:
+                    self.multipart_response.json = data[start:end]
+                    self.json_done = True
+                # we have a redirect header rather than a file, and have already captured it
+                elif self.has_redirect:
+                    return
+                # otherwise we are in the file part
+                else:
+                    logger.info("Writing multipart file data to stream")
+                    try:
+                        self.stream.write(data[start:end])
+                    except Exception as e:
+                        logger.warning(
+                            f"Exception encountered writing file data to stream: {e}"
+                        )
+                        self.deferred.errback()
+                    self.file_length += end - start
+
+            callbacks = {
+                "on_header_field": on_header_field,
+                "on_header_value": on_header_value,
+                "on_part_data": on_part_data,
+            }
+            self.parser = multipart.MultipartParser(self.boundary, callbacks)
+
+        self.total_length += len(incoming_data)
+        if self.max_length is not None and self.total_length >= self.max_length:
+            self.deferred.errback(BodyExceededMaxSize())
+            # Close the connection (forcefully) since all the data will get
+            # discarded anyway.
+            assert self.transport is not None
+            self.transport.abortConnection()
+
+        try:
+            self.parser.write(incoming_data)  # type: ignore[attr-defined]
+        except Exception as e:
+            logger.warning(f"Exception writing to multipart parser: {e}")
+            self.deferred.errback()
+            return
+
+    def connectionLost(self, reason: Failure = connectionDone) -> None:
+        # If the maximum size was already exceeded, there's nothing to do.
+        if self.deferred.called:
+            return
+
+        if reason.check(ResponseDone):
+            self.multipart_response.length = self.file_length
+            self.deferred.callback(self.multipart_response)
+        else:
+            self.deferred.errback(reason)
+
+
 class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
     """A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
 
@@ -1091,6 +1217,32 @@ def read_body_with_max_size(
     return d
 
 
+def read_multipart_response(
+    response: IResponse, stream: ByteWriteable, boundary: str, max_length: Optional[int]
+) -> "defer.Deferred[MultipartResponse]":
+    """
+    Reads a MSC3916 multipart/mixed response and parses it, reading the file part (if it contains one) into
+    the stream passed in and returning a deferred resolving to a MultipartResponse
+
+    Args:
+        response: The HTTP response to read from.
+        stream: The file-object to write to.
+        boundary: the multipart/mixed boundary string
+        max_length: maximum allowable length of the response
+    """
+    d: defer.Deferred[MultipartResponse] = defer.Deferred()
+
+    # If the Content-Length header gives a size larger than the maximum allowed
+    # size, do not bother downloading the body.
+    if max_length is not None and response.length != UNKNOWN_LENGTH:
+        if response.length > max_length:
+            response.deliverBody(_DiscardBodyWithMaxSizeProtocol(d))
+            return d
+
+    response.deliverBody(_MultipartParserProtocol(stream, d, boundary, max_length))
+    return d
+
+
 def encode_query_args(args: Optional[QueryParams]) -> bytes:
     """
     Encodes a map of query arguments to bytes which can be appended to a URL.
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 104b803b0f..749b01dd0e 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -75,9 +75,11 @@ from synapse.http.client import (
     BlocklistingAgentWrapper,
     BodyExceededMaxSize,
     ByteWriteable,
+    SimpleHttpClient,
     _make_scheduler,
     encode_query_args,
     read_body_with_max_size,
+    read_multipart_response,
 )
 from synapse.http.connectproxyclient import BearerProxyCredentials
 from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
@@ -466,6 +468,13 @@ class MatrixFederationHttpClient:
 
         self._sleeper = AwakenableSleeper(self.reactor)
 
+        self._simple_http_client = SimpleHttpClient(
+            hs,
+            ip_blocklist=hs.config.server.federation_ip_range_blocklist,
+            ip_allowlist=hs.config.server.federation_ip_range_allowlist,
+            use_proxy=True,
+        )
+
     def wake_destination(self, destination: str) -> None:
         """Called when the remote server may have come back online."""
 
@@ -1553,6 +1562,189 @@ class MatrixFederationHttpClient:
         )
         return length, headers
 
+    async def federation_get_file(
+        self,
+        destination: str,
+        path: str,
+        output_stream: BinaryIO,
+        download_ratelimiter: Ratelimiter,
+        ip_address: str,
+        max_size: int,
+        args: Optional[QueryParams] = None,
+        retry_on_dns_fail: bool = True,
+        ignore_backoff: bool = False,
+    ) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
+        """GETs a file from a given homeserver over the federation /download endpoint
+        Args:
+            destination: The remote server to send the HTTP request to.
+            path: The HTTP path to GET.
+            output_stream: File to write the response body to.
+            download_ratelimiter: a ratelimiter to limit remote media downloads, keyed to
+                requester IP
+            ip_address: IP address of the requester
+            max_size: maximum allowable size in bytes of the file
+            args: Optional dictionary used to create the query string.
+            ignore_backoff: true to ignore the historical backoff data
+                and try the request anyway.
+
+        Returns:
+            Resolves to an (int, dict, bytes) tuple of
+            the file length, a dict of the response headers, and the file json
+
+        Raises:
+            HttpResponseException: If we get an HTTP response code >= 300
+                (except 429).
+            NotRetryingDestination: If we are not yet ready to retry this
+                server.
+            FederationDeniedError: If this destination is not on our
+                federation whitelist
+            RequestSendFailed: If there were problems connecting to the
+                remote, due to e.g. DNS failures, connection timeouts etc.
+            SynapseError: If the requested file exceeds ratelimits or the response from the
+            remote server is not a multipart response
+            AssertionError: if the resolved multipart response's length is None
+        """
+        request = MatrixFederationRequest(
+            method="GET", destination=destination, path=path, query=args
+        )
+
+        # check for a minimum balance of 1MiB in ratelimiter before initiating request
+        send_req, _ = await download_ratelimiter.can_do_action(
+            requester=None, key=ip_address, n_actions=1048576, update=False
+        )
+
+        if not send_req:
+            msg = "Requested file size exceeds ratelimits"
+            logger.warning(
+                "{%s} [%s] %s",
+                request.txn_id,
+                request.destination,
+                msg,
+            )
+            raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
+
+        response = await self._send_request(
+            request,
+            retry_on_dns_fail=retry_on_dns_fail,
+            ignore_backoff=ignore_backoff,
+        )
+
+        headers = dict(response.headers.getAllRawHeaders())
+
+        expected_size = response.length
+        # if we don't get an expected length then use the max length
+        if expected_size == UNKNOWN_LENGTH:
+            expected_size = max_size
+            logger.debug(
+                f"File size unknown, assuming file is max allowable size: {max_size}"
+            )
+
+        read_body, _ = await download_ratelimiter.can_do_action(
+            requester=None,
+            key=ip_address,
+            n_actions=expected_size,
+        )
+        if not read_body:
+            msg = "Requested file size exceeds ratelimits"
+            logger.warning(
+                "{%s} [%s] %s",
+                request.txn_id,
+                request.destination,
+                msg,
+            )
+            raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
+
+        # this should be a multipart/mixed response with the boundary string in the header
+        try:
+            raw_content_type = headers.get(b"Content-Type")
+            assert raw_content_type is not None
+            content_type = raw_content_type[0].decode("UTF-8")
+            content_type_parts = content_type.split("boundary=")
+            boundary = content_type_parts[1]
+        except Exception:
+            msg = "Remote response is malformed: expected Content-Type of multipart/mixed with a boundary present."
+            logger.warning(
+                "{%s} [%s] %s",
+                request.txn_id,
+                request.destination,
+                msg,
+            )
+            raise SynapseError(HTTPStatus.BAD_GATEWAY, msg)
+
+        try:
+            # add a byte of headroom to max size as `_MultipartParserProtocol.dataReceived` errs at >=
+            deferred = read_multipart_response(
+                response, output_stream, boundary, expected_size + 1
+            )
+            deferred.addTimeout(self.default_timeout_seconds, self.reactor)
+        except BodyExceededMaxSize:
+            msg = "Requested file is too large > %r bytes" % (expected_size,)
+            logger.warning(
+                "{%s} [%s] %s",
+                request.txn_id,
+                request.destination,
+                msg,
+            )
+            raise SynapseError(HTTPStatus.BAD_GATEWAY, msg, Codes.TOO_LARGE)
+        except defer.TimeoutError as e:
+            logger.warning(
+                "{%s} [%s] Timed out reading response - %s %s",
+                request.txn_id,
+                request.destination,
+                request.method,
+                request.uri.decode("ascii"),
+            )
+            raise RequestSendFailed(e, can_retry=True) from e
+        except ResponseFailed as e:
+            logger.warning(
+                "{%s} [%s] Failed to read response - %s %s",
+                request.txn_id,
+                request.destination,
+                request.method,
+                request.uri.decode("ascii"),
+            )
+            raise RequestSendFailed(e, can_retry=True) from e
+        except Exception as e:
+            logger.warning(
+                "{%s} [%s] Error reading response: %s",
+                request.txn_id,
+                request.destination,
+                e,
+            )
+            raise
+
+        multipart_response = await make_deferred_yieldable(deferred)
+        if not multipart_response.url:
+            assert multipart_response.length is not None
+            length = multipart_response.length
+            headers[b"Content-Type"] = [multipart_response.content_type]
+            headers[b"Content-Disposition"] = [multipart_response.disposition]
+
+        # the response contained a redirect url to download the file from
+        else:
+            str_url = multipart_response.url.decode("utf-8")
+            logger.info(
+                "{%s} [%s] File download redirected, now downloading from: %s",
+                request.txn_id,
+                request.destination,
+                str_url,
+            )
+            length, headers, _, _ = await self._simple_http_client.get_file(
+                str_url, output_stream, expected_size
+            )
+
+        logger.info(
+            "{%s} [%s] Completed: %d %s [%d bytes] %s %s",
+            request.txn_id,
+            request.destination,
+            response.code,
+            response.phrase.decode("ascii", errors="replace"),
+            length,
+            request.method,
+            request.uri.decode("ascii"),
+        )
+        return length, headers, multipart_response.json
+
 
 def _flatten_response_never_received(e: BaseException) -> str:
     if hasattr(e, "reasons"):
diff --git a/synapse/media/_base.py b/synapse/media/_base.py
index 7ad0b7c3cf..1b268ce4d4 100644
--- a/synapse/media/_base.py
+++ b/synapse/media/_base.py
@@ -221,6 +221,7 @@ def add_file_headers(
     # select private. don't bother setting Expires as all our
     # clients are smart enough to be happy with Cache-Control
     request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
+
     if file_size is not None:
         request.setHeader(b"Content-Length", b"%d" % (file_size,))
 
@@ -302,12 +303,37 @@ async def respond_with_multipart_responder(
             )
             return
 
+        if media_info.media_type.lower().split(";", 1)[0] in INLINE_CONTENT_TYPES:
+            disposition = "inline"
+        else:
+            disposition = "attachment"
+
+        def _quote(x: str) -> str:
+            return urllib.parse.quote(x.encode("utf-8"))
+
+        if media_info.upload_name:
+            if _can_encode_filename_as_token(media_info.upload_name):
+                disposition = "%s; filename=%s" % (
+                    disposition,
+                    media_info.upload_name,
+                )
+            else:
+                disposition = "%s; filename*=utf-8''%s" % (
+                    disposition,
+                    _quote(media_info.upload_name),
+                )
+
         from synapse.media.media_storage import MultipartFileConsumer
 
         # note that currently the json_object is just {}, this will change when linked media
         # is implemented
         multipart_consumer = MultipartFileConsumer(
-            clock, request, media_info.media_type, {}, media_info.media_length
+            clock,
+            request,
+            media_info.media_type,
+            {},
+            disposition,
+            media_info.media_length,
         )
 
         logger.debug("Responding to media request with responder %s", responder)
diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 1436329fad..542642b900 100644
--- a/synapse/media/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -480,6 +480,7 @@ class MediaRepository:
         name: Optional[str],
         max_timeout_ms: int,
         ip_address: str,
+        use_federation_endpoint: bool,
     ) -> None:
         """Respond to requests for remote media.
 
@@ -492,6 +493,8 @@ class MediaRepository:
             max_timeout_ms: the maximum number of milliseconds to wait for the
                 media to be uploaded.
             ip_address: the IP address of the requester
+            use_federation_endpoint: whether to request the remote media over the new
+                federation `/download` endpoint
 
         Returns:
             Resolves once a response has successfully been written to request
@@ -522,6 +525,7 @@ class MediaRepository:
                 max_timeout_ms,
                 self.download_ratelimiter,
                 ip_address,
+                use_federation_endpoint,
             )
 
         # We deliberately stream the file outside the lock
@@ -569,6 +573,7 @@ class MediaRepository:
                 max_timeout_ms,
                 self.download_ratelimiter,
                 ip_address,
+                False,
             )
 
         # Ensure we actually use the responder so that it releases resources
@@ -585,6 +590,7 @@ class MediaRepository:
         max_timeout_ms: int,
         download_ratelimiter: Ratelimiter,
         ip_address: str,
+        use_federation_endpoint: bool,
     ) -> Tuple[Optional[Responder], RemoteMedia]:
         """Looks for media in local cache, if not there then attempt to
         download from remote server.
@@ -598,6 +604,8 @@ class MediaRepository:
             download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
                 requester IP.
             ip_address: the IP address of the requester
+            use_federation_endpoint: whether to request the remote media over the new federation
+            /download endpoint
 
         Returns:
             A tuple of responder and the media info of the file.
@@ -629,9 +637,23 @@ class MediaRepository:
         # Failed to find the file anywhere, lets download it.
 
         try:
-            media_info = await self._download_remote_file(
-                server_name, media_id, max_timeout_ms, download_ratelimiter, ip_address
-            )
+            if not use_federation_endpoint:
+                media_info = await self._download_remote_file(
+                    server_name,
+                    media_id,
+                    max_timeout_ms,
+                    download_ratelimiter,
+                    ip_address,
+                )
+            else:
+                media_info = await self._federation_download_remote_file(
+                    server_name,
+                    media_id,
+                    max_timeout_ms,
+                    download_ratelimiter,
+                    ip_address,
+                )
+
         except SynapseError:
             raise
         except Exception as e:
@@ -775,6 +797,129 @@ class MediaRepository:
             quarantined_by=None,
         )
 
+    async def _federation_download_remote_file(
+        self,
+        server_name: str,
+        media_id: str,
+        max_timeout_ms: int,
+        download_ratelimiter: Ratelimiter,
+        ip_address: str,
+    ) -> RemoteMedia:
+        """Attempt to download the remote file from the given server name.
+        Uses the given file_id as the local id and downloads the file over the federation
+        v1 download endpoint
+
+        Args:
+            server_name: Originating server
+            media_id: The media ID of the content (as defined by the
+                remote server). This is different than the file_id, which is
+                locally generated.
+            max_timeout_ms: the maximum number of milliseconds to wait for the
+                media to be uploaded.
+            download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
+                requester IP
+            ip_address: the IP address of the requester
+
+        Returns:
+            The media info of the file.
+        """
+
+        file_id = random_string(24)
+
+        file_info = FileInfo(server_name=server_name, file_id=file_id)
+
+        async with self.media_storage.store_into_file(file_info) as (f, fname):
+            try:
+                res = await self.client.federation_download_media(
+                    server_name,
+                    media_id,
+                    output_stream=f,
+                    max_size=self.max_upload_size,
+                    max_timeout_ms=max_timeout_ms,
+                    download_ratelimiter=download_ratelimiter,
+                    ip_address=ip_address,
+                )
+                # if we had to fall back to the _matrix/media endpoint it will only return
+                # the headers and length, check the length of the tuple before unpacking
+                if len(res) == 3:
+                    length, headers, json = res
+                else:
+                    length, headers = res
+            except RequestSendFailed as e:
+                logger.warning(
+                    "Request failed fetching remote media %s/%s: %r",
+                    server_name,
+                    media_id,
+                    e,
+                )
+                raise SynapseError(502, "Failed to fetch remote media")
+
+            except HttpResponseException as e:
+                logger.warning(
+                    "HTTP error fetching remote media %s/%s: %s",
+                    server_name,
+                    media_id,
+                    e.response,
+                )
+                if e.code == twisted.web.http.NOT_FOUND:
+                    raise e.to_synapse_error()
+                raise SynapseError(502, "Failed to fetch remote media")
+
+            except SynapseError:
+                logger.warning(
+                    "Failed to fetch remote media %s/%s", server_name, media_id
+                )
+                raise
+            except NotRetryingDestination:
+                logger.warning("Not retrying destination %r", server_name)
+                raise SynapseError(502, "Failed to fetch remote media")
+            except Exception:
+                logger.exception(
+                    "Failed to fetch remote media %s/%s", server_name, media_id
+                )
+                raise SynapseError(502, "Failed to fetch remote media")
+
+            if b"Content-Type" in headers:
+                media_type = headers[b"Content-Type"][0].decode("ascii")
+            else:
+                media_type = "application/octet-stream"
+            upload_name = get_filename_from_headers(headers)
+            time_now_ms = self.clock.time_msec()
+
+            # Multiple remote media download requests can race (when using
+            # multiple media repos), so this may throw a violation constraint
+            # exception. If it does we'll delete the newly downloaded file from
+            # disk (as we're in the ctx manager).
+            #
+            # However: we've already called `finish()` so we may have also
+            # written to the storage providers. This is preferable to the
+            # alternative where we call `finish()` *after* this, where we could
+            # end up having an entry in the DB but fail to write the files to
+            # the storage providers.
+            await self.store.store_cached_remote_media(
+                origin=server_name,
+                media_id=media_id,
+                media_type=media_type,
+                time_now_ms=time_now_ms,
+                upload_name=upload_name,
+                media_length=length,
+                filesystem_id=file_id,
+            )
+
+        logger.debug("Stored remote media in file %r", fname)
+
+        return RemoteMedia(
+            media_origin=server_name,
+            media_id=media_id,
+            media_type=media_type,
+            media_length=length,
+            upload_name=upload_name,
+            created_ts=time_now_ms,
+            filesystem_id=file_id,
+            last_access_ts=time_now_ms,
+            quarantined_by=None,
+        )
+
     def _get_thumbnail_requirements(
         self, media_type: str
     ) -> Tuple[ThumbnailRequirement, ...]:
diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py
index 1be2c9b5f5..2a106bb0eb 100644
--- a/synapse/media/media_storage.py
+++ b/synapse/media/media_storage.py
@@ -401,13 +401,14 @@ class MultipartFileConsumer:
         wrapped_consumer: interfaces.IConsumer,
         file_content_type: str,
         json_object: JsonDict,
-        content_length: Optional[int] = None,
+        disposition: str,
+        content_length: Optional[int],
     ) -> None:
         self.clock = clock
         self.wrapped_consumer = wrapped_consumer
         self.json_field = json_object
         self.json_field_written = False
-        self.content_type_written = False
+        self.file_headers_written = False
         self.file_content_type = file_content_type
         self.boundary = uuid4().hex.encode("ascii")
 
@@ -420,6 +421,7 @@ class MultipartFileConsumer:
         self.paused = False
 
         self.length = content_length
+        self.disposition = disposition
 
     ### IConsumer APIs ###
 
@@ -488,11 +490,13 @@ class MultipartFileConsumer:
             self.json_field_written = True
 
         # if we haven't written the content type yet, do so
-        if not self.content_type_written:
+        if not self.file_headers_written:
             type = self.file_content_type.encode("utf-8")
             content_type = Header(b"Content-Type", type)
-            self.wrapped_consumer.write(bytes(content_type) + CRLF + CRLF)
-            self.content_type_written = True
+            self.wrapped_consumer.write(bytes(content_type) + CRLF)
+            disp_header = Header(b"Content-Disposition", self.disposition)
+            self.wrapped_consumer.write(bytes(disp_header) + CRLF + CRLF)
+            self.file_headers_written = True
 
         self.wrapped_consumer.write(data)
 
@@ -506,7 +510,6 @@ class MultipartFileConsumer:
         producing data for good.
         """
         assert self.producer is not None
-
         self.paused = True
         self.producer.stopProducing()
 
@@ -518,7 +521,6 @@ class MultipartFileConsumer:
         the time being, and to stop until C{resumeProducing()} is called.
         """
         assert self.producer is not None
-
         self.paused = True
 
         if self.streaming:
@@ -549,7 +551,7 @@ class MultipartFileConsumer:
         """
         if not self.length:
             return None
-        # calculate length of json field and content-type header
+        # calculate length of json field and content-type, disposition headers
         json_field = json.dumps(self.json_field)
         json_bytes = json_field.encode("utf-8")
         json_length = len(json_bytes)
@@ -558,9 +560,13 @@ class MultipartFileConsumer:
         content_type = Header(b"Content-Type", type)
         type_length = len(bytes(content_type))
 
-        # 154 is the length of the elements that aren't variable, ie
+        disp = self.disposition.encode("utf-8")
+        disp_header = Header(b"Content-Disposition", disp)
+        disp_length = len(bytes(disp_header))
+
+        # 156 is the length of the elements that aren't variable, ie
         # CRLFs and boundary strings, etc
-        self.length += json_length + type_length + 154
+        self.length += json_length + type_length + disp_length + 156
 
         return self.length
 
@@ -569,7 +575,6 @@ class MultipartFileConsumer:
     async def _resumeProducingRepeatedly(self) -> None:
         assert self.producer is not None
         assert not self.streaming
-
         producer = cast("interfaces.IPullProducer", self.producer)
 
         self.paused = False
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 0024ccf708..c94d454a28 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -145,6 +145,10 @@ class ClientRestResource(JsonResource):
         password_policy.register_servlets(hs, client_resource)
         knock.register_servlets(hs, client_resource)
         appservice_ping.register_servlets(hs, client_resource)
+        if hs.config.server.enable_media_repo:
+            from synapse.rest.client import media
+
+            media.register_servlets(hs, client_resource)
 
         # moving to /_synapse/admin
         if is_main_process:
diff --git a/synapse/rest/client/media.py b/synapse/rest/client/media.py
index 0c089163c1..c0ae5dd66f 100644
--- a/synapse/rest/client/media.py
+++ b/synapse/rest/client/media.py
@@ -22,6 +22,7 @@
 
 import logging
 import re
+from typing import Optional
 
 from synapse.http.server import (
     HttpServer,
@@ -194,14 +195,76 @@ class UnstableThumbnailResource(RestServlet):
             self.media_repo.mark_recently_accessed(server_name, media_id)
 
 
+class DownloadResource(RestServlet):
+    PATTERNS = [
+        re.compile(
+            "/_matrix/client/v1/media/download/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)(/(?P<file_name>[^/]*))?$"
+        )
+    ]
+
+    def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
+        super().__init__()
+        self.media_repo = media_repo
+        self._is_mine_server_name = hs.is_mine_server_name
+        self.auth = hs.get_auth()
+
+    async def on_GET(
+        self,
+        request: SynapseRequest,
+        server_name: str,
+        media_id: str,
+        file_name: Optional[str] = None,
+    ) -> None:
+        # Validate the server name, raising if invalid
+        parse_and_validate_server_name(server_name)
+
+        await self.auth.get_user_by_req(request)
+
+        set_cors_headers(request)
+        set_corp_headers(request)
+        request.setHeader(
+            b"Content-Security-Policy",
+            b"sandbox;"
+            b" default-src 'none';"
+            b" script-src 'none';"
+            b" plugin-types application/pdf;"
+            b" style-src 'unsafe-inline';"
+            b" media-src 'self';"
+            b" object-src 'self';",
+        )
+        # Limited non-standard form of CSP for IE11
+        request.setHeader(b"X-Content-Security-Policy", b"sandbox;")
+        request.setHeader(b"Referrer-Policy", b"no-referrer")
+        max_timeout_ms = parse_integer(
+            request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS
+        )
+        max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS)
+
+        if self._is_mine_server_name(server_name):
+            await self.media_repo.get_local_media(
+                request, media_id, file_name, max_timeout_ms
+            )
+        else:
+            ip_address = request.getClientAddress().host
+            await self.media_repo.get_remote_media(
+                request,
+                server_name,
+                media_id,
+                file_name,
+                max_timeout_ms,
+                ip_address,
+                True,
+            )
+
+
 def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
-    if hs.config.experimental.msc3916_authenticated_media_enabled:
-        media_repo = hs.get_media_repository()
-        if hs.config.media.url_preview_enabled:
-            UnstablePreviewURLServlet(
-                hs, media_repo, media_repo.media_storage
-            ).register(http_server)
-        UnstableMediaConfigResource(hs).register(http_server)
-        UnstableThumbnailResource(hs, media_repo, media_repo.media_storage).register(
+    media_repo = hs.get_media_repository()
+    if hs.config.media.url_preview_enabled:
+        UnstablePreviewURLServlet(hs, media_repo, media_repo.media_storage).register(
             http_server
         )
+    UnstableMediaConfigResource(hs).register(http_server)
+    UnstableThumbnailResource(hs, media_repo, media_repo.media_storage).register(
+        http_server
+    )
+    DownloadResource(hs, media_repo).register(http_server)
diff --git a/synapse/rest/media/download_resource.py b/synapse/rest/media/download_resource.py
index 1628d58926..c32c626905 100644
--- a/synapse/rest/media/download_resource.py
+++ b/synapse/rest/media/download_resource.py
@@ -105,4 +105,5 @@ class DownloadResource(RestServlet):
                 file_name,
                 max_timeout_ms,
                 ip_address,
+                False,
             )
diff --git a/tests/federation/test_federation_media.py b/tests/federation/test_federation_media.py
index 2c396adbe3..142f73cfdb 100644
--- a/tests/federation/test_federation_media.py
+++ b/tests/federation/test_federation_media.py
@@ -36,10 +36,9 @@ from synapse.util import Clock
 
 from tests import unittest
 from tests.test_utils import SMALL_PNG
-from tests.unittest import override_config
 
 
-class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase):
+class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase):
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         super().prepare(reactor, clock, hs)
@@ -65,9 +64,6 @@ class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase
         )
         self.media_repo = hs.get_media_repository()
 
-    @override_config(
-        {"experimental_features": {"msc3916_authenticated_media_enabled": True}}
-    )
     def test_file_download(self) -> None:
         content = io.BytesIO(b"file_to_stream")
         content_uri = self.get_success(
@@ -82,7 +78,7 @@ class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase
         # test with a text file
         channel = self.make_signed_federation_request(
             "GET",
-            f"/_matrix/federation/unstable/org.matrix.msc3916/media/download/{content_uri.media_id}",
+            f"/_matrix/federation/v1/media/download/{content_uri.media_id}",
         )
         self.pump()
         self.assertEqual(200, channel.code)
@@ -106,7 +102,8 @@ class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase
 
         # check that the text file and expected value exist
         found_file = any(
-            "\r\nContent-Type: text/plain\r\n\r\nfile_to_stream" in field
+            "\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_to_stream"
+            in field
             for field in stripped
         )
         self.assertTrue(found_file)
@@ -124,7 +121,7 @@ class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase
         # test with an image file
         channel = self.make_signed_federation_request(
             "GET",
-            f"/_matrix/federation/unstable/org.matrix.msc3916/media/download/{content_uri.media_id}",
+            f"/_matrix/federation/v1/media/download/{content_uri.media_id}",
         )
         self.pump()
         self.assertEqual(200, channel.code)
@@ -149,25 +146,3 @@ class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase
         # check that the png file exists and matches what was uploaded
         found_file = any(SMALL_PNG in field for field in stripped_bytes)
         self.assertTrue(found_file)
-
-    @override_config(
-        {"experimental_features": {"msc3916_authenticated_media_enabled": False}}
-    )
-    def test_disable_config(self) -> None:
-        content = io.BytesIO(b"file_to_stream")
-        content_uri = self.get_success(
-            self.media_repo.create_content(
-                "text/plain",
-                "test_upload",
-                content,
-                46,
-                UserID.from_string("@user_id:whatever.org"),
-            )
-        )
-        channel = self.make_signed_federation_request(
-            "GET",
-            f"/_matrix/federation/unstable/org.matrix.msc3916/media/download/{content_uri.media_id}",
-        )
-        self.pump()
-        self.assertEqual(404, channel.code)
-        self.assertEqual(channel.json_body.get("errcode"), "M_UNRECOGNIZED")
diff --git a/tests/http/test_client.py b/tests/http/test_client.py
index a98091d711..721917f957 100644
--- a/tests/http/test_client.py
+++ b/tests/http/test_client.py
@@ -37,18 +37,155 @@ from synapse.http.client import (
     BlocklistingAgentWrapper,
     BlocklistingReactorWrapper,
     BodyExceededMaxSize,
+    MultipartResponse,
     _DiscardBodyWithMaxSizeProtocol,
+    _MultipartParserProtocol,
     read_body_with_max_size,
+    read_multipart_response,
 )
 
 from tests.server import FakeTransport, get_clock
 from tests.unittest import TestCase
 
 
+class ReadMultipartResponseTests(TestCase):
+    data1 = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_"
+    data2 = b"to_stream\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n"
+
+    redirect_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nLocation: https://cdn.example.org/ab/c1/2345.txt\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n"
+
+    def _build_multipart_response(
+        self, response_length: Union[int, str], max_length: int
+    ) -> Tuple[
+        BytesIO,
+        "Deferred[MultipartResponse]",
+        _MultipartParserProtocol,
+    ]:
+        """Start reading the body, returns the response, result and proto"""
+        response = Mock(length=response_length)
+        result = BytesIO()
+        boundary = "6067d4698f8d40a0a794ea7d7379d53a"
+        deferred = read_multipart_response(response, result, boundary, max_length)
+
+        # Fish the protocol out of the response.
+        protocol = response.deliverBody.call_args[0][0]
+        protocol.transport = Mock()
+
+        return result, deferred, protocol
+
+    def _assert_error(
+        self,
+        deferred: "Deferred[MultipartResponse]",
+        protocol: _MultipartParserProtocol,
+    ) -> None:
+        """Ensure that the expected error is received."""
+        assert isinstance(deferred.result, Failure)
+        self.assertIsInstance(deferred.result.value, BodyExceededMaxSize)
+        assert protocol.transport is not None
+        # type-ignore: presumably abortConnection has been replaced with a Mock.
+        protocol.transport.abortConnection.assert_called_once()  # type: ignore[attr-defined]
+
+    def _cleanup_error(self, deferred: "Deferred[MultipartResponse]") -> None:
+        """Ensure that the error in the Deferred is handled gracefully."""
+        called = [False]
+
+        def errback(f: Failure) -> None:
+            called[0] = True
+
+        deferred.addErrback(errback)
+        self.assertTrue(called[0])
+
+    def test_parse_file(self) -> None:
+        """
+        Check that a multipart response containing a file is properly parsed
+        into the json/file parts, and the json and file are properly captured
+        """
+        result, deferred, protocol = self._build_multipart_response(249, 250)
+
+        # Start sending data.
+        protocol.dataReceived(self.data1)
+        protocol.dataReceived(self.data2)
+        # Close the connection.
+        protocol.connectionLost(Failure(ResponseDone()))
+
+        multipart_response: MultipartResponse = deferred.result  # type: ignore[assignment]
+
+        self.assertEqual(multipart_response.json, b"{}")
+        self.assertEqual(result.getvalue(), b"file_to_stream")
+        self.assertEqual(multipart_response.length, len(b"file_to_stream"))
+        self.assertEqual(multipart_response.content_type, b"text/plain")
+        self.assertEqual(
+            multipart_response.disposition, b"inline; filename=test_upload"
+        )
+
+    def test_parse_redirect(self) -> None:
+        """
+        check that a multipart response containing a redirect is properly parsed and redirect url is
+        returned
+        """
+        result, deferred, protocol = self._build_multipart_response(249, 250)
+
+        # Start sending data.
+        protocol.dataReceived(self.redirect_data)
+        # Close the connection.
+        protocol.connectionLost(Failure(ResponseDone()))
+
+        multipart_response: MultipartResponse = deferred.result  # type: ignore[assignment]
+
+        self.assertEqual(multipart_response.json, b"{}")
+        self.assertEqual(result.getvalue(), b"")
+        self.assertEqual(
+            multipart_response.url, b"https://cdn.example.org/ab/c1/2345.txt"
+        )
+
+    def test_too_large(self) -> None:
+        """A response which is too large raises an exception."""
+        result, deferred, protocol = self._build_multipart_response(UNKNOWN_LENGTH, 180)
+
+        # Start sending data.
+        protocol.dataReceived(self.data1)
+
+        self.assertEqual(result.getvalue(), b"file_")
+        self._assert_error(deferred, protocol)
+        self._cleanup_error(deferred)
+
+    def test_additional_data(self) -> None:
+        """A connection can receive data after being closed."""
+        result, deferred, protocol = self._build_multipart_response(UNKNOWN_LENGTH, 180)
+
+        # Start sending data.
+        protocol.dataReceived(self.data1)
+        self._assert_error(deferred, protocol)
+
+        # More data might have come in.
+        protocol.dataReceived(self.data2)
+
+        self.assertEqual(result.getvalue(), b"file_")
+        self._assert_error(deferred, protocol)
+        self._cleanup_error(deferred)
+
+    def test_content_length(self) -> None:
+        """The body shouldn't be read (at all) if the Content-Length header is too large."""
+        result, deferred, protocol = self._build_multipart_response(250, 1)
+
+        # Deferred shouldn't be called yet.
+        self.assertFalse(deferred.called)
+
+        # Start sending data.
+        protocol.dataReceived(self.data1)
+        self._assert_error(deferred, protocol)
+        self._cleanup_error(deferred)
+
+        # The data is never consumed.
+        self.assertEqual(result.getvalue(), b"")
+
+
 class ReadBodyWithMaxSizeTests(TestCase):
-    def _build_response(
-        self, length: Union[int, str] = UNKNOWN_LENGTH
-    ) -> Tuple[BytesIO, "Deferred[int]", _DiscardBodyWithMaxSizeProtocol]:
+    def _build_response(self, length: Union[int, str] = UNKNOWN_LENGTH) -> Tuple[
+        BytesIO,
+        "Deferred[int]",
+        _DiscardBodyWithMaxSizeProtocol,
+    ]:
         """Start reading the body, returns the response, result and proto"""
         response = Mock(length=length)
         result = BytesIO()
diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py
index 46d20ce775..024086b775 100644
--- a/tests/media/test_media_storage.py
+++ b/tests/media/test_media_storage.py
@@ -129,7 +129,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
 
 
 @attr.s(auto_attribs=True, slots=True, frozen=True)
-class _TestImage:
+class TestImage:
     """An image for testing thumbnailing with the expected results
 
     Attributes:
@@ -158,7 +158,7 @@ class _TestImage:
     is_inline: bool = True
 
 
-small_png = _TestImage(
+small_png = TestImage(
     SMALL_PNG,
     b"image/png",
     b".png",
@@ -175,7 +175,7 @@ small_png = _TestImage(
     ),
 )
 
-small_png_with_transparency = _TestImage(
+small_png_with_transparency = TestImage(
     unhexlify(
         b"89504e470d0a1a0a0000000d49484452000000010000000101000"
         b"00000376ef9240000000274524e5300010194fdae0000000a4944"
@@ -188,7 +188,7 @@ small_png_with_transparency = _TestImage(
     # different versions of Pillow.
 )
 
-small_lossless_webp = _TestImage(
+small_lossless_webp = TestImage(
     unhexlify(
         b"524946461a000000574542505650384c0d0000002f0000001007" b"1011118888fe0700"
     ),
@@ -196,7 +196,7 @@ small_lossless_webp = _TestImage(
     b".webp",
 )
 
-empty_file = _TestImage(
+empty_file = TestImage(
     b"",
     b"image/gif",
     b".gif",
@@ -204,7 +204,7 @@ empty_file = _TestImage(
     unable_to_thumbnail=True,
 )
 
-SVG = _TestImage(
+SVG = TestImage(
     b"""<?xml version="1.0"?>
 <!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
   "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
@@ -236,7 +236,7 @@ urls = [
 @parameterized_class(("test_image", "url"), itertools.product(test_images, urls))
 class MediaRepoTests(unittest.HomeserverTestCase):
     servlets = [media.register_servlets]
-    test_image: ClassVar[_TestImage]
+    test_image: ClassVar[TestImage]
     hijack_auth = True
     user_id = "@test:user"
     url: ClassVar[str]
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 4927e45446..6fc4600c41 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -28,7 +28,7 @@ from twisted.web.http import HTTPChannel
 from twisted.web.server import Request
 
 from synapse.rest import admin
-from synapse.rest.client import login
+from synapse.rest.client import login, media
 from synapse.server import HomeServer
 from synapse.util import Clock
 
@@ -255,6 +255,238 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         return sum(len(files) for _, _, files in os.walk(path))
 
 
+class AuthenticatedMediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
+    """Checks running multiple media repos work correctly using autheticated media paths"""
+
+    servlets = [
+        admin.register_servlets_for_client_rest_resource,
+        login.register_servlets,
+        media.register_servlets,
+    ]
+
+    file_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_to_stream\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n"
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.user_id = self.register_user("user", "pass")
+        self.access_token = self.login("user", "pass")
+
+        self.reactor.lookups["example.com"] = "1.2.3.4"
+
+    def default_config(self) -> dict:
+        conf = super().default_config()
+        conf["federation_custom_ca_list"] = [get_test_ca_cert_file()]
+        return conf
+
+    def make_worker_hs(
+        self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any
+    ) -> HomeServer:
+        worker_hs = super().make_worker_hs(worker_app, extra_config, **kwargs)
+        # Force the media paths onto the replication resource.
+        worker_hs.get_media_repository_resource().register_servlets(
+            self._hs_to_site[worker_hs].resource, worker_hs
+        )
+        return worker_hs
+
+    def _get_media_req(
+        self, hs: HomeServer, target: str, media_id: str
+    ) -> Tuple[FakeChannel, Request]:
+        """Request some remote media from the given HS by calling the download
+        API.
+
+        This then triggers an outbound request from the HS to the target.
+
+        Returns:
+            The channel for the *client* request and the *outbound* request for
+            the media which the caller should respond to.
+        """
+        channel = make_request(
+            self.reactor,
+            self._hs_to_site[hs],
+            "GET",
+            f"/_matrix/client/v1/media/download/{target}/{media_id}",
+            shorthand=False,
+            access_token=self.access_token,
+            await_result=False,
+        )
+        self.pump()
+
+        clients = self.reactor.tcpClients
+        self.assertGreaterEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
+
+        # build the test server
+        server_factory = Factory.forProtocol(HTTPChannel)
+        # Request.finish expects the factory to have a 'log' method.
+        server_factory.log = _log_request
+
+        server_tls_protocol = wrap_server_factory_for_tls(
+            server_factory, self.reactor, sanlist=[b"DNS:example.com"]
+        ).buildProtocol(None)
+
+        # now, tell the client protocol factory to build the client protocol (it will be a
+        # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
+        # HTTP11ClientProtocol) and wire the output of said protocol up to the server via
+        # a FakeTransport.
+        #
+        # Normally this would be done by the TCP socket code in Twisted, but we are
+        # stubbing that out here.
+        client_protocol = client_factory.buildProtocol(None)
+        client_protocol.makeConnection(
+            FakeTransport(server_tls_protocol, self.reactor, client_protocol)
+        )
+
+        # tell the server tls protocol to send its stuff back to the client, too
+        server_tls_protocol.makeConnection(
+            FakeTransport(client_protocol, self.reactor, server_tls_protocol)
+        )
+
+        # fish the test server back out of the server-side TLS protocol.
+        http_server: HTTPChannel = server_tls_protocol.wrappedProtocol
+
+        # give the reactor a pump to get the TLS juices flowing.
+        self.reactor.pump((0.1,))
+
+        self.assertEqual(len(http_server.requests), 1)
+        request = http_server.requests[0]
+
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(
+            request.path,
+            f"/_matrix/federation/v1/media/download/{media_id}".encode(),
+        )
+        self.assertEqual(
+            request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")]
+        )
+
+        return channel, request
+
+    def test_basic(self) -> None:
+        """Test basic fetching of remote media from a single worker."""
+        hs1 = self.make_worker_hs("synapse.app.generic_worker")
+
+        channel, request = self._get_media_req(hs1, "example.com:443", "ABC123")
+
+        request.setResponseCode(200)
+        request.responseHeaders.setRawHeaders(
+            b"Content-Type",
+            ["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"],
+        )
+        request.write(self.file_data)
+        request.finish()
+
+        self.pump(0.1)
+
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.result["body"], b"file_to_stream")
+
+    def test_download_simple_file_race(self) -> None:
+        """Test that fetching remote media from two different processes at the
+        same time works.
+        """
+        hs1 = self.make_worker_hs("synapse.app.generic_worker")
+        hs2 = self.make_worker_hs("synapse.app.generic_worker")
+
+        start_count = self._count_remote_media()
+
+        # Make two requests without responding to the outbound media requests.
+        channel1, request1 = self._get_media_req(hs1, "example.com:443", "ABC123")
+        channel2, request2 = self._get_media_req(hs2, "example.com:443", "ABC123")
+
+        # Respond to the first outbound media request and check that the client
+        # request is successful
+        request1.setResponseCode(200)
+        request1.responseHeaders.setRawHeaders(
+            b"Content-Type",
+            ["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"],
+        )
+        request1.write(self.file_data)
+        request1.finish()
+
+        self.pump(0.1)
+
+        self.assertEqual(channel1.code, 200, channel1.result["body"])
+        self.assertEqual(channel1.result["body"], b"file_to_stream")
+
+        # Now respond to the second with the same content.
+        request2.setResponseCode(200)
+        request2.responseHeaders.setRawHeaders(
+            b"Content-Type",
+            ["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"],
+        )
+        request2.write(self.file_data)
+        request2.finish()
+
+        self.pump(0.1)
+
+        self.assertEqual(channel2.code, 200, channel2.result["body"])
+        self.assertEqual(channel2.result["body"], b"file_to_stream")
+
+        # We expect only one new file to have been persisted.
+        self.assertEqual(start_count + 1, self._count_remote_media())
+
+    def test_download_image_race(self) -> None:
+        """Test that fetching remote *images* from two different processes at
+        the same time works.
+
+        This checks that races generating thumbnails are handled correctly.
+        """
+        hs1 = self.make_worker_hs("synapse.app.generic_worker")
+        hs2 = self.make_worker_hs("synapse.app.generic_worker")
+
+        start_count = self._count_remote_thumbnails()
+
+        channel1, request1 = self._get_media_req(hs1, "example.com:443", "PIC1")
+        channel2, request2 = self._get_media_req(hs2, "example.com:443", "PIC1")
+
+        request1.setResponseCode(200)
+        request1.responseHeaders.setRawHeaders(
+            b"Content-Type",
+            ["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"],
+        )
+        img_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: image/png\r\nContent-Disposition: inline; filename=test_img\r\n\r\n"
+        request1.write(img_data)
+        request1.write(SMALL_PNG)
+        request1.write(b"\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n")
+        request1.finish()
+
+        self.pump(0.1)
+
+        self.assertEqual(channel1.code, 200, channel1.result["body"])
+        self.assertEqual(channel1.result["body"], SMALL_PNG)
+
+        request2.setResponseCode(200)
+        request2.responseHeaders.setRawHeaders(
+            b"Content-Type",
+            ["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"],
+        )
+        request2.write(img_data)
+        request2.write(SMALL_PNG)
+        request2.write(b"\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n")
+        request2.finish()
+
+        self.pump(0.1)
+
+        self.assertEqual(channel2.code, 200, channel2.result["body"])
+        self.assertEqual(channel2.result["body"], SMALL_PNG)
+
+        # We expect only three new thumbnails to have been persisted.
+        self.assertEqual(start_count + 3, self._count_remote_thumbnails())
+
+    def _count_remote_media(self) -> int:
+        """Count the number of files in our remote media directory."""
+        path = os.path.join(
+            self.hs.get_media_repository().primary_base_path, "remote_content"
+        )
+        return sum(len(files) for _, _, files in os.walk(path))
+
+    def _count_remote_thumbnails(self) -> int:
+        """Count the number of files in our remote thumbnails directory."""
+        path = os.path.join(
+            self.hs.get_media_repository().primary_base_path, "remote_thumbnail"
+        )
+        return sum(len(files) for _, _, files in os.walk(path))
+
+
 def _log_request(request: Request) -> None:
     """Implements Factory.log, which is expected by Request.finish"""
     logger.info("Completed request %s", request)
diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py
index be4a289ec1..6b5af2dbb6 100644
--- a/tests/rest/client/test_media.py
+++ b/tests/rest/client/test_media.py
@@ -19,31 +19,54 @@
 #
 #
 import base64
+import io
 import json
 import os
 import re
-from typing import Any, Dict, Optional, Sequence, Tuple, Type
+from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Sequence, Tuple, Type
+from unittest.mock import MagicMock, Mock, patch
+from urllib import parse
 from urllib.parse import quote, urlencode
 
+from parameterized import parameterized_class
+
+from twisted.internet import defer
 from twisted.internet._resolver import HostResolution
 from twisted.internet.address import IPv4Address, IPv6Address
+from twisted.internet.defer import Deferred
 from twisted.internet.error import DNSLookupError
 from twisted.internet.interfaces import IAddress, IResolutionReceiver
+from twisted.python.failure import Failure
 from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor
+from twisted.web.http_headers import Headers
+from twisted.web.iweb import UNKNOWN_LENGTH, IResponse
 from twisted.web.resource import Resource
 
+from synapse.api.errors import HttpResponseException
+from synapse.api.ratelimiting import Ratelimiter
 from synapse.config.oembed import OEmbedEndpointConfig
+from synapse.http.client import MultipartResponse
+from synapse.http.types import QueryParams
+from synapse.logging.context import make_deferred_yieldable
 from synapse.media._base import FileInfo
 from synapse.media.url_previewer import IMAGE_CACHE_EXPIRY_MS
 from synapse.rest import admin
 from synapse.rest.client import login, media
 from synapse.server import HomeServer
-from synapse.types import JsonDict
+from synapse.types import JsonDict, UserID
 from synapse.util import Clock
 from synapse.util.stringutils import parse_and_validate_mxc_uri
 
 from tests import unittest
-from tests.server import FakeTransport, ThreadedMemoryReactorClock
+from tests.media.test_media_storage import (
+    SVG,
+    TestImage,
+    empty_file,
+    small_lossless_webp,
+    small_png,
+    small_png_with_transparency,
+)
+from tests.server import FakeChannel, FakeTransport, ThreadedMemoryReactorClock
 from tests.test_utils import SMALL_PNG
 from tests.unittest import override_config
 
@@ -1607,3 +1630,583 @@ class UnstableMediaConfigTest(unittest.HomeserverTestCase):
         self.assertEqual(
             channel.json_body["m.upload.size"], self.hs.config.media.max_upload_size
         )
+
+
+class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        media.register_servlets,
+        login.register_servlets,
+        admin.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+        config = self.default_config()
+
+        self.storage_path = self.mktemp()
+        self.media_store_path = self.mktemp()
+        os.mkdir(self.storage_path)
+        os.mkdir(self.media_store_path)
+        config["media_store_path"] = self.media_store_path
+
+        provider_config = {
+            "module": "synapse.media.storage_provider.FileStorageProviderBackend",
+            "store_local": True,
+            "store_synchronous": False,
+            "store_remote": True,
+            "config": {"directory": self.storage_path},
+        }
+
+        config["media_storage_providers"] = [provider_config]
+
+        return self.setup_test_homeserver(config=config)
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.repo = hs.get_media_repository()
+        self.client = hs.get_federation_http_client()
+        self.store = hs.get_datastores().main
+        self.user = self.register_user("user", "pass")
+        self.tok = self.login("user", "pass")
+
+    # mock actually reading file body
+    def read_multipart_response_30MiB(*args: Any, **kwargs: Any) -> Deferred:
+        d: Deferred = defer.Deferred()
+        d.callback(MultipartResponse(b"{}", 31457280, b"img/png", None))
+        return d
+
+    def read_multipart_response_50MiB(*args: Any, **kwargs: Any) -> Deferred:
+        d: Deferred = defer.Deferred()
+        d.callback(MultipartResponse(b"{}", 31457280, b"img/png", None))
+        return d
+
+    @patch(
+        "synapse.http.matrixfederationclient.read_multipart_response",
+        read_multipart_response_30MiB,
+    )
+    def test_download_ratelimit_default(self) -> None:
+        """
+        Test remote media download ratelimiting against default configuration - 500MB bucket
+        and 87kb/second drain rate
+        """
+
+        # mock out actually sending the request, returns a 30MiB response
+        async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
+            resp = MagicMock(spec=IResponse)
+            resp.code = 200
+            resp.length = 31457280
+            resp.headers = Headers(
+                {"Content-Type": ["multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p"]}
+            )
+            resp.phrase = b"OK"
+            return resp
+
+        self.client._send_request = _send_request  # type: ignore
+
+        # first request should go through
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/v1/media/download/remote.org/abc",
+            shorthand=False,
+            access_token=self.tok,
+        )
+        assert channel.code == 200
+
+        # next 15 should go through
+        for i in range(15):
+            channel2 = self.make_request(
+                "GET",
+                f"/_matrix/client/v1/media/download/remote.org/abc{i}",
+                shorthand=False,
+                access_token=self.tok,
+            )
+            assert channel2.code == 200
+
+        # 17th will hit ratelimit
+        channel3 = self.make_request(
+            "GET",
+            "/_matrix/client/v1/media/download/remote.org/abcd",
+            shorthand=False,
+            access_token=self.tok,
+        )
+        assert channel3.code == 429
+
+        # however, a request from a different IP will go through
+        channel4 = self.make_request(
+            "GET",
+            "/_matrix/client/v1/media/download/remote.org/abcde",
+            shorthand=False,
+            client_ip="187.233.230.159",
+            access_token=self.tok,
+        )
+        assert channel4.code == 200
+
+        # at 87Kib/s it should take about 2 minutes for enough to drain from bucket that another
+        # 30MiB download is authorized - The last download was blocked at 503,316,480.
+        # The next download will be authorized when bucket hits 492,830,720
+        # (524,288,000 total capacity - 31,457,280 download size) so 503,316,480 - 492,830,720 ~= 10,485,760
+        # needs to drain before another download will be authorized, that will take ~=
+        # 2 minutes (10,485,760/89,088/60)
+        self.reactor.pump([2.0 * 60.0])
+
+        # enough has drained and next request goes through
+        channel5 = self.make_request(
+            "GET",
+            "/_matrix/client/v1/media/download/remote.org/abcdef",
+            shorthand=False,
+            access_token=self.tok,
+        )
+        assert channel5.code == 200
+
+    @override_config(
+        {
+            "remote_media_download_per_second": "50M",
+            "remote_media_download_burst_count": "50M",
+        }
+    )
+    @patch(
+        "synapse.http.matrixfederationclient.read_multipart_response",
+        read_multipart_response_50MiB,
+    )
+    def test_download_rate_limit_config(self) -> None:
+        """
+        Test that download rate limit config options are correctly picked up and applied
+        """
+
+        async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
+            resp = MagicMock(spec=IResponse)
+            resp.code = 200
+            resp.length = 52428800
+            resp.headers = Headers(
+                {"Content-Type": ["multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p"]}
+            )
+            resp.phrase = b"OK"
+            return resp
+
+        self.client._send_request = _send_request  # type: ignore
+
+        # first request should go through
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/v1/media/download/remote.org/abc",
+            shorthand=False,
+            access_token=self.tok,
+        )
+        assert channel.code == 200
+
+        # immediate second request should fail
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/v1/media/download/remote.org/abcd",
+            shorthand=False,
+            access_token=self.tok,
+        )
+        assert channel.code == 429
+
+        # advance half a second
+        self.reactor.pump([0.5])
+
+        # request still fails
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/v1/media/download/remote.org/abcde",
+            shorthand=False,
+            access_token=self.tok,
+        )
+        assert channel.code == 429
+
+        # advance another half second
+        self.reactor.pump([0.5])
+
+        # enough has drained from bucket and request is successful
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/v1/media/download/remote.org/abcdef",
+            shorthand=False,
+            access_token=self.tok,
+        )
+        assert channel.code == 200
+
+    @patch(
+        "synapse.http.matrixfederationclient.read_multipart_response",
+        read_multipart_response_30MiB,
+    )
+    def test_download_ratelimit_max_size_sub(self) -> None:
+        """
+        Test that if no content-length is provided, the default max size is applied instead
+        """
+
+        # mock out actually sending the request
+        async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
+            resp = MagicMock(spec=IResponse)
+            resp.code = 200
+            resp.length = UNKNOWN_LENGTH
+            resp.headers = Headers(
+                {"Content-Type": ["multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p"]}
+            )
+            resp.phrase = b"OK"
+            return resp
+
+        self.client._send_request = _send_request  # type: ignore
+
+        # ten requests should go through using the max size (500MB/50MB)
+        for i in range(10):
+            channel2 = self.make_request(
+                "GET",
+                f"/_matrix/client/v1/media/download/remote.org/abc{i}",
+                shorthand=False,
+                access_token=self.tok,
+            )
+            assert channel2.code == 200
+
+        # eleventh will hit ratelimit
+        channel3 = self.make_request(
+            "GET",
+            "/_matrix/client/v1/media/download/remote.org/abcd",
+            shorthand=False,
+            access_token=self.tok,
+        )
+        assert channel3.code == 429
+
+    def test_file_download(self) -> None:
+        content = io.BytesIO(b"file_to_stream")
+        content_uri = self.get_success(
+            self.repo.create_content(
+                "text/plain",
+                "test_upload",
+                content,
+                46,
+                UserID.from_string("@user_id:whatever.org"),
+            )
+        )
+        # test with a text file
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/v1/media/download/test/{content_uri.media_id}",
+            shorthand=False,
+            access_token=self.tok,
+        )
+        self.pump()
+        self.assertEqual(200, channel.code)
+
+
+test_images = [
+    small_png,
+    small_png_with_transparency,
+    small_lossless_webp,
+    empty_file,
+    SVG,
+]
+input_values = [(x,) for x in test_images]
+
+
+@parameterized_class(("test_image",), input_values)
+class DownloadTestCase(unittest.HomeserverTestCase):
+    test_image: ClassVar[TestImage]
+    servlets = [
+        media.register_servlets,
+        login.register_servlets,
+        admin.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+        self.fetches: List[
+            Tuple[
+                "Deferred[Any]",
+                str,
+                str,
+                Optional[QueryParams],
+            ]
+        ] = []
+
+        def federation_get_file(
+            destination: str,
+            path: str,
+            output_stream: BinaryIO,
+            download_ratelimiter: Ratelimiter,
+            ip_address: Any,
+            max_size: int,
+            args: Optional[QueryParams] = None,
+            retry_on_dns_fail: bool = True,
+            ignore_backoff: bool = False,
+            follow_redirects: bool = False,
+        ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]], bytes]]":
+            """A mock for MatrixFederationHttpClient.federation_get_file."""
+
+            def write_to(
+                r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]], bytes]]
+            ) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
+                data, response = r
+                output_stream.write(data)
+                return response
+
+            def write_err(f: Failure) -> Failure:
+                f.trap(HttpResponseException)
+                output_stream.write(f.value.response)
+                return f
+
+            d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]], bytes]]] = (
+                Deferred()
+            )
+            self.fetches.append((d, destination, path, args))
+            # Note that this callback changes the value held by d.
+            d_after_callback = d.addCallbacks(write_to, write_err)
+            return make_deferred_yieldable(d_after_callback)
+
+        def get_file(
+            destination: str,
+            path: str,
+            output_stream: BinaryIO,
+            download_ratelimiter: Ratelimiter,
+            ip_address: Any,
+            max_size: int,
+            args: Optional[QueryParams] = None,
+            retry_on_dns_fail: bool = True,
+            ignore_backoff: bool = False,
+            follow_redirects: bool = False,
+        ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
+            """A mock for MatrixFederationHttpClient.get_file."""
+
+            def write_to(
+                r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]
+            ) -> Tuple[int, Dict[bytes, List[bytes]]]:
+                data, response = r
+                output_stream.write(data)
+                return response
+
+            def write_err(f: Failure) -> Failure:
+                f.trap(HttpResponseException)
+                output_stream.write(f.value.response)
+                return f
+
+            d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred()
+            self.fetches.append((d, destination, path, args))
+            # Note that this callback changes the value held by d.
+            d_after_callback = d.addCallbacks(write_to, write_err)
+            return make_deferred_yieldable(d_after_callback)
+
+        # Mock out the homeserver's MatrixFederationHttpClient
+        client = Mock()
+        client.federation_get_file = federation_get_file
+        client.get_file = get_file
+
+        self.storage_path = self.mktemp()
+        self.media_store_path = self.mktemp()
+        os.mkdir(self.storage_path)
+        os.mkdir(self.media_store_path)
+
+        config = self.default_config()
+        config["media_store_path"] = self.media_store_path
+        config["max_image_pixels"] = 2000000
+
+        provider_config = {
+            "module": "synapse.media.storage_provider.FileStorageProviderBackend",
+            "store_local": True,
+            "store_synchronous": False,
+            "store_remote": True,
+            "config": {"directory": self.storage_path},
+        }
+        config["media_storage_providers"] = [provider_config]
+        config["experimental_features"] = {"msc3916_authenticated_media_enabled": True}
+
+        hs = self.setup_test_homeserver(config=config, federation_http_client=client)
+
+        return hs
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastores().main
+        self.media_repo = hs.get_media_repository()
+
+        self.remote = "example.com"
+        self.media_id = "12345"
+
+        self.user = self.register_user("user", "pass")
+        self.tok = self.login("user", "pass")
+
+    def _req(
+        self, content_disposition: Optional[bytes], include_content_type: bool = True
+    ) -> FakeChannel:
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/v1/media/download/{self.remote}/{self.media_id}",
+            shorthand=False,
+            await_result=False,
+            access_token=self.tok,
+        )
+        self.pump()
+
+        # We've made one fetch, to example.com, using the federation media URL
+        self.assertEqual(len(self.fetches), 1)
+        self.assertEqual(self.fetches[0][1], "example.com")
+        self.assertEqual(
+            self.fetches[0][2], "/_matrix/federation/v1/media/download/" + self.media_id
+        )
+        self.assertEqual(
+            self.fetches[0][3],
+            {"timeout_ms": "20000"},
+        )
+
+        headers = {
+            b"Content-Length": [b"%d" % (len(self.test_image.data))],
+        }
+
+        if include_content_type:
+            headers[b"Content-Type"] = [self.test_image.content_type]
+
+        if content_disposition:
+            headers[b"Content-Disposition"] = [content_disposition]
+
+        self.fetches[0][0].callback(
+            (self.test_image.data, (len(self.test_image.data), headers, b"{}"))
+        )
+
+        self.pump()
+        self.assertEqual(channel.code, 200)
+
+        return channel
+
+    def test_handle_missing_content_type(self) -> None:
+        channel = self._req(
+            b"attachment; filename=out" + self.test_image.extension,
+            include_content_type=False,
+        )
+        headers = channel.headers
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(
+            headers.getRawHeaders(b"Content-Type"), [b"application/octet-stream"]
+        )
+
+    def test_disposition_filename_ascii(self) -> None:
+        """
+        If the filename is filename=<ascii> then Synapse will decode it as an
+        ASCII string, and use filename= in the response.
+        """
+        channel = self._req(b"attachment; filename=out" + self.test_image.extension)
+
+        headers = channel.headers
+        self.assertEqual(
+            headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
+        )
+        self.assertEqual(
+            headers.getRawHeaders(b"Content-Disposition"),
+            [
+                (b"inline" if self.test_image.is_inline else b"attachment")
+                + b"; filename=out"
+                + self.test_image.extension
+            ],
+        )
+
+    def test_disposition_filenamestar_utf8escaped(self) -> None:
+        """
+        If the filename is filename=*utf8''<utf8 escaped> then Synapse will
+        correctly decode it as the UTF-8 string, and use filename* in the
+        response.
+        """
+        filename = parse.quote("\u2603".encode()).encode("ascii")
+        channel = self._req(
+            b"attachment; filename*=utf-8''" + filename + self.test_image.extension
+        )
+
+        headers = channel.headers
+        self.assertEqual(
+            headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
+        )
+        self.assertEqual(
+            headers.getRawHeaders(b"Content-Disposition"),
+            [
+                (b"inline" if self.test_image.is_inline else b"attachment")
+                + b"; filename*=utf-8''"
+                + filename
+                + self.test_image.extension
+            ],
+        )
+
+    def test_disposition_none(self) -> None:
+        """
+        If there is no filename, Content-Disposition should only
+        be a disposition type.
+        """
+        channel = self._req(None)
+
+        headers = channel.headers
+        self.assertEqual(
+            headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
+        )
+        self.assertEqual(
+            headers.getRawHeaders(b"Content-Disposition"),
+            [b"inline" if self.test_image.is_inline else b"attachment"],
+        )
+
+    def test_x_robots_tag_header(self) -> None:
+        """
+        Tests that the `X-Robots-Tag` header is present, which informs web crawlers
+        to not index, archive, or follow links in media.
+        """
+        channel = self._req(b"attachment; filename=out" + self.test_image.extension)
+
+        headers = channel.headers
+        self.assertEqual(
+            headers.getRawHeaders(b"X-Robots-Tag"),
+            [b"noindex, nofollow, noarchive, noimageindex"],
+        )
+
+    def test_cross_origin_resource_policy_header(self) -> None:
+        """
+        Test that the Cross-Origin-Resource-Policy header is set to "cross-origin"
+        allowing web clients to embed media from the downloads API.
+        """
+        channel = self._req(b"attachment; filename=out" + self.test_image.extension)
+
+        headers = channel.headers
+
+        self.assertEqual(
+            headers.getRawHeaders(b"Cross-Origin-Resource-Policy"),
+            [b"cross-origin"],
+        )
+
+    def test_unknown_federation_endpoint(self) -> None:
+        """
+        Test that if the downloadd request to remote federation endpoint returns a 404
+        we fall back to the _matrix/media endpoint
+        """
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/v1/media/download/{self.remote}/{self.media_id}",
+            shorthand=False,
+            await_result=False,
+            access_token=self.tok,
+        )
+        self.pump()
+
+        # We've made one fetch, to example.com, using the media URL, and asking
+        # the other server not to do a remote fetch
+        self.assertEqual(len(self.fetches), 1)
+        self.assertEqual(self.fetches[0][1], "example.com")
+        self.assertEqual(
+            self.fetches[0][2], f"/_matrix/federation/v1/media/download/{self.media_id}"
+        )
+
+        # The result which says the endpoint is unknown.
+        unknown_endpoint = b'{"errcode":"M_UNRECOGNIZED","error":"Unknown request"}'
+        self.fetches[0][0].errback(
+            HttpResponseException(404, "NOT FOUND", unknown_endpoint)
+        )
+
+        self.pump()
+
+        # There should now be another request to the _matrix/media/v3/download URL.
+        self.assertEqual(len(self.fetches), 2)
+        self.assertEqual(self.fetches[1][1], "example.com")
+        self.assertEqual(
+            self.fetches[1][2],
+            f"/_matrix/media/v3/download/example.com/{self.media_id}",
+        )
+
+        headers = {
+            b"Content-Length": [b"%d" % (len(self.test_image.data))],
+        }
+
+        self.fetches[1][0].callback(
+            (self.test_image.data, (len(self.test_image.data), headers))
+        )
+
+        self.pump()
+        self.assertEqual(channel.code, 200)