summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorShay <hillerys@element.io>2024-06-05 05:43:36 -0700
committerGitHub <noreply@github.com>2024-06-05 13:43:36 +0100
commitfcbc79bb87d08147e86dafa0fee5a9aec4d3fc23 (patch)
tree5d5efc029b3410f96ce9aa5dc49b55e7ef4755c5 /synapse
parentHandle hyphens in user dir search porperly (#17254) (diff)
downloadsynapse-fcbc79bb87d08147e86dafa0fee5a9aec4d3fc23.tar.xz
Ratelimiting of remote media downloads (#17256)
Diffstat (limited to 'synapse')
-rw-r--r--synapse/config/ratelimiting.py10
-rw-r--r--synapse/federation/federation_client.py7
-rw-r--r--synapse/federation/transport/client.py9
-rw-r--r--synapse/http/matrixfederationclient.py55
-rw-r--r--synapse/media/media_repository.py43
-rw-r--r--synapse/media/thumbnailer.py6
-rw-r--r--synapse/rest/client/media.py2
-rw-r--r--synapse/rest/media/download_resource.py8
-rw-r--r--synapse/rest/media/thumbnail_resource.py2
9 files changed, 130 insertions, 12 deletions
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index d2cb4576df..3fa33f5373 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -218,3 +218,13 @@ class RatelimitConfig(Config):
             "rc_media_create",
             defaults={"per_second": 10, "burst_count": 50},
         )
+
+        self.remote_media_downloads = RatelimitSettings(
+            key="rc_remote_media_downloads",
+            per_second=self.parse_size(
+                config.get("remote_media_download_per_second", "87K")
+            ),
+            burst_count=self.parse_size(
+                config.get("remote_media_download_burst_count", "500M")
+            ),
+        )
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index e613eb87a6..f0f5a37a57 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -56,6 +56,7 @@ from synapse.api.errors import (
     SynapseError,
     UnsupportedRoomVersionError,
 )
+from synapse.api.ratelimiting import Ratelimiter
 from synapse.api.room_versions import (
     KNOWN_ROOM_VERSIONS,
     EventFormatVersions,
@@ -1877,6 +1878,8 @@ class FederationClient(FederationBase):
         output_stream: BinaryIO,
         max_size: int,
         max_timeout_ms: int,
+        download_ratelimiter: Ratelimiter,
+        ip_address: str,
     ) -> Tuple[int, Dict[bytes, List[bytes]]]:
         try:
             return await self.transport_layer.download_media_v3(
@@ -1885,6 +1888,8 @@ class FederationClient(FederationBase):
                 output_stream=output_stream,
                 max_size=max_size,
                 max_timeout_ms=max_timeout_ms,
+                download_ratelimiter=download_ratelimiter,
+                ip_address=ip_address,
             )
         except HttpResponseException as e:
             # If an error is received that is due to an unrecognised endpoint,
@@ -1905,6 +1910,8 @@ class FederationClient(FederationBase):
             output_stream=output_stream,
             max_size=max_size,
             max_timeout_ms=max_timeout_ms,
+            download_ratelimiter=download_ratelimiter,
+            ip_address=ip_address,
         )
 
 
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index de408f7f8d..af1336fe5f 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -43,6 +43,7 @@ import ijson
 
 from synapse.api.constants import Direction, Membership
 from synapse.api.errors import Codes, HttpResponseException, SynapseError
+from synapse.api.ratelimiting import Ratelimiter
 from synapse.api.room_versions import RoomVersion
 from synapse.api.urls import (
     FEDERATION_UNSTABLE_PREFIX,
@@ -819,6 +820,8 @@ class TransportLayerClient:
         output_stream: BinaryIO,
         max_size: int,
         max_timeout_ms: int,
+        download_ratelimiter: Ratelimiter,
+        ip_address: str,
     ) -> Tuple[int, Dict[bytes, List[bytes]]]:
         path = f"/_matrix/media/r0/download/{destination}/{media_id}"
 
@@ -834,6 +837,8 @@ class TransportLayerClient:
                 "allow_remote": "false",
                 "timeout_ms": str(max_timeout_ms),
             },
