summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11729.bugfix1
-rw-r--r--synapse/storage/databases/main/devices.py8
-rw-r--r--tests/storage/test_devices.py50
3 files changed, 57 insertions, 2 deletions
diff --git a/changelog.d/11729.bugfix b/changelog.d/11729.bugfix
new file mode 100644
index 0000000000..8438ce5686
--- /dev/null
+++ b/changelog.d/11729.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in Synapse v1.0.0 whereby some device list updates would not be sent to remote homeservers if there were too many to send at once.
\ No newline at end of file
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 273adb61fd..324bd5f879 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -270,6 +270,10 @@ class DeviceWorkerStore(SQLBaseStore):
         # The most recent request's opentracing_context is used as the
         # context which created the Edu.
 
+        # This is the stream ID that we will return for the consumer to resume
+        # following this stream later.
+        last_processed_stream_id = from_stream_id
+
         query_map = {}
         cross_signing_keys_by_user = {}
         for user_id, device_id, update_stream_id, update_context in updates:
@@ -295,6 +299,8 @@ class DeviceWorkerStore(SQLBaseStore):
                 if update_stream_id > previous_update_stream_id:
                     query_map[key] = (update_stream_id, update_context)
 
+            last_processed_stream_id = update_stream_id
+
         results = await self._get_device_update_edus_by_remote(
             destination, from_stream_id, query_map
         )
@@ -307,7 +313,7 @@ class DeviceWorkerStore(SQLBaseStore):
             # FIXME: remove this when enough servers have upgraded
             results.append(("org.matrix.signing_key_update", result))
 
-        return now_stream_id, results
+        return last_processed_stream_id, results
 
     def _get_device_updates_by_remote_txn(
         self,
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 6790aa5242..5cd7f6bb7a 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -94,7 +94,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
     def test_get_device_updates_by_remote(self):
         device_ids = ["device_id1", "device_id2"]
 
-        # Add two device updates with a single stream_id
+        # Add two device updates with sequential `stream_id`s
         self.get_success(
             self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
         )
@@ -107,6 +107,54 @@ class DeviceStoreTestCase(HomeserverTestCase):
         # Check original device_ids are contained within these updates
         self._check_devices_in_updates(device_ids, device_updates)
 
+    def test_get_device_updates_by_remote_can_limit_properly(self):
+        """
+        Tests that `get_device_updates_by_remote` returns an appropriate
+        stream_id to resume fetching from (without skipping any results).
+        """
+
+        # Add some device updates with sequential `stream_id`s
+        device_ids = [
+            "device_id1",
+            "device_id2",
+            "device_id3",
+            "device_id4",
+            "device_id5",
+        ]
+        self.get_success(
+            self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
+        )
+
+        # Get all device updates ever meant for this remote
+        next_stream_id, device_updates = self.get_success(
+            self.store.get_device_updates_by_remote("somehost", -1, limit=3)
+        )
+
+        # Check the first three original device_ids are contained within these updates
+        self._check_devices_in_updates(device_ids[:3], device_updates)
+
+        # Get the next batch of device updates
+        next_stream_id, device_updates = self.get_success(
+            self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
+        )
+
+        # Check the last two original device_ids are contained within these updates
+        self._check_devices_in_updates(device_ids[3:], device_updates)
+
+        # Add some more device updates to ensure it still resumes properly
+        device_ids = ["device_id6", "device_id7"]
+        self.get_success(
+            self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
+        )
+
+        # Get the next batch of device updates
+        next_stream_id, device_updates = self.get_success(
+            self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
+        )
+
+        # Check the newly-added device_ids are contained within these updates
+        self._check_devices_in_updates(device_ids, device_updates)
+
     def _check_devices_in_updates(self, expected_device_ids, device_updates):
         """Check that an specific device ids exist in a list of device update EDUs"""
         self.assertEqual(len(device_updates), len(expected_device_ids))