diff options
Diffstat (limited to 'synapse/util/caches/lrucache.py')
-rw-r--r-- | synapse/util/caches/lrucache.py | 77 |
1 files changed, 41 insertions, 36 deletions
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 39dce9dd41..a0a7a9de32 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -40,7 +40,7 @@ from twisted.internet.interfaces import IReactorTime from synapse.config import cache as cache_config from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.util import Clock, caches -from synapse.util.caches import CacheMetric, register_cache +from synapse.util.caches import CacheMetric, EvictionReason, register_cache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry from synapse.util.linked_list import ListNode @@ -52,7 +52,7 @@ logger = logging.getLogger(__name__) try: from pympler.asizeof import Asizer - def _get_size_of(val: Any, *, recurse=True) -> int: + def _get_size_of(val: Any, *, recurse: bool = True) -> int: """Get an estimate of the size in bytes of the object. Args: @@ -71,7 +71,7 @@ try: except ImportError: - def _get_size_of(val: Any, *, recurse=True) -> int: + def _get_size_of(val: Any, *, recurse: bool = True) -> int: return 0 @@ -85,15 +85,6 @@ VT = TypeVar("VT") # a general type var, distinct from either KT or VT T = TypeVar("T") - -def enumerate_leaves(node, depth): - if depth == 0: - yield node - else: - for n in node.values(): - yield from enumerate_leaves(n, depth - 1) - - P = TypeVar("P") @@ -102,7 +93,7 @@ class _TimedListNode(ListNode[P]): __slots__ = ["last_access_ts_secs"] - def update_last_access(self, clock: Clock): + def update_last_access(self, clock: Clock) -> None: self.last_access_ts_secs = int(clock.time()) @@ -115,7 +106,7 @@ GLOBAL_ROOT = ListNode["_Node"].create_root_node() @wrap_as_background_process("LruCache._expire_old_entries") -async def _expire_old_entries(clock: Clock, expiry_seconds: int): +async def _expire_old_entries(clock: Clock, expiry_seconds: int) -> None: """Walks the global cache list to find cache entries that haven't been accessed in the given number of seconds. """ @@ -163,7 +154,7 @@ async def _expire_old_entries(clock: Clock, expiry_seconds: int): logger.info("Dropped %d items from caches", i) -def setup_expire_lru_cache_entries(hs: "HomeServer"): +def setup_expire_lru_cache_entries(hs: "HomeServer") -> None: """Start a background job that expires all cache entries if they have not been accessed for the given number of seconds. """ @@ -183,7 +174,7 @@ def setup_expire_lru_cache_entries(hs: "HomeServer"): ) -class _Node: +class _Node(Generic[KT, VT]): __slots__ = [ "_list_node", "_global_list_node", @@ -197,15 +188,16 @@ class _Node: def __init__( self, root: "ListNode[_Node]", - key, - value, + key: KT, + value: VT, cache: "weakref.ReferenceType[LruCache]", clock: Clock, callbacks: Collection[Callable[[], None]] = (), + prune_unread_entries: bool = True, ): self._list_node = ListNode.insert_after(self, root) - self._global_list_node = None - if USE_GLOBAL_LIST: + self._global_list_node: Optional[_TimedListNode] = None + if USE_GLOBAL_LIST and prune_unread_entries: self._global_list_node = _TimedListNode.insert_after(self, GLOBAL_ROOT) self._global_list_node.update_last_access(clock) @@ -314,6 +306,7 @@ class LruCache(Generic[KT, VT]): metrics_collection_callback: Optional[Callable[[], None]] = None, apply_cache_factor_from_config: bool = True, clock: Optional[Clock] = None, + prune_unread_entries: bool = True, ): """ Args: @@ -403,11 +396,11 @@ class LruCache(Generic[KT, VT]): evicted_len = delete_node(node) cache.pop(node.key, None) if metrics: - metrics.inc_evictions(evicted_len) + metrics.inc_evictions(EvictionReason.size, evicted_len) def synchronized(f: FT) -> FT: @wraps(f) - def inner(*args, **kwargs): + def inner(*args: Any, **kwargs: Any) -> Any: with lock: return f(*args, **kwargs) @@ -416,18 +409,28 @@ class LruCache(Generic[KT, VT]): cached_cache_len = [0] if size_callback is not None: - def cache_len(): + def cache_len() -> int: return cached_cache_len[0] else: - def cache_len(): + def cache_len() -> int: return len(cache) self.len = synchronized(cache_len) - def add_node(key, value, callbacks: Collection[Callable[[], None]] = ()): - node = _Node(list_root, key, value, weak_ref_to_self, real_clock, callbacks) + def add_node( + key: KT, value: VT, callbacks: Collection[Callable[[], None]] = () + ) -> None: + node = _Node( + list_root, + key, + value, + weak_ref_to_self, + real_clock, + callbacks, + prune_unread_entries, + ) cache[key] = node if size_callback: @@ -436,7 +439,7 @@ class LruCache(Generic[KT, VT]): if caches.TRACK_MEMORY_USAGE and metrics: metrics.inc_memory_usage(node.memory) - def move_node_to_front(node: _Node): + def move_node_to_front(node: _Node) -> None: node.move_to_front(real_clock, list_root) def delete_node(node: _Node) -> int: @@ -478,7 +481,7 @@ class LruCache(Generic[KT, VT]): default: Optional[T] = None, callbacks: Collection[Callable[[], None]] = (), update_metrics: bool = True, - ): + ) -> Union[None, T, VT]: node = cache.get(key, None) if node is not None: move_node_to_front(node) @@ -492,7 +495,9 @@ class LruCache(Generic[KT, VT]): return default @synchronized - def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = ()): + def cache_set( + key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = () + ) -> None: node = cache.get(key, None) if node is not None: # We sometimes store large objects, e.g. dicts, which cause @@ -537,7 +542,7 @@ class LruCache(Generic[KT, VT]): ... @synchronized - def cache_pop(key: KT, default: Optional[T] = None): + def cache_pop(key: KT, default: Optional[T] = None) -> Union[None, T, VT]: node = cache.get(key, None) if node: delete_node(node) @@ -602,25 +607,25 @@ class LruCache(Generic[KT, VT]): self.contains = cache_contains self.clear = cache_clear - def __getitem__(self, key): + def __getitem__(self, key: KT) -> VT: result = self.get(key, self.sentinel) if result is self.sentinel: raise KeyError() else: - return result + return cast(VT, result) - def __setitem__(self, key, value): + def __setitem__(self, key: KT, value: VT) -> None: self.set(key, value) - def __delitem__(self, key, value): + def __delitem__(self, key: KT, value: VT) -> None: result = self.pop(key, self.sentinel) if result is self.sentinel: raise KeyError() - def __len__(self): + def __len__(self) -> int: return self.len() - def __contains__(self, key): + def __contains__(self, key: KT) -> bool: return self.contains(key) def set_cache_factor(self, factor: float) -> bool: |