+            download_ratelimiter=download_ratelimiter,
+            ip_address=ip_address,
         )
 
     async def download_media_v3(
@@ -843,6 +848,8 @@ class TransportLayerClient:
         output_stream: BinaryIO,
         max_size: int,
         max_timeout_ms: int,
+        download_ratelimiter: Ratelimiter,
+        ip_address: str,
     ) -> Tuple[int, Dict[bytes, List[bytes]]]:
         path = f"/_matrix/media/v3/download/{destination}/{media_id}"
 
@@ -862,6 +869,8 @@ class TransportLayerClient:
                 "allow_redirect": "true",
             },
             follow_redirects=True,
+            download_ratelimiter=download_ratelimiter,
+            ip_address=ip_address,
         )
 
 
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index c73a589e6c..104b803b0f 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -57,7 +57,7 @@ from twisted.internet.interfaces import IReactorTime
 from twisted.internet.task import Cooperator
 from twisted.web.client import ResponseFailed
 from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent, IBodyProducer, IResponse
+from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse
 
 import synapse.metrics
 import synapse.util.retryutils
@@ -68,6 +68,7 @@ from synapse.api.errors import (
     RequestSendFailed,
     SynapseError,
 )
+from synapse.api.ratelimiting import Ratelimiter
 from synapse.crypto.context_factory import FederationPolicyForHTTPS
 from synapse.http import QuieterFileBodyProducer
 from synapse.http.client import (
@@ -1411,9 +1412,11 @@ class MatrixFederationHttpClient:
         destination: str,
         path: str,
         output_stream: BinaryIO,
+        download_ratelimiter: Ratelimiter,
+        ip_address: str,
+        max_size: int,
         args: Optional[QueryParams] = None,
         retry_on_dns_fail: bool = True,
-        max_size: Optional[int] = None,
         ignore_backoff: bool = False,
         follow_redirects: bool = False,
     ) -> Tuple[int, Dict[bytes, List[bytes]]]:
@@ -1422,6 +1425,10 @@ class MatrixFederationHttpClient:
             destination: The remote server to send the HTTP request to.
             path: The HTTP path to GET.
             output_stream: File to write the response body to.
+            download_ratelimiter: a ratelimiter to limit remote media downloads, keyed to
+                requester IP
+            ip_address: IP address of the requester
+            max_size: maximum allowable size in bytes of the file
             args: Optional dictionary used to create the query string.
             ignore_backoff: true to ignore the historical backoff data
                 and try the request anyway.
@@ -1441,11 +1448,27 @@ class MatrixFederationHttpClient:
                 federation whitelist
             RequestSendFailed: If there were problems connecting to the
                 remote, due to e.g. DNS failures, connection timeouts etc.
+            SynapseError: If the requested file exceeds ratelimits
         """
         request = MatrixFederationRequest(
             method="GET", destination=destination, path=path, query=args
         )
 
+        # check for a minimum balance of 1MiB in ratelimiter before initiating request
+        send_req, _ = await download_ratelimiter.can_do_action(
+            requester=None, key=ip_address, n_actions=1048576, update=False
+        )
+
+        if not send_req:
+            msg = "Requested file size exceeds ratelimits"
+            logger.warning(
+                "{%s} [%s] %s",
+                request.txn_id,
+                request.destination,
+                msg,
+            )
+            raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
+
         response = await self._send_request(
             request,
             retry_on_dns_fail=retry_on_dns_fail,
@@ -1455,12 +1478,36 @@ class MatrixFederationHttpClient:
 
         headers = dict(response.headers.getAllRawHeaders())
 
+        expected_size = response.length
+        # if we don't get an expected length then use the max length
+        if expected_size == UNKNOWN_LENGTH:
+            expected_size = max_size
+            logger.debug(
+                f"File size unknown, assuming file is max allowable size: {max_size}"
+            )
+
+        read_body, _ = await download_ratelimiter.can_do_action(
+            requester=None,
+            key=ip_address,
+            n_actions=expected_size,
+        )
+        if not read_body:
+            msg = "Requested file size exceeds ratelimits"
+            logger.warning(
+                "{%s} [%s] %s",
+                request.txn_id,
+                request.destination,
+                msg,
+            )
+            raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
+
         try:
-            d = read_body_with_max_size(response, output_stream, max_size)
+            # add a byte of headroom to max size as function errs at >=
+            d = read_body_with_max_size(response, output_stream, expected_size + 1)
             d.addTimeout(self.default_timeout_seconds, self.reactor)
             length = await make_deferred_yieldable(d)
         except BodyExceededMaxSize:
-            msg = "Requested file is too large > %r bytes" % (max_size,)
+            msg = "Requested file is too large > %r bytes" % (expected_size,)
             logger.warning(
                 "{%s} [%s] %s",
                 request.txn_id,
diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 9c29e09653..6ed56099ca 100644
--- a/synapse/media/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -42,6 +42,7 @@ from synapse.api.errors import (
     SynapseError,
     cs_error,
 )
+from synapse.api.ratelimiting import Ratelimiter
 from synapse.config.repository import ThumbnailRequirement
 from synapse.http.server import respond_with_json
 from synapse.http.site import SynapseRequest
@@ -111,6 +112,12 @@ class MediaRepository:
         )
         self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from
 
+        self.download_ratelimiter = Ratelimiter(
+            store=hs.get_storage_controllers().main,
+            clock=hs.get_clock(),
+            cfg=hs.config.ratelimiting.remote_media_downloads,
+        )
+
         # List of StorageProviders where we should search for media and
         # potentially upload to.
         storage_providers = []
@@ -464,6 +471,7 @@ class MediaRepository:
         media_id: str,
         name: Optional[str],
         max_timeout_ms: int,
+        ip_address: str,
     ) -> None:
         """Respond to requests for remote media.
 
@@ -475,6 +483,7 @@ class MediaRepository:
                 the filename in the Content-Disposition header of the response.
             max_timeout_ms: the maximum number of milliseconds to wait for the
                 media to be uploaded.
+            ip_address: the IP address of the requester
 
         Returns:
             Resolves once a response has successfully been written to request
@@ -500,7 +509,11 @@ class MediaRepository:
         key = (server_name, media_id)
         async with self.remote_media_linearizer.queue(key):
             responder, media_info = await self._get_remote_media_impl(
-                server_name, media_id, max_timeout_ms
+                server_name,
+                media_id,
+                max_timeout_ms,
+                self.download_ratelimiter,
+                ip_address,
             )
 
         # We deliberately stream the file outside the lock
@@ -517,7 +530,7 @@ class MediaRepository:
             respond_404(request)
 
     async def get_remote_media_info(
-        self, server_name: str, media_id: str, max_timeout_ms: int
+        self, server_name: str, media_id: str, max_timeout_ms: int, ip_address: str
     ) -> RemoteMedia:
         """Gets the media info associated with the remote file, downloading
         if necessary.
@@ -527,6 +540,7 @@ class MediaRepository:
             media_id: The media ID of the content (as defined by the remote server).
             max_timeout_ms: the maximum number of milliseconds to wait for the
                 media to be uploaded.
+            ip_address: IP address of the requester
 
         Returns:
             The media info of the file
@@ -542,7 +556,11 @@ class MediaRepository:
         key = (server_name, media_id)
         async with self.remote_media_linearizer.queue(key):
             responder, media_info = await self._get_remote_media_impl(
-                server_name, media_id, max_timeout_ms
+                server_name,
+                media_id,
+                max_timeout_ms,
+                self.download_ratelimiter,
+                ip_address,
             )
 
         # Ensure we actually use the responder so that it releases resources
@@ -553,7 +571,12 @@ class MediaRepository:
         return media_info
 
     async def _get_remote_media_impl(
-        self, server_name: str, media_id: str, max_timeout_ms: int
+        self,
+        server_name: str,
+        media_id: str,
+        max_timeout_ms: int,
+        download_ratelimiter: Ratelimiter,
+        ip_address: str,
     ) -> Tuple[Optional[Responder], RemoteMedia]:
         """Looks for media in local cache, if not there then attempt to
         download from remote server.
@@ -564,6 +587,9 @@ class MediaRepository:
                 remote server).
             max_timeout_ms: the maximum number of milliseconds to wait for the
                 media to be uploaded.
+            download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
+                requester IP.
+            ip_address: the IP address of the requester
 
         Returns:
             A tuple of responder and the media info of the file.
@@ -596,7 +622,7 @@ class MediaRepository:
 
         try:
             media_info = await self._download_remote_file(
-                server_name, media_id, max_timeout_ms
+                server_name, media_id, max_timeout_ms, download_ratelimiter, ip_address
             )
         except SynapseError:
             raise
@@ -630,6 +656,8 @@ class MediaRepository:
         server_name: str,
         media_id: str,
         max_timeout_ms: int,
