summary refs log tree commit diff
path: root/synapse/util/caches/lrucache.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches/lrucache.py')
-rw-r--r--synapse/util/caches/lrucache.py90
1 files changed, 88 insertions, 2 deletions
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 31f41fec82..b3bdedb04c 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -25,8 +25,10 @@ from typing import (
     Collection,
     Dict,
     Generic,
+    Iterable,
     List,
     Optional,
+    Tuple,
     Type,
     TypeVar,
     Union,
@@ -44,7 +46,11 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
 from synapse.metrics.jemalloc import get_jemalloc_stats
 from synapse.util import Clock, caches
 from synapse.util.caches import CacheMetric, EvictionReason, register_cache
-from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
+from synapse.util.caches.treecache import (
+    TreeCache,
+    iterate_tree_cache_entry,
+    iterate_tree_cache_items,
+)
 from synapse.util.linked_list import ListNode
 
 if TYPE_CHECKING:
@@ -537,6 +543,7 @@ class LruCache(Generic[KT, VT]):
             default: Literal[None] = None,
             callbacks: Collection[Callable[[], None]] = ...,
             update_metrics: bool = ...,
+            update_last_access: bool = ...,
         ) -> Optional[VT]:
             ...
 
@@ -546,6 +553,7 @@ class LruCache(Generic[KT, VT]):
             default: T,
             callbacks: Collection[Callable[[], None]] = ...,
             update_metrics: bool = ...,
+            update_last_access: bool = ...,
         ) -> Union[T, VT]:
             ...
 
@@ -555,10 +563,27 @@ class LruCache(Generic[KT, VT]):
             default: Optional[T] = None,
             callbacks: Collection[Callable[[], None]] = (),
             update_metrics: bool = True,
+            update_last_access: bool = True,
         ) -> Union[None, T, VT]:
+            """Look up a key in the cache
+
+            Args:
+                key
+                default
+                callbacks: A collection of callbacks that will fire when the
+                    node is removed from the cache (either due to invalidation
+                    or expiry).
+                update_metrics: Whether to update the hit rate metrics
+                update_last_access: Whether to update the last access metrics
+                    on a node if successfully fetched. These metrics are used
+                    to determine when to remove the node from the cache. Set
+                    to False if this fetch should *not* prevent a node from
+                    being expired.
+            """
             node = cache.get(key, None)
             if node is not None:
-                move_node_to_front(node)
+                if update_last_access:
+                    move_node_to_front(node)
                 node.add_callbacks(callbacks)
                 if update_metrics and metrics:
                     metrics.inc_hits()
@@ -568,6 +593,65 @@ class LruCache(Generic[KT, VT]):
                     metrics.inc_misses()
                 return default
 
+        @overload
+        def cache_get_multi(
+            key: tuple,
+            default: Literal[None] = None,
+            update_metrics: bool = True,
+        ) -> Union[None, Iterable[Tuple[KT, VT]]]:
+            ...
+
+        @overload
+        def cache_get_multi(
+            key: tuple,
+            default: T,
+            update_metrics: bool = True,
+        ) -> Union[T, Iterable[Tuple[KT, VT]]]:
+            ...
+
+        @synchronized
+        def cache_get_multi(
+            key: tuple,
+            default: Optional[T] = None,
+            update_metrics: bool = True,
+        ) -> Union[None, T, Iterable[Tuple[KT, VT]]]:
+            """Returns a generator yielding all entries under the given key.
+
+            Can only be used if backed by a tree cache.
+
+            Example:
+
+                cache = LruCache(10, cache_type=TreeCache)
+                cache[(1, 1)] = "a"
+                cache[(1, 2)] = "b"
+                cache[(2, 1)] = "c"
+
+                items = cache.get_multi((1,))
+                assert list(items) == [((1, 1), "a"), ((1, 2), "b")]
+
+            Returns:
+                Either default if the key doesn't exist, or a generator of the
+                key/value pairs.
+            """
+
+            assert isinstance(cache, TreeCache)
+
+            node = cache.get(key, None)
+            if node is not None:
+                if update_metrics and metrics:
+                    metrics.inc_hits()
+
+                # We store entries in the `TreeCache` with values of type `_Node`,
+                # which we need to unwrap.
+                return (
+                    (full_key, lru_node.value)
+                    for full_key, lru_node in iterate_tree_cache_items(key, node)
+                )
+            else:
+                if update_metrics and metrics:
+                    metrics.inc_misses()
+                return default
+
         @synchronized
         def cache_set(
             key: KT, value: VT, callbacks: Collection[Callable[[], None]] = ()
@@ -674,6 +758,8 @@ class LruCache(Generic[KT, VT]):
         self.setdefault = cache_set_default
         self.pop = cache_pop
         self.del_multi = cache_del_multi
+        if cache_type is TreeCache:
+            self.get_multi = cache_get_multi
         # `invalidate` is exposed for consistency with DeferredCache, so that it can be
         # invalidated by the cache invalidation replication stream.
         self.invalidate = cache_del_multi