summary refs log tree commit diff
path: root/tests/rest/client
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client')
-rw-r--r--tests/rest/client/test_account.py11
-rw-r--r--tests/rest/client/test_account_data.py75
-rw-r--r--tests/rest/client/test_relations.py207
-rw-r--r--tests/rest/client/test_rooms.py2
-rw-r--r--tests/rest/client/test_sync.py68
-rw-r--r--tests/rest/client/test_third_party_rules.py41
-rw-r--r--tests/rest/client/utils.py11
7 files changed, 202 insertions, 213 deletions
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py

index 27946febff..e00b5c171c 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py
@@ -89,6 +89,17 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.store = hs.get_datastores().main self.submit_token_resource = PasswordResetSubmitTokenResource(hs) + def attempt_wrong_password_login(self, username: str, password: str) -> None: + """Attempts to login as the user with the given password, asserting + that the attempt *fails*. + """ + body = {"type": "m.login.password", "user": username, "password": password} + + channel = self.make_request( + "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") + ) + self.assertEqual(channel.code, 403, channel.result) + def test_basic_password_reset(self) -> None: """Test basic password reset flow""" old_password = "monkey" diff --git a/tests/rest/client/test_account_data.py b/tests/rest/client/test_account_data.py new file mode 100644
index 0000000000..d5b0640e7a --- /dev/null +++ b/tests/rest/client/test_account_data.py
@@ -0,0 +1,75 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import Mock + +from synapse.rest import admin +from synapse.rest.client import account_data, login, room + +from tests import unittest +from tests.test_utils import make_awaitable + + +class AccountDataTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + account_data.register_servlets, + ] + + def test_on_account_data_updated_callback(self) -> None: + """Tests that the on_account_data_updated module callback is called correctly when + a user's account data changes. + """ + mocked_callback = Mock(return_value=make_awaitable(None)) + self.hs.get_account_data_handler()._on_account_data_updated_callbacks.append( + mocked_callback + ) + + user_id = self.register_user("user", "password") + tok = self.login("user", "password") + account_data_type = "org.matrix.foo" + account_data_content = {"bar": "baz"} + + # Change the user's global account data. + channel = self.make_request( + "PUT", + f"/user/{user_id}/account_data/{account_data_type}", + account_data_content, + access_token=tok, + ) + + # Test that the callback is called with the user ID, the new account data, and + # None as the room ID. + self.assertEqual(channel.code, 200, channel.result) + mocked_callback.assert_called_once_with( + user_id, None, account_data_type, account_data_content + ) + + # Change the user's room-specific account data. + room_id = self.helper.create_room_as(user_id, tok=tok) + channel = self.make_request( + "PUT", + f"/user/{user_id}/rooms/{room_id}/account_data/{account_data_type}", + account_data_content, + access_token=tok, + ) + + # Test that the callback is called with the user ID, the room ID and the new + # account data. + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(mocked_callback.call_count, 2) + mocked_callback.assert_called_with( + user_id, room_id, account_data_type, account_data_content + ) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index fe97a0b3dd..419eef166a 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py
@@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import itertools import urllib.parse from typing import Any, Callable, Dict, List, Optional, Tuple from unittest.mock import patch @@ -145,16 +144,6 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) return channel.json_body["unsigned"].get("m.relations", {}) - def _get_aggregations(self) -> List[JsonDict]: - """Request /aggregations on the parent ID and includes the returned chunk.""" - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - return channel.json_body["chunk"] - def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: """ Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. @@ -264,43 +253,6 @@ class RelationsTestCase(BaseRelationsTestCase): expected_response_code=400, ) - def test_aggregation(self) -> None: - """Test that annotations get correctly aggregated.""" - - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual( - channel.json_body, - { - "chunk": [ - {"type": "m.reaction", "key": "a", "count": 2}, - {"type": "m.reaction", "key": "b", "count": 1}, - ] - }, - ) - - def test_aggregation_must_be_annotation(self) -> None: - """Test that aggregations must be annotations.""" - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations" - f"/{self.parent_id}/{RelationTypes.REPLACE}?limit=1", - access_token=self.user_token, - ) - self.assertEqual(400, channel.code, channel.json_body) - def test_ignore_invalid_room(self) -> None: """Test that we ignore invalid relations over federation.""" # Create another room and send a message in it. @@ -394,15 +346,6 @@ class RelationsTestCase(BaseRelationsTestCase): self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(channel.json_body["chunk"], []) - # And when fetching aggregations. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{room2}/aggregations/{parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual(channel.json_body["chunk"], []) - # And for bundled aggregations. channel = self.make_request( "GET", @@ -717,15 +660,6 @@ class RelationsTestCase(BaseRelationsTestCase): self.assertEqual(200, channel.code, channel.json_body) self.assertNotIn("m.relations", channel.json_body["unsigned"]) - # But unknown relations can be directly queried. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual(channel.json_body["chunk"], []) - def test_background_update(self) -> None: """Test the event_arbitrary_relations background update.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") @@ -941,131 +875,6 @@ class RelationPaginationTestCase(BaseRelationsTestCase): annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] ) - def test_aggregation_pagination_groups(self) -> None: - """Test that we can paginate annotation groups correctly.""" - - # We need to create ten separate users to send each reaction. - access_tokens = [self.user_token, self.user2_token] - idx = 0 - while len(access_tokens) < 10: - user_id, token = self._create_user("test" + str(idx)) - idx += 1 - - self.helper.join(self.room, user=user_id, tok=token) - access_tokens.append(token) - - idx = 0 - sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1} - for key in itertools.chain.from_iterable( - itertools.repeat(key, num) for key, num in sent_groups.items() - ): - self._send_relation( - RelationTypes.ANNOTATION, - "m.reaction", - key=key, - access_token=access_tokens[idx], - ) - - idx += 1 - idx %= len(access_tokens) - - prev_token: Optional[str] = None - found_groups: Dict[str, int] = {} - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - - for groups in channel.json_body["chunk"]: - # We only expect reactions - self.assertEqual(groups["type"], "m.reaction", channel.json_body) - - # We should only see each key once - self.assertNotIn(groups["key"], found_groups, channel.json_body) - - found_groups[groups["key"]] = groups["count"] - - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - self.assertEqual(sent_groups, found_groups) - - def test_aggregation_pagination_within_group(self) -> None: - """Test that we can paginate within an annotation group.""" - - # We need to create ten separate users to send each reaction. - access_tokens = [self.user_token, self.user2_token] - idx = 0 - while len(access_tokens) < 10: - user_id, token = self._create_user("test" + str(idx)) - idx += 1 - - self.helper.join(self.room, user=user_id, tok=token) - access_tokens.append(token) - - idx = 0 - expected_event_ids = [] - for _ in range(10): - channel = self._send_relation( - RelationTypes.ANNOTATION, - "m.reaction", - key="👍", - access_token=access_tokens[idx], - ) - expected_event_ids.append(channel.json_body["event_id"]) - - idx += 1 - - # Also send a different type of reaction so that we test we don't see it - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - - prev_token = "" - found_event_ids: List[str] = [] - encoded_key = urllib.parse.quote_plus("👍".encode()) - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}" - f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}" - f"/m.reaction/{encoded_key}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - - found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - # We paginated backwards, so reverse - found_event_ids.reverse() - self.assertEqual(found_event_ids, expected_event_ids) - class BundledAggregationsTestCase(BaseRelationsTestCase): """ @@ -1453,10 +1262,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): {"chunk": [{"type": "m.reaction", "key": "a", "count": 2}]}, ) - # Both relations appear in the aggregation. - chunk = self._get_aggregations() - self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 2}]) - # Redact one of the reactions. self._redact(to_redact_event_id) @@ -1469,10 +1274,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, ) - # The unredacted aggregation should still exist. - chunk = self._get_aggregations() - self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 1}]) - def test_redact_relation_thread(self) -> None: """ Test that thread replies are properly handled after the thread reply redacted. @@ -1578,10 +1379,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): self.assertEqual(len(event_ids), 1) self.assertIn(RelationTypes.ANNOTATION, relations) - # The aggregation should exist. - chunk = self._get_aggregations() - self.assertEqual(chunk, [{"type": "m.reaction", "key": "👍", "count": 1}]) - # Redact the original event. self._redact(self.parent_id) @@ -1594,10 +1391,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]}, ) - # There's nothing to aggregate. - chunk = self._get_aggregations() - self.assertEqual(chunk, [{"count": 1, "key": "👍", "type": "m.reaction"}]) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_redact_parent_thread(self) -> None: """ diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 3a9617d6da..6ff79b9e2e 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py
@@ -982,7 +982,7 @@ class RoomJoinRatelimitTestCase(RoomBase): super().prepare(reactor, clock, hs) # profile changes expect that the user is actually registered user = UserID.from_string(self.user_id) - self.get_success(self.register_user(user.localpart, "supersecretpassword")) + self.register_user(user.localpart, "supersecretpassword") @unittest.override_config( {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 4351013952..773c16a54c 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py
@@ -341,7 +341,6 @@ class SyncKnockTestCase( hs, self.room_id, self.user_id ) - @override_config({"experimental_features": {"msc2403_enabled": True}}) def test_knock_room_state(self) -> None: """Tests that /sync returns state from a room after knocking on it.""" # Knock on a room @@ -497,6 +496,11 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): receipts.register_servlets, ] + def default_config(self) -> JsonDict: + config = super().default_config() + config["experimental_features"] = {"msc2654_enabled": True} + return config + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.url = "/sync?since=%s" self.next_batch = "s0" @@ -772,3 +776,65 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase): self.assertIn( self.user_id, device_list_changes, incremental_sync_channel.json_body ) + + +class ExcludeRoomTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + sync.register_servlets, + room.register_servlets, + ] + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + self.user_id = self.register_user("user", "password") + self.tok = self.login("user", "password") + + self.excluded_room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + self.included_room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + # We need to manually append the room ID, because we can't know the ID before + # creating the room, and we can't set the config after starting the homeserver. + self.hs.get_sync_handler().rooms_to_exclude.append(self.excluded_room_id) + + def test_join_leave(self) -> None: + """Tests that rooms are correctly excluded from the 'join' and 'leave' sections of + sync responses. + """ + channel = self.make_request("GET", "/sync", access_token=self.tok) + self.assertEqual(channel.code, 200, channel.result) + + self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["join"]) + self.assertIn(self.included_room_id, channel.json_body["rooms"]["join"]) + + self.helper.leave(self.excluded_room_id, self.user_id, tok=self.tok) + self.helper.leave(self.included_room_id, self.user_id, tok=self.tok) + + channel = self.make_request( + "GET", + "/sync?since=" + channel.json_body["next_batch"], + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["leave"]) + self.assertIn(self.included_room_id, channel.json_body["rooms"]["leave"]) + + def test_invite(self) -> None: + """Tests that rooms are correctly excluded from the 'invite' section of sync + responses. + """ + invitee = self.register_user("invitee", "password") + invitee_tok = self.login("invitee", "password") + + self.helper.invite(self.excluded_room_id, self.user_id, invitee, tok=self.tok) + self.helper.invite(self.included_room_id, self.user_id, invitee, tok=self.tok) + + channel = self.make_request("GET", "/sync", access_token=invitee_tok) + self.assertEqual(channel.code, 200, channel.result) + + self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["invite"]) + self.assertIn(self.included_room_id, channel.json_body["rooms"]["invite"]) diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index e7de67e3a3..5eb0f243f7 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py
@@ -896,3 +896,44 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): # Check that the mock was called with the right room ID self.assertEqual(args[1], self.room_id) + + def test_on_threepid_bind(self) -> None: + """Tests that the on_threepid_bind module callback is called correctly after + associating a 3PID to an account. + """ + # Register a mocked callback. + threepid_bind_mock = Mock(return_value=make_awaitable(None)) + third_party_rules = self.hs.get_third_party_event_rules() + third_party_rules._on_threepid_bind_callbacks.append(threepid_bind_mock) + + # Register an admin user. + self.register_user("admin", "password", admin=True) + admin_tok = self.login("admin", "password") + + # Also register a normal user we can modify. + user_id = self.register_user("user", "password") + + # Add a 3PID to the user. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % user_id, + { + "threepids": [ + { + "medium": "email", + "address": "foo@example.com", + }, + ], + }, + access_token=admin_tok, + ) + + # Check that the shutdown was blocked + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the mock was called once. + threepid_bind_mock.assert_called_once() + args = threepid_bind_mock.call_args[0] + + # Check that the mock was called with the right parameters + self.assertEqual(args, (user_id, "email", "foo@example.com")) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 28663826fc..a0788b1bb0 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py
@@ -88,7 +88,7 @@ class RestHelper: def create_room_as( self, room_creator: Optional[str] = None, - is_public: Optional[bool] = None, + is_public: Optional[bool] = True, room_version: Optional[str] = None, tok: Optional[str] = None, expect_code: int = HTTPStatus.OK, @@ -101,9 +101,12 @@ class RestHelper: Args: room_creator: The user ID to create the room with. is_public: If True, the `visibility` parameter will be set to - "public". If False, it will be set to "private". If left - unspecified, the server will set it to an appropriate default - (which should be "private" as per the CS spec). + "public". If False, it will be set to "private". + If None, doesn't specify the `visibility` parameter in which + case the server is supposed to make the room private according to + the CS API. + Defaults to public, since that is commonly needed in tests + for convenience where room privacy is not a problem. room_version: The room version to create the room as. Defaults to Synapse's default room version. tok: The access token to use in the request.