diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 66f5b8d108..f68027aaed 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -615,8 +615,8 @@ class ApplicationServicesHandler:
)
# Fetch the users who have modified their device list since then.
- users_with_changed_device_lists = (
- await self.store.get_users_whose_devices_changed(from_key, to_key=new_key)
+ users_with_changed_device_lists = await self.store.get_all_devices_changed(
+ from_key, to_key=new_key
)
# Filter out any users the application service is not interested in
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 1799174c2f..2af90b25a3 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -1692,10 +1692,12 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
if from_key is not None:
# First get all users that have had a presence update
- updated_users = stream_change_cache.get_all_entities_changed(from_key)
+ result = stream_change_cache.get_all_entities_changed(from_key)
# Cross-reference users we're interested in with those that have had updates.
- if updated_users is not None:
+ if result.hit:
+ updated_users = result.entities
+
# If we have the full list of changes for presence we can
# simply check which ones share a room with the user.
get_updates_counter.labels("stream").inc()
@@ -1767,9 +1769,9 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
updated_users = None
if from_key:
# Only return updates since the last sync
- updated_users = self.store.presence_stream_cache.get_all_entities_changed(
- from_key
- )
+ result = self.store.presence_stream_cache.get_all_entities_changed(from_key)
+ if result.hit:
+ updated_users = result.entities
if updated_users is not None:
# Get the actual presence update for each change
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index c8858b22dd..0b395a104d 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1528,10 +1528,12 @@ class SyncHandler:
#
# If we don't have that info cached then we get all the users that
# share a room with our user and check if those users have changed.
- changed_users = self.store.get_cached_device_list_changes(
+ cache_result = self.store.get_cached_device_list_changes(
since_token.device_list_key
)
- if changed_users is not None:
+ if cache_result.hit:
+ changed_users = cache_result.entities
+
result = await self.store.get_rooms_for_users(changed_users)
for changed_user_id, entries in result.items():
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index a0ea719430..3f656ea4f5 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -420,11 +420,11 @@ class TypingWriterHandler(FollowerTypingHandler):
if last_id == current_id:
return [], current_id, False
- changed_rooms: Optional[
- Iterable[str]
- ] = self._typing_stream_change_cache.get_all_entities_changed(last_id)
+ result = self._typing_stream_change_cache.get_all_entities_changed(last_id)
- if changed_rooms is None:
+ if result.hit:
+ changed_rooms: Iterable[str] = result.entities
+ else:
changed_rooms = self._room_serials
rows = []
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 8ba995df3b..a5bb4d404e 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -58,7 +58,10 @@ from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
-from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.util.caches.stream_change_cache import (
+ AllEntitiesChangedResult,
+ StreamChangeCache,
+)
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
@@ -799,7 +802,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
def get_cached_device_list_changes(
self,
from_key: int,
- ) -> Optional[List[str]]:
+ ) -> AllEntitiesChangedResult:
"""Get set of users whose devices have changed since `from_key`, or None
if that information is not in our cache.
"""
@@ -807,10 +810,58 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return self._device_list_stream_cache.get_all_entities_changed(from_key)
@cancellable
+ async def get_all_devices_changed(
+ self,
+ from_key: int,
+ to_key: int,
+ ) -> Set[str]:
+ """Get all users whose devices have changed in the given range.
+
+ Args:
+ from_key: The minimum device lists stream token to query device list
+ changes for, exclusive.
+ to_key: The maximum device lists stream token to query device list
+ changes for, inclusive.
+
+ Returns:
+ The set of user_ids whose devices have changed since `from_key`
+ (exclusive) until `to_key` (inclusive).
+ """
+
+ result = self._device_list_stream_cache.get_all_entities_changed(from_key)
+
+ if result.hit:
+ # We know which users might have changed devices.
+ if not result.entities:
+ # If no users then we can return early.
+ return set()
+
+ # Otherwise we need to filter down the list
+ return await self.get_users_whose_devices_changed(
+ from_key, result.entities, to_key
+ )
+
+ # If the cache didn't tell us anything, we just need to query the full
+ # range.
+ sql = """
+ SELECT DISTINCT user_id FROM device_lists_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ """
+
+ rows = await self.db_pool.execute(
+ "get_all_devices_changed",
+ None,
+ sql,
+ from_key,
+ to_key,
+ )
+ return {u for u, in rows}
+
+ @cancellable
async def get_users_whose_devices_changed(
self,
from_key: int,
- user_ids: Optional[Collection[str]] = None,
+ user_ids: Collection[str],
to_key: Optional[int] = None,
) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
@@ -830,52 +881,32 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
"""
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
- user_ids_to_check: Optional[Collection[str]]
- if user_ids is None:
- # Get set of all users that have had device list changes since 'from_key'
- user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
- from_key
- )
- else:
- # The same as above, but filter results to only those users in 'user_ids'
- user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
- user_ids, from_key
- )
+ user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
+ user_ids, from_key
+ )
# If an empty set was returned, there's nothing to do.
- if user_ids_to_check is not None and not user_ids_to_check:
+ if not user_ids_to_check:
return set()
- def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
- stream_id_where_clause = "stream_id > ?"
- sql_args = [from_key]
-
- if to_key:
- stream_id_where_clause += " AND stream_id <= ?"
- sql_args.append(to_key)
+ if to_key is None:
+ to_key = self._device_list_id_gen.get_current_token()
- sql = f"""
+ def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
+ sql = """
SELECT DISTINCT user_id FROM device_lists_stream
- WHERE {stream_id_where_clause}
+ WHERE ? < stream_id AND stream_id <= ? AND %s
"""
- # If the stream change cache gave us no information, fetch *all*
- # users between the stream IDs.
- if user_ids_to_check is None:
- txn.execute(sql, sql_args)
- return {user_id for user_id, in txn}
+ changes: Set[str] = set()
- # Otherwise, fetch changes for the given users.
- else:
- changes: Set[str] = set()
-
- # Query device changes with a batch of users at a time
- for chunk in batch_iter(user_ids_to_check, 100):
- clause, args = make_in_list_sql_clause(
- txn.database_engine, "user_id", chunk
- )
- txn.execute(sql + " AND " + clause, sql_args + args)
- changes.update(user_id for user_id, in txn)
+ # Query device changes with a batch of users at a time
+ for chunk in batch_iter(user_ids_to_check, 100):
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "user_id", chunk
+ )
+ txn.execute(sql % (clause,), [from_key, to_key] + args)
+ changes.update(user_id for user_id, in txn)
return changes
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 042de8d7c8..c8b17acb59 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -16,6 +16,7 @@ import logging
import math
from typing import Collection, Dict, FrozenSet, List, Mapping, Optional, Set, Union
+import attr
from sortedcontainers import SortedDict
from synapse.util import caches
@@ -26,6 +27,29 @@ logger = logging.getLogger(__name__)
EntityType = str
+@attr.s(auto_attribs=True, frozen=True, slots=True)
+class AllEntitiesChangedResult:
+ """Return type of `get_all_entities_changed`.
+
+ Callers must check that there was a cache hit, via `result.hit`, before
+ using the entities in `result.entities`.
+
+ This specifically does *not* implement helpers such as `__bool__` to ensure
+ that callers do the correct checks.
+ """
+
+ _entities: Optional[List[EntityType]]
+
+ @property
+ def hit(self) -> bool:
+ return self._entities is not None
+
+ @property
+ def entities(self) -> List[EntityType]:
+ assert self._entities is not None
+ return self._entities
+
+
class StreamChangeCache:
"""
Keeps track of the stream positions of the latest change in a set of entities.
@@ -153,19 +177,19 @@ class StreamChangeCache:
This will be all entities if the given stream position is at or earlier
than the earliest known stream position.
"""
- changed_entities = self.get_all_entities_changed(stream_pos)
- if changed_entities is not None:
+ cache_result = self.get_all_entities_changed(stream_pos)
+ if cache_result.hit:
# We now do an intersection, trying to do so in the most efficient
# way possible (some of these sets are *large*). First check in the
# given iterable is already a set that we can reuse, otherwise we
# create a set of the *smallest* of the two iterables and call
# `intersection(..)` on it (this can be twice as fast as the reverse).
if isinstance(entities, (set, frozenset)):
- result = entities.intersection(changed_entities)
- elif len(changed_entities) < len(entities):
- result = set(changed_entities).intersection(entities)
+ result = entities.intersection(cache_result.entities)
+ elif len(cache_result.entities) < len(entities):
+ result = set(cache_result.entities).intersection(entities)
else:
- result = set(entities).intersection(changed_entities)
+ result = set(entities).intersection(cache_result.entities)
self.metrics.inc_hits()
else:
result = set(entities)
@@ -202,12 +226,12 @@ class StreamChangeCache:
self.metrics.inc_hits()
return stream_pos < self._cache.peekitem()[0]
- def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType]]:
+ def get_all_entities_changed(self, stream_pos: int) -> AllEntitiesChangedResult:
"""
Returns all entities that have had changes after the given position.
- If the stream change cache does not go far enough back, i.e. the position
- is too old, it will return None.
+ If the stream change cache does not go far enough back, i.e. the
+ position is too old, it will return None.
Returns the entities in the order that they were changed.
@@ -215,23 +239,21 @@ class StreamChangeCache:
stream_pos: The stream position to check for changes after.
Return:
- Entities which have changed after the given stream position.
-
- None if the given stream position is at or earlier than the earliest
- known stream position.
+ A class indicating if we have the requested data cached, and if so
+ includes the entities in the order they were changed.
"""
assert isinstance(stream_pos, int)
# _cache is not valid at or before the earliest known stream position, so
# return None to mark that it is unknown if an entity has changed.
if stream_pos <= self._earliest_known_stream_pos:
- return None
+ return AllEntitiesChangedResult(None)
changed_entities: List[EntityType] = []
for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)):
changed_entities.extend(self._cache[k])
- return changed_entities
+ return AllEntitiesChangedResult(changed_entities)
def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None:
"""
|