diff options
-rw-r--r-- | changelog.d/11729.bugfix | 1 | ||||
-rw-r--r-- | synapse/storage/databases/main/devices.py | 8 | ||||
-rw-r--r-- | tests/storage/test_devices.py | 50 |
3 files changed, 57 insertions, 2 deletions
diff --git a/changelog.d/11729.bugfix b/changelog.d/11729.bugfix new file mode 100644 index 0000000000..8438ce5686 --- /dev/null +++ b/changelog.d/11729.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse v1.0.0 whereby some device list updates would not be sent to remote homeservers if there were too many to send at once. \ No newline at end of file diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 273adb61fd..324bd5f879 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -270,6 +270,10 @@ class DeviceWorkerStore(SQLBaseStore): # The most recent request's opentracing_context is used as the # context which created the Edu. + # This is the stream ID that we will return for the consumer to resume + # following this stream later. + last_processed_stream_id = from_stream_id + query_map = {} cross_signing_keys_by_user = {} for user_id, device_id, update_stream_id, update_context in updates: @@ -295,6 +299,8 @@ class DeviceWorkerStore(SQLBaseStore): if update_stream_id > previous_update_stream_id: query_map[key] = (update_stream_id, update_context) + last_processed_stream_id = update_stream_id + results = await self._get_device_update_edus_by_remote( destination, from_stream_id, query_map ) @@ -307,7 +313,7 @@ class DeviceWorkerStore(SQLBaseStore): # FIXME: remove this when enough servers have upgraded results.append(("org.matrix.signing_key_update", result)) - return now_stream_id, results + return last_processed_stream_id, results def _get_device_updates_by_remote_txn( self, diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 6790aa5242..5cd7f6bb7a 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -94,7 +94,7 @@ class DeviceStoreTestCase(HomeserverTestCase): def test_get_device_updates_by_remote(self): device_ids = ["device_id1", "device_id2"] - # Add two device updates with a single stream_id + # Add two device updates with sequential `stream_id`s self.get_success( self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) ) @@ -107,6 +107,54 @@ class DeviceStoreTestCase(HomeserverTestCase): # Check original device_ids are contained within these updates self._check_devices_in_updates(device_ids, device_updates) + def test_get_device_updates_by_remote_can_limit_properly(self): + """ + Tests that `get_device_updates_by_remote` returns an appropriate + stream_id to resume fetching from (without skipping any results). + """ + + # Add some device updates with sequential `stream_id`s + device_ids = [ + "device_id1", + "device_id2", + "device_id3", + "device_id4", + "device_id5", + ] + self.get_success( + self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) + ) + + # Get all device updates ever meant for this remote + next_stream_id, device_updates = self.get_success( + self.store.get_device_updates_by_remote("somehost", -1, limit=3) + ) + + # Check the first three original device_ids are contained within these updates + self._check_devices_in_updates(device_ids[:3], device_updates) + + # Get the next batch of device updates + next_stream_id, device_updates = self.get_success( + self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3) + ) + + # Check the last two original device_ids are contained within these updates + self._check_devices_in_updates(device_ids[3:], device_updates) + + # Add some more device updates to ensure it still resumes properly + device_ids = ["device_id6", "device_id7"] + self.get_success( + self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) + ) + + # Get the next batch of device updates + next_stream_id, device_updates = self.get_success( + self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3) + ) + + # Check the newly-added device_ids are contained within these updates + self._check_devices_in_updates(device_ids, device_updates) + def _check_devices_in_updates(self, expected_device_ids, device_updates): """Check that an specific device ids exist in a list of device update EDUs""" self.assertEqual(len(device_updates), len(expected_device_ids)) |