diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index f523d89b8f..c7eb88d33f 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -18,10 +18,14 @@
"""Tests REST events for /rooms paths."""
import json
-from typing import Any, Dict, Iterable, List, Optional
+from http import HTTPStatus
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from unittest.mock import Mock, call
from urllib import parse as urlparse
+from parameterized import param, parameterized
+from typing_extensions import Literal
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -30,7 +34,9 @@ from synapse.api.constants import (
EventContentFields,
EventTypes,
Membership,
+ PublicRoomsFilterFields,
RelationTypes,
+ RoomTypes,
)
from synapse.api.errors import Codes, HttpResponseException
from synapse.handlers.pagination import PurgeStatus
@@ -42,6 +48,7 @@ from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
+from tests.http.server._base import make_request_with_cancellation_test
from tests.test_utils import make_awaitable
PATH_PREFIX = b"/_matrix/client/api/v1"
@@ -98,7 +105,7 @@ class RoomPermissionsTestCase(RoomBase):
channel = self.make_request(
"PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}'
)
- self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
# set topic for public room
channel = self.make_request(
@@ -106,7 +113,7 @@ class RoomPermissionsTestCase(RoomBase):
("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"),
b'{"topic":"Public Room Topic"}',
)
- self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
# auth as user_id now
self.helper.auth_user_id = self.user_id
@@ -128,28 +135,28 @@ class RoomPermissionsTestCase(RoomBase):
"/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,),
msg_content,
)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# send message in created room not joined (no state), expect 403
channel = self.make_request("PUT", send_msg_path(), msg_content)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# send message in created room and invited, expect 403
self.helper.invite(
room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
)
channel = self.make_request("PUT", send_msg_path(), msg_content)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# send message in created room and joined, expect 200
self.helper.join(room=self.created_rmid, user=self.user_id)
channel = self.make_request("PUT", send_msg_path(), msg_content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# send message in created room and left, expect 403
self.helper.leave(room=self.created_rmid, user=self.user_id)
channel = self.make_request("PUT", send_msg_path(), msg_content)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
def test_topic_perms(self) -> None:
topic_content = b'{"topic":"My Topic Name"}'
@@ -159,28 +166,28 @@ class RoomPermissionsTestCase(RoomBase):
channel = self.make_request(
"PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content
)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
channel = self.make_request(
"GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid
)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# set/get topic in created PRIVATE room not joined, expect 403
channel = self.make_request("PUT", topic_path, topic_content)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
channel = self.make_request("GET", topic_path)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# set topic in created PRIVATE room and invited, expect 403
self.helper.invite(
room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
)
channel = self.make_request("PUT", topic_path, topic_content)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# get topic in created PRIVATE room and invited, expect 403
channel = self.make_request("GET", topic_path)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# set/get topic in created PRIVATE room and joined, expect 200
self.helper.join(room=self.created_rmid, user=self.user_id)
@@ -188,25 +195,25 @@ class RoomPermissionsTestCase(RoomBase):
# Only room ops can set topic by default
self.helper.auth_user_id = self.rmcreator_id
channel = self.make_request("PUT", topic_path, topic_content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.helper.auth_user_id = self.user_id
channel = self.make_request("GET", topic_path)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body)
# set/get topic in created PRIVATE room and left, expect 403
self.helper.leave(room=self.created_rmid, user=self.user_id)
channel = self.make_request("PUT", topic_path, topic_content)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
channel = self.make_request("GET", topic_path)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# get topic in PUBLIC room, not joined, expect 403
channel = self.make_request(
"GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid
)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# set topic in PUBLIC room, not joined, expect 403
channel = self.make_request(
@@ -214,7 +221,7 @@ class RoomPermissionsTestCase(RoomBase):
"/rooms/%s/state/m.room.topic" % self.created_public_rmid,
topic_content,
)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
def _test_get_membership(
self, room: str, members: Iterable = frozenset(), expect_code: int = 200
@@ -303,14 +310,14 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=self.rmcreator_id,
membership=Membership.JOIN,
- expect_code=403,
+ expect_code=HTTPStatus.FORBIDDEN,
)
self.helper.change_membership(
room=room,
src=self.user_id,
targ=self.rmcreator_id,
membership=Membership.LEAVE,
- expect_code=403,
+ expect_code=HTTPStatus.FORBIDDEN,
)
def test_joined_permissions(self) -> None:
@@ -336,7 +343,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=other,
membership=Membership.JOIN,
- expect_code=403,
+ expect_code=HTTPStatus.FORBIDDEN,
)
# set left of other, expect 403
@@ -345,7 +352,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=other,
membership=Membership.LEAVE,
- expect_code=403,
+ expect_code=HTTPStatus.FORBIDDEN,
)
# set left of self, expect 200
@@ -365,7 +372,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=usr,
membership=Membership.INVITE,
- expect_code=403,
+ expect_code=HTTPStatus.FORBIDDEN,
)
self.helper.change_membership(
@@ -373,7 +380,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=usr,
membership=Membership.JOIN,
- expect_code=403,
+ expect_code=HTTPStatus.FORBIDDEN,
)
# It is always valid to LEAVE if you've already left (currently.)
@@ -382,7 +389,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=self.rmcreator_id,
membership=Membership.LEAVE,
- expect_code=403,
+ expect_code=HTTPStatus.FORBIDDEN,
)
# tests the "from banned" line from the table in https://spec.matrix.org/unstable/client-server-api/#mroommember
@@ -399,7 +406,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=other,
membership=Membership.BAN,
- expect_code=403, # expect failure
+ expect_code=HTTPStatus.FORBIDDEN, # expect failure
expect_errcode=Codes.FORBIDDEN,
)
@@ -409,7 +416,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.rmcreator_id,
targ=other,
membership=Membership.BAN,
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
# from ban to invite: Must never happen.
@@ -418,7 +425,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.rmcreator_id,
targ=other,
membership=Membership.INVITE,
- expect_code=403, # expect failure
+ expect_code=HTTPStatus.FORBIDDEN, # expect failure
expect_errcode=Codes.BAD_STATE,
)
@@ -428,7 +435,7 @@ class RoomPermissionsTestCase(RoomBase):
src=other,
targ=other,
membership=Membership.JOIN,
- expect_code=403, # expect failure
+ expect_code=HTTPStatus.FORBIDDEN, # expect failure
expect_errcode=Codes.BAD_STATE,
)
@@ -438,7 +445,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.rmcreator_id,
targ=other,
membership=Membership.BAN,
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
# from ban to knock: Must never happen.
@@ -447,7 +454,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.rmcreator_id,
targ=other,
membership=Membership.KNOCK,
- expect_code=403, # expect failure
+ expect_code=HTTPStatus.FORBIDDEN, # expect failure
expect_errcode=Codes.BAD_STATE,
)
@@ -457,7 +464,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=other,
membership=Membership.LEAVE,
- expect_code=403, # expect failure
+ expect_code=HTTPStatus.FORBIDDEN, # expect failure
expect_errcode=Codes.FORBIDDEN,
)
@@ -467,10 +474,53 @@ class RoomPermissionsTestCase(RoomBase):
src=self.rmcreator_id,
targ=other,
membership=Membership.LEAVE,
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
+class RoomStateTestCase(RoomBase):
+ """Tests /rooms/$room_id/state."""
+
+ user_id = "@sid1:red"
+
+ def test_get_state_cancellation(self) -> None:
+ """Test cancellation of a `/rooms/$room_id/state` request."""
+ room_id = self.helper.create_room_as(self.user_id)
+ channel = make_request_with_cancellation_test(
+ "test_state_cancellation",
+ self.reactor,
+ self.site,
+ "GET",
+ "/rooms/%s/state" % room_id,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
+ self.assertCountEqual(
+ [state_event["type"] for state_event in channel.json_list],
+ {
+ "m.room.create",
+ "m.room.power_levels",
+ "m.room.join_rules",
+ "m.room.member",
+ "m.room.history_visibility",
+ },
+ )
+
+ def test_get_state_event_cancellation(self) -> None:
+ """Test cancellation of a `/rooms/$room_id/state/$event_type` request."""
+ room_id = self.helper.create_room_as(self.user_id)
+ channel = make_request_with_cancellation_test(
+ "test_state_cancellation",
+ self.reactor,
+ self.site,
+ "GET",
+ "/rooms/%s/state/m.room.member/%s" % (room_id, self.user_id),
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
+ self.assertEqual(channel.json_body, {"membership": "join"})
+
+
class RoomsMemberListTestCase(RoomBase):
"""Tests /rooms/$room_id/members/list REST events."""
@@ -481,16 +531,16 @@ class RoomsMemberListTestCase(RoomBase):
def test_get_member_list(self) -> None:
room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
def test_get_member_list_no_room(self) -> None:
channel = self.make_request("GET", "/rooms/roomdoesnotexist/members")
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission(self) -> None:
room_id = self.helper.create_room_as("@some_other_guy:red")
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission_with_at_token(self) -> None:
"""
@@ -501,7 +551,7 @@ class RoomsMemberListTestCase(RoomBase):
# first sync to get an at token
channel = self.make_request("GET", "/sync")
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
sync_token = channel.json_body["next_batch"]
# check that permission is denied for @sid1:red to get the
@@ -510,7 +560,7 @@ class RoomsMemberListTestCase(RoomBase):
"GET",
f"/rooms/{room_id}/members?at={sync_token}",
)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission_former_member(self) -> None:
"""
@@ -523,14 +573,14 @@ class RoomsMemberListTestCase(RoomBase):
# check that the user can see the member list to start with
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# ban the user
self.helper.change_membership(room_id, "@alice:red", self.user_id, "ban")
# check the user can no longer see the member list
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission_former_member_with_at_token(self) -> None:
"""
@@ -544,14 +594,14 @@ class RoomsMemberListTestCase(RoomBase):
# sync to get an at token
channel = self.make_request("GET", "/sync")
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
sync_token = channel.json_body["next_batch"]
# check that the user can see the member list to start with
channel = self.make_request(
"GET", "/rooms/%s/members?at=%s" % (room_id, sync_token)
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# ban the user (Note: the user is actually allowed to see this event and
# state so that they know they're banned!)
@@ -563,14 +613,14 @@ class RoomsMemberListTestCase(RoomBase):
# now, with the original user, sync again to get a new at token
channel = self.make_request("GET", "/sync")
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
sync_token = channel.json_body["next_batch"]
# check the user can no longer see the updated member list
channel = self.make_request(
"GET", "/rooms/%s/members?at=%s" % (room_id, sync_token)
)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
def test_get_member_list_mixed_memberships(self) -> None:
room_creator = "@some_other_guy:red"
@@ -579,17 +629,73 @@ class RoomsMemberListTestCase(RoomBase):
self.helper.invite(room=room_id, src=room_creator, targ=self.user_id)
# can't see list if you're just invited.
channel = self.make_request("GET", room_path)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
self.helper.join(room=room_id, user=self.user_id)
# can see list now joined
channel = self.make_request("GET", room_path)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.helper.leave(room=room_id, user=self.user_id)
# can see old list once left
channel = self.make_request("GET", room_path)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
+
+ def test_get_member_list_cancellation(self) -> None:
+ """Test cancellation of a `/rooms/$room_id/members` request."""
+ room_id = self.helper.create_room_as(self.user_id)
+ channel = make_request_with_cancellation_test(
+ "test_get_member_list_cancellation",
+ self.reactor,
+ self.site,
+ "GET",
+ "/rooms/%s/members" % room_id,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["chunk"]), 1)
+ self.assertLessEqual(
+ {
+ "content": {"membership": "join"},
+ "room_id": room_id,
+ "sender": self.user_id,
+ "state_key": self.user_id,
+ "type": "m.room.member",
+ "user_id": self.user_id,
+ }.items(),
+ channel.json_body["chunk"][0].items(),
+ )
+
+ def test_get_member_list_with_at_token_cancellation(self) -> None:
+ """Test cancellation of a `/rooms/$room_id/members?at=<sync token>` request."""
+ room_id = self.helper.create_room_as(self.user_id)
+
+ # first sync to get an at token
+ channel = self.make_request("GET", "/sync")
+ self.assertEqual(HTTPStatus.OK, channel.code)
+ sync_token = channel.json_body["next_batch"]
+
+ channel = make_request_with_cancellation_test(
+ "test_get_member_list_with_at_token_cancellation",
+ self.reactor,
+ self.site,
+ "GET",
+ "/rooms/%s/members?at=%s" % (room_id, sync_token),
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["chunk"]), 1)
+ self.assertLessEqual(
+ {
+ "content": {"membership": "join"},
+ "room_id": room_id,
+ "sender": self.user_id,
+ "state_key": self.user_id,
+ "type": "m.room.member",
+ "user_id": self.user_id,
+ }.items(),
+ channel.json_body["chunk"][0].items(),
+ )
class RoomsCreateTestCase(RoomBase):
@@ -601,19 +707,34 @@ class RoomsCreateTestCase(RoomBase):
# POST with no config keys, expect new room id
channel = self.make_request("POST", "/createRoom", "{}")
- self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
+ assert channel.resource_usage is not None
+ self.assertEqual(44, channel.resource_usage.db_txn_count)
+
+ def test_post_room_initial_state(self) -> None:
+ # POST with initial_state config key, expect new room id
+ channel = self.make_request(
+ "POST",
+ "/createRoom",
+ b'{"initial_state":[{"type": "m.bridge", "content": {}}]}',
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+ self.assertTrue("room_id" in channel.json_body)
+ assert channel.resource_usage is not None
+ self.assertEqual(50, channel.resource_usage.db_txn_count)
def test_post_room_visibility_key(self) -> None:
# POST with visibility config key, expect new room id
channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}')
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_custom_key(self) -> None:
# POST with custom config keys, expect new room id
channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}')
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_known_and_unknown_keys(self) -> None:
@@ -621,16 +742,16 @@ class RoomsCreateTestCase(RoomBase):
channel = self.make_request(
"POST", "/createRoom", b'{"visibility":"private","custom":"things"}'
)
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_invalid_content(self) -> None:
# POST with invalid content / paths, expect 400
channel = self.make_request("POST", "/createRoom", b'{"visibili')
- self.assertEqual(400, channel.code)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code)
channel = self.make_request("POST", "/createRoom", b'["hello"]')
- self.assertEqual(400, channel.code)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code)
def test_post_room_invitees_invalid_mxid(self) -> None:
# POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088
@@ -638,7 +759,7 @@ class RoomsCreateTestCase(RoomBase):
channel = self.make_request(
"POST", "/createRoom", b'{"invite":["@alice:example.com "]}'
)
- self.assertEqual(400, channel.code)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code)
@unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}})
def test_post_room_invitees_ratelimit(self) -> None:
@@ -649,20 +770,18 @@ class RoomsCreateTestCase(RoomBase):
# Build the request's content. We use local MXIDs because invites over federation
# are more difficult to mock.
- content = json.dumps(
- {
- "invite": [
- "@alice1:red",
- "@alice2:red",
- "@alice3:red",
- "@alice4:red",
- ]
- }
- ).encode("utf8")
+ content = {
+ "invite": [
+ "@alice1:red",
+ "@alice2:red",
+ "@alice3:red",
+ "@alice4:red",
+ ]
+ }
# Test that the invites are correctly ratelimited.
channel = self.make_request("POST", "/createRoom", content)
- self.assertEqual(400, channel.code)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code)
self.assertEqual(
"Cannot invite so many users at once",
channel.json_body["error"],
@@ -675,11 +794,13 @@ class RoomsCreateTestCase(RoomBase):
# Test that the invites aren't ratelimited anymore.
channel = self.make_request("POST", "/createRoom", content)
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
- def test_spam_checker_may_join_room(self) -> None:
+ def test_spam_checker_may_join_room_deprecated(self) -> None:
"""Tests that the user_may_join_room spam checker callback is correctly bypassed
when creating a new room.
+
+ In this test, we use the deprecated API in which callbacks return a bool.
"""
async def user_may_join_room(
@@ -697,10 +818,55 @@ class RoomsCreateTestCase(RoomBase):
"/createRoom",
{},
)
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
self.assertEqual(join_mock.call_count, 0)
+ def test_spam_checker_may_join_room(self) -> None:
+ """Tests that the user_may_join_room spam checker callback is correctly bypassed
+ when creating a new room.
+
+ In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`.
+ """
+
+ async def user_may_join_room_codes(
+ mxid: str,
+ room_id: str,
+ is_invite: bool,
+ ) -> Codes:
+ return Codes.CONSENT_NOT_GIVEN
+
+ join_mock = Mock(side_effect=user_may_join_room_codes)
+ self.hs.get_spam_checker()._user_may_join_room_callbacks.append(join_mock)
+
+ channel = self.make_request(
+ "POST",
+ "/createRoom",
+ {},
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
+
+ self.assertEqual(join_mock.call_count, 0)
+
+ # Now change the return value of the callback to deny any join. Since we're
+ # creating the room, despite the return value, we should be able to join.
+ async def user_may_join_room_tuple(
+ mxid: str,
+ room_id: str,
+ is_invite: bool,
+ ) -> Tuple[Codes, dict]:
+ return Codes.INCOMPATIBLE_ROOM_VERSION, {}
+
+ join_mock.side_effect = user_may_join_room_tuple
+
+ channel = self.make_request(
+ "POST",
+ "/createRoom",
+ {},
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
+ self.assertEqual(join_mock.call_count, 0)
+
class RoomTopicTestCase(RoomBase):
"""Tests /rooms/$room_id/topic REST events."""
@@ -715,54 +881,68 @@ class RoomTopicTestCase(RoomBase):
def test_invalid_puts(self) -> None:
# missing keys or invalid json
channel = self.make_request("PUT", self.path, "{}")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", self.path, '{"_name":"bo"}')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", self.path, '{"nao')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request(
"PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]'
)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", self.path, "text only")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", self.path, "")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
# valid key, wrong type
content = '{"topic":["Topic name"]}'
channel = self.make_request("PUT", self.path, content)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
def test_rooms_topic(self) -> None:
# nothing should be there
channel = self.make_request("GET", self.path)
- self.assertEqual(404, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.result["body"])
# valid put
content = '{"topic":"Topic name"}'
channel = self.make_request("PUT", self.path, content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# valid get
channel = self.make_request("GET", self.path)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(content), channel.json_body)
def test_rooms_topic_with_extra_keys(self) -> None:
# valid put with extra keys
content = '{"topic":"Seasons","subtopic":"Summer"}'
channel = self.make_request("PUT", self.path, content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# valid get
channel = self.make_request("GET", self.path)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(content), channel.json_body)
@@ -778,22 +958,34 @@ class RoomMemberStateTestCase(RoomBase):
path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id)
# missing keys or invalid json
channel = self.make_request("PUT", path, "{}")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, '{"_name":"bo"}')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, '{"nao')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, "text only")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, "")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
# valid keys, wrong types
content = '{"membership":["%s","%s","%s"]}' % (
@@ -802,7 +994,9 @@ class RoomMemberStateTestCase(RoomBase):
Membership.LEAVE,
)
channel = self.make_request("PUT", path, content.encode("ascii"))
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
def test_rooms_members_self(self) -> None:
path = "/rooms/%s/state/m.room.member/%s" % (
@@ -813,10 +1007,10 @@ class RoomMemberStateTestCase(RoomBase):
# valid join message (NOOP since we made the room)
content = '{"membership":"%s"}' % Membership.JOIN
channel = self.make_request("PUT", path, content.encode("ascii"))
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
channel = self.make_request("GET", path, content=b"")
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
expected_response = {"membership": Membership.JOIN}
self.assertEqual(expected_response, channel.json_body)
@@ -831,10 +1025,10 @@ class RoomMemberStateTestCase(RoomBase):
# valid invite message
content = '{"membership":"%s"}' % Membership.INVITE
channel = self.make_request("PUT", path, content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
channel = self.make_request("GET", path, content=b"")
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual(json.loads(content), channel.json_body)
def test_rooms_members_other_custom_keys(self) -> None:
@@ -850,10 +1044,10 @@ class RoomMemberStateTestCase(RoomBase):
"Join us!",
)
channel = self.make_request("PUT", path, content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
channel = self.make_request("GET", path, content=b"")
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual(json.loads(content), channel.json_body)
@@ -911,9 +1105,11 @@ class RoomJoinTestCase(RoomBase):
self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
- def test_spam_checker_may_join_room(self) -> None:
+ def test_spam_checker_may_join_room_deprecated(self) -> None:
"""Tests that the user_may_join_room spam checker callback is correctly called
and blocks room joins when needed.
+
+ This test uses the deprecated API, in which callbacks return booleans.
"""
# Register a dummy callback. Make it allow all room joins for now.
@@ -926,6 +1122,8 @@ class RoomJoinTestCase(RoomBase):
) -> bool:
return return_value
+ # `spec` argument is needed for this function mock to have `__qualname__`, which
+ # is needed for `Measure` metrics buried in SpamChecker.
callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None)
self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock)
@@ -966,7 +1164,92 @@ class RoomJoinTestCase(RoomBase):
# Now make the callback deny all room joins, and check that a join actually fails.
return_value = False
- self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2)
+ self.helper.join(
+ self.room3, self.user2, expect_code=HTTPStatus.FORBIDDEN, tok=self.tok2
+ )
+
+ def test_spam_checker_may_join_room(self) -> None:
+ """Tests that the user_may_join_room spam checker callback is correctly called
+ and blocks room joins when needed.
+
+ This test uses the latest API to this day, in which callbacks return `NOT_SPAM` or `Codes`.
+ """
+
+ # Register a dummy callback. Make it allow all room joins for now.
+ return_value: Union[
+ Literal["NOT_SPAM"], Tuple[Codes, dict], Codes
+ ] = synapse.module_api.NOT_SPAM
+
+ async def user_may_join_room(
+ userid: str,
+ room_id: str,
+ is_invited: bool,
+ ) -> Union[Literal["NOT_SPAM"], Tuple[Codes, dict], Codes]:
+ return return_value
+
+ # `spec` argument is needed for this function mock to have `__qualname__`, which
+ # is needed for `Measure` metrics buried in SpamChecker.
+ callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None)
+ self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock)
+
+ # Join a first room, without being invited to it.
+ self.helper.join(self.room1, self.user2, tok=self.tok2)
+
+ # Check that the callback was called with the right arguments.
+ expected_call_args = (
+ (
+ self.user2,
+ self.room1,
+ False,
+ ),
+ )
+ self.assertEqual(
+ callback_mock.call_args,
+ expected_call_args,
+ callback_mock.call_args,
+ )
+
+ # Join a second room, this time with an invite for it.
+ self.helper.invite(self.room2, self.user1, self.user2, tok=self.tok1)
+ self.helper.join(self.room2, self.user2, tok=self.tok2)
+
+ # Check that the callback was called with the right arguments.
+ expected_call_args = (
+ (
+ self.user2,
+ self.room2,
+ True,
+ ),
+ )
+ self.assertEqual(
+ callback_mock.call_args,
+ expected_call_args,
+ callback_mock.call_args,
+ )
+
+ # Now make the callback deny all room joins, and check that a join actually fails.
+ # We pick an arbitrary Codes rather than the default `Codes.FORBIDDEN`.
+ return_value = Codes.CONSENT_NOT_GIVEN
+ self.helper.invite(self.room3, self.user1, self.user2, tok=self.tok1)
+ self.helper.join(
+ self.room3,
+ self.user2,
+ expect_code=HTTPStatus.FORBIDDEN,
+ expect_errcode=return_value,
+ tok=self.tok2,
+ )
+
+ # Now make the callback deny all room joins, and check that a join actually fails.
+ # As above, with the experimental extension that lets us return dictionaries.
+ return_value = (Codes.BAD_ALIAS, {"another_field": "12345"})
+ self.helper.join(
+ self.room3,
+ self.user2,
+ expect_code=HTTPStatus.FORBIDDEN,
+ expect_errcode=return_value[0],
+ tok=self.tok2,
+ expect_additional_fields=return_value[1],
+ )
class RoomJoinRatelimitTestCase(RoomBase):
@@ -1016,7 +1299,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
# Update the display name for the user.
path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
channel = self.make_request("PUT", path, {"displayname": "John Doe"})
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
# Check that all the rooms have been sent a profile update into.
for room_id in room_ids:
@@ -1081,40 +1364,153 @@ class RoomMessagesTestCase(RoomBase):
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
# missing keys or invalid json
channel = self.make_request("PUT", path, b"{}")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, b'{"_name":"bo"}')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, b'{"nao')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, b"text only")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, b"")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
def test_rooms_messages_sent(self) -> None:
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
content = b'{"body":"test","msgtype":{"type":"a"}}'
channel = self.make_request("PUT", path, content)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
# custom message types
content = b'{"body":"test","msgtype":"test.custom.text"}'
channel = self.make_request("PUT", path, content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# m.text message type
path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id))
content = b'{"body":"test2","msgtype":"m.text"}'
channel = self.make_request("PUT", path, content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
+
+ @parameterized.expand(
+ [
+ # Allow
+ param(
+ name="NOT_SPAM",
+ value="NOT_SPAM",
+ expected_code=HTTPStatus.OK,
+ expected_fields={},
+ ),
+ param(
+ name="False",
+ value=False,
+ expected_code=HTTPStatus.OK,
+ expected_fields={},
+ ),
+ # Block
+ param(
+ name="scalene string",
+ value="ANY OTHER STRING",
+ expected_code=HTTPStatus.FORBIDDEN,
+ expected_fields={"errcode": "M_FORBIDDEN"},
+ ),
+ param(
+ name="True",
+ value=True,
+ expected_code=HTTPStatus.FORBIDDEN,
+ expected_fields={"errcode": "M_FORBIDDEN"},
+ ),
+ param(
+ name="Code",
+ value=Codes.LIMIT_EXCEEDED,
+ expected_code=HTTPStatus.FORBIDDEN,
+ expected_fields={"errcode": "M_LIMIT_EXCEEDED"},
+ ),
+ param(
+ name="Tuple",
+ value=(Codes.SERVER_NOT_TRUSTED, {"additional_field": "12345"}),
+ expected_code=HTTPStatus.FORBIDDEN,
+ expected_fields={
+ "errcode": "M_SERVER_NOT_TRUSTED",
+ "additional_field": "12345",
+ },
+ ),
+ ]
+ )
+ def test_spam_checker_check_event_for_spam(
+ self,
+ name: str,
+ value: Union[str, bool, Codes, Tuple[Codes, JsonDict]],
+ expected_code: int,
+ expected_fields: dict,
+ ) -> None:
+ class SpamCheck:
+ mock_return_value: Union[
+ str, bool, Codes, Tuple[Codes, JsonDict], bool
+ ] = "NOT_SPAM"
+ mock_content: Optional[JsonDict] = None
+
+ async def check_event_for_spam(
+ self,
+ event: synapse.events.EventBase,
+ ) -> Union[str, Codes, Tuple[Codes, JsonDict], bool]:
+ self.mock_content = event.content
+ return self.mock_return_value
+
+ spam_checker = SpamCheck()
+
+ self.hs.get_spam_checker()._check_event_for_spam_callbacks.append(
+ spam_checker.check_event_for_spam
+ )
+
+ # Inject `value` as mock_return_value
+ spam_checker.mock_return_value = value
+ path = "/rooms/%s/send/m.room.message/check_event_for_spam_%s" % (
+ urlparse.quote(self.room_id),
+ urlparse.quote(name),
+ )
+ body = "test-%s" % name
+ content = '{"body":"%s","msgtype":"m.text"}' % body
+ channel = self.make_request("PUT", path, content)
+
+ # Check that the callback has witnessed the correct event.
+ self.assertIsNotNone(spam_checker.mock_content)
+ if (
+ spam_checker.mock_content is not None
+ ): # Checked just above, but mypy doesn't know about that.
+ self.assertEqual(
+ spam_checker.mock_content["body"], body, spam_checker.mock_content
+ )
+
+ # Check that we have the correct result.
+ self.assertEqual(expected_code, channel.code, msg=channel.result["body"])
+ for expected_key, expected_value in expected_fields.items():
+ self.assertEqual(
+ channel.json_body.get(expected_key, None),
+ expected_value,
+ "Field %s absent or invalid " % expected_key,
+ )
class RoomPowerLevelOverridesTestCase(RoomBase):
@@ -1239,7 +1635,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase):
channel = self.make_request("PUT", path, "{}")
# Then I am allowed
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
def test_normal_user_can_not_post_state_event(self) -> None:
# Given I am a normal member of a room
@@ -1253,7 +1649,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase):
channel = self.make_request("PUT", path, "{}")
# Then I am not allowed because state events require PL>=50
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
self.assertEqual(
"You don't have permission to post that to the room. "
"user_level (0) < send_level (50)",
@@ -1280,7 +1676,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase):
channel = self.make_request("PUT", path, "{}")
# Then I am allowed
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
@unittest.override_config(
{
@@ -1308,7 +1704,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase):
channel = self.make_request("PUT", path, "{}")
# Then I am not allowed
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
@unittest.override_config(
{
@@ -1336,7 +1732,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase):
channel = self.make_request("PUT", path, "{}")
# Then I am not allowed
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
self.assertEqual(
"You don't have permission to post that to the room. "
+ "user_level (0) < send_level (1)",
@@ -1367,7 +1763,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase):
# Then I am not allowed because the public_chat config does not
# affect this room, because this room is a private_chat
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
self.assertEqual(
"You don't have permission to post that to the room. "
+ "user_level (0) < send_level (50)",
@@ -1386,7 +1782,7 @@ class RoomInitialSyncTestCase(RoomBase):
def test_initial_sync(self) -> None:
channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id)
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
self.assertEqual(self.room_id, channel.json_body["room_id"])
self.assertEqual("join", channel.json_body["membership"])
@@ -1429,7 +1825,7 @@ class RoomMessageListTestCase(RoomBase):
channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
)
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
self.assertTrue("start" in channel.json_body)
self.assertEqual(token, channel.json_body["start"])
self.assertTrue("chunk" in channel.json_body)
@@ -1440,7 +1836,7 @@ class RoomMessageListTestCase(RoomBase):
channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
)
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
self.assertTrue("start" in channel.json_body)
self.assertEqual(token, channel.json_body["start"])
self.assertTrue("chunk" in channel.json_body)
@@ -1479,7 +1875,7 @@ class RoomMessageListTestCase(RoomBase):
json.dumps({"types": [EventTypes.Message]}),
),
)
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
chunk = channel.json_body["chunk"]
self.assertEqual(len(chunk), 2, [event["content"] for event in chunk])
@@ -1507,7 +1903,7 @@ class RoomMessageListTestCase(RoomBase):
json.dumps({"types": [EventTypes.Message]}),
),
)
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
chunk = channel.json_body["chunk"]
self.assertEqual(len(chunk), 1, [event["content"] for event in chunk])
@@ -1524,7 +1920,7 @@ class RoomMessageListTestCase(RoomBase):
json.dumps({"types": [EventTypes.Message]}),
),
)
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
chunk = channel.json_body["chunk"]
self.assertEqual(len(chunk), 0, [event["content"] for event in chunk])
@@ -1652,14 +2048,97 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
def test_restricted_no_auth(self) -> None:
channel = self.make_request("GET", self.url)
- self.assertEqual(channel.code, 401, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
def test_restricted_auth(self) -> None:
self.register_user("user", "pass")
tok = self.login("user", "pass")
channel = self.make_request("GET", self.url, access_token=tok)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+
+class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+
+ config = self.default_config()
+ config["allow_public_rooms_without_auth"] = True
+ self.hs = self.setup_test_homeserver(config=config)
+ self.url = b"/_matrix/client/r0/publicRooms"
+
+ return self.hs
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ user = self.register_user("alice", "pass")
+ self.token = self.login(user, "pass")
+
+ # Create a room
+ self.helper.create_room_as(
+ user,
+ is_public=True,
+ extra_content={"visibility": "public"},
+ tok=self.token,
+ )
+ # Create a space
+ self.helper.create_room_as(
+ user,
+ is_public=True,
+ extra_content={
+ "visibility": "public",
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE},
+ },
+ tok=self.token,
+ )
+
+ def make_public_rooms_request(
+ self, room_types: Union[List[Union[str, None]], None]
+ ) -> Tuple[List[Dict[str, Any]], int]:
+ channel = self.make_request(
+ "POST",
+ self.url,
+ {"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}},
+ self.token,
+ )
+ chunk = channel.json_body["chunk"]
+ count = channel.json_body["total_room_count_estimate"]
+
+ self.assertEqual(len(chunk), count)
+
+ return chunk, count
+
+ def test_returns_both_rooms_and_spaces_if_no_filter(self) -> None:
+ chunk, count = self.make_public_rooms_request(None)
+
+ self.assertEqual(count, 2)
+
+ def test_returns_only_rooms_based_on_filter(self) -> None:
+ chunk, count = self.make_public_rooms_request([None])
+
+ self.assertEqual(count, 1)
+ self.assertEqual(chunk[0].get("room_type", None), None)
+
+ def test_returns_only_space_based_on_filter(self) -> None:
+ chunk, count = self.make_public_rooms_request(["m.space"])
+
+ self.assertEqual(count, 1)
+ self.assertEqual(chunk[0].get("room_type", None), "m.space")
+
+ def test_returns_both_rooms_and_space_based_on_filter(self) -> None:
+ chunk, count = self.make_public_rooms_request(["m.space", None])
+
+ self.assertEqual(count, 2)
+
+ def test_returns_both_rooms_and_spaces_if_array_is_empty(self) -> None:
+ chunk, count = self.make_public_rooms_request([])
+
+ self.assertEqual(count, 2)
class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
@@ -1686,7 +2165,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
"Simple test for searching rooms over federation"
self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined]
- search_filter = {"generic_search_term": "foobar"}
+ search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"}
channel = self.make_request(
"POST",
@@ -1694,7 +2173,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
content={"filter": search_filter},
access_token=self.token,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.federation_client.get_public_rooms.assert_called_once_with( # type: ignore[attr-defined]
"testserv",
@@ -1711,11 +2190,11 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
# The `get_public_rooms` should be called again if the first call fails
# with a 404, when using search filters.
self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined]
- HttpResponseException(404, "Not Found", b""),
+ HttpResponseException(HTTPStatus.NOT_FOUND, "Not Found", b""),
make_awaitable({}),
)
- search_filter = {"generic_search_term": "foobar"}
+ search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"}
channel = self.make_request(
"POST",
@@ -1723,7 +2202,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
content={"filter": search_filter},
access_token=self.token,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.federation_client.get_public_rooms.assert_has_calls( # type: ignore[attr-defined]
[
@@ -1769,21 +2248,19 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
# Set a profile for the test user
self.displayname = "test user"
- data = {"displayname": self.displayname}
- request_data = json.dumps(data)
+ request_data = {"displayname": self.displayname}
channel = self.make_request(
"PUT",
"/_matrix/client/r0/profile/%s/displayname" % (self.user_id,),
request_data,
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
def test_per_room_profile_forbidden(self) -> None:
- data = {"membership": "join", "displayname": "other test user"}
- request_data = json.dumps(data)
+ request_data = {"membership": "join", "displayname": "other test user"}
channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/state/m.room.member/%s"
@@ -1791,7 +2268,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
request_data,
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
event_id = channel.json_body["event_id"]
channel = self.make_request(
@@ -1799,7 +2276,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
res_displayname = channel.json_body["content"]["displayname"]
self.assertEqual(res_displayname, self.displayname, channel.result)
@@ -1833,7 +2310,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason},
access_token=self.second_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self._check_for_reason(reason)
@@ -1847,7 +2324,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason},
access_token=self.second_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self._check_for_reason(reason)
@@ -1861,7 +2338,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.second_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self._check_for_reason(reason)
@@ -1875,7 +2352,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.creator_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self._check_for_reason(reason)
@@ -1887,7 +2364,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.creator_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self._check_for_reason(reason)
@@ -1899,7 +2376,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.creator_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self._check_for_reason(reason)
@@ -1918,7 +2395,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason},
access_token=self.second_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self._check_for_reason(reason)
@@ -1930,7 +2407,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
),
access_token=self.creator_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
event_content = channel.json_body
@@ -1978,7 +2455,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
% (self.room_id, event_id, json.dumps(self.FILTER_LABELS)),
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
events_before = channel.json_body["events_before"]
@@ -2008,7 +2485,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
% (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)),
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
events_before = channel.json_body["events_before"]
@@ -2043,7 +2520,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
% (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)),
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
events_before = channel.json_body["events_before"]
@@ -2123,16 +2600,14 @@ class LabelsTestCase(unittest.HomeserverTestCase):
def test_search_filter_labels(self) -> None:
"""Test that we can filter by a label on a /search request."""
- request_data = json.dumps(
- {
- "search_categories": {
- "room_events": {
- "search_term": "label",
- "filter": self.FILTER_LABELS,
- }
+ request_data = {
+ "search_categories": {
+ "room_events": {
+ "search_term": "label",
+ "filter": self.FILTER_LABELS,
}
}
- )
+ }
self._send_labelled_messages_in_room()
@@ -2160,16 +2635,14 @@ class LabelsTestCase(unittest.HomeserverTestCase):
def test_search_filter_not_labels(self) -> None:
"""Test that we can filter by the absence of a label on a /search request."""
- request_data = json.dumps(
- {
- "search_categories": {
- "room_events": {
- "search_term": "label",
- "filter": self.FILTER_NOT_LABELS,
- }
+ request_data = {
+ "search_categories": {
+ "room_events": {
+ "search_term": "label",
+ "filter": self.FILTER_NOT_LABELS,
}
}
- )
+ }
self._send_labelled_messages_in_room()
@@ -2209,16 +2682,14 @@ class LabelsTestCase(unittest.HomeserverTestCase):
"""Test that we can filter by both a label and the absence of another label on a
/search request.
"""
- request_data = json.dumps(
- {
- "search_categories": {
- "room_events": {
- "search_term": "label",
- "filter": self.FILTER_LABELS_NOT_LABELS,
- }
+ request_data = {
+ "search_categories": {
+ "room_events": {
+ "search_term": "label",
+ "filter": self.FILTER_LABELS_NOT_LABELS,
}
}
- )
+ }
self._send_labelled_messages_in_room()
@@ -2391,7 +2862,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"/rooms/%s/messages?filter=%s&dir=b" % (self.room_id, json.dumps(filter)),
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
return channel.json_body["chunk"]
@@ -2496,7 +2967,7 @@ class ContextTestCase(unittest.HomeserverTestCase):
% (self.room_id, event_id),
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
events_before = channel.json_body["events_before"]
@@ -2562,7 +3033,7 @@ class ContextTestCase(unittest.HomeserverTestCase):
% (self.room_id, event_id),
access_token=invited_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
events_before = channel.json_body["events_before"]
@@ -2663,8 +3134,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None:
url = "/_matrix/client/r0/directory/room/" + alias
- data = {"room_id": self.room_id}
- request_data = json.dumps(data)
+ request_data = {"room_id": self.room_id}
channel = self.make_request(
"PUT", url, request_data, access_token=self.room_owner_tok
@@ -2693,8 +3163,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None:
url = "/_matrix/client/r0/directory/room/" + alias
- data = {"room_id": self.room_id}
- request_data = json.dumps(data)
+ request_data = {"room_id": self.room_id}
channel = self.make_request(
"PUT", url, request_data, access_token=self.room_owner_tok
@@ -2720,7 +3189,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
"rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
- json.dumps(content),
+ content,
access_token=self.room_owner_tok,
)
self.assertEqual(channel.code, expected_code, channel.result)
@@ -2845,11 +3314,16 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
- def test_threepid_invite_spamcheck(self) -> None:
+ def test_threepid_invite_spamcheck_deprecated(self) -> None:
+ """
+ Test allowing/blocking threepid invites with a spam-check module.
+
+ In this test, we use the deprecated API in which callbacks return a bool.
+ """
# Mock a few functions to prevent the test from failing due to failing to talk to
- # a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we
+ # a remote IS. We keep the mock for make_and_store_3pid_invite around so we
# can check its call_count later on during the test.
- make_invite_mock = Mock(return_value=make_awaitable(0))
+ make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
self.hs.get_identity_handler().lookup_3pid = Mock(
return_value=make_awaitable(None),
@@ -2901,3 +3375,107 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# Also check that it stopped before calling _make_and_store_3pid_invite.
make_invite_mock.assert_called_once()
+
+ def test_threepid_invite_spamcheck(self) -> None:
+ """
+ Test allowing/blocking threepid invites with a spam-check module.
+
+ In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`."""
+ # Mock a few functions to prevent the test from failing due to failing to talk to
+ # a remote IS. We keep the mock for make_and_store_3pid_invite around so we
+ # can check its call_count later on during the test.
+ make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
+ self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
+ self.hs.get_identity_handler().lookup_3pid = Mock(
+ return_value=make_awaitable(None),
+ )
+
+ # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it
+ # allow everything for now.
+ # `spec` argument is needed for this function mock to have `__qualname__`, which
+ # is needed for `Measure` metrics buried in SpamChecker.
+ mock = Mock(
+ return_value=make_awaitable(synapse.module_api.NOT_SPAM),
+ spec=lambda *x: None,
+ )
+ self.hs.get_spam_checker()._user_may_send_3pid_invite_callbacks.append(mock)
+
+ # Send a 3PID invite into the room and check that it succeeded.
+ email_to_invite = "teresa@example.com"
+ channel = self.make_request(
+ method="POST",
+ path="/rooms/" + self.room_id + "/invite",
+ content={
+ "id_server": "example.com",
+ "id_access_token": "sometoken",
+ "medium": "email",
+ "address": email_to_invite,
+ },
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ # Check that the callback was called with the right params.
+ mock.assert_called_with(self.user_id, "email", email_to_invite, self.room_id)
+
+ # Check that the call to send the invite was made.
+ make_invite_mock.assert_called_once()
+
+ # Now change the return value of the callback to deny any invite and test that
+ # we can't send the invite. We pick an arbitrary error code to be able to check
+ # that the same code has been returned
+ mock.return_value = make_awaitable(Codes.CONSENT_NOT_GIVEN)
+ channel = self.make_request(
+ method="POST",
+ path="/rooms/" + self.room_id + "/invite",
+ content={
+ "id_server": "example.com",
+ "id_access_token": "sometoken",
+ "medium": "email",
+ "address": email_to_invite,
+ },
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 403)
+ self.assertEqual(channel.json_body["errcode"], Codes.CONSENT_NOT_GIVEN)
+
+ # Also check that it stopped before calling _make_and_store_3pid_invite.
+ make_invite_mock.assert_called_once()
+
+ # Run variant with `Tuple[Codes, dict]`.
+ mock.return_value = make_awaitable((Codes.EXPIRED_ACCOUNT, {"field": "value"}))
+ channel = self.make_request(
+ method="POST",
+ path="/rooms/" + self.room_id + "/invite",
+ content={
+ "id_server": "example.com",
+ "id_access_token": "sometoken",
+ "medium": "email",
+ "address": email_to_invite,
+ },
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 403)
+ self.assertEqual(channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT)
+ self.assertEqual(channel.json_body["field"], "value")
+
+ # Also check that it stopped before calling _make_and_store_3pid_invite.
+ make_invite_mock.assert_called_once()
+
+ def test_400_missing_param_without_id_access_token(self) -> None:
+ """
+ Test that a 3pid invite request returns 400 M_MISSING_PARAM
+ if we do not include id_access_token.
+ """
+ channel = self.make_request(
+ method="POST",
+ path="/rooms/" + self.room_id + "/invite",
+ content={
+ "id_server": "example.com",
+ "medium": "email",
+ "address": "teresa@example.com",
+ },
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 400)
+ self.assertEqual(channel.json_body["errcode"], "M_MISSING_PARAM")
|