summary refs log tree commit diff
path: root/synapse/media/media_repository.py
diff options
context:
space:
mode:
authorShay <hillerys@element.io>2024-06-05 05:43:36 -0700
committerGitHub <noreply@github.com>2024-06-05 13:43:36 +0100
commitfcbc79bb87d08147e86dafa0fee5a9aec4d3fc23 (patch)
tree5d5efc029b3410f96ce9aa5dc49b55e7ef4755c5 /synapse/media/media_repository.py
parentHandle hyphens in user dir search porperly (#17254) (diff)
downloadsynapse-fcbc79bb87d08147e86dafa0fee5a9aec4d3fc23.tar.xz
Ratelimiting of remote media downloads (#17256)
Diffstat (limited to 'synapse/media/media_repository.py')
-rw-r--r--synapse/media/media_repository.py43
1 files changed, 38 insertions, 5 deletions
diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 9c29e09653..6ed56099ca 100644
--- a/synapse/media/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -42,6 +42,7 @@ from synapse.api.errors import (
     SynapseError,
     cs_error,
 )
+from synapse.api.ratelimiting import Ratelimiter
 from synapse.config.repository import ThumbnailRequirement
 from synapse.http.server import respond_with_json
 from synapse.http.site import SynapseRequest
@@ -111,6 +112,12 @@ class MediaRepository:
         )
         self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from
 
+        self.download_ratelimiter = Ratelimiter(
+            store=hs.get_storage_controllers().main,
+            clock=hs.get_clock(),
+            cfg=hs.config.ratelimiting.remote_media_downloads,
+        )
+
         # List of StorageProviders where we should search for media and
         # potentially upload to.
         storage_providers = []
@@ -464,6 +471,7 @@ class MediaRepository:
         media_id: str,
         name: Optional[str],
         max_timeout_ms: int,
+        ip_address: str,
     ) -> None:
         """Respond to requests for remote media.
 
@@ -475,6 +483,7 @@ class MediaRepository:
                 the filename in the Content-Disposition header of the response.
             max_timeout_ms: the maximum number of milliseconds to wait for the
                 media to be uploaded.
+            ip_address: the IP address of the requester
 
         Returns:
             Resolves once a response has successfully been written to request
@@ -500,7 +509,11 @@ class MediaRepository:
         key = (server_name, media_id)
         async with self.remote_media_linearizer.queue(key):
             responder, media_info = await self._get_remote_media_impl(
-                server_name, media_id, max_timeout_ms
+                server_name,
+                media_id,
+                max_timeout_ms,
+                self.download_ratelimiter,
+                ip_address,
             )
 
         # We deliberately stream the file outside the lock
@@ -517,7 +530,7 @@ class MediaRepository:
             respond_404(request)
 
     async def get_remote_media_info(
-        self, server_name: str, media_id: str, max_timeout_ms: int
+        self, server_name: str, media_id: str, max_timeout_ms: int, ip_address: str
     ) -> RemoteMedia:
         """Gets the media info associated with the remote file, downloading
         if necessary.
@@ -527,6 +540,7 @@ class MediaRepository:
             media_id: The media ID of the content (as defined by the remote server).
             max_timeout_ms: the maximum number of milliseconds to wait for the
                 media to be uploaded.
+            ip_address: IP address of the requester
 
         Returns:
             The media info of the file
@@ -542,7 +556,11 @@ class MediaRepository:
         key = (server_name, media_id)
         async with self.remote_media_linearizer.queue(key):
             responder, media_info = await self._get_remote_media_impl(
-                server_name, media_id, max_timeout_ms
+                server_name,
+                media_id,
+                max_timeout_ms,
+                self.download_ratelimiter,
+                ip_address,
             )
 
         # Ensure we actually use the responder so that it releases resources
@@ -553,7 +571,12 @@ class MediaRepository:
         return media_info
 
     async def _get_remote_media_impl(
-        self, server_name: str, media_id: str, max_timeout_ms: int
+        self,
+        server_name: str,
+        media_id: str,
+        max_timeout_ms: int,
+        download_ratelimiter: Ratelimiter,
+        ip_address: str,
     ) -> Tuple[Optional[Responder], RemoteMedia]:
         """Looks for media in local cache, if not there then attempt to
         download from remote server.
@@ -564,6 +587,9 @@ class MediaRepository:
                 remote server).
             max_timeout_ms: the maximum number of milliseconds to wait for the
                 media to be uploaded.
+            download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
+                requester IP.
+            ip_address: the IP address of the requester
 
         Returns:
             A tuple of responder and the media info of the file.
@@ -596,7 +622,7 @@ class MediaRepository:
 
         try:
             media_info = await self._download_remote_file(
-                server_name, media_id, max_timeout_ms
+                server_name, media_id, max_timeout_ms, download_ratelimiter, ip_address
             )
         except SynapseError:
             raise
@@ -630,6 +656,8 @@ class MediaRepository:
         server_name: str,
         media_id: str,
         max_timeout_ms: int,
+        download_ratelimiter: Ratelimiter,
+        ip_address: str,
     ) -> RemoteMedia:
         """Attempt to download the remote file from the given server name,
         using the given file_id as the local id.
@@ -641,6 +669,9 @@ class MediaRepository:
                 locally generated.
             max_timeout_ms: the maximum number of milliseconds to wait for the
                 media to be uploaded.
+            download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
+                requester IP
+            ip_address: the IP address of the requester
 
         Returns:
             The media info of the file.
@@ -658,6 +689,8 @@ class MediaRepository:
                     output_stream=f,
                     max_size=self.max_upload_size,
                     max_timeout_ms=max_timeout_ms,
+                    download_ratelimiter=download_ratelimiter,
+                    ip_address=ip_address,
                 )
             except RequestSendFailed as e:
                 logger.warning(