+        download_ratelimiter: Ratelimiter,
+        ip_address: str,
     ) -> RemoteMedia:
         """Attempt to download the remote file from the given server name,
         using the given file_id as the local id.
@@ -641,6 +669,9 @@ class MediaRepository:
                 locally generated.
             max_timeout_ms: the maximum number of milliseconds to wait for the
                 media to be uploaded.
+            download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
+                requester IP
+            ip_address: the IP address of the requester
 
         Returns:
             The media info of the file.
@@ -658,6 +689,8 @@ class MediaRepository:
                     output_stream=f,
                     max_size=self.max_upload_size,
                     max_timeout_ms=max_timeout_ms,
+                    download_ratelimiter=download_ratelimiter,
+                    ip_address=ip_address,
                 )
             except RequestSendFailed as e:
                 logger.warning(
diff --git a/synapse/media/thumbnailer.py b/synapse/media/thumbnailer.py
index cc3acf51e1..f8a9560784 100644
--- a/synapse/media/thumbnailer.py
+++ b/synapse/media/thumbnailer.py
@@ -359,9 +359,10 @@ class ThumbnailProvider:
         desired_method: str,
         desired_type: str,
         max_timeout_ms: int,
+        ip_address: str,
     ) -> None:
         media_info = await self.media_repo.get_remote_media_info(
-            server_name, media_id, max_timeout_ms
+            server_name, media_id, max_timeout_ms, ip_address
         )
         if not media_info:
             respond_404(request)
@@ -422,12 +423,13 @@ class ThumbnailProvider:
         method: str,
         m_type: str,
         max_timeout_ms: int,
+        ip_address: str,
     ) -> None:
         # TODO: Don't download the whole remote file
         # We should proxy the thumbnail from the remote server instead of
         # downloading the remote file and generating our own thumbnails.
         media_info = await self.media_repo.get_remote_media_info(
-            server_name, media_id, max_timeout_ms
+            server_name, media_id, max_timeout_ms, ip_address
         )
         if not media_info:
             return
diff --git a/synapse/rest/client/media.py b/synapse/rest/client/media.py
index 172d240783..0c089163c1 100644
--- a/synapse/rest/client/media.py
+++ b/synapse/rest/client/media.py
@@ -174,6 +174,7 @@ class UnstableThumbnailResource(RestServlet):
                 respond_404(request)
                 return
 
+            ip_address = request.getClientAddress().host
             remote_resp_function = (
                 self.thumbnailer.select_or_generate_remote_thumbnail
                 if self.dynamic_thumbnails
@@ -188,6 +189,7 @@ class UnstableThumbnailResource(RestServlet):
                 method,
                 m_type,
                 max_timeout_ms,
+                ip_address,
             )
             self.media_repo.mark_recently_accessed(server_name, media_id)
 
diff --git a/synapse/rest/media/download_resource.py b/synapse/rest/media/download_resource.py
index 8ba723c8d4..1628d58926 100644
--- a/synapse/rest/media/download_resource.py
+++ b/synapse/rest/media/download_resource.py
@@ -97,6 +97,12 @@ class DownloadResource(RestServlet):
                 respond_404(request)
                 return
 
+            ip_address = request.getClientAddress().host
             await self.media_repo.get_remote_media(
-                request, server_name, media_id, file_name, max_timeout_ms
+                request,
+                server_name,
+                media_id,
+                file_name,
+                max_timeout_ms,
+                ip_address,
             )
diff --git a/synapse/rest/media/thumbnail_resource.py b/synapse/rest/media/thumbnail_resource.py
index fe8fbb06e4..ce511c6dce 100644
--- a/synapse/rest/media/thumbnail_resource.py
+++ b/synapse/rest/media/thumbnail_resource.py
@@ -104,6 +104,7 @@ class ThumbnailResource(RestServlet):
                 respond_404(request)
                 return
 
+            ip_address = request.getClientAddress().host
             remote_resp_function = (
                 self.thumbnail_provider.select_or_generate_remote_thumbnail
                 if self.dynamic_thumbnails
@@ -118,5 +119,6 @@ class ThumbnailResource(RestServlet):
                 method,
                 m_type,
                 max_timeout_ms,
+                ip_address,
             )
             self.media_repo.mark_recently_accessed(server_name, media_id)