diff options
Diffstat (limited to 'tests/rest/client')
-rw-r--r-- | tests/rest/client/test_capabilities.py | 69 | ||||
-rw-r--r-- | tests/rest/client/test_register.py | 2 | ||||
-rw-r--r-- | tests/rest/client/test_relations.py | 227 | ||||
-rw-r--r-- | tests/rest/client/test_room_batch.py | 2 | ||||
-rw-r--r-- | tests/rest/client/test_rooms.py | 119 | ||||
-rw-r--r-- | tests/rest/client/test_sendtodevice.py | 40 | ||||
-rw-r--r-- | tests/rest/client/test_sync.py | 57 | ||||
-rw-r--r-- | tests/rest/client/test_third_party_rules.py | 6 | ||||
-rw-r--r-- | tests/rest/client/utils.py | 31 |
9 files changed, 355 insertions, 198 deletions
diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py index 249808b031..989e801768 100644 --- a/tests/rest/client/test_capabilities.py +++ b/tests/rest/client/test_capabilities.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus + import synapse.rest.admin from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.rest.client import capabilities, login @@ -28,7 +30,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor, clock): - self.url = b"/_matrix/client/r0/capabilities" + self.url = b"/capabilities" hs = self.setup_test_homeserver() self.config = hs.config self.auth_handler = hs.get_auth_handler() @@ -96,39 +98,20 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertFalse(capabilities["m.change_password"]["enabled"]) - def test_get_change_users_attributes_capabilities_when_msc3283_disabled(self): - """Test that per default msc3283 is disabled server returns `m.change_password`.""" + def test_get_change_users_attributes_capabilities(self): + """Test that server returns capabilities by default.""" access_token = self.login(self.localpart, self.password) channel = self.make_request("GET", self.url, access_token=access_token) capabilities = channel.json_body["capabilities"] - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) self.assertTrue(capabilities["m.change_password"]["enabled"]) - self.assertNotIn("org.matrix.msc3283.set_displayname", capabilities) - self.assertNotIn("org.matrix.msc3283.set_avatar_url", capabilities) - self.assertNotIn("org.matrix.msc3283.3pid_changes", capabilities) - - @override_config({"experimental_features": {"msc3283_enabled": True}}) - def test_get_change_users_attributes_capabilities_when_msc3283_enabled(self): - """Test if msc3283 is enabled server returns capabilities.""" - access_token = self.login(self.localpart, self.password) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] + self.assertTrue(capabilities["m.set_displayname"]["enabled"]) + self.assertTrue(capabilities["m.set_avatar_url"]["enabled"]) + self.assertTrue(capabilities["m.3pid_changes"]["enabled"]) - self.assertEqual(channel.code, 200) - self.assertTrue(capabilities["m.change_password"]["enabled"]) - self.assertTrue(capabilities["org.matrix.msc3283.set_displayname"]["enabled"]) - self.assertTrue(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"]) - self.assertTrue(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"]) - - @override_config( - { - "enable_set_displayname": False, - "experimental_features": {"msc3283_enabled": True}, - } - ) + @override_config({"enable_set_displayname": False}) def test_get_set_displayname_capabilities_displayname_disabled(self): """Test if set displayname is disabled that the server responds it.""" access_token = self.login(self.localpart, self.password) @@ -136,15 +119,10 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", self.url, access_token=access_token) capabilities = channel.json_body["capabilities"] - self.assertEqual(channel.code, 200) - self.assertFalse(capabilities["org.matrix.msc3283.set_displayname"]["enabled"]) - - @override_config( - { - "enable_set_avatar_url": False, - "experimental_features": {"msc3283_enabled": True}, - } - ) + self.assertEqual(channel.code, HTTPStatus.OK) + self.assertFalse(capabilities["m.set_displayname"]["enabled"]) + + @override_config({"enable_set_avatar_url": False}) def test_get_set_avatar_url_capabilities_avatar_url_disabled(self): """Test if set avatar_url is disabled that the server responds it.""" access_token = self.login(self.localpart, self.password) @@ -152,24 +130,19 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", self.url, access_token=access_token) capabilities = channel.json_body["capabilities"] - self.assertEqual(channel.code, 200) - self.assertFalse(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"]) - - @override_config( - { - "enable_3pid_changes": False, - "experimental_features": {"msc3283_enabled": True}, - } - ) - def test_change_3pid_capabilities_3pid_disabled(self): + self.assertEqual(channel.code, HTTPStatus.OK) + self.assertFalse(capabilities["m.set_avatar_url"]["enabled"]) + + @override_config({"enable_3pid_changes": False}) + def test_get_change_3pid_capabilities_3pid_disabled(self): """Test if change 3pid is disabled that the server responds it.""" access_token = self.login(self.localpart, self.password) channel = self.make_request("GET", self.url, access_token=access_token) capabilities = channel.json_body["capabilities"] - self.assertEqual(channel.code, 200) - self.assertFalse(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"]) + self.assertEqual(channel.code, HTTPStatus.OK) + self.assertFalse(capabilities["m.3pid_changes"]["enabled"]) @override_config({"experimental_features": {"msc3244_enabled": False}}) def test_get_does_not_include_msc3244_fields_when_disabled(self): diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 407dd32a73..0f1c47dcbb 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -1154,7 +1154,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): servlets = [register.register_servlets] - url = "/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity" + url = "/_matrix/client/v1/register/m.login.registration_token/validity" def default_config(self): config = super().default_config() diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 96ae7790bb..de80aca037 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -21,7 +21,8 @@ from unittest.mock import patch from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin from synapse.rest.client import login, register, relations, room, sync -from synapse.types import JsonDict +from synapse.storage.relations import RelationPaginationToken +from synapse.types import JsonDict, StreamToken from tests import unittest from tests.server import FakeChannel @@ -168,24 +169,28 @@ class RelationsTestCase(unittest.HomeserverTestCase): """Tests that calling pagination API correctly the latest relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(200, channel.code, channel.json_body) + first_annotation_id = channel.json_body["event_id"] channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") self.assertEquals(200, channel.code, channel.json_body) - annotation_id = channel.json_body["event_id"] + second_annotation_id = channel.json_body["event_id"] channel = self.make_request( "GET", - "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1" - % (self.room, self.parent_id), + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) - # We expect to get back a single pagination result, which is the full - # relation event we sent above. + # We expect to get back a single pagination result, which is the latest + # full relation event we sent above. self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) self.assert_dict( - {"event_id": annotation_id, "sender": self.user_id, "type": "m.reaction"}, + { + "event_id": second_annotation_id, + "sender": self.user_id, + "type": "m.reaction", + }, channel.json_body["chunk"][0], ) @@ -200,6 +205,36 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel.json_body.get("next_batch"), str, channel.json_body ) + # Request the relations again, but with a different direction. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations" + f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # We expect to get back a single pagination result, which is the earliest + # full relation event we sent above. + self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assert_dict( + { + "event_id": first_annotation_id, + "sender": self.user_id, + "type": "m.reaction", + }, + channel.json_body["chunk"][0], + ) + + def _stream_token_to_relation_token(self, token: str) -> str: + """Convert a StreamToken into a legacy token (RelationPaginationToken).""" + room_key = self.get_success(StreamToken.from_string(self.store, token)).room_key + return self.get_success( + RelationPaginationToken( + topological=room_key.topological, stream=room_key.stream + ).to_string(self.store) + ) + def test_repeated_paginate_relations(self): """Test that if we paginate using a limit and tokens then we get the expected events. @@ -213,7 +248,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) expected_event_ids.append(channel.json_body["event_id"]) - prev_token: Optional[str] = None + prev_token = "" found_event_ids: List[str] = [] for _ in range(20): from_token = "" @@ -222,8 +257,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s" - % (self.room, self.parent_id, from_token), + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) @@ -241,6 +275,93 @@ class RelationsTestCase(unittest.HomeserverTestCase): found_event_ids.reverse() self.assertEquals(found_event_ids, expected_event_ids) + # Reset and try again, but convert the tokens to the legacy format. + prev_token = "" + found_event_ids = [] + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + self._stream_token_to_relation_token(prev_token) + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) + next_batch = channel.json_body.get("next_batch") + + self.assertNotEquals(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + # We paginated backwards, so reverse + found_event_ids.reverse() + self.assertEquals(found_event_ids, expected_event_ids) + + def test_pagination_from_sync_and_messages(self): + """Pagination tokens from /sync and /messages can be used to paginate /relations.""" + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A") + self.assertEquals(200, channel.code, channel.json_body) + annotation_id = channel.json_body["event_id"] + # Send an event after the relation events. + self.helper.send(self.room, body="Latest event", tok=self.user_token) + + # Request /sync, limiting it such that only the latest event is returned + # (and not the relation). + filter = urllib.parse.quote_plus( + '{"room": {"timeline": {"limit": 1}}}'.encode() + ) + channel = self.make_request( + "GET", f"/sync?filter={filter}", access_token=self.user_token + ) + self.assertEquals(200, channel.code, channel.json_body) + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + sync_prev_batch = room_timeline["prev_batch"] + self.assertIsNotNone(sync_prev_batch) + # Ensure the relation event is not in the batch returned from /sync. + self.assertNotIn( + annotation_id, [ev["event_id"] for ev in room_timeline["events"]] + ) + + # Request /messages, limiting it such that only the latest event is + # returned (and not the relation). + channel = self.make_request( + "GET", + f"/rooms/{self.room}/messages?dir=b&limit=1", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + messages_end = channel.json_body["end"] + self.assertIsNotNone(messages_end) + # Ensure the relation event is not in the chunk returned from /messages. + self.assertNotIn( + annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] + ) + + # Request /relations with the pagination tokens received from both the + # /sync and /messages responses above, in turn. + # + # This is a tiny bit silly since the client wouldn't know the parent ID + # from the requests above; consider the parent ID to be known from a + # previous /sync. + for from_token in (sync_prev_batch, messages_end): + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # The relation should be in the returned chunk. + self.assertIn( + annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] + ) + def test_aggregation_pagination_groups(self): """Test that we can paginate annotation groups correctly.""" @@ -337,7 +458,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") self.assertEquals(200, channel.code, channel.json_body) - prev_token: Optional[str] = None + prev_token = "" found_event_ids: List[str] = [] encoded_key = urllib.parse.quote_plus("👍".encode()) for _ in range(20): @@ -347,15 +468,42 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "/_matrix/client/unstable/rooms/%s" - "/aggregations/%s/%s/m.reaction/%s?limit=1%s" - % ( - self.room, - self.parent_id, - RelationTypes.ANNOTATION, - encoded_key, - from_token, - ), + f"/_matrix/client/unstable/rooms/{self.room}" + f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}" + f"/m.reaction/{encoded_key}?limit=1{from_token}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) + + next_batch = channel.json_body.get("next_batch") + + self.assertNotEquals(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + # We paginated backwards, so reverse + found_event_ids.reverse() + self.assertEquals(found_event_ids, expected_event_ids) + + # Reset and try again, but convert the tokens to the legacy format. + prev_token = "" + found_event_ids = [] + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + self._stream_token_to_relation_token(prev_token) + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}" + f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}" + f"/m.reaction/{encoded_key}?limit=1{from_token}", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) @@ -453,7 +601,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) self.assertEquals(400, channel.code, channel.json_body) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) + @unittest.override_config( + {"experimental_features": {"msc3440_enabled": True, "msc3666_enabled": True}} + ) def test_bundled_aggregations(self): """ Test that annotations, references, and threads get correctly bundled. @@ -579,6 +729,23 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertTrue(room_timeline["limited"]) assert_bundle(self._find_event_in_chunk(room_timeline["events"])) + # Request search. + channel = self.make_request( + "POST", + "/search", + # Search term matches the parent message. + content={"search_categories": {"room_events": {"search_term": "Hi"}}}, + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + chunk = [ + result["result"] + for result in channel.json_body["search_categories"]["room_events"][ + "results" + ] + ] + assert_bundle(self._find_event_in_chunk(chunk)) + def test_aggregation_get_event_for_annotation(self): """Test that annotations do not get bundled aggregations included when directly requested. @@ -759,6 +926,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) self.assertNotIn("m.relations", channel.json_body["unsigned"]) + @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) def test_edit(self): """Test that a simple edit works.""" @@ -825,6 +993,23 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertTrue(room_timeline["limited"]) assert_bundle(self._find_event_in_chunk(room_timeline["events"])) + # Request search. + channel = self.make_request( + "POST", + "/search", + # Search term matches the parent message. + content={"search_categories": {"room_events": {"search_term": "Hi"}}}, + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + chunk = [ + result["result"] + for result in channel.json_body["search_categories"]["room_events"][ + "results" + ] + ] + assert_bundle(self._find_event_in_chunk(chunk)) + def test_multi_edit(self): """Test that multiple edits, including attempts by people who shouldn't be allowed, are correctly handled. diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py index 721454c187..e9f8704035 100644 --- a/tests/rest/client/test_room_batch.py +++ b/tests/rest/client/test_room_batch.py @@ -89,7 +89,7 @@ class RoomBatchTestCase(unittest.HomeserverTestCase): self.clock = clock self.storage = hs.get_storage() - self.virtual_user_id = self.register_appservice_user( + self.virtual_user_id, _ = self.register_appservice_user( "as_user_potato", self.appservice.token ) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 10a4a4dc5e..b7f086927b 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -18,7 +18,7 @@ """Tests REST events for /rooms paths.""" import json -from typing import Dict, Iterable, List, Optional +from typing import Iterable, List from unittest.mock import Mock, call from urllib import parse as urlparse @@ -35,7 +35,7 @@ from synapse.api.errors import Codes, HttpResponseException from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin from synapse.rest.client import account, directory, login, profile, room, sync -from synapse.types import JsonDict, Requester, RoomAlias, UserID, create_requester +from synapse.types import JsonDict, RoomAlias, UserID, create_requester from synapse.util.stringutils import random_string from tests import unittest @@ -674,121 +674,6 @@ class RoomsCreateTestCase(RoomBase): channel = self.make_request("POST", "/createRoom", content) self.assertEqual(200, channel.code) - def test_spamchecker_invites(self): - """Tests the user_may_create_room_with_invites spam checker callback.""" - - # Mock do_3pid_invite, so we don't fail from failing to send a 3PID invite to an - # IS. - async def do_3pid_invite( - room_id: str, - inviter: UserID, - medium: str, - address: str, - id_server: str, - requester: Requester, - txn_id: Optional[str], - id_access_token: Optional[str] = None, - ) -> int: - return 0 - - do_3pid_invite_mock = Mock(side_effect=do_3pid_invite) - self.hs.get_room_member_handler().do_3pid_invite = do_3pid_invite_mock - - # Add a mock callback for user_may_create_room_with_invites. Make it allow any - # room creation request for now. - return_value = True - - async def user_may_create_room_with_invites( - user: str, - invites: List[str], - threepid_invites: List[Dict[str, str]], - ) -> bool: - return return_value - - callback_mock = Mock(side_effect=user_may_create_room_with_invites) - self.hs.get_spam_checker()._user_may_create_room_with_invites_callbacks.append( - callback_mock, - ) - - # The MXIDs we'll try to invite. - invited_mxids = [ - "@alice1:red", - "@alice2:red", - "@alice3:red", - "@alice4:red", - ] - - # The 3PIDs we'll try to invite. - invited_3pids = [ - { - "id_server": "example.com", - "id_access_token": "sometoken", - "medium": "email", - "address": "alice1@example.com", - }, - { - "id_server": "example.com", - "id_access_token": "sometoken", - "medium": "email", - "address": "alice2@example.com", - }, - { - "id_server": "example.com", - "id_access_token": "sometoken", - "medium": "email", - "address": "alice3@example.com", - }, - ] - - # Create a room and invite the Matrix users, and check that it succeeded. - channel = self.make_request( - "POST", - "/createRoom", - json.dumps({"invite": invited_mxids}).encode("utf8"), - ) - self.assertEqual(200, channel.code) - - # Check that the callback was called with the right arguments. - expected_call_args = ((self.user_id, invited_mxids, []),) - self.assertEquals( - callback_mock.call_args, - expected_call_args, - callback_mock.call_args, - ) - - # Create a room and invite the 3PIDs, and check that it succeeded. - channel = self.make_request( - "POST", - "/createRoom", - json.dumps({"invite_3pid": invited_3pids}).encode("utf8"), - ) - self.assertEqual(200, channel.code) - - # Check that do_3pid_invite was called the right amount of time - self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids)) - - # Check that the callback was called with the right arguments. - expected_call_args = ((self.user_id, [], invited_3pids),) - self.assertEquals( - callback_mock.call_args, - expected_call_args, - callback_mock.call_args, - ) - - # Now deny any room creation. - return_value = False - - # Create a room and invite the 3PIDs, and check that it failed. - channel = self.make_request( - "POST", - "/createRoom", - json.dumps({"invite_3pid": invited_3pids}).encode("utf8"), - ) - self.assertEqual(403, channel.code) - - # Check that do_3pid_invite wasn't called this time. - self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids)) - def test_spam_checker_may_join_room(self): """Tests that the user_may_join_room spam checker callback is correctly bypassed when creating a new room. diff --git a/tests/rest/client/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py index 6db7062a8e..e2ed14457f 100644 --- a/tests/rest/client/test_sendtodevice.py +++ b/tests/rest/client/test_sendtodevice.py @@ -198,3 +198,43 @@ class SendToDeviceTestCase(HomeserverTestCase): "content": {"idx": 3}, }, ) + + def test_limited_sync(self): + """If a limited sync for to-devices happens the next /sync should respond immediately.""" + + self.register_user("u1", "pass") + user1_tok = self.login("u1", "pass", "d1") + + user2 = self.register_user("u2", "pass") + user2_tok = self.login("u2", "pass", "d2") + + # Do an initial sync + channel = self.make_request("GET", "/sync", access_token=user2_tok) + self.assertEqual(channel.code, 200, channel.result) + sync_token = channel.json_body["next_batch"] + + # Send 150 to-device messages. We limit to 100 in `/sync` + for i in range(150): + test_msg = {"foo": "bar"} + chan = self.make_request( + "PUT", + f"/_matrix/client/r0/sendToDevice/m.test/1234-{i}", + content={"messages": {user2: {"d2": test_msg}}}, + access_token=user1_tok, + ) + self.assertEqual(chan.code, 200, chan.result) + + channel = self.make_request( + "GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok + ) + self.assertEqual(channel.code, 200, channel.result) + messages = channel.json_body.get("to_device", {}).get("events", []) + self.assertEqual(len(messages), 100) + sync_token = channel.json_body["next_batch"] + + channel = self.make_request( + "GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok + ) + self.assertEqual(channel.code, 200, channel.result) + messages = channel.json_body.get("to_device", {}).get("events", []) + self.assertEqual(len(messages), 50) diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index c427686376..cd4af2b1f3 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -23,7 +23,7 @@ from synapse.api.constants import ( ReadReceiptEventFields, RelationTypes, ) -from synapse.rest.client import knock, login, read_marker, receipts, room, sync +from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync from tests import unittest from tests.federation.transport.test_knocking import ( @@ -710,3 +710,58 @@ class SyncCacheTestCase(unittest.HomeserverTestCase): channel.await_result(timeout_ms=9900) channel.await_result(timeout_ms=200) self.assertEqual(channel.code, 200, channel.json_body) + + +class DeviceListSyncTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + sync.register_servlets, + devices.register_servlets, + ] + + def test_user_with_no_rooms_receives_self_device_list_updates(self): + """Tests that a user with no rooms still receives their own device list updates""" + device_id = "TESTDEVICE" + + # Register a user and login, creating a device + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey", device_id=device_id) + + # Request an initial sync + channel = self.make_request("GET", "/sync", access_token=self.tok) + self.assertEqual(channel.code, 200, channel.json_body) + next_batch = channel.json_body["next_batch"] + + # Now, make an incremental sync request. + # It won't return until something has happened + incremental_sync_channel = self.make_request( + "GET", + f"/sync?since={next_batch}&timeout=30000", + access_token=self.tok, + await_result=False, + ) + + # Change our device's display name + channel = self.make_request( + "PUT", + f"devices/{device_id}", + { + "display_name": "freeze ray", + }, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # The sync should now have returned + incremental_sync_channel.await_result(timeout_ms=20000) + self.assertEqual(incremental_sync_channel.code, 200, channel.json_body) + + # We should have received notification that the (user's) device has changed + device_list_changes = incremental_sync_channel.json_body.get( + "device_lists", {} + ).get("changed", []) + + self.assertIn( + self.user_id, device_list_changes, incremental_sync_channel.json_body + ) diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index 4e71b6ec12..ac6b86ff6b 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -107,6 +107,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): return hs def prepare(self, reactor, clock, homeserver): + super().prepare(reactor, clock, homeserver) # Create some users and a room to play with during the tests self.user_id = self.register_user("kermit", "monkey") self.invitee = self.register_user("invitee", "hackme") @@ -473,8 +474,6 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): def _send_event_over_federation(self) -> None: """Send a dummy event over federation and check that the request succeeds.""" body = { - "origin": self.hs.config.server.server_name, - "origin_server_ts": self.clock.time_msec(), "pdus": [ { "sender": self.user_id, @@ -492,11 +491,10 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): ], } - channel = self.make_request( + channel = self.make_signed_federation_request( method="PUT", path="/_matrix/federation/v1/send/1", content=body, - federation_auth_origin=self.hs.config.server.server_name.encode("utf8"), ) self.assertEqual(channel.code, 200, channel.result) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 8424383580..1c0cb0cf4f 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -31,6 +31,7 @@ from typing import ( overload, ) from unittest.mock import patch +from urllib.parse import urlencode import attr from typing_extensions import Literal @@ -147,12 +148,20 @@ class RestHelper: expect_code=expect_code, ) - def join(self, room=None, user=None, expect_code=200, tok=None): + def join( + self, + room: str, + user: Optional[str] = None, + expect_code: int = 200, + tok: Optional[str] = None, + appservice_user_id: Optional[str] = None, + ) -> None: self.change_membership( room=room, src=user, targ=user, tok=tok, + appservice_user_id=appservice_user_id, membership=Membership.JOIN, expect_code=expect_code, ) @@ -209,11 +218,12 @@ class RestHelper: def change_membership( self, room: str, - src: str, - targ: str, + src: Optional[str], + targ: Optional[str], membership: str, extra_data: Optional[dict] = None, tok: Optional[str] = None, + appservice_user_id: Optional[str] = None, expect_code: int = 200, expect_errcode: Optional[str] = None, ) -> None: @@ -227,15 +237,26 @@ class RestHelper: membership: The type of membership event extra_data: Extra information to include in the content of the event tok: The user access token to use + appservice_user_id: The `user_id` URL parameter to pass. + This allows driving an application service user + using an application service access token in `tok`. expect_code: The expected HTTP response code expect_errcode: The expected Matrix error code """ temp_id = self.auth_user_id self.auth_user_id = src - path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (room, targ) + path = f"/_matrix/client/r0/rooms/{room}/state/m.room.member/{targ}" + url_params: Dict[str, str] = {} + if tok: - path = path + "?access_token=%s" % tok + url_params["access_token"] = tok + + if appservice_user_id: + url_params["user_id"] = appservice_user_id + + if url_params: + path += "?" + urlencode(url_params) data = {"membership": membership} data.update(extra_data or {}) |