summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_errors.py15
-rw-r--r--tests/api/test_ratelimiting.py67
-rw-r--r--tests/config/test_ratelimiting.py31
3 files changed, 84 insertions, 29 deletions
diff --git a/tests/api/test_errors.py b/tests/api/test_errors.py
index 319abfe63d..8e159029d9 100644
--- a/tests/api/test_errors.py
+++ b/tests/api/test_errors.py
@@ -1,6 +1,5 @@
 # Copyright 2023 The Matrix.org Foundation C.I.C.
 #
-#
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
@@ -13,24 +12,32 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import json
+
 from synapse.api.errors import LimitExceededError
 
 from tests import unittest
 
 
-class ErrorsTestCase(unittest.TestCase):
+class LimitExceededErrorTestCase(unittest.TestCase):
+    def test_key_appears_in_context_but_not_error_dict(self) -> None:
+        err = LimitExceededError("needle")
+        serialised = json.dumps(err.error_dict(None))
+        self.assertIn("needle", err.debug_context)
+        self.assertNotIn("needle", serialised)
+
     # Create a sub-class to avoid mutating the class-level property.
     class LimitExceededErrorHeaders(LimitExceededError):
         include_retry_after_header = True
 
     def test_limit_exceeded_header(self) -> None:
-        err = ErrorsTestCase.LimitExceededErrorHeaders(retry_after_ms=100)
+        err = self.LimitExceededErrorHeaders(limiter_name="test", retry_after_ms=100)
         self.assertEqual(err.error_dict(None).get("retry_after_ms"), 100)
         assert err.headers is not None
         self.assertEqual(err.headers.get("Retry-After"), "1")
 
     def test_limit_exceeded_rounding(self) -> None:
-        err = ErrorsTestCase.LimitExceededErrorHeaders(retry_after_ms=3001)
+        err = self.LimitExceededErrorHeaders(limiter_name="test", retry_after_ms=3001)
         self.assertEqual(err.error_dict(None).get("retry_after_ms"), 3001)
         assert err.headers is not None
         self.assertEqual(err.headers.get("Retry-After"), "4")
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index fa6c1c02ce..a24638c9ef 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -1,5 +1,6 @@
 from synapse.api.ratelimiting import LimitExceededError, Ratelimiter
 from synapse.appservice import ApplicationService
+from synapse.config.ratelimiting import RatelimitSettings
 from synapse.types import create_requester
 
 from tests import unittest
@@ -10,8 +11,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=1,
+            cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
         )
         allowed, time_allowed = self.get_success_or_raise(
             limiter.can_do_action(None, key="test_id", _time_now_s=0)
@@ -43,8 +43,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=1,
+            cfg=RatelimitSettings(
+                key="",
+                per_second=0.1,
+                burst_count=1,
+            ),
         )
         allowed, time_allowed = self.get_success_or_raise(
             limiter.can_do_action(as_requester, _time_now_s=0)
@@ -76,8 +79,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=1,
+            cfg=RatelimitSettings(
+                key="",
+                per_second=0.1,
+                burst_count=1,
+            ),
         )
         allowed, time_allowed = self.get_success_or_raise(
             limiter.can_do_action(as_requester, _time_now_s=0)
@@ -101,8 +107,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=1,
+            cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
         )
 
         # Shouldn't raise
@@ -128,8 +133,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=1,
+            cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
         )
 
         # First attempt should be allowed
@@ -177,8 +181,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=1,
+            cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
         )
 
         # First attempt should be allowed
@@ -208,8 +211,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=1,
+            cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
         )
         self.get_success_or_raise(
             limiter.can_do_action(None, key="test_id_1", _time_now_s=0)
@@ -244,7 +246,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
             )
         )
 
-        limiter = Ratelimiter(store=store, clock=self.clock, rate_hz=0.1, burst_count=1)
+        limiter = Ratelimiter(
+            store=store,
+            clock=self.clock,
+            cfg=RatelimitSettings("", per_second=0.1, burst_count=1),
+        )
 
         # Shouldn't raise
         for _ in range(20):
@@ -254,8 +260,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=3,
+            cfg=RatelimitSettings(
+                key="",
+                per_second=0.1,
+                burst_count=3,
+            ),
         )
         # Test that 4 actions aren't allowed with a maximum burst of 3.
         allowed, time_allowed = self.get_success_or_raise(
@@ -321,8 +330,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=3,
+            cfg=RatelimitSettings("", per_second=0.1, burst_count=3),
         )
 
         def consume_at(time: float) -> bool:
@@ -346,8 +354,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=3,
+            cfg=RatelimitSettings(
+                "",
+                per_second=0.1,
+                burst_count=3,
+            ),
         )
 
         # Observe two actions, leaving room in the bucket for one more.
@@ -369,8 +380,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=3,
+            cfg=RatelimitSettings(
+                "",
+                per_second=0.1,
+                burst_count=3,
+            ),
         )
 
         # Observe three actions, filling up the bucket.
@@ -398,8 +412,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=3,
+            cfg=RatelimitSettings(
+                "",
+                per_second=0.1,
+                burst_count=3,
+            ),
         )
 
         # Observe four actions, exceeding the bucket.
diff --git a/tests/config/test_ratelimiting.py b/tests/config/test_ratelimiting.py
index f12147eaa0..0c27dd21e2 100644
--- a/tests/config/test_ratelimiting.py
+++ b/tests/config/test_ratelimiting.py
@@ -12,11 +12,42 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from synapse.config.homeserver import HomeServerConfig
+from synapse.config.ratelimiting import RatelimitSettings
 
 from tests.unittest import TestCase
 from tests.utils import default_config
 
 
+class ParseRatelimitSettingsTestcase(TestCase):
+    def test_depth_1(self) -> None:
+        cfg = {
+            "a": {
+                "per_second": 5,
+                "burst_count": 10,
+            }
+        }
+        parsed = RatelimitSettings.parse(cfg, "a")
+        self.assertEqual(parsed, RatelimitSettings("a", 5, 10))
+
+    def test_depth_2(self) -> None:
+        cfg = {
+            "a": {
+                "b": {
+                    "per_second": 5,
+                    "burst_count": 10,
+                },
+            }
+        }
+        parsed = RatelimitSettings.parse(cfg, "a.b")
+        self.assertEqual(parsed, RatelimitSettings("a.b", 5, 10))
+
+    def test_missing(self) -> None:
+        parsed = RatelimitSettings.parse(
+            {}, "a", defaults={"per_second": 5, "burst_count": 10}
+        )
+        self.assertEqual(parsed, RatelimitSettings("a", 5, 10))
+
+
 class RatelimitConfigTestCase(TestCase):
     def test_parse_rc_federation(self) -> None:
         config_dict = default_config("test")