From 9b2bc75ed49ea52008dc787a09337ca4a843d8e7 Mon Sep 17 00:00:00 2001 From: Hugh Nimmo-Smith Date: Wed, 4 Jun 2025 13:09:11 +0100 Subject: Add ratelimit callbacks to module API to allow dynamic ratelimiting (#18458) --- synapse/api/ratelimiting.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) (limited to 'synapse/api/ratelimiting.py') diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 8665b3b765..38e5bdaa75 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -20,7 +20,7 @@ # # -from typing import Dict, Hashable, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Hashable, Optional, Tuple from synapse.api.errors import LimitExceededError from synapse.config.ratelimiting import RatelimitSettings @@ -28,6 +28,12 @@ from synapse.storage.databases.main import DataStore from synapse.types import Requester from synapse.util import Clock +if TYPE_CHECKING: + # To avoid circular imports: + from synapse.module_api.callbacks.ratelimit_callbacks import ( + RatelimitModuleApiCallbacks, + ) + class Ratelimiter: """ @@ -72,12 +78,14 @@ class Ratelimiter: store: DataStore, clock: Clock, cfg: RatelimitSettings, + ratelimit_callbacks: Optional["RatelimitModuleApiCallbacks"] = None, ): self.clock = clock self.rate_hz = cfg.per_second self.burst_count = cfg.burst_count self.store = store self._limiter_name = cfg.key + self._ratelimit_callbacks = ratelimit_callbacks # A dictionary representing the token buckets tracked by this rate # limiter. Each entry maps a key of arbitrary type to a tuple representing: @@ -165,6 +173,20 @@ class Ratelimiter: if override and not override.messages_per_second: return True, -1.0 + if requester and self._ratelimit_callbacks: + # Check if the user has a custom rate limit for this specific limiter + # as returned by the module API. + module_override = ( + await self._ratelimit_callbacks.get_ratelimit_override_for_user( + requester.user.to_string(), + self._limiter_name, + ) + ) + + if module_override: + rate_hz = module_override.messages_per_second + burst_count = module_override.burst_count + # Override default values if set time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() rate_hz = rate_hz if rate_hz is not None else self.rate_hz -- cgit 1.5.1