summary refs log tree commit diff
path: root/synapse/storage/databases/main/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/devices.py')
-rw-r--r--synapse/storage/databases/main/devices.py48
1 files changed, 35 insertions, 13 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 3b3a089b76..f08f7834d3 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -681,42 +681,64 @@ class DeviceWorkerStore(SQLBaseStore):
         return self._device_list_stream_cache.get_all_entities_changed(from_key)
 
     async def get_users_whose_devices_changed(
-        self, from_key: int, user_ids: Iterable[str]
+        self,
+        from_key: int,
+        user_ids: Optional[Iterable[str]] = None,
+        to_key: Optional[int] = None,
     ) -> Set[str]:
         """Get set of users whose devices have changed since `from_key` that
         are in the given list of user_ids.
 
         Args:
-            from_key: The device lists stream token
-            user_ids: The user IDs to query for devices.
+            from_key: The minimum device lists stream token to query device list changes for,
+                exclusive.
+            user_ids: If provided, only check if these users have changed their device lists.
+                Otherwise changes from all users are returned.
+            to_key: The maximum device lists stream token to query device list changes for,
+                inclusive.
 
         Returns:
-            The set of user_ids whose devices have changed since `from_key`
+            The set of user_ids whose devices have changed since `from_key` (exclusive)
+                until `to_key` (inclusive).
         """
-
         # Get set of users who *may* have changed. Users not in the returned
         # list have definitely not changed.
-        to_check = self._device_list_stream_cache.get_entities_changed(
-            user_ids, from_key
-        )
+        if user_ids is None:
+            # Get set of all users that have had device list changes since 'from_key'
+            user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
+                from_key
+            )
+        else:
+            # The same as above, but filter results to only those users in 'user_ids'
+            user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
+                user_ids, from_key
+            )
 
-        if not to_check:
+        if not user_ids_to_check:
             return set()
 
         def _get_users_whose_devices_changed_txn(txn):
             changes = set()
 
-            sql = """
+            stream_id_where_clause = "stream_id > ?"
+            sql_args = [from_key]
+
+            if to_key:
+                stream_id_where_clause += " AND stream_id <= ?"
+                sql_args.append(to_key)
+
+            sql = f"""
                 SELECT DISTINCT user_id FROM device_lists_stream
-                WHERE stream_id > ?
+                WHERE {stream_id_where_clause}
                 AND
             """
 
-            for chunk in batch_iter(to_check, 100):
+            # Query device changes with a batch of users at a time
+            for chunk in batch_iter(user_ids_to_check, 100):
                 clause, args = make_in_list_sql_clause(
                     txn.database_engine, "user_id", chunk
                 )
-                txn.execute(sql + clause, (from_key,) + tuple(args))
+                txn.execute(sql + clause, sql_args + args)
                 changes.update(user_id for user_id, in txn)
 
             return changes