summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/devices.py27
1 files changed, 12 insertions, 15 deletions
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index 97f6cd2754..44324bf400 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -24,6 +24,7 @@ from synapse.api.errors import StoreError
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import Cache, SQLBaseStore, db_to_json
 from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.util import batch_iter
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
 
 logger = logging.getLogger(__name__)
@@ -396,8 +397,8 @@ class DeviceWorkerStore(SQLBaseStore):
         are in the given list of user_ids.
 
         Args:
+            from_key (str): The device lists stream token
             user_ids (Iterable[str])
-            from_key: The device lists stream token
 
         Returns:
             Deferred[set[str]]: The set of user_ids whose devices have changed
@@ -414,23 +415,19 @@ class DeviceWorkerStore(SQLBaseStore):
         if not to_check:
             return defer.succeed(set())
 
-        # We now check the database for all users in `to_check`, in batches.
-        batch_size = 100
-        chunks = [
-            to_check[i : i + batch_size] for i in range(0, len(to_check), batch_size)
-        ]
-
-        sql = """
-            SELECT DISTINCT user_id FROM device_lists_stream
-            WHERE stream_id > ?
-            AND user_id IN (%s)
-        """
-
         def _get_users_whose_devices_changed_txn(txn):
             changes = set()
 
-            for chunk in chunks:
-                txn.execute(sql % (",".join("?" for _ in chunk),), [from_key] + chunk)
+            sql = """
+                SELECT DISTINCT user_id FROM device_lists_stream
+                WHERE stream_id > ?
+                AND user_id IN (%s)
+            """
+
+            for chunk in batch_iter(to_check, 100):
+                txn.execute(
+                    sql % (",".join("?" for _ in chunk),), [from_key] + list(chunk)
+                )
                 changes.update(user_id for user_id, in txn)
 
             return changes