summary refs log tree commit diff
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2022-06-28 15:50:18 +0100
committerDavid Robertson <davidr@element.io>2022-07-04 19:10:14 +0100
commitc2e3025b3385d8770b22a5350db8e5b23011818f (patch)
tree3e5f6a31b7b9a76878533bdc5df7aa39e18c1aab
parentRate limiter: describe leaky bucket (diff)
downloadsynapse-c2e3025b3385d8770b22a5350db8e5b23011818f.tar.xz
Rate limiter: Pull out some small methods
-rw-r--r--synapse/api/ratelimiting.py31
1 files changed, 25 insertions, 6 deletions
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index 965032e3af..ef6f2377cf 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -75,6 +75,29 @@ class Ratelimiter:
         #   * The rate_hz (leak rate) of this particular bucket.
         self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict()
 
+    def _get_key(
+        self, requester: Optional[Requester], key: Optional[Hashable]
+    ) -> Hashable:
+        """Use the requester's MXID as a fallback key if no key is provided.
+
+        Pulled out so that `can_do_action` and `record_action` are consistent.
+        """
+        if key is None:
+            if not requester:
+                raise ValueError("Must supply at least one of `requester` or `key`")
+
+            key = requester.user.to_string()
+        return key
+
+    def _get_action_counts(
+        self, key: Hashable, time_now_s: float
+    ) -> Tuple[float, float, float]:
+        """Retrieve the action counts, with a fallback representing an empty bucket.
+
+        Pulled out so that `can_do_action` and `record_action` are consistent.
+        """
+        return self.actions.get(key, (0.0, time_now_s, 0.0))
+
     async def can_do_action(
         self,
         requester: Optional[Requester],
@@ -114,11 +137,7 @@ class Ratelimiter:
                 * The reactor timestamp for when the action can be performed next.
                   -1 if rate_hz is less than or equal to zero
         """
-        if key is None:
-            if not requester:
-                raise ValueError("Must supply at least one of `requester` or `key`")
-
-            key = requester.user.to_string()
+        key = self._get_key(requester, key)
 
         if requester:
             # Disable rate limiting of users belonging to any AS that is configured
@@ -147,7 +166,7 @@ class Ratelimiter:
         self._prune_message_counts(time_now_s)
 
         # Check if there is an existing count entry for this key
-        action_count, time_start, _ = self.actions.get(key, (0.0, time_now_s, 0.0))
+        action_count, time_start, _ = self._get_action_counts(key, time_now_s)
 
         # Check whether performing another action is allowed
         time_delta = time_now_s - time_start