diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py
index 2b360732ac..a3ed12a38f 100644
--- a/tests/rest/client/test_devices.py
+++ b/tests/rest/client/test_devices.py
@@ -24,8 +24,8 @@ from twisted.internet.defer import ensureDeferred
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import NotFoundError
-from synapse.rest import admin, devices, room, sync
-from synapse.rest.client import account, keys, login, register
+from synapse.rest import admin, devices, sync
+from synapse.rest.client import keys, login, register
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
@@ -33,146 +33,6 @@ from synapse.util import Clock
from tests import unittest
-class DeviceListsTestCase(unittest.HomeserverTestCase):
- """Tests regarding device list changes."""
-
- servlets = [
- admin.register_servlets_for_client_rest_resource,
- login.register_servlets,
- register.register_servlets,
- account.register_servlets,
- room.register_servlets,
- sync.register_servlets,
- devices.register_servlets,
- ]
-
- def test_receiving_local_device_list_changes(self) -> None:
- """Tests that a local users that share a room receive each other's device list
- changes.
- """
- # Register two users
- test_device_id = "TESTDEVICE"
- alice_user_id = self.register_user("alice", "correcthorse")
- alice_access_token = self.login(
- alice_user_id, "correcthorse", device_id=test_device_id
- )
-
- bob_user_id = self.register_user("bob", "ponyponypony")
- bob_access_token = self.login(bob_user_id, "ponyponypony")
-
- # Create a room for them to coexist peacefully in
- new_room_id = self.helper.create_room_as(
- alice_user_id, is_public=True, tok=alice_access_token
- )
- self.assertIsNotNone(new_room_id)
-
- # Have Bob join the room
- self.helper.invite(
- new_room_id, alice_user_id, bob_user_id, tok=alice_access_token
- )
- self.helper.join(new_room_id, bob_user_id, tok=bob_access_token)
-
- # Now have Bob initiate an initial sync (in order to get a since token)
- channel = self.make_request(
- "GET",
- "/sync",
- access_token=bob_access_token,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
- next_batch_token = channel.json_body["next_batch"]
-
- # ...and then an incremental sync. This should block until the sync stream is woken up,
- # which we hope will happen as a result of Alice updating their device list.
- bob_sync_channel = self.make_request(
- "GET",
- f"/sync?since={next_batch_token}&timeout=30000",
- access_token=bob_access_token,
- # Start the request, then continue on.
- await_result=False,
- )
-
- # Have alice update their device list
- channel = self.make_request(
- "PUT",
- f"/devices/{test_device_id}",
- {
- "display_name": "New Device Name",
- },
- access_token=alice_access_token,
- )
- self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
-
- # Check that bob's incremental sync contains the updated device list.
- # If not, the client would only receive the device list update on the
- # *next* sync.
- bob_sync_channel.await_result()
- self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
-
- changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
- "changed", []
- )
- self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body)
-
- def test_not_receiving_local_device_list_changes(self) -> None:
- """Tests a local users DO NOT receive device updates from each other if they do not
- share a room.
- """
- # Register two users
- test_device_id = "TESTDEVICE"
- alice_user_id = self.register_user("alice", "correcthorse")
- alice_access_token = self.login(
- alice_user_id, "correcthorse", device_id=test_device_id
- )
-
- bob_user_id = self.register_user("bob", "ponyponypony")
- bob_access_token = self.login(bob_user_id, "ponyponypony")
-
- # These users do not share a room. They are lonely.
-
- # Have Bob initiate an initial sync (in order to get a since token)
- channel = self.make_request(
- "GET",
- "/sync",
- access_token=bob_access_token,
- )
- self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
- next_batch_token = channel.json_body["next_batch"]
-
- # ...and then an incremental sync. This should block until the sync stream is woken up,
- # which we hope will happen as a result of Alice updating their device list.
- bob_sync_channel = self.make_request(
- "GET",
- f"/sync?since={next_batch_token}&timeout=1000",
- access_token=bob_access_token,
- # Start the request, then continue on.
- await_result=False,
- )
-
- # Have alice update their device list
- channel = self.make_request(
- "PUT",
- f"/devices/{test_device_id}",
- {
- "display_name": "New Device Name",
- },
- access_token=alice_access_token,
- )
- self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
-
- # Check that bob's incremental sync does not contain the updated device list.
- bob_sync_channel.await_result()
- self.assertEqual(
- bob_sync_channel.code, HTTPStatus.OK, bob_sync_channel.json_body
- )
-
- changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
- "changed", []
- )
- self.assertNotIn(
- alice_user_id, changed_device_lists, bob_sync_channel.json_body
- )
-
-
class DevicesTestCase(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,
diff --git a/tests/rest/client/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py
index 2f994ad553..5ef501c6d5 100644
--- a/tests/rest/client/test_sendtodevice.py
+++ b/tests/rest/client/test_sendtodevice.py
@@ -18,15 +18,39 @@
# [This file includes modifications made by New Vector Limited]
#
#
+from parameterized import parameterized_class
from synapse.api.constants import EduTypes
from synapse.rest import admin
from synapse.rest.client import login, sendtodevice, sync
+from synapse.types import JsonDict
from tests.unittest import HomeserverTestCase, override_config
+@parameterized_class(
+ ("sync_endpoint", "experimental_features"),
+ [
+ ("/sync", {}),
+ (
+ "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
+ # Enable sliding sync
+ {"msc3575_enabled": True},
+ ),
+ ],
+)
class SendToDeviceTestCase(HomeserverTestCase):
+ """
+ Test `/sendToDevice` will deliver messages across to people receiving them over `/sync`.
+
+ Attributes:
+ sync_endpoint: The endpoint under test to use for syncing.
+ experimental_features: The experimental features homeserver config to use.
+ """
+
+ sync_endpoint: str
+ experimental_features: JsonDict
+
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -34,6 +58,11 @@ class SendToDeviceTestCase(HomeserverTestCase):
sync.register_servlets,
]
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["experimental_features"] = self.experimental_features
+ return config
+
def test_user_to_user(self) -> None:
"""A to-device message from one user to another should get delivered"""
@@ -54,7 +83,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
self.assertEqual(chan.code, 200, chan.result)
# check it appears
- channel = self.make_request("GET", "/sync", access_token=user2_tok)
+ channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
expected_result = {
"events": [
@@ -67,15 +96,19 @@ class SendToDeviceTestCase(HomeserverTestCase):
}
self.assertEqual(channel.json_body["to_device"], expected_result)
- # it should re-appear if we do another sync
- channel = self.make_request("GET", "/sync", access_token=user2_tok)
+ # it should re-appear if we do another sync because the to-device message is not
+ # deleted until we acknowledge it by sending a `?since=...` parameter in the
+ # next sync request corresponding to the `next_batch` value from the response.
+ channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body["to_device"], expected_result)
# it should *not* appear if we do an incremental sync
sync_token = channel.json_body["next_batch"]
channel = self.make_request(
- "GET", f"/sync?since={sync_token}", access_token=user2_tok
+ "GET",
+ f"{self.sync_endpoint}?since={sync_token}",
+ access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), [])
@@ -99,15 +132,19 @@ class SendToDeviceTestCase(HomeserverTestCase):
)
self.assertEqual(chan.code, 200, chan.result)
- # now sync: we should get two of the three
- channel = self.make_request("GET", "/sync", access_token=user2_tok)
+ # now sync: we should get two of the three (because burst_count=2)
+ channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 2)
for i in range(2):
self.assertEqual(
msgs[i],
- {"sender": user1, "type": "m.room_key_request", "content": {"idx": i}},
+ {
+ "sender": user1,
+ "type": "m.room_key_request",
+ "content": {"idx": i},
+ },
)
sync_token = channel.json_body["next_batch"]
@@ -125,7 +162,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
# ... which should arrive
channel = self.make_request(
- "GET", f"/sync?since={sync_token}", access_token=user2_tok
+ "GET",
+ f"{self.sync_endpoint}?since={sync_token}",
+ access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
@@ -159,7 +198,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
)
# now sync: we should get two of the three
- channel = self.make_request("GET", "/sync", access_token=user2_tok)
+ channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 2)
@@ -193,7 +232,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
# ... which should arrive
channel = self.make_request(
- "GET", f"/sync?since={sync_token}", access_token=user2_tok
+ "GET",
+ f"{self.sync_endpoint}?since={sync_token}",
+ access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
@@ -217,7 +258,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
user2_tok = self.login("u2", "pass", "d2")
# Do an initial sync
- channel = self.make_request("GET", "/sync", access_token=user2_tok)
+ channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
sync_token = channel.json_body["next_batch"]
@@ -233,7 +274,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
self.assertEqual(chan.code, 200, chan.result)
channel = self.make_request(
- "GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok
+ "GET",
+ f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
+ access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
messages = channel.json_body.get("to_device", {}).get("events", [])
@@ -241,7 +284,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
sync_token = channel.json_body["next_batch"]
channel = self.make_request(
- "GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok
+ "GET",
+ f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
+ access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
messages = channel.json_body.get("to_device", {}).get("events", [])
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 417a87feb2..daeb1d3ddd 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -21,7 +21,7 @@
import json
from typing import List
-from parameterized import parameterized
+from parameterized import parameterized, parameterized_class
from twisted.test.proto_helpers import MemoryReactor
@@ -688,24 +688,180 @@ class SyncCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.json_body)
+@parameterized_class(
+ ("sync_endpoint", "experimental_features"),
+ [
+ ("/sync", {}),
+ (
+ "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
+ # Enable sliding sync
+ {"msc3575_enabled": True},
+ ),
+ ],
+)
class DeviceListSyncTestCase(unittest.HomeserverTestCase):
+ """
+ Tests regarding device list (`device_lists`) changes.
+
+ Attributes:
+ sync_endpoint: The endpoint under test to use for syncing.
+ experimental_features: The experimental features homeserver config to use.
+ """
+
+ sync_endpoint: str
+ experimental_features: JsonDict
+
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
+ room.register_servlets,
sync.register_servlets,
devices.register_servlets,
]
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["experimental_features"] = self.experimental_features
+ return config
+
+ def test_receiving_local_device_list_changes(self) -> None:
+ """Tests that a local users that share a room receive each other's device list
+ changes.
+ """
+ # Register two users
+ test_device_id = "TESTDEVICE"
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
+
+ bob_user_id = self.register_user("bob", "ponyponypony")
+ bob_access_token = self.login(bob_user_id, "ponyponypony")
+
+ # Create a room for them to coexist peacefully in
+ new_room_id = self.helper.create_room_as(
+ alice_user_id, is_public=True, tok=alice_access_token
+ )
+ self.assertIsNotNone(new_room_id)
+
+ # Have Bob join the room
+ self.helper.invite(
+ new_room_id, alice_user_id, bob_user_id, tok=alice_access_token
+ )
+ self.helper.join(new_room_id, bob_user_id, tok=bob_access_token)
+
+ # Now have Bob initiate an initial sync (in order to get a since token)
+ channel = self.make_request(
+ "GET",
+ self.sync_endpoint,
+ access_token=bob_access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ next_batch_token = channel.json_body["next_batch"]
+
+ # ...and then an incremental sync. This should block until the sync stream is woken up,
+ # which we hope will happen as a result of Alice updating their device list.
+ bob_sync_channel = self.make_request(
+ "GET",
+ f"{self.sync_endpoint}?since={next_batch_token}&timeout=30000",
+ access_token=bob_access_token,
+ # Start the request, then continue on.
+ await_result=False,
+ )
+
+ # Have alice update their device list
+ channel = self.make_request(
+ "PUT",
+ f"/devices/{test_device_id}",
+ {
+ "display_name": "New Device Name",
+ },
+ access_token=alice_access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check that bob's incremental sync contains the updated device list.
+ # If not, the client would only receive the device list update on the
+ # *next* sync.
+ bob_sync_channel.await_result()
+ self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
+
+ changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
+ "changed", []
+ )
+ self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body)
+
+ def test_not_receiving_local_device_list_changes(self) -> None:
+ """Tests a local users DO NOT receive device updates from each other if they do not
+ share a room.
+ """
+ # Register two users
+ test_device_id = "TESTDEVICE"
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
+
+ bob_user_id = self.register_user("bob", "ponyponypony")
+ bob_access_token = self.login(bob_user_id, "ponyponypony")
+
+ # These users do not share a room. They are lonely.
+
+ # Have Bob initiate an initial sync (in order to get a since token)
+ channel = self.make_request(
+ "GET",
+ self.sync_endpoint,
+ access_token=bob_access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ next_batch_token = channel.json_body["next_batch"]
+
+ # ...and then an incremental sync. This should block until the sync stream is woken up,
+ # which we hope will happen as a result of Alice updating their device list.
+ bob_sync_channel = self.make_request(
+ "GET",
+ f"{self.sync_endpoint}?since={next_batch_token}&timeout=1000",
+ access_token=bob_access_token,
+ # Start the request, then continue on.
+ await_result=False,
+ )
+
+ # Have alice update their device list
+ channel = self.make_request(
+ "PUT",
+ f"/devices/{test_device_id}",
+ {
+ "display_name": "New Device Name",
+ },
+ access_token=alice_access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check that bob's incremental sync does not contain the updated device list.
+ bob_sync_channel.await_result()
+ self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
+
+ changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
+ "changed", []
+ )
+ self.assertNotIn(
+ alice_user_id, changed_device_lists, bob_sync_channel.json_body
+ )
+
def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None:
"""Tests that a user with no rooms still receives their own device list updates"""
- device_id = "TESTDEVICE"
+ test_device_id = "TESTDEVICE"
# Register a user and login, creating a device
- self.user_id = self.register_user("kermit", "monkey")
- self.tok = self.login("kermit", "monkey", device_id=device_id)
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
# Request an initial sync
- channel = self.make_request("GET", "/sync", access_token=self.tok)
+ channel = self.make_request(
+ "GET", self.sync_endpoint, access_token=alice_access_token
+ )
self.assertEqual(channel.code, 200, channel.json_body)
next_batch = channel.json_body["next_batch"]
@@ -713,19 +869,19 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
# It won't return until something has happened
incremental_sync_channel = self.make_request(
"GET",
- f"/sync?since={next_batch}&timeout=30000",
- access_token=self.tok,
+ f"{self.sync_endpoint}?since={next_batch}&timeout=30000",
+ access_token=alice_access_token,
await_result=False,
)
# Change our device's display name
channel = self.make_request(
"PUT",
- f"devices/{device_id}",
+ f"devices/{test_device_id}",
{
"display_name": "freeze ray",
},
- access_token=self.tok,
+ access_token=alice_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -739,7 +895,230 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
).get("changed", [])
self.assertIn(
- self.user_id, device_list_changes, incremental_sync_channel.json_body
+ alice_user_id, device_list_changes, incremental_sync_channel.json_body
+ )
+
+
+@parameterized_class(
+ ("sync_endpoint", "experimental_features"),
+ [
+ ("/sync", {}),
+ (
+ "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
+ # Enable sliding sync
+ {"msc3575_enabled": True},
+ ),
+ ],
+)
+class DeviceOneTimeKeysSyncTestCase(unittest.HomeserverTestCase):
+ """
+ Tests regarding device one time keys (`device_one_time_keys_count`) changes.
+
+ Attributes:
+ sync_endpoint: The endpoint under test to use for syncing.
+ experimental_features: The experimental features homeserver config to use.
+ """
+
+ sync_endpoint: str
+ experimental_features: JsonDict
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ devices.register_servlets,
+ ]
+
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["experimental_features"] = self.experimental_features
+ return config
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.e2e_keys_handler = hs.get_e2e_keys_handler()
+
+ def test_no_device_one_time_keys(self) -> None:
+ """
+ Tests when no one time keys set, it still has the default `signed_curve25519` in
+ `device_one_time_keys_count`
+ """
+ test_device_id = "TESTDEVICE"
+
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
+
+ # Request an initial sync
+ channel = self.make_request(
+ "GET", self.sync_endpoint, access_token=alice_access_token
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check for those one time key counts
+ self.assertDictEqual(
+ channel.json_body["device_one_time_keys_count"],
+ # Note that "signed_curve25519" is always returned in key count responses
+ # regardless of whether we uploaded any keys for it. This is necessary until
+ # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
+ {"signed_curve25519": 0},
+ channel.json_body["device_one_time_keys_count"],
+ )
+
+ def test_returns_device_one_time_keys(self) -> None:
+ """
+ Tests that one time keys for the device/user are counted correctly in the `/sync`
+ response
+ """
+ test_device_id = "TESTDEVICE"
+
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
+
+ # Upload one time keys for the user/device
+ keys: JsonDict = {
+ "alg1:k1": "key1",
+ "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
+ "alg2:k3": {"key": "key3"},
+ }
+ res = self.get_success(
+ self.e2e_keys_handler.upload_keys_for_user(
+ alice_user_id, test_device_id, {"one_time_keys": keys}
+ )
+ )
+ # Note that "signed_curve25519" is always returned in key count responses
+ # regardless of whether we uploaded any keys for it. This is necessary until
+ # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
+ self.assertDictEqual(
+ res,
+ {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}},
+ )
+
+ # Request an initial sync
+ channel = self.make_request(
+ "GET", self.sync_endpoint, access_token=alice_access_token
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check for those one time key counts
+ self.assertDictEqual(
+ channel.json_body["device_one_time_keys_count"],
+ {"alg1": 1, "alg2": 2, "signed_curve25519": 0},
+ channel.json_body["device_one_time_keys_count"],
+ )
+
+
+@parameterized_class(
+ ("sync_endpoint", "experimental_features"),
+ [
+ ("/sync", {}),
+ (
+ "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
+ # Enable sliding sync
+ {"msc3575_enabled": True},
+ ),
+ ],
+)
+class DeviceUnusedFallbackKeySyncTestCase(unittest.HomeserverTestCase):
+ """
+ Tests regarding device one time keys (`device_unused_fallback_key_types`) changes.
+
+ Attributes:
+ sync_endpoint: The endpoint under test to use for syncing.
+ experimental_features: The experimental features homeserver config to use.
+ """
+
+ sync_endpoint: str
+ experimental_features: JsonDict
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ devices.register_servlets,
+ ]
+
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["experimental_features"] = self.experimental_features
+ return config
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = self.hs.get_datastores().main
+ self.e2e_keys_handler = hs.get_e2e_keys_handler()
+
+ def test_no_device_unused_fallback_key(self) -> None:
+ """
+ Test when no unused fallback key is set, it just returns an empty list. The MSC
+ says "The device_unused_fallback_key_types parameter must be present if the
+ server supports fallback keys.",
+ https://github.com/matrix-org/matrix-spec-proposals/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md
+ """
+ test_device_id = "TESTDEVICE"
+
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
+
+ # Request an initial sync
+ channel = self.make_request(
+ "GET", self.sync_endpoint, access_token=alice_access_token
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check for those one time key counts
+ self.assertListEqual(
+ channel.json_body["device_unused_fallback_key_types"],
+ [],
+ channel.json_body["device_unused_fallback_key_types"],
+ )
+
+ def test_returns_device_one_time_keys(self) -> None:
+ """
+ Tests that device unused fallback key type is returned correctly in the `/sync`
+ """
+ test_device_id = "TESTDEVICE"
+
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
+
+ # We shouldn't have any unused fallback keys yet
+ res = self.get_success(
+ self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id)
+ )
+ self.assertEqual(res, [])
+
+ # Upload a fallback key for the user/device
+ fallback_key = {"alg1:k1": "fallback_key1"}
+ self.get_success(
+ self.e2e_keys_handler.upload_keys_for_user(
+ alice_user_id,
+ test_device_id,
+ {"fallback_keys": fallback_key},
+ )
+ )
+ # We should now have an unused alg1 key
+ fallback_res = self.get_success(
+ self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id)
+ )
+ self.assertEqual(fallback_res, ["alg1"], fallback_res)
+
+ # Request an initial sync
+ channel = self.make_request(
+ "GET", self.sync_endpoint, access_token=alice_access_token
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check for the unused fallback key types
+ self.assertListEqual(
+ channel.json_body["device_unused_fallback_key_types"],
+ ["alg1"],
+ channel.json_body["device_unused_fallback_key_types"],
)
|