From 28f21b403683619047668d5974219fdff8a33dfd Mon Sep 17 00:00:00 2001 From: Hugh Nimmo-Smith Date: Wed, 4 Jun 2025 12:26:04 +0100 Subject: Add user_may_send_state_event callback to spam checker module API (#18455) --- tests/module_api/test_spamchecker.py | 89 ++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) (limited to 'tests/module_api/test_spamchecker.py') diff --git a/tests/module_api/test_spamchecker.py b/tests/module_api/test_spamchecker.py index 82790222c8..926fe30b43 100644 --- a/tests/module_api/test_spamchecker.py +++ b/tests/module_api/test_spamchecker.py @@ -153,3 +153,92 @@ class SpamCheckerTestCase(HomeserverTestCase): channel = self.create_room({"foo": "baa"}) self.assertEqual(channel.code, 200) self.assertEqual(self.last_user_id, self.user_id) + + def test_user_may_send_state_event(self) -> None: + """Test that the user_may_send_state_event callback is called when a state event + is sent, and that it receives the correct parameters. + """ + + async def user_may_send_state_event( + user_id: str, + room_id: str, + event_type: str, + state_key: str, + content: JsonDict, + ) -> Union[Literal["NOT_SPAM"], Codes]: + self.last_user_id = user_id + self.last_room_id = room_id + self.last_event_type = event_type + self.last_state_key = state_key + self.last_content = content + return "NOT_SPAM" + + self._module_api.register_spam_checker_callbacks( + user_may_send_state_event=user_may_send_state_event + ) + + channel = self.create_room({}) + self.assertEqual(channel.code, 200) + + room_id = channel.json_body["room_id"] + + event_type = "test.event.type" + state_key = "test.state.key" + channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/state/%s/%s" + % ( + room_id, + event_type, + state_key, + ), + content={"foo": "bar"}, + access_token=self.token, + ) + + self.assertEqual(channel.code, 200) + self.assertEqual(self.last_user_id, self.user_id) + self.assertEqual(self.last_room_id, room_id) + self.assertEqual(self.last_event_type, event_type) + self.assertEqual(self.last_state_key, state_key) + self.assertEqual(self.last_content, {"foo": "bar"}) + + def test_user_may_send_state_event_disallows(self) -> None: + """Test that the user_may_send_state_event callback is called when a state event + is sent, and that the response is honoured. + """ + + async def user_may_send_state_event( + user_id: str, + room_id: str, + event_type: str, + state_key: str, + content: JsonDict, + ) -> Union[Literal["NOT_SPAM"], Codes]: + return Codes.FORBIDDEN + + self._module_api.register_spam_checker_callbacks( + user_may_send_state_event=user_may_send_state_event + ) + + channel = self.create_room({}) + self.assertEqual(channel.code, 200) + + room_id = channel.json_body["room_id"] + + event_type = "test.event.type" + state_key = "test.state.key" + channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/state/%s/%s" + % ( + room_id, + event_type, + state_key, + ), + content={"foo": "bar"}, + access_token=self.token, + ) + + self.assertEqual(channel.code, 403) + self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) -- cgit 1.5.1