diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 30bdaa9c27..376853fd65 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -784,6 +784,30 @@ class RoomsCreateTestCase(RoomBase):
# Check that do_3pid_invite wasn't called this time.
self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids))
+ def test_spam_checker_may_join_room(self):
+ """Tests that the user_may_join_room spam checker callback is correctly bypassed
+ when creating a new room.
+ """
+
+ async def user_may_join_room(
+ mxid: str,
+ room_id: str,
+ is_invite: bool,
+ ) -> bool:
+ return False
+
+ join_mock = Mock(side_effect=user_may_join_room)
+ self.hs.get_spam_checker()._user_may_join_room_callbacks.append(join_mock)
+
+ channel = self.make_request(
+ "POST",
+ "/createRoom",
+ {},
+ )
+ self.assertEquals(channel.code, 200, channel.json_body)
+
+ self.assertEquals(join_mock.call_count, 0)
+
class RoomTopicTestCase(RoomBase):
"""Tests /rooms/$room_id/topic REST events."""
@@ -975,6 +999,83 @@ class RoomInviteRatelimitTestCase(RoomBase):
self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429)
+class RoomJoinTestCase(RoomBase):
+
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user1 = self.register_user("thomas", "hackme")
+ self.tok1 = self.login("thomas", "hackme")
+
+ self.user2 = self.register_user("teresa", "hackme")
+ self.tok2 = self.login("teresa", "hackme")
+
+ self.room1 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
+ self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
+ self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
+
+ def test_spam_checker_may_join_room(self):
+ """Tests that the user_may_join_room spam checker callback is correctly called
+ and blocks room joins when needed.
+ """
+
+ # Register a dummy callback. Make it allow all room joins for now.
+ return_value = True
+
+ async def user_may_join_room(
+ userid: str,
+ room_id: str,
+ is_invited: bool,
+ ) -> bool:
+ return return_value
+
+ callback_mock = Mock(side_effect=user_may_join_room)
+ self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock)
+
+ # Join a first room, without being invited to it.
+ self.helper.join(self.room1, self.user2, tok=self.tok2)
+
+ # Check that the callback was called with the right arguments.
+ expected_call_args = (
+ (
+ self.user2,
+ self.room1,
+ False,
+ ),
+ )
+ self.assertEquals(
+ callback_mock.call_args,
+ expected_call_args,
+ callback_mock.call_args,
+ )
+
+ # Join a second room, this time with an invite for it.
+ self.helper.invite(self.room2, self.user1, self.user2, tok=self.tok1)
+ self.helper.join(self.room2, self.user2, tok=self.tok2)
+
+ # Check that the callback was called with the right arguments.
+ expected_call_args = (
+ (
+ self.user2,
+ self.room2,
+ True,
+ ),
+ )
+ self.assertEquals(
+ callback_mock.call_args,
+ expected_call_args,
+ callback_mock.call_args,
+ )
+
+ # Now make the callback deny all room joins, and check that a join actually fails.
+ return_value = False
+ self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2)
+
+
class RoomJoinRatelimitTestCase(RoomBase):
user_id = "@sid1:red"
@@ -2430,3 +2531,73 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
"""An alias which does not point to the room raises a SynapseError."""
self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400)
self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400)
+
+
+class ThreepidInviteTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("thomas", "hackme")
+ self.tok = self.login("thomas", "hackme")
+
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+ def test_threepid_invite_spamcheck(self):
+ # Mock a few functions to prevent the test from failing due to failing to talk to
+ # a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we
+ # can check its call_count later on during the test.
+ make_invite_mock = Mock(return_value=make_awaitable(0))
+ self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
+ self.hs.get_identity_handler().lookup_3pid = Mock(
+ return_value=make_awaitable(None),
+ )
+
+ # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it
+ # allow everything for now.
+ mock = Mock(return_value=make_awaitable(True))
+ self.hs.get_spam_checker()._user_may_send_3pid_invite_callbacks.append(mock)
+
+ # Send a 3PID invite into the room and check that it succeeded.
+ email_to_invite = "teresa@example.com"
+ channel = self.make_request(
+ method="POST",
+ path="/rooms/" + self.room_id + "/invite",
+ content={
+ "id_server": "example.com",
+ "id_access_token": "sometoken",
+ "medium": "email",
+ "address": email_to_invite,
+ },
+ access_token=self.tok,
+ )
+ self.assertEquals(channel.code, 200)
+
+ # Check that the callback was called with the right params.
+ mock.assert_called_with(self.user_id, "email", email_to_invite, self.room_id)
+
+ # Check that the call to send the invite was made.
+ make_invite_mock.assert_called_once()
+
+ # Now change the return value of the callback to deny any invite and test that
+ # we can't send the invite.
+ mock.return_value = make_awaitable(False)
+ channel = self.make_request(
+ method="POST",
+ path="/rooms/" + self.room_id + "/invite",
+ content={
+ "id_server": "example.com",
+ "id_access_token": "sometoken",
+ "medium": "email",
+ "address": email_to_invite,
+ },
+ access_token=self.tok,
+ )
+ self.assertEquals(channel.code, 403)
+
+ # Also check that it stopped before calling _make_and_store_3pid_invite.
+ make_invite_mock.assert_called_once()
|