diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index 1a1cbde74e..5e73f5d5ec 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -1,6 +1,10 @@
+from typing import Optional
+
from synapse.api.ratelimiting import LimitExceededError, Ratelimiter
from synapse.appservice import ApplicationService
from synapse.config.ratelimiting import RatelimitSettings
+from synapse.module_api.callbacks.ratelimit_callbacks import RatelimitModuleApiCallbacks
+from synapse.storage.databases.main.room import RatelimitOverride
from synapse.types import create_requester
from tests import unittest
@@ -440,3 +444,49 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter.can_do_action(requester=None, key="a", _time_now_s=20.0)
)
self.assertTrue(success)
+
+ def test_get_ratelimit_override_for_user_callback(self) -> None:
+ test_user_id = "@user:test"
+ test_limiter_name = "name"
+ callbacks = RatelimitModuleApiCallbacks(self.hs)
+ requester = create_requester(test_user_id)
+ limiter = Ratelimiter(
+ store=self.hs.get_datastores().main,
+ clock=self.clock,
+ cfg=RatelimitSettings(
+ test_limiter_name,
+ per_second=0.1,
+ burst_count=3,
+ ),
+ ratelimit_callbacks=callbacks,
+ )
+
+ # Observe four actions, exceeding the burst_count.
+ limiter.record_action(requester=requester, n_actions=4, _time_now_s=0.0)
+
+ # We should be prevented from taking a new action now.
+ success, _ = self.get_success_or_raise(
+ limiter.can_do_action(requester=requester, _time_now_s=0.0)
+ )
+ self.assertFalse(success)
+
+ # Now register a callback that overrides the ratelimit for this user
+ # and limiter name.
+ async def get_ratelimit_override_for_user(
+ user_id: str, limiter_name: str
+ ) -> Optional[RatelimitOverride]:
+ if user_id == test_user_id:
+ return RatelimitOverride(
+ messages_per_second=0.1,
+ burst_count=10,
+ )
+ return None
+
+ callbacks.register_callbacks(
+ get_ratelimit_override_for_user=get_ratelimit_override_for_user
+ )
+
+ success, _ = self.get_success_or_raise(
+ limiter.can_do_action(requester=requester, _time_now_s=0.0)
+ )
+ self.assertTrue(success)
|