diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 9c29e09653..6ed56099ca 100644
--- a/synapse/media/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -42,6 +42,7 @@ from synapse.api.errors import (
SynapseError,
cs_error,
)
+from synapse.api.ratelimiting import Ratelimiter
from synapse.config.repository import ThumbnailRequirement
from synapse.http.server import respond_with_json
from synapse.http.site import SynapseRequest
@@ -111,6 +112,12 @@ class MediaRepository:
)
self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from
+ self.download_ratelimiter = Ratelimiter(
+ store=hs.get_storage_controllers().main,
+ clock=hs.get_clock(),
+ cfg=hs.config.ratelimiting.remote_media_downloads,
+ )
+
# List of StorageProviders where we should search for media and
# potentially upload to.
storage_providers = []
@@ -464,6 +471,7 @@ class MediaRepository:
media_id: str,
name: Optional[str],
max_timeout_ms: int,
+ ip_address: str,
) -> None:
"""Respond to requests for remote media.
@@ -475,6 +483,7 @@ class MediaRepository:
the filename in the Content-Disposition header of the response.
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
+ ip_address: the IP address of the requester
Returns:
Resolves once a response has successfully been written to request
@@ -500,7 +509,11 @@ class MediaRepository:
key = (server_name, media_id)
async with self.remote_media_linearizer.queue(key):
responder, media_info = await self._get_remote_media_impl(
- server_name, media_id, max_timeout_ms
+ server_name,
+ media_id,
+ max_timeout_ms,
+ self.download_ratelimiter,
+ ip_address,
)
# We deliberately stream the file outside the lock
@@ -517,7 +530,7 @@ class MediaRepository:
respond_404(request)
async def get_remote_media_info(
- self, server_name: str, media_id: str, max_timeout_ms: int
+ self, server_name: str, media_id: str, max_timeout_ms: int, ip_address: str
) -> RemoteMedia:
"""Gets the media info associated with the remote file, downloading
if necessary.
@@ -527,6 +540,7 @@ class MediaRepository:
media_id: The media ID of the content (as defined by the remote server).
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
+ ip_address: IP address of the requester
Returns:
The media info of the file
@@ -542,7 +556,11 @@ class MediaRepository:
key = (server_name, media_id)
async with self.remote_media_linearizer.queue(key):
responder, media_info = await self._get_remote_media_impl(
- server_name, media_id, max_timeout_ms
+ server_name,
+ media_id,
+ max_timeout_ms,
+ self.download_ratelimiter,
+ ip_address,
)
# Ensure we actually use the responder so that it releases resources
@@ -553,7 +571,12 @@ class MediaRepository:
return media_info
async def _get_remote_media_impl(
- self, server_name: str, media_id: str, max_timeout_ms: int
+ self,
+ server_name: str,
+ media_id: str,
+ max_timeout_ms: int,
+ download_ratelimiter: Ratelimiter,
+ ip_address: str,
) -> Tuple[Optional[Responder], RemoteMedia]:
"""Looks for media in local cache, if not there then attempt to
download from remote server.
@@ -564,6 +587,9 @@ class MediaRepository:
remote server).
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
+ download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
+ requester IP.
+ ip_address: the IP address of the requester
Returns:
A tuple of responder and the media info of the file.
@@ -596,7 +622,7 @@ class MediaRepository:
try:
media_info = await self._download_remote_file(
- server_name, media_id, max_timeout_ms
+ server_name, media_id, max_timeout_ms, download_ratelimiter, ip_address
)
except SynapseError:
raise
@@ -630,6 +656,8 @@ class MediaRepository:
server_name: str,
media_id: str,
max_timeout_ms: int,
+ download_ratelimiter: Ratelimiter,
+ ip_address: str,
) -> RemoteMedia:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
@@ -641,6 +669,9 @@ class MediaRepository:
locally generated.
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
+ download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
+ requester IP
+ ip_address: the IP address of the requester
Returns:
The media info of the file.
@@ -658,6 +689,8 @@ class MediaRepository:
output_stream=f,
max_size=self.max_upload_size,
max_timeout_ms=max_timeout_ms,
+ download_ratelimiter=download_ratelimiter,
+ ip_address=ip_address,
)
except RequestSendFailed as e:
logger.warning(
|