diff options
Diffstat (limited to 'synapse/util')
-rw-r--r-- | synapse/util/caches/dictionary_cache.py | 64 | ||||
-rw-r--r-- | synapse/util/caches/ttlcache.py | 53 |
2 files changed, 73 insertions, 44 deletions
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index 588d2d49f2..b3b413b02c 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -15,26 +15,38 @@ import enum import logging import threading -from collections import namedtuple -from typing import Any +from typing import Any, Dict, Generic, Iterable, Optional, Set, TypeVar + +import attr from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) -class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "value"))): +# The type of the cache keys. +KT = TypeVar("KT") +# The type of the dictionary keys. +DKT = TypeVar("DKT") + + +@attr.s(slots=True) +class DictionaryEntry: """Returned when getting an entry from the cache Attributes: - full (bool): Whether the cache has the full or dict or just some keys. + full: Whether the cache has the full or dict or just some keys. If not full then not all requested keys will necessarily be present in `value` - known_absent (set): Keys that were looked up in the dict and were not + known_absent: Keys that were looked up in the dict and were not there. - value (dict): The full or partial dict value + value: The full or partial dict value """ + full = attr.ib(type=bool) + known_absent = attr.ib() + value = attr.ib() + def __len__(self): return len(self.value) @@ -45,21 +57,21 @@ class _Sentinel(enum.Enum): sentinel = object() -class DictionaryCache: +class DictionaryCache(Generic[KT, DKT]): """Caches key -> dictionary lookups, supporting caching partial dicts, i.e. fetching a subset of dictionary keys for a particular key. """ - def __init__(self, name, max_entries=1000): + def __init__(self, name: str, max_entries: int = 1000): self.cache = LruCache( max_size=max_entries, cache_name=name, size_callback=len - ) # type: LruCache[Any, DictionaryEntry] + ) # type: LruCache[KT, DictionaryEntry] self.name = name self.sequence = 0 - self.thread = None + self.thread = None # type: Optional[threading.Thread] - def check_thread(self): + def check_thread(self) -> None: expected_thread = self.thread if expected_thread is None: self.thread = threading.current_thread() @@ -69,12 +81,14 @@ class DictionaryCache: "Cache objects can only be accessed from the main thread" ) - def get(self, key, dict_keys=None): + def get( + self, key: KT, dict_keys: Optional[Iterable[DKT]] = None + ) -> DictionaryEntry: """Fetch an entry out of the cache Args: key - dict_key(list): If given a set of keys then return only those keys + dict_key: If given a set of keys then return only those keys that exist in the cache. Returns: @@ -95,7 +109,7 @@ class DictionaryCache: return DictionaryEntry(False, set(), {}) - def invalidate(self, key): + def invalidate(self, key: KT) -> None: self.check_thread() # Increment the sequence number so that any SELECT statements that @@ -103,19 +117,25 @@ class DictionaryCache: self.sequence += 1 self.cache.pop(key, None) - def invalidate_all(self): + def invalidate_all(self) -> None: self.check_thread() self.sequence += 1 self.cache.clear() - def update(self, sequence, key, value, fetched_keys=None): + def update( + self, + sequence: int, + key: KT, + value: Dict[DKT, Any], + fetched_keys: Optional[Set[DKT]] = None, + ) -> None: """Updates the entry in the cache Args: sequence - key (K) - value (dict[X,Y]): The value to update the cache with. - fetched_keys (None|set[X]): All of the dictionary keys which were + key + value: The value to update the cache with. + fetched_keys: All of the dictionary keys which were fetched from the database. If None, this is the complete value for key K. Otherwise, it @@ -131,7 +151,9 @@ class DictionaryCache: else: self._update_or_insert(key, value, fetched_keys) - def _update_or_insert(self, key, value, known_absent): + def _update_or_insert( + self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT] + ) -> None: # We pop and reinsert as we need to tell the cache the size may have # changed @@ -140,5 +162,5 @@ class DictionaryCache: entry.known_absent.update(known_absent) self.cache[key] = entry - def _insert(self, key, value, known_absent): + def _insert(self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]) -> None: self.cache[key] = DictionaryEntry(True, known_absent, value) diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py index 6ce2a3d12b..96a8274940 100644 --- a/synapse/util/caches/ttlcache.py +++ b/synapse/util/caches/ttlcache.py @@ -15,6 +15,7 @@ import logging import time +from typing import Any, Callable, Dict, Generic, Tuple, TypeVar, Union import attr from sortedcontainers import SortedList @@ -23,15 +24,19 @@ from synapse.util.caches import register_cache logger = logging.getLogger(__name__) -SENTINEL = object() +SENTINEL = object() # type: Any +T = TypeVar("T") +KT = TypeVar("KT") +VT = TypeVar("VT") -class TTLCache: + +class TTLCache(Generic[KT, VT]): """A key/value cache implementation where each entry has its own TTL""" - def __init__(self, cache_name, timer=time.time): + def __init__(self, cache_name: str, timer: Callable[[], float] = time.time): # map from key to _CacheEntry - self._data = {} + self._data = {} # type: Dict[KT, _CacheEntry] # the _CacheEntries, sorted by expiry time self._expiry_list = SortedList() # type: SortedList[_CacheEntry] @@ -40,26 +45,27 @@ class TTLCache: self._metrics = register_cache("ttl", cache_name, self, resizable=False) - def set(self, key, value, ttl): + def set(self, key: KT, value: VT, ttl: float) -> None: """Add/update an entry in the cache Args: key: key for this entry value: value for this entry - ttl (float): TTL for this entry, in seconds + ttl: TTL for this entry, in seconds """ expiry = self._timer() + ttl self.expire() e = self._data.pop(key, SENTINEL) - if e != SENTINEL: + if e is not SENTINEL: + assert isinstance(e, _CacheEntry) self._expiry_list.remove(e) entry = _CacheEntry(expiry_time=expiry, ttl=ttl, key=key, value=value) self._data[key] = entry self._expiry_list.add(entry) - def get(self, key, default=SENTINEL): + def get(self, key: KT, default: T = SENTINEL) -> Union[VT, T]: """Get a value from the cache Args: @@ -72,23 +78,23 @@ class TTLCache: """ self.expire() e = self._data.get(key, SENTINEL) - if e == SENTINEL: + if e is SENTINEL: self._metrics.inc_misses() - if default == SENTINEL: + if default is SENTINEL: raise KeyError(key) return default + assert isinstance(e, _CacheEntry) self._metrics.inc_hits() return e.value - def get_with_expiry(self, key): + def get_with_expiry(self, key: KT) -> Tuple[VT, float, float]: """Get a value, and its expiry time, from the cache Args: key: key to look up Returns: - Tuple[Any, float, float]: the value from the cache, the expiry time - and the TTL + A tuple of the value from the cache, the expiry time and the TTL Raises: KeyError if the entry is not found @@ -102,7 +108,7 @@ class TTLCache: self._metrics.inc_hits() return e.value, e.expiry_time, e.ttl - def pop(self, key, default=SENTINEL): + def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]: # type: ignore """Remove a value from the cache If key is in the cache, remove it and return its value, else return default. @@ -118,29 +124,30 @@ class TTLCache: """ self.expire() e = self._data.pop(key, SENTINEL) - if e == SENTINEL: + if e is SENTINEL: self._metrics.inc_misses() - if default == SENTINEL: + if default is SENTINEL: raise KeyError(key) return default + assert isinstance(e, _CacheEntry) self._expiry_list.remove(e) self._metrics.inc_hits() return e.value - def __getitem__(self, key): + def __getitem__(self, key: KT) -> VT: return self.get(key) - def __delitem__(self, key): + def __delitem__(self, key: KT) -> None: self.pop(key) - def __contains__(self, key): + def __contains__(self, key: KT) -> bool: return key in self._data - def __len__(self): + def __len__(self) -> int: self.expire() return len(self._data) - def expire(self): + def expire(self) -> None: """Run the expiry on the cache. Any entries whose expiry times are due will be removed """ @@ -158,7 +165,7 @@ class _CacheEntry: """TTLCache entry""" # expiry_time is the first attribute, so that entries are sorted by expiry. - expiry_time = attr.ib() - ttl = attr.ib() + expiry_time = attr.ib(type=float) + ttl = attr.ib(type=float) key = attr.ib() value = attr.ib() |