summary refs log tree commit diff
path: root/synapse/util/caches
diff options
context:
space:
mode:
authorPatrick Cloke <patrickc@matrix.org>2023-09-29 15:39:48 -0400
committerPatrick Cloke <patrickc@matrix.org>2023-09-29 15:39:48 -0400
commitfcc4dc7181ebd28a5f54bff9e46f54994a2b3b6e (patch)
tree2930f9117367eb342de6207a89245f29b53abbdd /synapse/util/caches
parentNewsfile (diff)
parentDowngrade repl stream time out error to warning (#16401) (diff)
downloadsynapse-fcc4dc7181ebd28a5f54bff9e46f54994a2b3b6e.tar.xz
Merge remote-tracking branch 'origin/develop' into erikj/rust_lru_cache
Diffstat (limited to 'synapse/util/caches')
-rw-r--r--synapse/util/caches/__init__.py6
-rw-r--r--synapse/util/caches/deferred_cache.py8
-rw-r--r--synapse/util/caches/descriptors.py136
-rw-r--r--synapse/util/caches/dictionary_cache.py19
-rw-r--r--synapse/util/caches/expiringcache.py26
-rw-r--r--synapse/util/caches/lrucache.py31
-rw-r--r--synapse/util/caches/response_cache.py22
-rw-r--r--synapse/util/caches/stream_change_cache.py200
-rw-r--r--synapse/util/caches/ttlcache.py10
9 files changed, 240 insertions, 218 deletions
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 35c0be08b0..6ffa56217e 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -98,7 +98,6 @@ class EvictionReason(Enum):
 
 @attr.s(slots=True, auto_attribs=True)
 class CacheMetric:
-
     _cache: Sized
     _cache_type: str
     _cache_name: str
@@ -197,7 +196,7 @@ def register_cache(
         resize_callback: A function which can be called to resize the cache.
 
     Returns:
-        CacheMetric: an object which provides inc_{hits,misses,evictions} methods
+        an object which provides inc_{hits,misses,evictions} methods
     """
     if resizable:
         if not resize_callback:
@@ -205,8 +204,9 @@ def register_cache(
         add_resizable_cache(cache_name, resize_callback)
 
     metric = CacheMetric(cache, cache_type, cache_name, collect_callback)
+    metric_name = "cache_%s_%s" % (cache_type, cache_name)
     caches_by_name[cache_name] = cache
-    CACHE_METRIC_REGISTRY.register_hook(metric.collect)
+    CACHE_METRIC_REGISTRY.register_hook(metric_name, metric.collect)
     return metric
 
 
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 6425f851ea..029eedcc6f 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -153,7 +153,7 @@ class DeferredCache(Generic[KT, VT]):
         Args:
             key:
             callback: Gets called when the entry in the cache is invalidated
-            update_metrics (bool): whether to update the cache hit rate metrics
+            update_metrics: whether to update the cache hit rate metrics
 
         Returns:
             A Deferred which completes with the result. Note that this may later fail
@@ -395,8 +395,8 @@ class DeferredCache(Generic[KT, VT]):
             # _pending_deferred_cache.pop should either return a CacheEntry, or, in the
             # case of a TreeCache, a dict of keys to cache entries. Either way calling
             # iterate_tree_cache_entry on it will do the right thing.
-            for entry in iterate_tree_cache_entry(entry):
-                for cb in entry.get_invalidation_callbacks(key):
+            for iter_entry in iterate_tree_cache_entry(entry):
+                for cb in iter_entry.get_invalidation_callbacks(key):
                     cb()
 
     def invalidate_all(self) -> None:
@@ -470,7 +470,7 @@ class CacheMultipleEntries(CacheEntry[KT, VT]):
     def deferred(self, key: KT) -> "defer.Deferred[VT]":
         if not self._deferred:
             self._deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
-        return self._deferred.observe().addCallback(lambda res: res.get(key))
+        return self._deferred.observe().addCallback(lambda res: res[key])
 
     def add_invalidation_callback(
         self, key: KT, callback: Optional[Callable[[], None]]
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 10aff4d04a..8514a75a1c 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -12,7 +12,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import enum
 import functools
 import inspect
 import logging
@@ -53,10 +52,10 @@ CacheKey = Union[Tuple, Any]
 F = TypeVar("F", bound=Callable[..., Any])
 
 
-class _CachedFunction(Generic[F]):
-    invalidate: Any = None
-    invalidate_all: Any = None
-    prefill: Any = None
+class CachedFunction(Generic[F]):
+    invalidate: Callable[[Tuple[Any, ...]], None]
+    invalidate_all: Callable[[], None]
+    prefill: Callable[[Tuple[Any, ...], Any], None]
     cache: Any = None
     num_args: Any = None
 
@@ -146,109 +145,6 @@ class _CacheDescriptorBase:
         )
 
 
-class _LruCachedFunction(Generic[F]):
-    cache: LruCache[CacheKey, Any]
-    __call__: F
-
-
-def lru_cache(
-    *, max_entries: int = 1000, cache_context: bool = False
-) -> Callable[[F], _LruCachedFunction[F]]:
-    """A method decorator that applies a memoizing cache around the function.
-
-    This is more-or-less a drop-in equivalent to functools.lru_cache, although note
-    that the signature is slightly different.
-
-    The main differences with functools.lru_cache are:
-        (a) the size of the cache can be controlled via the cache_factor mechanism
-        (b) the wrapped function can request a "cache_context" which provides a
-            callback mechanism to indicate that the result is no longer valid
-        (c) prometheus metrics are exposed automatically.
-
-    The function should take zero or more arguments, which are used as the key for the
-    cache. Single-argument functions use that argument as the cache key; otherwise the
-    arguments are built into a tuple.
-
-    Cached functions can be "chained" (i.e. a cached function can call other cached
-    functions and get appropriately invalidated when they called caches are
-    invalidated) by adding a special "cache_context" argument to the function
-    and passing that as a kwarg to all caches called. For example:
-
-        @lru_cache(cache_context=True)
-        def foo(self, key, cache_context):
-            r1 = self.bar1(key, on_invalidate=cache_context.invalidate)
-            r2 = self.bar2(key, on_invalidate=cache_context.invalidate)
-            return r1 + r2
-
-    The wrapped function also has a 'cache' property which offers direct access to the
-    underlying LruCache.
-    """
-
-    def func(orig: F) -> _LruCachedFunction[F]:
-        desc = LruCacheDescriptor(
-            orig,
-            max_entries=max_entries,
-            cache_context=cache_context,
-        )
-        return cast(_LruCachedFunction[F], desc)
-
-    return func
-
-
-class LruCacheDescriptor(_CacheDescriptorBase):
-    """Helper for @lru_cache"""
-
-    class _Sentinel(enum.Enum):
-        sentinel = object()
-
-    def __init__(
-        self,
-        orig: Callable[..., Any],
-        max_entries: int = 1000,
-        cache_context: bool = False,
-    ):
-        super().__init__(
-            orig, num_args=None, uncached_args=None, cache_context=cache_context
-        )
-        self.max_entries = max_entries
-
-    def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
-        cache: LruCache[CacheKey, Any] = LruCache(
-            cache_name=self.name,
-            max_size=self.max_entries,
-        )
-
-        get_cache_key = self.cache_key_builder
-        sentinel = LruCacheDescriptor._Sentinel.sentinel
-
-        @functools.wraps(self.orig)
-        def _wrapped(*args: Any, **kwargs: Any) -> Any:
-            invalidate_callback = kwargs.pop("on_invalidate", None)
-            callbacks = (invalidate_callback,) if invalidate_callback else ()
-
-            cache_key = get_cache_key(args, kwargs)
-
-            ret = cache.get(cache_key, default=sentinel, callbacks=callbacks)
-            if ret != sentinel:
-                return ret
-
-            # Add our own `cache_context` to argument list if the wrapped function
-            # has asked for one
-            if self.add_cache_context:
-                kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
-
-            ret2 = self.orig(obj, *args, **kwargs)
-            cache.set(cache_key, ret2, callbacks=callbacks)
-
-            return ret2
-
-        wrapped = cast(_CachedFunction, _wrapped)
-        wrapped.cache = cache
-        obj.__dict__[self.name] = wrapped
-
-        return wrapped
-
-
 class DeferredCacheDescriptor(_CacheDescriptorBase):
     """A method decorator that applies a memoizing cache around the function.
 
@@ -324,7 +220,9 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         self.iterable = iterable
         self.prune_unread_entries = prune_unread_entries
 
-    def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
+    def __get__(
+        self, obj: Optional[Any], owner: Optional[Type]
+    ) -> Callable[..., "defer.Deferred[Any]"]:
         cache: DeferredCache[CacheKey, Any] = DeferredCache(
             name=self.name,
             max_entries=self.max_entries,
@@ -336,7 +234,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         get_cache_key = self.cache_key_builder
 
         @functools.wraps(self.orig)
-        def _wrapped(*args: Any, **kwargs: Any) -> Any:
+        def _wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Any]":
             # If we're passed a cache_context then we'll want to call its invalidate()
             # whenever we are invalidated
             invalidate_callback = kwargs.pop("on_invalidate", None)
@@ -363,7 +261,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
 
             return make_deferred_yieldable(ret)
 
-        wrapped = cast(_CachedFunction, _wrapped)
+        wrapped = cast(CachedFunction, _wrapped)
 
         if self.num_args == 1:
             assert not self.tree
@@ -431,6 +329,12 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
         cache: DeferredCache[CacheKey, Any] = cached_method.cache
         num_args = cached_method.num_args
 
+        if num_args != self.num_args:
+            raise TypeError(
+                "Number of args (%s) does not match underlying cache_method_name=%s (%s)."
+                % (self.num_args, self.cached_method_name, num_args)
+            )
+
         @functools.wraps(self.orig)
         def wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Dict]":
             # If we're passed a cache_context then we'll want to call its
@@ -572,7 +476,7 @@ def cached(
     iterable: bool = False,
     prune_unread_entries: bool = True,
     name: Optional[str] = None,
-) -> Callable[[F], _CachedFunction[F]]:
+) -> Callable[[F], CachedFunction[F]]:
     func = lambda orig: DeferredCacheDescriptor(
         orig,
         max_entries=max_entries,
@@ -585,7 +489,7 @@ def cached(
         name=name,
     )
 
-    return cast(Callable[[F], _CachedFunction[F]], func)
+    return cast(Callable[[F], CachedFunction[F]], func)
 
 
 def cachedList(
@@ -594,14 +498,14 @@ def cachedList(
     list_name: str,
     num_args: Optional[int] = None,
     name: Optional[str] = None,
-) -> Callable[[F], _CachedFunction[F]]:
+) -> Callable[[F], CachedFunction[F]]:
     """Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
 
     Used to do batch lookups for an already created cache. One of the arguments
     is specified as a list that is iterated through to lookup keys in the
     original cache. A new tuple consisting of the (deduplicated) keys that weren't in
     the cache gets passed to the original function, which is expected to results
-    in a map of key to value for each passed value. THe new results are stored in the
+    in a map of key to value for each passed value. The new results are stored in the
     original cache. Note that any missing values are cached as None.
 
     Args:
@@ -631,7 +535,7 @@ def cachedList(
         name=name,
     )
 
-    return cast(Callable[[F], _CachedFunction[F]], func)
+    return cast(Callable[[F], CachedFunction[F]], func)
 
 
 def _get_cache_key_builder(
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index fa91479c97..2fbc7b1e6c 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -14,7 +14,7 @@
 import enum
 import logging
 import threading
-from typing import Any, Dict, Generic, Iterable, Optional, Set, Tuple, TypeVar, Union
+from typing import Dict, Generic, Iterable, Optional, Set, Tuple, TypeVar, Union
 
 import attr
 from typing_extensions import Literal
@@ -33,10 +33,8 @@ DKT = TypeVar("DKT")
 DV = TypeVar("DV")
 
 
-# This class can't be generic because it uses slots with attrs.
-# See: https://github.com/python-attrs/attrs/issues/313
 @attr.s(slots=True, frozen=True, auto_attribs=True)
-class DictionaryEntry:  # should be: Generic[DKT, DV].
+class DictionaryEntry(Generic[DKT, DV]):
     """Returned when getting an entry from the cache
 
     If `full` is true then `known_absent` will be the empty set.
@@ -50,8 +48,8 @@ class DictionaryEntry:  # should be: Generic[DKT, DV].
     """
 
     full: bool
-    known_absent: Set[Any]  # should be: Set[DKT]
-    value: Dict[Any, Any]  # should be: Dict[DKT, DV]
+    known_absent: Set[DKT]
+    value: Dict[DKT, DV]
 
     def __len__(self) -> int:
         return len(self.value)
@@ -169,10 +167,11 @@ class DictionaryCache(Generic[KT, DKT, DV]):
                 if it is in the cache.
 
         Returns:
-            DictionaryEntry: If `dict_keys` is not None then `DictionaryEntry`
-            will contain include the keys that are in the cache. If None then
-            will either return the full dict if in the cache, or the empty
-            dict (with `full` set to False) if it isn't.
+            If `dict_keys` is not None then `DictionaryEntry` will contain include
+            the keys that are in the cache.
+
+            If None then will either return the full dict if in the cache, or the
+            empty dict (with `full` set to False) if it isn't.
         """
         if dict_keys is None:
             # The caller wants the full set of dictionary keys for this cache key
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index c6a5d0dfc0..e73cf66080 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -14,7 +14,7 @@
 
 import logging
 from collections import OrderedDict
-from typing import Any, Generic, Optional, TypeVar, Union, overload
+from typing import Any, Generic, Iterable, Optional, TypeVar, Union, overload
 
 import attr
 from typing_extensions import Literal
@@ -73,7 +73,7 @@ class ExpiringCache(Generic[KT, VT]):
         self._expiry_ms = expiry_ms
         self._reset_expiry_on_get = reset_expiry_on_get
 
-        self._cache: OrderedDict[KT, _CacheEntry] = OrderedDict()
+        self._cache: OrderedDict[KT, _CacheEntry[VT]] = OrderedDict()
 
         self.iterable = iterable
 
@@ -84,9 +84,7 @@ class ExpiringCache(Generic[KT, VT]):
             return
 
         def f() -> "defer.Deferred[None]":
-            return run_as_background_process(
-                "prune_cache_%s" % self._cache_name, self._prune_cache
-            )
+            return run_as_background_process("prune_cache", self._prune_cache)
 
         self._clock.looping_call(f, self._expiry_ms / 2)
 
@@ -100,7 +98,10 @@ class ExpiringCache(Generic[KT, VT]):
         while self._max_size and len(self) > self._max_size:
             _key, value = self._cache.popitem(last=False)
             if self.iterable:
-                self.metrics.inc_evictions(EvictionReason.size, len(value.value))
+                # type-ignore, here and below: if self.iterable is true, then the value
+                # type VT should be Sized (i.e. have a __len__ method). We don't enforce
+                # this via the type system at present.
+                self.metrics.inc_evictions(EvictionReason.size, len(value.value))  # type: ignore[arg-type]
             else:
                 self.metrics.inc_evictions(EvictionReason.size)
 
@@ -134,7 +135,7 @@ class ExpiringCache(Generic[KT, VT]):
             return default
 
         if self.iterable:
-            self.metrics.inc_evictions(EvictionReason.invalidation, len(value.value))
+            self.metrics.inc_evictions(EvictionReason.invalidation, len(value.value))  # type: ignore[arg-type]
         else:
             self.metrics.inc_evictions(EvictionReason.invalidation)
 
@@ -182,7 +183,7 @@ class ExpiringCache(Generic[KT, VT]):
         for k in keys_to_delete:
             value = self._cache.pop(k)
             if self.iterable:
-                self.metrics.inc_evictions(EvictionReason.time, len(value.value))
+                self.metrics.inc_evictions(EvictionReason.time, len(value.value))  # type: ignore[arg-type]
             else:
                 self.metrics.inc_evictions(EvictionReason.time)
 
@@ -195,7 +196,8 @@ class ExpiringCache(Generic[KT, VT]):
 
     def __len__(self) -> int:
         if self.iterable:
-            return sum(len(entry.value) for entry in self._cache.values())
+            g: Iterable[int] = (len(entry.value) for entry in self._cache.values())  # type: ignore[arg-type]
+            return sum(g)
         else:
             return len(self._cache)
 
@@ -207,7 +209,7 @@ class ExpiringCache(Generic[KT, VT]):
         items from the cache.
 
         Returns:
-            bool: Whether the cache changed size or not.
+            Whether the cache changed size or not.
         """
         new_size = int(self._original_max_size * factor)
         if new_size != self._max_size:
@@ -218,6 +220,6 @@ class ExpiringCache(Generic[KT, VT]):
 
 
 @attr.s(slots=True, auto_attribs=True)
-class _CacheEntry:
+class _CacheEntry(Generic[VT]):
     time: int
-    value: Any
+    value: VT
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 895594adbe..7d1e405457 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -94,10 +94,8 @@ VT = TypeVar("VT")
 # a general type var, distinct from either KT or VT
 T = TypeVar("T")
 
-P = TypeVar("P")
 
-
-class _TimedListNode(ListNode[P]):
+class _TimedListNode(ListNode[T]):
     """A `ListNode` that tracks last access time."""
 
     __slots__ = ["last_access_ts_secs"]
@@ -390,11 +388,11 @@ class LruCache(Generic[KT, VT]):
             cache_name: The name of this cache, for the prometheus metrics. If unset,
                 no metrics will be reported on this cache.
 
-            cache_type (type):
+            cache_type:
                 type of underlying cache to be used. Typically one of dict
                 or TreeCache.
 
-            size_callback (func(V) -> int | None):
+            size_callback:
 
             metrics_collection_callback:
                 metrics collection callback. This is called early in the metrics
@@ -404,7 +402,7 @@ class LruCache(Generic[KT, VT]):
 
                 Ignored if cache_name is None.
 
-            apply_cache_factor_from_config (bool): If true, `max_size` will be
+            apply_cache_factor_from_config: If true, `max_size` will be
                 multiplied by a cache factor derived from the homeserver config
 
             clock:
@@ -784,26 +782,21 @@ class LruCache(Generic[KT, VT]):
     def __contains__(self, key: KT) -> bool:
         return self.contains(key)
 
-    def set_cache_factor(self, factor: float) -> bool:
+    def set_cache_factor(self, factor: float) -> None:
         """
         Set the cache factor for this individual cache.
 
         This will trigger a resize if it changes, which may require evicting
         items from the cache.
-
-        Returns:
-            bool: Whether the cache changed size or not.
         """
         if not self.apply_cache_factor_from_config:
-            return False
+            return
 
         new_size = int(self._original_max_size * factor)
         if new_size != self.max_size:
             self.max_size = new_size
             if self._on_resize:
                 self._on_resize()
-            return True
-        return False
 
     def __del__(self) -> None:
         # We're about to be deleted, so we make sure to clear up all the nodes
@@ -822,7 +815,7 @@ class AsyncLruCache(Generic[KT, VT]):
     utilize external cache systems that require await behaviour to be created.
     """
 
-    def __init__(self, *args, **kwargs):  # type: ignore
+    def __init__(self, *args: Any, **kwargs: Any):
         self._lru_cache: LruCache[KT, VT] = LruCache(*args, **kwargs)
 
     async def get(
@@ -845,7 +838,13 @@ class AsyncLruCache(Generic[KT, VT]):
         return self._lru_cache.get(key, update_metrics=update_metrics)
 
     async def set(self, key: KT, value: VT) -> None:
-        self._lru_cache.set(key, value)
+        # This will add the entries in the correct order, local first external second
+        self.set_local(key, value)
+        await self.set_external(key, value)
+
+    async def set_external(self, key: KT, value: VT) -> None:
+        # This method should add an entry to any configured external cache, in this case noop.
+        pass
 
     def set_local(self, key: KT, value: VT) -> None:
         self._lru_cache.set(key, value)
@@ -865,5 +864,5 @@ class AsyncLruCache(Generic[KT, VT]):
     async def contains(self, key: KT) -> bool:
         return self._lru_cache.contains(key)
 
-    async def clear(self) -> None:
+    def clear(self) -> None:
         self._lru_cache.clear()
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index a3eb5f741b..0cb46700a9 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -36,7 +36,7 @@ from synapse.logging.opentracing import (
 )
 from synapse.util import Clock
 from synapse.util.async_helpers import AbstractObservableDeferred, ObservableDeferred
-from synapse.util.caches import register_cache
+from synapse.util.caches import EvictionReason, register_cache
 
 logger = logging.getLogger(__name__)
 
@@ -167,12 +167,10 @@ class ResponseCache(Generic[KV]):
             # the should_cache bit, we leave it in the cache for now and schedule
             # its removal later.
             if self.timeout_sec and context.should_cache:
-                self.clock.call_later(
-                    self.timeout_sec, self._result_cache.pop, key, None
-                )
+                self.clock.call_later(self.timeout_sec, self._entry_timeout, key)
             else:
                 # otherwise, remove the result immediately.
-                self._result_cache.pop(key, None)
+                self.unset(key)
             return r
 
         # make sure we do this *after* adding the entry to result_cache,
@@ -181,6 +179,20 @@ class ResponseCache(Generic[KV]):
         result.addBoth(on_complete)
         return entry
 
+    def unset(self, key: KV) -> None:
+        """Remove the cached value for this key from the cache, if any.
+
+        Args:
+            key: key used to remove the cached value
+        """
+        self._metrics.inc_evictions(EvictionReason.invalidation)
+        self._result_cache.pop(key, None)
+
+    def _entry_timeout(self, key: KV) -> None:
+        """For the call_later to remove from the cache"""
+        self._metrics.inc_evictions(EvictionReason.time)
+        self._result_cache.pop(key, None)
+
     async def wrap(
         self,
         key: KV,
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 330709b8b7..1657459549 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -16,6 +16,7 @@ import logging
 import math
 from typing import Collection, Dict, FrozenSet, List, Mapping, Optional, Set, Union
 
+import attr
 from sortedcontainers import SortedDict
 
 from synapse.util import caches
@@ -26,14 +27,41 @@ logger = logging.getLogger(__name__)
 EntityType = str
 
 
+@attr.s(auto_attribs=True, frozen=True, slots=True)
+class AllEntitiesChangedResult:
+    """Return type of `get_all_entities_changed`.
+
+    Callers must check that there was a cache hit, via `result.hit`, before
+    using the entities in `result.entities`.
+
+    This specifically does *not* implement helpers such as `__bool__` to ensure
+    that callers do the correct checks.
+    """
+
+    _entities: Optional[List[EntityType]]
+
+    @property
+    def hit(self) -> bool:
+        return self._entities is not None
+
+    @property
+    def entities(self) -> List[EntityType]:
+        assert self._entities is not None
+        return self._entities
+
+
 class StreamChangeCache:
-    """Keeps track of the stream positions of the latest change in a set of entities.
+    """
+    Keeps track of the stream positions of the latest change in a set of entities.
+
+    The entity will is typically a room ID or user ID, but can be any string.
 
-    Typically the entity will be a room or user id.
+    Can be queried for whether a specific entity has changed after a stream position
+    or for a list of changed entities after a stream position. See the individual
+    methods for more information.
 
-    Given a list of entities and a stream position, it will give a subset of
-    entities that may have changed since that position. If position key is too
-    old then the cache will simply return all given entities.
+    Only tracks to a maximum cache size, any position earlier than the earliest
+    known stream position must be treated as unknown.
     """
 
     def __init__(
@@ -45,16 +73,20 @@ class StreamChangeCache:
     ) -> None:
         self._original_max_size: int = max_size
         self._max_size = math.floor(max_size)
-        self._entity_to_key: Dict[EntityType, int] = {}
 
-        # map from stream id to the a set of entities which changed at that stream id.
+        # map from stream id to the set of entities which changed at that stream id.
         self._cache: SortedDict[int, Set[EntityType]] = SortedDict()
+        # map from entity to the stream ID of the latest change for that entity.
+        #
+        # Must be kept in sync with _cache.
+        self._entity_to_key: Dict[EntityType, int] = {}
 
         # the earliest stream_pos for which we can reliably answer
         # get_all_entities_changed. In other words, one less than the earliest
         # stream_pos for which we know _cache is valid.
         #
         self._earliest_known_stream_pos = current_stream_pos
+
         self.name = name
         self.metrics = caches.register_cache(
             "cache", self.name, self._cache, resize_callback=self.set_cache_factor
@@ -72,7 +104,7 @@ class StreamChangeCache:
         items from the cache.
 
         Returns:
-            bool: Whether the cache changed size or not.
+            Whether the cache changed size or not.
         """
         new_size = math.floor(self._original_max_size * factor)
         if new_size != self._max_size:
@@ -82,22 +114,46 @@ class StreamChangeCache:
         return False
 
     def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool:
-        """Returns True if the entity may have been updated since stream_pos"""
+        """
+        Returns True if the entity may have been updated after stream_pos.
+
+        Args:
+            entity: The entity to check for changes.
+            stream_pos: The stream position to check for changes after.
+
+        Return:
+            True if the entity may have been updated, this happens if:
+                * The given stream position is at or earlier than the earliest
+                  known stream position.
+                * The given stream position is earlier than the latest change for
+                  the entity.
+
+            False otherwise:
+                * The entity is unknown.
+                * The given stream position is at or later than the latest change
+                  for the entity.
+        """
         assert isinstance(stream_pos, int)
 
-        if stream_pos < self._earliest_known_stream_pos:
+        # _cache is not valid at or before the earliest known stream position, so
+        # return that the entity has changed.
+        if stream_pos <= self._earliest_known_stream_pos:
             self.metrics.inc_misses()
             return True
 
+        # If the entity is unknown, it hasn't changed.
         latest_entity_change_pos = self._entity_to_key.get(entity, None)
         if latest_entity_change_pos is None:
             self.metrics.inc_hits()
             return False
 
+        # This is a known entity, return true if the stream position is earlier
+        # than the last change.
         if stream_pos < latest_entity_change_pos:
             self.metrics.inc_misses()
             return True
 
+        # Otherwise, the stream position is after the latest change: return false.
         self.metrics.inc_hits()
         return False
 
@@ -105,23 +161,35 @@ class StreamChangeCache:
         self, entities: Collection[EntityType], stream_pos: int
     ) -> Union[Set[EntityType], FrozenSet[EntityType]]:
         """
-        Returns subset of entities that have had new things since the given
-        position.  Entities unknown to the cache will be returned.  If the
-        position is too old it will just return the given list.
+        Returns the subset of the given entities that have had changes after the given position.
+
+        Entities unknown to the cache will be returned.
+
+        If the position is too old it will just return the given list.
+
+        Args:
+            entities: Entities to check for changes.
+            stream_pos: The stream position to check for changes after.
+
+        Return:
+            A subset of entities which have changed after the given stream position.
+
+            This will be all entities if the given stream position is at or earlier
+            than the earliest known stream position.
         """
-        changed_entities = self.get_all_entities_changed(stream_pos)
-        if changed_entities is not None:
+        cache_result = self.get_all_entities_changed(stream_pos)
+        if cache_result.hit:
             # We now do an intersection, trying to do so in the most efficient
             # way possible (some of these sets are *large*). First check in the
-            # given iterable is already set that we can reuse, otherwise we
+            # given iterable is already a set that we can reuse, otherwise we
             # create a set of the *smallest* of the two iterables and call
             # `intersection(..)` on it (this can be twice as fast as the reverse).
             if isinstance(entities, (set, frozenset)):
-                result = entities.intersection(changed_entities)
-            elif len(changed_entities) < len(entities):
-                result = set(changed_entities).intersection(entities)
+                result = entities.intersection(cache_result.entities)
+            elif len(cache_result.entities) < len(entities):
+                result = set(cache_result.entities).intersection(entities)
             else:
-                result = set(entities).intersection(changed_entities)
+                result = set(entities).intersection(cache_result.entities)
             self.metrics.inc_hits()
         else:
             result = set(entities)
@@ -130,43 +198,76 @@ class StreamChangeCache:
         return result
 
     def has_any_entity_changed(self, stream_pos: int) -> bool:
-        """Returns if any entity has changed"""
-        assert type(stream_pos) is int
+        """
+        Returns true if any entity has changed after the given stream position.
 
-        if not self._cache:
-            # If the cache is empty, nothing can have changed.
-            return False
+        Args:
+            stream_pos: The stream position to check for changes after.
 
-        if stream_pos >= self._earliest_known_stream_pos:
-            self.metrics.inc_hits()
-            return self._cache.bisect_right(stream_pos) < len(self._cache)
-        else:
+        Return:
+            True if any entity has changed after the given stream position or
+            if the given stream position is at or earlier than the earliest
+            known stream position.
+
+            False otherwise.
+        """
+        assert isinstance(stream_pos, int)
+
+        # _cache is not valid at or before the earliest known stream position, so
+        # return that an entity has changed.
+        if stream_pos <= self._earliest_known_stream_pos:
             self.metrics.inc_misses()
             return True
 
-    def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType]]:
-        """Returns all entities that have had new things since the given
-        position. If the position is too old it will return None.
+        # If the cache is empty, nothing can have changed.
+        if not self._cache:
+            self.metrics.inc_misses()
+            return False
+
+        self.metrics.inc_hits()
+        return stream_pos < self._cache.peekitem()[0]
+
+    def get_all_entities_changed(self, stream_pos: int) -> AllEntitiesChangedResult:
+        """
+        Returns all entities that have had changes after the given position.
+
+        If the stream change cache does not go far enough back, i.e. the
+        position is too old, it will return None.
 
         Returns the entities in the order that they were changed.
