summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/17439.bugfix1
-rw-r--r--synapse/http/matrixfederationclient.py126
-rw-r--r--tests/media/test_media_storage.py49
-rw-r--r--tests/rest/client/test_media.py50
4 files changed, 166 insertions, 60 deletions
diff --git a/changelog.d/17439.bugfix b/changelog.d/17439.bugfix
new file mode 100644
index 0000000000..f36c3ec255
--- /dev/null
+++ b/changelog.d/17439.bugfix
@@ -0,0 +1 @@
+Limit concurrent remote downloads to 6 per IP address, and decrement remote downloads without a content-length from the ratelimiter after the download is complete.
\ No newline at end of file
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 749b01dd0e..6fd75fd381 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -90,7 +90,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.logging.opentracing import set_tag, start_active_span, tags
 from synapse.types import JsonDict
 from synapse.util import json_decoder
-from synapse.util.async_helpers import AwakenableSleeper, timeout_deferred
+from synapse.util.async_helpers import AwakenableSleeper, Linearizer, timeout_deferred
 from synapse.util.metrics import Measure
 from synapse.util.stringutils import parse_and_validate_server_name
 
@@ -475,6 +475,8 @@ class MatrixFederationHttpClient:
             use_proxy=True,
         )
 
+        self.remote_download_linearizer = Linearizer("remote_download_linearizer", 6)
+
     def wake_destination(self, destination: str) -> None:
         """Called when the remote server may have come back online."""
 
@@ -1486,35 +1488,44 @@ class MatrixFederationHttpClient:
         )
 
         headers = dict(response.headers.getAllRawHeaders())
-
         expected_size = response.length
-        # if we don't get an expected length then use the max length
+
         if expected_size == UNKNOWN_LENGTH:
             expected_size = max_size
-            logger.debug(
-                f"File size unknown, assuming file is max allowable size: {max_size}"
-            )
+        else:
+            if int(expected_size) > max_size:
+                msg = "Requested file is too large > %r bytes" % (max_size,)
+                logger.warning(
+                    "{%s} [%s] %s",
+                    request.txn_id,
+                    request.destination,
+                    msg,
+                )
+                raise SynapseError(HTTPStatus.BAD_GATEWAY, msg, Codes.TOO_LARGE)
 
