summary refs log tree commit diff
path: root/synapse/util/caches
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches')
-rw-r--r--synapse/util/caches/dictionary_cache.py218
-rw-r--r--synapse/util/caches/lrucache.py90
-rw-r--r--synapse/util/caches/treecache.py38
3 files changed, 313 insertions, 33 deletions
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index d267703df0..fa91479c97 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -14,11 +14,13 @@
 import enum
 import logging
 import threading
-from typing import Any, Dict, Generic, Iterable, Optional, Set, TypeVar
+from typing import Any, Dict, Generic, Iterable, Optional, Set, Tuple, TypeVar, Union
 
 import attr
+from typing_extensions import Literal
 
 from synapse.util.caches.lrucache import LruCache
+from synapse.util.caches.treecache import TreeCache
 
 logger = logging.getLogger(__name__)
 
@@ -33,10 +35,12 @@ DV = TypeVar("DV")
 
 # This class can't be generic because it uses slots with attrs.
 # See: https://github.com/python-attrs/attrs/issues/313
-@attr.s(slots=True, auto_attribs=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
 class DictionaryEntry:  # should be: Generic[DKT, DV].
     """Returned when getting an entry from the cache
 
+    If `full` is true then `known_absent` will be the empty set.
+
     Attributes:
         full: Whether the cache has the full or dict or just some keys.
             If not full then not all requested keys will necessarily be present
@@ -53,20 +57,90 @@ class DictionaryEntry:  # should be: Generic[DKT, DV].
         return len(self.value)
 
 
+class _FullCacheKey(enum.Enum):
+    """The key we use to cache the full dict."""
+
+    KEY = object()
+
+
 class _Sentinel(enum.Enum):
     # defining a sentinel in this way allows mypy to correctly handle the
     # type of a dictionary lookup.
     sentinel = object()
 
 
+class _PerKeyValue(Generic[DV]):
+    """The cached value of a dictionary key. If `value` is the sentinel,
+    indicates that the requested key is known to *not* be in the full dict.
+    """
+
+    __slots__ = ["value"]
+
+    def __init__(self, value: Union[DV, Literal[_Sentinel.sentinel]]) -> None:
+        self.value = value
+
+    def __len__(self) -> int:
+        # We add a `__len__` implementation as we use this class in a cache
+        # where the values are variable length.
+        return 1
+
+
 class DictionaryCache(Generic[KT, DKT, DV]):
     """Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
     fetching a subset of dictionary keys for a particular key.
+
+    This cache has two levels of key. First there is the "cache key" (of type
+    `KT`), which maps to a dict. The keys to that dict are the "dict key" (of
+    type `DKT`). The overall structure is therefore `KT->DKT->DV`. For
+    example, it might look like:
+
+       {
+           1: { 1: "a", 2: "b" },
+           2: { 1: "c" },
+       }
+
+    It is possible to look up either individual dict keys, or the *complete*
+    dict for a given cache key.
+
+    Each dict item, and the complete dict is treated as a separate LRU
+    entry for the purpose of cache expiry. For example, given:
+        dict_cache.get(1, None)  -> DictionaryEntry({1: "a", 2: "b"})
+        dict_cache.get(1, [1])  -> DictionaryEntry({1: "a"})
+        dict_cache.get(1, [2])  -> DictionaryEntry({2: "b"})
+
+    ... then the cache entry for the complete dict will expire first,
+    followed by the cache entry for the '1' dict key, and finally that
+    for the '2' dict key.
     """
 
     def __init__(self, name: str, max_entries: int = 1000):
-        self.cache: LruCache[KT, DictionaryEntry] = LruCache(
-            max_size=max_entries, cache_name=name, size_callback=len
+        # We use a single LruCache to store two different types of entries:
+        #   1. Map from (key, dict_key) -> dict value (or sentinel, indicating
+        #      the key doesn't exist in the dict); and
+        #   2. Map from (key, _FullCacheKey.KEY) -> full dict.
+        #
+        # The former is used when explicit keys of the dictionary are looked up,
+        # and the latter when the full dictionary is requested.
+        #
+        # If when explicit keys are requested and not in the cache, we then look
+        # to see if we have the full dict and use that if we do. If found in the
+        # full dict each key is added into the cache.
+        #
+        # This set up allows the `LruCache` to prune the full dict entries if
+        # they haven't been used in a while, even when there have been recent
+        # queries for subsets of the dict.
+        #
+        # Typing:
+        #     * A key of `(KT, DKT)` has a value of `_PerKeyValue`
+        #     * A key of `(KT, _FullCacheKey.KEY)` has a value of `Dict[DKT, DV]`
+        self.cache: LruCache[
+            Tuple[KT, Union[DKT, Literal[_FullCacheKey.KEY]]],
+            Union[_PerKeyValue, Dict[DKT, DV]],
+        ] = LruCache(
+            max_size=max_entries,
+            cache_name=name,
+            cache_type=TreeCache,
+            size_callback=len,
         )
 
         self.name = name
@@ -91,23 +165,83 @@ class DictionaryCache(Generic[KT, DKT, DV]):
         Args:
             key
             dict_keys: If given a set of keys then return only those keys
-                that exist in the cache.
+                that exist in the cache. If None then returns the full dict
+                if it is in the cache.
 
         Returns:
-            DictionaryEntry
+            DictionaryEntry: If `dict_keys` is not None then `DictionaryEntry`
+            will contain include the keys that are in the cache. If None then
+            will either return the full dict if in the cache, or the empty
+            dict (with `full` set to False) if it isn't.
         """
-        entry = self.cache.get(key, _Sentinel.sentinel)
-        if entry is not _Sentinel.sentinel:
-            if dict_keys is None:
-                return DictionaryEntry(
-                    entry.full, entry.known_absent, dict(entry.value)
-                )
+        if dict_keys is None:
+            # The caller wants the full set of dictionary keys for this cache key
+            return self._get_full_dict(key)
+
+        # We are being asked for a subset of keys.
+
+        # First go and check for each requested dict key in the cache, tracking
+        # which we couldn't find.
+        values = {}
+        known_absent = set()
+        missing = []
+        for dict_key in dict_keys:
+            entry = self.cache.get((key, dict_key), _Sentinel.sentinel)
+            if entry is _Sentinel.sentinel:
+                missing.append(dict_key)
+                continue
+
+            assert isinstance(entry, _PerKeyValue)
+
+            if entry.value is _Sentinel.sentinel:
+                known_absent.add(dict_key)
             else:
-                return DictionaryEntry(
-                    entry.full,
-                    entry.known_absent,
-                    {k: entry.value[k] for k in dict_keys if k in entry.value},
-                )
+                values[dict_key] = entry.value
+
+        # If we found everything we can return immediately.
+        if not missing:
+            return DictionaryEntry(False, known_absent, values)
+
+        # We are missing some keys, so check if we happen to have the full dict in
+        # the cache.
+        #
+        # We don't update the last access time for this cache fetch, as we
+        # aren't explicitly interested in the full dict and so we don't want
+        # requests for explicit dict keys to keep the full dict in the cache.
+        entry = self.cache.get(
+            (key, _FullCacheKey.KEY),
+            _Sentinel.sentinel,
+            update_last_access=False,
+        )
+        if entry is _Sentinel.sentinel:
+            # Not in the cache, return the subset of keys we found.
+            return DictionaryEntry(False, known_absent, values)
+
+        # We have the full dict!
+        assert isinstance(entry, dict)
+
+        for dict_key in missing:
+            # We explicitly add each dict key to the cache, so that cache hit
+            # rates and LRU times for each key can be tracked separately.
+            value = entry.get(dict_key, _Sentinel.sentinel)  # type: ignore[arg-type]
+            self.cache[(key, dict_key)] = _PerKeyValue(value)
+
+            if value is not _Sentinel.sentinel:
+                values[dict_key] = value
+
+        return DictionaryEntry(True, set(), values)
+
+    def _get_full_dict(
+        self,
+        key: KT,
+    ) -> DictionaryEntry:
+        """Fetch the full dict for the given key."""
+
+        # First we check if we have cached the full dict.
+        entry = self.cache.get((key, _FullCacheKey.KEY), _Sentinel.sentinel)
+        if entry is not _Sentinel.sentinel:
+            assert isinstance(entry, dict)
+            return DictionaryEntry(True, set(), entry)
 
         return DictionaryEntry(False, set(), {})
 
@@ -117,7 +251,13 @@ class DictionaryCache(Generic[KT, DKT, DV]):
         # Increment the sequence number so that any SELECT statements that
         # raced with the INSERT don't update the cache (SYN-369)
         self.sequence += 1
-        self.cache.pop(key, None)
+
+        # We want to drop all information about the dict for the given key, so
+        # we use `del_multi` to delete it all in one go.
+        #
+        # We ignore the type error here: `del_multi` accepts a truncated key
+        # (when the key type is a tuple).
+        self.cache.del_multi((key,))  # type: ignore[arg-type]
 
     def invalidate_all(self) -> None:
         self.check_thread()
@@ -131,7 +271,16 @@ class DictionaryCache(Generic[KT, DKT, DV]):
         value: Dict[DKT, DV],
         fetched_keys: Optional[Iterable[DKT]] = None,
     ) -> None:
