diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 042de8d7c8..c8b17acb59 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -16,6 +16,7 @@ import logging
import math
from typing import Collection, Dict, FrozenSet, List, Mapping, Optional, Set, Union
+import attr
from sortedcontainers import SortedDict
from synapse.util import caches
@@ -26,6 +27,29 @@ logger = logging.getLogger(__name__)
EntityType = str
+@attr.s(auto_attribs=True, frozen=True, slots=True)
+class AllEntitiesChangedResult:
+ """Return type of `get_all_entities_changed`.
+
+ Callers must check that there was a cache hit, via `result.hit`, before
+ using the entities in `result.entities`.
+
+ This specifically does *not* implement helpers such as `__bool__` to ensure
+ that callers do the correct checks.
+ """
+
+ _entities: Optional[List[EntityType]]
+
+ @property
+ def hit(self) -> bool:
+ return self._entities is not None
+
+ @property
+ def entities(self) -> List[EntityType]:
+ assert self._entities is not None
+ return self._entities
+
+
class StreamChangeCache:
"""
Keeps track of the stream positions of the latest change in a set of entities.
@@ -153,19 +177,19 @@ class StreamChangeCache:
This will be all entities if the given stream position is at or earlier
than the earliest known stream position.
"""
- changed_entities = self.get_all_entities_changed(stream_pos)
- if changed_entities is not None:
+ cache_result = self.get_all_entities_changed(stream_pos)
+ if cache_result.hit:
# We now do an intersection, trying to do so in the most efficient
# way possible (some of these sets are *large*). First check in the
# given iterable is already a set that we can reuse, otherwise we
# create a set of the *smallest* of the two iterables and call
# `intersection(..)` on it (this can be twice as fast as the reverse).
if isinstance(entities, (set, frozenset)):
- result = entities.intersection(changed_entities)
- elif len(changed_entities) < len(entities):
- result = set(changed_entities).intersection(entities)
+ result = entities.intersection(cache_result.entities)
+ elif len(cache_result.entities) < len(entities):
+ result = set(cache_result.entities).intersection(entities)
else:
- result = set(entities).intersection(changed_entities)
+ result = set(entities).intersection(cache_result.entities)
self.metrics.inc_hits()
else:
result = set(entities)
@@ -202,12 +226,12 @@ class StreamChangeCache:
self.metrics.inc_hits()
return stream_pos < self._cache.peekitem()[0]
- def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType]]:
+ def get_all_entities_changed(self, stream_pos: int) -> AllEntitiesChangedResult:
"""
Returns all entities that have had changes after the given position.
- If the stream change cache does not go far enough back, i.e. the position
- is too old, it will return None.
+ If the stream change cache does not go far enough back, i.e. the
+ position is too old, it will return None.
Returns the entities in the order that they were changed.
@@ -215,23 +239,21 @@ class StreamChangeCache:
stream_pos: The stream position to check for changes after.
Return:
- Entities which have changed after the given stream position.
-
- None if the given stream position is at or earlier than the earliest
- known stream position.
+ A class indicating if we have the requested data cached, and if so
+ includes the entities in the order they were changed.
"""
assert isinstance(stream_pos, int)
# _cache is not valid at or before the earliest known stream position, so
# return None to mark that it is unknown if an entity has changed.
if stream_pos <= self._earliest_known_stream_pos:
- return None
+ return AllEntitiesChangedResult(None)
changed_entities: List[EntityType] = []
for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)):
changed_entities.extend(self._cache[k])
- return changed_entities
+ return AllEntitiesChangedResult(changed_entities)
def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None:
"""
|