summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/devices.py111
1 files changed, 71 insertions, 40 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 8ba995df3b..a5bb4d404e 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -58,7 +58,10 @@ from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
 from synapse.util import json_decoder, json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.lrucache import LruCache
-from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.util.caches.stream_change_cache import (
+    AllEntitiesChangedResult,
+    StreamChangeCache,
+)
 from synapse.util.cancellation import cancellable
 from synapse.util.iterutils import batch_iter
 from synapse.util.stringutils import shortstr
@@ -799,7 +802,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     def get_cached_device_list_changes(
         self,
         from_key: int,
-    ) -> Optional[List[str]]:
+    ) -> AllEntitiesChangedResult:
         """Get set of users whose devices have changed since `from_key`, or None
         if that information is not in our cache.
         """
@@ -807,10 +810,58 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
         return self._device_list_stream_cache.get_all_entities_changed(from_key)
 
     @cancellable
+    async def get_all_devices_changed(
+        self,
+        from_key: int,
+        to_key: int,
+    ) -> Set[str]:
+        """Get all users whose devices have changed in the given range.
+
+        Args:
+            from_key: The minimum device lists stream token to query device list
+                changes for, exclusive.
+            to_key: The maximum device lists stream token to query device list
+                changes for, inclusive.
+
+        Returns:
+            The set of user_ids whose devices have changed since `from_key`
+            (exclusive) until `to_key` (inclusive).
+        """
+
+        result = self._device_list_stream_cache.get_all_entities_changed(from_key)
+
+        if result.hit:
+            # We know which users might have changed devices.
+            if not result.entities:
+                # If no users then we can return early.
+                return set()
+
+            # Otherwise we need to filter down the list
+            return await self.get_users_whose_devices_changed(
+                from_key, result.entities, to_key
+            )
+
+        # If the cache didn't tell us anything, we just need to query the full
+        # range.
+        sql = """
+            SELECT DISTINCT user_id FROM device_lists_stream
+            WHERE ? < stream_id AND stream_id <= ?
+        """
+
+        rows = await self.db_pool.execute(
+            "get_all_devices_changed",
+            None,
+            sql,
+            from_key,
+            to_key,
+        )
+        return {u for u, in rows}
+
+    @cancellable
     async def get_users_whose_devices_changed(
         self,
         from_key: int,
-        user_ids: Optional[Collection[str]] = None,
+        user_ids: Collection[str],
         to_key: Optional[int] = None,
     ) -> Set[str]:
         """Get set of users whose devices have changed since `from_key` that
@@ -830,52 +881,32 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
         """
         # Get set of users who *may* have changed. Users not in the returned
         # list have definitely not changed.
-        user_ids_to_check: Optional[Collection[str]]
-        if user_ids is None:
-            # Get set of all users that have had device list changes since 'from_key'
-            user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
-                from_key
-            )
-        else:
-            # The same as above, but filter results to only those users in 'user_ids'
-            user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
-                user_ids, from_key
-            )
+        user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
+            user_ids, from_key
+        )
 
         # If an empty set was returned, there's nothing to do.
-        if user_ids_to_check is not None and not user_ids_to_check:
+        if not user_ids_to_check:
             return set()
 
-        def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
-            stream_id_where_clause = "stream_id > ?"
-            sql_args = [from_key]
-
-            if to_key:
-                stream_id_where_clause += " AND stream_id <= ?"
-                sql_args.append(to_key)
+        if to_key is None:
+            to_key = self._device_list_id_gen.get_current_token()
 
-            sql = f"""
+        def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
+            sql = """
                 SELECT DISTINCT user_id FROM device_lists_stream
-                WHERE {stream_id_where_clause}
+                WHERE  ? < stream_id AND stream_id <= ? AND %s
             """
 
-            # If the stream change cache gave us no information, fetch *all*
-            # users between the stream IDs.
-            if user_ids_to_check is None:
-                txn.execute(sql, sql_args)
-                return {user_id for user_id, in txn}
+            changes: Set[str] = set()
 
-            # Otherwise, fetch changes for the given users.
-            else:
-                changes: Set[str] = set()
-
-                # Query device changes with a batch of users at a time
-                for chunk in batch_iter(user_ids_to_check, 100):
-                    clause, args = make_in_list_sql_clause(
-                        txn.database_engine, "user_id", chunk
-                    )
-                    txn.execute(sql + " AND " + clause, sql_args + args)
-                    changes.update(user_id for user_id, in txn)
+            # Query device changes with a batch of users at a time
+            for chunk in batch_iter(user_ids_to_check, 100):
+                clause, args = make_in_list_sql_clause(
+                    txn.database_engine, "user_id", chunk
+                )
+                txn.execute(sql % (clause,), [from_key, to_key] + args)
+                changes.update(user_id for user_id, in txn)
 
             return changes