summary refs log tree commit diff
path: root/synapse/util/caches/lrucache.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches/lrucache.py')
-rw-r--r--synapse/util/caches/lrucache.py138
1 files changed, 114 insertions, 24 deletions
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 4bc1a67b58..60bb6ff642 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -15,11 +15,35 @@
 
 import threading
 from functools import wraps
-from typing import Callable, Optional, Type, Union
+from typing import (
+    Any,
+    Callable,
+    Generic,
+    Iterable,
+    Optional,
+    Type,
+    TypeVar,
+    Union,
+    cast,
+    overload,
+)
+
+from typing_extensions import Literal
 
 from synapse.config import cache as cache_config
+from synapse.util.caches import CacheMetric, register_cache
 from synapse.util.caches.treecache import TreeCache
 
+# Function type: the type used for invalidation callbacks
+FT = TypeVar("FT", bound=Callable[..., Any])
+
+# Key and Value type for the cache
+KT = TypeVar("KT")
+VT = TypeVar("VT")
+
+# a general type var, distinct from either KT or VT
+T = TypeVar("T")
+
 
 def enumerate_leaves(node, depth):
     if depth == 0:
@@ -41,30 +65,33 @@ class _Node:
         self.callbacks = callbacks
 
 
-class LruCache:
+class LruCache(Generic[KT, VT]):
     """
-    Least-recently-used cache.
+    Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
+
     Supports del_multi only if cache_type=TreeCache
     If cache_type=TreeCache, all keys must be tuples.
-
-    Can also set callbacks on objects when getting/setting which are fired
-    when that key gets invalidated/evicted.
     """
 
     def __init__(
         self,
         max_size: int,
+        cache_name: Optional[str] = None,
         keylen: int = 1,
         cache_type: Type[Union[dict, TreeCache]] = dict,
         size_callback: Optional[Callable] = None,
-        evicted_callback: Optional[Callable] = None,
+        metrics_collection_callback: Optional[Callable[[], None]] = None,
         apply_cache_factor_from_config: bool = True,
     ):
         """
         Args:
             max_size: The maximum amount of entries the cache can hold
 
-            keylen: The length of the tuple used as the cache key
+            cache_name: The name of this cache, for the prometheus metrics. If unset,
+                no metrics will be reported on this cache.
+
+            keylen: The length of the tuple used as the cache key. Ignored unless
+                cache_type is `TreeCache`.
 
             cache_type (type):
                 type of underlying cache to be used. Typically one of dict
@@ -72,9 +99,13 @@ class LruCache:
 
             size_callback (func(V) -> int | None):
 
-            evicted_callback (func(int)|None):
-                if not None, called on eviction with the size of the evicted
-                entry
+            metrics_collection_callback:
+                metrics collection callback. This is called early in the metrics
+                collection process, before any of the metrics registered with the
+                prometheus Registry are collected, so can be used to update any dynamic
+                metrics.
+
+                Ignored if cache_name is None.
 
             apply_cache_factor_from_config (bool): If true, `max_size` will be
                 multiplied by a cache factor derived from the homeserver config
@@ -93,6 +124,23 @@ class LruCache:
         else:
             self.max_size = int(max_size)
 
+        # register_cache might call our "set_cache_factor" callback; there's nothing to
+        # do yet when we get resized.
+        self._on_resize = None  # type: Optional[Callable[[],None]]
+
+        if cache_name is not None:
+            metrics = register_cache(
+                "lru_cache",
+                cache_name,
+                self,
+                collect_callback=metrics_collection_callback,
+            )  # type: Optional[CacheMetric]
+        else:
+            metrics = None
+
+        # this is exposed for access from outside this class
+        self.metrics = metrics
+
         list_root = _Node(None, None, None, None)
         list_root.next_node = list_root
         list_root.prev_node = list_root
@@ -104,16 +152,16 @@ class LruCache:
                 todelete = list_root.prev_node
                 evicted_len = delete_node(todelete)
                 cache.pop(todelete.key, None)
-                if evicted_callback:
-                    evicted_callback(evicted_len)
+                if metrics:
+                    metrics.inc_evictions(evicted_len)
 
-        def synchronized(f):
+        def synchronized(f: FT) -> FT:
             @wraps(f)
             def inner(*args, **kwargs):
                 with lock:
                     return f(*args, **kwargs)
 
-            return inner
+            return cast(FT, inner)
 
         cached_cache_len = [0]
         if size_callback is not None:
@@ -167,18 +215,45 @@ class LruCache:
             node.callbacks.clear()
             return deleted_len
 
+        @overload
+        def cache_get(
+            key: KT,
+            default: Literal[None] = None,
+            callbacks: Iterable[Callable[[], None]] = ...,
+            update_metrics: bool = ...,
+        ) -> Optional[VT]:
+            ...
+
+        @overload
+        def cache_get(
+            key: KT,
+            default: T,
+            callbacks: Iterable[Callable[[], None]] = ...,
+            update_metrics: bool = ...,
+        ) -> Union[T, VT]:
+            ...
+
         @synchronized
-        def cache_get(key, default=None, callbacks=[]):
+        def cache_get(
+            key: KT,
+            default: Optional[T] = None,
+            callbacks: Iterable[Callable[[], None]] = [],
+            update_metrics: bool = True,
+        ):
             node = cache.get(key, None)
             if node is not None:
                 move_node_to_front(node)
                 node.callbacks.update(callbacks)
+                if update_metrics and metrics:
+                    metrics.inc_hits()
                 return node.value
             else:
+                if update_metrics and metrics:
+                    metrics.inc_misses()
                 return default
 
         @synchronized
-        def cache_set(key, value, callbacks=[]):
+        def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []):
             node = cache.get(key, None)
             if node is not None:
                 # We sometimes store large objects, e.g. dicts, which cause
@@ -207,7 +282,7 @@ class LruCache:
             evict()
 
         @synchronized
-        def cache_set_default(key, value):
+        def cache_set_default(key: KT, value: VT) -> VT:
             node = cache.get(key, None)
             if node is not None:
                 return node.value
@@ -216,8 +291,16 @@ class LruCache:
                 evict()
                 return value
 
+        @overload
+        def cache_pop(key: KT, default: Literal[None] = None) -> Optional[VT]:
+            ...
+
+        @overload
+        def cache_pop(key: KT, default: T) -> Union[T, VT]:
+            ...
+
         @synchronized
-        def cache_pop(key, default=None):
+        def cache_pop(key: KT, default: Optional[T] = None):
             node = cache.get(key, None)
             if node:
                 delete_node(node)
@@ -227,18 +310,18 @@ class LruCache:
                 return default
 
         @synchronized
-        def cache_del_multi(key):
+        def cache_del_multi(key: KT) -> None:
             """
             This will only work if constructed with cache_type=TreeCache
             """
             popped = cache.pop(key)
             if popped is None:
                 return
-            for leaf in enumerate_leaves(popped, keylen - len(key)):
+            for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))):
                 delete_node(leaf)
 
         @synchronized
-        def cache_clear():
+        def cache_clear() -> None:
             list_root.next_node = list_root
             list_root.prev_node = list_root
             for node in cache.values():
@@ -249,15 +332,21 @@ class LruCache:
                 cached_cache_len[0] = 0
 
         @synchronized
-        def cache_contains(key):
+        def cache_contains(key: KT) -> bool:
             return key in cache
 
         self.sentinel = object()
+
+        # make sure that we clear out any excess entries after we get resized.
         self._on_resize = evict
+
         self.get = cache_get
         self.set = cache_set
         self.setdefault = cache_set_default
         self.pop = cache_pop
+        # `invalidate` is exposed for consistency with DeferredCache, so that it can be
+        # invalidated by the cache invalidation replication stream.
+        self.invalidate = cache_pop
         if cache_type is TreeCache:
             self.del_multi = cache_del_multi
         self.len = synchronized(cache_len)
@@ -301,6 +390,7 @@ class LruCache:
         new_size = int(self._original_max_size * factor)
         if new_size != self.max_size:
             self.max_size = new_size
-            self._on_resize()
+            if self._on_resize:
+                self._on_resize()
             return True
         return False