diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index fd869b934c..d102e07372 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from six import iteritems, itervalues
+from six import iteritems
from canonicaljson import json
@@ -72,11 +72,14 @@ class DeviceWorkerStore(SQLBaseStore):
defer.returnValue({d["device_id"]: d for d in devices})
- def get_devices_by_remote(self, destination, from_stream_id):
+ @defer.inlineCallbacks
+ def get_devices_by_remote(self, destination, from_stream_id, limit):
"""Get stream of updates to send to remote servers
Returns:
- (int, list[dict]): current stream id and list of updates
+ Deferred[tuple[int, list[dict]]]:
+ current stream id (ie, the stream id of the last update included in the
+ response), and the list of updates
"""
now_stream_id = self._device_list_id_gen.get_current_token()
@@ -84,55 +87,131 @@ class DeviceWorkerStore(SQLBaseStore):
destination, int(from_stream_id)
)
if not has_changed:
- return (now_stream_id, [])
-
- return self.runInteraction(
+ defer.returnValue((now_stream_id, []))
+
+ # We retrieve n+1 devices from the list of outbound pokes where n is
+ # our outbound device update limit. We then check if the very last
+ # device has the same stream_id as the second-to-last device. If so,
+ # then we ignore all devices with that stream_id and only send the
+ # devices with a lower stream_id.
+ #
+ # If when culling the list we end up with no devices afterwards, we
+ # consider the device update to be too large, and simply skip the
+ # stream_id; the rationale being that such a large device list update
+ # is likely an error.
+ updates = yield self.runInteraction(
"get_devices_by_remote",
self._get_devices_by_remote_txn,
destination,
from_stream_id,
now_stream_id,
+ limit + 1,
)
+ # Return an empty list if there are no updates
+ if not updates:
+ defer.returnValue((now_stream_id, []))
+
+ # if we have exceeded the limit, we need to exclude any results with the
+ # same stream_id as the last row.
+ if len(updates) > limit:
+ stream_id_cutoff = updates[-1][2]
+ now_stream_id = stream_id_cutoff - 1
+ else:
+ stream_id_cutoff = None
+
+ # Perform the equivalent of a GROUP BY
+ #
+ # Iterate through the updates list and copy non-duplicate
+ # (user_id, device_id) entries into a map, with the value being
+ # the max stream_id across each set of duplicate entries
+ #
+ # maps (user_id, device_id) -> stream_id
+ # as long as their stream_id does not match that of the last row
+ query_map = {}
+ for update in updates:
+ if stream_id_cutoff is not None and update[2] >= stream_id_cutoff:
+ # Stop processing updates
+ break
+
+ key = (update[0], update[1])
+ query_map[key] = max(query_map.get(key, 0), update[2])
+
+ # If we didn't find any updates with a stream_id lower than the cutoff, it
+ # means that there are more than limit updates all of which have the same
+ # steam_id.
+
+ # That should only happen if a client is spamming the server with new
+ # devices, in which case E2E isn't going to work well anyway. We'll just
+ # skip that stream_id and return an empty list, and continue with the next
+ # stream_id next time.
+ if not query_map:
+ defer.returnValue((stream_id_cutoff, []))
+
+ results = yield self._get_device_update_edus_by_remote(
+ destination,
+ from_stream_id,
+ query_map,
+ )
+
+ defer.returnValue((now_stream_id, results))
+
def _get_devices_by_remote_txn(
- self, txn, destination, from_stream_id, now_stream_id
+ self, txn, destination, from_stream_id, now_stream_id, limit
):
+ """Return device update information for a given remote destination
+
+ Args:
+ txn (LoggingTransaction): The transaction to execute
+ destination (str): The host the device updates are intended for
+ from_stream_id (int): The minimum stream_id to filter updates by, exclusive
+ now_stream_id (int): The maximum stream_id to filter updates by, inclusive
+ limit (int): Maximum number of device updates to return
+
+ Returns:
+ List: List of device updates
+ """
sql = """
- SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
+ SELECT user_id, device_id, stream_id FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
- GROUP BY user_id, device_id
- LIMIT 20
+ ORDER BY stream_id
+ LIMIT ?
"""
- txn.execute(sql, (destination, from_stream_id, now_stream_id, False))
+ txn.execute(sql, (destination, from_stream_id, now_stream_id, False, limit))
- # maps (user_id, device_id) -> stream_id
- query_map = {(r[0], r[1]): r[2] for r in txn}
- if not query_map:
- return (now_stream_id, [])
+ return list(txn)
- if len(query_map) >= 20:
- now_stream_id = max(stream_id for stream_id in itervalues(query_map))
+ @defer.inlineCallbacks
+ def _get_device_update_edus_by_remote(
+ self, destination, from_stream_id, query_map,
+ ):
+ """Returns a list of device update EDUs as well as E2EE keys
- devices = self._get_e2e_device_keys_txn(
- txn,
+ Args:
+ destination (str): The host the device updates are intended for
+ from_stream_id (int): The minimum stream_id to filter updates by, exclusive
+ query_map (Dict[(str, str): int]): Dictionary mapping
+ user_id/device_id to update stream_id
+
+ Returns:
+ List[Dict]: List of objects representing an device update EDU
+
+ """
+ devices = yield self.runInteraction(
+ "_get_e2e_device_keys_txn",
+ self._get_e2e_device_keys_txn,
query_map.keys(),
include_all_devices=True,
include_deleted_devices=True,
)
- prev_sent_id_sql = """
- SELECT coalesce(max(stream_id), 0) as stream_id
- FROM device_lists_outbound_last_success
- WHERE destination = ? AND user_id = ? AND stream_id <= ?
- """
-
results = []
for user_id, user_devices in iteritems(devices):
# The prev_id for the first row is always the last row before
# `from_stream_id`
- txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
- rows = txn.fetchall()
- prev_id = rows[0][0]
+ prev_id = yield self._get_last_device_update_for_remote_user(
+ destination, user_id, from_stream_id,
+ )
for device_id, device in iteritems(user_devices):
stream_id = query_map[(user_id, device_id)]
result = {
@@ -156,7 +235,22 @@ class DeviceWorkerStore(SQLBaseStore):
results.append(result)
- return (now_stream_id, results)
+ defer.returnValue(results)
+
+ def _get_last_device_update_for_remote_user(
+ self, destination, user_id, from_stream_id,
+ ):
+ def f(txn):
+ prev_sent_id_sql = """
+ SELECT coalesce(max(stream_id), 0) as stream_id
+ FROM device_lists_outbound_last_success
+ WHERE destination = ? AND user_id = ? AND stream_id <= ?
+ """
+ txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
+ rows = txn.fetchall()
+ return rows[0][0]
+
+ return self.runInteraction("get_last_device_update_for_remote_user", f)
def mark_as_sent_devices_by_remote(self, destination, stream_id):
"""Mark that updates have successfully been sent to the destination.
|