summary refs log tree commit diff
diff options
context:
space:
mode:
authorAndrew Morgan <andrew@amorgan.xyz>2020-05-29 18:30:44 +0100
committerAndrew Morgan <andrew@amorgan.xyz>2020-05-29 18:35:19 +0100
commitf6203a60e099651b51dc3d755dbd6b1c6aa8ce08 (patch)
tree19e077ab17d1fd79ee11d30097bbb68cb846e28c
parentchangelog (diff)
downloadsynapse-f6203a60e099651b51dc3d755dbd6b1c6aa8ce08.tar.xz
Make rate_hz and burst_count overridable per-request
-rw-r--r--synapse/api/ratelimiting.py66
-rw-r--r--synapse/handlers/_base.py14
-rw-r--r--tests/handlers/test_profile.py3
-rw-r--r--tests/rest/client/v1/test_rooms.py4
-rw-r--r--tests/rest/client/v1/test_typing.py3
5 files changed, 58 insertions, 32 deletions
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index 13fff302fe..79b7631172 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 from collections import OrderedDict
-from typing import Any, Tuple
+from typing import Any, Optional, Tuple
 
 from synapse.api.errors import LimitExceededError
 
@@ -23,10 +23,8 @@ class Ratelimiter(object):
     Ratelimit actions marked by arbitrary keys.
 
     Args:
-        rate_hz: The long term number of actions that can be performed in a
-            second.
-        burst_count: How many actions that can be performed before being
-            limited.
+        rate_hz: The long term number of actions that can be performed in a second.
+        burst_count: How many actions that can be performed before being limited.
     """
 
     def __init__(self, rate_hz: float, burst_count: int):
@@ -40,7 +38,12 @@ class Ratelimiter(object):
         self.burst_count = burst_count
 
     def can_do_action(
-        self, key: Any, time_now_s: int, update: bool = True,
+        self,
+        key: Any,
+        time_now_s: int,
+        update: bool = True,
+        rate_hz: Optional[float] = None,
+        burst_count: Optional[int] = None,
     ) -> Tuple[bool, float]:
         """Can the entity (e.g. user or IP address) perform the action?
 
@@ -49,27 +52,36 @@ class Ratelimiter(object):
                 (when sending events), an IP address, etc.
             time_now_s: The time now
             update: Whether to count this check as performing the action
+            rate_hz: The long term number of actions that can be performed in a second.
+                Overrides the value set during instantiation if set.
+            burst_count: How many actions that can be performed before being limited.
+                Overrides the value set during instantiation if set.
+
         Returns:
             A tuple containing:
                 * A bool indicating if they can perform the action now
                 * The time in seconds of when it can next be performed.
                   -1 if a rate_hz has not been defined for this Ratelimiter
         """
+        # Override default values if set
+        rate_hz = rate_hz or self.rate_hz
+        burst_count = burst_count or self.burst_count
+
         # Remove any expired entries
-        self._prune_message_counts(time_now_s)
+        self._prune_message_counts(time_now_s, rate_hz)
 
         # Check if there is an existing count entry for this key
         action_count, time_start, = self.actions.get(key, (0.0, time_now_s))
 
         # Check whether performing another action is allowed
         time_delta = time_now_s - time_start
-        performed_count = action_count - time_delta * self.rate_hz
+        performed_count = action_count - time_delta * rate_hz
         if performed_count < 0:
             # Allow, reset back to count 1
             allowed = True
             time_start = time_now_s
             action_count = 1.0
-        elif performed_count > self.burst_count - 1.0:
+        elif performed_count > burst_count - 1.0:
             # Deny, we have exceeded our burst count
             allowed = False
         else:
@@ -82,9 +94,7 @@ class Ratelimiter(object):
 
         # Figure out the time when an action can be performed again
         if self.rate_hz > 0:
-            time_allowed = (
-                time_start + (action_count - self.burst_count + 1) / self.rate_hz
-            )
+            time_allowed = time_start + (action_count - burst_count + 1) / rate_hz
 
             # Don't give back a time in the past
             if time_allowed < time_now_s:
@@ -95,26 +105,34 @@ class Ratelimiter(object):
 
         return allowed, time_allowed
 
-    def _prune_message_counts(self, time_now_s: int):
-        """Remove message count entries that are older than
+    def _prune_message_counts(self, time_now_s: int, rate_hz: float):
+        """Remove message count entries that have not exceeded their defined
+        rate_hz limit
 
         Args:
             time_now_s: The current time
+            rate_hz: The long term number of actions that can be performed in a second.
         """
         # We create a copy of the key list here as the dictionary is modified during
         # the loop
         for key in list(self.actions.keys()):
             action_count, time_start = self.actions[key]
 
