diff --git a/tests/api/test_errors.py b/tests/api/test_errors.py
index 319abfe63d..8e159029d9 100644
--- a/tests/api/test_errors.py
+++ b/tests/api/test_errors.py
@@ -1,6 +1,5 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
-#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
@@ -13,24 +12,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import json
+
from synapse.api.errors import LimitExceededError
from tests import unittest
-class ErrorsTestCase(unittest.TestCase):
+class LimitExceededErrorTestCase(unittest.TestCase):
+ def test_key_appears_in_context_but_not_error_dict(self) -> None:
+ err = LimitExceededError("needle")
+ serialised = json.dumps(err.error_dict(None))
+ self.assertIn("needle", err.debug_context)
+ self.assertNotIn("needle", serialised)
+
# Create a sub-class to avoid mutating the class-level property.
class LimitExceededErrorHeaders(LimitExceededError):
include_retry_after_header = True
def test_limit_exceeded_header(self) -> None:
- err = ErrorsTestCase.LimitExceededErrorHeaders(retry_after_ms=100)
+ err = self.LimitExceededErrorHeaders(limiter_name="test", retry_after_ms=100)
self.assertEqual(err.error_dict(None).get("retry_after_ms"), 100)
assert err.headers is not None
self.assertEqual(err.headers.get("Retry-After"), "1")
def test_limit_exceeded_rounding(self) -> None:
- err = ErrorsTestCase.LimitExceededErrorHeaders(retry_after_ms=3001)
+ err = self.LimitExceededErrorHeaders(limiter_name="test", retry_after_ms=3001)
self.assertEqual(err.error_dict(None).get("retry_after_ms"), 3001)
assert err.headers is not None
self.assertEqual(err.headers.get("Retry-After"), "4")
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index fa6c1c02ce..a24638c9ef 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -1,5 +1,6 @@
from synapse.api.ratelimiting import LimitExceededError, Ratelimiter
from synapse.appservice import ApplicationService
+from synapse.config.ratelimiting import RatelimitSettings
from synapse.types import create_requester
from tests import unittest
@@ -10,8 +11,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter(
store=self.hs.get_datastores().main,
clock=self.clock,
- rate_hz=0.1,
- burst_count=1,
+ cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
)
allowed, time_allowed = self.get_success_or_raise(
limiter.can_do_action(None, key="test_id", _time_now_s=0)
@@ -43,8 +43,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter(
store=self.hs.get_datastores().main,
clock=self.clock,
- rate_hz=0.1,
- burst_count=1,
+ cfg=RatelimitSettings(
+ key="",
+ per_second=0.1,
+ burst_count=1,
+ ),
)
allowed, time_allowed = self.get_success_or_raise(
limiter.can_do_action(as_requester, _time_now_s=0)
@@ -76,8 +79,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter(
store=self.hs.get_datastores().main,
clock=self.clock,
- rate_hz=0.1,
- burst_count=1,
+ cfg=RatelimitSettings(
+ key="",
+ per_second=0.1,
+ burst_count=1,
+ ),
)
allowed, time_allowed = self.get_success_or_raise(
limiter.can_do_action(as_requester, _time_now_s=0)
@@ -101,8 +107,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter(
store=self.hs.get_datastores().main,
clock=self.clock,
- rate_hz=0.1,
- burst_count=1,
+ cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
)
# Shouldn't raise
@@ -128,8 +133,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter(
store=self.hs.get_datastores().main,
clock=self.clock,
- rate_hz=0.1,
- burst_count=1,
+ cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
)
# First attempt should be allowed
@@ -177,8 +181,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter(
store=self.hs.get_datastores().main,
clock=self.clock,
- rate_hz=0.1,
- burst_count=1,
+ cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
)
# First attempt should be allowed
@@ -208,8 +211,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter(
store=self.hs.get_datastores().main,
clock=self.clock,
- rate_hz=0.1,
- burst_count=1,
+ cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
)
self.get_success_or_raise(
limiter.can_do_action(None, key="test_id_1", _time_now_s=0)
@@ -244,7 +246,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
)
)
- limiter = Ratelimiter(store=store, clock=self.clock, rate_hz=0.1, burst_count=1)
+ limiter = Ratelimiter(
+ store=store,
+ clock=self.clock,
+ cfg=RatelimitSettings("", per_second=0.1, burst_count=1),
+ )
# Shouldn't raise
for _ in range(20):
@@ -254,8 +260,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter(
store=self.hs.get_datastores().main,
clock=self.clock,
- rate_hz=0.1,
- burst_count=3,
+ cfg=RatelimitSettings(
+ key="",
+ per_second=0.1,
+ burst_count=3,
+ ),
)
# Test that 4 actions aren't allowed with a maximum burst of 3.
allowed, time_allowed = self.get_success_or_raise(
@@ -321,8 +330,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter(
store=self.hs.get_datastores().main,
clock=self.clock,
- rate_hz=0.1,
- burst_count=3,
+ cfg=RatelimitSettings("", per_second=0.1, burst_count=3),
)
def consume_at(time: float) -> bool:
@@ -346,8 +354,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter(
store=self.hs.get_datastores().main,
clock=self.clock,
- rate_hz=0.1,
- burst_count=3,
+ cfg=RatelimitSettings(
+ "",
+ per_second=0.1,
+ burst_count=3,
+ ),
)
# Observe two actions, leaving room in the bucket for one more.
@@ -369,8 +380,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter(
store=self.hs.get_datastores().main,
clock=self.clock,
- rate_hz=0.1,
- burst_count=3,
+ cfg=RatelimitSettings(
+ "",
+ per_second=0.1,
+ burst_count=3,
+ ),
)
# Observe three actions, filling up the bucket.
@@ -398,8 +412,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter(
store=self.hs.get_datastores().main,
clock=self.clock,
- rate_hz=0.1,
- burst_count=3,
+ cfg=RatelimitSettings(
+ "",
+ per_second=0.1,
+ burst_count=3,
+ ),
)
# Observe four actions, exceeding the bucket.
diff --git a/tests/config/test_ratelimiting.py b/tests/config/test_ratelimiting.py
index f12147eaa0..0c27dd21e2 100644
--- a/tests/config/test_ratelimiting.py
+++ b/tests/config/test_ratelimiting.py
@@ -12,11 +12,42 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.config.homeserver import HomeServerConfig
+from synapse.config.ratelimiting import RatelimitSettings
from tests.unittest import TestCase
from tests.utils import default_config
+class ParseRatelimitSettingsTestcase(TestCase):
+ def test_depth_1(self) -> None:
+ cfg = {
+ "a": {
+ "per_second": 5,
+ "burst_count": 10,
+ }
+ }
+ parsed = RatelimitSettings.parse(cfg, "a")
+ self.assertEqual(parsed, RatelimitSettings("a", 5, 10))
+
+ def test_depth_2(self) -> None:
+ cfg = {
+ "a": {
+ "b": {
+ "per_second": 5,
+ "burst_count": 10,
+ },
+ }
+ }
+ parsed = RatelimitSettings.parse(cfg, "a.b")
+ self.assertEqual(parsed, RatelimitSettings("a.b", 5, 10))
+
+ def test_missing(self) -> None:
+ parsed = RatelimitSettings.parse(
+ {}, "a", defaults={"per_second": 5, "burst_count": 10}
+ )
+ self.assertEqual(parsed, RatelimitSettings("a", 5, 10))
+
+
class RatelimitConfigTestCase(TestCase):
def test_parse_rc_federation(self) -> None:
config_dict = default_config("test")
|