summary refs log tree commit diff
diff options
context:
space:
mode:
authorShay <hillerys@element.io>2024-06-05 05:43:36 -0700
committerGitHub <noreply@github.com>2024-06-05 13:43:36 +0100
commitfcbc79bb87d08147e86dafa0fee5a9aec4d3fc23 (patch)
tree5d5efc029b3410f96ce9aa5dc49b55e7ef4755c5
parentHandle hyphens in user dir search porperly (#17254) (diff)
downloadsynapse-fcbc79bb87d08147e86dafa0fee5a9aec4d3fc23.tar.xz
Ratelimiting of remote media downloads (#17256)
-rw-r--r--changelog.d/17256.feature1
-rw-r--r--docs/usage/configuration/config_documentation.md18
-rw-r--r--synapse/config/ratelimiting.py10
-rw-r--r--synapse/federation/federation_client.py7
-rw-r--r--synapse/federation/transport/client.py9
-rw-r--r--synapse/http/matrixfederationclient.py55
-rw-r--r--synapse/media/media_repository.py43
-rw-r--r--synapse/media/thumbnailer.py6
-rw-r--r--synapse/rest/client/media.py2
-rw-r--r--synapse/rest/media/download_resource.py8
-rw-r--r--synapse/rest/media/thumbnail_resource.py2
-rw-r--r--tests/media/test_media_storage.py225
12 files changed, 372 insertions, 14 deletions
diff --git a/changelog.d/17256.feature b/changelog.d/17256.feature
new file mode 100644
index 0000000000..6ec4cb7a31
--- /dev/null
+++ b/changelog.d/17256.feature
@@ -0,0 +1 @@
+ Improve ratelimiting in Synapse (#17256).
\ No newline at end of file
diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md
index 2c917d1f8e..d23f8c4c4f 100644
--- a/docs/usage/configuration/config_documentation.md
+++ b/docs/usage/configuration/config_documentation.md
@@ -1946,6 +1946,24 @@ Example configuration:
 max_image_pixels: 35M
 ```
 ---
+### `remote_media_download_burst_count`
+
+Remote media downloads are ratelimited using a [leaky bucket algorithm](https://en.wikipedia.org/wiki/Leaky_bucket), where a given "bucket" is keyed to the IP address of the requester when requesting remote media downloads. This configuration option sets the size of the bucket against which the size in bytes of downloads are penalized - if the bucket is full, ie a given number of bytes have already been downloaded, further downloads will be denied until the bucket drains.  Defaults to 500MiB. See also `remote_media_download_per_second` which determines the rate at which the "bucket" is emptied and thus has available space to authorize new requests.  
+
+Example configuration:
+```yaml
+remote_media_download_burst_count: 200M
+```
+---
+### `remote_media_download_per_second`
+
+Works in conjunction with `remote_media_download_burst_count` to ratelimit remote media downloads - this configuration option determines the rate at which the "bucket" (see above) leaks in bytes per second. As requests are made to download remote media, the size of those requests in bytes is added to the bucket, and once the bucket has reached it's capacity, no more requests will be allowed until a number of bytes has "drained" from the bucket. This setting determines the rate at which bytes drain from the bucket, with the practical effect that the larger the number, the faster the bucket leaks, allowing for more bytes downloaded over a shorter period of time. Defaults to 87KiB per second. See also `remote_media_download_burst_count`.
+
+Example configuration:
+```yaml
+remote_media_download_per_second: 40K
+```
+---
 ### `prevent_media_downloads_from`
 
 A list of domains to never download media from. Media from these
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index d2cb4576df..3fa33f5373 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -218,3 +218,13 @@ class RatelimitConfig(Config):
             "rc_media_create",
             defaults={"per_second": 10, "burst_count": 50},
         )
+
+        self.remote_media_downloads = RatelimitSettings(
+            key="rc_remote_media_downloads",
+            per_second=self.parse_size(
+                config.get("remote_media_download_per_second", "87K")
+            ),
+            burst_count=self.parse_size(
+                config.get("remote_media_download_burst_count", "500M")
+            ),
+        )
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index e613eb87a6..f0f5a37a57 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -56,6 +56,7 @@ from synapse.api.errors import (
     SynapseError,
     UnsupportedRoomVersionError,
 )
+from synapse.api.ratelimiting import Ratelimiter
 from synapse.api.room_versions import (
     KNOWN_ROOM_VERSIONS,
     EventFormatVersions,
@@ -1877,6 +1878,8 @@ class FederationClient(FederationBase):
         output_stream: BinaryIO,
         max_size: int,
         max_timeout_ms: int,
+        download_ratelimiter: Ratelimiter,
+        ip_address: str,
     ) -> Tuple[int, Dict[bytes, List[bytes]]]:
         try:
             return await self.transport_layer.download_media_v3(
@@ -1885,6 +1888,8 @@ class FederationClient(FederationBase):
                 output_stream=output_stream,
                 max_size=max_size,
                 max_timeout_ms=max_timeout_ms,
+                download_ratelimiter=download_ratelimiter,
+                ip_address=ip_address,
             )
         except HttpResponseException as e:
             # If an error is received that is due to an unrecognised endpoint,
@@ -1905,6 +1910,8 @@ class FederationClient(FederationBase):
             output_stream=output_stream,
             max_size=max_size,
             max_timeout_ms=max_timeout_ms,
+            download_ratelimiter=download_ratelimiter,
+            ip_address=ip_address,
         )
 
 
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index de408f7f8d..af1336fe5f 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -43,6 +43,7 @@ import ijson
 
 from synapse.api.constants import Direction, Membership
 from synapse.api.errors import Codes, HttpResponseException, SynapseError
+from synapse.api.ratelimiting import Ratelimiter
 from synapse.api.room_versions import RoomVersion
 from synapse.api.urls import (
     FEDERATION_UNSTABLE_PREFIX,
@@ -819,6 +820,8 @@ class TransportLayerClient:
         output_stream: BinaryIO,
         max_size: int,
         max_timeout_ms: int,
+        download_ratelimiter: Ratelimiter,
+        ip_address: str,
     ) -> Tuple[int, Dict[bytes, List[bytes]]]:
         path = f"/_matrix/media/r0/download/{destination}/{media_id}"
 
@@ -834,6 +837,8 @@ class TransportLayerClient:
                 "allow_remote": "false",
                 "timeout_ms": str(max_timeout_ms),
             },
+            download_ratelimiter=download_ratelimiter,
+            ip_address=ip_address,
         )
 
     async def download_media_v3(
@@ -843,6 +848,8 @@ class TransportLayerClient:
         output_stream: BinaryIO,
         max_size: int,
         max_timeout_ms: int,
+        download_ratelimiter: Ratelimiter,
+        ip_address: str,
     ) -> Tuple[int, Dict[bytes, List[bytes]]]:
         path = f"/_matrix/media/v3/download/{destination}/{media_id}"
 
@@ -862,6 +869,8 @@ class TransportLayerClient:
                 "allow_redirect": "true",
             },
             follow_redirects=True,
+            download_ratelimiter=download_ratelimiter,
+            ip_address=ip_address,
         )
 
 
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index c73a589e6c..104b803b0f 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -57,7 +57,7 @@ from twisted.internet.interfaces import IReactorTime
 from twisted.internet.task import Cooperator
 from twisted.web.client import ResponseFailed
 from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent, IBodyProducer, IResponse
+from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse
 
 import synapse.metrics
 import synapse.util.retryutils
@@ -68,6 +68,7 @@ from synapse.api.errors import (
     RequestSendFailed,
     SynapseError,
 )
+from synapse.api.ratelimiting import Ratelimiter
 from synapse.crypto.context_factory import FederationPolicyForHTTPS
 from synapse.http import QuieterFileBodyProducer
 from synapse.http.client import (
@@ -1411,9 +1412,11 @@ class MatrixFederationHttpClient:
         destination: str,
         path: str,
         output_stream: BinaryIO,
+        download_ratelimiter: Ratelimiter,
+        ip_address: str,
+        max_size: int,
         args: Optional[QueryParams] = None,
         retry_on_dns_fail: bool = True,
-        max_size: Optional[int] = None,
         ignore_backoff: bool = False,
         follow_redirects: bool = False,
     ) -> Tuple[int, Dict[bytes, List[bytes]]]:
@@ -1422,6 +1425,10 @@ class MatrixFederationHttpClient:
             destination: The remote server to send the HTTP request to.
             path: The HTTP path to GET.
             output_stream: File to write the response body to.
+            download_ratelimiter: a ratelimiter to limit remote media downloads, keyed to
+                requester IP
+            ip_address: IP address of the requester
+            max_size: maximum allowable size in bytes of the file
             args: Optional dictionary used to create the query string.
             ignore_backoff: true to ignore the historical backoff data
                 and try the request anyway.
@@ -1441,11 +1448,27 @@ class MatrixFederationHttpClient:
                 federation whitelist
             RequestSendFailed: If there were problems connecting to the
                 remote, due to e.g. DNS failures, connection timeouts etc.
+            SynapseError: If the requested file exceeds ratelimits
         """
         request = MatrixFederationRequest(
             method="GET", destination=destination, path=path, query=args
         )
 
+        # check for a minimum balance of 1MiB in ratelimiter before initiating request
+        send_req, _ = await download_ratelimiter.can_do_action(
+            requester=None, key=ip_address, n_actions=1048576, update=False
+        )
+
+        if not send_req:
+            msg = "Requested file size exceeds ratelimits"
+            logger.warning(
+                "{%s} [%s] %s",
+                request.txn_id,
+                request.destination,
+                msg,
+            )
+            raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
+
         response = await self._send_request(
             request,
             retry_on_dns_fail=retry_on_dns_fail,
@@ -1455,12 +1478,36 @@ class MatrixFederationHttpClient:
 
         headers = dict(response.headers.getAllRawHeaders())
 
+        expected_size = response.length
+        # if we don't get an expected length then use the max length
+        if expected_size == UNKNOWN_LENGTH:
+            expected_size = max_size
+            logger.debug(
+                f"File size unknown, assuming file is max allowable size: {max_size}"
+            )
+
+        read_body, _ = await download_ratelimiter.can_do_action(
+            requester=None,
+            key=ip_address,
+            n_actions=expected_size,
+        )
+        if not read_body:
+            msg = "Requested file size exceeds ratelimits"
+            logger.warning(
+                "{%s} [%s] %s",
+                request.txn_id,
+                request.destination,
+                msg,
+            )
+            raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
+
         try:
-            d = read_body_with_max_size(response, output_stream, max_size)
+            # add a byte of headroom to max size as function errs at >=
+            d = read_body_with_max_size(response, output_stream, expected_size + 1)
             d.addTimeout(self.default_timeout_seconds, self.reactor)
             length = await make_deferred_yieldable(d)
         except BodyExceededMaxSize:
-            msg = "Requested file is too large > %r bytes" % (max_size,)
+            msg = "Requested file is too large > %r bytes" % (expected_size,)
             logger.warning(
                 "{%s} [%s] %s",
                 request.txn_id,
diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 9c29e09653..6ed56099ca 100644
--- a/synapse/media/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -42,6 +42,7 @@ from synapse.api.errors import (
     SynapseError,
     cs_error,
 )
+from synapse.api.ratelimiting import Ratelimiter
 from synapse.config.repository import ThumbnailRequirement
 from synapse.http.server import respond_with_json
 from synapse.http.site import SynapseRequest
@@ -111,6 +112,12 @@ class MediaRepository:
         )
         self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from
 
+        self.download_ratelimiter = Ratelimiter(
+            store=hs.get_storage_controllers().main,
+            clock=hs.get_clock(),
+            cfg=hs.config.ratelimiting.remote_media_downloads,
+        )
+
         # List of StorageProviders where we should search for media and
         # potentially upload to.
         storage_providers = []
@@ -464,6 +471,7 @@ class MediaRepository:
         media_id: str,
         name: Optional[str],
         max_timeout_ms: int,
+        ip_address: str,
     ) -> None:
         """Respond to requests for remote media.
 
@@ -475,6 +483,7 @@ class MediaRepository:
                 the filename in the Content-Disposition header of the response.
             max_timeout_ms: the maximum number of milliseconds to wait for the
                 media to be uploaded.
+            ip_address: the IP address of the requester
 
         Returns:
             Resolves once a response has successfully been written to request
@@ -500,7 +509,11 @@ class MediaRepository:
         key = (server_name, media_id)
         async with self.remote_media_linearizer.queue(key):
             responder, media_info = await self._get_remote_media_impl(
-                server_name, media_id, max_timeout_ms
+                server_name,
+                media_id,
+                max_timeout_ms,
+                self.download_ratelimiter,
+                ip_address,
             )
 
         # We deliberately stream the file outside the lock
@@ -517,7 +530,7 @@ class MediaRepository:
             respond_404(request)
 
     async def get_remote_media_info(
-        self, server_name: str, media_id: str, max_timeout_ms: int
+        self, server_name: str, media_id: str, max_timeout_ms: int, ip_address: str
     ) -> RemoteMedia:
         """Gets the media info associated with the remote file, downloading
         if necessary.
@@ -527,6 +540,7 @@ class MediaRepository:
             media_id: The media ID of the content (as defined by the remote server).
             max_timeout_ms: the maximum number of milliseconds to wait for the
                 media to be uploaded.
+            ip_address: IP address of the requester
 
         Returns:
             The media info of the file
@@ -542,7 +556,11 @@ class MediaRepository:
         key = (server_name, media_id)
         async with self.remote_media_linearizer.queue(key):
             responder, media_info = await self._get_remote_media_impl(
-                server_name, media_id, max_timeout_ms
+                server_name,
+                media_id,
+                max_timeout_ms,
+                self.download_ratelimiter,
+                ip_address,
             )
 
         # Ensure we actually use the responder so that it releases resources
@@ -553,7 +571,12 @@ class MediaRepository:
         return media_info
 
     async def _get_remote_media_impl(
-        self, server_name: str, media_id: str, max_timeout_ms: int
+        self,
+        server_name: str,
+        media_id: str,
+        max_timeout_ms: int,
+        download_ratelimiter: Ratelimiter,
+        ip_address: str,
     ) -> Tuple[Optional[Responder], RemoteMedia]:
         """Looks for media in local cache, if not there then attempt to
         download from remote server.
@@ -564,6 +587,9 @@ class MediaRepository:
                 remote server).
             max_timeout_ms: the maximum number of milliseconds to wait for the
                 media to be uploaded.
+            download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
+                requester IP.
+            ip_address: the IP address of the requester
 
         Returns:
             A tuple of responder and the media info of the file.
@@ -596,7 +622,7 @@ class MediaRepository:
 
         try:
             media_info = await self._download_remote_file(
-                server_name, media_id, max_timeout_ms
+                server_name, media_id, max_timeout_ms, download_ratelimiter, ip_address
             )
         except SynapseError:
             raise
@@ -630,6 +656,8 @@ class MediaRepository:
         server_name: str,
         media_id: str,
         max_timeout_ms: int,
+        download_ratelimiter: Ratelimiter,
+        ip_address: str,
     ) -> RemoteMedia:
         """Attempt to download the remote file from the given server name,
         using the given file_id as the local id.
@@ -641,6 +669,9 @@ class MediaRepository:
                 locally generated.
             max_timeout_ms: the maximum number of milliseconds to wait for the
                 media to be uploaded.
+            download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
+                requester IP
+            ip_address: the IP address of the requester
 
         Returns:
             The media info of the file.
@@ -658,6 +689,8 @@ class MediaRepository:
                     output_stream=f,
                     max_size=self.max_upload_size,
                     max_timeout_ms=max_timeout_ms,
+                    download_ratelimiter=download_ratelimiter,
+                    ip_address=ip_address,
                 )
             except RequestSendFailed as e:
                 logger.warning(
diff --git a/synapse/media/thumbnailer.py b/synapse/media/thumbnailer.py
index cc3acf51e1..f8a9560784 100644
--- a/synapse/media/thumbnailer.py
+++ b/synapse/media/thumbnailer.py
@@ -359,9 +359,10 @@ class ThumbnailProvider:
         desired_method: str,
         desired_type: str,
         max_timeout_ms: int,
+        ip_address: str,
     ) -> None:
         media_info = await self.media_repo.get_remote_media_info(
-            server_name, media_id, max_timeout_ms
+            server_name, media_id, max_timeout_ms, ip_address
         )
         if not media_info:
             respond_404(request)
@@ -422,12 +423,13 @@ class ThumbnailProvider:
         method: str,
         m_type: str,
         max_timeout_ms: int,
+        ip_address: str,
     ) -> None:
         # TODO: Don't download the whole remote file
         # We should proxy the thumbnail from the remote server instead of
         # downloading the remote file and generating our own thumbnails.
         media_info = await self.media_repo.get_remote_media_info(
-            server_name, media_id, max_timeout_ms
+            server_name, media_id, max_timeout_ms, ip_address
         )
         if not media_info:
             return
diff --git a/synapse/rest/client/media.py b/synapse/rest/client/media.py
index 172d240783..0c089163c1 100644
--- a/synapse/rest/client/media.py
+++ b/synapse/rest/client/media.py
@@ -174,6 +174,7 @@ class UnstableThumbnailResource(RestServlet):
                 respond_404(request)
                 return
 
+            ip_address = request.getClientAddress().host
             remote_resp_function = (
                 self.thumbnailer.select_or_generate_remote_thumbnail
                 if self.dynamic_thumbnails
@@ -188,6 +189,7 @@ class UnstableThumbnailResource(RestServlet):
                 method,
                 m_type,
                 max_timeout_ms,
+                ip_address,
             )
             self.media_repo.mark_recently_accessed(server_name, media_id)
 
diff --git a/synapse/rest/media/download_resource.py b/synapse/rest/media/download_resource.py
index 8ba723c8d4..1628d58926 100644
--- a/synapse/rest/media/download_resource.py
+++ b/synapse/rest/media/download_resource.py
@@ -97,6 +97,12 @@ class DownloadResource(RestServlet):
                 respond_404(request)
                 return
 
+            ip_address = request.getClientAddress().host
             await self.media_repo.get_remote_media(
-                request, server_name, media_id, file_name, max_timeout_ms
+                request,
+                server_name,
+                media_id,
+                file_name,
+                max_timeout_ms,
+                ip_address,
             )
diff --git a/synapse/rest/media/thumbnail_resource.py b/synapse/rest/media/thumbnail_resource.py
index fe8fbb06e4..ce511c6dce 100644
--- a/synapse/rest/media/thumbnail_resource.py
+++ b/synapse/rest/media/thumbnail_resource.py
@@ -104,6 +104,7 @@ class ThumbnailResource(RestServlet):
                 respond_404(request)
                 return
 
+            ip_address = request.getClientAddress().host
             remote_resp_function = (
                 self.thumbnail_provider.select_or_generate_remote_thumbnail
                 if self.dynamic_thumbnails
@@ -118,5 +119,6 @@ class ThumbnailResource(RestServlet):
                 method,
                 m_type,
                 max_timeout_ms,
+                ip_address,
             )
             self.media_repo.mark_recently_accessed(server_name, media_id)
diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py
index 1bd51ceba2..46d20ce775 100644
--- a/tests/media/test_media_storage.py
+++ b/tests/media/test_media_storage.py
@@ -25,7 +25,7 @@ import tempfile
 from binascii import unhexlify
 from io import BytesIO
 from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Tuple, Union
-from unittest.mock import Mock
+from unittest.mock import MagicMock, Mock, patch
 from urllib import parse
 
 import attr
@@ -37,9 +37,12 @@ from twisted.internet import defer
 from twisted.internet.defer import Deferred
 from twisted.python.failure import Failure
 from twisted.test.proto_helpers import MemoryReactor
+from twisted.web.http_headers import Headers
+from twisted.web.iweb import UNKNOWN_LENGTH, IResponse
 from twisted.web.resource import Resource
 
 from synapse.api.errors import Codes, HttpResponseException
+from synapse.api.ratelimiting import Ratelimiter
 from synapse.events import EventBase
 from synapse.http.types import QueryParams
 from synapse.logging.context import make_deferred_yieldable
@@ -59,6 +62,7 @@ from synapse.util import Clock
 from tests import unittest
 from tests.server import FakeChannel
 from tests.test_utils import SMALL_PNG
+from tests.unittest import override_config
 from tests.utils import default_config
 
 
@@ -251,9 +255,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
             destination: str,
             path: str,
             output_stream: BinaryIO,
+            download_ratelimiter: Ratelimiter,
+            ip_address: Any,
+            max_size: int,
             args: Optional[QueryParams] = None,
             retry_on_dns_fail: bool = True,
-            max_size: Optional[int] = None,
             ignore_backoff: bool = False,
             follow_redirects: bool = False,
         ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
@@ -878,3 +884,218 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
             tok=self.tok,
             expect_code=400,
         )
+
+
+class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+        config = self.default_config()
+
+        self.storage_path = self.mktemp()
+        self.media_store_path = self.mktemp()
+        os.mkdir(self.storage_path)
+        os.mkdir(self.media_store_path)
+        config["media_store_path"] = self.media_store_path
+
+        provider_config = {
+            "module": "synapse.media.storage_provider.FileStorageProviderBackend",
+            "store_local": True,
+            "store_synchronous": False,
+            "store_remote": True,
+            "config": {"directory": self.storage_path},
+        }
+
+        config["media_storage_providers"] = [provider_config]
+
+        return self.setup_test_homeserver(config=config)
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.repo = hs.get_media_repository()
+        self.client = hs.get_federation_http_client()
+        self.store = hs.get_datastores().main
+
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        # We need to manually set the resource tree to include media, the
+        # default only does `/_matrix/client` APIs.
+        return {"/_matrix/media": self.hs.get_media_repository_resource()}
+
+    # mock actually reading file body
+    def read_body_with_max_size_30MiB(*args: Any, **kwargs: Any) -> Deferred:
+        d: Deferred = defer.Deferred()
+        d.callback(31457280)
+        return d
+
+    def read_body_with_max_size_50MiB(*args: Any, **kwargs: Any) -> Deferred:
+        d: Deferred = defer.Deferred()
+        d.callback(52428800)
+        return d
+
+    @patch(
+        "synapse.http.matrixfederationclient.read_body_with_max_size",
+        read_body_with_max_size_30MiB,
+    )
+    def test_download_ratelimit_default(self) -> None:
+        """
+        Test remote media download ratelimiting against default configuration - 500MB bucket
+        and 87kb/second drain rate
+        """
+
+        # mock out actually sending the request, returns a 30MiB response
+        async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
+            resp = MagicMock(spec=IResponse)
+            resp.code = 200
+            resp.length = 31457280
+            resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
+            resp.phrase = b"OK"
+            return resp
+
+        self.client._send_request = _send_request  # type: ignore
+
+        # first request should go through
+        channel = self.make_request(
+            "GET",
+            "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz",
+            shorthand=False,
+        )
+        assert channel.code == 200
+
+        # next 15 should go through
+        for i in range(15):
+            channel2 = self.make_request(
+                "GET",
+                f"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy{i}",
+                shorthand=False,
+            )
+            assert channel2.code == 200
+
+        # 17th will hit ratelimit
+        channel3 = self.make_request(
+            "GET",
+            "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyx",
+            shorthand=False,
+        )
+        assert channel3.code == 429
+
+        # however, a request from a different IP will go through
+        channel4 = self.make_request(
+            "GET",
+            "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz",
+            shorthand=False,
+            client_ip="187.233.230.159",
+        )
+        assert channel4.code == 200
+
+        # at 87Kib/s it should take about 2 minutes for enough to drain from bucket that another
+        # 30MiB download is authorized - The last download was blocked at 503,316,480.
+        # The next download will be authorized when bucket hits 492,830,720
+        # (524,288,000 total capacity - 31,457,280 download size) so 503,316,480 - 492,830,720 ~= 10,485,760
+        # needs to drain before another download will be authorized, that will take ~=
+        # 2 minutes (10,485,760/89,088/60)
+        self.reactor.pump([2.0 * 60.0])
+
+        # enough has drained and next request goes through
+        channel5 = self.make_request(
+            "GET",
+            "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyb",
+            shorthand=False,
+        )
+        assert channel5.code == 200
+
+    @override_config(
+        {
+            "remote_media_download_per_second": "50M",
+            "remote_media_download_burst_count": "50M",
+        }
+    )
+    @patch(
+        "synapse.http.matrixfederationclient.read_body_with_max_size",
+        read_body_with_max_size_50MiB,
+    )
+    def test_download_rate_limit_config(self) -> None:
+        """
+        Test that download rate limit config options are correctly picked up and applied
+        """
+
+        async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
+            resp = MagicMock(spec=IResponse)
+            resp.code = 200
+            resp.length = 52428800
+            resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
+            resp.phrase = b"OK"
+            return resp
+
+        self.client._send_request = _send_request  # type: ignore
+
+        # first request should go through
+        channel = self.make_request(
+            "GET",
+            "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz",
+            shorthand=False,
+        )
+        assert channel.code == 200
+
+        # immediate second request should fail
+        channel = self.make_request(
+            "GET",
+            "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy1",
+            shorthand=False,
+        )
+        assert channel.code == 429
+
+        # advance half a second
+        self.reactor.pump([0.5])
+
+        # request still fails
+        channel = self.make_request(
+            "GET",
+            "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy2",
+            shorthand=False,
+        )
+        assert channel.code == 429
+
+        # advance another half second
+        self.reactor.pump([0.5])
+
+        # enough has drained from bucket and request is successful
+        channel = self.make_request(
+            "GET",
+            "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy3",
+            shorthand=False,
+        )
+        assert channel.code == 200
+
+    @patch(
+        "synapse.http.matrixfederationclient.read_body_with_max_size",
+        read_body_with_max_size_30MiB,
+    )
+    def test_download_ratelimit_max_size_sub(self) -> None:
+        """
+        Test that if no content-length is provided, the default max size is applied instead
+        """
+
+        # mock out actually sending the request
+        async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
+            resp = MagicMock(spec=IResponse)
+            resp.code = 200
+            resp.length = UNKNOWN_LENGTH
+            resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
+            resp.phrase = b"OK"
+            return resp
+
+        self.client._send_request = _send_request  # type: ignore
+
+        # ten requests should go through using the max size (500MB/50MB)
+        for i in range(10):
+            channel2 = self.make_request(
+                "GET",
+                f"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy{i}",
+                shorthand=False,
+            )
+            assert channel2.code == 200
+
+        # eleventh will hit ratelimit
+        channel3 = self.make_request(
+            "GET",
+            "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyx",
+            shorthand=False,
+        )
+        assert channel3.code == 429