+            # Rate limit = "seconds since we started limiting this action" * rate_hz
+            # If this limit has not been exceeded, wipe our record of this action
             time_delta = time_now_s - time_start
-            if action_count - time_delta * self.rate_hz > 0:
-                # XXX: Should this be a continue?
-                break
+            if action_count - time_delta * rate_hz > 0:
+                continue
             else:
                 del self.actions[key]
 
     def ratelimit(
-        self, key: Any, time_now_s: int, update: bool = True,
+        self,
+        key: Any,
+        time_now_s: int,
+        update: bool = True,
+        rate_hz: Optional[float] = None,
+        burst_count: Optional[int] = None,
     ):
         """Checks if an action can be performed. If not, raises a LimitExceededError
 
@@ -122,12 +140,22 @@ class Ratelimiter(object):
             key: An arbitrary key used to classify an action
             time_now_s: The current time
             update: Whether to count this check as performing the action
+            rate_hz: The long term number of actions that can be performed in a second.
+                Overrides the value set during instantiation if set.
+            burst_count: How many actions that can be performed before being limited.
+                Overrides the value set during instantiation if set.
 
         Raises:
             LimitExceededError: If an action could not be performed, along with the time in
                 milliseconds until the action can be performed again
         """
-        allowed, time_allowed = self.can_do_action(key, time_now_s, update)
+        # Override default values if set
+        rate_hz = rate_hz or self.rate_hz
+        burst_count = burst_count or self.burst_count
+
+        allowed, time_allowed = self.can_do_action(
+            key, time_now_s, update=update, rate_hz=rate_hz, burst_count=burst_count
+        )
 
         if not allowed:
             raise LimitExceededError(
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 206702b6ad..e10e2427c4 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -46,7 +46,6 @@ class BaseHandler(object):
         self.clock = hs.get_clock()
         self.hs = hs
 
-        self.ratelimiter = None
         self.request_ratelimiter = hs.get_request_ratelimiter()
         self._rc_message = self.hs.config.rc_message
 
@@ -103,14 +102,17 @@ class BaseHandler(object):
 
         if is_admin_redaction and self.admin_redaction_ratelimiter:
             # If we have separate config for admin redactions, use a separate
-            # ratelimiter as to not have user_id's clash
+            # ratelimiter as to not have user_ids clash
             self.admin_redaction_ratelimiter.ratelimit(user_id, time_now, update)
         else:
             # Override rate and burst count per-user
-            self.request_ratelimiter.rate_hz = messages_per_second
-            self.request_ratelimiter.burst_count = burst_count
-
-            self.request_ratelimiter.ratelimit(user_id, time_now, update)
+            self.request_ratelimiter.ratelimit(
+                user_id,
+                time_now,
+                update,
+                rate_hz=messages_per_second,
+                burst_count=burst_count,
+            )
 
     async def maybe_kick_guest_users(self, event, context=None):
         # Technically this function invalidates current_state by changing it.
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 5b2dcde2ba..891c986fbc 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -56,8 +56,7 @@ class ProfileTestCase(unittest.TestCase):
             federation_server=Mock(),
             federation_registry=self.mock_registry,
             request_ratelimiter=NonCallableMock(
-                # rate_hz and burst_count are overridden in BaseHandler
-                spec_set=["can_do_action", "ratelimit", "rate_hz", "burst_count"]
+                spec_set=["can_do_action", "ratelimit"]
             ),
             login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]),
         )
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 28b7ce085b..ba10f34468 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -50,14 +50,12 @@ class RoomBase(unittest.HomeserverTestCase):
             http_client=None,
             federation_client=Mock(),
             request_ratelimiter=NonCallableMock(
-                # rate_hz and burst_count are overridden in BaseHandler
-                spec_set=["can_do_action", "ratelimit", "rate_hz", "burst_count"]
+                spec_set=["can_do_action", "ratelimit"]
             ),
             login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]),
         )
         self.request_ratelimiter = self.hs.get_request_ratelimiter()
         self.request_ratelimiter.can_do_action.return_value = (True, 0)
-        self.request_ratelimiter.rate_hz = Mock()
 
         self.login_ratelimiter = self.hs.get_login_ratelimiter()
         self.login_ratelimiter.can_do_action.return_value = (True, 0)
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 27d38d354a..2ec678a2a2 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -43,8 +43,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
             http_client=None,
             federation_client=Mock(),
             request_ratelimiter=NonCallableMock(
-                # rate_hz and burst_count are overridden in BaseHandler
-                spec_set=["can_do_action", "ratelimit", "rate_hz", "burst_count"]
+                spec_set=["can_do_action", "ratelimit"]
             ),
             login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]),
         )