diff options
Diffstat (limited to 'synapse/util/caches/lrucache.py')
-rw-r--r-- | synapse/util/caches/lrucache.py | 42 |
1 files changed, 17 insertions, 25 deletions
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index eb96f7e665..a0a7a9de32 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -15,15 +15,14 @@ import logging import threading import weakref -from enum import Enum from functools import wraps from typing import ( TYPE_CHECKING, Any, Callable, Collection, - Dict, Generic, + Iterable, List, Optional, Type, @@ -191,7 +190,7 @@ class _Node(Generic[KT, VT]): root: "ListNode[_Node]", key: KT, value: VT, - cache: "weakref.ReferenceType[LruCache[KT, VT]]", + cache: "weakref.ReferenceType[LruCache]", clock: Clock, callbacks: Collection[Callable[[], None]] = (), prune_unread_entries: bool = True, @@ -271,10 +270,7 @@ class _Node(Generic[KT, VT]): removed from all lists. """ cache = self._cache() - if ( - cache is None - or cache.pop(self.key, _Sentinel.sentinel) is _Sentinel.sentinel - ): + if not cache or not cache.pop(self.key, None): # `cache.pop` should call `drop_from_lists()`, unless this Node had # already been removed from the cache. self.drop_from_lists() @@ -294,12 +290,6 @@ class _Node(Generic[KT, VT]): self._global_list_node.update_last_access(clock) -class _Sentinel(Enum): - # defining a sentinel in this way allows mypy to correctly handle the - # type of a dictionary lookup. - sentinel = object() - - class LruCache(Generic[KT, VT]): """ Least-recently-used cache, supporting prometheus metrics and invalidation callbacks. @@ -312,7 +302,7 @@ class LruCache(Generic[KT, VT]): max_size: int, cache_name: Optional[str] = None, cache_type: Type[Union[dict, TreeCache]] = dict, - size_callback: Optional[Callable[[VT], int]] = None, + size_callback: Optional[Callable] = None, metrics_collection_callback: Optional[Callable[[], None]] = None, apply_cache_factor_from_config: bool = True, clock: Optional[Clock] = None, @@ -349,7 +339,7 @@ class LruCache(Generic[KT, VT]): else: real_clock = clock - cache: Union[Dict[KT, _Node[KT, VT]], TreeCache] = cache_type() + cache = cache_type() self.cache = cache # Used for introspection. self.apply_cache_factor_from_config = apply_cache_factor_from_config @@ -384,7 +374,7 @@ class LruCache(Generic[KT, VT]): # creating more each time we create a `_Node`. weak_ref_to_self = weakref.ref(self) - list_root = ListNode[_Node[KT, VT]].create_root_node() + list_root = ListNode[_Node].create_root_node() lock = threading.Lock() @@ -432,7 +422,7 @@ class LruCache(Generic[KT, VT]): def add_node( key: KT, value: VT, callbacks: Collection[Callable[[], None]] = () ) -> None: - node: _Node[KT, VT] = _Node( + node = _Node( list_root, key, value, @@ -449,10 +439,10 @@ class LruCache(Generic[KT, VT]): if caches.TRACK_MEMORY_USAGE and metrics: metrics.inc_memory_usage(node.memory) - def move_node_to_front(node: _Node[KT, VT]) -> None: + def move_node_to_front(node: _Node) -> None: node.move_to_front(real_clock, list_root) - def delete_node(node: _Node[KT, VT]) -> int: + def delete_node(node: _Node) -> int: node.drop_from_lists() deleted_len = 1 @@ -506,7 +496,7 @@ class LruCache(Generic[KT, VT]): @synchronized def cache_set( - key: KT, value: VT, callbacks: Collection[Callable[[], None]] = () + key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = () ) -> None: node = cache.get(key, None) if node is not None: @@ -600,6 +590,8 @@ class LruCache(Generic[KT, VT]): def cache_contains(key: KT) -> bool: return key in cache + self.sentinel = object() + # make sure that we clear out any excess entries after we get resized. self._on_resize = evict @@ -616,18 +608,18 @@ class LruCache(Generic[KT, VT]): self.clear = cache_clear def __getitem__(self, key: KT) -> VT: - result = self.get(key, _Sentinel.sentinel) - if result is _Sentinel.sentinel: + result = self.get(key, self.sentinel) + if result is self.sentinel: raise KeyError() else: - return result + return cast(VT, result) def __setitem__(self, key: KT, value: VT) -> None: self.set(key, value) def __delitem__(self, key: KT, value: VT) -> None: - result = self.pop(key, _Sentinel.sentinel) - if result is _Sentinel.sentinel: + result = self.pop(key, self.sentinel) + if result is self.sentinel: raise KeyError() def __len__(self) -> int: |