summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/13253.misc1
-rw-r--r--synapse/api/ratelimiting.py94
-rw-r--r--tests/api/test_ratelimiting.py74
3 files changed, 157 insertions, 12 deletions
diff --git a/changelog.d/13253.misc b/changelog.d/13253.misc
new file mode 100644
index 0000000000..cba6b9ee0f
--- /dev/null
+++ b/changelog.d/13253.misc
@@ -0,0 +1 @@
+Preparatory work for a per-room rate limiter on joins.
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index 54d13026c9..f43965c1c8 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -27,6 +27,33 @@ class Ratelimiter:
     """
     Ratelimit actions marked by arbitrary keys.
 
+    (Note that the source code speaks of "actions" and "burst_count" rather than
+    "tokens" and a "bucket_size".)
+
+    This is a "leaky bucket as a meter". For each key to be tracked there is a bucket
+    containing some number 0 <= T <= `burst_count` of tokens corresponding to previously
+    permitted requests for that key. Each bucket starts empty, and gradually leaks
+    tokens at a rate of `rate_hz`.
+
+    Upon an incoming request, we must determine:
+    - the key that this request falls under (which bucket to inspect), and
+    - the cost C of this request in tokens.
+    Then, if there is room in the bucket for C tokens (T + C <= `burst_count`),
+    the request is permitted and `cost` tokens are added to the bucket.
+    Otherwise the request is denied, and the bucket continues to hold T tokens.
+
+    This means that the limiter enforces an average request frequency of `rate_hz`,
+    while accumulating a buffer of up to `burst_count` requests which can be consumed
+    instantaneously.
+
+    The tricky bit is the leaking. We do not want to have a periodic process which
+    leaks every bucket! Instead, we track
+    - the time point when the bucket was last completely empty, and
+    - how many tokens have added to the bucket permitted since then.
+    Then for each incoming request, we can calculate how many tokens have leaked
+    since this time point, and use that to decide if we should accept or reject the
+    request.
+
     Args:
         clock: A homeserver clock, for retrieving the current time
         rate_hz: The long term number of actions that can be performed in a second.
@@ -41,14 +68,30 @@ class Ratelimiter:
         self.burst_count = burst_count
         self.store = store
 
-        # A ordered dictionary keeping track of actions, when they were last
-        # performed and how often. Each entry is a mapping from a key of arbitrary type
-        # to a tuple representing:
-        #   * How many times an action has occurred since a point in time
-        #   * The point in time
-        #   * The rate_hz of this particular entry. This can vary per request
+        # An ordered dictionary representing the token buckets tracked by this rate
+        # limiter. Each entry maps a key of arbitrary type to a tuple representing:
+        #   * The number of tokens currently in the bucket,
+        #   * The time point when the bucket was last completely empty, and
+        #   * The rate_hz (leak rate) of this particular bucket.
         self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict()
 
+    def _get_key(
+        self, requester: Optional[Requester], key: Optional[Hashable]
+    ) -> Hashable:
+        """Use the requester's MXID as a fallback key if no key is provided."""
+        if key is None:
+            if not requester:
+                raise ValueError("Must supply at least one of `requester` or `key`")
+
+            key = requester.user.to_string()
+        return key
+
+    def _get_action_counts(
+        self, key: Hashable, time_now_s: float
+    ) -> Tuple[float, float, float]:
+        """Retrieve the action counts, with a fallback representing an empty bucket."""
+        return self.actions.get(key, (0.0, time_now_s, 0.0))
+
     async def can_do_action(
         self,
         requester: Optional[Requester],
@@ -88,11 +131,7 @@ class Ratelimiter:
                 * The reactor timestamp for when the action can be performed next.
                   -1 if rate_hz is less than or equal to zero
         """
-        if key is None:
-            if not requester:
-                raise ValueError("Must supply at least one of `requester` or `key`")
-
-            key = requester.user.to_string()
+        key = self._get_key(requester, key)
 
         if requester:
             # Disable rate limiting of users belonging to any AS that is configured
@@ -121,7 +160,7 @@ class Ratelimiter:
         self._prune_message_counts(time_now_s)
 
         # Check if there is an existing count entry for this key
-        action_count, time_start, _ = self.actions.get(key, (0.0, time_now_s, 0.0))
+        action_count, time_start, _ = self._get_action_counts(key, time_now_s)
 
         # Check whether performing another action is allowed
         time_delta = time_now_s - time_start
@@ -164,6 +203,37 @@ class Ratelimiter:
 
         return allowed, time_allowed
 
+    def record_action(
+        self,
+        requester: Optional[Requester],
+        key: Optional[Hashable] = None,
+        n_actions: int = 1,
+        _time_now_s: Optional[float] = None,
+    ) -> None:
+        """Record that an action(s) took place, even if they violate the rate limit.
+
+        This is useful for tracking the frequency of events that happen across
+        federation which we still want to impose local rate limits on. For instance, if
+        we are alice.com monitoring a particular room, we cannot prevent bob.com
+        from joining users to that room. However, we can track the number of recent
+        joins in the room and refuse to serve new joins ourselves if there have been too
+        many in the room across both homeservers.
+
+        Args:
+            requester: The requester that is doing the action, if any.
+            key: An arbitrary key used to classify an action. Defaults to the
+                requester's user ID.
+            n_actions: The number of times the user wants to do this action. If the user
+                cannot do all of the actions, the user's action count is not incremented
+                at all.
+            _time_now_s: The current time. Optional, defaults to the current time according
+                to self.clock. Only used by tests.
+        """
+        key = self._get_key(requester, key)
+        time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
+        action_count, time_start, rate_hz = self._get_action_counts(key, time_now_s)
+        self.actions[key] = (action_count + n_actions, time_start, rate_hz)
+
     def _prune_message_counts(self, time_now_s: float) -> None:
         """Remove message count entries that have not exceeded their defined
         rate_hz limit
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index 18649c2c05..c86f783c5b 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -314,3 +314,77 @@ class TestRatelimiter(unittest.HomeserverTestCase):
 
         # Check that we get rate limited after using that token.
         self.assertFalse(consume_at(11.1))
+
+    def test_record_action_which_doesnt_fill_bucket(self) -> None:
+        limiter = Ratelimiter(
+            store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
+        )
+
+        # Observe two actions, leaving room in the bucket for one more.
+        limiter.record_action(requester=None, key="a", n_actions=2, _time_now_s=0.0)
+
+        # We should be able to take a new action now.
+        success, _ = self.get_success_or_raise(
+            limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
+        )
+        self.assertTrue(success)
+
+        # ... but not two.
+        success, _ = self.get_success_or_raise(
+            limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
+        )
+        self.assertFalse(success)
+
+    def test_record_action_which_fills_bucket(self) -> None:
+        limiter = Ratelimiter(
+            store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
+        )
+
+        # Observe three actions, filling up the bucket.
+        limiter.record_action(requester=None, key="a", n_actions=3, _time_now_s=0.0)
+
+        # We should be unable to take a new action now.
+        success, _ = self.get_success_or_raise(
+            limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
+        )
+        self.assertFalse(success)
+
+        # If we wait 10 seconds to leak a token, we should be able to take one action...
+        success, _ = self.get_success_or_raise(
+            limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
+        )
+        self.assertTrue(success)
+
+        # ... but not two.
+        success, _ = self.get_success_or_raise(
+            limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
+        )
+        self.assertFalse(success)
+
+    def test_record_action_which_overfills_bucket(self) -> None:
+        limiter = Ratelimiter(
+            store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
+        )
+
+        # Observe four actions, exceeding the bucket.
+        limiter.record_action(requester=None, key="a", n_actions=4, _time_now_s=0.0)
+
+        # We should be prevented from taking a new action now.
+        success, _ = self.get_success_or_raise(
+            limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
+        )
+        self.assertFalse(success)
+
+        # If we wait 10 seconds to leak a token, we should be unable to take an action
+        # because the bucket is still full.
+        success, _ = self.get_success_or_raise(
+            limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
+        )
+        self.assertFalse(success)
+
+        # But after another 10 seconds we leak a second token, giving us room for
+        # action.
+        success, _ = self.get_success_or_raise(
+            limiter.can_do_action(requester=None, key="a", _time_now_s=20.0)
+        )
+        self.assertTrue(success)