diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index 3413a46675..3af0171f75 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -391,22 +391,53 @@ class DeviceWorkerStore(SQLBaseStore):
return now_stream_id, []
- @defer.inlineCallbacks
- def get_user_whose_devices_changed(self, from_key):
- """Get set of users whose devices have changed since `from_key`.
+ def get_user_whose_devices_changed(self, from_key, user_ids):
+ """Get set of users whose devices have changed since `from_key` that
+ are in the given list of user_ids.
+
+ Args:
+ user_ids (Iterable[str])
+ from_key: The device lists stream token
+
+ Returns:
+ Deferred[set[str]]: The set of user_ids whose devices have changed
+ since `from_key`
"""
from_key = int(from_key)
- changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
- if changed is not None:
- defer.returnValue(set(changed))
+
+ # Get set of users who *may* have changed. Users not in the returned
+ # list have definitely not changed.
+ to_check = list(
+ self._device_list_stream_cache.get_entities_changed(user_ids, from_key)
+ )
+
+ if not to_check:
+ return defer.succeed(set())
+
+ # We now check the database for all users in `to_check`, in batches.
+ batch_size = 100
+ chunks = [
+ to_check[i : i + batch_size] for i in range(0, len(to_check), batch_size)
+ ]
sql = """
- SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
+ SELECT DISTINCT user_id FROM device_lists_stream
+ WHERE stream_id > ?
+ AND user_id IN (%s)
"""
- rows = yield self._execute(
- "get_user_whose_devices_changed", None, sql, from_key
+
+ def _get_user_whose_devices_changed_txn(txn):
+ changes = set()
+
+ for chunk in chunks:
+ txn.execute(sql % (",".join("?" for _ in chunk),), [from_key] + chunk)
+ changes.update(user_id for user_id, in txn)
+
+ return changes
+
+ return self.runInteraction(
+ "get_user_whose_devices_changed", _get_user_whose_devices_changed_txn
)
- defer.returnValue(set(row[0] for row in rows))
def get_all_device_list_changes_for_remotes(self, from_key, to_key):
"""Return a list of `(stream_id, user_id, destination)` which is the
|