summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/async_helpers.py2
-rw-r--r--synapse/util/caches/__init__.py6
-rw-r--r--synapse/util/caches/deferred_cache.py4
-rw-r--r--synapse/util/caches/dictionary_cache.py64
-rw-r--r--synapse/util/caches/expiringcache.py83
-rw-r--r--synapse/util/caches/ttlcache.py53
-rw-r--r--synapse/util/frozenutils.py2
7 files changed, 132 insertions, 82 deletions
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py

index f33c115844..c3b2d981ea 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py
@@ -496,7 +496,7 @@ def timeout_deferred( try: deferred.cancel() - except: # noqa: E722, if we throw any exception it'll break time outs + except Exception: # if we throw any exception it'll break time outs logger.exception("Canceller failed during timeout") # the cancel() call should have set off a chain of errbacks which diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index e676c2cac4..48f64eeb38 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py
@@ -25,8 +25,8 @@ from synapse.config.cache import add_resizable_cache logger = logging.getLogger(__name__) -caches_by_name = {} -collectors_by_name = {} # type: Dict +caches_by_name = {} # type: Dict[str, Sized] +collectors_by_name = {} # type: Dict[str, CacheMetric] cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"]) cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"]) @@ -116,7 +116,7 @@ def register_cache( """ if resizable: if not resize_callback: - resize_callback = getattr(cache, "set_cache_factor") + resize_callback = cache.set_cache_factor # type: ignore add_resizable_cache(cache_name, resize_callback) metric = CacheMetric(cache, cache_type, cache_name, collect_callback) diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 1adc92eb90..dd392cf694 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py
@@ -283,7 +283,9 @@ class DeferredCache(Generic[KT, VT]): # we return a new Deferred which will be called before any subsequent observers. return observable.observe() - def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None): + def prefill( + self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None + ): callbacks = [callback] if callback else [] self.cache.set(key, value, callbacks=callbacks) diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 588d2d49f2..b3b413b02c 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py
@@ -15,26 +15,38 @@ import enum import logging import threading -from collections import namedtuple -from typing import Any +from typing import Any, Dict, Generic, Iterable, Optional, Set, TypeVar + +import attr from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) -class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "value"))): +# The type of the cache keys. +KT = TypeVar("KT") +# The type of the dictionary keys. +DKT = TypeVar("DKT") + + +@attr.s(slots=True) +class DictionaryEntry: """Returned when getting an entry from the cache Attributes: - full (bool): Whether the cache has the full or dict or just some keys. + full: Whether the cache has the full or dict or just some keys. If not full then not all requested keys will necessarily be present in `value` - known_absent (set): Keys that were looked up in the dict and were not + known_absent: Keys that were looked up in the dict and were not there. - value (dict): The full or partial dict value + value: The full or partial dict value """ + full = attr.ib(type=bool) + known_absent = attr.ib() + value = attr.ib() + def __len__(self): return len(self.value) @@ -45,21 +57,21 @@ class _Sentinel(enum.Enum): sentinel = object() -class DictionaryCache: +class DictionaryCache(Generic[KT, DKT]): """Caches key -> dictionary lookups, supporting caching partial dicts, i.e. fetching a subset of dictionary keys for a particular key. """ - def __init__(self, name, max_entries=1000): + def __init__(self, name: str, max_entries: int = 1000): self.cache = LruCache( max_size=max_entries, cache_name=name, size_callback=len - ) # type: LruCache[Any, DictionaryEntry] + ) # type: LruCache[KT, DictionaryEntry] self.name = name self.sequence = 0 - self.thread = None + self.thread = None # type: Optional[threading.Thread] - def check_thread(self): + def check_thread(self) -> None: expected_thread = self.thread if expected_thread is None: self.thread = threading.current_thread() @@ -69,12 +81,14 @@ class DictionaryCache: "Cache objects can only be accessed from the main thread" ) - def get(self, key, dict_keys=None): + def get( + self, key: KT, dict_keys: Optional[Iterable[DKT]] = None + ) -> DictionaryEntry: """Fetch an entry out of the cache Args: key - dict_key(list): If given a set of keys then return only those keys + dict_key: If given a set of keys then return only those keys that exist in the cache. Returns: @@ -95,7 +109,7 @@ class DictionaryCache: return DictionaryEntry(False, set(), {}) - def invalidate(self, key): + def invalidate(self, key: KT) -> None: self.check_thread() # Increment the sequence number so that any SELECT statements that @@ -103,19 +117,25 @@ class DictionaryCache: self.sequence += 1 self.cache.pop(key, None) - def invalidate_all(self): + def invalidate_all(self) -> None: self.check_thread() self.sequence += 1 self.cache.clear() - def update(self, sequence, key, value, fetched_keys=None): + def update( + self, + sequence: int, + key: KT, + value: Dict[DKT, Any], + fetched_keys: Optional[Set[DKT]] = None, + ) -> None: """Updates the entry in the cache Args: sequence - key (K) - value (dict[X,Y]): The value to update the cache with. - fetched_keys (None|set[X]): All of the dictionary keys which were + key + value: The value to update the cache with. + fetched_keys: All of the dictionary keys which were fetched from the database. If None, this is the complete value for key K. Otherwise, it @@ -131,7 +151,9 @@ class DictionaryCache: else: self._update_or_insert(key, value, fetched_keys) - def _update_or_insert(self, key, value, known_absent): + def _update_or_insert( + self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT] + ) -> None: # We pop and reinsert as we need to tell the cache the size may have # changed @@ -140,5 +162,5 @@ class DictionaryCache: entry.known_absent.update(known_absent) self.cache[key] = entry - def _insert(self, key, value, known_absent): + def _insert(self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]) -> None: self.cache[key] = DictionaryEntry(True, known_absent, value) diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index e15f7ee698..4dc3477e89 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py
@@ -15,40 +15,50 @@ import logging from collections import OrderedDict +from typing import Any, Generic, Optional, TypeVar, Union, overload + +import attr +from typing_extensions import Literal from synapse.config import cache as cache_config from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util import Clock from synapse.util.caches import register_cache logger = logging.getLogger(__name__) -SENTINEL = object() +SENTINEL = object() # type: Any + +T = TypeVar("T") +KT = TypeVar("KT") +VT = TypeVar("VT") -class ExpiringCache: + +class ExpiringCache(Generic[KT, VT]): def __init__( self, - cache_name, - clock, - max_len=0, - expiry_ms=0, - reset_expiry_on_get=False, - iterable=False, + cache_name: str, + clock: Clock, + max_len: int = 0, + expiry_ms: int = 0, + reset_expiry_on_get: bool = False, + iterable: bool = False, ): """ Args: - cache_name (str): Name of this cache, used for logging. - clock (Clock) - max_len (int): Max size of dict. If the dict grows larger than this + cache_name: Name of this cache, used for logging. + clock + max_len: Max size of dict. If the dict grows larger than this then the oldest items get automatically evicted. Default is 0, which indicates there is no max limit. - expiry_ms (int): How long before an item is evicted from the cache + expiry_ms: How long before an item is evicted from the cache in milliseconds. Default is 0, indicating items never get evicted based on time. - reset_expiry_on_get (bool): If true, will reset the expiry time for + reset_expiry_on_get: If true, will reset the expiry time for an item on access. Defaults to False. - iterable (bool): If true, the size is calculated by summing the + iterable: If true, the size is calculated by summing the sizes of all entries, rather than the number of entries. """ self._cache_name = cache_name @@ -62,7 +72,7 @@ class ExpiringCache: self._expiry_ms = expiry_ms self._reset_expiry_on_get = reset_expiry_on_get - self._cache = OrderedDict() + self._cache = OrderedDict() # type: OrderedDict[KT, _CacheEntry] self.iterable = iterable @@ -79,12 +89,12 @@ class ExpiringCache: self._clock.looping_call(f, self._expiry_ms / 2) - def __setitem__(self, key, value): + def __setitem__(self, key: KT, value: VT) -> None: now = self._clock.time_msec() self._cache[key] = _CacheEntry(now, value) self.evict() - def evict(self): + def evict(self) -> None: # Evict if there are now too many items while self._max_size and len(self) > self._max_size: _key, value = self._cache.popitem(last=False) @@ -93,7 +103,7 @@ class ExpiringCache: else: self.metrics.inc_evictions() - def __getitem__(self, key): + def __getitem__(self, key: KT) -> VT: try: entry = self._cache[key] self.metrics.inc_hits() @@ -106,7 +116,7 @@ class ExpiringCache: return entry.value - def pop(self, key, default=SENTINEL): + def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]: """Removes and returns the value with the given key from the cache. If the key isn't in the cache then `default` will be returned if @@ -115,29 +125,40 @@ class ExpiringCache: Identical functionality to `dict.pop(..)`. """ - value = self._cache.pop(key, default) + value = self._cache.pop(key, SENTINEL) + # The key was not found. if value is SENTINEL: - raise KeyError(key) + if default is SENTINEL: + raise KeyError(key) + return default - return value + return value.value - def __contains__(self, key): + def __contains__(self, key: KT) -> bool: return key in self._cache - def get(self, key, default=None): + @overload + def get(self, key: KT, default: Literal[None] = None) -> Optional[VT]: + ... + + @overload + def get(self, key: KT, default: T) -> Union[VT, T]: + ... + + def get(self, key: KT, default: Optional[T] = None) -> Union[VT, Optional[T]]: try: return self[key] except KeyError: return default - def setdefault(self, key, value): + def setdefault(self, key: KT, value: VT) -> VT: try: return self[key] except KeyError: self[key] = value return value - def _prune_cache(self): + def _prune_cache(self) -> None: if not self._expiry_ms: # zero expiry time means don't expire. This should never get called # since we have this check in start too. @@ -166,7 +187,7 @@ class ExpiringCache: len(self), ) - def __len__(self): + def __len__(self) -> int: if self.iterable: return sum(len(entry.value) for entry in self._cache.values()) else: @@ -190,9 +211,7 @@ class ExpiringCache: return False +@attr.s(slots=True) class _CacheEntry: - __slots__ = ["time", "value"] - - def __init__(self, time, value): - self.time = time - self.value = value + time = attr.ib(type=int) + value = attr.ib() diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index 6ce2a3d12b..96a8274940 100644 --- a/synapse/util/caches/ttlcache.py +++ b/synapse/util/caches/ttlcache.py
@@ -15,6 +15,7 @@ import logging import time +from typing import Any, Callable, Dict, Generic, Tuple, TypeVar, Union import attr from sortedcontainers import SortedList @@ -23,15 +24,19 @@ from synapse.util.caches import register_cache logger = logging.getLogger(__name__) -SENTINEL = object() +SENTINEL = object() # type: Any +T = TypeVar("T") +KT = TypeVar("KT") +VT = TypeVar("VT") -class TTLCache: + +class TTLCache(Generic[KT, VT]): """A key/value cache implementation where each entry has its own TTL""" - def __init__(self, cache_name, timer=time.time): + def __init__(self, cache_name: str, timer: Callable[[], float] = time.time): # map from key to _CacheEntry - self._data = {} + self._data = {} # type: Dict[KT, _CacheEntry] # the _CacheEntries, sorted by expiry time self._expiry_list = SortedList() # type: SortedList[_CacheEntry] @@ -40,26 +45,27 @@ class TTLCache: self._metrics = register_cache("ttl", cache_name, self, resizable=False) - def set(self, key, value, ttl): + def set(self, key: KT, value: VT, ttl: float) -> None: """Add/update an entry in the cache Args: key: key for this entry value: value for this entry - ttl (float): TTL for this entry, in seconds + ttl: TTL for this entry, in seconds """ expiry = self._timer() + ttl self.expire() e = self._data.pop(key, SENTINEL) - if e != SENTINEL: + if e is not SENTINEL: + assert isinstance(e, _CacheEntry) self._expiry_list.remove(e) entry = _CacheEntry(expiry_time=expiry, ttl=ttl, key=key, value=value) self._data[key] = entry self._expiry_list.add(entry) - def get(self, key, default=SENTINEL): + def get(self, key: KT, default: T = SENTINEL) -> Union[VT, T]: """Get a value from the cache Args: @@ -72,23 +78,23 @@ class TTLCache: """ self.expire() e = self._data.get(key, SENTINEL) - if e == SENTINEL: + if e is SENTINEL: self._metrics.inc_misses() - if default == SENTINEL: + if default is SENTINEL: raise KeyError(key) return default + assert isinstance(e, _CacheEntry) self._metrics.inc_hits() return e.value - def get_with_expiry(self, key): + def get_with_expiry(self, key: KT) -> Tuple[VT, float, float]: """Get a value, and its expiry time, from the cache Args: key: key to look up Returns: - Tuple[Any, float, float]: the value from the cache, the expiry time - and the TTL + A tuple of the value from the cache, the expiry time and the TTL Raises: KeyError if the entry is not found @@ -102,7 +108,7 @@ class TTLCache: self._metrics.inc_hits() return e.value, e.expiry_time, e.ttl - def pop(self, key, default=SENTINEL): + def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]: # type: ignore """Remove a value from the cache If key is in the cache, remove it and return its value, else return default. @@ -118,29 +124,30 @@ class TTLCache: """ self.expire() e = self._data.pop(key, SENTINEL) - if e == SENTINEL: + if e is SENTINEL: self._metrics.inc_misses() - if default == SENTINEL: + if default is SENTINEL: raise KeyError(key) return default + assert isinstance(e, _CacheEntry) self._expiry_list.remove(e) self._metrics.inc_hits() return e.value - def __getitem__(self, key): + def __getitem__(self, key: KT) -> VT: return self.get(key) - def __delitem__(self, key): + def __delitem__(self, key: KT) -> None: self.pop(key) - def __contains__(self, key): + def __contains__(self, key: KT) -> bool: return key in self._data - def __len__(self): + def __len__(self) -> int: self.expire() return len(self._data) - def expire(self): + def expire(self) -> None: """Run the expiry on the cache. Any entries whose expiry times are due will be removed """ @@ -158,7 +165,7 @@ class _CacheEntry: """TTLCache entry""" # expiry_time is the first attribute, so that entries are sorted by expiry. - expiry_time = attr.ib() - ttl = attr.ib() + expiry_time = attr.ib(type=float) + ttl = attr.ib(type=float) key = attr.ib() value = attr.ib() diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index 5f7a6dd1d3..5ca2e71e60 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py
@@ -36,7 +36,7 @@ def freeze(o): def unfreeze(o): if isinstance(o, (dict, frozendict)): - return dict({k: unfreeze(v) for k, v in o.items()}) + return {k: unfreeze(v) for k, v in o.items()} if isinstance(o, (bytes, str)): return o