diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 417a87feb2..daeb1d3ddd 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -21,7 +21,7 @@
import json
from typing import List
-from parameterized import parameterized
+from parameterized import parameterized, parameterized_class
from twisted.test.proto_helpers import MemoryReactor
@@ -688,24 +688,180 @@ class SyncCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.json_body)
+@parameterized_class(
+ ("sync_endpoint", "experimental_features"),
+ [
+ ("/sync", {}),
+ (
+ "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
+ # Enable sliding sync
+ {"msc3575_enabled": True},
+ ),
+ ],
+)
class DeviceListSyncTestCase(unittest.HomeserverTestCase):
+ """
+ Tests regarding device list (`device_lists`) changes.
+
+ Attributes:
+ sync_endpoint: The endpoint under test to use for syncing.
+ experimental_features: The experimental features homeserver config to use.
+ """
+
+ sync_endpoint: str
+ experimental_features: JsonDict
+
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
+ room.register_servlets,
sync.register_servlets,
devices.register_servlets,
]
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["experimental_features"] = self.experimental_features
+ return config
+
+ def test_receiving_local_device_list_changes(self) -> None:
+ """Tests that a local users that share a room receive each other's device list
+ changes.
+ """
+ # Register two users
+ test_device_id = "TESTDEVICE"
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
+
+ bob_user_id = self.register_user("bob", "ponyponypony")
+ bob_access_token = self.login(bob_user_id, "ponyponypony")
+
+ # Create a room for them to coexist peacefully in
+ new_room_id = self.helper.create_room_as(
+ alice_user_id, is_public=True, tok=alice_access_token
+ )
+ self.assertIsNotNone(new_room_id)
+
+ # Have Bob join the room
+ self.helper.invite(
+ new_room_id, alice_user_id, bob_user_id, tok=alice_access_token
+ )
+ self.helper.join(new_room_id, bob_user_id, tok=bob_access_token)
+
+ # Now have Bob initiate an initial sync (in order to get a since token)
+ channel = self.make_request(
+ "GET",
+ self.sync_endpoint,
+ access_token=bob_access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ next_batch_token = channel.json_body["next_batch"]
+
+ # ...and then an incremental sync. This should block until the sync stream is woken up,
+ # which we hope will happen as a result of Alice updating their device list.
+ bob_sync_channel = self.make_request(
+ "GET",
+ f"{self.sync_endpoint}?since={next_batch_token}&timeout=30000",
+ access_token=bob_access_token,
+ # Start the request, then continue on.
+ await_result=False,
+ )
+
+ # Have alice update their device list
+ channel = self.make_request(
+ "PUT",
+ f"/devices/{test_device_id}",
+ {
+ "display_name": "New Device Name",
+ },
+ access_token=alice_access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check that bob's incremental sync contains the updated device list.
+ # If not, the client would only receive the device list update on the
+ # *next* sync.
+ bob_sync_channel.await_result()
+ self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
+
+ changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
+ "changed", []
+ )
+ self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body)
+
+ def test_not_receiving_local_device_list_changes(self) -> None:
+ """Tests a local users DO NOT receive device updates from each other if they do not
+ share a room.
+ """
+ # Register two users
+ test_device_id = "TESTDEVICE"
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
+
+ bob_user_id = self.register_user("bob", "ponyponypony")
+ bob_access_token = self.login(bob_user_id, "ponyponypony")
+
+ # These users do not share a room. They are lonely.
+
+ # Have Bob initiate an initial sync (in order to get a since token)
+ channel = self.make_request(
+ "GET",
+ self.sync_endpoint,
+ access_token=bob_access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ next_batch_token = channel.json_body["next_batch"]
+
+ # ...and then an incremental sync. This should block until the sync stream is woken up,
+ # which we hope will happen as a result of Alice updating their device list.
+ bob_sync_channel = self.make_request(
+ "GET",
+ f"{self.sync_endpoint}?since={next_batch_token}&timeout=1000",
+ access_token=bob_access_token,
+ # Start the request, then continue on.
+ await_result=False,
+ )
+
+ # Have alice update their device list
+ channel = self.make_request(
+ "PUT",
+ f"/devices/{test_device_id}",
+ {
+ "display_name": "New Device Name",
+ },
+ access_token=alice_access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check that bob's incremental sync does not contain the updated device list.
+ bob_sync_channel.await_result()
+ self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
+
+ changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
+ "changed", []
+ )
+ self.assertNotIn(
+ alice_user_id, changed_device_lists, bob_sync_channel.json_body
+ )
+
def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None:
"""Tests that a user with no rooms still receives their own device list updates"""
- device_id = "TESTDEVICE"
+ test_device_id = "TESTDEVICE"
# Register a user and login, creating a device
- self.user_id = self.register_user("kermit", "monkey")
- self.tok = self.login("kermit", "monkey", device_id=device_id)
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
# Request an initial sync
- channel = self.make_request("GET", "/sync", access_token=self.tok)
+ channel = self.make_request(
+ "GET", self.sync_endpoint, access_token=alice_access_token
+ )
self.assertEqual(channel.code, 200, channel.json_body)
next_batch = channel.json_body["next_batch"]
@@ -713,19 +869,19 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
# It won't return until something has happened
incremental_sync_channel = self.make_request(
"GET",
- f"/sync?since={next_batch}&timeout=30000",
- access_token=self.tok,
+ f"{self.sync_endpoint}?since={next_batch}&timeout=30000",
+ access_token=alice_access_token,
await_result=False,
)
# Change our device's display name
channel = self.make_request(
"PUT",
- f"devices/{device_id}",
+ f"devices/{test_device_id}",
{
"display_name": "freeze ray",
},
- access_token=self.tok,
+ access_token=alice_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -739,7 +895,230 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
).get("changed", [])
self.assertIn(
- self.user_id, device_list_changes, incremental_sync_channel.json_body
+ alice_user_id, device_list_changes, incremental_sync_channel.json_body
+ )
+
+
+@parameterized_class(
+ ("sync_endpoint", "experimental_features"),
+ [
+ ("/sync", {}),
+ (
+ "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
+ # Enable sliding sync
+ {"msc3575_enabled": True},
+ ),
+ ],
+)
+class DeviceOneTimeKeysSyncTestCase(unittest.HomeserverTestCase):
+ """
+ Tests regarding device one time keys (`device_one_time_keys_count`) changes.
+
+ Attributes:
+ sync_endpoint: The endpoint under test to use for syncing.
+ experimental_features: The experimental features homeserver config to use.
+ """
+
+ sync_endpoint: str
+ experimental_features: JsonDict
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ devices.register_servlets,
+ ]
+
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["experimental_features"] = self.experimental_features
+ return config
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.e2e_keys_handler = hs.get_e2e_keys_handler()
+
+ def test_no_device_one_time_keys(self) -> None:
+ """
+ Tests when no one time keys set, it still has the default `signed_curve25519` in
+ `device_one_time_keys_count`
+ """
+ test_device_id = "TESTDEVICE"
+
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
+
+ # Request an initial sync
+ channel = self.make_request(
+ "GET", self.sync_endpoint, access_token=alice_access_token
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check for those one time key counts
+ self.assertDictEqual(
+ channel.json_body["device_one_time_keys_count"],
+ # Note that "signed_curve25519" is always returned in key count responses
+ # regardless of whether we uploaded any keys for it. This is necessary until
+ # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
+ {"signed_curve25519": 0},
+ channel.json_body["device_one_time_keys_count"],
+ )
+
+ def test_returns_device_one_time_keys(self) -> None:
+ """
+ Tests that one time keys for the device/user are counted correctly in the `/sync`
+ response
+ """
+ test_device_id = "TESTDEVICE"
+
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
+
+ # Upload one time keys for the user/device
+ keys: JsonDict = {
+ "alg1:k1": "key1",
+ "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
+ "alg2:k3": {"key": "key3"},
+ }
+ res = self.get_success(
+ self.e2e_keys_handler.upload_keys_for_user(
+ alice_user_id, test_device_id, {"one_time_keys": keys}
+ )
+ )
+ # Note that "signed_curve25519" is always returned in key count responses
+ # regardless of whether we uploaded any keys for it. This is necessary until
+ # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
+ self.assertDictEqual(
+ res,
+ {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}},
+ )
+
+ # Request an initial sync
+ channel = self.make_request(
+ "GET", self.sync_endpoint, access_token=alice_access_token
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check for those one time key counts
+ self.assertDictEqual(
+ channel.json_body["device_one_time_keys_count"],
+ {"alg1": 1, "alg2": 2, "signed_curve25519": 0},
+ channel.json_body["device_one_time_keys_count"],
+ )
+
+
+@parameterized_class(
+ ("sync_endpoint", "experimental_features"),
+ [
+ ("/sync", {}),
+ (
+ "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
+ # Enable sliding sync
+ {"msc3575_enabled": True},
+ ),
+ ],
+)
+class DeviceUnusedFallbackKeySyncTestCase(unittest.HomeserverTestCase):
+ """
+ Tests regarding device one time keys (`device_unused_fallback_key_types`) changes.
+
+ Attributes:
+ sync_endpoint: The endpoint under test to use for syncing.
+ experimental_features: The experimental features homeserver config to use.
+ """
+
+ sync_endpoint: str
+ experimental_features: JsonDict
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ devices.register_servlets,
+ ]
+
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["experimental_features"] = self.experimental_features
+ return config
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = self.hs.get_datastores().main
+ self.e2e_keys_handler = hs.get_e2e_keys_handler()
+
+ def test_no_device_unused_fallback_key(self) -> None:
+ """
+ Test when no unused fallback key is set, it just returns an empty list. The MSC
+ says "The device_unused_fallback_key_types parameter must be present if the
+ server supports fallback keys.",
+ https://github.com/matrix-org/matrix-spec-proposals/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md
+ """
+ test_device_id = "TESTDEVICE"
+
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
+
+ # Request an initial sync
+ channel = self.make_request(
+ "GET", self.sync_endpoint, access_token=alice_access_token
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check for those one time key counts
+ self.assertListEqual(
+ channel.json_body["device_unused_fallback_key_types"],
+ [],
+ channel.json_body["device_unused_fallback_key_types"],
+ )
+
+ def test_returns_device_one_time_keys(self) -> None:
+ """
+ Tests that device unused fallback key type is returned correctly in the `/sync`
+ """
+ test_device_id = "TESTDEVICE"
+
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
+
+ # We shouldn't have any unused fallback keys yet
+ res = self.get_success(
+ self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id)
+ )
+ self.assertEqual(res, [])
+
+ # Upload a fallback key for the user/device
+ fallback_key = {"alg1:k1": "fallback_key1"}
+ self.get_success(
+ self.e2e_keys_handler.upload_keys_for_user(
+ alice_user_id,
+ test_device_id,
+ {"fallback_keys": fallback_key},
+ )
+ )
+ # We should now have an unused alg1 key
+ fallback_res = self.get_success(
+ self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id)
+ )
+ self.assertEqual(fallback_res, ["alg1"], fallback_res)
+
+ # Request an initial sync
+ channel = self.make_request(
+ "GET", self.sync_endpoint, access_token=alice_access_token
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check for the unused fallback key types
+ self.assertListEqual(
+ channel.json_body["device_unused_fallback_key_types"],
+ ["alg1"],
+ channel.json_body["device_unused_fallback_key_types"],
)
|