diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index 13fff302fe..79b7631172 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -13,7 +13,7 @@
# limitations under the License.
from collections import OrderedDict
-from typing import Any, Tuple
+from typing import Any, Optional, Tuple
from synapse.api.errors import LimitExceededError
@@ -23,10 +23,8 @@ class Ratelimiter(object):
Ratelimit actions marked by arbitrary keys.
Args:
- rate_hz: The long term number of actions that can be performed in a
- second.
- burst_count: How many actions that can be performed before being
- limited.
+ rate_hz: The long term number of actions that can be performed in a second.
+ burst_count: How many actions that can be performed before being limited.
"""
def __init__(self, rate_hz: float, burst_count: int):
@@ -40,7 +38,12 @@ class Ratelimiter(object):
self.burst_count = burst_count
def can_do_action(
- self, key: Any, time_now_s: int, update: bool = True,
+ self,
+ key: Any,
+ time_now_s: int,
+ update: bool = True,
+ rate_hz: Optional[float] = None,
+ burst_count: Optional[int] = None,
) -> Tuple[bool, float]:
"""Can the entity (e.g. user or IP address) perform the action?
@@ -49,27 +52,36 @@ class Ratelimiter(object):
(when sending events), an IP address, etc.
time_now_s: The time now
update: Whether to count this check as performing the action
+ rate_hz: The long term number of actions that can be performed in a second.
+ Overrides the value set during instantiation if set.
+ burst_count: How many actions that can be performed before being limited.
+ Overrides the value set during instantiation if set.
+
Returns:
A tuple containing:
* A bool indicating if they can perform the action now
* The time in seconds of when it can next be performed.
-1 if a rate_hz has not been defined for this Ratelimiter
"""
+ # Override default values if set
+ rate_hz = rate_hz or self.rate_hz
+ burst_count = burst_count or self.burst_count
+
# Remove any expired entries
- self._prune_message_counts(time_now_s)
+ self._prune_message_counts(time_now_s, rate_hz)
# Check if there is an existing count entry for this key
action_count, time_start, = self.actions.get(key, (0.0, time_now_s))
# Check whether performing another action is allowed
time_delta = time_now_s - time_start
- performed_count = action_count - time_delta * self.rate_hz
+ performed_count = action_count - time_delta * rate_hz
if performed_count < 0:
# Allow, reset back to count 1
allowed = True
time_start = time_now_s
action_count = 1.0
- elif performed_count > self.burst_count - 1.0:
+ elif performed_count > burst_count - 1.0:
# Deny, we have exceeded our burst count
allowed = False
else:
@@ -82,9 +94,7 @@ class Ratelimiter(object):
# Figure out the time when an action can be performed again
if self.rate_hz > 0:
- time_allowed = (
- time_start + (action_count - self.burst_count + 1) / self.rate_hz
- )
+ time_allowed = time_start + (action_count - burst_count + 1) / rate_hz
# Don't give back a time in the past
if time_allowed < time_now_s:
@@ -95,26 +105,34 @@ class Ratelimiter(object):
return allowed, time_allowed
- def _prune_message_counts(self, time_now_s: int):
- """Remove message count entries that are older than
+ def _prune_message_counts(self, time_now_s: int, rate_hz: float):
+ """Remove message count entries that have not exceeded their defined
+ rate_hz limit
Args:
time_now_s: The current time
+ rate_hz: The long term number of actions that can be performed in a second.
"""
# We create a copy of the key list here as the dictionary is modified during
# the loop
for key in list(self.actions.keys()):
action_count, time_start = self.actions[key]
+ # Rate limit = "seconds since we started limiting this action" * rate_hz
+ # If this limit has not been exceeded, wipe our record of this action
time_delta = time_now_s - time_start
- if action_count - time_delta * self.rate_hz > 0:
- # XXX: Should this be a continue?
- break
+ if action_count - time_delta * rate_hz > 0:
+ continue
else:
del self.actions[key]
def ratelimit(
- self, key: Any, time_now_s: int, update: bool = True,
+ self,
+ key: Any,
+ time_now_s: int,
+ update: bool = True,
+ rate_hz: Optional[float] = None,
+ burst_count: Optional[int] = None,
):
"""Checks if an action can be performed. If not, raises a LimitExceededError
@@ -122,12 +140,22 @@ class Ratelimiter(object):
key: An arbitrary key used to classify an action
time_now_s: The current time
update: Whether to count this check as performing the action
+ rate_hz: The long term number of actions that can be performed in a second.
+ Overrides the value set during instantiation if set.
+ burst_count: How many actions that can be performed before being limited.
+ Overrides the value set during instantiation if set.
Raises:
LimitExceededError: If an action could not be performed, along with the time in
milliseconds until the action can be performed again
"""
- allowed, time_allowed = self.can_do_action(key, time_now_s, update)
+ # Override default values if set
+ rate_hz = rate_hz or self.rate_hz
+ burst_count = burst_count or self.burst_count
+
+ allowed, time_allowed = self.can_do_action(
+ key, time_now_s, update=update, rate_hz=rate_hz, burst_count=burst_count
+ )
if not allowed:
raise LimitExceededError(
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 206702b6ad..e10e2427c4 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -46,7 +46,6 @@ class BaseHandler(object):
self.clock = hs.get_clock()
self.hs = hs
- self.ratelimiter = None
self.request_ratelimiter = hs.get_request_ratelimiter()
self._rc_message = self.hs.config.rc_message
@@ -103,14 +102,17 @@ class BaseHandler(object):
if is_admin_redaction and self.admin_redaction_ratelimiter:
# If we have separate config for admin redactions, use a separate
- # ratelimiter as to not have user_id's clash
+ # ratelimiter as to not have user_ids clash
self.admin_redaction_ratelimiter.ratelimit(user_id, time_now, update)
else:
# Override rate and burst count per-user
- self.request_ratelimiter.rate_hz = messages_per_second
- self.request_ratelimiter.burst_count = burst_count
-
- self.request_ratelimiter.ratelimit(user_id, time_now, update)
+ self.request_ratelimiter.ratelimit(
+ user_id,
+ time_now,
+ update,
+ rate_hz=messages_per_second,
+ burst_count=burst_count,
+ )
async def maybe_kick_guest_users(self, event, context=None):
# Technically this function invalidates current_state by changing it.
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 5b2dcde2ba..891c986fbc 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -56,8 +56,7 @@ class ProfileTestCase(unittest.TestCase):
federation_server=Mock(),
federation_registry=self.mock_registry,
request_ratelimiter=NonCallableMock(
- # rate_hz and burst_count are overridden in BaseHandler
- spec_set=["can_do_action", "ratelimit", "rate_hz", "burst_count"]
+ spec_set=["can_do_action", "ratelimit"]
),
login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]),
)
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 28b7ce085b..ba10f34468 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -50,14 +50,12 @@ class RoomBase(unittest.HomeserverTestCase):
http_client=None,
federation_client=Mock(),
request_ratelimiter=NonCallableMock(
- # rate_hz and burst_count are overridden in BaseHandler
- spec_set=["can_do_action", "ratelimit", "rate_hz", "burst_count"]
+ spec_set=["can_do_action", "ratelimit"]
),
login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]),
)
self.request_ratelimiter = self.hs.get_request_ratelimiter()
self.request_ratelimiter.can_do_action.return_value = (True, 0)
- self.request_ratelimiter.rate_hz = Mock()
self.login_ratelimiter = self.hs.get_login_ratelimiter()
self.login_ratelimiter.can_do_action.return_value = (True, 0)
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 27d38d354a..2ec678a2a2 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -43,8 +43,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
http_client=None,
federation_client=Mock(),
request_ratelimiter=NonCallableMock(
- # rate_hz and burst_count are overridden in BaseHandler
- spec_set=["can_do_action", "ratelimit", "rate_hz", "burst_count"]
+ spec_set=["can_do_action", "ratelimit"]
),
login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]),
)
|