summary refs log tree commit diff
path: root/tests/rest/client/test_rooms.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client/test_rooms.py')
-rw-r--r--tests/rest/client/test_rooms.py119
1 files changed, 117 insertions, 2 deletions
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index ef847f0f5f..30bdaa9c27 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -18,7 +18,7 @@
 """Tests REST events for /rooms paths."""
 
 import json
-from typing import Iterable
+from typing import Dict, Iterable, List, Optional
 from unittest.mock import Mock, call
 from urllib import parse as urlparse
 
@@ -30,7 +30,7 @@ from synapse.api.errors import Codes, HttpResponseException
 from synapse.handlers.pagination import PurgeStatus
 from synapse.rest import admin
 from synapse.rest.client import account, directory, login, profile, room, sync
-from synapse.types import JsonDict, RoomAlias, UserID, create_requester
+from synapse.types import JsonDict, Requester, RoomAlias, UserID, create_requester
 from synapse.util.stringutils import random_string
 
 from tests import unittest
@@ -669,6 +669,121 @@ class RoomsCreateTestCase(RoomBase):
         channel = self.make_request("POST", "/createRoom", content)
         self.assertEqual(200, channel.code)
 
+    def test_spamchecker_invites(self):
+        """Tests the user_may_create_room_with_invites spam checker callback."""
+
+        # Mock do_3pid_invite, so we don't fail from failing to send a 3PID invite to an
+        # IS.
+        async def do_3pid_invite(
+            room_id: str,
+            inviter: UserID,
+            medium: str,
+            address: str,
+            id_server: str,
+            requester: Requester,
+            txn_id: Optional[str],
+            id_access_token: Optional[str] = None,
+        ) -> int:
+            return 0
+
+        do_3pid_invite_mock = Mock(side_effect=do_3pid_invite)
+        self.hs.get_room_member_handler().do_3pid_invite = do_3pid_invite_mock
+
+        # Add a mock callback for user_may_create_room_with_invites. Make it allow any
+        # room creation request for now.
+        return_value = True
+
+        async def user_may_create_room_with_invites(
+            user: str,
+            invites: List[str],
+            threepid_invites: List[Dict[str, str]],
+        ) -> bool:
+            return return_value
+
+        callback_mock = Mock(side_effect=user_may_create_room_with_invites)
+        self.hs.get_spam_checker()._user_may_create_room_with_invites_callbacks.append(
+            callback_mock,
+        )
+
+        # The MXIDs we'll try to invite.
+        invited_mxids = [
+            "@alice1:red",
+            "@alice2:red",
+            "@alice3:red",
+            "@alice4:red",
+        ]
+
+        # The 3PIDs we'll try to invite.
+        invited_3pids = [
+            {
+                "id_server": "example.com",
+                "id_access_token": "sometoken",
+                "medium": "email",
+                "address": "alice1@example.com",
+            },
+            {
+                "id_server": "example.com",
+                "id_access_token": "sometoken",
+                "medium": "email",
+                "address": "alice2@example.com",
+            },
+            {
+                "id_server": "example.com",
+                "id_access_token": "sometoken",
+                "medium": "email",
+                "address": "alice3@example.com",
+            },
+        ]
+
+        # Create a room and invite the Matrix users, and check that it succeeded.
+        channel = self.make_request(
+            "POST",
+            "/createRoom",
+            json.dumps({"invite": invited_mxids}).encode("utf8"),
+        )
+        self.assertEqual(200, channel.code)
+
+        # Check that the callback was called with the right arguments.
+        expected_call_args = ((self.user_id, invited_mxids, []),)
+        self.assertEquals(
+            callback_mock.call_args,
+            expected_call_args,
+            callback_mock.call_args,
+        )
+
+        # Create a room and invite the 3PIDs, and check that it succeeded.
+        channel = self.make_request(
+            "POST",
+            "/createRoom",
+            json.dumps({"invite_3pid": invited_3pids}).encode("utf8"),
+        )
+        self.assertEqual(200, channel.code)
+
+        # Check that do_3pid_invite was called the right amount of time
+        self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids))
+
+        # Check that the callback was called with the right arguments.
+        expected_call_args = ((self.user_id, [], invited_3pids),)
+        self.assertEquals(
+            callback_mock.call_args,
+            expected_call_args,
+            callback_mock.call_args,
+        )
+
+        # Now deny any room creation.
+        return_value = False
+
+        # Create a room and invite the 3PIDs, and check that it failed.
+        channel = self.make_request(
+            "POST",
+            "/createRoom",
+            json.dumps({"invite_3pid": invited_3pids}).encode("utf8"),
+        )
+        self.assertEqual(403, channel.code)
+
+        # Check that do_3pid_invite wasn't called this time.
+        self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids))
+
 
 class RoomTopicTestCase(RoomBase):
     """Tests /rooms/$room_id/topic REST events."""