summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/7158.misc1
-rw-r--r--synapse/storage/data_stores/main/devices.py10
-rw-r--r--tests/federation/test_federation_sender.py6
3 files changed, 15 insertions, 2 deletions
diff --git a/changelog.d/7158.misc b/changelog.d/7158.misc
new file mode 100644
index 0000000000..269b8daeb0
--- /dev/null
+++ b/changelog.d/7158.misc
@@ -0,0 +1 @@
+Fix device list update stream ids going backward.
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 20995e1b78..dd3561e9b2 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -165,7 +165,6 @@ class DeviceWorkerStore(SQLBaseStore):
         # the max stream_id across each set of duplicate entries
         #
         # maps (user_id, device_id) -> (stream_id, opentracing_context)
-        # as long as their stream_id does not match that of the last row
         #
         # opentracing_context contains the opentracing metadata for the request
         # that created the poke
@@ -270,7 +269,14 @@ class DeviceWorkerStore(SQLBaseStore):
             prev_id = yield self._get_last_device_update_for_remote_user(
                 destination, user_id, from_stream_id
             )
-            for device_id, device in iteritems(user_devices):
+
+            # make sure we go through the devices in stream order
+            device_ids = sorted(
+                user_devices.keys(), key=lambda i: query_map[(user_id, i)][0],
+            )
+
+            for device_id in device_ids:
+                device = user_devices[device_id]
                 stream_id, opentracing_context = query_map[(user_id, device_id)]
                 result = {
                     "user_id": user_id,
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index a5fe5c6880..33105576af 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -297,6 +297,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
             c = edu["content"]
             if stream_id is not None:
                 self.assertEqual(c["prev_id"], [stream_id])
+                self.assertGreaterEqual(c["stream_id"], stream_id)
             stream_id = c["stream_id"]
         devices = {edu["content"]["device_id"] for edu in self.edus}
         self.assertEqual({"D1", "D2"}, devices)
@@ -330,6 +331,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
                 c.items(),
                 {"user_id": u1, "prev_id": [stream_id], "deleted": True}.items(),
             )
+            self.assertGreaterEqual(c["stream_id"], stream_id)
             stream_id = c["stream_id"]
         devices = {edu["content"]["device_id"] for edu in self.edus}
         self.assertEqual({"D1", "D2", "D3"}, devices)
@@ -366,6 +368,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
             self.assertEqual(edu["edu_type"], "m.device_list_update")
             c = edu["content"]
             self.assertEqual(c["prev_id"], [stream_id] if stream_id is not None else [])
+            if stream_id is not None:
+                self.assertGreaterEqual(c["stream_id"], stream_id)
             stream_id = c["stream_id"]
         devices = {edu["content"]["device_id"] for edu in self.edus}
         self.assertEqual({"D1", "D2", "D3"}, devices)
@@ -482,6 +486,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         }
 
         self.assertLessEqual(expected.items(), content.items())
+        if prev_stream_id is not None:
+            self.assertGreaterEqual(content["stream_id"], prev_stream_id)
         return content["stream_id"]
 
     def check_signing_key_update_txn(self, txn: JsonDict,) -> None: