summary refs log tree commit diff
path: root/tests/api/test_ratelimiting.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/api/test_ratelimiting.py')
-rw-r--r--tests/api/test_ratelimiting.py50
1 files changed, 50 insertions, 0 deletions
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py

index 1a1cbde74e..5e73f5d5ec 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py
@@ -1,6 +1,10 @@ +from typing import Optional + from synapse.api.ratelimiting import LimitExceededError, Ratelimiter from synapse.appservice import ApplicationService from synapse.config.ratelimiting import RatelimitSettings +from synapse.module_api.callbacks.ratelimit_callbacks import RatelimitModuleApiCallbacks +from synapse.storage.databases.main.room import RatelimitOverride from synapse.types import create_requester from tests import unittest @@ -440,3 +444,49 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter.can_do_action(requester=None, key="a", _time_now_s=20.0) ) self.assertTrue(success) + + def test_get_ratelimit_override_for_user_callback(self) -> None: + test_user_id = "@user:test" + test_limiter_name = "name" + callbacks = RatelimitModuleApiCallbacks(self.hs) + requester = create_requester(test_user_id) + limiter = Ratelimiter( + store=self.hs.get_datastores().main, + clock=self.clock, + cfg=RatelimitSettings( + test_limiter_name, + per_second=0.1, + burst_count=3, + ), + ratelimit_callbacks=callbacks, + ) + + # Observe four actions, exceeding the burst_count. + limiter.record_action(requester=requester, n_actions=4, _time_now_s=0.0) + + # We should be prevented from taking a new action now. + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=requester, _time_now_s=0.0) + ) + self.assertFalse(success) + + # Now register a callback that overrides the ratelimit for this user + # and limiter name. + async def get_ratelimit_override_for_user( + user_id: str, limiter_name: str + ) -> Optional[RatelimitOverride]: + if user_id == test_user_id: + return RatelimitOverride( + messages_per_second=0.1, + burst_count=10, + ) + return None + + callbacks.register_callbacks( + get_ratelimit_override_for_user=get_ratelimit_override_for_user + ) + + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=requester, _time_now_s=0.0) + ) + self.assertTrue(success)