summary refs log tree commit diff
path: root/tests/api/test_ratelimiting.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2021-03-30 12:06:09 +0100
committerGitHub <noreply@github.com>2021-03-30 12:06:09 +0100
commit963f4309fe29206f3ba92b493e922280feea30ed (patch)
tree67250b7423dc4a0a1b47626efa55d69ea032f51d /tests/api/test_ratelimiting.py
parentUpdate changelog (diff)
downloadsynapse-963f4309fe29206f3ba92b493e922280feea30ed.tar.xz
Make RateLimiter class check for ratelimit overrides (#9711)
This should fix a class of bug where we forget to check if e.g. the appservice shouldn't be ratelimited.

We also check the `ratelimit_override` table to check if the user has ratelimiting disabled. That table is really only meant to override the event sender ratelimiting, so we don't use any values from it (as they might not make sense for different rate limits), but we do infer that if ratelimiting is disabled for the user we should disabled all ratelimits.

Fixes #9663
Diffstat (limited to 'tests/api/test_ratelimiting.py')
-rw-r--r--tests/api/test_ratelimiting.py168
1 files changed, 108 insertions, 60 deletions
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index 483418192c..fa96ba07a5 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -5,38 +5,25 @@ from synapse.types import create_requester
 from tests import unittest
 
 
-class TestRatelimiter(unittest.TestCase):
+class TestRatelimiter(unittest.HomeserverTestCase):
     def test_allowed_via_can_do_action(self):
-        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
-        allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=0)
-        self.assertTrue(allowed)
-        self.assertEquals(10.0, time_allowed)
-
-        allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=5)
-        self.assertFalse(allowed)
-        self.assertEquals(10.0, time_allowed)
-
-        allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=10)
-        self.assertTrue(allowed)
-        self.assertEquals(20.0, time_allowed)
-
-    def test_allowed_user_via_can_requester_do_action(self):
-        user_requester = create_requester("@user:example.com")
-        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
-        allowed, time_allowed = limiter.can_requester_do_action(
-            user_requester, _time_now_s=0
+        limiter = Ratelimiter(
+            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+        )
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(None, key="test_id", _time_now_s=0)
         )
         self.assertTrue(allowed)
         self.assertEquals(10.0, time_allowed)
 
-        allowed, time_allowed = limiter.can_requester_do_action(
-            user_requester, _time_now_s=5
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(None, key="test_id", _time_now_s=5)
         )
         self.assertFalse(allowed)
         self.assertEquals(10.0, time_allowed)
 
-        allowed, time_allowed = limiter.can_requester_do_action(
-            user_requester, _time_now_s=10
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(None, key="test_id", _time_now_s=10)
         )
         self.assertTrue(allowed)
         self.assertEquals(20.0, time_allowed)
@@ -51,21 +38,23 @@ class TestRatelimiter(unittest.TestCase):
         )
         as_requester = create_requester("@user:example.com", app_service=appservice)
 
-        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
-        allowed, time_allowed = limiter.can_requester_do_action(
-            as_requester, _time_now_s=0
+        limiter = Ratelimiter(
+            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+        )
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(as_requester, _time_now_s=0)
         )
         self.assertTrue(allowed)
         self.assertEquals(10.0, time_allowed)
 
-        allowed, time_allowed = limiter.can_requester_do_action(
-            as_requester, _time_now_s=5
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(as_requester, _time_now_s=5)
         )
         self.assertFalse(allowed)
         self.assertEquals(10.0, time_allowed)
 
-        allowed, time_allowed = limiter.can_requester_do_action(
-            as_requester, _time_now_s=10
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(as_requester, _time_now_s=10)
         )
         self.assertTrue(allowed)
         self.assertEquals(20.0, time_allowed)
@@ -80,73 +69,89 @@ class TestRatelimiter(unittest.TestCase):
         )
         as_requester = create_requester("@user:example.com", app_service=appservice)
 
-        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
-        allowed, time_allowed = limiter.can_requester_do_action(
-            as_requester, _time_now_s=0
+        limiter = Ratelimiter(
+            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+        )
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(as_requester, _time_now_s=0)
         )
         self.assertTrue(allowed)
         self.assertEquals(-1, time_allowed)
 
-        allowed, time_allowed = limiter.can_requester_do_action(
-            as_requester, _time_now_s=5
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(as_requester, _time_now_s=5)
         )
         self.assertTrue(allowed)
         self.assertEquals(-1, time_allowed)
 
-        allowed, time_allowed = limiter.can_requester_do_action(
-            as_requester, _time_now_s=10
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(as_requester, _time_now_s=10)
         )
         self.assertTrue(allowed)
         self.assertEquals(-1, time_allowed)
 
     def test_allowed_via_ratelimit(self):
-        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+        limiter = Ratelimiter(
+            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+        )
 
         # Shouldn't raise
-        limiter.ratelimit(key="test_id", _time_now_s=0)
+        self.get_success_or_raise(limiter.ratelimit(None, key="test_id", _time_now_s=0))
 
         # Should raise
         with self.assertRaises(LimitExceededError) as context:
