diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 3b3a089b76..9efca232cd 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -681,42 +681,64 @@ class DeviceWorkerStore(SQLBaseStore):
return self._device_list_stream_cache.get_all_entities_changed(from_key)
async def get_users_whose_devices_changed(
- self, from_key: int, user_ids: Iterable[str]
+ self,
+ from_key: int,
+ user_ids: Optional[Iterable[str]] = None,
+ to_key: Optional[int] = None,
) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
are in the given list of user_ids.
-
Args:
- from_key: The device lists stream token
- user_ids: The user IDs to query for devices.
-
+ from_key: The minimum device lists stream token to query device list changes for,
+ exclusive.
+ user_ids: If provided, only check if these users have changed their device lists.
+ Otherwise changes from all users are returned.
+ to_key: The maximum device lists stream token to query device list changes for,
+ inclusive.
Returns:
- The set of user_ids whose devices have changed since `from_key`
+ The set of user_ids whose devices have changed since `from_key` (exclusive)
+ until `to_key` (inclusive).
"""
-
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
- to_check = self._device_list_stream_cache.get_entities_changed(
- user_ids, from_key
- )
+ if user_ids is None:
+ # Get set of all users that have had device list changes since 'from_key'
+ user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
+ from_key
+ )
+ else:
+ # The same as above, but filter results to only those users in 'user_ids'
+ user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
+ user_ids, from_key
+ )
- if not to_check:
+ if not user_ids_to_check:
return set()
def _get_users_whose_devices_changed_txn(txn):
changes = set()
- sql = """
+ stream_id_where_clause = "stream_id > ?"
+ sql_args = [from_key]
+
+ if to_key:
+ stream_id_where_clause += " AND stream_id <= ?"
+ sql_args += [to_key]
+
+ sql = f"""
SELECT DISTINCT user_id FROM device_lists_stream
- WHERE stream_id > ?
+ WHERE {stream_id_where_clause}
AND
"""
- for chunk in batch_iter(to_check, 100):
+ # Query device changes with a batch of users at a time
+ for chunk in batch_iter(user_ids_to_check, 100):
clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk
)
- txn.execute(sql + clause, (from_key,) + tuple(args))
+ sql_args += args
+
+ txn.execute(sql + clause, sql_args)
changes.update(user_id for user_id, in txn)
return changes
|