diff --git a/tests/module_api/test_spamchecker.py b/tests/module_api/test_spamchecker.py
index 82790222c8..926fe30b43 100644
--- a/tests/module_api/test_spamchecker.py
+++ b/tests/module_api/test_spamchecker.py
@@ -153,3 +153,92 @@ class SpamCheckerTestCase(HomeserverTestCase):
channel = self.create_room({"foo": "baa"})
self.assertEqual(channel.code, 200)
self.assertEqual(self.last_user_id, self.user_id)
+
+ def test_user_may_send_state_event(self) -> None:
+ """Test that the user_may_send_state_event callback is called when a state event
+ is sent, and that it receives the correct parameters.
+ """
+
+ async def user_may_send_state_event(
+ user_id: str,
+ room_id: str,
+ event_type: str,
+ state_key: str,
+ content: JsonDict,
+ ) -> Union[Literal["NOT_SPAM"], Codes]:
+ self.last_user_id = user_id
+ self.last_room_id = room_id
+ self.last_event_type = event_type
+ self.last_state_key = state_key
+ self.last_content = content
+ return "NOT_SPAM"
+
+ self._module_api.register_spam_checker_callbacks(
+ user_may_send_state_event=user_may_send_state_event
+ )
+
+ channel = self.create_room({})
+ self.assertEqual(channel.code, 200)
+
+ room_id = channel.json_body["room_id"]
+
+ event_type = "test.event.type"
+ state_key = "test.state.key"
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/state/%s/%s"
+ % (
+ room_id,
+ event_type,
+ state_key,
+ ),
+ content={"foo": "bar"},
+ access_token=self.token,
+ )
+
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(self.last_user_id, self.user_id)
+ self.assertEqual(self.last_room_id, room_id)
+ self.assertEqual(self.last_event_type, event_type)
+ self.assertEqual(self.last_state_key, state_key)
+ self.assertEqual(self.last_content, {"foo": "bar"})
+
+ def test_user_may_send_state_event_disallows(self) -> None:
+ """Test that the user_may_send_state_event callback is called when a state event
+ is sent, and that the response is honoured.
+ """
+
+ async def user_may_send_state_event(
+ user_id: str,
+ room_id: str,
+ event_type: str,
+ state_key: str,
+ content: JsonDict,
+ ) -> Union[Literal["NOT_SPAM"], Codes]:
+ return Codes.FORBIDDEN
+
+ self._module_api.register_spam_checker_callbacks(
+ user_may_send_state_event=user_may_send_state_event
+ )
+
+ channel = self.create_room({})
+ self.assertEqual(channel.code, 200)
+
+ room_id = channel.json_body["room_id"]
+
+ event_type = "test.event.type"
+ state_key = "test.state.key"
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/state/%s/%s"
+ % (
+ room_id,
+ event_type,
+ state_key,
+ ),
+ content={"foo": "bar"},
+ access_token=self.token,
+ )
+
+ self.assertEqual(channel.code, 403)
+ self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
|