-        """Updates the entry in the cache
+        """Updates the entry in the cache.
+
+        Note: This does *not* invalidate any existing entries for the `key`.
+        In particular, if we add an entry for the cached "full dict" with
+        `fetched_keys=None`, existing entries for individual dict keys are
+        not invalidated. Likewise, adding entries for individual keys does
+        not invalidate any cached value for the full dict.
+
+        In other words: if the underlying data is *changed*, the cache must
+        be explicitly invalidated via `.invalidate()`.
 
         Args:
             sequence
@@ -149,20 +298,27 @@ class DictionaryCache(Generic[KT, DKT, DV]):
             # Only update the cache if the caches sequence number matches the
             # number that the cache had before the SELECT was started (SYN-369)
             if fetched_keys is None:
-                self._insert(key, value, set())
+                self.cache[(key, _FullCacheKey.KEY)] = value
             else:
-                self._update_or_insert(key, value, fetched_keys)
+                self._update_subset(key, value, fetched_keys)
 
-    def _update_or_insert(
-        self, key: KT, value: Dict[DKT, DV], known_absent: Iterable[DKT]
+    def _update_subset(
+        self, key: KT, value: Dict[DKT, DV], fetched_keys: Iterable[DKT]
     ) -> None:
-        # We pop and reinsert as we need to tell the cache the size may have
-        # changed
+        """Add the given dictionary values as explicit keys in the cache.
+
+        Args:
+            key: top-level cache key
+            value: The dictionary with all the values that we should cache
+            fetched_keys: The full set of dict keys that were looked up. Any keys
+                here not in `value` should be marked as "known absent".
+        """
+
+        for dict_key, dict_value in value.items():
+            self.cache[(key, dict_key)] = _PerKeyValue(dict_value)
 
-        entry: DictionaryEntry = self.cache.pop(key, DictionaryEntry(False, set(), {}))
-        entry.value.update(value)
-        entry.known_absent.update(known_absent)
-        self.cache[key] = entry
+        for dict_key in fetched_keys:
+            if dict_key in value:
+                continue
 
-    def _insert(self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]) -> None:
-        self.cache[key] = DictionaryEntry(True, known_absent, value)
+            self.cache[(key, dict_key)] = _PerKeyValue(_Sentinel.sentinel)
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 31f41fec82..b3bdedb04c 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -25,8 +25,10 @@ from typing import (
     Collection,
     Dict,
     Generic,
+    Iterable,
     List,
     Optional,
+    Tuple,
     Type,
     TypeVar,
     Union,
@@ -44,7 +46,11 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
 from synapse.metrics.jemalloc import get_jemalloc_stats
 from synapse.util import Clock, caches
 from synapse.util.caches import CacheMetric, EvictionReason, register_cache
-from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
+from synapse.util.caches.treecache import (
+    TreeCache,
+    iterate_tree_cache_entry,
+    iterate_tree_cache_items,
+)
 from synapse.util.linked_list import ListNode
 
 if TYPE_CHECKING:
@@ -537,6 +543,7 @@ class LruCache(Generic[KT, VT]):
             default: Literal[None] = None,
             callbacks: Collection[Callable[[], None]] = ...,
             update_metrics: bool = ...,
+            update_last_access: bool = ...,
         ) -> Optional[VT]:
             ...
 
