summary refs log tree commit diff
path: root/tests/media/test_media_storage.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/media/test_media_storage.py')
-rw-r--r--tests/media/test_media_storage.py225
1 files changed, 223 insertions, 2 deletions
diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py
index 1bd51ceba2..46d20ce775 100644
--- a/tests/media/test_media_storage.py
+++ b/tests/media/test_media_storage.py
@@ -25,7 +25,7 @@ import tempfile
 from binascii import unhexlify
 from io import BytesIO
 from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Tuple, Union
-from unittest.mock import Mock
+from unittest.mock import MagicMock, Mock, patch
 from urllib import parse
 
 import attr
@@ -37,9 +37,12 @@ from twisted.internet import defer
 from twisted.internet.defer import Deferred
 from twisted.python.failure import Failure
 from twisted.test.proto_helpers import MemoryReactor
+from twisted.web.http_headers import Headers
+from twisted.web.iweb import UNKNOWN_LENGTH, IResponse
 from twisted.web.resource import Resource
 
 from synapse.api.errors import Codes, HttpResponseException
+from synapse.api.ratelimiting import Ratelimiter
 from synapse.events import EventBase
 from synapse.http.types import QueryParams
 from synapse.logging.context import make_deferred_yieldable
@@ -59,6 +62,7 @@ from synapse.util import Clock
 from tests import unittest
 from tests.server import FakeChannel
 from tests.test_utils import SMALL_PNG
+from tests.unittest import override_config
 from tests.utils import default_config
 
 
@@ -251,9 +255,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
             destination: str,
             path: str,
             output_stream: BinaryIO,
+            download_ratelimiter: Ratelimiter,
+            ip_address: Any,
+            max_size: int,
             args: Optional[QueryParams] = None,
             retry_on_dns_fail: bool = True,
-            max_size: Optional[int] = None,
             ignore_backoff: bool = False,
             follow_redirects: bool = False,
         ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
@@ -878,3 +884,218 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
             tok=self.tok,
             expect_code=400,
         )
