summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/caches/stream_change_cache.py52
1 files changed, 37 insertions, 15 deletions
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 042de8d7c8..c8b17acb59 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -16,6 +16,7 @@ import logging
 import math
 from typing import Collection, Dict, FrozenSet, List, Mapping, Optional, Set, Union
 
+import attr
 from sortedcontainers import SortedDict
 
 from synapse.util import caches
@@ -26,6 +27,29 @@ logger = logging.getLogger(__name__)
 EntityType = str
 
 
+@attr.s(auto_attribs=True, frozen=True, slots=True)
+class AllEntitiesChangedResult:
+    """Return type of `get_all_entities_changed`.
+
+    Callers must check that there was a cache hit, via `result.hit`, before
+    using the entities in `result.entities`.
+
+    This specifically does *not* implement helpers such as `__bool__` to ensure
+    that callers do the correct checks.
+    """
+
+    _entities: Optional[List[EntityType]]
+
+    @property
+    def hit(self) -> bool:
+        return self._entities is not None
+
+    @property
+    def entities(self) -> List[EntityType]:
+        assert self._entities is not None
+        return self._entities
+
+
 class StreamChangeCache:
     """
     Keeps track of the stream positions of the latest change in a set of entities.
@@ -153,19 +177,19 @@ class StreamChangeCache:
             This will be all entities if the given stream position is at or earlier
             than the earliest known stream position.
         """
-        changed_entities = self.get_all_entities_changed(stream_pos)
-        if changed_entities is not None:
+        cache_result = self.get_all_entities_changed(stream_pos)
+        if cache_result.hit:
             # We now do an intersection, trying to do so in the most efficient
             # way possible (some of these sets are *large*). First check in the
             # given iterable is already a set that we can reuse, otherwise we
             # create a set of the *smallest* of the two iterables and call
             # `intersection(..)` on it (this can be twice as fast as the reverse).
             if isinstance(entities, (set, frozenset)):
-                result = entities.intersection(changed_entities)
-            elif len(changed_entities) < len(entities):
-                result = set(changed_entities).intersection(entities)
+                result = entities.intersection(cache_result.entities)
+            elif len(cache_result.entities) < len(entities):
+                result = set(cache_result.entities).intersection(entities)
             else:
-                result = set(entities).intersection(changed_entities)
+                result = set(entities).intersection(cache_result.entities)
             self.metrics.inc_hits()
         else:
             result = set(entities)
@@ -202,12 +226,12 @@ class StreamChangeCache:
         self.metrics.inc_hits()
         return stream_pos < self._cache.peekitem()[0]
 
-    def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType]]:
+    def get_all_entities_changed(self, stream_pos: int) -> AllEntitiesChangedResult:
         """
         Returns all entities that have had changes after the given position.
 
-        If the stream change cache does not go far enough back, i.e. the position
-        is too old, it will return None.
+        If the stream change cache does not go far enough back, i.e. the
+        position is too old, it will return None.
 
         Returns the entities in the order that they were changed.
 
@@ -215,23 +239,21 @@ class StreamChangeCache:
             stream_pos: The stream position to check for changes after.
 
         Return:
-            Entities which have changed after the given stream position.
-
-            None if the given stream position is at or earlier than the earliest
-            known stream position.
+            A class indicating if we have the requested data cached, and if so
+            includes the entities in the order they were changed.
         """
         assert isinstance(stream_pos, int)
 
         # _cache is not valid at or before the earliest known stream position, so
         # return None to mark that it is unknown if an entity has changed.
         if stream_pos <= self._earliest_known_stream_pos:
-            return None
+            return AllEntitiesChangedResult(None)
 
         changed_entities: List[EntityType] = []
 
         for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)):
             changed_entities.extend(self._cache[k])
-        return changed_entities
+        return AllEntitiesChangedResult(changed_entities)
 
     def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None:
         """