summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/5156.bugfix1
-rw-r--r--synapse/federation/sender/per_destination_queue.py5
-rw-r--r--synapse/storage/devices.py152
-rw-r--r--tests/storage/test_devices.py69
4 files changed, 196 insertions, 31 deletions
diff --git a/changelog.d/5156.bugfix b/changelog.d/5156.bugfix
new file mode 100644
index 0000000000..e8aa7d8241
--- /dev/null
+++ b/changelog.d/5156.bugfix
@@ -0,0 +1 @@
+Prevent federation device list updates breaking when processing multiple updates at once.
\ No newline at end of file
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index fae8bea392..564c57203d 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -349,9 +349,10 @@ class PerDestinationQueue(object):
     @defer.inlineCallbacks
     def _get_new_device_messages(self, limit):
         last_device_list = self._last_device_list_stream_id
-        # Will return at most 20 entries
+
+        # Retrieve list of new device updates to send to the destination
         now_stream_id, results = yield self._store.get_devices_by_remote(
-            self._destination, last_device_list
+            self._destination, last_device_list, limit=limit,
         )
         edus = [
             Edu(
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index fd869b934c..d102e07372 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import logging
 
-from six import iteritems, itervalues
+from six import iteritems
 
 from canonicaljson import json
 
@@ -72,11 +72,14 @@ class DeviceWorkerStore(SQLBaseStore):
 
         defer.returnValue({d["device_id"]: d for d in devices})
 
-    def get_devices_by_remote(self, destination, from_stream_id):
+    @defer.inlineCallbacks
+    def get_devices_by_remote(self, destination, from_stream_id, limit):
         """Get stream of updates to send to remote servers
 
         Returns:
-            (int, list[dict]): current stream id and list of updates
+            Deferred[tuple[int, list[dict]]]:
+                current stream id (ie, the stream id of the last update included in the
+                response), and the list of updates
         """
         now_stream_id = self._device_list_id_gen.get_current_token()
 
@@ -84,55 +87,131 @@ class DeviceWorkerStore(SQLBaseStore):
             destination, int(from_stream_id)
         )
         if not has_changed:
-            return (now_stream_id, [])
-
-        return self.runInteraction(
+            defer.returnValue((now_stream_id, []))
+
+        # We retrieve n+1 devices from the list of outbound pokes where n is
+        # our outbound device update limit. We then check if the very last
+        # device has the same stream_id as the second-to-last device. If so,
+        # then we ignore all devices with that stream_id and only send the
+        # devices with a lower stream_id.
+        #
+        # If when culling the list we end up with no devices afterwards, we
+        # consider the device update to be too large, and simply skip the
+        # stream_id; the rationale being that such a large device list update
+        # is likely an error.
+        updates = yield self.runInteraction(
             "get_devices_by_remote",
             self._get_devices_by_remote_txn,
             destination,
             from_stream_id,
             now_stream_id,
+            limit + 1,
         )
 
+        # Return an empty list if there are no updates
+        if not updates:
+            defer.returnValue((now_stream_id, []))
+
+        # if we have exceeded the limit, we need to exclude any results with the
+        # same stream_id as the last row.
+        if len(updates) > limit:
+            stream_id_cutoff = updates[-1][2]
+            now_stream_id = stream_id_cutoff - 1
+        else:
+            stream_id_cutoff = None
+
+        # Perform the equivalent of a GROUP BY
+        #
+        # Iterate through the updates list and copy non-duplicate
+        # (user_id, device_id) entries into a map, with the value being
+        # the max stream_id across each set of duplicate entries
+        #
+        # maps (user_id, device_id) -> stream_id
+        # as long as their stream_id does not match that of the last row
+        query_map = {}
+        for update in updates:
+            if stream_id_cutoff is not None and update[2] >= stream_id_cutoff:
+                # Stop processing updates
+                break
+
+            key = (update[0], update[1])
+            query_map[key] = max(query_map.get(key, 0), update[2])
+
+        # If we didn't find any updates with a stream_id lower than the cutoff, it
+        # means that there are more than limit updates all of which have the same
+        # steam_id.
+
+        # That should only happen if a client is spamming the server with new
+        # devices, in which case E2E isn't going to work well anyway. We'll just
+        # skip that stream_id and return an empty list, and continue with the next
+        # stream_id next time.
+        if not query_map:
+            defer.returnValue((stream_id_cutoff, []))
+
+        results = yield self._get_device_update_edus_by_remote(
+            destination,
+            from_stream_id,
+            query_map,
+        )
+
+        defer.returnValue((now_stream_id, results))
+
     def _get_devices_by_remote_txn(
-        self, txn, destination, from_stream_id, now_stream_id
+        self, txn, destination, from_stream_id, now_stream_id, limit
     ):
+        """Return device update information for a given remote destination
+
+        Args:
+            txn (LoggingTransaction): The transaction to execute
+            destination (str): The host the device updates are intended for
+            from_stream_id (int): The minimum stream_id to filter updates by, exclusive
+            now_stream_id (int): The maximum stream_id to filter updates by, inclusive
+            limit (int): Maximum number of device updates to return
+
+        Returns:
+            List: List of device updates
+        """
         sql = """
-            SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
+            SELECT user_id, device_id, stream_id FROM device_lists_outbound_pokes
             WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
-            GROUP BY user_id, device_id
-            LIMIT 20
+            ORDER BY stream_id
+            LIMIT ?
         """
-        txn.execute(sql, (destination, from_stream_id, now_stream_id, False))
+        txn.execute(sql, (destination, from_stream_id, now_stream_id, False, limit))
 
-        # maps (user_id, device_id) -> stream_id
-        query_map = {(r[0], r[1]): r[2] for r in txn}
-        if not query_map:
-            return (now_stream_id, [])
+        return list(txn)
 
-        if len(query_map) >= 20:
-            now_stream_id = max(stream_id for stream_id in itervalues(query_map))
+    @defer.inlineCallbacks
+    def _get_device_update_edus_by_remote(
+        self, destination, from_stream_id, query_map,
+    ):
+        """Returns a list of device update EDUs as well as E2EE keys
 
-        devices = self._get_e2e_device_keys_txn(
-            txn,
+        Args:
+            destination (str): The host the device updates are intended for
+            from_stream_id (int): The minimum stream_id to filter updates by, exclusive
+            query_map (Dict[(str, str): int]): Dictionary mapping
+                user_id/device_id to update stream_id
+
+        Returns:
+            List[Dict]: List of objects representing an device update EDU
+
+        """
+        devices = yield self.runInteraction(
+            "_get_e2e_device_keys_txn",
+            self._get_e2e_device_keys_txn,
             query_map.keys(),
             include_all_devices=True,
             include_deleted_devices=True,
         )
 
-        prev_sent_id_sql = """
-            SELECT coalesce(max(stream_id), 0) as stream_id
-            FROM device_lists_outbound_last_success
-            WHERE destination = ? AND user_id = ? AND stream_id <= ?
-        """
-
         results = []
         for user_id, user_devices in iteritems(devices):
             # The prev_id for the first row is always the last row before
             # `from_stream_id`
-            txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
-            rows = txn.fetchall()
-            prev_id = rows[0][0]
+            prev_id = yield self._get_last_device_update_for_remote_user(
+                destination, user_id, from_stream_id,
+            )
             for device_id, device in iteritems(user_devices):
                 stream_id = query_map[(user_id, device_id)]
                 result = {
@@ -156,7 +235,22 @@ class DeviceWorkerStore(SQLBaseStore):
 
                 results.append(result)
 
-        return (now_stream_id, results)
+        defer.returnValue(results)
+
+    def _get_last_device_update_for_remote_user(
+        self, destination, user_id, from_stream_id,
+    ):
+        def f(txn):
+            prev_sent_id_sql = """
+                SELECT coalesce(max(stream_id), 0) as stream_id
+                FROM device_lists_outbound_last_success
+                WHERE destination = ? AND user_id = ? AND stream_id <= ?
+            """
+            txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
+            rows = txn.fetchall()
+            return rows[0][0]
+
+        return self.runInteraction("get_last_device_update_for_remote_user", f)
 
     def mark_as_sent_devices_by_remote(self, destination, stream_id):
         """Mark that updates have successfully been sent to the destination.
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index aef4dfaf57..6396ccddb5 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -72,6 +72,75 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
         )
 
     @defer.inlineCallbacks
+    def test_get_devices_by_remote(self):
+        device_ids = ["device_id1", "device_id2"]
+
+        # Add two device updates with a single stream_id
+        yield self.store.add_device_change_to_streams(
+            "user_id", device_ids, ["somehost"],
+        )
+
+        # Get all device updates ever meant for this remote
+        now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+            "somehost", -1, limit=100,
+        )
+
+        # Check original device_ids are contained within these updates
+        self._check_devices_in_updates(device_ids, device_updates)
+
+    @defer.inlineCallbacks
+    def test_get_devices_by_remote_limited(self):
+        # Test breaking the update limit in 1, 101, and 1 device_id segments
+
+        # first add one device
+        device_ids1 = ["device_id0"]
+        yield self.store.add_device_change_to_streams(
+            "user_id", device_ids1, ["someotherhost"],
+        )
+
+        # then add 101
+        device_ids2 = ["device_id" + str(i + 1) for i in range(101)]
+        yield self.store.add_device_change_to_streams(
+            "user_id", device_ids2, ["someotherhost"],
+        )
+
+        # then one more
+        device_ids3 = ["newdevice"]
+        yield self.store.add_device_change_to_streams(
+            "user_id", device_ids3, ["someotherhost"],
+        )
+
+        #
+        # now read them back.
+        #
+
+        # first we should get a single update
+        now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+            "someotherhost", -1, limit=100,
+        )
+        self._check_devices_in_updates(device_ids1, device_updates)
+
+        # Then we should get an empty list back as the 101 devices broke the limit
+        now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+            "someotherhost", now_stream_id, limit=100,
+        )
+        self.assertEqual(len(device_updates), 0)
+
+        # The 101 devices should've been cleared, so we should now just get one device
+        # update
+        now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+            "someotherhost", now_stream_id, limit=100,
+        )
+        self._check_devices_in_updates(device_ids3, device_updates)
+
+    def _check_devices_in_updates(self, expected_device_ids, device_updates):
+        """Check that an specific device ids exist in a list of device update EDUs"""
+        self.assertEqual(len(device_updates), len(expected_device_ids))
+
+        received_device_ids = {update["device_id"] for update in device_updates}
+        self.assertEqual(received_device_ids, set(expected_device_ids))
+
+    @defer.inlineCallbacks
     def test_update_device(self):
         yield self.store.store_device("user_id", "device_id", "display_name 1")