diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 62e4db23ef..aa84906548 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -728,6 +728,7 @@ class RelationsTestCase(BaseRelationsTestCase):
class RelationPaginationTestCase(BaseRelationsTestCase):
+ @unittest.override_config({"experimental_features": {"msc3715_enabled": True}})
def test_basic_paginate_relations(self) -> None:
"""Tests that calling pagination API correctly the latest relations."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index f523d89b8f..35c59ee9e0 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -18,10 +18,13 @@
"""Tests REST events for /rooms paths."""
import json
-from typing import Any, Dict, Iterable, List, Optional
+from typing import Any, Dict, Iterable, List, Optional, Union
from unittest.mock import Mock, call
from urllib import parse as urlparse
+# `Literal` appears with Python 3.8.
+from typing_extensions import Literal
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -42,6 +45,7 @@ from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
+from tests.http.server._base import make_request_with_cancellation_test
from tests.test_utils import make_awaitable
PATH_PREFIX = b"/_matrix/client/api/v1"
@@ -471,6 +475,49 @@ class RoomPermissionsTestCase(RoomBase):
)
+class RoomStateTestCase(RoomBase):
+ """Tests /rooms/$room_id/state."""
+
+ user_id = "@sid1:red"
+
+ def test_get_state_cancellation(self) -> None:
+ """Test cancellation of a `/rooms/$room_id/state` request."""
+ room_id = self.helper.create_room_as(self.user_id)
+ channel = make_request_with_cancellation_test(
+ "test_state_cancellation",
+ self.reactor,
+ self.site,
+ "GET",
+ "/rooms/%s/state" % room_id,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertCountEqual(
+ [state_event["type"] for state_event in channel.json_body],
+ {
+ "m.room.create",
+ "m.room.power_levels",
+ "m.room.join_rules",
+ "m.room.member",
+ "m.room.history_visibility",
+ },
+ )
+
+ def test_get_state_event_cancellation(self) -> None:
+ """Test cancellation of a `/rooms/$room_id/state/$event_type` request."""
+ room_id = self.helper.create_room_as(self.user_id)
+ channel = make_request_with_cancellation_test(
+ "test_state_cancellation",
+ self.reactor,
+ self.site,
+ "GET",
+ "/rooms/%s/state/m.room.member/%s" % (room_id, self.user_id),
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(channel.json_body, {"membership": "join"})
+
+
class RoomsMemberListTestCase(RoomBase):
"""Tests /rooms/$room_id/members/list REST events."""
@@ -591,6 +638,62 @@ class RoomsMemberListTestCase(RoomBase):
channel = self.make_request("GET", room_path)
self.assertEqual(200, channel.code, msg=channel.result["body"])
+ def test_get_member_list_cancellation(self) -> None:
+ """Test cancellation of a `/rooms/$room_id/members` request."""
+ room_id = self.helper.create_room_as(self.user_id)
+ channel = make_request_with_cancellation_test(
+ "test_get_member_list_cancellation",
+ self.reactor,
+ self.site,
+ "GET",
+ "/rooms/%s/members" % room_id,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["chunk"]), 1)
+ self.assertLessEqual(
+ {
+ "content": {"membership": "join"},
+ "room_id": room_id,
+ "sender": self.user_id,
+ "state_key": self.user_id,
+ "type": "m.room.member",
+ "user_id": self.user_id,
+ }.items(),
+ channel.json_body["chunk"][0].items(),
+ )
+
+ def test_get_member_list_with_at_token_cancellation(self) -> None:
+ """Test cancellation of a `/rooms/$room_id/members?at=<sync token>` request."""
+ room_id = self.helper.create_room_as(self.user_id)
+
+ # first sync to get an at token
+ channel = self.make_request("GET", "/sync")
+ self.assertEqual(200, channel.code)
+ sync_token = channel.json_body["next_batch"]
+
+ channel = make_request_with_cancellation_test(
+ "test_get_member_list_with_at_token_cancellation",
+ self.reactor,
+ self.site,
+ "GET",
+ "/rooms/%s/members?at=%s" % (room_id, sync_token),
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["chunk"]), 1)
+ self.assertLessEqual(
+ {
+ "content": {"membership": "join"},
+ "room_id": room_id,
+ "sender": self.user_id,
+ "state_key": self.user_id,
+ "type": "m.room.member",
+ "user_id": self.user_id,
+ }.items(),
+ channel.json_body["chunk"][0].items(),
+ )
+
class RoomsCreateTestCase(RoomBase):
"""Tests /rooms and /rooms/$room_id REST events."""
@@ -677,9 +780,11 @@ class RoomsCreateTestCase(RoomBase):
channel = self.make_request("POST", "/createRoom", content)
self.assertEqual(200, channel.code)
- def test_spam_checker_may_join_room(self) -> None:
+ def test_spam_checker_may_join_room_deprecated(self) -> None:
"""Tests that the user_may_join_room spam checker callback is correctly bypassed
when creating a new room.
+
+ In this test, we use the deprecated API in which callbacks return a bool.
"""
async def user_may_join_room(
@@ -701,6 +806,32 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(join_mock.call_count, 0)
+ def test_spam_checker_may_join_room(self) -> None:
+ """Tests that the user_may_join_room spam checker callback is correctly bypassed
+ when creating a new room.
+
+ In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`.
+ """
+
+ async def user_may_join_room(
+ mxid: str,
+ room_id: str,
+ is_invite: bool,
+ ) -> Codes:
+ return Codes.CONSENT_NOT_GIVEN
+
+ join_mock = Mock(side_effect=user_may_join_room)
+ self.hs.get_spam_checker()._user_may_join_room_callbacks.append(join_mock)
+
+ channel = self.make_request(
+ "POST",
+ "/createRoom",
+ {},
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ self.assertEqual(join_mock.call_count, 0)
+
class RoomTopicTestCase(RoomBase):
"""Tests /rooms/$room_id/topic REST events."""
@@ -911,9 +1042,11 @@ class RoomJoinTestCase(RoomBase):
self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
- def test_spam_checker_may_join_room(self) -> None:
+ def test_spam_checker_may_join_room_deprecated(self) -> None:
"""Tests that the user_may_join_room spam checker callback is correctly called
and blocks room joins when needed.
+
+ This test uses the deprecated API, in which callbacks return booleans.
"""
# Register a dummy callback. Make it allow all room joins for now.
@@ -926,6 +1059,8 @@ class RoomJoinTestCase(RoomBase):
) -> bool:
return return_value
+ # `spec` argument is needed for this function mock to have `__qualname__`, which
+ # is needed for `Measure` metrics buried in SpamChecker.
callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None)
self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock)
@@ -968,6 +1103,67 @@ class RoomJoinTestCase(RoomBase):
return_value = False
self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2)
+ def test_spam_checker_may_join_room(self) -> None:
+ """Tests that the user_may_join_room spam checker callback is correctly called
+ and blocks room joins when needed.
+
+ This test uses the latest API to this day, in which callbacks return `NOT_SPAM` or `Codes`.
+ """
+
+ # Register a dummy callback. Make it allow all room joins for now.
+ return_value: Union[Literal["NOT_SPAM"], Codes] = synapse.module_api.NOT_SPAM
+
+ async def user_may_join_room(
+ userid: str,
+ room_id: str,
+ is_invited: bool,
+ ) -> Union[Literal["NOT_SPAM"], Codes]:
+ return return_value
+
+ # `spec` argument is needed for this function mock to have `__qualname__`, which
+ # is needed for `Measure` metrics buried in SpamChecker.
+ callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None)
+ self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock)
+
+ # Join a first room, without being invited to it.
+ self.helper.join(self.room1, self.user2, tok=self.tok2)
+
+ # Check that the callback was called with the right arguments.
+ expected_call_args = (
+ (
+ self.user2,
+ self.room1,
+ False,
+ ),
+ )
+ self.assertEqual(
+ callback_mock.call_args,
+ expected_call_args,
+ callback_mock.call_args,
+ )
+
+ # Join a second room, this time with an invite for it.
+ self.helper.invite(self.room2, self.user1, self.user2, tok=self.tok1)
+ self.helper.join(self.room2, self.user2, tok=self.tok2)
+
+ # Check that the callback was called with the right arguments.
+ expected_call_args = (
+ (
+ self.user2,
+ self.room2,
+ True,
+ ),
+ )
+ self.assertEqual(
+ callback_mock.call_args,
+ expected_call_args,
+ callback_mock.call_args,
+ )
+
+ # Now make the callback deny all room joins, and check that a join actually fails.
+ return_value = Codes.CONSENT_NOT_GIVEN
+ self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2)
+
class RoomJoinRatelimitTestCase(RoomBase):
user_id = "@sid1:red"
@@ -2845,9 +3041,14 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
- def test_threepid_invite_spamcheck(self) -> None:
+ def test_threepid_invite_spamcheck_deprecated(self) -> None:
+ """
+ Test allowing/blocking threepid invites with a spam-check module.
+
+ In this test, we use the deprecated API in which callbacks return a bool.
+ """
# Mock a few functions to prevent the test from failing due to failing to talk to
- # a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we
+ # a remote IS. We keep the mock for make_and_store_3pid_invite around so we
# can check its call_count later on during the test.
make_invite_mock = Mock(return_value=make_awaitable(0))
self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
@@ -2901,3 +3102,67 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# Also check that it stopped before calling _make_and_store_3pid_invite.
make_invite_mock.assert_called_once()
+
+ def test_threepid_invite_spamcheck(self) -> None:
+ """
+ Test allowing/blocking threepid invites with a spam-check module.
+
+ In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`."""
+ # Mock a few functions to prevent the test from failing due to failing to talk to
+ # a remote IS. We keep the mock for make_and_store_3pid_invite around so we
+ # can check its call_count later on during the test.
+ make_invite_mock = Mock(return_value=make_awaitable(0))
+ self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
+ self.hs.get_identity_handler().lookup_3pid = Mock(
+ return_value=make_awaitable(None),
+ )
+
+ # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it
+ # allow everything for now.
+ # `spec` argument is needed for this function mock to have `__qualname__`, which
+ # is needed for `Measure` metrics buried in SpamChecker.
+ mock = Mock(
+ return_value=make_awaitable(synapse.module_api.NOT_SPAM),
+ spec=lambda *x: None,
+ )
+ self.hs.get_spam_checker()._user_may_send_3pid_invite_callbacks.append(mock)
+
+ # Send a 3PID invite into the room and check that it succeeded.
+ email_to_invite = "teresa@example.com"
+ channel = self.make_request(
+ method="POST",
+ path="/rooms/" + self.room_id + "/invite",
+ content={
+ "id_server": "example.com",
+ "id_access_token": "sometoken",
+ "medium": "email",
+ "address": email_to_invite,
+ },
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ # Check that the callback was called with the right params.
+ mock.assert_called_with(self.user_id, "email", email_to_invite, self.room_id)
+
+ # Check that the call to send the invite was made.
+ make_invite_mock.assert_called_once()
+
+ # Now change the return value of the callback to deny any invite and test that
+ # we can't send the invite.
+ mock.return_value = make_awaitable(Codes.CONSENT_NOT_GIVEN)
+ channel = self.make_request(
+ method="POST",
+ path="/rooms/" + self.room_id + "/invite",
+ content={
+ "id_server": "example.com",
+ "id_access_token": "sometoken",
+ "medium": "email",
+ "address": email_to_invite,
+ },
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 403)
+
+ # Also check that it stopped before calling _make_and_store_3pid_invite.
+ make_invite_mock.assert_called_once()
|