diff options
author | Andrew Morgan <andrew@amorgan.xyz> | 2020-05-29 18:30:44 +0100 |
---|---|---|
committer | Andrew Morgan <andrew@amorgan.xyz> | 2020-05-29 18:35:19 +0100 |
commit | f6203a60e099651b51dc3d755dbd6b1c6aa8ce08 (patch) | |
tree | 19e077ab17d1fd79ee11d30097bbb68cb846e28c | |
parent | changelog (diff) | |
download | synapse-f6203a60e099651b51dc3d755dbd6b1c6aa8ce08.tar.xz |
Make rate_hz and burst_count overridable per-request
-rw-r--r-- | synapse/api/ratelimiting.py | 66 | ||||
-rw-r--r-- | synapse/handlers/_base.py | 14 | ||||
-rw-r--r-- | tests/handlers/test_profile.py | 3 | ||||
-rw-r--r-- | tests/rest/client/v1/test_rooms.py | 4 | ||||
-rw-r--r-- | tests/rest/client/v1/test_typing.py | 3 |
5 files changed, 58 insertions, 32 deletions
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 13fff302fe..79b7631172 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -13,7 +13,7 @@ # limitations under the License. from collections import OrderedDict -from typing import Any, Tuple +from typing import Any, Optional, Tuple from synapse.api.errors import LimitExceededError @@ -23,10 +23,8 @@ class Ratelimiter(object): Ratelimit actions marked by arbitrary keys. Args: - rate_hz: The long term number of actions that can be performed in a - second. - burst_count: How many actions that can be performed before being - limited. + rate_hz: The long term number of actions that can be performed in a second. + burst_count: How many actions that can be performed before being limited. """ def __init__(self, rate_hz: float, burst_count: int): @@ -40,7 +38,12 @@ class Ratelimiter(object): self.burst_count = burst_count def can_do_action( - self, key: Any, time_now_s: int, update: bool = True, + self, + key: Any, + time_now_s: int, + update: bool = True, + rate_hz: Optional[float] = None, + burst_count: Optional[int] = None, ) -> Tuple[bool, float]: """Can the entity (e.g. user or IP address) perform the action? @@ -49,27 +52,36 @@ class Ratelimiter(object): (when sending events), an IP address, etc. time_now_s: The time now update: Whether to count this check as performing the action + rate_hz: The long term number of actions that can be performed in a second. + Overrides the value set during instantiation if set. + burst_count: How many actions that can be performed before being limited. + Overrides the value set during instantiation if set. + Returns: A tuple containing: * A bool indicating if they can perform the action now * The time in seconds of when it can next be performed. -1 if a rate_hz has not been defined for this Ratelimiter """ + # Override default values if set + rate_hz = rate_hz or self.rate_hz + burst_count = burst_count or self.burst_count + # Remove any expired entries - self._prune_message_counts(time_now_s) + self._prune_message_counts(time_now_s, rate_hz) # Check if there is an existing count entry for this key action_count, time_start, = self.actions.get(key, (0.0, time_now_s)) # Check whether performing another action is allowed time_delta = time_now_s - time_start - performed_count = action_count - time_delta * self.rate_hz + performed_count = action_count - time_delta * rate_hz if performed_count < 0: # Allow, reset back to count 1 allowed = True time_start = time_now_s action_count = 1.0 - elif performed_count > self.burst_count - 1.0: + elif performed_count > burst_count - 1.0: # Deny, we have exceeded our burst count allowed = False else: @@ -82,9 +94,7 @@ class Ratelimiter(object): # Figure out the time when an action can be performed again if self.rate_hz > 0: - time_allowed = ( - time_start + (action_count - self.burst_count + 1) / self.rate_hz - ) + time_allowed = time_start + (action_count - burst_count + 1) / rate_hz # Don't give back a time in the past if time_allowed < time_now_s: @@ -95,26 +105,34 @@ class Ratelimiter(object): return allowed, time_allowed - def _prune_message_counts(self, time_now_s: int): - """Remove message count entries that are older than + def _prune_message_counts(self, time_now_s: int, rate_hz: float): + """Remove message count entries that have not exceeded their defined + rate_hz limit Args: time_now_s: The current time + rate_hz: The long term number of actions that can be performed in a second. """ # We create a copy of the key list here as the dictionary is modified during # the loop for key in list(self.actions.keys()): action_count, time_start = self.actions[key] + # Rate limit = "seconds since we started limiting this action" * rate_hz + # If this limit has not been exceeded, wipe our record of this action time_delta = time_now_s - time_start - if action_count - time_delta * self.rate_hz > 0: - # XXX: Should this be a continue? - break + if action_count - time_delta * rate_hz > 0: + continue else: del self.actions[key] def ratelimit( - self, key: Any, time_now_s: int, update: bool = True, + self, + key: Any, + time_now_s: int, + update: bool = True, + rate_hz: Optional[float] = None, + burst_count: Optional[int] = None, ): """Checks if an action can be performed. If not, raises a LimitExceededError @@ -122,12 +140,22 @@ class Ratelimiter(object): key: An arbitrary key used to classify an action time_now_s: The current time update: Whether to count this check as performing the action + rate_hz: The long term number of actions that can be performed in a second. + Overrides the value set during instantiation if set. + burst_count: How many actions that can be performed before being limited. + Overrides the value set during instantiation if set. Raises: LimitExceededError: If an action could not be performed, along with the time in milliseconds until the action can be performed again """ - allowed, time_allowed = self.can_do_action(key, time_now_s, update) + # Override default values if set + rate_hz = rate_hz or self.rate_hz + burst_count = burst_count or self.burst_count + + allowed, time_allowed = self.can_do_action( + key, time_now_s, update=update, rate_hz=rate_hz, burst_count=burst_count + ) if not allowed: raise LimitExceededError( diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 206702b6ad..e10e2427c4 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -46,7 +46,6 @@ class BaseHandler(object): self.clock = hs.get_clock() self.hs = hs - self.ratelimiter = None self.request_ratelimiter = hs.get_request_ratelimiter() self._rc_message = self.hs.config.rc_message @@ -103,14 +102,17 @@ class BaseHandler(object): if is_admin_redaction and self.admin_redaction_ratelimiter: # If we have separate config for admin redactions, use a separate - # ratelimiter as to not have user_id's clash + # ratelimiter as to not have user_ids clash self.admin_redaction_ratelimiter.ratelimit(user_id, time_now, update) else: # Override rate and burst count per-user - self.request_ratelimiter.rate_hz = messages_per_second - self.request_ratelimiter.burst_count = burst_count - - self.request_ratelimiter.ratelimit(user_id, time_now, update) + self.request_ratelimiter.ratelimit( + user_id, + time_now, + update, + rate_hz=messages_per_second, + burst_count=burst_count, + ) async def maybe_kick_guest_users(self, event, context=None): # Technically this function invalidates current_state by changing it. diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 5b2dcde2ba..891c986fbc 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -56,8 +56,7 @@ class ProfileTestCase(unittest.TestCase): federation_server=Mock(), federation_registry=self.mock_registry, request_ratelimiter=NonCallableMock( - # rate_hz and burst_count are overridden in BaseHandler - spec_set=["can_do_action", "ratelimit", "rate_hz", "burst_count"] + spec_set=["can_do_action", "ratelimit"] ), login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]), ) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 28b7ce085b..ba10f34468 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -50,14 +50,12 @@ class RoomBase(unittest.HomeserverTestCase): http_client=None, federation_client=Mock(), request_ratelimiter=NonCallableMock( - # rate_hz and burst_count are overridden in BaseHandler - spec_set=["can_do_action", "ratelimit", "rate_hz", "burst_count"] + spec_set=["can_do_action", "ratelimit"] ), login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]), ) self.request_ratelimiter = self.hs.get_request_ratelimiter() self.request_ratelimiter.can_do_action.return_value = (True, 0) - self.request_ratelimiter.rate_hz = Mock() self.login_ratelimiter = self.hs.get_login_ratelimiter() self.login_ratelimiter.can_do_action.return_value = (True, 0) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 27d38d354a..2ec678a2a2 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -43,8 +43,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): http_client=None, federation_client=Mock(), request_ratelimiter=NonCallableMock( - # rate_hz and burst_count are overridden in BaseHandler - spec_set=["can_do_action", "ratelimit", "rate_hz", "burst_count"] + spec_set=["can_do_action", "ratelimit"] ), login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]), ) |