summary refs log tree commit diff
path: root/synapse/util/caches/lrucache.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches/lrucache.py')
-rw-r--r--synapse/util/caches/lrucache.py77
1 files changed, 41 insertions, 36 deletions
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 39dce9dd41..a0a7a9de32 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -40,7 +40,7 @@ from twisted.internet.interfaces import IReactorTime
 from synapse.config import cache as cache_config
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.util import Clock, caches
-from synapse.util.caches import CacheMetric, register_cache
+from synapse.util.caches import CacheMetric, EvictionReason, register_cache
 from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
 from synapse.util.linked_list import ListNode
 
@@ -52,7 +52,7 @@ logger = logging.getLogger(__name__)
 try:
     from pympler.asizeof import Asizer
 
-    def _get_size_of(val: Any, *, recurse=True) -> int:
+    def _get_size_of(val: Any, *, recurse: bool = True) -> int:
         """Get an estimate of the size in bytes of the object.
 
         Args:
@@ -71,7 +71,7 @@ try:
 
 except ImportError:
 
-    def _get_size_of(val: Any, *, recurse=True) -> int:
+    def _get_size_of(val: Any, *, recurse: bool = True) -> int:
         return 0
 
 
@@ -85,15 +85,6 @@ VT = TypeVar("VT")
 # a general type var, distinct from either KT or VT
 T = TypeVar("T")
 
-
-def enumerate_leaves(node, depth):
-    if depth == 0:
-        yield node
-    else:
-        for n in node.values():
-            yield from enumerate_leaves(n, depth - 1)
-
-
 P = TypeVar("P")
 
 
@@ -102,7 +93,7 @@ class _TimedListNode(ListNode[P]):
 
     __slots__ = ["last_access_ts_secs"]
 
-    def update_last_access(self, clock: Clock):
+    def update_last_access(self, clock: Clock) -> None:
         self.last_access_ts_secs = int(clock.time())
 
 
@@ -115,7 +106,7 @@ GLOBAL_ROOT = ListNode["_Node"].create_root_node()
 
 
 @wrap_as_background_process("LruCache._expire_old_entries")
-async def _expire_old_entries(clock: Clock, expiry_seconds: int):
+async def _expire_old_entries(clock: Clock, expiry_seconds: int) -> None:
     """Walks the global cache list to find cache entries that haven't been
     accessed in the given number of seconds.
     """
@@ -163,7 +154,7 @@ async def _expire_old_entries(clock: Clock, expiry_seconds: int):
     logger.info("Dropped %d items from caches", i)
 
 
-def setup_expire_lru_cache_entries(hs: "HomeServer"):
+def setup_expire_lru_cache_entries(hs: "HomeServer") -> None:
     """Start a background job that expires all cache entries if they have not
     been accessed for the given number of seconds.
     """
@@ -183,7 +174,7 @@ def setup_expire_lru_cache_entries(hs: "HomeServer"):
     )
 
 
-class _Node:
+class _Node(Generic[KT, VT]):
     __slots__ = [
         "_list_node",
         "_global_list_node",
@@ -197,15 +188,16 @@ class _Node:
     def __init__(
         self,
         root: "ListNode[_Node]",
-        key,
-        value,
+        key: KT,
+        value: VT,
         cache: "weakref.ReferenceType[LruCache]",
         clock: Clock,
         callbacks: Collection[Callable[[], None]] = (),
+        prune_unread_entries: bool = True,
     ):
         self._list_node = ListNode.insert_after(self, root)
-        self._global_list_node = None
-        if USE_GLOBAL_LIST:
+        self._global_list_node: Optional[_TimedListNode] = None
+        if USE_GLOBAL_LIST and prune_unread_entries:
             self._global_list_node = _TimedListNode.insert_after(self, GLOBAL_ROOT)
             self._global_list_node.update_last_access(clock)
 
@@ -314,6 +306,7 @@ class LruCache(Generic[KT, VT]):
         metrics_collection_callback: Optional[Callable[[], None]] = None,
         apply_cache_factor_from_config: bool = True,
         clock: Optional[Clock] = None,
+        prune_unread_entries: bool = True,
     ):
         """
         Args:
