summary refs log tree commit diff
path: root/tests/rest/client/test_sync.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client/test_sync.py')
-rw-r--r--tests/rest/client/test_sync.py399
1 files changed, 389 insertions, 10 deletions
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 417a87feb2..daeb1d3ddd 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -21,7 +21,7 @@
 import json
 from typing import List
 
-from parameterized import parameterized
+from parameterized import parameterized, parameterized_class
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -688,24 +688,180 @@ class SyncCacheTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200, channel.json_body)
 
 
+@parameterized_class(
+    ("sync_endpoint", "experimental_features"),
+    [
+        ("/sync", {}),
+        (
+            "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
+            # Enable sliding sync
+            {"msc3575_enabled": True},
+        ),
+    ],
+)
 class DeviceListSyncTestCase(unittest.HomeserverTestCase):
+    """
+    Tests regarding device list (`device_lists`) changes.
+
+    Attributes:
+        sync_endpoint: The endpoint under test to use for syncing.
+        experimental_features: The experimental features homeserver config to use.
+    """
+
+    sync_endpoint: str
+    experimental_features: JsonDict
+
     servlets = [
         synapse.rest.admin.register_servlets,
         login.register_servlets,
+        room.register_servlets,
         sync.register_servlets,
         devices.register_servlets,
     ]
 