+
+        Args:
+            stream_pos: The stream position to check for changes after.
+
+        Return:
+            A class indicating if we have the requested data cached, and if so
+            includes the entities in the order they were changed.
         """
-        assert type(stream_pos) is int
+        assert isinstance(stream_pos, int)
 
-        if stream_pos < self._earliest_known_stream_pos:
-            return None
+        # _cache is not valid at or before the earliest known stream position, so
+        # return None to mark that it is unknown if an entity has changed.
+        if stream_pos <= self._earliest_known_stream_pos:
+            return AllEntitiesChangedResult(None)
 
         changed_entities: List[EntityType] = []
 
         for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)):
             changed_entities.extend(self._cache[k])
-        return changed_entities
+        return AllEntitiesChangedResult(changed_entities)
 
     def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None:
-        """Informs the cache that the entity has been changed at the given
-        position.
         """
-        assert type(stream_pos) is int
+        Informs the cache that the entity has been changed at the given position.
+
+        Args:
+            entity: The entity to mark as changed.
+            stream_pos: The stream position to update the entity to.
+        """
+        assert isinstance(stream_pos, int)
 
+        # For a change before _cache is valid (e.g. at or before the earliest known
+        # stream position) there's nothing to do.
         if stream_pos <= self._earliest_known_stream_pos:
             return
 