-            limiter.ratelimit(key="test_id", _time_now_s=5)
+            self.get_success_or_raise(
+                limiter.ratelimit(None, key="test_id", _time_now_s=5)
+            )
         self.assertEqual(context.exception.retry_after_ms, 5000)
 
         # Shouldn't raise
-        limiter.ratelimit(key="test_id", _time_now_s=10)
+        self.get_success_or_raise(
+            limiter.ratelimit(None, key="test_id", _time_now_s=10)
+        )
 
     def test_allowed_via_can_do_action_and_overriding_parameters(self):
         """Test that we can override options of can_do_action that would otherwise fail
         an action
         """
         # Create a Ratelimiter with a very low allowed rate_hz and burst_count
-        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+        limiter = Ratelimiter(
+            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+        )
 
         # First attempt should be allowed
-        allowed, time_allowed = limiter.can_do_action(
-            ("test_id",),
-            _time_now_s=0,
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(
+                None,
+                ("test_id",),
+                _time_now_s=0,
+            )
         )
         self.assertTrue(allowed)
         self.assertEqual(10.0, time_allowed)
 
         # Second attempt, 1s later, will fail
-        allowed, time_allowed = limiter.can_do_action(
-            ("test_id",),
-            _time_now_s=1,
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(
+                None,
+                ("test_id",),
+                _time_now_s=1,
+            )
         )
         self.assertFalse(allowed)
         self.assertEqual(10.0, time_allowed)
 
         # But, if we allow 10 actions/sec for this request, we should be allowed
         # to continue.
-        allowed, time_allowed = limiter.can_do_action(
-            ("test_id",), _time_now_s=1, rate_hz=10.0
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(None, ("test_id",), _time_now_s=1, rate_hz=10.0)
         )
         self.assertTrue(allowed)
         self.assertEqual(1.1, time_allowed)
 
         # Similarly if we allow a burst of 10 actions
-        allowed, time_allowed = limiter.can_do_action(
-            ("test_id",), _time_now_s=1, burst_count=10
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(None, ("test_id",), _time_now_s=1, burst_count=10)
         )
         self.assertTrue(allowed)
         self.assertEqual(1.0, time_allowed)
@@ -156,29 +161,72 @@ class TestRatelimiter(unittest.TestCase):
         fail an action
         """
         # Create a Ratelimiter with a very low allowed rate_hz and burst_count
-        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+        limiter = Ratelimiter(
+            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+        )
 
         # First attempt should be allowed
-        limiter.ratelimit(key=("test_id",), _time_now_s=0)
+        self.get_success_or_raise(
+            limiter.ratelimit(None, key=("test_id",), _time_now_s=0)
+        )
 
         # Second attempt, 1s later, will fail
         with self.assertRaises(LimitExceededError) as context:
-            limiter.ratelimit(key=("test_id",), _time_now_s=1)
+            self.get_success_or_raise(
+                limiter.ratelimit(None, key=("test_id",), _time_now_s=1)
+            )
         self.assertEqual(context.exception.retry_after_ms, 9000)
 
         # But, if we allow 10 actions/sec for this request, we should be allowed
         # to continue.
-        limiter.ratelimit(key=("test_id",), _time_now_s=1, rate_hz=10.0)
+        self.get_success_or_raise(
+            limiter.ratelimit(None, key=("test_id",), _time_now_s=1, rate_hz=10.0)
+        )
 
         # Similarly if we allow a burst of 10 actions
-        limiter.ratelimit(key=("test_id",), _time_now_s=1, burst_count=10)
+        self.get_success_or_raise(
+            limiter.ratelimit(None, key=("test_id",), _time_now_s=1, burst_count=10)
+        )
 
     def test_pruning(self):
-        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
-        limiter.can_do_action(key="test_id_1", _time_now_s=0)
+        limiter = Ratelimiter(
+            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+        )
+        self.get_success_or_raise(
+            limiter.can_do_action(None, key="test_id_1", _time_now_s=0)
+        )
 
         self.assertIn("test_id_1", limiter.actions)
 
-        limiter.can_do_action(key="test_id_2", _time_now_s=10)
+        self.get_success_or_raise(
+            limiter.can_do_action(None, key="test_id_2", _time_now_s=10)
+        )
 
         self.assertNotIn("test_id_1", limiter.actions)
+
+    def test_db_user_override(self):
+        """Test that users that have ratelimiting disabled in the DB aren't
+        ratelimited.
+        """
+        store = self.hs.get_datastore()
+
+        user_id = "@user:test"
+        requester = create_requester(user_id)
+
+        self.get_success(
+            store.db_pool.simple_insert(
+                table="ratelimit_override",
+                values={
+                    "user_id": user_id,
+                    "messages_per_second": None,
+                    "burst_count": None,
+                },
+                desc="test_db_user_override",
+            )
+        )
+
+        limiter = Ratelimiter(store=store, clock=None, rate_hz=0.1, burst_count=1)
+
+        # Shouldn't raise
+        for _ in range(20):
+            self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0))