diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index c3f07bc1a3..2244b8a340 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -17,6 +17,7 @@ from collections import OrderedDict
from typing import Hashable, Optional, Tuple
from synapse.api.errors import LimitExceededError
+from synapse.storage.databases.main import DataStore
from synapse.types import Requester
from synapse.util import Clock
@@ -31,10 +32,13 @@ class Ratelimiter:
burst_count: How many actions that can be performed before being limited.
"""
- def __init__(self, clock: Clock, rate_hz: float, burst_count: int):
+ def __init__(
+ self, store: DataStore, clock: Clock, rate_hz: float, burst_count: int
+ ):
self.clock = clock
self.rate_hz = rate_hz
self.burst_count = burst_count
+ self.store = store
# A ordered dictionary keeping track of actions, when they were last
# performed and how often. Each entry is a mapping from a key of arbitrary type
@@ -46,45 +50,10 @@ class Ratelimiter:
OrderedDict()
) # type: OrderedDict[Hashable, Tuple[float, int, float]]
- def can_requester_do_action(
- self,
- requester: Requester,
- rate_hz: Optional[float] = None,
- burst_count: Optional[int] = None,
- update: bool = True,
- _time_now_s: Optional[int] = None,
- ) -> Tuple[bool, float]:
- """Can the requester perform the action?
-
- Args:
- requester: The requester to key off when rate limiting. The user property
- will be used.
- rate_hz: The long term number of actions that can be performed in a second.
- Overrides the value set during instantiation if set.
- burst_count: How many actions that can be performed before being limited.
- Overrides the value set during instantiation if set.
- update: Whether to count this check as performing the action
- _time_now_s: The current time. Optional, defaults to the current time according
- to self.clock. Only used by tests.
-
- Returns:
- A tuple containing:
- * A bool indicating if they can perform the action now
- * The reactor timestamp for when the action can be performed next.
- -1 if rate_hz is less than or equal to zero
- """
- # Disable rate limiting of users belonging to any AS that is configured
- # not to be rate limited in its registration file (rate_limited: true|false).
- if requester.app_service and not requester.app_service.is_rate_limited():
- return True, -1.0
-
- return self.can_do_action(
- requester.user.to_string(), rate_hz, burst_count, update, _time_now_s
- )
-
- def can_do_action(
+ async def can_do_action(
self,
- key: Hashable,
+ requester: Optional[Requester],
+ key: Optional[Hashable] = None,
rate_hz: Optional[float] = None,
burst_count: Optional[int] = None,
update: bool = True,
@@ -92,9 +61,16 @@ class Ratelimiter:
) -> Tuple[bool, float]:
"""Can the entity (e.g. user or IP address) perform the action?
+ Checks if the user has ratelimiting disabled in the database by looking
+ for null/zero values in the `ratelimit_override` table. (Non-zero
+ values aren't honoured, as they're specific to the event sending
+ ratelimiter, rather than all ratelimiters)
+
Args:
- key: The key we should use when rate limiting. Can be a user ID
- (when sending events), an IP address, etc.
+ requester: The requester that is doing the action, if any. Used to check
+ if the user has ratelimits disabled in the database.
+ key: An arbitrary key used to classify an action. Defaults to the
+ requester's user ID.
rate_hz: The long term number of actions that can be performed in a second.
Overrides the value set during instantiation if set.
burst_count: How many actions that can be performed before being limited.
@@ -109,6 +85,30 @@ class Ratelimiter:
* The reactor timestamp for when the action can be performed next.
-1 if rate_hz is less than or equal to zero
"""
+ if key is None:
+ if not requester:
+ raise ValueError("Must supply at least one of `requester` or `key`")
+
+ key = requester.user.to_string()
+
+ if requester:
+ # Disable rate limiting of users belonging to any AS that is configured
+ # not to be rate limited in its registration file (rate_limited: true|false).
+ if requester.app_service and not requester.app_service.is_rate_limited():
+ return True, -1.0
+
+ # Check if ratelimiting has been disabled for the user.
+ #
+ # Note that we don't use the returned rate/burst count, as the table
+ # is specifically for the event sending ratelimiter. Instead, we
+ # only use it to (somewhat cheekily) infer whether the user should
+ # be subject to any rate limiting or not.
+ override = await self.store.get_ratelimit_for_user(
+ requester.authenticated_entity
+ )
+ if override and not override.messages_per_second:
+ return True, -1.0
+
# Override default values if set
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
rate_hz = rate_hz if rate_hz is not None else self.rate_hz
@@ -175,9 +175,10 @@ class Ratelimiter:
else:
del self.actions[key]
- def ratelimit(
+ async def ratelimit(
self,
- key: Hashable,
+ requester: Optional[Requester],
+ key: Optional[Hashable] = None,
rate_hz: Optional[float] = None,
burst_count: Optional[int] = None,
update: bool = True,
@@ -185,8 +186,16 @@ class Ratelimiter:
):
"""Checks if an action can be performed. If not, raises a LimitExceededError
+ Checks if the user has ratelimiting disabled in the database by looking
+ for null/zero values in the `ratelimit_override` table. (Non-zero
+ values aren't honoured, as they're specific to the event sending
+ ratelimiter, rather than all ratelimiters)
+
Args:
- key: An arbitrary key used to classify an action
+ requester: The requester that is doing the action, if any. Used to check for
+ if the user has ratelimits disabled.
+ key: An arbitrary key used to classify an action. Defaults to the
+ requester's user ID.
rate_hz: The long term number of actions that can be performed in a second.
Overrides the value set during instantiation if set.
burst_count: How many actions that can be performed before being limited.
@@ -201,7 +210,8 @@ class Ratelimiter:
"""
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
- allowed, time_allowed = self.can_do_action(
+ allowed, time_allowed = await self.can_do_action(
+ requester,
key,
rate_hz=rate_hz,
burst_count=burst_count,
|