diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py
index 1bd51ceba2..46d20ce775 100644
--- a/tests/media/test_media_storage.py
+++ b/tests/media/test_media_storage.py
@@ -25,7 +25,7 @@ import tempfile
from binascii import unhexlify
from io import BytesIO
from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Tuple, Union
-from unittest.mock import Mock
+from unittest.mock import MagicMock, Mock, patch
from urllib import parse
import attr
@@ -37,9 +37,12 @@ from twisted.internet import defer
from twisted.internet.defer import Deferred
from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactor
+from twisted.web.http_headers import Headers
+from twisted.web.iweb import UNKNOWN_LENGTH, IResponse
from twisted.web.resource import Resource
from synapse.api.errors import Codes, HttpResponseException
+from synapse.api.ratelimiting import Ratelimiter
from synapse.events import EventBase
from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable
@@ -59,6 +62,7 @@ from synapse.util import Clock
from tests import unittest
from tests.server import FakeChannel
from tests.test_utils import SMALL_PNG
+from tests.unittest import override_config
from tests.utils import default_config
@@ -251,9 +255,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
destination: str,
path: str,
output_stream: BinaryIO,
+ download_ratelimiter: Ratelimiter,
+ ip_address: Any,
+ max_size: int,
args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True,
- max_size: Optional[int] = None,
ignore_backoff: bool = False,
follow_redirects: bool = False,
) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
@@ -878,3 +884,218 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
tok=self.tok,
expect_code=400,
)
+
+
+class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ config = self.default_config()
+
+ self.storage_path = self.mktemp()
+ self.media_store_path = self.mktemp()
+ os.mkdir(self.storage_path)
+ os.mkdir(self.media_store_path)
+ config["media_store_path"] = self.media_store_path
+
+ provider_config = {
+ "module": "synapse.media.storage_provider.FileStorageProviderBackend",
+ "store_local": True,
+ "store_synchronous": False,
+ "store_remote": True,
+ "config": {"directory": self.storage_path},
+ }
+
+ config["media_storage_providers"] = [provider_config]
+
+ return self.setup_test_homeserver(config=config)
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.repo = hs.get_media_repository()
+ self.client = hs.get_federation_http_client()
+ self.store = hs.get_datastores().main
+
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ # We need to manually set the resource tree to include media, the
+ # default only does `/_matrix/client` APIs.
+ return {"/_matrix/media": self.hs.get_media_repository_resource()}
+
+ # mock actually reading file body
+ def read_body_with_max_size_30MiB(*args: Any, **kwargs: Any) -> Deferred:
+ d: Deferred = defer.Deferred()
+ d.callback(31457280)
+ return d
+
+ def read_body_with_max_size_50MiB(*args: Any, **kwargs: Any) -> Deferred:
+ d: Deferred = defer.Deferred()
+ d.callback(52428800)
+ return d
+
+ @patch(
+ "synapse.http.matrixfederationclient.read_body_with_max_size",
+ read_body_with_max_size_30MiB,
+ )
+ def test_download_ratelimit_default(self) -> None:
+ """
+ Test remote media download ratelimiting against default configuration - 500MB bucket
+ and 87kb/second drain rate
+ """
+
+ # mock out actually sending the request, returns a 30MiB response
+ async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
+ resp = MagicMock(spec=IResponse)
+ resp.code = 200
+ resp.length = 31457280
+ resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
+ resp.phrase = b"OK"
+ return resp
+
+ self.client._send_request = _send_request # type: ignore
+
+ # first request should go through
+ channel = self.make_request(
+ "GET",
+ "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz",
+ shorthand=False,
+ )
+ assert channel.code == 200
+
+ # next 15 should go through
+ for i in range(15):
+ channel2 = self.make_request(
+ "GET",
+ f"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy{i}",
+ shorthand=False,
+ )
+ assert channel2.code == 200
+
+ # 17th will hit ratelimit
+ channel3 = self.make_request(
+ "GET",
+ "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyx",
+ shorthand=False,
+ )
+ assert channel3.code == 429
+
+ # however, a request from a different IP will go through
+ channel4 = self.make_request(
+ "GET",
+ "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz",
+ shorthand=False,
+ client_ip="187.233.230.159",
+ )
+ assert channel4.code == 200
+
+ # at 87Kib/s it should take about 2 minutes for enough to drain from bucket that another
+ # 30MiB download is authorized - The last download was blocked at 503,316,480.
+ # The next download will be authorized when bucket hits 492,830,720
+ # (524,288,000 total capacity - 31,457,280 download size) so 503,316,480 - 492,830,720 ~= 10,485,760
+ # needs to drain before another download will be authorized, that will take ~=
+ # 2 minutes (10,485,760/89,088/60)
+ self.reactor.pump([2.0 * 60.0])
+
+ # enough has drained and next request goes through
+ channel5 = self.make_request(
+ "GET",
+ "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyb",
+ shorthand=False,
+ )
+ assert channel5.code == 200
+
+ @override_config(
+ {
+ "remote_media_download_per_second": "50M",
+ "remote_media_download_burst_count": "50M",
+ }
+ )
+ @patch(
+ "synapse.http.matrixfederationclient.read_body_with_max_size",
+ read_body_with_max_size_50MiB,
+ )
+ def test_download_rate_limit_config(self) -> None:
+ """
+ Test that download rate limit config options are correctly picked up and applied
+ """
+
+ async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
+ resp = MagicMock(spec=IResponse)
+ resp.code = 200
+ resp.length = 52428800
+ resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
+ resp.phrase = b"OK"
+ return resp
+
+ self.client._send_request = _send_request # type: ignore
+
+ # first request should go through
+ channel = self.make_request(
+ "GET",
+ "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz",
+ shorthand=False,
+ )
+ assert channel.code == 200
+
+ # immediate second request should fail
+ channel = self.make_request(
+ "GET",
+ "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy1",
+ shorthand=False,
+ )
+ assert channel.code == 429
+
+ # advance half a second
+ self.reactor.pump([0.5])
+
+ # request still fails
+ channel = self.make_request(
+ "GET",
+ "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy2",
+ shorthand=False,
+ )
+ assert channel.code == 429
+
+ # advance another half second
+ self.reactor.pump([0.5])
+
+ # enough has drained from bucket and request is successful
+ channel = self.make_request(
+ "GET",
+ "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy3",
+ shorthand=False,
+ )
+ assert channel.code == 200
+
+ @patch(
+ "synapse.http.matrixfederationclient.read_body_with_max_size",
+ read_body_with_max_size_30MiB,
+ )
+ def test_download_ratelimit_max_size_sub(self) -> None:
+ """
+ Test that if no content-length is provided, the default max size is applied instead
+ """
+
+ # mock out actually sending the request
+ async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
+ resp = MagicMock(spec=IResponse)
+ resp.code = 200
+ resp.length = UNKNOWN_LENGTH
+ resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
+ resp.phrase = b"OK"
+ return resp
+
+ self.client._send_request = _send_request # type: ignore
+
+ # ten requests should go through using the max size (500MB/50MB)
+ for i in range(10):
+ channel2 = self.make_request(
+ "GET",
+ f"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy{i}",
+ shorthand=False,
+ )
+ assert channel2.code == 200
+
+ # eleventh will hit ratelimit
+ channel3 = self.make_request(
+ "GET",
+ "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyx",
+ shorthand=False,
+ )
+ assert channel3.code == 429
|