summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/10910.feature1
-rw-r--r--docs/modules/spam_checker_callbacks.md15
-rw-r--r--synapse/events/spamcheck.py24
-rw-r--r--synapse/handlers/room.py2
-rw-r--r--synapse/handlers/room_member.py31
-rw-r--r--tests/rest/client/test_rooms.py101
6 files changed, 174 insertions, 0 deletions
diff --git a/changelog.d/10910.feature b/changelog.d/10910.feature
new file mode 100644
index 0000000000..aee139f8b6
--- /dev/null
+++ b/changelog.d/10910.feature
@@ -0,0 +1 @@
+Add a spam checker callback to allow or deny room joins.
diff --git a/docs/modules/spam_checker_callbacks.md b/docs/modules/spam_checker_callbacks.md
index 7920ac5f8f..92376df993 100644
--- a/docs/modules/spam_checker_callbacks.md
+++ b/docs/modules/spam_checker_callbacks.md
@@ -19,6 +19,21 @@ either a `bool` to indicate whether the event must be rejected because of spam,
 to indicate the event must be rejected because of spam and to give a rejection reason to
 forward to clients.
 
+### `user_may_join_room`
+
+```python
+async def user_may_join_room(user: str, room: str, is_invited: bool) -> bool
+```
+
+Called when a user is trying to join a room. The module must return a `bool` to indicate
+whether the user can join the room. The user is represented by their Matrix user ID (e.g.
+`@alice:example.com`) and the room is represented by its Matrix ID (e.g.
+`!room:example.com`). The module is also given a boolean to indicate whether the user
+currently has a pending invite in the room.
+
+This callback isn't called if the join is performed by a server administrator, or in the
+context of a room creation.
+
 ### `user_may_invite`
 
 ```python
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index c389f70b8d..ec8863e397 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -44,6 +44,7 @@ CHECK_EVENT_FOR_SPAM_CALLBACK = Callable[
     ["synapse.events.EventBase"],
     Awaitable[Union[bool, str]],
 ]
+USER_MAY_JOIN_ROOM_CALLBACK = Callable[[str, str, bool], Awaitable[bool]]
 USER_MAY_INVITE_CALLBACK = Callable[[str, str, str], Awaitable[bool]]
 USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]]
 USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK = Callable[
@@ -165,6 +166,7 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
 class SpamChecker:
     def __init__(self):
         self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = []
+        self._user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = []
         self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = []
         self._user_may_create_room_callbacks: List[USER_MAY_CREATE_ROOM_CALLBACK] = []
         self._user_may_create_room_with_invites_callbacks: List[
@@ -187,6 +189,7 @@ class SpamChecker:
     def register_callbacks(
         self,
         check_event_for_spam: Optional[CHECK_EVENT_FOR_SPAM_CALLBACK] = None,
+        user_may_join_room: Optional[USER_MAY_JOIN_ROOM_CALLBACK] = None,
         user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None,
         user_may_create_room: Optional[USER_MAY_CREATE_ROOM_CALLBACK] = None,
         user_may_create_room_with_invites: Optional[
@@ -206,6 +209,9 @@ class SpamChecker:
         if check_event_for_spam is not None:
             self._check_event_for_spam_callbacks.append(check_event_for_spam)
 
+        if user_may_join_room is not None:
+            self._user_may_join_room_callbacks.append(user_may_join_room)
+
         if user_may_invite is not None:
             self._user_may_invite_callbacks.append(user_may_invite)
 
@@ -259,6 +265,24 @@ class SpamChecker:
 
         return False
 
+    async def user_may_join_room(self, user_id: str, room_id: str, is_invited: bool):
+        """Checks if a given users is allowed to join a room.
+        Not called when a user creates a room.
+
+        Args:
+            userid: The ID of the user wanting to join the room
+            room_id: The ID of the room the user wants to join
+            is_invited: Whether the user is invited into the room
+
+        Returns:
+            bool: Whether the user may join the room
+        """
+        for callback in self._user_may_join_room_callbacks:
+            if await callback(user_id, room_id, is_invited) is False:
+                return False
+
+        return True
+
     async def user_may_invite(
         self, inviter_userid: str, invitee_userid: str, room_id: str
     ) -> bool:
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 873e08258e..d40dbd761d 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -860,6 +860,7 @@ class RoomCreationHandler(BaseHandler):
                     "invite",
                     ratelimit=False,
                     content=content,
+                    new_room=True,
                 )
 
         for invite_3pid in invite_3pid_list:
@@ -962,6 +963,7 @@ class RoomCreationHandler(BaseHandler):
             "join",
             ratelimit=ratelimit,
             content=creator_join_profile,
+            new_room=True,
         )
 
         # We treat the power levels override specially as this needs to be one
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index c8fb24a20c..0b79dbcf8d 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -434,6 +434,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         third_party_signed: Optional[dict] = None,
         ratelimit: bool = True,
         content: Optional[dict] = None,
+        new_room: bool = False,
         require_consent: bool = True,
         outlier: bool = False,
         prev_event_ids: Optional[List[str]] = None,
@@ -451,6 +452,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             third_party_signed: Information from a 3PID invite.
             ratelimit: Whether to rate limit the request.
             content: The content of the created event.
+            new_room: Whether the membership update is happening in the context of a room
+                creation.
             require_consent: Whether consent is required.
             outlier: Indicates whether the event is an `outlier`, i.e. if
                 it's from an arbitrary point and floating in the DAG as
@@ -485,6 +488,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 third_party_signed=third_party_signed,
                 ratelimit=ratelimit,
                 content=content,
+                new_room=new_room,
                 require_consent=require_consent,
                 outlier=outlier,
                 prev_event_ids=prev_event_ids,
@@ -504,6 +508,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         third_party_signed: Optional[dict] = None,
         ratelimit: bool = True,
         content: Optional[dict] = None,
+        new_room: bool = False,
         require_consent: bool = True,
         outlier: bool = False,
         prev_event_ids: Optional[List[str]] = None,
@@ -523,6 +528,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             third_party_signed:
             ratelimit:
             content:
+            new_room: Whether the membership update is happening in the context of a room
+                creation.
             require_consent:
             outlier: Indicates whether the event is an `outlier`, i.e. if
                 it's from an arbitrary point and floating in the DAG as
@@ -726,6 +733,30 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     # so don't really fit into the general auth process.
                     raise AuthError(403, "Guest access not allowed")
 
+            # Figure out whether the user is a server admin to determine whether they
+            # should be able to bypass the spam checker.
+            if (
+                self._server_notices_mxid is not None
+                and requester.user.to_string() == self._server_notices_mxid
+            ):
+                # allow the server notices mxid to join rooms
+                bypass_spam_checker = True
+
+            else:
+                bypass_spam_checker = await self.auth.is_server_admin(requester.user)
+
+            inviter = await self._get_inviter(target.to_string(), room_id)
+            if (
+                not bypass_spam_checker
+                # We assume that if the spam checker allowed the user to create
+                # a room then they're allowed to join it.
+                and not new_room
+                and not await self.spam_checker.user_may_join_room(
+                    target.to_string(), room_id, is_invited=inviter is not None
+                )
+            ):
+                raise SynapseError(403, "Not allowed to join this room")
+
             # Check if a remote join should be performed.
             remote_join, remote_room_hosts = await self._should_perform_remote_join(
                 target.to_string(), room_id, remote_room_hosts, content, is_host_in_room
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 30bdaa9c27..a41ec6a98f 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -784,6 +784,30 @@ class RoomsCreateTestCase(RoomBase):
         # Check that do_3pid_invite wasn't called this time.
         self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids))
 
+    def test_spam_checker_may_join_room(self):
+        """Tests that the user_may_join_room spam checker callback is correctly bypassed
+        when creating a new room.
+        """
+
+        async def user_may_join_room(
+            mxid: str,
+            room_id: str,
+            is_invite: bool,
+        ) -> bool:
+            return False
+
+        join_mock = Mock(side_effect=user_may_join_room)
+        self.hs.get_spam_checker()._user_may_join_room_callbacks.append(join_mock)
+
+        channel = self.make_request(
+            "POST",
+            "/createRoom",
+            {},
+        )
+        self.assertEquals(channel.code, 200, channel.json_body)
+
+        self.assertEquals(join_mock.call_count, 0)
+
 
 class RoomTopicTestCase(RoomBase):
     """Tests /rooms/$room_id/topic REST events."""
@@ -975,6 +999,83 @@ class RoomInviteRatelimitTestCase(RoomBase):
         self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429)
 
 
+class RoomJoinTestCase(RoomBase):
+
+    servlets = [
+        admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, homeserver):
+        self.user1 = self.register_user("thomas", "hackme")
+        self.tok1 = self.login("thomas", "hackme")
+
+        self.user2 = self.register_user("teresa", "hackme")
+        self.tok2 = self.login("teresa", "hackme")
+
+        self.room1 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
+        self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
+        self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
+
+    def test_spam_checker_may_join_room(self):
+        """Tests that the user_may_join_room spam checker callback is correctly called
+        and blocks room joins when needed.
+        """
+
+        # Register a dummy callback. Make it allow all room joins for now.
+        return_value = True
+
+        async def user_may_join_room(
+            userid: str,
+            room_id: str,
+            is_invited: bool,
+        ) -> bool:
+            return return_value
+
+        callback_mock = Mock(side_effect=user_may_join_room)
+        self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock)
+
+        # Join a first room, without being invited to it.
+        self.helper.join(self.room1, self.user2, tok=self.tok2)
+
+        # Check that the callback was called with the right arguments.
+        expected_call_args = (
+            (
+                self.user2,
+                self.room1,
+                False,
+            ),
+        )
+        self.assertEquals(
+            callback_mock.call_args,
+            expected_call_args,
+            callback_mock.call_args,
+        )
+
+        # Join a second room, this time with an invite for it.
+        self.helper.invite(self.room2, self.user1, self.user2, tok=self.tok1)
+        self.helper.join(self.room2, self.user2, tok=self.tok2)
+
+        # Check that the callback was called with the right arguments.
+        expected_call_args = (
+            (
+                self.user2,
+                self.room2,
+                True,
+            ),
+        )
+        self.assertEquals(
+            callback_mock.call_args,
+            expected_call_args,
+            callback_mock.call_args,
+        )
+
+        # Now make the callback deny all room joins, and check that a join actually fails.
+        return_value = False
+        self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2)
+
+
 class RoomJoinRatelimitTestCase(RoomBase):
     user_id = "@sid1:red"