diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 30bdaa9c27..a41ec6a98f 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -784,6 +784,30 @@ class RoomsCreateTestCase(RoomBase):
# Check that do_3pid_invite wasn't called this time.
self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids))
+ def test_spam_checker_may_join_room(self):
+ """Tests that the user_may_join_room spam checker callback is correctly bypassed
+ when creating a new room.
+ """
+
+ async def user_may_join_room(
+ mxid: str,
+ room_id: str,
+ is_invite: bool,
+ ) -> bool:
+ return False
+
+ join_mock = Mock(side_effect=user_may_join_room)
+ self.hs.get_spam_checker()._user_may_join_room_callbacks.append(join_mock)
+
+ channel = self.make_request(
+ "POST",
+ "/createRoom",
+ {},
+ )
+ self.assertEquals(channel.code, 200, channel.json_body)
+
+ self.assertEquals(join_mock.call_count, 0)
+
class RoomTopicTestCase(RoomBase):
"""Tests /rooms/$room_id/topic REST events."""
@@ -975,6 +999,83 @@ class RoomInviteRatelimitTestCase(RoomBase):
self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429)
+class RoomJoinTestCase(RoomBase):
+
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user1 = self.register_user("thomas", "hackme")
+ self.tok1 = self.login("thomas", "hackme")
+
+ self.user2 = self.register_user("teresa", "hackme")
+ self.tok2 = self.login("teresa", "hackme")
+
+ self.room1 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
+ self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
+ self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
+
+ def test_spam_checker_may_join_room(self):
+ """Tests that the user_may_join_room spam checker callback is correctly called
+ and blocks room joins when needed.
+ """
+
+ # Register a dummy callback. Make it allow all room joins for now.
+ return_value = True
+
+ async def user_may_join_room(
+ userid: str,
+ room_id: str,
+ is_invited: bool,
+ ) -> bool:
+ return return_value
+
+ callback_mock = Mock(side_effect=user_may_join_room)
+ self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock)
+
+ # Join a first room, without being invited to it.
+ self.helper.join(self.room1, self.user2, tok=self.tok2)
+
+ # Check that the callback was called with the right arguments.
+ expected_call_args = (
+ (
+ self.user2,
+ self.room1,
+ False,
+ ),
+ )
+ self.assertEquals(
+ callback_mock.call_args,
+ expected_call_args,
+ callback_mock.call_args,
+ )
+
+ # Join a second room, this time with an invite for it.
+ self.helper.invite(self.room2, self.user1, self.user2, tok=self.tok1)
+ self.helper.join(self.room2, self.user2, tok=self.tok2)
+
+ # Check that the callback was called with the right arguments.
+ expected_call_args = (
+ (
+ self.user2,
+ self.room2,
+ True,
+ ),
+ )
+ self.assertEquals(
+ callback_mock.call_args,
+ expected_call_args,
+ callback_mock.call_args,
+ )
+
+ # Now make the callback deny all room joins, and check that a join actually fails.
+ return_value = False
+ self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2)
+
+
class RoomJoinRatelimitTestCase(RoomBase):
user_id = "@sid1:red"
|