summary refs log tree commit diff
path: root/tests/module_api/test_spamchecker.py
diff options
context:
space:
mode:
authorHugh Nimmo-Smith <hughns@users.noreply.github.com>2025-06-04 11:30:45 +0100
committerGitHub <noreply@github.com>2025-06-04 11:30:45 +0100
commitfbe7a898f0380aa194c26b71eccf029d0ac47b5c (patch)
tree479a0222b3192614d2a863d102c533e5c4ef5c23 /tests/module_api/test_spamchecker.py
parentMerge branch 'master' into develop (diff)
downloadsynapse-fbe7a898f0380aa194c26b71eccf029d0ac47b5c.tar.xz
Pass room_config argument to user_may_create_room spam checker module callback (#18486)
This PR adds an additional `room_config` argument to the
`user_may_create_room` spam checker module API callback.

It will continue to work with implementations of `user_may_create_room`
that do not expect the additional parameter.

A side affect is that on a room upgrade the spam checker callback is
called *after* doing some work to calculate the state rather than
before. However, I hope that this is acceptable given the relative
infrequency of room upgrades.
Diffstat (limited to 'tests/module_api/test_spamchecker.py')
-rw-r--r--tests/module_api/test_spamchecker.py155
1 files changed, 155 insertions, 0 deletions
diff --git a/tests/module_api/test_spamchecker.py b/tests/module_api/test_spamchecker.py
new file mode 100644

index 0000000000..82790222c8 --- /dev/null +++ b/tests/module_api/test_spamchecker.py
@@ -0,0 +1,155 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# <https://www.gnu.org/licenses/agpl-3.0.html>. +# +# +from typing import Literal, Union + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.config.server import DEFAULT_ROOM_VERSION +from synapse.rest import admin, login, room, room_upgrade_rest_servlet +from synapse.server import HomeServer +from synapse.types import Codes, JsonDict +from synapse.util import Clock + +from tests.server import FakeChannel +from tests.unittest import HomeserverTestCase + + +class SpamCheckerTestCase(HomeserverTestCase): + servlets = [ + room.register_servlets, + admin.register_servlets, + login.register_servlets, + room_upgrade_rest_servlet.register_servlets, + ] + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + self._module_api = homeserver.get_module_api() + self.user_id = self.register_user("user", "password") + self.token = self.login("user", "password") + + def create_room(self, content: JsonDict) -> FakeChannel: + channel = self.make_request( + "POST", + "/_matrix/client/r0/createRoom", + content, + access_token=self.token, + ) + + return channel + + def test_may_user_create_room(self) -> None: + """Test that the may_user_create_room callback is called when a user + creates a room, and that it receives the correct parameters. + """ + + async def user_may_create_room( + user_id: str, room_config: JsonDict + ) -> Union[Literal["NOT_SPAM"], Codes]: + self.last_room_config = room_config + self.last_user_id = user_id + return "NOT_SPAM" + + self._module_api.register_spam_checker_callbacks( + user_may_create_room=user_may_create_room + ) + + channel = self.create_room({"foo": "baa"}) + self.assertEqual(channel.code, 200) + self.assertEqual(self.last_user_id, self.user_id) + self.assertEqual(self.last_room_config["foo"], "baa") + + def test_may_user_create_room_on_upgrade(self) -> None: + """Test that the may_user_create_room callback is called when a room is upgraded.""" + + # First, create a room to upgrade. + channel = self.create_room({"topic": "foo"}) + self.assertEqual(channel.code, 200) + room_id = channel.json_body["room_id"] + + async def user_may_create_room( + user_id: str, room_config: JsonDict + ) -> Union[Literal["NOT_SPAM"], Codes]: + self.last_room_config = room_config + self.last_user_id = user_id + return "NOT_SPAM" + + # Register the callback for spam checking. + self._module_api.register_spam_checker_callbacks( + user_may_create_room=user_may_create_room + ) + + # Now upgrade the room. + channel = self.make_request( + "POST", + f"/_matrix/client/r0/rooms/{room_id}/upgrade", + # This will upgrade a room to the same version, but that's fine. + content={"new_version": DEFAULT_ROOM_VERSION}, + access_token=self.token, + ) + + # Check that the callback was called and the room was upgraded. + self.assertEqual(channel.code, 200) + self.assertEqual(self.last_user_id, self.user_id) + # Check that the initial state received by callback contains the topic event. + self.assertTrue( + any( + event[0][0] == "m.room.topic" and event[1].get("topic") == "foo" + for event in self.last_room_config["initial_state"] + ) + ) + + def test_may_user_create_room_disallowed(self) -> None: + """Test that the codes response from may_user_create_room callback is respected + and returned via the API. + """ + + async def user_may_create_room( + user_id: str, room_config: JsonDict + ) -> Union[Literal["NOT_SPAM"], Codes]: + self.last_room_config = room_config + self.last_user_id = user_id + return Codes.UNAUTHORIZED + + self._module_api.register_spam_checker_callbacks( + user_may_create_room=user_may_create_room + ) + + channel = self.create_room({"foo": "baa"}) + self.assertEqual(channel.code, 403) + self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEqual(self.last_user_id, self.user_id) + self.assertEqual(self.last_room_config["foo"], "baa") + + def test_may_user_create_room_compatibility(self) -> None: + """Test that the may_user_create_room callback is called when a user + creates a room for a module that uses the old callback signature + (without the `room_config` parameter) + """ + + async def user_may_create_room( + user_id: str, + ) -> Union[Literal["NOT_SPAM"], Codes]: + self.last_user_id = user_id + return "NOT_SPAM" + + self._module_api.register_spam_checker_callbacks( + user_may_create_room=user_may_create_room + ) + + channel = self.create_room({"foo": "baa"}) + self.assertEqual(channel.code, 200) + self.assertEqual(self.last_user_id, self.user_id)