diff options
Diffstat (limited to 'tests/media/test_media_storage.py')
-rw-r--r-- | tests/media/test_media_storage.py | 225 |
1 files changed, 223 insertions, 2 deletions
diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py index 1bd51ceba2..46d20ce775 100644 --- a/tests/media/test_media_storage.py +++ b/tests/media/test_media_storage.py @@ -25,7 +25,7 @@ import tempfile from binascii import unhexlify from io import BytesIO from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Tuple, Union -from unittest.mock import Mock +from unittest.mock import MagicMock, Mock, patch from urllib import parse import attr @@ -37,9 +37,12 @@ from twisted.internet import defer from twisted.internet.defer import Deferred from twisted.python.failure import Failure from twisted.test.proto_helpers import MemoryReactor +from twisted.web.http_headers import Headers +from twisted.web.iweb import UNKNOWN_LENGTH, IResponse from twisted.web.resource import Resource from synapse.api.errors import Codes, HttpResponseException +from synapse.api.ratelimiting import Ratelimiter from synapse.events import EventBase from synapse.http.types import QueryParams from synapse.logging.context import make_deferred_yieldable @@ -59,6 +62,7 @@ from synapse.util import Clock from tests import unittest from tests.server import FakeChannel from tests.test_utils import SMALL_PNG +from tests.unittest import override_config from tests.utils import default_config @@ -251,9 +255,11 @@ class MediaRepoTests(unittest.HomeserverTestCase): destination: str, path: str, output_stream: BinaryIO, + download_ratelimiter: Ratelimiter, + ip_address: Any, + max_size: int, args: Optional[QueryParams] = None, retry_on_dns_fail: bool = True, - max_size: Optional[int] = None, ignore_backoff: bool = False, follow_redirects: bool = False, ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]": @@ -878,3 +884,218 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase): tok=self.tok, expect_code=400, ) + + +class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + config = self.default_config() + + self.storage_path = self.mktemp() + self.media_store_path = self.mktemp() + os.mkdir(self.storage_path) + os.mkdir(self.media_store_path) + config["media_store_path"] = self.media_store_path + + provider_config = { + "module": "synapse.media.storage_provider.FileStorageProviderBackend", + "store_local": True, + "store_synchronous": False, + "store_remote": True, + "config": {"directory": self.storage_path}, + } + + config["media_storage_providers"] = [provider_config] + + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.repo = hs.get_media_repository() + self.client = hs.get_federation_http_client() + self.store = hs.get_datastores().main + + def create_resource_dict(self) -> Dict[str, Resource]: + # We need to manually set the resource tree to include media, the + # default only does `/_matrix/client` APIs. + return {"/_matrix/media": self.hs.get_media_repository_resource()} + + # mock actually reading file body + def read_body_with_max_size_30MiB(*args: Any, **kwargs: Any) -> Deferred: + d: Deferred = defer.Deferred() + d.callback(31457280) + return d + + def read_body_with_max_size_50MiB(*args: Any, **kwargs: Any) -> Deferred: + d: Deferred = defer.Deferred() + d.callback(52428800) + return d + + @patch( + "synapse.http.matrixfederationclient.read_body_with_max_size", + read_body_with_max_size_30MiB, + ) + def test_download_ratelimit_default(self) -> None: + """ + Test remote media download ratelimiting against default configuration - 500MB bucket + and 87kb/second drain rate + """ + + # mock out actually sending the request, returns a 30MiB response + async def _send_request(*args: Any, **kwargs: Any) -> IResponse: + resp = MagicMock(spec=IResponse) + resp.code = 200 + resp.length = 31457280 + resp.headers = Headers({"Content-Type": ["application/octet-stream"]}) + resp.phrase = b"OK" + return resp + + self.client._send_request = _send_request # type: ignore + + # first request should go through + channel = self.make_request( + "GET", + "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz", + shorthand=False, + ) + assert channel.code == 200 + + # next 15 should go through + for i in range(15): + channel2 = self.make_request( + "GET", + f"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy{i}", + shorthand=False, + ) + assert channel2.code == 200 + + # 17th will hit ratelimit + channel3 = self.make_request( + "GET", + "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyx", + shorthand=False, + ) + assert channel3.code == 429 + + # however, a request from a different IP will go through + channel4 = self.make_request( + "GET", + "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz", + shorthand=False, + client_ip="187.233.230.159", + ) + assert channel4.code == 200 + + # at 87Kib/s it should take about 2 minutes for enough to drain from bucket that another + # 30MiB download is authorized - The last download was blocked at 503,316,480. + # The next download will be authorized when bucket hits 492,830,720 + # (524,288,000 total capacity - 31,457,280 download size) so 503,316,480 - 492,830,720 ~= 10,485,760 + # needs to drain before another download will be authorized, that will take ~= + # 2 minutes (10,485,760/89,088/60) + self.reactor.pump([2.0 * 60.0]) + + # enough has drained and next request goes through + channel5 = self.make_request( + "GET", + "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyb", + shorthand=False, + ) + assert channel5.code == 200 + + @override_config( + { + "remote_media_download_per_second": "50M", + "remote_media_download_burst_count": "50M", + } + ) + @patch( + "synapse.http.matrixfederationclient.read_body_with_max_size", + read_body_with_max_size_50MiB, + ) + def test_download_rate_limit_config(self) -> None: + """ + Test that download rate limit config options are correctly picked up and applied + """ + + async def _send_request(*args: Any, **kwargs: Any) -> IResponse: + resp = MagicMock(spec=IResponse) + resp.code = 200 + resp.length = 52428800 + resp.headers = Headers({"Content-Type": ["application/octet-stream"]}) + resp.phrase = b"OK" + return resp + + self.client._send_request = _send_request # type: ignore + + # first request should go through + channel = self.make_request( + "GET", + "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz", + shorthand=False, + ) + assert channel.code == 200 + + # immediate second request should fail + channel = self.make_request( + "GET", + "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy1", + shorthand=False, + ) + assert channel.code == 429 + + # advance half a second + self.reactor.pump([0.5]) + + # request still fails + channel = self.make_request( + "GET", + "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy2", + shorthand=False, + ) + assert channel.code == 429 + + # advance another half second + self.reactor.pump([0.5]) + + # enough has drained from bucket and request is successful + channel = self.make_request( + "GET", + "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy3", + shorthand=False, + ) + assert channel.code == 200 + + @patch( + "synapse.http.matrixfederationclient.read_body_with_max_size", + read_body_with_max_size_30MiB, + ) + def test_download_ratelimit_max_size_sub(self) -> None: + """ + Test that if no content-length is provided, the default max size is applied instead + """ + + # mock out actually sending the request + async def _send_request(*args: Any, **kwargs: Any) -> IResponse: + resp = MagicMock(spec=IResponse) + resp.code = 200 + resp.length = UNKNOWN_LENGTH + resp.headers = Headers({"Content-Type": ["application/octet-stream"]}) + resp.phrase = b"OK" + return resp + + self.client._send_request = _send_request # type: ignore + + # ten requests should go through using the max size (500MB/50MB) + for i in range(10): + channel2 = self.make_request( + "GET", + f"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy{i}", + shorthand=False, + ) + assert channel2.code == 200 + + # eleventh will hit ratelimit + channel3 = self.make_request( + "GET", + "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyx", + shorthand=False, + ) + assert channel3.code == 429 |