diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index 965032e3af..ef6f2377cf 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -75,6 +75,29 @@ class Ratelimiter:
# * The rate_hz (leak rate) of this particular bucket.
self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict()
+ def _get_key(
+ self, requester: Optional[Requester], key: Optional[Hashable]
+ ) -> Hashable:
+ """Use the requester's MXID as a fallback key if no key is provided.
+
+ Pulled out so that `can_do_action` and `record_action` are consistent.
+ """
+ if key is None:
+ if not requester:
+ raise ValueError("Must supply at least one of `requester` or `key`")
+
+ key = requester.user.to_string()
+ return key
+
+ def _get_action_counts(
+ self, key: Hashable, time_now_s: float
+ ) -> Tuple[float, float, float]:
+ """Retrieve the action counts, with a fallback representing an empty bucket.
+
+ Pulled out so that `can_do_action` and `record_action` are consistent.
+ """
+ return self.actions.get(key, (0.0, time_now_s, 0.0))
+
async def can_do_action(
self,
requester: Optional[Requester],
@@ -114,11 +137,7 @@ class Ratelimiter:
* The reactor timestamp for when the action can be performed next.
-1 if rate_hz is less than or equal to zero
"""
- if key is None:
- if not requester:
- raise ValueError("Must supply at least one of `requester` or `key`")
-
- key = requester.user.to_string()
+ key = self._get_key(requester, key)
if requester:
# Disable rate limiting of users belonging to any AS that is configured
@@ -147,7 +166,7 @@ class Ratelimiter:
self._prune_message_counts(time_now_s)
# Check if there is an existing count entry for this key
- action_count, time_start, _ = self.actions.get(key, (0.0, time_now_s, 0.0))
+ action_count, time_start, _ = self._get_action_counts(key, time_now_s)
# Check whether performing another action is allowed
time_delta = time_now_s - time_start
|