diff options
Diffstat (limited to 'tests/rest/client')
-rw-r--r-- | tests/rest/client/test_rooms.py | 69 |
1 files changed, 66 insertions, 3 deletions
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index b796163dcb..d398cead1c 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -48,7 +48,16 @@ from synapse.appservice import ApplicationService from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.rest import admin -from synapse.rest.client import account, directory, login, profile, register, room, sync +from synapse.rest.client import ( + account, + directory, + knock, + login, + profile, + register, + room, + sync, +) from synapse.server import HomeServer from synapse.types import JsonDict, RoomAlias, UserID, create_requester from synapse.util import Clock @@ -733,7 +742,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(32, channel.resource_usage.db_txn_count) + self.assertEqual(33, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -746,7 +755,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(34, channel.resource_usage.db_txn_count) + self.assertEqual(35, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id @@ -1154,6 +1163,7 @@ class RoomJoinTestCase(RoomBase): admin.register_servlets, login.register_servlets, room.register_servlets, + knock.register_servlets, ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -1167,6 +1177,8 @@ class RoomJoinTestCase(RoomBase): self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) + self.store = hs.get_datastores().main + def test_spam_checker_may_join_room_deprecated(self) -> None: """Tests that the user_may_join_room spam checker callback is correctly called and blocks room joins when needed. @@ -1317,6 +1329,57 @@ class RoomJoinTestCase(RoomBase): expect_additional_fields=return_value[1], ) + def test_suspended_user_cannot_join_room(self) -> None: + # set the user as suspended + self.get_success(self.store.set_user_suspended_status(self.user2, True)) + + channel = self.make_request( + "POST", f"/join/{self.room1}", access_token=self.tok2 + ) + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" + ) + + channel = self.make_request( + "POST", f"/rooms/{self.room1}/join", access_token=self.tok2 + ) + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" + ) + + def test_suspended_user_cannot_knock_on_room(self) -> None: + # set the user as suspended + self.get_success(self.store.set_user_suspended_status(self.user2, True)) + + channel = self.make_request( + "POST", + f"/_matrix/client/v3/knock/{self.room1}", + access_token=self.tok2, + content={}, + shorthand=False, + ) + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" + ) + + def test_suspended_user_cannot_invite_to_room(self) -> None: + # set the user as suspended + self.get_success(self.store.set_user_suspended_status(self.user1, True)) + + # first user invites second user + channel = self.make_request( + "POST", + f"/rooms/{self.room1}/invite", + access_token=self.tok1, + content={"user_id": self.user2}, + ) + self.assertEqual( + channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" + ) + class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase): servlets = [ |