summary refs log tree commit diff
path: root/tests/rest/client/test_rooms.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client/test_rooms.py')
-rw-r--r--tests/rest/client/test_rooms.py936
1 files changed, 757 insertions, 179 deletions
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py

index f523d89b8f..c7eb88d33f 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py
@@ -18,10 +18,14 @@ """Tests REST events for /rooms paths.""" import json -from typing import Any, Dict, Iterable, List, Optional +from http import HTTPStatus +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from unittest.mock import Mock, call from urllib import parse as urlparse +from parameterized import param, parameterized +from typing_extensions import Literal + from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin @@ -30,7 +34,9 @@ from synapse.api.constants import ( EventContentFields, EventTypes, Membership, + PublicRoomsFilterFields, RelationTypes, + RoomTypes, ) from synapse.api.errors import Codes, HttpResponseException from synapse.handlers.pagination import PurgeStatus @@ -42,6 +48,7 @@ from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest +from tests.http.server._base import make_request_with_cancellation_test from tests.test_utils import make_awaitable PATH_PREFIX = b"/_matrix/client/api/v1" @@ -98,7 +105,7 @@ class RoomPermissionsTestCase(RoomBase): channel = self.make_request( "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}' ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) # set topic for public room channel = self.make_request( @@ -106,7 +113,7 @@ class RoomPermissionsTestCase(RoomBase): ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"), b'{"topic":"Public Room Topic"}', ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) # auth as user_id now self.helper.auth_user_id = self.user_id @@ -128,28 +135,28 @@ class RoomPermissionsTestCase(RoomBase): "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), msg_content, ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # send message in created room not joined (no state), expect 403 channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # send message in created room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # send message in created room and joined, expect 200 self.helper.join(room=self.created_rmid, user=self.user_id) channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # send message in created room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_topic_perms(self) -> None: topic_content = b'{"topic":"My Topic Name"}' @@ -159,28 +166,28 @@ class RoomPermissionsTestCase(RoomBase): channel = self.make_request( "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) channel = self.make_request( "GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room not joined, expect 403 channel = self.make_request("PUT", topic_path, topic_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", topic_path) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # set topic in created PRIVATE room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) channel = self.make_request("PUT", topic_path, topic_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # get topic in created PRIVATE room and invited, expect 403 channel = self.make_request("GET", topic_path) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room and joined, expect 200 self.helper.join(room=self.created_rmid, user=self.user_id) @@ -188,25 +195,25 @@ class RoomPermissionsTestCase(RoomBase): # Only room ops can set topic by default self.helper.auth_user_id = self.rmcreator_id channel = self.make_request("PUT", topic_path, topic_content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.helper.auth_user_id = self.user_id channel = self.make_request("GET", topic_path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body) # set/get topic in created PRIVATE room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) channel = self.make_request("PUT", topic_path, topic_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", topic_path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # get topic in PUBLIC room, not joined, expect 403 channel = self.make_request( "GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # set topic in PUBLIC room, not joined, expect 403 channel = self.make_request( @@ -214,7 +221,7 @@ class RoomPermissionsTestCase(RoomBase): "/rooms/%s/state/m.room.topic" % self.created_public_rmid, topic_content, ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def _test_get_membership( self, room: str, members: Iterable = frozenset(), expect_code: int = 200 @@ -303,14 +310,14 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=self.rmcreator_id, membership=Membership.JOIN, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) self.helper.change_membership( room=room, src=self.user_id, targ=self.rmcreator_id, membership=Membership.LEAVE, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) def test_joined_permissions(self) -> None: @@ -336,7 +343,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=other, membership=Membership.JOIN, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) # set left of other, expect 403 @@ -345,7 +352,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=other, membership=Membership.LEAVE, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) # set left of self, expect 200 @@ -365,7 +372,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=usr, membership=Membership.INVITE, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) self.helper.change_membership( @@ -373,7 +380,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=usr, membership=Membership.JOIN, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) # It is always valid to LEAVE if you've already left (currently.) @@ -382,7 +389,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=self.rmcreator_id, membership=Membership.LEAVE, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) # tests the "from banned" line from the table in https://spec.matrix.org/unstable/client-server-api/#mroommember @@ -399,7 +406,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=other, membership=Membership.BAN, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure expect_errcode=Codes.FORBIDDEN, ) @@ -409,7 +416,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.rmcreator_id, targ=other, membership=Membership.BAN, - expect_code=200, + expect_code=HTTPStatus.OK, ) # from ban to invite: Must never happen. @@ -418,7 +425,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.rmcreator_id, targ=other, membership=Membership.INVITE, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure expect_errcode=Codes.BAD_STATE, ) @@ -428,7 +435,7 @@ class RoomPermissionsTestCase(RoomBase): src=other, targ=other, membership=Membership.JOIN, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure expect_errcode=Codes.BAD_STATE, ) @@ -438,7 +445,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.rmcreator_id, targ=other, membership=Membership.BAN, - expect_code=200, + expect_code=HTTPStatus.OK, ) # from ban to knock: Must never happen. @@ -447,7 +454,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.rmcreator_id, targ=other, membership=Membership.KNOCK, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure expect_errcode=Codes.BAD_STATE, ) @@ -457,7 +464,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=other, membership=Membership.LEAVE, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure expect_errcode=Codes.FORBIDDEN, ) @@ -467,10 +474,53 @@ class RoomPermissionsTestCase(RoomBase): src=self.rmcreator_id, targ=other, membership=Membership.LEAVE, - expect_code=200, + expect_code=HTTPStatus.OK, ) +class RoomStateTestCase(RoomBase): + """Tests /rooms/$room_id/state.""" + + user_id = "@sid1:red" + + def test_get_state_cancellation(self) -> None: + """Test cancellation of a `/rooms/$room_id/state` request.""" + room_id = self.helper.create_room_as(self.user_id) + channel = make_request_with_cancellation_test( + "test_state_cancellation", + self.reactor, + self.site, + "GET", + "/rooms/%s/state" % room_id, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) + self.assertCountEqual( + [state_event["type"] for state_event in channel.json_list], + { + "m.room.create", + "m.room.power_levels", + "m.room.join_rules", + "m.room.member", + "m.room.history_visibility", + }, + ) + + def test_get_state_event_cancellation(self) -> None: + """Test cancellation of a `/rooms/$room_id/state/$event_type` request.""" + room_id = self.helper.create_room_as(self.user_id) + channel = make_request_with_cancellation_test( + "test_state_cancellation", + self.reactor, + self.site, + "GET", + "/rooms/%s/state/m.room.member/%s" % (room_id, self.user_id), + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) + self.assertEqual(channel.json_body, {"membership": "join"}) + + class RoomsMemberListTestCase(RoomBase): """Tests /rooms/$room_id/members/list REST events.""" @@ -481,16 +531,16 @@ class RoomsMemberListTestCase(RoomBase): def test_get_member_list(self) -> None: room_id = self.helper.create_room_as(self.user_id) channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) def test_get_member_list_no_room(self) -> None: channel = self.make_request("GET", "/rooms/roomdoesnotexist/members") - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission(self) -> None: room_id = self.helper.create_room_as("@some_other_guy:red") channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission_with_at_token(self) -> None: """ @@ -501,7 +551,7 @@ class RoomsMemberListTestCase(RoomBase): # first sync to get an at token channel = self.make_request("GET", "/sync") - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) sync_token = channel.json_body["next_batch"] # check that permission is denied for @sid1:red to get the @@ -510,7 +560,7 @@ class RoomsMemberListTestCase(RoomBase): "GET", f"/rooms/{room_id}/members?at={sync_token}", ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission_former_member(self) -> None: """ @@ -523,14 +573,14 @@ class RoomsMemberListTestCase(RoomBase): # check that the user can see the member list to start with channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # ban the user self.helper.change_membership(room_id, "@alice:red", self.user_id, "ban") # check the user can no longer see the member list channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission_former_member_with_at_token(self) -> None: """ @@ -544,14 +594,14 @@ class RoomsMemberListTestCase(RoomBase): # sync to get an at token channel = self.make_request("GET", "/sync") - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) sync_token = channel.json_body["next_batch"] # check that the user can see the member list to start with channel = self.make_request( "GET", "/rooms/%s/members?at=%s" % (room_id, sync_token) ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # ban the user (Note: the user is actually allowed to see this event and # state so that they know they're banned!) @@ -563,14 +613,14 @@ class RoomsMemberListTestCase(RoomBase): # now, with the original user, sync again to get a new at token channel = self.make_request("GET", "/sync") - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) sync_token = channel.json_body["next_batch"] # check the user can no longer see the updated member list channel = self.make_request( "GET", "/rooms/%s/members?at=%s" % (room_id, sync_token) ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_get_member_list_mixed_memberships(self) -> None: room_creator = "@some_other_guy:red" @@ -579,17 +629,73 @@ class RoomsMemberListTestCase(RoomBase): self.helper.invite(room=room_id, src=room_creator, targ=self.user_id) # can't see list if you're just invited. channel = self.make_request("GET", room_path) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) self.helper.join(room=room_id, user=self.user_id) # can see list now joined channel = self.make_request("GET", room_path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.helper.leave(room=room_id, user=self.user_id) # can see old list once left channel = self.make_request("GET", room_path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) + + def test_get_member_list_cancellation(self) -> None: + """Test cancellation of a `/rooms/$room_id/members` request.""" + room_id = self.helper.create_room_as(self.user_id) + channel = make_request_with_cancellation_test( + "test_get_member_list_cancellation", + self.reactor, + self.site, + "GET", + "/rooms/%s/members" % room_id, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["chunk"]), 1) + self.assertLessEqual( + { + "content": {"membership": "join"}, + "room_id": room_id, + "sender": self.user_id, + "state_key": self.user_id, + "type": "m.room.member", + "user_id": self.user_id, + }.items(), + channel.json_body["chunk"][0].items(), + ) + + def test_get_member_list_with_at_token_cancellation(self) -> None: + """Test cancellation of a `/rooms/$room_id/members?at=<sync token>` request.""" + room_id = self.helper.create_room_as(self.user_id) + + # first sync to get an at token + channel = self.make_request("GET", "/sync") + self.assertEqual(HTTPStatus.OK, channel.code) + sync_token = channel.json_body["next_batch"] + + channel = make_request_with_cancellation_test( + "test_get_member_list_with_at_token_cancellation", + self.reactor, + self.site, + "GET", + "/rooms/%s/members?at=%s" % (room_id, sync_token), + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["chunk"]), 1) + self.assertLessEqual( + { + "content": {"membership": "join"}, + "room_id": room_id, + "sender": self.user_id, + "state_key": self.user_id, + "type": "m.room.member", + "user_id": self.user_id, + }.items(), + channel.json_body["chunk"][0].items(), + ) class RoomsCreateTestCase(RoomBase): @@ -601,19 +707,34 @@ class RoomsCreateTestCase(RoomBase): # POST with no config keys, expect new room id channel = self.make_request("POST", "/createRoom", "{}") - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) + assert channel.resource_usage is not None + self.assertEqual(44, channel.resource_usage.db_txn_count) + + def test_post_room_initial_state(self) -> None: + # POST with initial_state config key, expect new room id + channel = self.make_request( + "POST", + "/createRoom", + b'{"initial_state":[{"type": "m.bridge", "content": {}}]}', + ) + + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + self.assertTrue("room_id" in channel.json_body) + assert channel.resource_usage is not None + self.assertEqual(50, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}') - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_custom_key(self) -> None: # POST with custom config keys, expect new room id channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}') - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_known_and_unknown_keys(self) -> None: @@ -621,16 +742,16 @@ class RoomsCreateTestCase(RoomBase): channel = self.make_request( "POST", "/createRoom", b'{"visibility":"private","custom":"things"}' ) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_invalid_content(self) -> None: # POST with invalid content / paths, expect 400 channel = self.make_request("POST", "/createRoom", b'{"visibili') - self.assertEqual(400, channel.code) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code) channel = self.make_request("POST", "/createRoom", b'["hello"]') - self.assertEqual(400, channel.code) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code) def test_post_room_invitees_invalid_mxid(self) -> None: # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088 @@ -638,7 +759,7 @@ class RoomsCreateTestCase(RoomBase): channel = self.make_request( "POST", "/createRoom", b'{"invite":["@alice:example.com "]}' ) - self.assertEqual(400, channel.code) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code) @unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}}) def test_post_room_invitees_ratelimit(self) -> None: @@ -649,20 +770,18 @@ class RoomsCreateTestCase(RoomBase): # Build the request's content. We use local MXIDs because invites over federation # are more difficult to mock. - content = json.dumps( - { - "invite": [ - "@alice1:red", - "@alice2:red", - "@alice3:red", - "@alice4:red", - ] - } - ).encode("utf8") + content = { + "invite": [ + "@alice1:red", + "@alice2:red", + "@alice3:red", + "@alice4:red", + ] + } # Test that the invites are correctly ratelimited. channel = self.make_request("POST", "/createRoom", content) - self.assertEqual(400, channel.code) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code) self.assertEqual( "Cannot invite so many users at once", channel.json_body["error"], @@ -675,11 +794,13 @@ class RoomsCreateTestCase(RoomBase): # Test that the invites aren't ratelimited anymore. channel = self.make_request("POST", "/createRoom", content) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) - def test_spam_checker_may_join_room(self) -> None: + def test_spam_checker_may_join_room_deprecated(self) -> None: """Tests that the user_may_join_room spam checker callback is correctly bypassed when creating a new room. + + In this test, we use the deprecated API in which callbacks return a bool. """ async def user_may_join_room( @@ -697,10 +818,55 @@ class RoomsCreateTestCase(RoomBase): "/createRoom", {}, ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) self.assertEqual(join_mock.call_count, 0) + def test_spam_checker_may_join_room(self) -> None: + """Tests that the user_may_join_room spam checker callback is correctly bypassed + when creating a new room. + + In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`. + """ + + async def user_may_join_room_codes( + mxid: str, + room_id: str, + is_invite: bool, + ) -> Codes: + return Codes.CONSENT_NOT_GIVEN + + join_mock = Mock(side_effect=user_may_join_room_codes) + self.hs.get_spam_checker()._user_may_join_room_callbacks.append(join_mock) + + channel = self.make_request( + "POST", + "/createRoom", + {}, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + + self.assertEqual(join_mock.call_count, 0) + + # Now change the return value of the callback to deny any join. Since we're + # creating the room, despite the return value, we should be able to join. + async def user_may_join_room_tuple( + mxid: str, + room_id: str, + is_invite: bool, + ) -> Tuple[Codes, dict]: + return Codes.INCOMPATIBLE_ROOM_VERSION, {} + + join_mock.side_effect = user_may_join_room_tuple + + channel = self.make_request( + "POST", + "/createRoom", + {}, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + self.assertEqual(join_mock.call_count, 0) + class RoomTopicTestCase(RoomBase): """Tests /rooms/$room_id/topic REST events.""" @@ -715,54 +881,68 @@ class RoomTopicTestCase(RoomBase): def test_invalid_puts(self) -> None: # missing keys or invalid json channel = self.make_request("PUT", self.path, "{}") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", self.path, '{"_name":"bo"}') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", self.path, '{"nao') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request( "PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]' ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", self.path, "text only") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", self.path, "") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) # valid key, wrong type content = '{"topic":["Topic name"]}' channel = self.make_request("PUT", self.path, content) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) def test_rooms_topic(self) -> None: # nothing should be there channel = self.make_request("GET", self.path) - self.assertEqual(404, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.result["body"]) # valid put content = '{"topic":"Topic name"}' channel = self.make_request("PUT", self.path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # valid get channel = self.make_request("GET", self.path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) def test_rooms_topic_with_extra_keys(self) -> None: # valid put with extra keys content = '{"topic":"Seasons","subtopic":"Summer"}' channel = self.make_request("PUT", self.path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # valid get channel = self.make_request("GET", self.path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) @@ -778,22 +958,34 @@ class RoomMemberStateTestCase(RoomBase): path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) # missing keys or invalid json channel = self.make_request("PUT", path, "{}") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, '{"_name":"bo"}') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, '{"nao') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, "text only") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, "") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) # valid keys, wrong types content = '{"membership":["%s","%s","%s"]}' % ( @@ -802,7 +994,9 @@ class RoomMemberStateTestCase(RoomBase): Membership.LEAVE, ) channel = self.make_request("PUT", path, content.encode("ascii")) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) def test_rooms_members_self(self) -> None: path = "/rooms/%s/state/m.room.member/%s" % ( @@ -813,10 +1007,10 @@ class RoomMemberStateTestCase(RoomBase): # valid join message (NOOP since we made the room) content = '{"membership":"%s"}' % Membership.JOIN channel = self.make_request("PUT", path, content.encode("ascii")) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", path, content=b"") - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) expected_response = {"membership": Membership.JOIN} self.assertEqual(expected_response, channel.json_body) @@ -831,10 +1025,10 @@ class RoomMemberStateTestCase(RoomBase): # valid invite message content = '{"membership":"%s"}' % Membership.INVITE channel = self.make_request("PUT", path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", path, content=b"") - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual(json.loads(content), channel.json_body) def test_rooms_members_other_custom_keys(self) -> None: @@ -850,10 +1044,10 @@ class RoomMemberStateTestCase(RoomBase): "Join us!", ) channel = self.make_request("PUT", path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", path, content=b"") - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual(json.loads(content), channel.json_body) @@ -911,9 +1105,11 @@ class RoomJoinTestCase(RoomBase): self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) - def test_spam_checker_may_join_room(self) -> None: + def test_spam_checker_may_join_room_deprecated(self) -> None: """Tests that the user_may_join_room spam checker callback is correctly called and blocks room joins when needed. + + This test uses the deprecated API, in which callbacks return booleans. """ # Register a dummy callback. Make it allow all room joins for now. @@ -926,6 +1122,8 @@ class RoomJoinTestCase(RoomBase): ) -> bool: return return_value + # `spec` argument is needed for this function mock to have `__qualname__`, which + # is needed for `Measure` metrics buried in SpamChecker. callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None) self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock) @@ -966,7 +1164,92 @@ class RoomJoinTestCase(RoomBase): # Now make the callback deny all room joins, and check that a join actually fails. return_value = False - self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2) + self.helper.join( + self.room3, self.user2, expect_code=HTTPStatus.FORBIDDEN, tok=self.tok2 + ) + + def test_spam_checker_may_join_room(self) -> None: + """Tests that the user_may_join_room spam checker callback is correctly called + and blocks room joins when needed. + + This test uses the latest API to this day, in which callbacks return `NOT_SPAM` or `Codes`. + """ + + # Register a dummy callback. Make it allow all room joins for now. + return_value: Union[ + Literal["NOT_SPAM"], Tuple[Codes, dict], Codes + ] = synapse.module_api.NOT_SPAM + + async def user_may_join_room( + userid: str, + room_id: str, + is_invited: bool, + ) -> Union[Literal["NOT_SPAM"], Tuple[Codes, dict], Codes]: + return return_value + + # `spec` argument is needed for this function mock to have `__qualname__`, which + # is needed for `Measure` metrics buried in SpamChecker. + callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None) + self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock) + + # Join a first room, without being invited to it. + self.helper.join(self.room1, self.user2, tok=self.tok2) + + # Check that the callback was called with the right arguments. + expected_call_args = ( + ( + self.user2, + self.room1, + False, + ), + ) + self.assertEqual( + callback_mock.call_args, + expected_call_args, + callback_mock.call_args, + ) + + # Join a second room, this time with an invite for it. + self.helper.invite(self.room2, self.user1, self.user2, tok=self.tok1) + self.helper.join(self.room2, self.user2, tok=self.tok2) + + # Check that the callback was called with the right arguments. + expected_call_args = ( + ( + self.user2, + self.room2, + True, + ), + ) + self.assertEqual( + callback_mock.call_args, + expected_call_args, + callback_mock.call_args, + ) + + # Now make the callback deny all room joins, and check that a join actually fails. + # We pick an arbitrary Codes rather than the default `Codes.FORBIDDEN`. + return_value = Codes.CONSENT_NOT_GIVEN + self.helper.invite(self.room3, self.user1, self.user2, tok=self.tok1) + self.helper.join( + self.room3, + self.user2, + expect_code=HTTPStatus.FORBIDDEN, + expect_errcode=return_value, + tok=self.tok2, + ) + + # Now make the callback deny all room joins, and check that a join actually fails. + # As above, with the experimental extension that lets us return dictionaries. + return_value = (Codes.BAD_ALIAS, {"another_field": "12345"}) + self.helper.join( + self.room3, + self.user2, + expect_code=HTTPStatus.FORBIDDEN, + expect_errcode=return_value[0], + tok=self.tok2, + expect_additional_fields=return_value[1], + ) class RoomJoinRatelimitTestCase(RoomBase): @@ -1016,7 +1299,7 @@ class RoomJoinRatelimitTestCase(RoomBase): # Update the display name for the user. path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id channel = self.make_request("PUT", path, {"displayname": "John Doe"}) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) # Check that all the rooms have been sent a profile update into. for room_id in room_ids: @@ -1081,40 +1364,153 @@ class RoomMessagesTestCase(RoomBase): path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) # missing keys or invalid json channel = self.make_request("PUT", path, b"{}") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b'{"_name":"bo"}') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b'{"nao') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b"text only") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b"") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) def test_rooms_messages_sent(self) -> None: path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) content = b'{"body":"test","msgtype":{"type":"a"}}' channel = self.make_request("PUT", path, content) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) # custom message types content = b'{"body":"test","msgtype":"test.custom.text"}' channel = self.make_request("PUT", path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # m.text message type path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id)) content = b'{"body":"test2","msgtype":"m.text"}' channel = self.make_request("PUT", path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) + + @parameterized.expand( + [ + # Allow + param( + name="NOT_SPAM", + value="NOT_SPAM", + expected_code=HTTPStatus.OK, + expected_fields={}, + ), + param( + name="False", + value=False, + expected_code=HTTPStatus.OK, + expected_fields={}, + ), + # Block + param( + name="scalene string", + value="ANY OTHER STRING", + expected_code=HTTPStatus.FORBIDDEN, + expected_fields={"errcode": "M_FORBIDDEN"}, + ), + param( + name="True", + value=True, + expected_code=HTTPStatus.FORBIDDEN, + expected_fields={"errcode": "M_FORBIDDEN"}, + ), + param( + name="Code", + value=Codes.LIMIT_EXCEEDED, + expected_code=HTTPStatus.FORBIDDEN, + expected_fields={"errcode": "M_LIMIT_EXCEEDED"}, + ), + param( + name="Tuple", + value=(Codes.SERVER_NOT_TRUSTED, {"additional_field": "12345"}), + expected_code=HTTPStatus.FORBIDDEN, + expected_fields={ + "errcode": "M_SERVER_NOT_TRUSTED", + "additional_field": "12345", + }, + ), + ] + ) + def test_spam_checker_check_event_for_spam( + self, + name: str, + value: Union[str, bool, Codes, Tuple[Codes, JsonDict]], + expected_code: int, + expected_fields: dict, + ) -> None: + class SpamCheck: + mock_return_value: Union[ + str, bool, Codes, Tuple[Codes, JsonDict], bool + ] = "NOT_SPAM" + mock_content: Optional[JsonDict] = None + + async def check_event_for_spam( + self, + event: synapse.events.EventBase, + ) -> Union[str, Codes, Tuple[Codes, JsonDict], bool]: + self.mock_content = event.content + return self.mock_return_value + + spam_checker = SpamCheck() + + self.hs.get_spam_checker()._check_event_for_spam_callbacks.append( + spam_checker.check_event_for_spam + ) + + # Inject `value` as mock_return_value + spam_checker.mock_return_value = value + path = "/rooms/%s/send/m.room.message/check_event_for_spam_%s" % ( + urlparse.quote(self.room_id), + urlparse.quote(name), + ) + body = "test-%s" % name + content = '{"body":"%s","msgtype":"m.text"}' % body + channel = self.make_request("PUT", path, content) + + # Check that the callback has witnessed the correct event. + self.assertIsNotNone(spam_checker.mock_content) + if ( + spam_checker.mock_content is not None + ): # Checked just above, but mypy doesn't know about that. + self.assertEqual( + spam_checker.mock_content["body"], body, spam_checker.mock_content + ) + + # Check that we have the correct result. + self.assertEqual(expected_code, channel.code, msg=channel.result["body"]) + for expected_key, expected_value in expected_fields.items(): + self.assertEqual( + channel.json_body.get(expected_key, None), + expected_value, + "Field %s absent or invalid " % expected_key, + ) class RoomPowerLevelOverridesTestCase(RoomBase): @@ -1239,7 +1635,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel = self.make_request("PUT", path, "{}") # Then I am allowed - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) def test_normal_user_can_not_post_state_event(self) -> None: # Given I am a normal member of a room @@ -1253,7 +1649,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel = self.make_request("PUT", path, "{}") # Then I am not allowed because state events require PL>=50 - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) self.assertEqual( "You don't have permission to post that to the room. " "user_level (0) < send_level (50)", @@ -1280,7 +1676,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel = self.make_request("PUT", path, "{}") # Then I am allowed - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) @unittest.override_config( { @@ -1308,7 +1704,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel = self.make_request("PUT", path, "{}") # Then I am not allowed - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) @unittest.override_config( { @@ -1336,7 +1732,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel = self.make_request("PUT", path, "{}") # Then I am not allowed - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) self.assertEqual( "You don't have permission to post that to the room. " + "user_level (0) < send_level (1)", @@ -1367,7 +1763,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): # Then I am not allowed because the public_chat config does not # affect this room, because this room is a private_chat - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) self.assertEqual( "You don't have permission to post that to the room. " + "user_level (0) < send_level (50)", @@ -1386,7 +1782,7 @@ class RoomInitialSyncTestCase(RoomBase): def test_initial_sync(self) -> None: channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertEqual(self.room_id, channel.json_body["room_id"]) self.assertEqual("join", channel.json_body["membership"]) @@ -1429,7 +1825,7 @@ class RoomMessageListTestCase(RoomBase): channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertTrue("start" in channel.json_body) self.assertEqual(token, channel.json_body["start"]) self.assertTrue("chunk" in channel.json_body) @@ -1440,7 +1836,7 @@ class RoomMessageListTestCase(RoomBase): channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertTrue("start" in channel.json_body) self.assertEqual(token, channel.json_body["start"]) self.assertTrue("chunk" in channel.json_body) @@ -1479,7 +1875,7 @@ class RoomMessageListTestCase(RoomBase): json.dumps({"types": [EventTypes.Message]}), ), ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) chunk = channel.json_body["chunk"] self.assertEqual(len(chunk), 2, [event["content"] for event in chunk]) @@ -1507,7 +1903,7 @@ class RoomMessageListTestCase(RoomBase): json.dumps({"types": [EventTypes.Message]}), ), ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) chunk = channel.json_body["chunk"] self.assertEqual(len(chunk), 1, [event["content"] for event in chunk]) @@ -1524,7 +1920,7 @@ class RoomMessageListTestCase(RoomBase): json.dumps({"types": [EventTypes.Message]}), ), ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) chunk = channel.json_body["chunk"] self.assertEqual(len(chunk), 0, [event["content"] for event in chunk]) @@ -1652,14 +2048,97 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): def test_restricted_no_auth(self) -> None: channel = self.make_request("GET", self.url) - self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) def test_restricted_auth(self) -> None: self.register_user("user", "pass") tok = self.login("user", "pass") channel = self.make_request("GET", self.url, access_token=tok) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + +class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + + config = self.default_config() + config["allow_public_rooms_without_auth"] = True + self.hs = self.setup_test_homeserver(config=config) + self.url = b"/_matrix/client/r0/publicRooms" + + return self.hs + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + user = self.register_user("alice", "pass") + self.token = self.login(user, "pass") + + # Create a room + self.helper.create_room_as( + user, + is_public=True, + extra_content={"visibility": "public"}, + tok=self.token, + ) + # Create a space + self.helper.create_room_as( + user, + is_public=True, + extra_content={ + "visibility": "public", + "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}, + }, + tok=self.token, + ) + + def make_public_rooms_request( + self, room_types: Union[List[Union[str, None]], None] + ) -> Tuple[List[Dict[str, Any]], int]: + channel = self.make_request( + "POST", + self.url, + {"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}}, + self.token, + ) + chunk = channel.json_body["chunk"] + count = channel.json_body["total_room_count_estimate"] + + self.assertEqual(len(chunk), count) + + return chunk, count + + def test_returns_both_rooms_and_spaces_if_no_filter(self) -> None: + chunk, count = self.make_public_rooms_request(None) + + self.assertEqual(count, 2) + + def test_returns_only_rooms_based_on_filter(self) -> None: + chunk, count = self.make_public_rooms_request([None]) + + self.assertEqual(count, 1) + self.assertEqual(chunk[0].get("room_type", None), None) + + def test_returns_only_space_based_on_filter(self) -> None: + chunk, count = self.make_public_rooms_request(["m.space"]) + + self.assertEqual(count, 1) + self.assertEqual(chunk[0].get("room_type", None), "m.space") + + def test_returns_both_rooms_and_space_based_on_filter(self) -> None: + chunk, count = self.make_public_rooms_request(["m.space", None]) + + self.assertEqual(count, 2) + + def test_returns_both_rooms_and_spaces_if_array_is_empty(self) -> None: + chunk, count = self.make_public_rooms_request([]) + + self.assertEqual(count, 2) class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): @@ -1686,7 +2165,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): "Simple test for searching rooms over federation" self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined] - search_filter = {"generic_search_term": "foobar"} + search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"} channel = self.make_request( "POST", @@ -1694,7 +2173,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): content={"filter": search_filter}, access_token=self.token, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.federation_client.get_public_rooms.assert_called_once_with( # type: ignore[attr-defined] "testserv", @@ -1711,11 +2190,11 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): # The `get_public_rooms` should be called again if the first call fails # with a 404, when using search filters. self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined] - HttpResponseException(404, "Not Found", b""), + HttpResponseException(HTTPStatus.NOT_FOUND, "Not Found", b""), make_awaitable({}), ) - search_filter = {"generic_search_term": "foobar"} + search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"} channel = self.make_request( "POST", @@ -1723,7 +2202,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): content={"filter": search_filter}, access_token=self.token, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.federation_client.get_public_rooms.assert_has_calls( # type: ignore[attr-defined] [ @@ -1769,21 +2248,19 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): # Set a profile for the test user self.displayname = "test user" - data = {"displayname": self.displayname} - request_data = json.dumps(data) + request_data = {"displayname": self.displayname} channel = self.make_request( "PUT", "/_matrix/client/r0/profile/%s/displayname" % (self.user_id,), request_data, access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) def test_per_room_profile_forbidden(self) -> None: - data = {"membership": "join", "displayname": "other test user"} - request_data = json.dumps(data) + request_data = {"membership": "join", "displayname": "other test user"} channel = self.make_request( "PUT", "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" @@ -1791,7 +2268,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): request_data, access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) event_id = channel.json_body["event_id"] channel = self.make_request( @@ -1799,7 +2276,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) res_displayname = channel.json_body["content"]["displayname"] self.assertEqual(res_displayname, self.displayname, channel.result) @@ -1833,7 +2310,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason}, access_token=self.second_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -1847,7 +2324,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason}, access_token=self.second_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -1861,7 +2338,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason, "user_id": self.second_user_id}, access_token=self.second_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -1875,7 +2352,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -1887,7 +2364,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -1899,7 +2376,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -1918,7 +2395,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason}, access_token=self.second_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -1930,7 +2407,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): ), access_token=self.creator_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) event_content = channel.json_body @@ -1978,7 +2455,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id, json.dumps(self.FILTER_LABELS)), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) events_before = channel.json_body["events_before"] @@ -2008,7 +2485,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) events_before = channel.json_body["events_before"] @@ -2043,7 +2520,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) events_before = channel.json_body["events_before"] @@ -2123,16 +2600,14 @@ class LabelsTestCase(unittest.HomeserverTestCase): def test_search_filter_labels(self) -> None: """Test that we can filter by a label on a /search request.""" - request_data = json.dumps( - { - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_LABELS, - } + request_data = { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS, } } - ) + } self._send_labelled_messages_in_room() @@ -2160,16 +2635,14 @@ class LabelsTestCase(unittest.HomeserverTestCase): def test_search_filter_not_labels(self) -> None: """Test that we can filter by the absence of a label on a /search request.""" - request_data = json.dumps( - { - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_NOT_LABELS, - } + request_data = { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_NOT_LABELS, } } - ) + } self._send_labelled_messages_in_room() @@ -2209,16 +2682,14 @@ class LabelsTestCase(unittest.HomeserverTestCase): """Test that we can filter by both a label and the absence of another label on a /search request. """ - request_data = json.dumps( - { - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_LABELS_NOT_LABELS, - } + request_data = { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS_NOT_LABELS, } } - ) + } self._send_labelled_messages_in_room() @@ -2391,7 +2862,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "/rooms/%s/messages?filter=%s&dir=b" % (self.room_id, json.dumps(filter)), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) return channel.json_body["chunk"] @@ -2496,7 +2967,7 @@ class ContextTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) events_before = channel.json_body["events_before"] @@ -2562,7 +3033,7 @@ class ContextTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id), access_token=invited_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) events_before = channel.json_body["events_before"] @@ -2663,8 +3134,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None: url = "/_matrix/client/r0/directory/room/" + alias - data = {"room_id": self.room_id} - request_data = json.dumps(data) + request_data = {"room_id": self.room_id} channel = self.make_request( "PUT", url, request_data, access_token=self.room_owner_tok @@ -2693,8 +3163,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None: url = "/_matrix/client/r0/directory/room/" + alias - data = {"room_id": self.room_id} - request_data = json.dumps(data) + request_data = {"room_id": self.room_id} channel = self.make_request( "PUT", url, request_data, access_token=self.room_owner_tok @@ -2720,7 +3189,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), - json.dumps(content), + content, access_token=self.room_owner_tok, ) self.assertEqual(channel.code, expected_code, channel.result) @@ -2845,11 +3314,16 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - def test_threepid_invite_spamcheck(self) -> None: + def test_threepid_invite_spamcheck_deprecated(self) -> None: + """ + Test allowing/blocking threepid invites with a spam-check module. + + In this test, we use the deprecated API in which callbacks return a bool. + """ # Mock a few functions to prevent the test from failing due to failing to talk to - # a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we + # a remote IS. We keep the mock for make_and_store_3pid_invite around so we # can check its call_count later on during the test. - make_invite_mock = Mock(return_value=make_awaitable(0)) + make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock self.hs.get_identity_handler().lookup_3pid = Mock( return_value=make_awaitable(None), @@ -2901,3 +3375,107 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): # Also check that it stopped before calling _make_and_store_3pid_invite. make_invite_mock.assert_called_once() + + def test_threepid_invite_spamcheck(self) -> None: + """ + Test allowing/blocking threepid invites with a spam-check module. + + In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`.""" + # Mock a few functions to prevent the test from failing due to failing to talk to + # a remote IS. We keep the mock for make_and_store_3pid_invite around so we + # can check its call_count later on during the test. + make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) + self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock + self.hs.get_identity_handler().lookup_3pid = Mock( + return_value=make_awaitable(None), + ) + + # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it + # allow everything for now. + # `spec` argument is needed for this function mock to have `__qualname__`, which + # is needed for `Measure` metrics buried in SpamChecker. + mock = Mock( + return_value=make_awaitable(synapse.module_api.NOT_SPAM), + spec=lambda *x: None, + ) + self.hs.get_spam_checker()._user_may_send_3pid_invite_callbacks.append(mock) + + # Send a 3PID invite into the room and check that it succeeded. + email_to_invite = "teresa@example.com" + channel = self.make_request( + method="POST", + path="/rooms/" + self.room_id + "/invite", + content={ + "id_server": "example.com", + "id_access_token": "sometoken", + "medium": "email", + "address": email_to_invite, + }, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200) + + # Check that the callback was called with the right params. + mock.assert_called_with(self.user_id, "email", email_to_invite, self.room_id) + + # Check that the call to send the invite was made. + make_invite_mock.assert_called_once() + + # Now change the return value of the callback to deny any invite and test that + # we can't send the invite. We pick an arbitrary error code to be able to check + # that the same code has been returned + mock.return_value = make_awaitable(Codes.CONSENT_NOT_GIVEN) + channel = self.make_request( + method="POST", + path="/rooms/" + self.room_id + "/invite", + content={ + "id_server": "example.com", + "id_access_token": "sometoken", + "medium": "email", + "address": email_to_invite, + }, + access_token=self.tok, + ) + self.assertEqual(channel.code, 403) + self.assertEqual(channel.json_body["errcode"], Codes.CONSENT_NOT_GIVEN) + + # Also check that it stopped before calling _make_and_store_3pid_invite. + make_invite_mock.assert_called_once() + + # Run variant with `Tuple[Codes, dict]`. + mock.return_value = make_awaitable((Codes.EXPIRED_ACCOUNT, {"field": "value"})) + channel = self.make_request( + method="POST", + path="/rooms/" + self.room_id + "/invite", + content={ + "id_server": "example.com", + "id_access_token": "sometoken", + "medium": "email", + "address": email_to_invite, + }, + access_token=self.tok, + ) + self.assertEqual(channel.code, 403) + self.assertEqual(channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT) + self.assertEqual(channel.json_body["field"], "value") + + # Also check that it stopped before calling _make_and_store_3pid_invite. + make_invite_mock.assert_called_once() + + def test_400_missing_param_without_id_access_token(self) -> None: + """ + Test that a 3pid invite request returns 400 M_MISSING_PARAM + if we do not include id_access_token. + """ + channel = self.make_request( + method="POST", + path="/rooms/" + self.room_id + "/invite", + content={ + "id_server": "example.com", + "medium": "email", + "address": "teresa@example.com", + }, + access_token=self.tok, + ) + self.assertEqual(channel.code, 400) + self.assertEqual(channel.json_body["errcode"], "M_MISSING_PARAM")