summary refs log tree commit diff
path: root/tests/api/test_ratelimiting.py
diff options
context:
space:
mode:
authorWill Hunt <will@half-shot.uk>2020-08-21 15:07:56 +0100
committerBrendan Abolivier <babolivier@matrix.org>2020-08-24 14:53:53 +0100
commit2df82ae451e03d76fae5381961dd6229d5796400 (patch)
treed709012cb871a80bf45b15a3bd2a5146feace59b /tests/api/test_ratelimiting.py
parentChangelog changes (diff)
downloadsynapse-2df82ae451e03d76fae5381961dd6229d5796400.tar.xz
Do not apply ratelimiting on joins to appservices (#8139)
Add new method ratelimiter.can_requester_do_action and ensure that appservices are exempt from being ratelimited.

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
Co-authored-by: Erik Johnston <erik@matrix.org>
Diffstat (limited to '')
-rw-r--r--tests/api/test_ratelimiting.py73
1 files changed, 73 insertions, 0 deletions
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index d580e729c5..1e1f30d790 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -1,4 +1,6 @@
 from synapse.api.ratelimiting import LimitExceededError, Ratelimiter
+from synapse.appservice import ApplicationService
+from synapse.types import create_requester
 
 from tests import unittest
 
@@ -18,6 +20,77 @@ class TestRatelimiter(unittest.TestCase):
         self.assertTrue(allowed)
         self.assertEquals(20.0, time_allowed)
 
+    def test_allowed_user_via_can_requester_do_action(self):
+        user_requester = create_requester("@user:example.com")
+        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+        allowed, time_allowed = limiter.can_requester_do_action(
+            user_requester, _time_now_s=0
+        )
+        self.assertTrue(allowed)
+        self.assertEquals(10.0, time_allowed)
+
+        allowed, time_allowed = limiter.can_requester_do_action(
+            user_requester, _time_now_s=5
+        )
+        self.assertFalse(allowed)
+        self.assertEquals(10.0, time_allowed)
+
+        allowed, time_allowed = limiter.can_requester_do_action(
+            user_requester, _time_now_s=10
+        )
+        self.assertTrue(allowed)
+        self.assertEquals(20.0, time_allowed)
+
+    def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
+        appservice = ApplicationService(
+            None, "example.com", id="foo", rate_limited=True,
+        )
+        as_requester = create_requester("@user:example.com", app_service=appservice)
+
+        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+        allowed, time_allowed = limiter.can_requester_do_action(
+            as_requester, _time_now_s=0
+        )
+        self.assertTrue(allowed)
+        self.assertEquals(10.0, time_allowed)
+
+        allowed, time_allowed = limiter.can_requester_do_action(
+            as_requester, _time_now_s=5
+        )
+        self.assertFalse(allowed)
+        self.assertEquals(10.0, time_allowed)
+
+        allowed, time_allowed = limiter.can_requester_do_action(
+            as_requester, _time_now_s=10
+        )
+        self.assertTrue(allowed)
+        self.assertEquals(20.0, time_allowed)
+
+    def test_allowed_appservice_via_can_requester_do_action(self):
+        appservice = ApplicationService(
+            None, "example.com", id="foo", rate_limited=False,
+        )
+        as_requester = create_requester("@user:example.com", app_service=appservice)
+
+        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+        allowed, time_allowed = limiter.can_requester_do_action(
+            as_requester, _time_now_s=0
+        )
+        self.assertTrue(allowed)
+        self.assertEquals(-1, time_allowed)
+
+        allowed, time_allowed = limiter.can_requester_do_action(
+            as_requester, _time_now_s=5
+        )
+        self.assertTrue(allowed)
+        self.assertEquals(-1, time_allowed)
+
+        allowed, time_allowed = limiter.can_requester_do_action(
+            as_requester, _time_now_s=10
+        )
+        self.assertTrue(allowed)
+        self.assertEquals(-1, time_allowed)
+
     def test_allowed_via_ratelimit(self):
         limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)