diff options
author | David Robertson <davidr@element.io> | 2022-06-28 15:50:18 +0100 |
---|---|---|
committer | David Robertson <davidr@element.io> | 2022-07-04 19:10:14 +0100 |
commit | c2e3025b3385d8770b22a5350db8e5b23011818f (patch) | |
tree | 3e5f6a31b7b9a76878533bdc5df7aa39e18c1aab | |
parent | Rate limiter: describe leaky bucket (diff) | |
download | synapse-c2e3025b3385d8770b22a5350db8e5b23011818f.tar.xz |
Rate limiter: Pull out some small methods
-rw-r--r-- | synapse/api/ratelimiting.py | 31 |
1 files changed, 25 insertions, 6 deletions
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 965032e3af..ef6f2377cf 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -75,6 +75,29 @@ class Ratelimiter: # * The rate_hz (leak rate) of this particular bucket. self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict() + def _get_key( + self, requester: Optional[Requester], key: Optional[Hashable] + ) -> Hashable: + """Use the requester's MXID as a fallback key if no key is provided. + + Pulled out so that `can_do_action` and `record_action` are consistent. + """ + if key is None: + if not requester: + raise ValueError("Must supply at least one of `requester` or `key`") + + key = requester.user.to_string() + return key + + def _get_action_counts( + self, key: Hashable, time_now_s: float + ) -> Tuple[float, float, float]: + """Retrieve the action counts, with a fallback representing an empty bucket. + + Pulled out so that `can_do_action` and `record_action` are consistent. + """ + return self.actions.get(key, (0.0, time_now_s, 0.0)) + async def can_do_action( self, requester: Optional[Requester], @@ -114,11 +137,7 @@ class Ratelimiter: * The reactor timestamp for when the action can be performed next. -1 if rate_hz is less than or equal to zero """ - if key is None: - if not requester: - raise ValueError("Must supply at least one of `requester` or `key`") - - key = requester.user.to_string() + key = self._get_key(requester, key) if requester: # Disable rate limiting of users belonging to any AS that is configured @@ -147,7 +166,7 @@ class Ratelimiter: self._prune_message_counts(time_now_s) # Check if there is an existing count entry for this key - action_count, time_start, _ = self.actions.get(key, (0.0, time_now_s, 0.0)) + action_count, time_start, _ = self._get_action_counts(key, time_now_s) # Check whether performing another action is allowed time_delta = time_now_s - time_start |