summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/appservice.py4
-rw-r--r--synapse/handlers/presence.py12
-rw-r--r--synapse/handlers/sync.py6
-rw-r--r--synapse/handlers/typing.py8
-rw-r--r--synapse/storage/databases/main/devices.py111
-rw-r--r--synapse/util/caches/stream_change_cache.py52
6 files changed, 125 insertions, 68 deletions
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 66f5b8d108..f68027aaed 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -615,8 +615,8 @@ class ApplicationServicesHandler:
         )
 
         # Fetch the users who have modified their device list since then.
-        users_with_changed_device_lists = (
-            await self.store.get_users_whose_devices_changed(from_key, to_key=new_key)
+        users_with_changed_device_lists = await self.store.get_all_devices_changed(
+            from_key, to_key=new_key
         )
 
         # Filter out any users the application service is not interested in
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 1799174c2f..2af90b25a3 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -1692,10 +1692,12 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
 
             if from_key is not None:
                 # First get all users that have had a presence update
-                updated_users = stream_change_cache.get_all_entities_changed(from_key)
+                result = stream_change_cache.get_all_entities_changed(from_key)
 
                 # Cross-reference users we're interested in with those that have had updates.
-                if updated_users is not None:
+                if result.hit:
+                    updated_users = result.entities
+
                     # If we have the full list of changes for presence we can
                     # simply check which ones share a room with the user.
                     get_updates_counter.labels("stream").inc()
@@ -1767,9 +1769,9 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
         updated_users = None
         if from_key:
             # Only return updates since the last sync
-            updated_users = self.store.presence_stream_cache.get_all_entities_changed(
-                from_key
-            )
+            result = self.store.presence_stream_cache.get_all_entities_changed(from_key)
+            if result.hit:
+                updated_users = result.entities
 
         if updated_users is not None:
             # Get the actual presence update for each change
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index c8858b22dd..0b395a104d 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1528,10 +1528,12 @@ class SyncHandler:
             #
             # If we don't have that info cached then we get all the users that
             # share a room with our user and check if those users have changed.
-            changed_users = self.store.get_cached_device_list_changes(
+            cache_result = self.store.get_cached_device_list_changes(
                 since_token.device_list_key
             )
-            if changed_users is not None:
+            if cache_result.hit:
+                changed_users = cache_result.entities
+
                 result = await self.store.get_rooms_for_users(changed_users)
 
                 for changed_user_id, entries in result.items():
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index a0ea719430..3f656ea4f5 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -420,11 +420,11 @@ class TypingWriterHandler(FollowerTypingHandler):
         if last_id == current_id:
             return [], current_id, False
 
-        changed_rooms: Optional[
-            Iterable[str]
-        ] = self._typing_stream_change_cache.get_all_entities_changed(last_id)
+        result = self._typing_stream_change_cache.get_all_entities_changed(last_id)
 
-        if changed_rooms is None:
+        if result.hit:
+            changed_rooms: Iterable[str] = result.entities
+        else:
             changed_rooms = self._room_serials
 
         rows = []
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 8ba995df3b..a5bb4d404e 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -58,7 +58,10 @@ from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
 from synapse.util import json_decoder, json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.lrucache import LruCache
-from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.util.caches.stream_change_cache import (
+    AllEntitiesChangedResult,
+    StreamChangeCache,
+)
 from synapse.util.cancellation import cancellable
 from synapse.util.iterutils import batch_iter
 from synapse.util.stringutils import shortstr
@@ -799,7 +802,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     def get_cached_device_list_changes(
         self,
         from_key: int,
-    ) -> Optional[List[str]]:
+    ) -> AllEntitiesChangedResult:
         """Get set of users whose devices have changed since `from_key`, or None
         if that information is not in our cache.
         """
@@ -807,10 +810,58 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
         return self._device_list_stream_cache.get_all_entities_changed(from_key)
 
     @cancellable
+    async def get_all_devices_changed(
+        self,
+        from_key: int,
+        to_key: int,
+    ) -> Set[str]:
+        """Get all users whose devices have changed in the given range.
+
+        Args:
+            from_key: The minimum device lists stream token to query device list
+                changes for, exclusive.
+            to_key: The maximum device lists stream token to query device list
+                changes for, inclusive.
+
+        Returns:
+            The set of user_ids whose devices have changed since `from_key`
+            (exclusive) until `to_key` (inclusive).
+        """
+
+        result = self._device_list_stream_cache.get_all_entities_changed(from_key)
+
+        if result.hit:
+            # We know which users might have changed devices.
+            if not result.entities:
+                # If no users then we can return early.
+                return set()
+
+            # Otherwise we need to filter down the list
+            return await self.get_users_whose_devices_changed(
+                from_key, result.entities, to_key
+            )
+
+        # If the cache didn't tell us anything, we just need to query the full
+        # range.
+        sql = """
+            SELECT DISTINCT user_id FROM device_lists_stream
+            WHERE ? < stream_id AND stream_id <= ?
+        """
+
+        rows = await self.db_pool.execute(
+            "get_all_devices_changed",
+            None,
+            sql,
+            from_key,
+            to_key,
+        )
+        return {u for u, in rows}
+
+    @cancellable
     async def get_users_whose_devices_changed(
         self,
         from_key: int,
-        user_ids: Optional[Collection[str]] = None,
+        user_ids: Collection[str],
         to_key: Optional[int] = None,
     ) -> Set[str]:
         """Get set of users whose devices have changed since `from_key` that
@@ -830,52 +881,32 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
         """
         # Get set of users who *may* have changed. Users not in the returned
         # list have definitely not changed.
-        user_ids_to_check: Optional[Collection[str]]
-        if user_ids is None:
-            # Get set of all users that have had device list changes since 'from_key'
-            user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
-                from_key
-            )
-        else:
-            # The same as above, but filter results to only those users in 'user_ids'
-            user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
-                user_ids, from_key
-            )
+        user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
+            user_ids, from_key
+        )
 
         # If an empty set was returned, there's nothing to do.
