diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index d2cb4576df..3fa33f5373 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -218,3 +218,13 @@ class RatelimitConfig(Config):
"rc_media_create",
defaults={"per_second": 10, "burst_count": 50},
)
+
+ self.remote_media_downloads = RatelimitSettings(
+ key="rc_remote_media_downloads",
+ per_second=self.parse_size(
+ config.get("remote_media_download_per_second", "87K")
+ ),
+ burst_count=self.parse_size(
+ config.get("remote_media_download_burst_count", "500M")
+ ),
+ )
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index e613eb87a6..f0f5a37a57 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -56,6 +56,7 @@ from synapse.api.errors import (
SynapseError,
UnsupportedRoomVersionError,
)
+from synapse.api.ratelimiting import Ratelimiter
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
@@ -1877,6 +1878,8 @@ class FederationClient(FederationBase):
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
+ download_ratelimiter: Ratelimiter,
+ ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
try:
return await self.transport_layer.download_media_v3(
@@ -1885,6 +1888,8 @@ class FederationClient(FederationBase):
output_stream=output_stream,
max_size=max_size,
max_timeout_ms=max_timeout_ms,
+ download_ratelimiter=download_ratelimiter,
+ ip_address=ip_address,
)
except HttpResponseException as e:
# If an error is received that is due to an unrecognised endpoint,
@@ -1905,6 +1910,8 @@ class FederationClient(FederationBase):
output_stream=output_stream,
max_size=max_size,
max_timeout_ms=max_timeout_ms,
+ download_ratelimiter=download_ratelimiter,
+ ip_address=ip_address,
)
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index de408f7f8d..af1336fe5f 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -43,6 +43,7 @@ import ijson
from synapse.api.constants import Direction, Membership
from synapse.api.errors import Codes, HttpResponseException, SynapseError
+from synapse.api.ratelimiting import Ratelimiter
from synapse.api.room_versions import RoomVersion
from synapse.api.urls import (
FEDERATION_UNSTABLE_PREFIX,
@@ -819,6 +820,8 @@ class TransportLayerClient:
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
+ download_ratelimiter: Ratelimiter,
+ ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
path = f"/_matrix/media/r0/download/{destination}/{media_id}"
@@ -834,6 +837,8 @@ class TransportLayerClient:
"allow_remote": "false",
"timeout_ms": str(max_timeout_ms),
},
+ download_ratelimiter=download_ratelimiter,
+ ip_address=ip_address,
)
async def download_media_v3(
@@ -843,6 +848,8 @@ class TransportLayerClient:
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
+ download_ratelimiter: Ratelimiter,
+ ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
path = f"/_matrix/media/v3/download/{destination}/{media_id}"
@@ -862,6 +869,8 @@ class TransportLayerClient:
"allow_redirect": "true",
},
follow_redirects=True,
+ download_ratelimiter=download_ratelimiter,
+ ip_address=ip_address,
)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index c73a589e6c..104b803b0f 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -57,7 +57,7 @@ from twisted.internet.interfaces import IReactorTime
from twisted.internet.task import Cooperator
from twisted.web.client import ResponseFailed
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent, IBodyProducer, IResponse
+from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse
import synapse.metrics
import synapse.util.retryutils
@@ -68,6 +68,7 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
+from synapse.api.ratelimiting import Ratelimiter
from synapse.crypto.context_factory import FederationPolicyForHTTPS
from synapse.http import QuieterFileBodyProducer
from synapse.http.client import (
@@ -1411,9 +1412,11 @@ class MatrixFederationHttpClient:
destination: str,
path: str,
output_stream: BinaryIO,
+ download_ratelimiter: Ratelimiter,
+ ip_address: str,
+ max_size: int,
args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True,
- max_size: Optional[int] = None,
ignore_backoff: bool = False,
follow_redirects: bool = False,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
@@ -1422,6 +1425,10 @@ class MatrixFederationHttpClient:
destination: The remote server to send the HTTP request to.
path: The HTTP path to GET.
output_stream: File to write the response body to.
+ download_ratelimiter: a ratelimiter to limit remote media downloads, keyed to
+ requester IP
+ ip_address: IP address of the requester
+ max_size: maximum allowable size in bytes of the file
args: Optional dictionary used to create the query string.
ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
@@ -1441,11 +1448,27 @@ class MatrixFederationHttpClient:
federation whitelist
RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc.
+ SynapseError: If the requested file exceeds ratelimits
"""
request = MatrixFederationRequest(
method="GET", destination=destination, path=path, query=args
)
+ # check for a minimum balance of 1MiB in ratelimiter before initiating request
+ send_req, _ = await download_ratelimiter.can_do_action(
+ requester=None, key=ip_address, n_actions=1048576, update=False
+ )
+
+ if not send_req:
+ msg = "Requested file size exceeds ratelimits"
+ logger.warning(
+ "{%s} [%s] %s",
+ request.txn_id,
+ request.destination,
+ msg,
+ )
+ raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
+
response = await self._send_request(
request,
retry_on_dns_fail=retry_on_dns_fail,
@@ -1455,12 +1478,36 @@ class MatrixFederationHttpClient:
headers = dict(response.headers.getAllRawHeaders())
+ expected_size = response.length
+ # if we don't get an expected length then use the max length
+ if expected_size == UNKNOWN_LENGTH:
+ expected_size = max_size
+ logger.debug(
+ f"File size unknown, assuming file is max allowable size: {max_size}"
+ )
+
+ read_body, _ = await download_ratelimiter.can_do_action(
+ requester=None,
+ key=ip_address,
+ n_actions=expected_size,
+ )
+ if not read_body:
+ msg = "Requested file size exceeds ratelimits"
+ logger.warning(
+ "{%s} [%s] %s",
+ request.txn_id,
+ request.destination,
+ msg,
+ )
+ raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
+
try:
- d = read_body_with_max_size(response, output_stream, max_size)
+ # add a byte of headroom to max size as function errs at >=
+ d = read_body_with_max_size(response, output_stream, expected_size + 1)
d.addTimeout(self.default_timeout_seconds, self.reactor)
length = await make_deferred_yieldable(d)
except BodyExceededMaxSize:
- msg = "Requested file is too large > %r bytes" % (max_size,)
+ msg = "Requested file is too large > %r bytes" % (expected_size,)
logger.warning(
"{%s} [%s] %s",
request.txn_id,
diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 9c29e09653..6ed56099ca 100644
--- a/synapse/media/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -42,6 +42,7 @@ from synapse.api.errors import (
SynapseError,
cs_error,
)
+from synapse.api.ratelimiting import Ratelimiter
from synapse.config.repository import ThumbnailRequirement
from synapse.http.server import respond_with_json
from synapse.http.site import SynapseRequest
@@ -111,6 +112,12 @@ class MediaRepository:
)
self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from
+ self.download_ratelimiter = Ratelimiter(
+ store=hs.get_storage_controllers().main,
+ clock=hs.get_clock(),
+ cfg=hs.config.ratelimiting.remote_media_downloads,
+ )
+
# List of StorageProviders where we should search for media and
# potentially upload to.
storage_providers = []
@@ -464,6 +471,7 @@ class MediaRepository:
media_id: str,
name: Optional[str],
max_timeout_ms: int,
+ ip_address: str,
) -> None:
"""Respond to requests for remote media.
@@ -475,6 +483,7 @@ class MediaRepository:
the filename in the Content-Disposition header of the response.
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
+ ip_address: the IP address of the requester
Returns:
Resolves once a response has successfully been written to request
@@ -500,7 +509,11 @@ class MediaRepository:
key = (server_name, media_id)
async with self.remote_media_linearizer.queue(key):
responder, media_info = await self._get_remote_media_impl(
- server_name, media_id, max_timeout_ms
+ server_name,
+ media_id,
+ max_timeout_ms,
+ self.download_ratelimiter,
+ ip_address,
)
# We deliberately stream the file outside the lock
@@ -517,7 +530,7 @@ class MediaRepository:
respond_404(request)
async def get_remote_media_info(
- self, server_name: str, media_id: str, max_timeout_ms: int
+ self, server_name: str, media_id: str, max_timeout_ms: int, ip_address: str
) -> RemoteMedia:
"""Gets the media info associated with the remote file, downloading
if necessary.
@@ -527,6 +540,7 @@ class MediaRepository:
media_id: The media ID of the content (as defined by the remote server).
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
+ ip_address: IP address of the requester
Returns:
The media info of the file
@@ -542,7 +556,11 @@ class MediaRepository:
key = (server_name, media_id)
async with self.remote_media_linearizer.queue(key):
responder, media_info = await self._get_remote_media_impl(
- server_name, media_id, max_timeout_ms
+ server_name,
+ media_id,
+ max_timeout_ms,
+ self.download_ratelimiter,
+ ip_address,
)
# Ensure we actually use the responder so that it releases resources
@@ -553,7 +571,12 @@ class MediaRepository:
return media_info
async def _get_remote_media_impl(
- self, server_name: str, media_id: str, max_timeout_ms: int
+ self,
+ server_name: str,
+ media_id: str,
+ max_timeout_ms: int,
+ download_ratelimiter: Ratelimiter,
+ ip_address: str,
) -> Tuple[Optional[Responder], RemoteMedia]:
"""Looks for media in local cache, if not there then attempt to
download from remote server.
@@ -564,6 +587,9 @@ class MediaRepository:
remote server).
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
+ download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
+ requester IP.
+ ip_address: the IP address of the requester
Returns:
A tuple of responder and the media info of the file.
@@ -596,7 +622,7 @@ class MediaRepository:
try:
media_info = await self._download_remote_file(
- server_name, media_id, max_timeout_ms
+ server_name, media_id, max_timeout_ms, download_ratelimiter, ip_address
)
except SynapseError:
raise
@@ -630,6 +656,8 @@ class MediaRepository:
server_name: str,
media_id: str,
max_timeout_ms: int,
+ download_ratelimiter: Ratelimiter,
+ ip_address: str,
) -> RemoteMedia:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
@@ -641,6 +669,9 @@ class MediaRepository:
locally generated.
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
+ download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
+ requester IP
+ ip_address: the IP address of the requester
Returns:
The media info of the file.
@@ -658,6 +689,8 @@ class MediaRepository:
output_stream=f,
max_size=self.max_upload_size,
max_timeout_ms=max_timeout_ms,
+ download_ratelimiter=download_ratelimiter,
+ ip_address=ip_address,
)
except RequestSendFailed as e:
logger.warning(
diff --git a/synapse/media/thumbnailer.py b/synapse/media/thumbnailer.py
index cc3acf51e1..f8a9560784 100644
--- a/synapse/media/thumbnailer.py
+++ b/synapse/media/thumbnailer.py
@@ -359,9 +359,10 @@ class ThumbnailProvider:
desired_method: str,
desired_type: str,
max_timeout_ms: int,
+ ip_address: str,
) -> None:
media_info = await self.media_repo.get_remote_media_info(
- server_name, media_id, max_timeout_ms
+ server_name, media_id, max_timeout_ms, ip_address
)
if not media_info:
respond_404(request)
@@ -422,12 +423,13 @@ class ThumbnailProvider:
method: str,
m_type: str,
max_timeout_ms: int,
+ ip_address: str,
) -> None:
# TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead of
# downloading the remote file and generating our own thumbnails.
media_info = await self.media_repo.get_remote_media_info(
- server_name, media_id, max_timeout_ms
+ server_name, media_id, max_timeout_ms, ip_address
)
if not media_info:
return
diff --git a/synapse/rest/client/media.py b/synapse/rest/client/media.py
index 172d240783..0c089163c1 100644
--- a/synapse/rest/client/media.py
+++ b/synapse/rest/client/media.py
@@ -174,6 +174,7 @@ class UnstableThumbnailResource(RestServlet):
respond_404(request)
return
+ ip_address = request.getClientAddress().host
remote_resp_function = (
self.thumbnailer.select_or_generate_remote_thumbnail
if self.dynamic_thumbnails
@@ -188,6 +189,7 @@ class UnstableThumbnailResource(RestServlet):
method,
m_type,
max_timeout_ms,
+ ip_address,
)
self.media_repo.mark_recently_accessed(server_name, media_id)
diff --git a/synapse/rest/media/download_resource.py b/synapse/rest/media/download_resource.py
index 8ba723c8d4..1628d58926 100644
--- a/synapse/rest/media/download_resource.py
+++ b/synapse/rest/media/download_resource.py
@@ -97,6 +97,12 @@ class DownloadResource(RestServlet):
respond_404(request)
return
+ ip_address = request.getClientAddress().host
await self.media_repo.get_remote_media(
- request, server_name, media_id, file_name, max_timeout_ms
+ request,
+ server_name,
+ media_id,
+ file_name,
+ max_timeout_ms,
+ ip_address,
)
diff --git a/synapse/rest/media/thumbnail_resource.py b/synapse/rest/media/thumbnail_resource.py
index fe8fbb06e4..ce511c6dce 100644
--- a/synapse/rest/media/thumbnail_resource.py
+++ b/synapse/rest/media/thumbnail_resource.py
@@ -104,6 +104,7 @@ class ThumbnailResource(RestServlet):
respond_404(request)
return
+ ip_address = request.getClientAddress().host
remote_resp_function = (
self.thumbnail_provider.select_or_generate_remote_thumbnail
if self.dynamic_thumbnails
@@ -118,5 +119,6 @@ class ThumbnailResource(RestServlet):
method,
m_type,
max_timeout_ms,
+ ip_address,
)
self.media_repo.mark_recently_accessed(server_name, media_id)
|