summary refs log tree commit diff
path: root/tests/rest/client/test_delayed_events.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client/test_delayed_events.py')
-rw-r--r--tests/rest/client/test_delayed_events.py143
1 files changed, 143 insertions, 0 deletions
diff --git a/tests/rest/client/test_delayed_events.py b/tests/rest/client/test_delayed_events.py

index 1793b38c4a..2c938390c8 100644 --- a/tests/rest/client/test_delayed_events.py +++ b/tests/rest/client/test_delayed_events.py
@@ -109,6 +109,27 @@ class DelayedEventsTestCase(HomeserverTestCase): ) self.assertEqual(setter_expected, content.get(setter_key), content) + @unittest.override_config( + {"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}} + ) + def test_get_delayed_events_ratelimit(self) -> None: + args = ("GET", PATH_PREFIX) + + channel = self.make_request(*args) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + + channel = self.make_request(*args) + self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result) + + # Add the current user to the ratelimit overrides, allowing them no ratelimiting. + self.get_success( + self.hs.get_datastores().main.set_ratelimit_for_user(self.user_id, 0, 0) + ) + + # Test that the request isn't ratelimited anymore. + channel = self.make_request(*args) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + def test_update_delayed_event_without_id(self) -> None: channel = self.make_request( "POST", @@ -206,6 +227,46 @@ class DelayedEventsTestCase(HomeserverTestCase): expect_code=HTTPStatus.NOT_FOUND, ) + @unittest.override_config( + {"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}} + ) + def test_cancel_delayed_event_ratelimit(self) -> None: + delay_ids = [] + for _ in range(2): + channel = self.make_request( + "POST", + _get_path_for_delayed_send(self.room_id, _EVENT_TYPE, 100000), + {}, + ) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + delay_id = channel.json_body.get("delay_id") + self.assertIsNotNone(delay_id) + delay_ids.append(delay_id) + + channel = self.make_request( + "POST", + f"{PATH_PREFIX}/{delay_ids.pop(0)}", + {"action": "cancel"}, + ) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + + args = ( + "POST", + f"{PATH_PREFIX}/{delay_ids.pop(0)}", + {"action": "cancel"}, + ) + channel = self.make_request(*args) + self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result) + + # Add the current user to the ratelimit overrides, allowing them no ratelimiting. + self.get_success( + self.hs.get_datastores().main.set_ratelimit_for_user(self.user_id, 0, 0) + ) + + # Test that the request isn't ratelimited anymore. + channel = self.make_request(*args) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + def test_send_delayed_state_event(self) -> None: state_key = "to_send_on_request" @@ -250,6 +311,44 @@ class DelayedEventsTestCase(HomeserverTestCase): ) self.assertEqual(setter_expected, content.get(setter_key), content) + @unittest.override_config({"rc_message": {"per_second": 3.5, "burst_count": 4}}) + def test_send_delayed_event_ratelimit(self) -> None: + delay_ids = [] + for _ in range(2): + channel = self.make_request( + "POST", + _get_path_for_delayed_send(self.room_id, _EVENT_TYPE, 100000), + {}, + ) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + delay_id = channel.json_body.get("delay_id") + self.assertIsNotNone(delay_id) + delay_ids.append(delay_id) + + channel = self.make_request( + "POST", + f"{PATH_PREFIX}/{delay_ids.pop(0)}", + {"action": "send"}, + ) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + + args = ( + "POST", + f"{PATH_PREFIX}/{delay_ids.pop(0)}", + {"action": "send"}, + ) + channel = self.make_request(*args) + self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result) + + # Add the current user to the ratelimit overrides, allowing them no ratelimiting. + self.get_success( + self.hs.get_datastores().main.set_ratelimit_for_user(self.user_id, 0, 0) + ) + + # Test that the request isn't ratelimited anymore. + channel = self.make_request(*args) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + def test_restart_delayed_state_event(self) -> None: state_key = "to_send_on_restarted_timeout" @@ -309,6 +408,46 @@ class DelayedEventsTestCase(HomeserverTestCase): ) self.assertEqual(setter_expected, content.get(setter_key), content) + @unittest.override_config( + {"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}} + ) + def test_restart_delayed_event_ratelimit(self) -> None: + delay_ids = [] + for _ in range(2): + channel = self.make_request( + "POST", + _get_path_for_delayed_send(self.room_id, _EVENT_TYPE, 100000), + {}, + ) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + delay_id = channel.json_body.get("delay_id") + self.assertIsNotNone(delay_id) + delay_ids.append(delay_id) + + channel = self.make_request( + "POST", + f"{PATH_PREFIX}/{delay_ids.pop(0)}", + {"action": "restart"}, + ) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + + args = ( + "POST", + f"{PATH_PREFIX}/{delay_ids.pop(0)}", + {"action": "restart"}, + ) + channel = self.make_request(*args) + self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result) + + # Add the current user to the ratelimit overrides, allowing them no ratelimiting. + self.get_success( + self.hs.get_datastores().main.set_ratelimit_for_user(self.user_id, 0, 0) + ) + + # Test that the request isn't ratelimited anymore. + channel = self.make_request(*args) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + def test_delayed_state_events_are_cancelled_by_more_recent_state(self) -> None: state_key = "to_be_cancelled" @@ -374,3 +513,7 @@ def _get_path_for_delayed_state( room_id: str, event_type: str, state_key: str, delay_ms: int ) -> str: return f"rooms/{room_id}/state/{event_type}/{state_key}?org.matrix.msc4140.delay={delay_ms}" + + +def _get_path_for_delayed_send(room_id: str, event_type: str, delay_ms: int) -> str: + return f"rooms/{room_id}/send/{event_type}?org.matrix.msc4140.delay={delay_ms}"