diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 4bc1a67b58..60bb6ff642 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -15,11 +15,35 @@
import threading
from functools import wraps
-from typing import Callable, Optional, Type, Union
+from typing import (
+ Any,
+ Callable,
+ Generic,
+ Iterable,
+ Optional,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+ overload,
+)
+
+from typing_extensions import Literal
from synapse.config import cache as cache_config
+from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.treecache import TreeCache
+# Function type: the type used for invalidation callbacks
+FT = TypeVar("FT", bound=Callable[..., Any])
+
+# Key and Value type for the cache
+KT = TypeVar("KT")
+VT = TypeVar("VT")
+
+# a general type var, distinct from either KT or VT
+T = TypeVar("T")
+
def enumerate_leaves(node, depth):
if depth == 0:
@@ -41,30 +65,33 @@ class _Node:
self.callbacks = callbacks
-class LruCache:
+class LruCache(Generic[KT, VT]):
"""
- Least-recently-used cache.
+ Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
+
Supports del_multi only if cache_type=TreeCache
If cache_type=TreeCache, all keys must be tuples.
-
- Can also set callbacks on objects when getting/setting which are fired
- when that key gets invalidated/evicted.
"""
def __init__(
self,
max_size: int,
+ cache_name: Optional[str] = None,
keylen: int = 1,
cache_type: Type[Union[dict, TreeCache]] = dict,
size_callback: Optional[Callable] = None,
- evicted_callback: Optional[Callable] = None,
+ metrics_collection_callback: Optional[Callable[[], None]] = None,
apply_cache_factor_from_config: bool = True,
):
"""
Args:
max_size: The maximum amount of entries the cache can hold
- keylen: The length of the tuple used as the cache key
+ cache_name: The name of this cache, for the prometheus metrics. If unset,
+ no metrics will be reported on this cache.
+
+ keylen: The length of the tuple used as the cache key. Ignored unless
+ cache_type is `TreeCache`.
cache_type (type):
type of underlying cache to be used. Typically one of dict
@@ -72,9 +99,13 @@ class LruCache:
size_callback (func(V) -> int | None):
- evicted_callback (func(int)|None):
- if not None, called on eviction with the size of the evicted
- entry
+ metrics_collection_callback:
+ metrics collection callback. This is called early in the metrics
+ collection process, before any of the metrics registered with the
+ prometheus Registry are collected, so can be used to update any dynamic
+ metrics.
+
+ Ignored if cache_name is None.
apply_cache_factor_from_config (bool): If true, `max_size` will be
multiplied by a cache factor derived from the homeserver config
@@ -93,6 +124,23 @@ class LruCache:
else:
self.max_size = int(max_size)
+ # register_cache might call our "set_cache_factor" callback; there's nothing to
+ # do yet when we get resized.
+ self._on_resize = None # type: Optional[Callable[[],None]]
+
+ if cache_name is not None:
+ metrics = register_cache(
+ "lru_cache",
+ cache_name,
+ self,
+ collect_callback=metrics_collection_callback,
+ ) # type: Optional[CacheMetric]
+ else:
+ metrics = None
+
+ # this is exposed for access from outside this class
+ self.metrics = metrics
+
list_root = _Node(None, None, None, None)
list_root.next_node = list_root
list_root.prev_node = list_root
@@ -104,16 +152,16 @@ class LruCache:
todelete = list_root.prev_node
evicted_len = delete_node(todelete)
cache.pop(todelete.key, None)
- if evicted_callback:
- evicted_callback(evicted_len)
+ if metrics:
+ metrics.inc_evictions(evicted_len)
- def synchronized(f):
+ def synchronized(f: FT) -> FT:
@wraps(f)
def inner(*args, **kwargs):
with lock:
return f(*args, **kwargs)
- return inner
+ return cast(FT, inner)
cached_cache_len = [0]
if size_callback is not None:
@@ -167,18 +215,45 @@ class LruCache:
node.callbacks.clear()
return deleted_len
+ @overload
+ def cache_get(
+ key: KT,
+ default: Literal[None] = None,
+ callbacks: Iterable[Callable[[], None]] = ...,
+ update_metrics: bool = ...,
+ ) -> Optional[VT]:
+ ...
+
+ @overload
+ def cache_get(
+ key: KT,
+ default: T,
+ callbacks: Iterable[Callable[[], None]] = ...,
+ update_metrics: bool = ...,
+ ) -> Union[T, VT]:
+ ...
+
@synchronized
- def cache_get(key, default=None, callbacks=[]):
+ def cache_get(
+ key: KT,
+ default: Optional[T] = None,
+ callbacks: Iterable[Callable[[], None]] = [],
+ update_metrics: bool = True,
+ ):
node = cache.get(key, None)
if node is not None:
move_node_to_front(node)
node.callbacks.update(callbacks)
+ if update_metrics and metrics:
+ metrics.inc_hits()
return node.value
else:
+ if update_metrics and metrics:
+ metrics.inc_misses()
return default
@synchronized
- def cache_set(key, value, callbacks=[]):
+ def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []):
node = cache.get(key, None)
if node is not None:
# We sometimes store large objects, e.g. dicts, which cause
@@ -207,7 +282,7 @@ class LruCache:
evict()
@synchronized
- def cache_set_default(key, value):
+ def cache_set_default(key: KT, value: VT) -> VT:
node = cache.get(key, None)
if node is not None:
return node.value
@@ -216,8 +291,16 @@ class LruCache:
evict()
return value
+ @overload
+ def cache_pop(key: KT, default: Literal[None] = None) -> Optional[VT]:
+ ...
+
+ @overload
+ def cache_pop(key: KT, default: T) -> Union[T, VT]:
+ ...
+
@synchronized
- def cache_pop(key, default=None):
+ def cache_pop(key: KT, default: Optional[T] = None):
node = cache.get(key, None)
if node:
delete_node(node)
@@ -227,18 +310,18 @@ class LruCache:
return default
@synchronized
- def cache_del_multi(key):
+ def cache_del_multi(key: KT) -> None:
"""
This will only work if constructed with cache_type=TreeCache
"""
popped = cache.pop(key)
if popped is None:
return
- for leaf in enumerate_leaves(popped, keylen - len(key)):
+ for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))):
delete_node(leaf)
@synchronized
- def cache_clear():
+ def cache_clear() -> None:
list_root.next_node = list_root
list_root.prev_node = list_root
for node in cache.values():
@@ -249,15 +332,21 @@ class LruCache:
cached_cache_len[0] = 0
@synchronized
- def cache_contains(key):
+ def cache_contains(key: KT) -> bool:
return key in cache
self.sentinel = object()
+
+ # make sure that we clear out any excess entries after we get resized.
self._on_resize = evict
+
self.get = cache_get
self.set = cache_set
self.setdefault = cache_set_default
self.pop = cache_pop
+ # `invalidate` is exposed for consistency with DeferredCache, so that it can be
+ # invalidated by the cache invalidation replication stream.
+ self.invalidate = cache_pop
if cache_type is TreeCache:
self.del_multi = cache_del_multi
self.len = synchronized(cache_len)
@@ -301,6 +390,7 @@ class LruCache:
new_size = int(self._original_max_size * factor)
if new_size != self.max_size:
self.max_size = new_size
- self._on_resize()
+ if self._on_resize:
+ self._on_resize()
return True
return False
|