summary refs log tree commit diff
path: root/synapse/storage/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/devices.py')
-rw-r--r--synapse/storage/devices.py62
1 files changed, 42 insertions, 20 deletions
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index d102e07372..d2b113a4e7 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -24,6 +24,7 @@ from synapse.api.errors import StoreError
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import Cache, SQLBaseStore, db_to_json
 from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.util import batch_iter
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
 
 logger = logging.getLogger(__name__)
@@ -149,9 +150,7 @@ class DeviceWorkerStore(SQLBaseStore):
             defer.returnValue((stream_id_cutoff, []))
 
         results = yield self._get_device_update_edus_by_remote(
-            destination,
-            from_stream_id,
-            query_map,
+            destination, from_stream_id, query_map
         )
 
         defer.returnValue((now_stream_id, results))
@@ -182,9 +181,7 @@ class DeviceWorkerStore(SQLBaseStore):
         return list(txn)
 
     @defer.inlineCallbacks
-    def _get_device_update_edus_by_remote(
-        self, destination, from_stream_id, query_map,
-    ):
+    def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_map):
         """Returns a list of device update EDUs as well as E2EE keys
 
         Args:
@@ -210,7 +207,7 @@ class DeviceWorkerStore(SQLBaseStore):
             # The prev_id for the first row is always the last row before
             # `from_stream_id`
             prev_id = yield self._get_last_device_update_for_remote_user(
-                destination, user_id, from_stream_id,
+                destination, user_id, from_stream_id
             )
             for device_id, device in iteritems(user_devices):
                 stream_id = query_map[(user_id, device_id)]
@@ -238,7 +235,7 @@ class DeviceWorkerStore(SQLBaseStore):
         defer.returnValue(results)
 
     def _get_last_device_update_for_remote_user(
-        self, destination, user_id, from_stream_id,
+        self, destination, user_id, from_stream_id
     ):
         def f(txn):
             prev_sent_id_sql = """
@@ -395,22 +392,47 @@ class DeviceWorkerStore(SQLBaseStore):
 
         return now_stream_id, []
 
-    @defer.inlineCallbacks
-    def get_user_whose_devices_changed(self, from_key):
-        """Get set of users whose devices have changed since `from_key`.
+    def get_users_whose_devices_changed(self, from_key, user_ids):
+        """Get set of users whose devices have changed since `from_key` that
+        are in the given list of user_ids.
+
+        Args:
+            from_key (str): The device lists stream token
+            user_ids (Iterable[str])
+
+        Returns:
+            Deferred[set[str]]: The set of user_ids whose devices have changed
+            since `from_key`
         """
         from_key = int(from_key)
-        changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
-        if changed is not None:
-            defer.returnValue(set(changed))
 
-        sql = """
-            SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
-        """
-        rows = yield self._execute(
-            "get_user_whose_devices_changed", None, sql, from_key
+        # Get set of users who *may* have changed. Users not in the returned
+        # list have definitely not changed.
+        to_check = list(
+            self._device_list_stream_cache.get_entities_changed(user_ids, from_key)
+        )
+
+        if not to_check:
+            return defer.succeed(set())
+
+        def _get_users_whose_devices_changed_txn(txn):
+            changes = set()
+
+            sql = """
+                SELECT DISTINCT user_id FROM device_lists_stream
+                WHERE stream_id > ?
+                AND user_id IN (%s)
+            """
+
+            for chunk in batch_iter(to_check, 100):
+                txn.execute(sql % (",".join("?" for _ in chunk),), (from_key,) + chunk)
+                changes.update(user_id for user_id, in txn)
+
+            return changes
+
+        return self.runInteraction(
+            "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
         )
-        defer.returnValue(set(row[0] for row in rows))
 
     def get_all_device_list_changes_for_remotes(self, from_key, to_key):
         """Return a list of `(stream_id, user_id, destination)` which is the