diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/handlers/test_e2e_keys.py | 96 | ||||
-rw-r--r-- | tests/rest/admin/test_federation.py | 75 | ||||
-rw-r--r-- | tests/rest/client/test_relations.py | 26 | ||||
-rw-r--r-- | tests/rest/client/test_retention.py | 2 | ||||
-rw-r--r-- | tests/test_federation.py | 72 | ||||
-rw-r--r-- | tests/test_utils/__init__.py | 4 |
6 files changed, 241 insertions, 34 deletions
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index ddcf3ee348..734ed84d78 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -13,8 +13,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Iterable from unittest import mock +from parameterized import parameterized from signedjson import key as key, sign as sign from twisted.internet import defer @@ -23,6 +25,7 @@ from synapse.api.constants import RoomEncryptionAlgorithms from synapse.api.errors import Codes, SynapseError from tests import unittest +from tests.test_utils import make_awaitable class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): @@ -765,6 +768,8 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): remote_user_id = "@test:other" local_user_id = "@test:test" + # Pretend we're sharing a room with the user we're querying. If not, + # `_query_devices_for_destination` will return early. self.store.get_rooms_for_user = mock.Mock( return_value=defer.succeed({"some_room_id"}) ) @@ -831,3 +836,94 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): } }, ) + + @parameterized.expand( + [ + # The remote homeserver's response indicates that this user has 0/1/2 devices. + ([],), + (["device_1"],), + (["device_1", "device_2"],), + ] + ) + def test_query_all_devices_caches_result(self, device_ids: Iterable[str]): + """Test that requests for all of a remote user's devices are cached. + + We do this by asserting that only one call over federation was made, and that + the two queries to the local homeserver produce the same response. + """ + local_user_id = "@test:test" + remote_user_id = "@test:other" + request_body = {"device_keys": {remote_user_id: []}} + + response_devices = [ + { + "device_id": device_id, + "keys": { + "algorithms": ["dummy"], + "device_id": device_id, + "keys": {f"dummy:{device_id}": "dummy"}, + "signatures": {device_id: {f"dummy:{device_id}": "dummy"}}, + "unsigned": {}, + "user_id": "@test:other", + }, + } + for device_id in device_ids + ] + + response_body = { + "devices": response_devices, + "user_id": remote_user_id, + "stream_id": 12345, # an integer, according to the spec + } + + e2e_handler = self.hs.get_e2e_keys_handler() + + # Pretend we're sharing a room with the user we're querying. If not, + # `_query_devices_for_destination` will return early. + mock_get_rooms = mock.patch.object( + self.store, + "get_rooms_for_user", + new_callable=mock.MagicMock, + return_value=make_awaitable(["some_room_id"]), + ) + mock_request = mock.patch.object( + self.hs.get_federation_client(), + "query_user_devices", + new_callable=mock.MagicMock, + return_value=make_awaitable(response_body), + ) + + with mock_get_rooms, mock_request as mocked_federation_request: + # Make the first query and sanity check it succeeds. + response_1 = self.get_success( + e2e_handler.query_devices( + request_body, + timeout=10, + from_user_id=local_user_id, + from_device_id="some_device_id", + ) + ) + self.assertEqual(response_1["failures"], {}) + + # We should have made a federation request to do so. + mocked_federation_request.assert_called_once() + + # Reset the mock so we can prove we don't make a second federation request. + mocked_federation_request.reset_mock() + + # Repeat the query. + response_2 = self.get_success( + e2e_handler.query_devices( + request_body, + timeout=10, + from_user_id=local_user_id, + from_device_id="some_device_id", + ) + ) + self.assertEqual(response_2["failures"], {}) + + # We should not have made a second federation request. + mocked_federation_request.assert_not_called() + + # The two requests to the local homeserver should be identical. + self.assertEqual(response_1, response_2) diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py index 742f194257..b70350b6f1 100644 --- a/tests/rest/admin/test_federation.py +++ b/tests/rest/admin/test_federation.py @@ -314,15 +314,12 @@ class FederationTestCase(unittest.HomeserverTestCase): retry_interval, last_successful_stream_ordering, ) in dest: - self.get_success( - self.store.set_destination_retry_timings( - destination, failure_ts, retry_last_ts, retry_interval - ) - ) - self.get_success( - self.store.set_destination_last_successful_stream_ordering( - destination, last_successful_stream_ordering - ) + self._create_destination( + destination, + failure_ts, + retry_last_ts, + retry_interval, + last_successful_stream_ordering, ) # order by default (destination) @@ -413,11 +410,9 @@ class FederationTestCase(unittest.HomeserverTestCase): _search_test(None, "foo") _search_test(None, "bar") - def test_get_single_destination(self) -> None: - """ - Get one specific destinations. - """ - self._create_destinations(5) + def test_get_single_destination_with_retry_timings(self) -> None: + """Get one specific destination which has retry timings.""" + self._create_destinations(1) channel = self.make_request( "GET", @@ -432,6 +427,53 @@ class FederationTestCase(unittest.HomeserverTestCase): # convert channel.json_body into a List self._check_fields([channel.json_body]) + def test_get_single_destination_no_retry_timings(self) -> None: + """Get one specific destination which has no retry timings.""" + self._create_destination("sub0.example.com") + + channel = self.make_request( + "GET", + self.url + "/sub0.example.com", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual("sub0.example.com", channel.json_body["destination"]) + self.assertEqual(0, channel.json_body["retry_last_ts"]) + self.assertEqual(0, channel.json_body["retry_interval"]) + self.assertIsNone(channel.json_body["failure_ts"]) + self.assertIsNone(channel.json_body["last_successful_stream_ordering"]) + + def _create_destination( + self, + destination: str, + failure_ts: Optional[int] = None, + retry_last_ts: int = 0, + retry_interval: int = 0, + last_successful_stream_ordering: Optional[int] = None, + ) -> None: + """Create one specific destination + + Args: + destination: the destination we have successfully sent to + failure_ts: when the server started failing (ms since epoch) + retry_last_ts: time of last retry attempt in unix epoch ms + retry_interval: how long until next retry in ms + last_successful_stream_ordering: the stream_ordering of the most + recent successfully-sent PDU + """ + self.get_success( + self.store.set_destination_retry_timings( + destination, failure_ts, retry_last_ts, retry_interval + ) + ) + if last_successful_stream_ordering is not None: + self.get_success( + self.store.set_destination_last_successful_stream_ordering( + destination, last_successful_stream_ordering + ) + ) + def _create_destinations(self, number_destinations: int) -> None: """Create a number of destinations @@ -440,10 +482,7 @@ class FederationTestCase(unittest.HomeserverTestCase): """ for i in range(0, number_destinations): dest = f"sub{i}.example.com" - self.get_success(self.store.set_destination_retry_timings(dest, 50, 50, 50)) - self.get_success( - self.store.set_destination_last_successful_stream_ordering(dest, 100) - ) + self._create_destination(dest, 50, 50, 50, 100) def _check_fields(self, content: List[JsonDict]) -> None: """Checks that the expected destination attributes are present in content diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index c026d526ef..ff4e81d069 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -93,11 +93,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel.json_body, ) - def test_deny_membership(self): - """Test that we deny relations on membership events""" - channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member) - self.assertEquals(400, channel.code, channel.json_body) - def test_deny_invalid_event(self): """Test that we deny relations on non-existant events""" channel = self._send_relation( @@ -1119,7 +1114,8 @@ class RelationsTestCase(unittest.HomeserverTestCase): relation_type: One of `RelationTypes` event_type: The type of the event to create key: The aggregation key used for m.annotation relation type. - content: The content of the created event. + content: The content of the created event. Will be modified to configure + the m.relates_to key based on the other provided parameters. access_token: The access token used to send the relation, defaults to `self.user_token` parent_id: The event_id this relation relates to. If None, then self.parent_id @@ -1130,17 +1126,21 @@ class RelationsTestCase(unittest.HomeserverTestCase): if not access_token: access_token = self.user_token - query = "" - if key: - query = "?key=" + urllib.parse.quote_plus(key.encode("utf-8")) - original_id = parent_id if parent_id else self.parent_id + if content is None: + content = {} + content["m.relates_to"] = { + "event_id": original_id, + "rel_type": relation_type, + } + if key is not None: + content["m.relates_to"]["key"] = key + channel = self.make_request( "POST", - "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s" - % (self.room, original_id, relation_type, event_type, query), - content or {}, + f"/_matrix/client/v3/rooms/{self.room}/send/{event_type}", + content, access_token=access_token, ) return channel diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index b58452195a..fe5b536d97 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -228,7 +228,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): self.assertIsNotNone(event) time_now = self.clock.time_msec() - serialized = self.get_success(self.serializer.serialize_event(event, time_now)) + serialized = self.serializer.serialize_event(event, time_now) return serialized diff --git a/tests/test_federation.py b/tests/test_federation.py index 3eef1c4c05..2b9804aba0 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -17,7 +17,9 @@ from unittest.mock import Mock from twisted.internet.defer import succeed from synapse.api.errors import FederationError +from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict +from synapse.federation.federation_base import event_from_pdu_json from synapse.logging.context import LoggingContext from synapse.types import UserID, create_requester from synapse.util import Clock @@ -276,3 +278,73 @@ class MessageAcceptTests(unittest.HomeserverTestCase): "ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(), ) self.assertTrue(remote_self_signing_key in self_signing_key["keys"].values()) + + +class StripUnsignedFromEventsTestCase(unittest.TestCase): + def test_strip_unauthorized_unsigned_values(self): + event1 = { + "sender": "@baduser:test.serv", + "state_key": "@baduser:test.serv", + "event_id": "$event1:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.member", + "origin": "test.servx", + "content": {"membership": "join"}, + "auth_events": [], + "unsigned": {"malicious garbage": "hackz", "more warez": "more hackz"}, + } + filtered_event = event_from_pdu_json(event1, RoomVersions.V1) + # Make sure unauthorized fields are stripped from unsigned + self.assertNotIn("more warez", filtered_event.unsigned) + + def test_strip_event_maintains_allowed_fields(self): + event2 = { + "sender": "@baduser:test.serv", + "state_key": "@baduser:test.serv", + "event_id": "$event2:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.member", + "origin": "test.servx", + "auth_events": [], + "content": {"membership": "join"}, + "unsigned": { + "malicious garbage": "hackz", + "more warez": "more hackz", + "age": 14, + "invite_room_state": [], + }, + } + + filtered_event2 = event_from_pdu_json(event2, RoomVersions.V1) + self.assertIn("age", filtered_event2.unsigned) + self.assertEqual(14, filtered_event2.unsigned["age"]) + self.assertNotIn("more warez", filtered_event2.unsigned) + # Invite_room_state is allowed in events of type m.room.member + self.assertIn("invite_room_state", filtered_event2.unsigned) + self.assertEqual([], filtered_event2.unsigned["invite_room_state"]) + + def test_strip_event_removes_fields_based_on_event_type(self): + event3 = { + "sender": "@baduser:test.serv", + "state_key": "@baduser:test.serv", + "event_id": "$event3:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.power_levels", + "origin": "test.servx", + "content": {}, + "auth_events": [], + "unsigned": { + "malicious garbage": "hackz", + "more warez": "more hackz", + "age": 14, + "invite_room_state": [], + }, + } + filtered_event3 = event_from_pdu_json(event3, RoomVersions.V1) + self.assertIn("age", filtered_event3.unsigned) + # Invite_room_state field is only permitted in event type m.room.member + self.assertNotIn("invite_room_state", filtered_event3.unsigned) + self.assertNotIn("more warez", filtered_event3.unsigned) diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index 15ac2bfeba..f05a373aa0 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -19,7 +19,7 @@ import sys import warnings from asyncio import Future from binascii import unhexlify -from typing import Any, Awaitable, Callable, TypeVar +from typing import Awaitable, Callable, TypeVar from unittest.mock import Mock import attr @@ -46,7 +46,7 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV: raise Exception("awaitable has not yet completed") -def make_awaitable(result: Any) -> Awaitable[Any]: +def make_awaitable(result: TV) -> Awaitable[TV]: """ Makes an awaitable, suitable for mocking an `async` function. This uses Futures as they can be awaited multiple times so can be returned |