summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_e2e_keys.py96
-rw-r--r--tests/rest/admin/test_federation.py75
-rw-r--r--tests/rest/client/test_relations.py26
-rw-r--r--tests/rest/client/test_retention.py2
-rw-r--r--tests/test_federation.py72
-rw-r--r--tests/test_utils/__init__.py4
6 files changed, 241 insertions, 34 deletions
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index ddcf3ee348..734ed84d78 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -13,8 +13,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Iterable
 from unittest import mock
 
+from parameterized import parameterized
 from signedjson import key as key, sign as sign
 
 from twisted.internet import defer
@@ -23,6 +25,7 @@ from synapse.api.constants import RoomEncryptionAlgorithms
 from synapse.api.errors import Codes, SynapseError
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 
 
 class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
@@ -765,6 +768,8 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         remote_user_id = "@test:other"
         local_user_id = "@test:test"
 
+        # Pretend we're sharing a room with the user we're querying. If not,
+        # `_query_devices_for_destination` will return early.
         self.store.get_rooms_for_user = mock.Mock(
             return_value=defer.succeed({"some_room_id"})
         )
@@ -831,3 +836,94 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
                 }
             },
         )
+
+    @parameterized.expand(
+        [
+            # The remote homeserver's response indicates that this user has 0/1/2 devices.
+            ([],),
+            (["device_1"],),
+            (["device_1", "device_2"],),
+        ]
+    )
+    def test_query_all_devices_caches_result(self, device_ids: Iterable[str]):
+        """Test that requests for all of a remote user's devices are cached.
+
+        We do this by asserting that only one call over federation was made, and that
+        the two queries to the local homeserver produce the same response.
+        """
+        local_user_id = "@test:test"
+        remote_user_id = "@test:other"
+        request_body = {"device_keys": {remote_user_id: []}}
+
+        response_devices = [
+            {
+                "device_id": device_id,
+                "keys": {
+                    "algorithms": ["dummy"],
+                    "device_id": device_id,
+                    "keys": {f"dummy:{device_id}": "dummy"},
+                    "signatures": {device_id: {f"dummy:{device_id}": "dummy"}},
+                    "unsigned": {},
+                    "user_id": "@test:other",
+                },
+            }
+            for device_id in device_ids
+        ]
+
+        response_body = {
+            "devices": response_devices,
+            "user_id": remote_user_id,
+            "stream_id": 12345,  # an integer, according to the spec
+        }
+
+        e2e_handler = self.hs.get_e2e_keys_handler()
+
+        # Pretend we're sharing a room with the user we're querying. If not,
+        # `_query_devices_for_destination` will return early.
+        mock_get_rooms = mock.patch.object(
+            self.store,
+            "get_rooms_for_user",
+            new_callable=mock.MagicMock,
+            return_value=make_awaitable(["some_room_id"]),
+        )
+        mock_request = mock.patch.object(
+            self.hs.get_federation_client(),
+            "query_user_devices",
+            new_callable=mock.MagicMock,
+            return_value=make_awaitable(response_body),
+        )
+
+        with mock_get_rooms, mock_request as mocked_federation_request:
+            # Make the first query and sanity check it succeeds.
+            response_1 = self.get_success(
+                e2e_handler.query_devices(
+                    request_body,
+                    timeout=10,
+                    from_user_id=local_user_id,
+                    from_device_id="some_device_id",
+                )
+            )
+            self.assertEqual(response_1["failures"], {})
+
+            # We should have made a federation request to do so.
+            mocked_federation_request.assert_called_once()
+
+            # Reset the mock so we can prove we don't make a second federation request.
+            mocked_federation_request.reset_mock()
+
+            # Repeat the query.
+            response_2 = self.get_success(
+                e2e_handler.query_devices(
+                    request_body,
+                    timeout=10,
+                    from_user_id=local_user_id,
+                    from_device_id="some_device_id",
+                )
+            )
+            self.assertEqual(response_2["failures"], {})
+
+            # We should not have made a second federation request.
+            mocked_federation_request.assert_not_called()
+
+            # The two requests to the local homeserver should be identical.
+            self.assertEqual(response_1, response_2)
diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py
index 742f194257..b70350b6f1 100644
--- a/tests/rest/admin/test_federation.py
+++ b/tests/rest/admin/test_federation.py
@@ -314,15 +314,12 @@ class FederationTestCase(unittest.HomeserverTestCase):
             retry_interval,
             last_successful_stream_ordering,
         ) in dest:
