diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index b946fca8b3..9e9e953cf4 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -312,7 +312,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Load the password reset confirmation page
channel = make_request(
self.reactor,
- FakeSite(self.submit_token_resource),
+ FakeSite(self.submit_token_resource, self.reactor),
"GET",
path,
shorthand=False,
@@ -326,7 +326,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Confirm the password reset
channel = make_request(
self.reactor,
- FakeSite(self.submit_token_resource),
+ FakeSite(self.submit_token_resource, self.reactor),
"POST",
path,
content=b"",
diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py
index 65c58ce70a..84d092ca82 100644
--- a/tests/rest/client/test_consent.py
+++ b/tests/rest/client/test_consent.py
@@ -61,7 +61,11 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
"""You can observe the terms form without specifying a user"""
resource = consent_resource.ConsentResource(self.hs)
channel = make_request(
- self.reactor, FakeSite(resource), "GET", "/consent?v=1", shorthand=False
+ self.reactor,
+ FakeSite(resource, self.reactor),
+ "GET",
+ "/consent?v=1",
+ shorthand=False,
)
self.assertEqual(channel.code, 200)
@@ -83,7 +87,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
)
channel = make_request(
self.reactor,
- FakeSite(resource),
+ FakeSite(resource, self.reactor),
"GET",
consent_uri,
access_token=access_token,
@@ -98,7 +102,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
# POST to the consent page, saying we've agreed
channel = make_request(
self.reactor,
- FakeSite(resource),
+ FakeSite(resource, self.reactor),
"POST",
consent_uri + "&v=" + version,
access_token=access_token,
@@ -110,7 +114,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
# changed
channel = make_request(
self.reactor,
- FakeSite(resource),
+ FakeSite(resource, self.reactor),
"GET",
consent_uri,
access_token=access_token,
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index f5c195a075..371615a015 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -97,7 +97,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.hs.config.enable_registration = True
self.hs.config.registrations_require_3pid = []
self.hs.config.auto_join_rooms = []
- self.hs.config.enable_registration_captcha = False
+ self.hs.config.captcha.enable_registration_captcha = False
return self.hs
@@ -815,9 +815,9 @@ class JWTTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver()
- self.hs.config.jwt_enabled = True
- self.hs.config.jwt_secret = self.jwt_secret
- self.hs.config.jwt_algorithm = self.jwt_algorithm
+ self.hs.config.jwt.jwt_enabled = True
+ self.hs.config.jwt.jwt_secret = self.jwt_secret
+ self.hs.config.jwt.jwt_algorithm = self.jwt_algorithm
return self.hs
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
@@ -1023,9 +1023,9 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver()
- self.hs.config.jwt_enabled = True
- self.hs.config.jwt_secret = self.jwt_pubkey
- self.hs.config.jwt_algorithm = "RS256"
+ self.hs.config.jwt.jwt_enabled = True
+ self.hs.config.jwt.jwt_secret = self.jwt_pubkey
+ self.hs.config.jwt.jwt_algorithm = "RS256"
return self.hs
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 9f3ab2c985..72a5a11b46 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -146,7 +146,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN")
def test_POST_guest_registration(self):
- self.hs.config.macaroon_secret_key = "test"
+ self.hs.config.key.macaroon_secret_key = "test"
self.hs.config.allow_guest_access = True
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index ef847f0f5f..30bdaa9c27 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -18,7 +18,7 @@
"""Tests REST events for /rooms paths."""
import json
-from typing import Iterable
+from typing import Dict, Iterable, List, Optional
from unittest.mock import Mock, call
from urllib import parse as urlparse
@@ -30,7 +30,7 @@ from synapse.api.errors import Codes, HttpResponseException
from synapse.handlers.pagination import PurgeStatus
from synapse.rest import admin
from synapse.rest.client import account, directory, login, profile, room, sync
-from synapse.types import JsonDict, RoomAlias, UserID, create_requester
+from synapse.types import JsonDict, Requester, RoomAlias, UserID, create_requester
from synapse.util.stringutils import random_string
from tests import unittest
@@ -669,6 +669,121 @@ class RoomsCreateTestCase(RoomBase):
channel = self.make_request("POST", "/createRoom", content)
self.assertEqual(200, channel.code)
+ def test_spamchecker_invites(self):
+ """Tests the user_may_create_room_with_invites spam checker callback."""
+
+ # Mock do_3pid_invite, so we don't fail from failing to send a 3PID invite to an
+ # IS.
+ async def do_3pid_invite(
+ room_id: str,
+ inviter: UserID,
+ medium: str,
+ address: str,
+ id_server: str,
+ requester: Requester,
+ txn_id: Optional[str],
+ id_access_token: Optional[str] = None,
+ ) -> int:
+ return 0
+
+ do_3pid_invite_mock = Mock(side_effect=do_3pid_invite)
+ self.hs.get_room_member_handler().do_3pid_invite = do_3pid_invite_mock
+
+ # Add a mock callback for user_may_create_room_with_invites. Make it allow any
+ # room creation request for now.
+ return_value = True
+
+ async def user_may_create_room_with_invites(
+ user: str,
+ invites: List[str],
+ threepid_invites: List[Dict[str, str]],
+ ) -> bool:
+ return return_value
+
+ callback_mock = Mock(side_effect=user_may_create_room_with_invites)
+ self.hs.get_spam_checker()._user_may_create_room_with_invites_callbacks.append(
+ callback_mock,
+ )
+
+ # The MXIDs we'll try to invite.
+ invited_mxids = [
+ "@alice1:red",
+ "@alice2:red",
+ "@alice3:red",
+ "@alice4:red",
+ ]
+
+ # The 3PIDs we'll try to invite.
+ invited_3pids = [
+ {
+ "id_server": "example.com",
+ "id_access_token": "sometoken",
+ "medium": "email",
+ "address": "alice1@example.com",
+ },
+ {
+ "id_server": "example.com",
+ "id_access_token": "sometoken",
+ "medium": "email",
+ "address": "alice2@example.com",
+ },
+ {
+ "id_server": "example.com",
+ "id_access_token": "sometoken",
+ "medium": "email",
+ "address": "alice3@example.com",
+ },
+ ]
+
+ # Create a room and invite the Matrix users, and check that it succeeded.
+ channel = self.make_request(
+ "POST",
+ "/createRoom",
+ json.dumps({"invite": invited_mxids}).encode("utf8"),
+ )
+ self.assertEqual(200, channel.code)
+
+ # Check that the callback was called with the right arguments.
+ expected_call_args = ((self.user_id, invited_mxids, []),)
+ self.assertEquals(
+ callback_mock.call_args,
+ expected_call_args,
+ callback_mock.call_args,
+ )
+
+ # Create a room and invite the 3PIDs, and check that it succeeded.
+ channel = self.make_request(
+ "POST",
+ "/createRoom",
+ json.dumps({"invite_3pid": invited_3pids}).encode("utf8"),
+ )
+ self.assertEqual(200, channel.code)
+
+ # Check that do_3pid_invite was called the right amount of time
+ self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids))
+
+ # Check that the callback was called with the right arguments.
+ expected_call_args = ((self.user_id, [], invited_3pids),)
+ self.assertEquals(
+ callback_mock.call_args,
+ expected_call_args,
+ callback_mock.call_args,
+ )
+
+ # Now deny any room creation.
+ return_value = False
+
+ # Create a room and invite the 3PIDs, and check that it failed.
+ channel = self.make_request(
+ "POST",
+ "/createRoom",
+ json.dumps({"invite_3pid": invited_3pids}).encode("utf8"),
+ )
+ self.assertEqual(403, channel.code)
+
+ # Check that do_3pid_invite wasn't called this time.
+ self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids))
+
class RoomTopicTestCase(RoomBase):
"""Tests /rooms/$room_id/topic REST events."""
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index c56e45fc10..3075d3f288 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -383,7 +383,7 @@ class RestHelper:
path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
channel = make_request(
self.hs.get_reactor(),
- FakeSite(resource),
+ FakeSite(resource, self.hs.get_reactor()),
"POST",
path,
content=image_data,
|