diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index b80630c5d3..4f3bf8f770 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -20,8 +20,7 @@
#
#
-from collections import OrderedDict
-from typing import Hashable, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Hashable, Optional, Tuple
from synapse.api.errors import LimitExceededError
from synapse.config.ratelimiting import RatelimitSettings
@@ -29,6 +28,12 @@ from synapse.storage.databases.main import DataStore
from synapse.types import Requester
from synapse.util import Clock
+if TYPE_CHECKING:
+ # To avoid circular imports:
+ from synapse.module_api.callbacks.ratelimit_callbacks import (
+ RatelimitModuleApiCallbacks,
+ )
+
class Ratelimiter:
"""
@@ -73,19 +78,23 @@ class Ratelimiter:
store: DataStore,
clock: Clock,
cfg: RatelimitSettings,
+ ratelimit_callbacks: Optional["RatelimitModuleApiCallbacks"] = None,
):
self.clock = clock
self.rate_hz = cfg.per_second
self.burst_count = cfg.burst_count
self.store = store
self._limiter_name = cfg.key
+ self._ratelimit_callbacks = ratelimit_callbacks
- # An ordered dictionary representing the token buckets tracked by this rate
+ # A dictionary representing the token buckets tracked by this rate
# limiter. Each entry maps a key of arbitrary type to a tuple representing:
# * The number of tokens currently in the bucket,
# * The time point when the bucket was last completely empty, and
# * The rate_hz (leak rate) of this particular bucket.
- self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict()
+ self.actions: Dict[Hashable, Tuple[float, float, float]] = {}
+
+ self.clock.looping_call(self._prune_message_counts, 60 * 1000)
def _get_key(
self, requester: Optional[Requester], key: Optional[Hashable]
@@ -164,14 +173,25 @@ class Ratelimiter:
if override and not override.messages_per_second:
return True, -1.0
+ if requester and self._ratelimit_callbacks:
+ # Check if the user has a custom rate limit for this specific limiter
+ # as returned by the module API.
+ module_override = (
+ await self._ratelimit_callbacks.get_ratelimit_override_for_user(
+ requester.user.to_string(),
+ self._limiter_name,
+ )
+ )
+
+ if module_override:
+ rate_hz = module_override.per_second
+ burst_count = module_override.burst_count
+
# Override default values if set
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
rate_hz = rate_hz if rate_hz is not None else self.rate_hz
burst_count = burst_count if burst_count is not None else self.burst_count
- # Remove any expired entries
- self._prune_message_counts(time_now_s)
-
# Check if there is an existing count entry for this key
action_count, time_start, _ = self._get_action_counts(key, time_now_s)
@@ -246,13 +266,12 @@ class Ratelimiter:
action_count, time_start, rate_hz = self._get_action_counts(key, time_now_s)
self.actions[key] = (action_count + n_actions, time_start, rate_hz)
- def _prune_message_counts(self, time_now_s: float) -> None:
+ def _prune_message_counts(self) -> None:
"""Remove message count entries that have not exceeded their defined
rate_hz limit
-
- Args:
- time_now_s: The current time
"""
+ time_now_s = self.clock.time()
+
# We create a copy of the key list here as the dictionary is modified during
# the loop
for key in list(self.actions.keys()):
@@ -275,6 +294,7 @@ class Ratelimiter:
update: bool = True,
n_actions: int = 1,
_time_now_s: Optional[float] = None,
+ pause: Optional[float] = 0.5,
) -> None:
"""Checks if an action can be performed. If not, raises a LimitExceededError
@@ -298,6 +318,8 @@ class Ratelimiter:
at all.
_time_now_s: The current time. Optional, defaults to the current time according
to self.clock. Only used by tests.
+ pause: Time in seconds to pause when an action is being limited. Defaults to 0.5
+ to stop clients from "tight-looping" on retrying their request.
Raises:
LimitExceededError: If an action could not be performed, along with the time in
@@ -316,9 +338,8 @@ class Ratelimiter:
)
if not allowed:
- # We pause for a bit here to stop clients from "tight-looping" on
- # retrying their request.
- await self.clock.sleep(0.5)
+ if pause:
+ await self.clock.sleep(pause)
raise LimitExceededError(
limiter_name=self._limiter_name,
|