summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/10898.feature1
-rw-r--r--docs/modules/spam_checker_callbacks.md29
-rw-r--r--synapse/events/spamcheck.py42
-rw-r--r--synapse/handlers/room.py14
-rw-r--r--tests/rest/client/test_rooms.py119
5 files changed, 199 insertions, 6 deletions
diff --git a/changelog.d/10898.feature b/changelog.d/10898.feature
new file mode 100644
index 0000000000..97fa39fd0c
--- /dev/null
+++ b/changelog.d/10898.feature
@@ -0,0 +1 @@
+Add a `user_may_create_room_with_invites` spam checker callback to allow modules to allow or deny a room creation request based on the invites and/or 3PID invites it includes.
diff --git a/docs/modules/spam_checker_callbacks.md b/docs/modules/spam_checker_callbacks.md
index 81574a015c..7920ac5f8f 100644
--- a/docs/modules/spam_checker_callbacks.md
+++ b/docs/modules/spam_checker_callbacks.md
@@ -38,6 +38,35 @@ async def user_may_create_room(user: str) -> bool
 Called when processing a room creation request. The module must return a `bool` indicating
 whether the given user (represented by their Matrix user ID) is allowed to create a room.
 
+### `user_may_create_room_with_invites`
+
+```python
+async def user_may_create_room_with_invites(
+    user: str,
+    invites: List[str],
+    threepid_invites: List[Dict[str, str]],
+) -> bool
+```
+
+Called when processing a room creation request (right after `user_may_create_room`).
+The module is given the Matrix user ID of the user trying to create a room, as well as a
+list of Matrix users to invite and a list of third-party identifiers (3PID, e.g. email
+addresses) to invite.
+
+An invited Matrix user to invite is represented by their Matrix user IDs, and an invited
+3PIDs is represented by a dict that includes the 3PID medium (e.g. "email") through its
+`medium` key and its address (e.g. "alice@example.com") through its `address` key.
+
+See [the Matrix specification](https://matrix.org/docs/spec/appendices#pid-types) for more
+information regarding third-party identifiers.
+
+If no invite and/or 3PID invite were specified in the room creation request, the
+corresponding list(s) will be empty.
+
+**Note**: This callback is not called when a room is cloned (e.g. during a room upgrade)
+since no invites are sent when cloning a room. To cover this case, modules also need to
+implement `user_may_create_room`.
+
 ### `user_may_create_room_alias`
 
 ```python
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 19ee246f96..c389f70b8d 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -46,6 +46,9 @@ CHECK_EVENT_FOR_SPAM_CALLBACK = Callable[
 ]
 USER_MAY_INVITE_CALLBACK = Callable[[str, str, str], Awaitable[bool]]
 USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]]
+USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK = Callable[
+    [str, List[str], List[Dict[str, str]]], Awaitable[bool]
+]
 USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[[str, RoomAlias], Awaitable[bool]]
 USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]]
 CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[Dict[str, str]], Awaitable[bool]]
@@ -164,6 +167,9 @@ class SpamChecker:
         self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = []
         self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = []
         self._user_may_create_room_callbacks: List[USER_MAY_CREATE_ROOM_CALLBACK] = []
+        self._user_may_create_room_with_invites_callbacks: List[
+            USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK
+        ] = []
         self._user_may_create_room_alias_callbacks: List[
             USER_MAY_CREATE_ROOM_ALIAS_CALLBACK
         ] = []
@@ -183,6 +189,9 @@ class SpamChecker:
         check_event_for_spam: Optional[CHECK_EVENT_FOR_SPAM_CALLBACK] = None,
         user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None,
         user_may_create_room: Optional[USER_MAY_CREATE_ROOM_CALLBACK] = None,
+        user_may_create_room_with_invites: Optional[
+            USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK
+        ] = None,
         user_may_create_room_alias: Optional[
             USER_MAY_CREATE_ROOM_ALIAS_CALLBACK
         ] = None,
@@ -203,6 +212,11 @@ class SpamChecker:
         if user_may_create_room is not None:
             self._user_may_create_room_callbacks.append(user_may_create_room)
 
+        if user_may_create_room_with_invites is not None:
+            self._user_may_create_room_with_invites_callbacks.append(
+                user_may_create_room_with_invites,
+            )
+
         if user_may_create_room_alias is not None:
             self._user_may_create_room_alias_callbacks.append(
                 user_may_create_room_alias,
@@ -283,6 +297,34 @@ class SpamChecker:
 
         return True
 
+    async def user_may_create_room_with_invites(
+        self,
+        userid: str,
+        invites: List[str],
+        threepid_invites: List[Dict[str, str]],
+    ) -> bool:
+        """Checks if a given user may create a room with invites
+
+        If this method returns false, the creation request will be rejected.
+
+        Args:
+            userid: The ID of the user attempting to create a room
+            invites: The IDs of the Matrix users to be invited if the room creation is
+                allowed.
+            threepid_invites: The threepids to be invited if the room creation is allowed,
+                as a dict including a "medium" key indicating the threepid's medium (e.g.
+                "email") and an "address" key indicating the threepid's address (e.g.
+                "alice@example.com")
+
+        Returns:
+            True if the user may create the room, otherwise False
+        """
+        for callback in self._user_may_create_room_with_invites_callbacks:
+            if await callback(userid, invites, threepid_invites) is False:
+                return False
+
+        return True
+
     async def user_may_create_room_alias(
         self, userid: str, room_alias: RoomAlias
     ) -> bool:
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 408b7d7b74..8fede5e935 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -649,8 +649,16 @@ class RoomCreationHandler(BaseHandler):
             requester, config, is_requester_admin=is_requester_admin
         )
 
-        if not is_requester_admin and not await self.spam_checker.user_may_create_room(
-            user_id
+        invite_3pid_list = config.get("invite_3pid", [])
+        invite_list = config.get("invite", [])
+
+        if not is_requester_admin and not (
+            await self.spam_checker.user_may_create_room(user_id)
+            and await self.spam_checker.user_may_create_room_with_invites(
+                user_id,
+                invite_list,
+                invite_3pid_list,
+            )
         ):
             raise SynapseError(403, "You are not permitted to create rooms")
 
@@ -684,8 +692,6 @@ class RoomCreationHandler(BaseHandler):
             if mapping:
                 raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE)
 
-        invite_3pid_list = config.get("invite_3pid", [])
-        invite_list = config.get("invite", [])
         for i in invite_list:
             try:
                 uid = UserID.from_string(i)
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index ef847f0f5f..30bdaa9c27 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -18,7 +18,7 @@
 """Tests REST events for /rooms paths."""
 
 import json
-from typing import Iterable
+from typing import Dict, Iterable, List, Optional
 from unittest.mock import Mock, call
 from urllib import parse as urlparse
 
@@ -30,7 +30,7 @@ from synapse.api.errors import Codes, HttpResponseException
 from synapse.handlers.pagination import PurgeStatus
 from synapse.rest import admin
 from synapse.rest.client import account, directory, login, profile, room, sync
-from synapse.types import JsonDict, RoomAlias, UserID, create_requester
+from synapse.types import JsonDict, Requester, RoomAlias, UserID, create_requester
 from synapse.util.stringutils import random_string
 
 from tests import unittest
@@ -669,6 +669,121 @@ class RoomsCreateTestCase(RoomBase):
         channel = self.make_request("POST", "/createRoom", content)
         self.assertEqual(200, channel.code)
 
+    def test_spamchecker_invites(self):
+        """Tests the user_may_create_room_with_invites spam checker callback."""
+
+        # Mock do_3pid_invite, so we don't fail from failing to send a 3PID invite to an
+        # IS.
+        async def do_3pid_invite(
+            room_id: str,
+            inviter: UserID,
+            medium: str,
+            address: str,
+            id_server: str,
+            requester: Requester,
+            txn_id: Optional[str],
+            id_access_token: Optional[str] = None,
+        ) -> int:
+            return 0
+
+        do_3pid_invite_mock = Mock(side_effect=do_3pid_invite)
+        self.hs.get_room_member_handler().do_3pid_invite = do_3pid_invite_mock
+
+        # Add a mock callback for user_may_create_room_with_invites. Make it allow any
+        # room creation request for now.
+        return_value = True
+
+        async def user_may_create_room_with_invites(
+            user: str,
+            invites: List[str],
+            threepid_invites: List[Dict[str, str]],
+        ) -> bool:
+            return return_value
+
+        callback_mock = Mock(side_effect=user_may_create_room_with_invites)
+        self.hs.get_spam_checker()._user_may_create_room_with_invites_callbacks.append(
+            callback_mock,
+        )
+
+        # The MXIDs we'll try to invite.
+        invited_mxids = [
+            "@alice1:red",
+            "@alice2:red",
+            "@alice3:red",
+            "@alice4:red",
+        ]
+
+        # The 3PIDs we'll try to invite.
+        invited_3pids = [
+            {
+                "id_server": "example.com",
+                "id_access_token": "sometoken",
+                "medium": "email",
+                "address": "alice1@example.com",
+            },
+            {
+                "id_server": "example.com",
+                "id_access_token": "sometoken",
+                "medium": "email",
+                "address": "alice2@example.com",
+            },
+            {
+                "id_server": "example.com",
+                "id_access_token": "sometoken",
+                "medium": "email",
+                "address": "alice3@example.com",
+            },
+        ]
+
+        # Create a room and invite the Matrix users, and check that it succeeded.
+        channel = self.make_request(
+            "POST",
+            "/createRoom",
+            json.dumps({"invite": invited_mxids}).encode("utf8"),
+        )
+        self.assertEqual(200, channel.code)
+
+        # Check that the callback was called with the right arguments.
+        expected_call_args = ((self.user_id, invited_mxids, []),)
+        self.assertEquals(
+            callback_mock.call_args,
+            expected_call_args,
+            callback_mock.call_args,
+        )
+
+        # Create a room and invite the 3PIDs, and check that it succeeded.
+        channel = self.make_request(
+            "POST",
+            "/createRoom",
+            json.dumps({"invite_3pid": invited_3pids}).encode("utf8"),
+        )
+        self.assertEqual(200, channel.code)
+
+        # Check that do_3pid_invite was called the right amount of time
+        self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids))
+
+        # Check that the callback was called with the right arguments.
+        expected_call_args = ((self.user_id, [], invited_3pids),)
+        self.assertEquals(
+            callback_mock.call_args,
+            expected_call_args,
+            callback_mock.call_args,
+        )
+
+        # Now deny any room creation.
+        return_value = False
+
+        # Create a room and invite the 3PIDs, and check that it failed.
+        channel = self.make_request(
+            "POST",
+            "/createRoom",
+            json.dumps({"invite_3pid": invited_3pids}).encode("utf8"),
+        )
+        self.assertEqual(403, channel.code)
+
+        # Check that do_3pid_invite wasn't called this time.
+        self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids))
+
 
 class RoomTopicTestCase(RoomBase):
     """Tests /rooms/$room_id/topic REST events."""