summary refs log tree commit diff
path: root/synapse/api/ratelimiting.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/api/ratelimiting.py')
-rw-r--r--synapse/api/ratelimiting.py24
1 files changed, 23 insertions, 1 deletions
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py

index 8665b3b765..38e5bdaa75 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py
@@ -20,7 +20,7 @@ # # -from typing import Dict, Hashable, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Hashable, Optional, Tuple from synapse.api.errors import LimitExceededError from synapse.config.ratelimiting import RatelimitSettings @@ -28,6 +28,12 @@ from synapse.storage.databases.main import DataStore from synapse.types import Requester from synapse.util import Clock +if TYPE_CHECKING: + # To avoid circular imports: + from synapse.module_api.callbacks.ratelimit_callbacks import ( + RatelimitModuleApiCallbacks, + ) + class Ratelimiter: """ @@ -72,12 +78,14 @@ class Ratelimiter: store: DataStore, clock: Clock, cfg: RatelimitSettings, + ratelimit_callbacks: Optional["RatelimitModuleApiCallbacks"] = None, ): self.clock = clock self.rate_hz = cfg.per_second self.burst_count = cfg.burst_count self.store = store self._limiter_name = cfg.key + self._ratelimit_callbacks = ratelimit_callbacks # A dictionary representing the token buckets tracked by this rate # limiter. Each entry maps a key of arbitrary type to a tuple representing: @@ -165,6 +173,20 @@ class Ratelimiter: if override and not override.messages_per_second: return True, -1.0 + if requester and self._ratelimit_callbacks: + # Check if the user has a custom rate limit for this specific limiter + # as returned by the module API. + module_override = ( + await self._ratelimit_callbacks.get_ratelimit_override_for_user( + requester.user.to_string(), + self._limiter_name, + ) + ) + + if module_override: + rate_hz = module_override.messages_per_second + burst_count = module_override.burst_count + # Override default values if set time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() rate_hz = rate_hz if rate_hz is not None else self.rate_hz