+    def default_config(self) -> JsonDict:
+        config = super().default_config()
+        config["experimental_features"] = self.experimental_features
+        return config
+
+    def test_receiving_local_device_list_changes(self) -> None:
+        """Tests that a local users that share a room receive each other's device list
+        changes.
+        """
+        # Register two users
+        test_device_id = "TESTDEVICE"
+        alice_user_id = self.register_user("alice", "correcthorse")
+        alice_access_token = self.login(
+            alice_user_id, "correcthorse", device_id=test_device_id
+        )
+
+        bob_user_id = self.register_user("bob", "ponyponypony")
+        bob_access_token = self.login(bob_user_id, "ponyponypony")
+
+        # Create a room for them to coexist peacefully in
+        new_room_id = self.helper.create_room_as(
+            alice_user_id, is_public=True, tok=alice_access_token
+        )
+        self.assertIsNotNone(new_room_id)
+
+        # Have Bob join the room
+        self.helper.invite(
+            new_room_id, alice_user_id, bob_user_id, tok=alice_access_token
+        )
+        self.helper.join(new_room_id, bob_user_id, tok=bob_access_token)
+
+        # Now have Bob initiate an initial sync (in order to get a since token)
+        channel = self.make_request(
+            "GET",
+            self.sync_endpoint,
+            access_token=bob_access_token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        next_batch_token = channel.json_body["next_batch"]
+
+        # ...and then an incremental sync. This should block until the sync stream is woken up,
+        # which we hope will happen as a result of Alice updating their device list.
+        bob_sync_channel = self.make_request(
+            "GET",
+            f"{self.sync_endpoint}?since={next_batch_token}&timeout=30000",
+            access_token=bob_access_token,
+            # Start the request, then continue on.
+            await_result=False,
+        )
+
+        # Have alice update their device list
+        channel = self.make_request(
+            "PUT",
+            f"/devices/{test_device_id}",
+            {
+                "display_name": "New Device Name",
+            },
+            access_token=alice_access_token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check that bob's incremental sync contains the updated device list.
+        # If not, the client would only receive the device list update on the
+        # *next* sync.
+        bob_sync_channel.await_result()
+        self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
+
+        changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
+            "changed", []
+        )
+        self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body)
+
+    def test_not_receiving_local_device_list_changes(self) -> None:
+        """Tests a local users DO NOT receive device updates from each other if they do not
+        share a room.
+        """
+        # Register two users
+        test_device_id = "TESTDEVICE"
+        alice_user_id = self.register_user("alice", "correcthorse")
+        alice_access_token = self.login(
+            alice_user_id, "correcthorse", device_id=test_device_id
+        )
+
+        bob_user_id = self.register_user("bob", "ponyponypony")
+        bob_access_token = self.login(bob_user_id, "ponyponypony")
+
+        # These users do not share a room. They are lonely.
+
+        # Have Bob initiate an initial sync (in order to get a since token)
+        channel = self.make_request(
+            "GET",
+            self.sync_endpoint,
+            access_token=bob_access_token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        next_batch_token = channel.json_body["next_batch"]
+
+        # ...and then an incremental sync. This should block until the sync stream is woken up,
+        # which we hope will happen as a result of Alice updating their device list.
+        bob_sync_channel = self.make_request(
+            "GET",
+            f"{self.sync_endpoint}?since={next_batch_token}&timeout=1000",
+            access_token=bob_access_token,
+            # Start the request, then continue on.
+            await_result=False,
+        )
+
+        # Have alice update their device list
+        channel = self.make_request(
+            "PUT",
+            f"/devices/{test_device_id}",
+            {
+                "display_name": "New Device Name",
+            },
+            access_token=alice_access_token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check that bob's incremental sync does not contain the updated device list.
+        bob_sync_channel.await_result()
+        self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
+
+        changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
+            "changed", []
+        )
+        self.assertNotIn(
+            alice_user_id, changed_device_lists, bob_sync_channel.json_body
+        )
+
     def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None:
         """Tests that a user with no rooms still receives their own device list updates"""
-        device_id = "TESTDEVICE"
+        test_device_id = "TESTDEVICE"
 
         # Register a user and login, creating a device
-        self.user_id = self.register_user("kermit", "monkey")
-        self.tok = self.login("kermit", "monkey", device_id=device_id)
+        alice_user_id = self.register_user("alice", "correcthorse")
+        alice_access_token = self.login(
+            alice_user_id, "correcthorse", device_id=test_device_id
+        )
 
         # Request an initial sync
-        channel = self.make_request("GET", "/sync", access_token=self.tok)
+        channel = self.make_request(
+            "GET", self.sync_endpoint, access_token=alice_access_token
+        )
         self.assertEqual(channel.code, 200, channel.json_body)
         next_batch = channel.json_body["next_batch"]
 
@@ -713,19 +869,19 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
         # It won't return until something has happened
         incremental_sync_channel = self.make_request(
             "GET",
-            f"/sync?since={next_batch}&timeout=30000",
-            access_token=self.tok,
+            f"{self.sync_endpoint}?since={next_batch}&timeout=30000",
+            access_token=alice_access_token,
             await_result=False,
         )
 
         # Change our device's display name
         channel = self.make_request(
             "PUT",
-            f"devices/{device_id}",
+            f"devices/{test_device_id}",
             {
                 "display_name": "freeze ray",
             },
-            access_token=self.tok,
+            access_token=alice_access_token,
         )
         self.assertEqual(channel.code, 200, channel.json_body)
 
@@ -739,7 +895,230 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
         ).get("changed", [])
 
         self.assertIn(
-            self.user_id, device_list_changes, incremental_sync_channel.json_body
+            alice_user_id, device_list_changes, incremental_sync_channel.json_body
+        )
+
+
+@parameterized_class(
+    ("sync_endpoint", "experimental_features"),
+    [
+        ("/sync", {}),
+        (
+            "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
+            # Enable sliding sync
+            {"msc3575_enabled": True},
+        ),
+    ],
+)
+class DeviceOneTimeKeysSyncTestCase(unittest.HomeserverTestCase):
+    """
+    Tests regarding device one time keys (`device_one_time_keys_count`) changes.
+
+    Attributes:
+        sync_endpoint: The endpoint under test to use for syncing.
+        experimental_features: The experimental features homeserver config to use.
+    """
+
+    sync_endpoint: str
+    experimental_features: JsonDict
+
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        sync.register_servlets,
+        devices.register_servlets,
+    ]
+
+    def default_config(self) -> JsonDict:
+        config = super().default_config()
+        config["experimental_features"] = self.experimental_features
+        return config
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.e2e_keys_handler = hs.get_e2e_keys_handler()
+
+    def test_no_device_one_time_keys(self) -> None:
+        """
+        Tests when no one time keys set, it still has the default `signed_curve25519` in
+        `device_one_time_keys_count`
+        """
+        test_device_id = "TESTDEVICE"
+
+        alice_user_id = self.register_user("alice", "correcthorse")
+        alice_access_token = self.login(
+            alice_user_id, "correcthorse", device_id=test_device_id
+        )
+
+        # Request an initial sync
+        channel = self.make_request(
+            "GET", self.sync_endpoint, access_token=alice_access_token
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check for those one time key counts
+        self.assertDictEqual(
+            channel.json_body["device_one_time_keys_count"],
+            # Note that "signed_curve25519" is always returned in key count responses
+            # regardless of whether we uploaded any keys for it. This is necessary until
+            # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
+            {"signed_curve25519": 0},
+            channel.json_body["device_one_time_keys_count"],
+        )
+
+    def test_returns_device_one_time_keys(self) -> None:
+        """
+        Tests that one time keys for the device/user are counted correctly in the `/sync`
+        response
+        """
+        test_device_id = "TESTDEVICE"
+
+        alice_user_id = self.register_user("alice", "correcthorse")
+        alice_access_token = self.login(
+            alice_user_id, "correcthorse", device_id=test_device_id
+        )
+
+        # Upload one time keys for the user/device
+        keys: JsonDict = {
+            "alg1:k1": "key1",
+            "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
+            "alg2:k3": {"key": "key3"},
+        }
+        res = self.get_success(
+            self.e2e_keys_handler.upload_keys_for_user(
+                alice_user_id, test_device_id, {"one_time_keys": keys}
+            )
+        )
+        # Note that "signed_curve25519" is always returned in key count responses
+        # regardless of whether we uploaded any keys for it. This is necessary until
+        # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
+        self.assertDictEqual(
+            res,
+            {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}},
+        )
+
+        # Request an initial sync
+        channel = self.make_request(
+            "GET", self.sync_endpoint, access_token=alice_access_token
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check for those one time key counts
+        self.assertDictEqual(
+            channel.json_body["device_one_time_keys_count"],
+            {"alg1": 1, "alg2": 2, "signed_curve25519": 0},
+            channel.json_body["device_one_time_keys_count"],
+        )
+
+
+@parameterized_class(
+    ("sync_endpoint", "experimental_features"),
+    [
+        ("/sync", {}),
+        (
+            "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
+            # Enable sliding sync
+            {"msc3575_enabled": True},
+        ),
+    ],
+)
+class DeviceUnusedFallbackKeySyncTestCase(unittest.HomeserverTestCase):
+    """
+    Tests regarding device one time keys (`device_unused_fallback_key_types`) changes.
+
+    Attributes:
+        sync_endpoint: The endpoint under test to use for syncing.
+        experimental_features: The experimental features homeserver config to use.
+    """
+
+    sync_endpoint: str
+    experimental_features: JsonDict
+
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        sync.register_servlets,
+        devices.register_servlets,
+    ]
+
+    def default_config(self) -> JsonDict:
+        config = super().default_config()
+        config["experimental_features"] = self.experimental_features
+        return config
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = self.hs.get_datastores().main
+        self.e2e_keys_handler = hs.get_e2e_keys_handler()
+
+    def test_no_device_unused_fallback_key(self) -> None:
+        """
+        Test when no unused fallback key is set, it just returns an empty list. The MSC
+        says "The device_unused_fallback_key_types parameter must be present if the
+        server supports fallback keys.",
+        https://github.com/matrix-org/matrix-spec-proposals/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md
+        """
+        test_device_id = "TESTDEVICE"
+
+        alice_user_id = self.register_user("alice", "correcthorse")
+        alice_access_token = self.login(
+            alice_user_id, "correcthorse", device_id=test_device_id
+        )
+
+        # Request an initial sync
+        channel = self.make_request(
+            "GET", self.sync_endpoint, access_token=alice_access_token
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check for those one time key counts
+        self.assertListEqual(
+            channel.json_body["device_unused_fallback_key_types"],
+            [],
+            channel.json_body["device_unused_fallback_key_types"],
+        )
+
+    def test_returns_device_one_time_keys(self) -> None:
+        """
+        Tests that device unused fallback key type is returned correctly in the `/sync`
+        """
+        test_device_id = "TESTDEVICE"
+
+        alice_user_id = self.register_user("alice", "correcthorse")
+        alice_access_token = self.login(
+            alice_user_id, "correcthorse", device_id=test_device_id
+        )
+
+        # We shouldn't have any unused fallback keys yet
+        res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id)
+        )
+        self.assertEqual(res, [])
+
+        # Upload a fallback key for the user/device
+        fallback_key = {"alg1:k1": "fallback_key1"}
+        self.get_success(
+            self.e2e_keys_handler.upload_keys_for_user(
+                alice_user_id,
+                test_device_id,
+                {"fallback_keys": fallback_key},
+            )
+        )
+        # We should now have an unused alg1 key
+        fallback_res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id)
+        )
+        self.assertEqual(fallback_res, ["alg1"], fallback_res)
+
+        # Request an initial sync
+        channel = self.make_request(
+            "GET", self.sync_endpoint, access_token=alice_access_token
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check for the unused fallback key types
+        self.assertListEqual(
+            channel.json_body["device_unused_fallback_key_types"],
+            ["alg1"],
+            channel.json_body["device_unused_fallback_key_types"],
         )