-        if user_ids_to_check is not None and not user_ids_to_check:
+        if not user_ids_to_check:
             return set()
 
-        def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
-            stream_id_where_clause = "stream_id > ?"
-            sql_args = [from_key]
-
-            if to_key:
-                stream_id_where_clause += " AND stream_id <= ?"
-                sql_args.append(to_key)
+        if to_key is None:
+            to_key = self._device_list_id_gen.get_current_token()
 
-            sql = f"""
+        def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
+            sql = """
                 SELECT DISTINCT user_id FROM device_lists_stream
-                WHERE {stream_id_where_clause}
+                WHERE  ? < stream_id AND stream_id <= ? AND %s
             """
 
-            # If the stream change cache gave us no information, fetch *all*
-            # users between the stream IDs.
-            if user_ids_to_check is None:
-                txn.execute(sql, sql_args)
-                return {user_id for user_id, in txn}
+            changes: Set[str] = set()
 
-            # Otherwise, fetch changes for the given users.
-            else:
-                changes: Set[str] = set()
-
-                # Query device changes with a batch of users at a time
-                for chunk in batch_iter(user_ids_to_check, 100):
-                    clause, args = make_in_list_sql_clause(
-                        txn.database_engine, "user_id", chunk
-                    )
-                    txn.execute(sql + " AND " + clause, sql_args + args)
-                    changes.update(user_id for user_id, in txn)
+            # Query device changes with a batch of users at a time
+            for chunk in batch_iter(user_ids_to_check, 100):
+                clause, args = make_in_list_sql_clause(
+                    txn.database_engine, "user_id", chunk
+                )
+                txn.execute(sql % (clause,), [from_key, to_key] + args)
+                changes.update(user_id for user_id, in txn)
 
             return changes
 
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 042de8d7c8..c8b17acb59 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -16,6 +16,7 @@ import logging
 import math
 from typing import Collection, Dict, FrozenSet, List, Mapping, Optional, Set, Union
 
+import attr
 from sortedcontainers import SortedDict
 
 from synapse.util import caches
@@ -26,6 +27,29 @@ logger = logging.getLogger(__name__)
 EntityType = str
 
 
+@attr.s(auto_attribs=True, frozen=True, slots=True)
+class AllEntitiesChangedResult:
+    """Return type of `get_all_entities_changed`.
+
+    Callers must check that there was a cache hit, via `result.hit`, before
+    using the entities in `result.entities`.
+
+    This specifically does *not* implement helpers such as `__bool__` to ensure
+    that callers do the correct checks.
+    """
+
+    _entities: Optional[List[EntityType]]
+
+    @property
+    def hit(self) -> bool:
+        return self._entities is not None
+
+    @property
+    def entities(self) -> List[EntityType]:
+        assert self._entities is not None
+        return self._entities
+
+
 class StreamChangeCache:
     """
     Keeps track of the stream positions of the latest change in a set of entities.
@@ -153,19 +177,19 @@ class StreamChangeCache:
             This will be all entities if the given stream position is at or earlier
             than the earliest known stream position.
         """
-        changed_entities = self.get_all_entities_changed(stream_pos)
-        if changed_entities is not None:
+        cache_result = self.get_all_entities_changed(stream_pos)
+        if cache_result.hit:
             # We now do an intersection, trying to do so in the most efficient
             # way possible (some of these sets are *large*). First check in the
             # given iterable is already a set that we can reuse, otherwise we
             # create a set of the *smallest* of the two iterables and call
             # `intersection(..)` on it (this can be twice as fast as the reverse).
             if isinstance(entities, (set, frozenset)):
-                result = entities.intersection(changed_entities)
-            elif len(changed_entities) < len(entities):
-                result = set(changed_entities).intersection(entities)
+                result = entities.intersection(cache_result.entities)
+            elif len(cache_result.entities) < len(entities):
+                result = set(cache_result.entities).intersection(entities)
             else:
-                result = set(entities).intersection(changed_entities)
+                result = set(entities).intersection(cache_result.entities)
             self.metrics.inc_hits()
         else:
             result = set(entities)
@@ -202,12 +226,12 @@ class StreamChangeCache:
         self.metrics.inc_hits()
         return stream_pos < self._cache.peekitem()[0]
 
-    def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType]]:
+    def get_all_entities_changed(self, stream_pos: int) -> AllEntitiesChangedResult:
         """
         Returns all entities that have had changes after the given position.
 
-        If the stream change cache does not go far enough back, i.e. the position
-        is too old, it will return None.
+        If the stream change cache does not go far enough back, i.e. the
+        position is too old, it will return None.
 
         Returns the entities in the order that they were changed.
 
@@ -215,23 +239,21 @@ class StreamChangeCache:
             stream_pos: The stream position to check for changes after.
 
         Return:
-            Entities which have changed after the given stream position.
-
-            None if the given stream position is at or earlier than the earliest
-            known stream position.
+            A class indicating if we have the requested data cached, and if so
+            includes the entities in the order they were changed.
         """
         assert isinstance(stream_pos, int)
 
         # _cache is not valid at or before the earliest known stream position, so
         # return None to mark that it is unknown if an entity has changed.
         if stream_pos <= self._earliest_known_stream_pos:
-            return None
+            return AllEntitiesChangedResult(None)
 
         changed_entities: List[EntityType] = []
 
         for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)):
             changed_entities.extend(self._cache[k])
-        return changed_entities
+        return AllEntitiesChangedResult(changed_entities)
 
     def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None:
         """