diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index 8665b3b765..38e5bdaa75 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -20,7 +20,7 @@
#
#
-from typing import Dict, Hashable, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Hashable, Optional, Tuple
from synapse.api.errors import LimitExceededError
from synapse.config.ratelimiting import RatelimitSettings
@@ -28,6 +28,12 @@ from synapse.storage.databases.main import DataStore
from synapse.types import Requester
from synapse.util import Clock
+if TYPE_CHECKING:
+ # To avoid circular imports:
+ from synapse.module_api.callbacks.ratelimit_callbacks import (
+ RatelimitModuleApiCallbacks,
+ )
+
class Ratelimiter:
"""
@@ -72,12 +78,14 @@ class Ratelimiter:
store: DataStore,
clock: Clock,
cfg: RatelimitSettings,
+ ratelimit_callbacks: Optional["RatelimitModuleApiCallbacks"] = None,
):
self.clock = clock
self.rate_hz = cfg.per_second
self.burst_count = cfg.burst_count
self.store = store
self._limiter_name = cfg.key
+ self._ratelimit_callbacks = ratelimit_callbacks
# A dictionary representing the token buckets tracked by this rate
# limiter. Each entry maps a key of arbitrary type to a tuple representing:
@@ -165,6 +173,20 @@ class Ratelimiter:
if override and not override.messages_per_second:
return True, -1.0
+ if requester and self._ratelimit_callbacks:
+ # Check if the user has a custom rate limit for this specific limiter
+ # as returned by the module API.
+ module_override = (
+ await self._ratelimit_callbacks.get_ratelimit_override_for_user(
+ requester.user.to_string(),
+ self._limiter_name,
+ )
+ )
+
+ if module_override:
+ rate_hz = module_override.messages_per_second
+ burst_count = module_override.burst_count
+
# Override default values if set
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
rate_hz = rate_hz if rate_hz is not None else self.rate_hz
|