diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 31f41fec82..b3bdedb04c 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -25,8 +25,10 @@ from typing import (
Collection,
Dict,
Generic,
+ Iterable,
List,
Optional,
+ Tuple,
Type,
TypeVar,
Union,
@@ -44,7 +46,11 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
from synapse.metrics.jemalloc import get_jemalloc_stats
from synapse.util import Clock, caches
from synapse.util.caches import CacheMetric, EvictionReason, register_cache
-from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
+from synapse.util.caches.treecache import (
+ TreeCache,
+ iterate_tree_cache_entry,
+ iterate_tree_cache_items,
+)
from synapse.util.linked_list import ListNode
if TYPE_CHECKING:
@@ -537,6 +543,7 @@ class LruCache(Generic[KT, VT]):
default: Literal[None] = None,
callbacks: Collection[Callable[[], None]] = ...,
update_metrics: bool = ...,
+ update_last_access: bool = ...,
) -> Optional[VT]:
...
@@ -546,6 +553,7 @@ class LruCache(Generic[KT, VT]):
default: T,
callbacks: Collection[Callable[[], None]] = ...,
update_metrics: bool = ...,
+ update_last_access: bool = ...,
) -> Union[T, VT]:
...
@@ -555,10 +563,27 @@ class LruCache(Generic[KT, VT]):
default: Optional[T] = None,
callbacks: Collection[Callable[[], None]] = (),
update_metrics: bool = True,
+ update_last_access: bool = True,
) -> Union[None, T, VT]:
+ """Look up a key in the cache
+
+ Args:
+ key
+ default
+ callbacks: A collection of callbacks that will fire when the
+ node is removed from the cache (either due to invalidation
+ or expiry).
+ update_metrics: Whether to update the hit rate metrics
+ update_last_access: Whether to update the last access metrics
+ on a node if successfully fetched. These metrics are used
+ to determine when to remove the node from the cache. Set
+ to False if this fetch should *not* prevent a node from
+ being expired.
+ """
node = cache.get(key, None)
if node is not None:
- move_node_to_front(node)
+ if update_last_access:
+ move_node_to_front(node)
node.add_callbacks(callbacks)
if update_metrics and metrics:
metrics.inc_hits()
@@ -568,6 +593,65 @@ class LruCache(Generic[KT, VT]):
metrics.inc_misses()
return default
+ @overload
+ def cache_get_multi(
+ key: tuple,
+ default: Literal[None] = None,
+ update_metrics: bool = True,
+ ) -> Union[None, Iterable[Tuple[KT, VT]]]:
+ ...
+
+ @overload
+ def cache_get_multi(
+ key: tuple,
+ default: T,
+ update_metrics: bool = True,
+ ) -> Union[T, Iterable[Tuple[KT, VT]]]:
+ ...
+
+ @synchronized
+ def cache_get_multi(
+ key: tuple,
+ default: Optional[T] = None,
+ update_metrics: bool = True,
+ ) -> Union[None, T, Iterable[Tuple[KT, VT]]]:
+ """Returns a generator yielding all entries under the given key.
+
+ Can only be used if backed by a tree cache.
+
+ Example:
+
+ cache = LruCache(10, cache_type=TreeCache)
+ cache[(1, 1)] = "a"
+ cache[(1, 2)] = "b"
+ cache[(2, 1)] = "c"
+
+ items = cache.get_multi((1,))
+ assert list(items) == [((1, 1), "a"), ((1, 2), "b")]
+
+ Returns:
+ Either default if the key doesn't exist, or a generator of the
+ key/value pairs.
+ """
+
+ assert isinstance(cache, TreeCache)
+
+ node = cache.get(key, None)
+ if node is not None:
+ if update_metrics and metrics:
+ metrics.inc_hits()
+
+ # We store entries in the `TreeCache` with values of type `_Node`,
+ # which we need to unwrap.
+ return (
+ (full_key, lru_node.value)
+ for full_key, lru_node in iterate_tree_cache_items(key, node)
+ )
+ else:
+ if update_metrics and metrics:
+ metrics.inc_misses()
+ return default
+
@synchronized
def cache_set(
key: KT, value: VT, callbacks: Collection[Callable[[], None]] = ()
@@ -674,6 +758,8 @@ class LruCache(Generic[KT, VT]):
self.setdefault = cache_set_default
self.pop = cache_pop
self.del_multi = cache_del_multi
+ if cache_type is TreeCache:
+ self.get_multi = cache_get_multi
# `invalidate` is exposed for consistency with DeferredCache, so that it can be
# invalidated by the cache invalidation replication stream.
self.invalidate = cache_del_multi
|