@@ -188,14 +289,13 @@ class StreamChangeCache:
         self._entity_to_key[entity] = stream_pos
         self._evict()
 
-        # if the cache is too big, remove entries
-        while len(self._cache) > self._max_size:
-            k, r = self._cache.popitem(0)
-            self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)
-            for entity in r:
-                del self._entity_to_key[entity]
-
     def _evict(self) -> None:
+        """
+        Ensure the cache has not exceeded the maximum size.
+
+        Evicts entries until it is at the maximum size.
+        """
+        # if the cache is too big, remove entries
         while len(self._cache) > self._max_size:
             k, r = self._cache.popitem(0)
             self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)
@@ -203,8 +303,14 @@ class StreamChangeCache:
                 self._entity_to_key.pop(entity, None)
 
     def get_max_pos_of_last_change(self, entity: EntityType) -> int:
-
         """Returns an upper bound of the stream id of the last change to an
         entity.
+
+        Args:
+            entity: The entity to check.
+
+        Return:
+            The stream position of the latest change for the given entity or
+            the earliest known stream position if the entitiy is unknown.
         """
         return self._entity_to_key.get(entity, self._earliest_known_stream_pos)
diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index f6b3ee31e4..48a6e4a906 100644
--- a/synapse/util/caches/ttlcache.py
+++ b/synapse/util/caches/ttlcache.py
@@ -35,10 +35,10 @@ class TTLCache(Generic[KT, VT]):
 
     def __init__(self, cache_name: str, timer: Callable[[], float] = time.time):
         # map from key to _CacheEntry
-        self._data: Dict[KT, _CacheEntry] = {}
+        self._data: Dict[KT, _CacheEntry[KT, VT]] = {}
 
         # the _CacheEntries, sorted by expiry time
-        self._expiry_list: SortedList[_CacheEntry] = SortedList()
+        self._expiry_list: SortedList[_CacheEntry[KT, VT]] = SortedList()
 
         self._timer = timer
 
@@ -160,11 +160,11 @@ class TTLCache(Generic[KT, VT]):
 
 
 @attr.s(frozen=True, slots=True, auto_attribs=True)
-class _CacheEntry:  # Should be Generic[KT, VT]. See python-attrs/attrs#313
+class _CacheEntry(Generic[KT, VT]):
     """TTLCache entry"""
 
     # expiry_time is the first attribute, so that entries are sorted by expiry.
     expiry_time: float
     ttl: float
-    key: Any  # should be KT
-    value: Any  # should be VT
+    key: KT
+    value: VT