+
+
+class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+        config = self.default_config()
+
+        self.storage_path = self.mktemp()
+        self.media_store_path = self.mktemp()
+        os.mkdir(self.storage_path)
+        os.mkdir(self.media_store_path)
+        config["media_store_path"] = self.media_store_path
+
+        provider_config = {
+            "module": "synapse.media.storage_provider.FileStorageProviderBackend",
+            "store_local": True,
+            "store_synchronous": False,
+            "store_remote": True,
+            "config": {"directory": self.storage_path},
+        }
+
+        config["media_storage_providers"] = [provider_config]
+
+        return self.setup_test_homeserver(config=config)
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.repo = hs.get_media_repository()
+        self.client = hs.get_federation_http_client()
+        self.store = hs.get_datastores().main
+
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        # We need to manually set the resource tree to include media, the
+        # default only does `/_matrix/client` APIs.
+        return {"/_matrix/media": self.hs.get_media_repository_resource()}
+
+    # mock actually reading file body
+    def read_body_with_max_size_30MiB(*args: Any, **kwargs: Any) -> Deferred:
+        d: Deferred = defer.Deferred()
+        d.callback(31457280)
+        return d
+
+    def read_body_with_max_size_50MiB(*args: Any, **kwargs: Any) -> Deferred:
+        d: Deferred = defer.Deferred()
+        d.callback(52428800)
+        return d
+
+    @patch(
+        "synapse.http.matrixfederationclient.read_body_with_max_size",
+        read_body_with_max_size_30MiB,
+    )
+    def test_download_ratelimit_default(self) -> None:
+        """
+        Test remote media download ratelimiting against default configuration - 500MB bucket
+        and 87kb/second drain rate
+        """
+
+        # mock out actually sending the request, returns a 30MiB response
+        async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
+            resp = MagicMock(spec=IResponse)
+            resp.code = 200
+            resp.length = 31457280
+            resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
+            resp.phrase = b"OK"
+            return resp
+
+        self.client._send_request = _send_request  # type: ignore
+
+        # first request should go through
+        channel = self.make_request(
+            "GET",
+            "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz",
+            shorthand=False,
+        )
+        assert channel.code == 200
+
+        # next 15 should go through
+        for i in range(15):
+            channel2 = self.make_request(
+                "GET",
+                f"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy{i}",
+                shorthand=False,
+            )
+            assert channel2.code == 200
+
+        # 17th will hit ratelimit
+        channel3 = self.make_request(
+            "GET",
+            "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyx",
+            shorthand=False,
+        )
+        assert channel3.code == 429
+
+        # however, a request from a different IP will go through
+        channel4 = self.make_request(
+            "GET",
+            "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz",
+            shorthand=False,
+            client_ip="187.233.230.159",
+        )
+        assert channel4.code == 200
+
+        # at 87Kib/s it should take about 2 minutes for enough to drain from bucket that another
+        # 30MiB download is authorized - The last download was blocked at 503,316,480.
+        # The next download will be authorized when bucket hits 492,830,720
+        # (524,288,000 total capacity - 31,457,280 download size) so 503,316,480 - 492,830,720 ~= 10,485,760
+        # needs to drain before another download will be authorized, that will take ~=
+        # 2 minutes (10,485,760/89,088/60)
+        self.reactor.pump([2.0 * 60.0])
+
+        # enough has drained and next request goes through
+        channel5 = self.make_request(
+            "GET",
+            "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyb",
+            shorthand=False,
+        )
+        assert channel5.code == 200
+
+    @override_config(
+        {
+            "remote_media_download_per_second": "50M",
+            "remote_media_download_burst_count": "50M",
+        }
+    )
+    @patch(
+        "synapse.http.matrixfederationclient.read_body_with_max_size",
+        read_body_with_max_size_50MiB,
+    )
+    def test_download_rate_limit_config(self) -> None:
+        """
+        Test that download rate limit config options are correctly picked up and applied
+        """
+
+        async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
+            resp = MagicMock(spec=IResponse)
+            resp.code = 200
+            resp.length = 52428800
+            resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
+            resp.phrase = b"OK"
+            return resp
+
+        self.client._send_request = _send_request  # type: ignore
+
+        # first request should go through
+        channel = self.make_request(
+            "GET",
+            "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz",
+            shorthand=False,
+        )
+        assert channel.code == 200
+
+        # immediate second request should fail
+        channel = self.make_request(
+            "GET",
+            "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy1",
+            shorthand=False,
+        )
+        assert channel.code == 429
+
+        # advance half a second
+        self.reactor.pump([0.5])
+
+        # request still fails
+        channel = self.make_request(
+            "GET",
+            "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy2",
+            shorthand=False,
+        )
+        assert channel.code == 429
+
+        # advance another half second
+        self.reactor.pump([0.5])
+
+        # enough has drained from bucket and request is successful
+        channel = self.make_request(
+            "GET",
+            "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy3",
+            shorthand=False,
+        )
+        assert channel.code == 200
+
+    @patch(
+        "synapse.http.matrixfederationclient.read_body_with_max_size",
+        read_body_with_max_size_30MiB,
+    )
+    def test_download_ratelimit_max_size_sub(self) -> None:
+        """
+        Test that if no content-length is provided, the default max size is applied instead
+        """
+
+        # mock out actually sending the request
+        async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
+            resp = MagicMock(spec=IResponse)
+            resp.code = 200
+            resp.length = UNKNOWN_LENGTH
+            resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
+            resp.phrase = b"OK"
+            return resp
+
+        self.client._send_request = _send_request  # type: ignore
+
+        # ten requests should go through using the max size (500MB/50MB)
+        for i in range(10):
+            channel2 = self.make_request(
+                "GET",
+                f"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy{i}",
+                shorthand=False,
+            )
+            assert channel2.code == 200
+
+        # eleventh will hit ratelimit
+        channel3 = self.make_request(
+            "GET",
+            "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyx",
+            shorthand=False,
+        )
+        assert channel3.code == 429