summary refs log tree commit diff
path: root/tests/rest/client
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client')
-rw-r--r--tests/rest/client/test_rooms.py101
1 files changed, 101 insertions, 0 deletions
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 30bdaa9c27..a41ec6a98f 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -784,6 +784,30 @@ class RoomsCreateTestCase(RoomBase):
         # Check that do_3pid_invite wasn't called this time.
         self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids))
 
+    def test_spam_checker_may_join_room(self):
+        """Tests that the user_may_join_room spam checker callback is correctly bypassed
+        when creating a new room.
+        """
+
+        async def user_may_join_room(
+            mxid: str,
+            room_id: str,
+            is_invite: bool,
+        ) -> bool:
+            return False
+
+        join_mock = Mock(side_effect=user_may_join_room)
+        self.hs.get_spam_checker()._user_may_join_room_callbacks.append(join_mock)
+
+        channel = self.make_request(
+            "POST",
+            "/createRoom",
+            {},
+        )
+        self.assertEquals(channel.code, 200, channel.json_body)
+
+        self.assertEquals(join_mock.call_count, 0)
+
 
 class RoomTopicTestCase(RoomBase):
     """Tests /rooms/$room_id/topic REST events."""
@@ -975,6 +999,83 @@ class RoomInviteRatelimitTestCase(RoomBase):
         self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429)
 
 
+class RoomJoinTestCase(RoomBase):
+
+    servlets = [
+        admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, homeserver):
+        self.user1 = self.register_user("thomas", "hackme")
+        self.tok1 = self.login("thomas", "hackme")
+
+        self.user2 = self.register_user("teresa", "hackme")
+        self.tok2 = self.login("teresa", "hackme")
+
+        self.room1 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
+        self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
+        self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
+
+    def test_spam_checker_may_join_room(self):
+        """Tests that the user_may_join_room spam checker callback is correctly called
+        and blocks room joins when needed.
+        """
+
+        # Register a dummy callback. Make it allow all room joins for now.
+        return_value = True
+
+        async def user_may_join_room(
+            userid: str,
+            room_id: str,
+            is_invited: bool,
+        ) -> bool:
+            return return_value
+
+        callback_mock = Mock(side_effect=user_may_join_room)
+        self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock)
+
+        # Join a first room, without being invited to it.
+        self.helper.join(self.room1, self.user2, tok=self.tok2)
+
+        # Check that the callback was called with the right arguments.
+        expected_call_args = (
+            (
+                self.user2,
+                self.room1,
+                False,
+            ),
+        )
+        self.assertEquals(
+            callback_mock.call_args,
+            expected_call_args,
+            callback_mock.call_args,
+        )
+
+        # Join a second room, this time with an invite for it.
+        self.helper.invite(self.room2, self.user1, self.user2, tok=self.tok1)
+        self.helper.join(self.room2, self.user2, tok=self.tok2)
+
+        # Check that the callback was called with the right arguments.
+        expected_call_args = (
+            (
+                self.user2,
+                self.room2,
+                True,
+            ),
+        )
+        self.assertEquals(
+            callback_mock.call_args,
+            expected_call_args,
+            callback_mock.call_args,
+        )
+
+        # Now make the callback deny all room joins, and check that a join actually fails.
+        return_value = False
+        self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2)
+
+
 class RoomJoinRatelimitTestCase(RoomBase):
     user_id = "@sid1:red"