@@ -546,6 +553,7 @@ class LruCache(Generic[KT, VT]):
             default: T,
             callbacks: Collection[Callable[[], None]] = ...,
             update_metrics: bool = ...,
+            update_last_access: bool = ...,
         ) -> Union[T, VT]:
             ...
 
@@ -555,10 +563,27 @@ class LruCache(Generic[KT, VT]):
             default: Optional[T] = None,
             callbacks: Collection[Callable[[], None]] = (),
             update_metrics: bool = True,
+            update_last_access: bool = True,
         ) -> Union[None, T, VT]:
+            """Look up a key in the cache
+
+            Args:
+                key
+                default
+                callbacks: A collection of callbacks that will fire when the
+                    node is removed from the cache (either due to invalidation
+                    or expiry).
+                update_metrics: Whether to update the hit rate metrics
+                update_last_access: Whether to update the last access metrics
+                    on a node if successfully fetched. These metrics are used
+                    to determine when to remove the node from the cache. Set
+                    to False if this fetch should *not* prevent a node from
+                    being expired.
+            """
             node = cache.get(key, None)
             if node is not None:
-                move_node_to_front(node)
+                if update_last_access:
+                    move_node_to_front(node)
                 node.add_callbacks(callbacks)
                 if update_metrics and metrics:
                     metrics.inc_hits()
