summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/test_devices.py160
1 files changed, 159 insertions, 1 deletions
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 6790aa5242..b547bf8d99 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -94,7 +94,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
     def test_get_device_updates_by_remote(self):
         device_ids = ["device_id1", "device_id2"]
 
-        # Add two device updates with a single stream_id
+        # Add two device updates with sequential `stream_id`s
         self.get_success(
             self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
         )
@@ -107,6 +107,164 @@ class DeviceStoreTestCase(HomeserverTestCase):
         # Check original device_ids are contained within these updates
         self._check_devices_in_updates(device_ids, device_updates)
 
+    def test_get_device_updates_by_remote_can_limit_properly(self):
+        """
+        Tests that `get_device_updates_by_remote` returns an appropriate
+        stream_id to resume fetching from (without skipping any results).
+        """
+
+        # Add some device updates with sequential `stream_id`s
+        device_ids = [
+            "device_id1",
+            "device_id2",
+            "device_id3",
+            "device_id4",
+            "device_id5",
+        ]
+        self.get_success(
+            self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
+        )
+
+        # Get device updates meant for this remote
+        next_stream_id, device_updates = self.get_success(
+            self.store.get_device_updates_by_remote("somehost", -1, limit=3)
+        )
+
+        # Check the first three original device_ids are contained within these updates
+        self._check_devices_in_updates(device_ids[:3], device_updates)
+
+        # Get the next batch of device updates
+        next_stream_id, device_updates = self.get_success(
+            self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
+        )
+
+        # Check the last two original device_ids are contained within these updates
+        self._check_devices_in_updates(device_ids[3:], device_updates)
+
+        # Add some more device updates to ensure it still resumes properly
+        device_ids = ["device_id6", "device_id7"]
+        self.get_success(
+            self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
+        )
+
+        # Get the next batch of device updates
+        next_stream_id, device_updates = self.get_success(
+            self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
+        )
+
+        # Check the newly-added device_ids are contained within these updates
+        self._check_devices_in_updates(device_ids, device_updates)
+
+        # Check there are no more device updates left.
+        _, device_updates = self.get_success(
+            self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
+        )
+        self.assertEqual(device_updates, [])
+
+    def test_get_device_updates_by_remote_cross_signing_key_updates(
+        self,
+    ) -> None:
+        """
+        Tests that `get_device_updates_by_remote` limits the length of the return value
+        properly when cross-signing key updates are present.
+        Current behaviour is that the cross-signing key updates will always come in pairs,
+        even if that means leaving an earlier batch one EDU short of the limit.
+        """
+
+        assert self.hs.is_mine_id(
+            "@user_id:test"
+        ), "Test not valid: this MXID should be considered local"
+
+        self.get_success(
+            self.store.set_e2e_cross_signing_key(
+                "@user_id:test",
+                "master",
+                {
+                    "keys": {
+                        "ed25519:fakeMaster": "aaafakefakefake1AAAAAAAAAAAAAAAAAAAAAAAAAAA="
+                    },
+                    "signatures": {
+                        "@user_id:test": {
+                            "ed25519:fake2": "aaafakefakefake2AAAAAAAAAAAAAAAAAAAAAAAAAAA="
+                        }
+                    },
+                },
+            )
+        )
+        self.get_success(
+            self.store.set_e2e_cross_signing_key(
+                "@user_id:test",
+                "self_signing",
+                {
+                    "keys": {
+                        "ed25519:fakeSelfSigning": "aaafakefakefake3AAAAAAAAAAAAAAAAAAAAAAAAAAA="
+                    },
+                    "signatures": {
+                        "@user_id:test": {
+                            "ed25519:fake4": "aaafakefakefake4AAAAAAAAAAAAAAAAAAAAAAAAAAA="
+                        }
+                    },
+                },
+            )
+        )
+
+        # Add some device updates with sequential `stream_id`s
+        # Note that the public cross-signing keys occupy the same space as device IDs,
+        # so also notify that those have updated.
+        device_ids = [
+            "device_id1",
+            "device_id2",
+            "fakeMaster",
+            "fakeSelfSigning",
+        ]
+
+        self.get_success(
+            self.store.add_device_change_to_streams(
+                "@user_id:test", device_ids, ["somehost"]
+            )
+        )
+
+        # Get device updates meant for this remote
+        next_stream_id, device_updates = self.get_success(
+            self.store.get_device_updates_by_remote("somehost", -1, limit=3)
+        )
+
+        # Here we expect the device updates for `device_id1` and `device_id2`.
+        # That means we only receive 2 updates this time around.
+        # If we had a higher limit, we would expect to see the pair of
+        # (unstable-prefixed & unprefixed) signing key updates for the device
+        # represented by `fakeMaster` and `fakeSelfSigning`.
+        # Our implementation only sends these two variants together, so we get
+        # a short batch.
+        self.assertEqual(len(device_updates), 2, device_updates)
+
+        # Check the first two devices (device_id1, device_id2) came out.
+        self._check_devices_in_updates(device_ids[:2], device_updates)
+
+        # Get more device updates meant for this remote
+        next_stream_id, device_updates = self.get_success(
+            self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
+        )
+
+        # The next 2 updates should be a cross-signing key update
+        # (the master key update and the self-signing key update are combined into
+        # one 'signing key update', but the cross-signing key update is emitted
+        # twice, once with an unprefixed type and once again with an unstable-prefixed type)
+        # (This is a temporary arrangement for backwards compatibility!)
+        self.assertEqual(len(device_updates), 2, device_updates)
+        self.assertEqual(
+            device_updates[0][0], "m.signing_key_update", device_updates[0]
+        )
+        self.assertEqual(
+            device_updates[1][0], "org.matrix.signing_key_update", device_updates[1]
+        )
+
+        # Check there are no more device updates left.
+        _, device_updates = self.get_success(
+            self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
+        )
+        self.assertEqual(device_updates, [])
+
     def _check_devices_in_updates(self, expected_device_ids, device_updates):
         """Check that an specific device ids exist in a list of device update EDUs"""
         self.assertEqual(len(device_updates), len(expected_device_ids))