From 44bb881096d7ad4d730e2dc31e0094c0324e0970 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 6 Apr 2021 08:58:18 -0400 Subject: Add type hints to expiring cache. (#9730) --- synapse/util/caches/expiringcache.py | 83 ++++++++++++++++++++++-------------- 1 file changed, 51 insertions(+), 32 deletions(-) (limited to 'synapse/util/caches/expiringcache.py') diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index e15f7ee698..4dc3477e89 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -15,40 +15,50 @@ import logging from collections import OrderedDict +from typing import Any, Generic, Optional, TypeVar, Union, overload + +import attr +from typing_extensions import Literal from synapse.config import cache as cache_config from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util import Clock from synapse.util.caches import register_cache logger = logging.getLogger(__name__) -SENTINEL = object() +SENTINEL = object() # type: Any + +T = TypeVar("T") +KT = TypeVar("KT") +VT = TypeVar("VT") -class ExpiringCache: + +class ExpiringCache(Generic[KT, VT]): def __init__( self, - cache_name, - clock, - max_len=0, - expiry_ms=0, - reset_expiry_on_get=False, - iterable=False, + cache_name: str, + clock: Clock, + max_len: int = 0, + expiry_ms: int = 0, + reset_expiry_on_get: bool = False, + iterable: bool = False, ): """ Args: - cache_name (str): Name of this cache, used for logging. - clock (Clock) - max_len (int): Max size of dict. If the dict grows larger than this + cache_name: Name of this cache, used for logging. + clock + max_len: Max size of dict. If the dict grows larger than this then the oldest items get automatically evicted. Default is 0, which indicates there is no max limit. - expiry_ms (int): How long before an item is evicted from the cache + expiry_ms: How long before an item is evicted from the cache in milliseconds. Default is 0, indicating items never get evicted based on time. - reset_expiry_on_get (bool): If true, will reset the expiry time for + reset_expiry_on_get: If true, will reset the expiry time for an item on access. Defaults to False. - iterable (bool): If true, the size is calculated by summing the + iterable: If true, the size is calculated by summing the sizes of all entries, rather than the number of entries. """ self._cache_name = cache_name @@ -62,7 +72,7 @@ class ExpiringCache: self._expiry_ms = expiry_ms self._reset_expiry_on_get = reset_expiry_on_get - self._cache = OrderedDict() + self._cache = OrderedDict() # type: OrderedDict[KT, _CacheEntry] self.iterable = iterable @@ -79,12 +89,12 @@ class ExpiringCache: self._clock.looping_call(f, self._expiry_ms / 2) - def __setitem__(self, key, value): + def __setitem__(self, key: KT, value: VT) -> None: now = self._clock.time_msec() self._cache[key] = _CacheEntry(now, value) self.evict() - def evict(self): + def evict(self) -> None: # Evict if there are now too many items while self._max_size and len(self) > self._max_size: _key, value = self._cache.popitem(last=False) @@ -93,7 +103,7 @@ class ExpiringCache: else: self.metrics.inc_evictions() - def __getitem__(self, key): + def __getitem__(self, key: KT) -> VT: try: entry = self._cache[key] self.metrics.inc_hits() @@ -106,7 +116,7 @@ class ExpiringCache: return entry.value - def pop(self, key, default=SENTINEL): + def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]: """Removes and returns the value with the given key from the cache. If the key isn't in the cache then `default` will be returned if @@ -115,29 +125,40 @@ class ExpiringCache: Identical functionality to `dict.pop(..)`. """ - value = self._cache.pop(key, default) + value = self._cache.pop(key, SENTINEL) + # The key was not found. if value is SENTINEL: - raise KeyError(key) + if default is SENTINEL: + raise KeyError(key) + return default - return value + return value.value - def __contains__(self, key): + def __contains__(self, key: KT) -> bool: return key in self._cache - def get(self, key, default=None): + @overload + def get(self, key: KT, default: Literal[None] = None) -> Optional[VT]: + ... + + @overload + def get(self, key: KT, default: T) -> Union[VT, T]: + ... + + def get(self, key: KT, default: Optional[T] = None) -> Union[VT, Optional[T]]: try: return self[key] except KeyError: return default - def setdefault(self, key, value): + def setdefault(self, key: KT, value: VT) -> VT: try: return self[key] except KeyError: self[key] = value return value - def _prune_cache(self): + def _prune_cache(self) -> None: if not self._expiry_ms: # zero expiry time means don't expire. This should never get called # since we have this check in start too. @@ -166,7 +187,7 @@ class ExpiringCache: len(self), ) - def __len__(self): + def __len__(self) -> int: if self.iterable: return sum(len(entry.value) for entry in self._cache.values()) else: @@ -190,9 +211,7 @@ class ExpiringCache: return False +@attr.s(slots=True) class _CacheEntry: - __slots__ = ["time", "value"] - - def __init__(self, time, value): - self.time = time - self.value = value + time = attr.ib(type=int) + value = attr.ib() -- cgit 1.4.1