-        read_body, _ = await download_ratelimiter.can_do_action(
-            requester=None,
-            key=ip_address,
-            n_actions=expected_size,
-        )
-        if not read_body:
-            msg = "Requested file size exceeds ratelimits"
-            logger.warning(
-                "{%s} [%s] %s",
-                request.txn_id,
-                request.destination,
-                msg,
+            read_body, _ = await download_ratelimiter.can_do_action(
+                requester=None,
+                key=ip_address,
+                n_actions=expected_size,
             )
-            raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
+            if not read_body:
+                msg = "Requested file size exceeds ratelimits"
+                logger.warning(
+                    "{%s} [%s] %s",
+                    request.txn_id,
+                    request.destination,
+                    msg,
+                )
+                raise SynapseError(
+                    HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED
+                )
 
         try:
-            # add a byte of headroom to max size as function errs at >=
-            d = read_body_with_max_size(response, output_stream, expected_size + 1)
-            d.addTimeout(self.default_timeout_seconds, self.reactor)
-            length = await make_deferred_yieldable(d)
+            async with self.remote_download_linearizer.queue(ip_address):
+                # add a byte of headroom to max size as function errs at >=
+                d = read_body_with_max_size(response, output_stream, expected_size + 1)
+                d.addTimeout(self.default_timeout_seconds, self.reactor)
+                length = await make_deferred_yieldable(d)
         except BodyExceededMaxSize:
             msg = "Requested file is too large > %r bytes" % (expected_size,)
             logger.warning(
@@ -1560,6 +1571,13 @@ class MatrixFederationHttpClient:
             request.method,
             request.uri.decode("ascii"),
         )
+
+        # if we didn't know the length upfront, decrement the actual size from ratelimiter
+        if response.length == UNKNOWN_LENGTH:
+            download_ratelimiter.record_action(
+                requester=None, key=ip_address, n_actions=length
+            )
+
         return length, headers
 
     async def federation_get_file(
@@ -1630,29 +1648,37 @@ class MatrixFederationHttpClient:
         )
 
         headers = dict(response.headers.getAllRawHeaders())
-
         expected_size = response.length
-        # if we don't get an expected length then use the max length
+
         if expected_size == UNKNOWN_LENGTH:
             expected_size = max_size
-            logger.debug(
-                f"File size unknown, assuming file is max allowable size: {max_size}"
-            )
+        else:
+            if int(expected_size) > max_size:
+                msg = "Requested file is too large > %r bytes" % (max_size,)
+                logger.warning(
+                    "{%s} [%s] %s",
+                    request.txn_id,
+                    request.destination,
+                    msg,
+                )
+                raise SynapseError(HTTPStatus.BAD_GATEWAY, msg, Codes.TOO_LARGE)
 
-        read_body, _ = await download_ratelimiter.can_do_action(
-            requester=None,
-            key=ip_address,
-            n_actions=expected_size,
-        )
-        if not read_body:
-            msg = "Requested file size exceeds ratelimits"
-            logger.warning(
-                "{%s} [%s] %s",
-                request.txn_id,
-                request.destination,
-                msg,
+            read_body, _ = await download_ratelimiter.can_do_action(
+                requester=None,
+                key=ip_address,
+                n_actions=expected_size,
             )
-            raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
+            if not read_body:
+                msg = "Requested file size exceeds ratelimits"
+                logger.warning(
+                    "{%s} [%s] %s",
+                    request.txn_id,
+                    request.destination,
+                    msg,
+                )
+                raise SynapseError(
+                    HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED
+                )
 
         # this should be a multipart/mixed response with the boundary string in the header
         try:
@@ -1672,11 +1698,12 @@ class MatrixFederationHttpClient:
             raise SynapseError(HTTPStatus.BAD_GATEWAY, msg)
 
         try:
-            # add a byte of headroom to max size as `_MultipartParserProtocol.dataReceived` errs at >=
-            deferred = read_multipart_response(
-                response, output_stream, boundary, expected_size + 1
-            )
-            deferred.addTimeout(self.default_timeout_seconds, self.reactor)
+            async with self.remote_download_linearizer.queue(ip_address):
+                # add a byte of headroom to max size as `_MultipartParserProtocol.dataReceived` errs at >=
+                deferred = read_multipart_response(
+                    response, output_stream, boundary, expected_size + 1
+                )
+                deferred.addTimeout(self.default_timeout_seconds, self.reactor)
         except BodyExceededMaxSize:
             msg = "Requested file is too large > %r bytes" % (expected_size,)
             logger.warning(
@@ -1743,6 +1770,13 @@ class MatrixFederationHttpClient:
             request.method,
             request.uri.decode("ascii"),
         )
+
+        # if we didn't know the length upfront, decrement the actual size from ratelimiter
+        if response.length == UNKNOWN_LENGTH:
+            download_ratelimiter.record_action(
+                requester=None, key=ip_address, n_actions=length
+            )
+
         return length, headers, multipart_response.json
 
 
diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py
index 70912e22f8..e55001fb40 100644
--- a/tests/media/test_media_storage.py
+++ b/tests/media/test_media_storage.py
@@ -1057,13 +1057,15 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
         )
         assert channel.code == 200
 
+    @override_config({"remote_media_download_burst_count": "87M"})
     @patch(
         "synapse.http.matrixfederationclient.read_body_with_max_size",
         read_body_with_max_size_30MiB,
     )
-    def test_download_ratelimit_max_size_sub(self) -> None:
+    def test_download_ratelimit_unknown_length(self) -> None:
         """
-        Test that if no content-length is provided, the default max size is applied instead
+        Test that if no content-length is provided, ratelimit will still be applied after
+        download once length is known
         """
 
         # mock out actually sending the request
@@ -1077,19 +1079,48 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
 
         self.client._send_request = _send_request  # type: ignore
 
-        # ten requests should go through using the max size (500MB/50MB)
-        for i in range(10):
-            channel2 = self.make_request(
+        # 3 requests should go through (note 3rd one would technically violate ratelimit but
+        # is applied *after* download - the next one will be ratelimited)
+        for i in range(3):
+            channel = self.make_request(
                 "GET",
                 f"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy{i}",
                 shorthand=False,
             )
-            assert channel2.code == 200
+            assert channel.code == 200
 
-        # eleventh will hit ratelimit
-        channel3 = self.make_request(
+        # 4th will hit ratelimit
+        channel2 = self.make_request(
             "GET",
             "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyx",
             shorthand=False,
         )
-        assert channel3.code == 429
+        assert channel2.code == 429
+
+    @override_config({"max_upload_size": "29M"})
+    @patch(
+        "synapse.http.matrixfederationclient.read_body_with_max_size",
+        read_body_with_max_size_30MiB,
+    )
+    def test_max_download_respected(self) -> None:
+        """
+        Test that the max download size is enforced - note that max download size is determined
+        by the max_upload_size
+        """
+
+        # mock out actually sending the request
+        async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
+            resp = MagicMock(spec=IResponse)
+            resp.code = 200
+            resp.length = 31457280
+            resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
+            resp.phrase = b"OK"
+            return resp
+
+        self.client._send_request = _send_request  # type: ignore
+
+        channel = self.make_request(
+            "GET", "/_matrix/media/v3/download/remote.org/abcd", shorthand=False
+        )
+        assert channel.code == 502
+        assert channel.json_body["errcode"] == "M_TOO_LARGE"
diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py
index 7f2caed7d5..466c5a0b70 100644
--- a/tests/rest/client/test_media.py
+++ b/tests/rest/client/test_media.py
@@ -1809,13 +1809,19 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
         )
         assert channel.code == 200
 
+    @override_config(
+        {
+            "remote_media_download_burst_count": "87M",
+        }
+    )
     @patch(
         "synapse.http.matrixfederationclient.read_multipart_response",
         read_multipart_response_30MiB,
     )
-    def test_download_ratelimit_max_size_sub(self) -> None:
+    def test_download_ratelimit_unknown_length(self) -> None:
         """
-        Test that if no content-length is provided, the default max size is applied instead
+        Test that if no content-length is provided, ratelimiting is still applied after
+        media is downloaded and length is known
         """
 
         # mock out actually sending the request
@@ -1831,8 +1837,9 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
 
         self.client._send_request = _send_request  # type: ignore
 
-        # ten requests should go through using the max size (500MB/50MB)
-        for i in range(10):
+        # first 3 will go through (note that 3rd request technically violates rate limit but
+        # that since the ratelimiting is applied *after* download it goes through, but next one fails)
+        for i in range(3):
             channel2 = self.make_request(
                 "GET",
                 f"/_matrix/client/v1/media/download/remote.org/abc{i}",
@@ -1841,7 +1848,7 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
             )
             assert channel2.code == 200
 
-        # eleventh will hit ratelimit
+        # 4th will hit ratelimit
         channel3 = self.make_request(
             "GET",
             "/_matrix/client/v1/media/download/remote.org/abcd",
@@ -1850,6 +1857,39 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
         )
         assert channel3.code == 429
 
+    @override_config({"max_upload_size": "29M"})
+    @patch(
+        "synapse.http.matrixfederationclient.read_multipart_response",
+        read_multipart_response_30MiB,
+    )
+    def test_max_download_respected(self) -> None:
+        """
+        Test that the max download size is enforced - note that max download size is determined
+        by the max_upload_size
+        """
+
+        # mock out actually sending the request, returns a 30MiB response
+        async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
+            resp = MagicMock(spec=IResponse)
+            resp.code = 200
+            resp.length = 31457280
+            resp.headers = Headers(
+                {"Content-Type": ["multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p"]}
+            )
+            resp.phrase = b"OK"
+            return resp
+
+        self.client._send_request = _send_request  # type: ignore
+
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/v1/media/download/remote.org/abcd",
+            shorthand=False,
+            access_token=self.tok,
+        )
+        assert channel.code == 502
+        assert channel.json_body["errcode"] == "M_TOO_LARGE"
+
     def test_file_download(self) -> None:
         content = io.BytesIO(b"file_to_stream")
         content_uri = self.get_success(