@@ -403,11 +396,11 @@ class LruCache(Generic[KT, VT]):
                 evicted_len = delete_node(node)
                 cache.pop(node.key, None)
                 if metrics:
-                    metrics.inc_evictions(evicted_len)
+                    metrics.inc_evictions(EvictionReason.size, evicted_len)
 
         def synchronized(f: FT) -> FT:
             @wraps(f)
-            def inner(*args, **kwargs):
+            def inner(*args: Any, **kwargs: Any) -> Any:
                 with lock:
                     return f(*args, **kwargs)
 
@@ -416,18 +409,28 @@ class LruCache(Generic[KT, VT]):
         cached_cache_len = [0]
         if size_callback is not None:
 
-            def cache_len():
+            def cache_len() -> int:
                 return cached_cache_len[0]
 
         else:
 
-            def cache_len():
+            def cache_len() -> int:
                 return len(cache)
 
         self.len = synchronized(cache_len)
 
-        def add_node(key, value, callbacks: Collection[Callable[[], None]] = ()):
-            node = _Node(list_root, key, value, weak_ref_to_self, real_clock, callbacks)
+        def add_node(
+            key: KT, value: VT, callbacks: Collection[Callable[[], None]] = ()
+        ) -> None:
+            node = _Node(
+                list_root,
+                key,
+                value,
+                weak_ref_to_self,
+                real_clock,
+                callbacks,
+                prune_unread_entries,
+            )
             cache[key] = node
 
             if size_callback:
@@ -436,7 +439,7 @@ class LruCache(Generic[KT, VT]):
             if caches.TRACK_MEMORY_USAGE and metrics:
                 metrics.inc_memory_usage(node.memory)
 
-        def move_node_to_front(node: _Node):
+        def move_node_to_front(node: _Node) -> None:
             node.move_to_front(real_clock, list_root)
 
         def delete_node(node: _Node) -> int:
@@ -478,7 +481,7 @@ class LruCache(Generic[KT, VT]):
             default: Optional[T] = None,
             callbacks: Collection[Callable[[], None]] = (),
             update_metrics: bool = True,
-        ):
+        ) -> Union[None, T, VT]:
             node = cache.get(key, None)
             if node is not None:
                 move_node_to_front(node)
@@ -492,7 +495,9 @@ class LruCache(Generic[KT, VT]):
                 return default
 
         @synchronized
-        def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = ()):
+        def cache_set(
+            key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = ()
+        ) -> None:
             node = cache.get(key, None)
             if node is not None:
                 # We sometimes store large objects, e.g. dicts, which cause
@@ -537,7 +542,7 @@ class LruCache(Generic[KT, VT]):
             ...
 
         @synchronized
-        def cache_pop(key: KT, default: Optional[T] = None):
+        def cache_pop(key: KT, default: Optional[T] = None) -> Union[None, T, VT]:
             node = cache.get(key, None)
             if node:
                 delete_node(node)
@@ -602,25 +607,25 @@ class LruCache(Generic[KT, VT]):
         self.contains = cache_contains
         self.clear = cache_clear
 
-    def __getitem__(self, key):
+    def __getitem__(self, key: KT) -> VT:
         result = self.get(key, self.sentinel)
         if result is self.sentinel:
             raise KeyError()
         else:
-            return result
+            return cast(VT, result)
 
-    def __setitem__(self, key, value):
+    def __setitem__(self, key: KT, value: VT) -> None:
         self.set(key, value)
 
-    def __delitem__(self, key, value):
+    def __delitem__(self, key: KT, value: VT) -> None:
         result = self.pop(key, self.sentinel)
         if result is self.sentinel:
             raise KeyError()
 
-    def __len__(self):
+    def __len__(self) -> int:
         return self.len()
 
-    def __contains__(self, key):
+    def __contains__(self, key: KT) -> bool:
         return self.contains(key)
 
     def set_cache_factor(self, factor: float) -> bool: