summary refs log tree commit diff
path: root/tests/rest/client
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client')
-rw-r--r--tests/rest/client/test_rooms.py175
1 files changed, 170 insertions, 5 deletions
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 4be83dfd6d..35c59ee9e0 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -18,10 +18,13 @@
 """Tests REST events for /rooms paths."""
 
 import json
-from typing import Any, Dict, Iterable, List, Optional
+from typing import Any, Dict, Iterable, List, Optional, Union
 from unittest.mock import Mock, call
 from urllib import parse as urlparse
 
+# `Literal` appears with Python 3.8.
+from typing_extensions import Literal
+
 from twisted.test.proto_helpers import MemoryReactor
 
 import synapse.rest.admin
@@ -777,9 +780,11 @@ class RoomsCreateTestCase(RoomBase):
         channel = self.make_request("POST", "/createRoom", content)
         self.assertEqual(200, channel.code)
 
-    def test_spam_checker_may_join_room(self) -> None:
+    def test_spam_checker_may_join_room_deprecated(self) -> None:
         """Tests that the user_may_join_room spam checker callback is correctly bypassed
         when creating a new room.
+
+        In this test, we use the deprecated API in which callbacks return a bool.
         """
 
         async def user_may_join_room(
@@ -801,6 +806,32 @@ class RoomsCreateTestCase(RoomBase):
 
         self.assertEqual(join_mock.call_count, 0)
 
+    def test_spam_checker_may_join_room(self) -> None:
+        """Tests that the user_may_join_room spam checker callback is correctly bypassed
+        when creating a new room.
+
+        In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`.
+        """
+
+        async def user_may_join_room(
+            mxid: str,
+            room_id: str,
+            is_invite: bool,
+        ) -> Codes:
+            return Codes.CONSENT_NOT_GIVEN
+
+        join_mock = Mock(side_effect=user_may_join_room)
+        self.hs.get_spam_checker()._user_may_join_room_callbacks.append(join_mock)
+
+        channel = self.make_request(
+            "POST",
+            "/createRoom",
+            {},
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        self.assertEqual(join_mock.call_count, 0)
+
 
 class RoomTopicTestCase(RoomBase):
     """Tests /rooms/$room_id/topic REST events."""
@@ -1011,9 +1042,11 @@ class RoomJoinTestCase(RoomBase):
         self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
         self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
 
-    def test_spam_checker_may_join_room(self) -> None:
+    def test_spam_checker_may_join_room_deprecated(self) -> None:
         """Tests that the user_may_join_room spam checker callback is correctly called
         and blocks room joins when needed.
+
+        This test uses the deprecated API, in which callbacks return booleans.
         """
 
         # Register a dummy callback. Make it allow all room joins for now.
@@ -1026,6 +1059,8 @@ class RoomJoinTestCase(RoomBase):
         ) -> bool:
             return return_value
 
+        # `spec` argument is needed for this function mock to have `__qualname__`, which
+        # is needed for `Measure` metrics buried in SpamChecker.
         callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None)
         self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock)
 
@@ -1068,6 +1103,67 @@ class RoomJoinTestCase(RoomBase):
         return_value = False
         self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2)
 
+    def test_spam_checker_may_join_room(self) -> None:
+        """Tests that the user_may_join_room spam checker callback is correctly called
+        and blocks room joins when needed.
+
+        This test uses the latest API to this day, in which callbacks return `NOT_SPAM` or `Codes`.
+        """
+
+        # Register a dummy callback. Make it allow all room joins for now.
+        return_value: Union[Literal["NOT_SPAM"], Codes] = synapse.module_api.NOT_SPAM
+
+        async def user_may_join_room(
+            userid: str,
+            room_id: str,
+            is_invited: bool,
+        ) -> Union[Literal["NOT_SPAM"], Codes]:
+            return return_value
+
+        # `spec` argument is needed for this function mock to have `__qualname__`, which
+        # is needed for `Measure` metrics buried in SpamChecker.
+        callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None)
+        self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock)
+
+        # Join a first room, without being invited to it.
+        self.helper.join(self.room1, self.user2, tok=self.tok2)
+
+        # Check that the callback was called with the right arguments.
+        expected_call_args = (
+            (
+                self.user2,
+                self.room1,
+                False,
+            ),
+        )
+        self.assertEqual(
+            callback_mock.call_args,
+            expected_call_args,
+            callback_mock.call_args,
+        )
+
+        # Join a second room, this time with an invite for it.
+        self.helper.invite(self.room2, self.user1, self.user2, tok=self.tok1)
+        self.helper.join(self.room2, self.user2, tok=self.tok2)
+
+        # Check that the callback was called with the right arguments.
+        expected_call_args = (
+            (
+                self.user2,
+                self.room2,
+                True,
+            ),
+        )
+        self.assertEqual(
+            callback_mock.call_args,
+            expected_call_args,
+            callback_mock.call_args,
+        )
+
+        # Now make the callback deny all room joins, and check that a join actually fails.
+        return_value = Codes.CONSENT_NOT_GIVEN
+        self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2)
+
 
 class RoomJoinRatelimitTestCase(RoomBase):
     user_id = "@sid1:red"
@@ -2945,9 +3041,14 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
 
         self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
 
-    def test_threepid_invite_spamcheck(self) -> None:
+    def test_threepid_invite_spamcheck_deprecated(self) -> None:
+        """
+        Test allowing/blocking threepid invites with a spam-check module.
+
+        In this test, we use the deprecated API in which callbacks return a bool.
+        """
         # Mock a few functions to prevent the test from failing due to failing to talk to
-        # a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we
+        # a remote IS. We keep the mock for make_and_store_3pid_invite around so we
         # can check its call_count later on during the test.
         make_invite_mock = Mock(return_value=make_awaitable(0))
         self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
@@ -3001,3 +3102,67 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
 
         # Also check that it stopped before calling _make_and_store_3pid_invite.
         make_invite_mock.assert_called_once()
+
+    def test_threepid_invite_spamcheck(self) -> None:
+        """
+        Test allowing/blocking threepid invites with a spam-check module.
+
+        In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`."""
+        # Mock a few functions to prevent the test from failing due to failing to talk to
+        # a remote IS. We keep the mock for make_and_store_3pid_invite around so we
+        # can check its call_count later on during the test.
+        make_invite_mock = Mock(return_value=make_awaitable(0))
+        self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
+        self.hs.get_identity_handler().lookup_3pid = Mock(
+            return_value=make_awaitable(None),
+        )
+
+        # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it
+        # allow everything for now.
+        # `spec` argument is needed for this function mock to have `__qualname__`, which
+        # is needed for `Measure` metrics buried in SpamChecker.
+        mock = Mock(
+            return_value=make_awaitable(synapse.module_api.NOT_SPAM),
+            spec=lambda *x: None,
+        )
+        self.hs.get_spam_checker()._user_may_send_3pid_invite_callbacks.append(mock)
+
+        # Send a 3PID invite into the room and check that it succeeded.
+        email_to_invite = "teresa@example.com"
+        channel = self.make_request(
+            method="POST",
+            path="/rooms/" + self.room_id + "/invite",
+            content={
+                "id_server": "example.com",
+                "id_access_token": "sometoken",
+                "medium": "email",
+                "address": email_to_invite,
+            },
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.code, 200)
+
+        # Check that the callback was called with the right params.
+        mock.assert_called_with(self.user_id, "email", email_to_invite, self.room_id)
+
+        # Check that the call to send the invite was made.
+        make_invite_mock.assert_called_once()
+
+        # Now change the return value of the callback to deny any invite and test that
+        # we can't send the invite.
+        mock.return_value = make_awaitable(Codes.CONSENT_NOT_GIVEN)
+        channel = self.make_request(
+            method="POST",
+            path="/rooms/" + self.room_id + "/invite",
+            content={
+                "id_server": "example.com",
+                "id_access_token": "sometoken",
+                "medium": "email",
+                "address": email_to_invite,
+            },
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.code, 403)
+
+        # Also check that it stopped before calling _make_and_store_3pid_invite.
+        make_invite_mock.assert_called_once()