diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 6790aa5242..5cd7f6bb7a 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -94,7 +94,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
def test_get_device_updates_by_remote(self):
device_ids = ["device_id1", "device_id2"]
- # Add two device updates with a single stream_id
+ # Add two device updates with sequential `stream_id`s
self.get_success(
self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
)
@@ -107,6 +107,54 @@ class DeviceStoreTestCase(HomeserverTestCase):
# Check original device_ids are contained within these updates
self._check_devices_in_updates(device_ids, device_updates)
+ def test_get_device_updates_by_remote_can_limit_properly(self):
+ """
+ Tests that `get_device_updates_by_remote` returns an appropriate
+ stream_id to resume fetching from (without skipping any results).
+ """
+
+ # Add some device updates with sequential `stream_id`s
+ device_ids = [
+ "device_id1",
+ "device_id2",
+ "device_id3",
+ "device_id4",
+ "device_id5",
+ ]
+ self.get_success(
+ self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
+ )
+
+ # Get all device updates ever meant for this remote
+ next_stream_id, device_updates = self.get_success(
+ self.store.get_device_updates_by_remote("somehost", -1, limit=3)
+ )
+
+ # Check the first three original device_ids are contained within these updates
+ self._check_devices_in_updates(device_ids[:3], device_updates)
+
+ # Get the next batch of device updates
+ next_stream_id, device_updates = self.get_success(
+ self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
+ )
+
+ # Check the last two original device_ids are contained within these updates
+ self._check_devices_in_updates(device_ids[3:], device_updates)
+
+ # Add some more device updates to ensure it still resumes properly
+ device_ids = ["device_id6", "device_id7"]
+ self.get_success(
+ self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
+ )
+
+ # Get the next batch of device updates
+ next_stream_id, device_updates = self.get_success(
+ self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
+ )
+
+ # Check the newly-added device_ids are contained within these updates
+ self._check_devices_in_updates(device_ids, device_updates)
+
def _check_devices_in_updates(self, expected_device_ids, device_updates):
"""Check that an specific device ids exist in a list of device update EDUs"""
self.assertEqual(len(device_updates), len(expected_device_ids))
|