diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index ef847f0f5f..30bdaa9c27 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -18,7 +18,7 @@
"""Tests REST events for /rooms paths."""
import json
-from typing import Iterable
+from typing import Dict, Iterable, List, Optional
from unittest.mock import Mock, call
from urllib import parse as urlparse
@@ -30,7 +30,7 @@ from synapse.api.errors import Codes, HttpResponseException
from synapse.handlers.pagination import PurgeStatus
from synapse.rest import admin
from synapse.rest.client import account, directory, login, profile, room, sync
-from synapse.types import JsonDict, RoomAlias, UserID, create_requester
+from synapse.types import JsonDict, Requester, RoomAlias, UserID, create_requester
from synapse.util.stringutils import random_string
from tests import unittest
@@ -669,6 +669,121 @@ class RoomsCreateTestCase(RoomBase):
channel = self.make_request("POST", "/createRoom", content)
self.assertEqual(200, channel.code)
+ def test_spamchecker_invites(self):
+ """Tests the user_may_create_room_with_invites spam checker callback."""
+
+ # Mock do_3pid_invite, so we don't fail from failing to send a 3PID invite to an
+ # IS.
+ async def do_3pid_invite(
+ room_id: str,
+ inviter: UserID,
+ medium: str,
+ address: str,
+ id_server: str,
+ requester: Requester,
+ txn_id: Optional[str],
+ id_access_token: Optional[str] = None,
+ ) -> int:
+ return 0
+
+ do_3pid_invite_mock = Mock(side_effect=do_3pid_invite)
+ self.hs.get_room_member_handler().do_3pid_invite = do_3pid_invite_mock
+
+ # Add a mock callback for user_may_create_room_with_invites. Make it allow any
+ # room creation request for now.
+ return_value = True
+
+ async def user_may_create_room_with_invites(
+ user: str,
+ invites: List[str],
+ threepid_invites: List[Dict[str, str]],
+ ) -> bool:
+ return return_value
+
+ callback_mock = Mock(side_effect=user_may_create_room_with_invites)
+ self.hs.get_spam_checker()._user_may_create_room_with_invites_callbacks.append(
+ callback_mock,
+ )
+
+ # The MXIDs we'll try to invite.
+ invited_mxids = [
+ "@alice1:red",
+ "@alice2:red",
+ "@alice3:red",
+ "@alice4:red",
+ ]
+
+ # The 3PIDs we'll try to invite.
+ invited_3pids = [
+ {
+ "id_server": "example.com",
+ "id_access_token": "sometoken",
+ "medium": "email",
+ "address": "alice1@example.com",
+ },
+ {
+ "id_server": "example.com",
+ "id_access_token": "sometoken",
+ "medium": "email",
+ "address": "alice2@example.com",
+ },
+ {
+ "id_server": "example.com",
+ "id_access_token": "sometoken",
+ "medium": "email",
+ "address": "alice3@example.com",
+ },
+ ]
+
+ # Create a room and invite the Matrix users, and check that it succeeded.
+ channel = self.make_request(
+ "POST",
+ "/createRoom",
+ json.dumps({"invite": invited_mxids}).encode("utf8"),
+ )
+ self.assertEqual(200, channel.code)
+
+ # Check that the callback was called with the right arguments.
+ expected_call_args = ((self.user_id, invited_mxids, []),)
+ self.assertEquals(
+ callback_mock.call_args,
+ expected_call_args,
+ callback_mock.call_args,
+ )
+
+ # Create a room and invite the 3PIDs, and check that it succeeded.
+ channel = self.make_request(
+ "POST",
+ "/createRoom",
+ json.dumps({"invite_3pid": invited_3pids}).encode("utf8"),
+ )
+ self.assertEqual(200, channel.code)
+
+ # Check that do_3pid_invite was called the right amount of time
+ self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids))
+
+ # Check that the callback was called with the right arguments.
+ expected_call_args = ((self.user_id, [], invited_3pids),)
+ self.assertEquals(
+ callback_mock.call_args,
+ expected_call_args,
+ callback_mock.call_args,
+ )
+
+ # Now deny any room creation.
+ return_value = False
+
+ # Create a room and invite the 3PIDs, and check that it failed.
+ channel = self.make_request(
+ "POST",
+ "/createRoom",
+ json.dumps({"invite_3pid": invited_3pids}).encode("utf8"),
+ )
+ self.assertEqual(403, channel.code)
+
+ # Check that do_3pid_invite wasn't called this time.
+ self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids))
+
class RoomTopicTestCase(RoomBase):
"""Tests /rooms/$room_id/topic REST events."""
|