@@ -568,6 +593,65 @@ class LruCache(Generic[KT, VT]):
                     metrics.inc_misses()
                 return default
 
+        @overload
+        def cache_get_multi(
+            key: tuple,
+            default: Literal[None] = None,
+            update_metrics: bool = True,
+        ) -> Union[None, Iterable[Tuple[KT, VT]]]:
+            ...
+
+        @overload
+        def cache_get_multi(
+            key: tuple,
+            default: T,
+            update_metrics: bool = True,
+        ) -> Union[T, Iterable[Tuple[KT, VT]]]:
+            ...
+
+        @synchronized
+        def cache_get_multi(
+            key: tuple,
+            default: Optional[T] = None,
+            update_metrics: bool = True,
+        ) -> Union[None, T, Iterable[Tuple[KT, VT]]]:
+            """Returns a generator yielding all entries under the given key.
+
+            Can only be used if backed by a tree cache.
+
+            Example:
+
+                cache = LruCache(10, cache_type=TreeCache)
+                cache[(1, 1)] = "a"
+                cache[(1, 2)] = "b"
+                cache[(2, 1)] = "c"
+
+                items = cache.get_multi((1,))
+                assert list(items) == [((1, 1), "a"), ((1, 2), "b")]
+
+            Returns:
+                Either default if the key doesn't exist, or a generator of the
+                key/value pairs.
+            """
+
+            assert isinstance(cache, TreeCache)
+
+            node = cache.get(key, None)
+            if node is not None:
+                if update_metrics and metrics:
+                    metrics.inc_hits()
+
+                # We store entries in the `TreeCache` with values of type `_Node`,
+                # which we need to unwrap.
+                return (
+                    (full_key, lru_node.value)
+                    for full_key, lru_node in iterate_tree_cache_items(key, node)
+                )
+            else:
+                if update_metrics and metrics:
+                    metrics.inc_misses()
+                return default
+
         @synchronized
         def cache_set(
             key: KT, value: VT, callbacks: Collection[Callable[[], None]] = ()
@@ -674,6 +758,8 @@ class LruCache(Generic[KT, VT]):
         self.setdefault = cache_set_default
         self.pop = cache_pop
         self.del_multi = cache_del_multi
+        if cache_type is TreeCache:
+            self.get_multi = cache_get_multi
         # `invalidate` is exposed for consistency with DeferredCache, so that it can be
         # invalidated by the cache invalidation replication stream.
         self.invalidate = cache_del_multi
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index e78305f787..c1b8ec0c73 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -64,6 +64,15 @@ class TreeCache:
         self.size += 1
 
     def get(self, key, default=None):
+        """When `key` is a full key, fetches the value for the given key (if
+        any).
+
+        If `key` is only a partial key (i.e. a truncated tuple) then returns a
+        `TreeCacheNode`, which can be passed to the `iterate_tree_cache_*`
+        functions to iterate over all entries in the cache with keys that start
+        with the given partial key.
+        """
+
         node = self.root
         for k in key[:-1]:
             node = node.get(k, None)
@@ -139,3 +148,32 @@ def iterate_tree_cache_entry(d):
             yield from iterate_tree_cache_entry(value_d)
     else:
         yield d
+
+
+def iterate_tree_cache_items(key, value):
+    """Helper function to iterate over the leaves of a tree, i.e. a dict of that
+    can contain dicts.
+
+    The provided key is a tuple that will get prepended to the returned keys.
+
+    Example:
+
+        cache = TreeCache()
+        cache[(1, 1)] = "a"
+        cache[(1, 2)] = "b"
+        cache[(2, 1)] = "c"
+
+        tree_node = cache.get((1,))
+
+        items = iterate_tree_cache_items((1,), tree_node)
+        assert list(items) == [((1, 1), "a"), ((1, 2), "b")]
+
+    Returns:
+        A generator yielding key/value pairs.
+    """
+    if isinstance(value, TreeCacheNode):
+        for sub_key, sub_value in value.items():
+            yield from iterate_tree_cache_items((*key, sub_key), sub_value)
+    else:
+        # we've reached a leaf of the tree.
+        yield key, value