diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 5cd7f6bb7a..b547bf8d99 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -125,7 +125,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
)
- # Get all device updates ever meant for this remote
+ # Get device updates meant for this remote
next_stream_id, device_updates = self.get_success(
self.store.get_device_updates_by_remote("somehost", -1, limit=3)
)
@@ -155,6 +155,116 @@ class DeviceStoreTestCase(HomeserverTestCase):
# Check the newly-added device_ids are contained within these updates
self._check_devices_in_updates(device_ids, device_updates)
+ # Check there are no more device updates left.
+ _, device_updates = self.get_success(
+ self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
+ )
+ self.assertEqual(device_updates, [])
+
+ def test_get_device_updates_by_remote_cross_signing_key_updates(
+ self,
+ ) -> None:
+ """
+ Tests that `get_device_updates_by_remote` limits the length of the return value
+ properly when cross-signing key updates are present.
+ Current behaviour is that the cross-signing key updates will always come in pairs,
+ even if that means leaving an earlier batch one EDU short of the limit.
+ """
+
+ assert self.hs.is_mine_id(
+ "@user_id:test"
+ ), "Test not valid: this MXID should be considered local"
+
+ self.get_success(
+ self.store.set_e2e_cross_signing_key(
+ "@user_id:test",
+ "master",
+ {
+ "keys": {
+ "ed25519:fakeMaster": "aaafakefakefake1AAAAAAAAAAAAAAAAAAAAAAAAAAA="
+ },
+ "signatures": {
+ "@user_id:test": {
+ "ed25519:fake2": "aaafakefakefake2AAAAAAAAAAAAAAAAAAAAAAAAAAA="
+ }
+ },
+ },
+ )
+ )
+ self.get_success(
+ self.store.set_e2e_cross_signing_key(
+ "@user_id:test",
+ "self_signing",
+ {
+ "keys": {
+ "ed25519:fakeSelfSigning": "aaafakefakefake3AAAAAAAAAAAAAAAAAAAAAAAAAAA="
+ },
+ "signatures": {
+ "@user_id:test": {
+ "ed25519:fake4": "aaafakefakefake4AAAAAAAAAAAAAAAAAAAAAAAAAAA="
+ }
+ },
+ },
+ )
+ )
+
+ # Add some device updates with sequential `stream_id`s
+ # Note that the public cross-signing keys occupy the same space as device IDs,
+ # so also notify that those have updated.
+ device_ids = [
+ "device_id1",
+ "device_id2",
+ "fakeMaster",
+ "fakeSelfSigning",
+ ]
+
+ self.get_success(
+ self.store.add_device_change_to_streams(
+ "@user_id:test", device_ids, ["somehost"]
+ )
+ )
+
+ # Get device updates meant for this remote
+ next_stream_id, device_updates = self.get_success(
+ self.store.get_device_updates_by_remote("somehost", -1, limit=3)
+ )
+
+ # Here we expect the device updates for `device_id1` and `device_id2`.
+ # That means we only receive 2 updates this time around.
+ # If we had a higher limit, we would expect to see the pair of
+ # (unstable-prefixed & unprefixed) signing key updates for the device
+ # represented by `fakeMaster` and `fakeSelfSigning`.
+ # Our implementation only sends these two variants together, so we get
+ # a short batch.
+ self.assertEqual(len(device_updates), 2, device_updates)
+
+ # Check the first two devices (device_id1, device_id2) came out.
+ self._check_devices_in_updates(device_ids[:2], device_updates)
+
+ # Get more device updates meant for this remote
+ next_stream_id, device_updates = self.get_success(
+ self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
+ )
+
+ # The next 2 updates should be a cross-signing key update
+ # (the master key update and the self-signing key update are combined into
+ # one 'signing key update', but the cross-signing key update is emitted
+ # twice, once with an unprefixed type and once again with an unstable-prefixed type)
+ # (This is a temporary arrangement for backwards compatibility!)
+ self.assertEqual(len(device_updates), 2, device_updates)
+ self.assertEqual(
+ device_updates[0][0], "m.signing_key_update", device_updates[0]
+ )
+ self.assertEqual(
+ device_updates[1][0], "org.matrix.signing_key_update", device_updates[1]
+ )
+
+ # Check there are no more device updates left.
+ _, device_updates = self.get_success(
+ self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
+ )
+ self.assertEqual(device_updates, [])
+
def _check_devices_in_updates(self, expected_device_ids, device_updates):
"""Check that an specific device ids exist in a list of device update EDUs"""
self.assertEqual(len(device_updates), len(expected_device_ids))
|