summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/devices.py51
1 files changed, 41 insertions, 10 deletions
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index 3413a46675..3af0171f75 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -391,22 +391,53 @@ class DeviceWorkerStore(SQLBaseStore):
 
         return now_stream_id, []
 
-    @defer.inlineCallbacks
-    def get_user_whose_devices_changed(self, from_key):
-        """Get set of users whose devices have changed since `from_key`.
+    def get_user_whose_devices_changed(self, from_key, user_ids):
+        """Get set of users whose devices have changed since `from_key` that
+        are in the given list of user_ids.
+
+        Args:
+            user_ids (Iterable[str])
+            from_key: The device lists stream token
+
+        Returns:
+            Deferred[set[str]]: The set of user_ids whose devices have changed
+            since `from_key`
         """
         from_key = int(from_key)
-        changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
-        if changed is not None:
-            defer.returnValue(set(changed))
+
+        # Get set of users who *may* have changed. Users not in the returned
+        # list have definitely not changed.
+        to_check = list(
+            self._device_list_stream_cache.get_entities_changed(user_ids, from_key)
+        )
+
+        if not to_check:
+            return defer.succeed(set())
+
+        # We now check the database for all users in `to_check`, in batches.
+        batch_size = 100
+        chunks = [
+            to_check[i : i + batch_size] for i in range(0, len(to_check), batch_size)
+        ]
 
         sql = """
-            SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
+            SELECT DISTINCT user_id FROM device_lists_stream
+            WHERE stream_id > ?
+            AND user_id IN (%s)
         """
-        rows = yield self._execute(
-            "get_user_whose_devices_changed", None, sql, from_key
+
+        def _get_user_whose_devices_changed_txn(txn):
+            changes = set()
+
+            for chunk in chunks:
+                txn.execute(sql % (",".join("?" for _ in chunk),), [from_key] + chunk)
+                changes.update(user_id for user_id, in txn)
+
+            return changes
+
+        return self.runInteraction(
+            "get_user_whose_devices_changed", _get_user_whose_devices_changed_txn
         )
-        defer.returnValue(set(row[0] for row in rows))
 
     def get_all_device_list_changes_for_remotes(self, from_key, to_key):
         """Return a list of `(stream_id, user_id, destination)` which is the