-            self.get_success(
-                self.store.set_destination_retry_timings(
-                    destination, failure_ts, retry_last_ts, retry_interval
-                )
-            )
-            self.get_success(
-                self.store.set_destination_last_successful_stream_ordering(
-                    destination, last_successful_stream_ordering
-                )
+            self._create_destination(
+                destination,
+                failure_ts,
+                retry_last_ts,
+                retry_interval,
+                last_successful_stream_ordering,
             )
 
         # order by default (destination)
@@ -413,11 +410,9 @@ class FederationTestCase(unittest.HomeserverTestCase):
         _search_test(None, "foo")
         _search_test(None, "bar")
 
-    def test_get_single_destination(self) -> None:
-        """
-        Get one specific destinations.
-        """
-        self._create_destinations(5)
+    def test_get_single_destination_with_retry_timings(self) -> None:
+        """Get one specific destination which has retry timings."""
+        self._create_destinations(1)
 
         channel = self.make_request(
             "GET",
@@ -432,6 +427,53 @@ class FederationTestCase(unittest.HomeserverTestCase):
         # convert channel.json_body into a List
         self._check_fields([channel.json_body])
 
+    def test_get_single_destination_no_retry_timings(self) -> None:
+        """Get one specific destination which has no retry timings."""
+        self._create_destination("sub0.example.com")
+
+        channel = self.make_request(
+            "GET",
+            self.url + "/sub0.example.com",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual("sub0.example.com", channel.json_body["destination"])
+        self.assertEqual(0, channel.json_body["retry_last_ts"])
+        self.assertEqual(0, channel.json_body["retry_interval"])
+        self.assertIsNone(channel.json_body["failure_ts"])
+        self.assertIsNone(channel.json_body["last_successful_stream_ordering"])
+
+    def _create_destination(
+        self,
+        destination: str,
+        failure_ts: Optional[int] = None,
+        retry_last_ts: int = 0,
+        retry_interval: int = 0,
+        last_successful_stream_ordering: Optional[int] = None,
+    ) -> None:
+        """Create one specific destination
+
+        Args:
+            destination: the destination we have successfully sent to
+            failure_ts: when the server started failing (ms since epoch)
+            retry_last_ts: time of last retry attempt in unix epoch ms
+            retry_interval: how long until next retry in ms
+            last_successful_stream_ordering: the stream_ordering of the most
+                recent successfully-sent PDU
+        """
+        self.get_success(
+            self.store.set_destination_retry_timings(
+                destination, failure_ts, retry_last_ts, retry_interval
+            )
+        )
+        if last_successful_stream_ordering is not None:
+            self.get_success(
+                self.store.set_destination_last_successful_stream_ordering(
+                    destination, last_successful_stream_ordering
+                )
+            )
+
     def _create_destinations(self, number_destinations: int) -> None:
         """Create a number of destinations
 
@@ -440,10 +482,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         """
         for i in range(0, number_destinations):
             dest = f"sub{i}.example.com"
-            self.get_success(self.store.set_destination_retry_timings(dest, 50, 50, 50))
-            self.get_success(
-                self.store.set_destination_last_successful_stream_ordering(dest, 100)
-            )
+            self._create_destination(dest, 50, 50, 50, 100)
 
     def _check_fields(self, content: List[JsonDict]) -> None:
         """Checks that the expected destination attributes are present in content
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index c026d526ef..ff4e81d069 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -93,11 +93,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             channel.json_body,
         )
 
-    def test_deny_membership(self):
-        """Test that we deny relations on membership events"""
-        channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member)
-        self.assertEquals(400, channel.code, channel.json_body)
-
     def test_deny_invalid_event(self):
         """Test that we deny relations on non-existant events"""
         channel = self._send_relation(
@@ -1119,7 +1114,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             relation_type: One of `RelationTypes`
             event_type: The type of the event to create
             key: The aggregation key used for m.annotation relation type.
-            content: The content of the created event.
+            content: The content of the created event. Will be modified to configure
+                the m.relates_to key based on the other provided parameters.
             access_token: The access token used to send the relation, defaults
                 to `self.user_token`
             parent_id: The event_id this relation relates to. If None, then self.parent_id
@@ -1130,17 +1126,21 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         if not access_token:
             access_token = self.user_token
 
-        query = ""
-        if key:
-            query = "?key=" + urllib.parse.quote_plus(key.encode("utf-8"))
-
         original_id = parent_id if parent_id else self.parent_id
 
+        if content is None:
+            content = {}
+        content["m.relates_to"] = {
+            "event_id": original_id,
+            "rel_type": relation_type,
+        }
+        if key is not None:
+            content["m.relates_to"]["key"] = key
+
         channel = self.make_request(
             "POST",
-            "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
-            % (self.room, original_id, relation_type, event_type, query),
-            content or {},
+            f"/_matrix/client/v3/rooms/{self.room}/send/{event_type}",
+            content,
             access_token=access_token,
         )
         return channel
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index b58452195a..fe5b536d97 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -228,7 +228,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
         self.assertIsNotNone(event)
 
         time_now = self.clock.time_msec()
-        serialized = self.get_success(self.serializer.serialize_event(event, time_now))
+        serialized = self.serializer.serialize_event(event, time_now)
 
         return serialized
 
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 3eef1c4c05..2b9804aba0 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -17,7 +17,9 @@ from unittest.mock import Mock
 from twisted.internet.defer import succeed
 
 from synapse.api.errors import FederationError
+from synapse.api.room_versions import RoomVersions
 from synapse.events import make_event_from_dict
+from synapse.federation.federation_base import event_from_pdu_json
 from synapse.logging.context import LoggingContext
 from synapse.types import UserID, create_requester
 from synapse.util import Clock
@@ -276,3 +278,73 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
             "ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(),
         )
         self.assertTrue(remote_self_signing_key in self_signing_key["keys"].values())
+
+
+class StripUnsignedFromEventsTestCase(unittest.TestCase):
+    def test_strip_unauthorized_unsigned_values(self):
+        event1 = {
+            "sender": "@baduser:test.serv",
+            "state_key": "@baduser:test.serv",
+            "event_id": "$event1:test.serv",
+            "depth": 1000,
+            "origin_server_ts": 1,
+            "type": "m.room.member",
+            "origin": "test.servx",
+            "content": {"membership": "join"},
+            "auth_events": [],
+            "unsigned": {"malicious garbage": "hackz", "more warez": "more hackz"},
+        }
+        filtered_event = event_from_pdu_json(event1, RoomVersions.V1)
+        # Make sure unauthorized fields are stripped from unsigned
+        self.assertNotIn("more warez", filtered_event.unsigned)
+
+    def test_strip_event_maintains_allowed_fields(self):
+        event2 = {
+            "sender": "@baduser:test.serv",
+            "state_key": "@baduser:test.serv",
+            "event_id": "$event2:test.serv",
+            "depth": 1000,
+            "origin_server_ts": 1,
+            "type": "m.room.member",
+            "origin": "test.servx",
+            "auth_events": [],
+            "content": {"membership": "join"},
+            "unsigned": {
+                "malicious garbage": "hackz",
+                "more warez": "more hackz",
+                "age": 14,
+                "invite_room_state": [],
+            },
+        }
+
+        filtered_event2 = event_from_pdu_json(event2, RoomVersions.V1)
+        self.assertIn("age", filtered_event2.unsigned)
+        self.assertEqual(14, filtered_event2.unsigned["age"])
+        self.assertNotIn("more warez", filtered_event2.unsigned)
+        # Invite_room_state is allowed in events of type m.room.member
+        self.assertIn("invite_room_state", filtered_event2.unsigned)
+        self.assertEqual([], filtered_event2.unsigned["invite_room_state"])
+
+    def test_strip_event_removes_fields_based_on_event_type(self):
+        event3 = {
+            "sender": "@baduser:test.serv",
+            "state_key": "@baduser:test.serv",
+            "event_id": "$event3:test.serv",
+            "depth": 1000,
+            "origin_server_ts": 1,
+            "type": "m.room.power_levels",
+            "origin": "test.servx",
+            "content": {},
+            "auth_events": [],
+            "unsigned": {
+                "malicious garbage": "hackz",
+                "more warez": "more hackz",
+                "age": 14,
+                "invite_room_state": [],
+            },
+        }
+        filtered_event3 = event_from_pdu_json(event3, RoomVersions.V1)
+        self.assertIn("age", filtered_event3.unsigned)
+        # Invite_room_state field is only permitted in event type m.room.member
+        self.assertNotIn("invite_room_state", filtered_event3.unsigned)
+        self.assertNotIn("more warez", filtered_event3.unsigned)
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 15ac2bfeba..f05a373aa0 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -19,7 +19,7 @@ import sys
 import warnings
 from asyncio import Future
 from binascii import unhexlify
-from typing import Any, Awaitable, Callable, TypeVar
+from typing import Awaitable, Callable, TypeVar
 from unittest.mock import Mock
 
 import attr
@@ -46,7 +46,7 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
     raise Exception("awaitable has not yet completed")
 
 
-def make_awaitable(result: Any) -> Awaitable[Any]:
+def make_awaitable(result: TV) -> Awaitable[TV]:
     """
     Makes an awaitable, suitable for mocking an `async` function.
     This uses Futures as they can be awaited multiple times so can be returned