diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index f33c115844..c3b2d981ea 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -496,7 +496,7 @@ def timeout_deferred(
try:
deferred.cancel()
- except: # noqa: E722, if we throw any exception it'll break time outs
+ except Exception: # if we throw any exception it'll break time outs
logger.exception("Canceller failed during timeout")
# the cancel() call should have set off a chain of errbacks which
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index e676c2cac4..48f64eeb38 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -25,8 +25,8 @@ from synapse.config.cache import add_resizable_cache
logger = logging.getLogger(__name__)
-caches_by_name = {}
-collectors_by_name = {} # type: Dict
+caches_by_name = {} # type: Dict[str, Sized]
+collectors_by_name = {} # type: Dict[str, CacheMetric]
cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
@@ -116,7 +116,7 @@ def register_cache(
"""
if resizable:
if not resize_callback:
- resize_callback = getattr(cache, "set_cache_factor")
+ resize_callback = cache.set_cache_factor # type: ignore
add_resizable_cache(cache_name, resize_callback)
metric = CacheMetric(cache, cache_type, cache_name, collect_callback)
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 1adc92eb90..dd392cf694 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -283,7 +283,9 @@ class DeferredCache(Generic[KT, VT]):
# we return a new Deferred which will be called before any subsequent observers.
return observable.observe()
- def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None):
+ def prefill(
+ self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None
+ ):
callbacks = [callback] if callback else []
self.cache.set(key, value, callbacks=callbacks)
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 588d2d49f2..b3b413b02c 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -15,26 +15,38 @@
import enum
import logging
import threading
-from collections import namedtuple
-from typing import Any
+from typing import Any, Dict, Generic, Iterable, Optional, Set, TypeVar
+
+import attr
from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__)
-class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "value"))):
+# The type of the cache keys.
+KT = TypeVar("KT")
+# The type of the dictionary keys.
+DKT = TypeVar("DKT")
+
+
+@attr.s(slots=True)
+class DictionaryEntry:
"""Returned when getting an entry from the cache
Attributes:
- full (bool): Whether the cache has the full or dict or just some keys.
+ full: Whether the cache has the full or dict or just some keys.
If not full then not all requested keys will necessarily be present
in `value`
- known_absent (set): Keys that were looked up in the dict and were not
+ known_absent: Keys that were looked up in the dict and were not
there.
- value (dict): The full or partial dict value
+ value: The full or partial dict value
"""
+ full = attr.ib(type=bool)
+ known_absent = attr.ib()
+ value = attr.ib()
+
def __len__(self):
return len(self.value)
@@ -45,21 +57,21 @@ class _Sentinel(enum.Enum):
sentinel = object()
-class DictionaryCache:
+class DictionaryCache(Generic[KT, DKT]):
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
fetching a subset of dictionary keys for a particular key.
"""
- def __init__(self, name, max_entries=1000):
+ def __init__(self, name: str, max_entries: int = 1000):
self.cache = LruCache(
max_size=max_entries, cache_name=name, size_callback=len
- ) # type: LruCache[Any, DictionaryEntry]
+ ) # type: LruCache[KT, DictionaryEntry]
self.name = name
self.sequence = 0
- self.thread = None
+ self.thread = None # type: Optional[threading.Thread]
- def check_thread(self):
+ def check_thread(self) -> None:
expected_thread = self.thread
if expected_thread is None:
self.thread = threading.current_thread()
@@ -69,12 +81,14 @@ class DictionaryCache:
"Cache objects can only be accessed from the main thread"
)
- def get(self, key, dict_keys=None):
+ def get(
+ self, key: KT, dict_keys: Optional[Iterable[DKT]] = None
+ ) -> DictionaryEntry:
"""Fetch an entry out of the cache
Args:
key
- dict_key(list): If given a set of keys then return only those keys
+ dict_key: If given a set of keys then return only those keys
that exist in the cache.
Returns:
@@ -95,7 +109,7 @@ class DictionaryCache:
return DictionaryEntry(False, set(), {})
- def invalidate(self, key):
+ def invalidate(self, key: KT) -> None:
self.check_thread()
# Increment the sequence number so that any SELECT statements that
@@ -103,19 +117,25 @@ class DictionaryCache:
self.sequence += 1
self.cache.pop(key, None)
- def invalidate_all(self):
+ def invalidate_all(self) -> None:
self.check_thread()
self.sequence += 1
self.cache.clear()
- def update(self, sequence, key, value, fetched_keys=None):
+ def update(
+ self,
+ sequence: int,
+ key: KT,
+ value: Dict[DKT, Any],
+ fetched_keys: Optional[Set[DKT]] = None,
+ ) -> None:
"""Updates the entry in the cache
Args:
sequence
- key (K)
- value (dict[X,Y]): The value to update the cache with.
- fetched_keys (None|set[X]): All of the dictionary keys which were
+ key
+ value: The value to update the cache with.
+ fetched_keys: All of the dictionary keys which were
fetched from the database.
If None, this is the complete value for key K. Otherwise, it
@@ -131,7 +151,9 @@ class DictionaryCache:
else:
self._update_or_insert(key, value, fetched_keys)
- def _update_or_insert(self, key, value, known_absent):
+ def _update_or_insert(
+ self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]
+ ) -> None:
# We pop and reinsert as we need to tell the cache the size may have
# changed
@@ -140,5 +162,5 @@ class DictionaryCache:
entry.known_absent.update(known_absent)
self.cache[key] = entry
- def _insert(self, key, value, known_absent):
+ def _insert(self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]) -> None:
self.cache[key] = DictionaryEntry(True, known_absent, value)
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index e15f7ee698..4dc3477e89 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -15,40 +15,50 @@
import logging
from collections import OrderedDict
+from typing import Any, Generic, Optional, TypeVar, Union, overload
+
+import attr
+from typing_extensions import Literal
from synapse.config import cache as cache_config
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util import Clock
from synapse.util.caches import register_cache
logger = logging.getLogger(__name__)
-SENTINEL = object()
+SENTINEL = object() # type: Any
+
+T = TypeVar("T")
+KT = TypeVar("KT")
+VT = TypeVar("VT")
-class ExpiringCache:
+
+class ExpiringCache(Generic[KT, VT]):
def __init__(
self,
- cache_name,
- clock,
- max_len=0,
- expiry_ms=0,
- reset_expiry_on_get=False,
- iterable=False,
+ cache_name: str,
+ clock: Clock,
+ max_len: int = 0,
+ expiry_ms: int = 0,
+ reset_expiry_on_get: bool = False,
+ iterable: bool = False,
):
"""
Args:
- cache_name (str): Name of this cache, used for logging.
- clock (Clock)
- max_len (int): Max size of dict. If the dict grows larger than this
+ cache_name: Name of this cache, used for logging.
+ clock
+ max_len: Max size of dict. If the dict grows larger than this
then the oldest items get automatically evicted. Default is 0,
which indicates there is no max limit.
- expiry_ms (int): How long before an item is evicted from the cache
+ expiry_ms: How long before an item is evicted from the cache
in milliseconds. Default is 0, indicating items never get
evicted based on time.
- reset_expiry_on_get (bool): If true, will reset the expiry time for
+ reset_expiry_on_get: If true, will reset the expiry time for
an item on access. Defaults to False.
- iterable (bool): If true, the size is calculated by summing the
+ iterable: If true, the size is calculated by summing the
sizes of all entries, rather than the number of entries.
"""
self._cache_name = cache_name
@@ -62,7 +72,7 @@ class ExpiringCache:
self._expiry_ms = expiry_ms
self._reset_expiry_on_get = reset_expiry_on_get
- self._cache = OrderedDict()
+ self._cache = OrderedDict() # type: OrderedDict[KT, _CacheEntry]
self.iterable = iterable
@@ -79,12 +89,12 @@ class ExpiringCache:
self._clock.looping_call(f, self._expiry_ms / 2)
- def __setitem__(self, key, value):
+ def __setitem__(self, key: KT, value: VT) -> None:
now = self._clock.time_msec()
self._cache[key] = _CacheEntry(now, value)
self.evict()
- def evict(self):
+ def evict(self) -> None:
# Evict if there are now too many items
while self._max_size and len(self) > self._max_size:
_key, value = self._cache.popitem(last=False)
@@ -93,7 +103,7 @@ class ExpiringCache:
else:
self.metrics.inc_evictions()
- def __getitem__(self, key):
+ def __getitem__(self, key: KT) -> VT:
try:
entry = self._cache[key]
self.metrics.inc_hits()
@@ -106,7 +116,7 @@ class ExpiringCache:
return entry.value
- def pop(self, key, default=SENTINEL):
+ def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]:
"""Removes and returns the value with the given key from the cache.
If the key isn't in the cache then `default` will be returned if
@@ -115,29 +125,40 @@ class ExpiringCache:
Identical functionality to `dict.pop(..)`.
"""
- value = self._cache.pop(key, default)
+ value = self._cache.pop(key, SENTINEL)
+ # The key was not found.
if value is SENTINEL:
- raise KeyError(key)
+ if default is SENTINEL:
+ raise KeyError(key)
+ return default
- return value
+ return value.value
- def __contains__(self, key):
+ def __contains__(self, key: KT) -> bool:
return key in self._cache
- def get(self, key, default=None):
+ @overload
+ def get(self, key: KT, default: Literal[None] = None) -> Optional[VT]:
+ ...
+
+ @overload
+ def get(self, key: KT, default: T) -> Union[VT, T]:
+ ...
+
+ def get(self, key: KT, default: Optional[T] = None) -> Union[VT, Optional[T]]:
try:
return self[key]
except KeyError:
return default
- def setdefault(self, key, value):
+ def setdefault(self, key: KT, value: VT) -> VT:
try:
return self[key]
except KeyError:
self[key] = value
return value
- def _prune_cache(self):
+ def _prune_cache(self) -> None:
if not self._expiry_ms:
# zero expiry time means don't expire. This should never get called
# since we have this check in start too.
@@ -166,7 +187,7 @@ class ExpiringCache:
len(self),
)
- def __len__(self):
+ def __len__(self) -> int:
if self.iterable:
return sum(len(entry.value) for entry in self._cache.values())
else:
@@ -190,9 +211,7 @@ class ExpiringCache:
return False
+@attr.s(slots=True)
class _CacheEntry:
- __slots__ = ["time", "value"]
-
- def __init__(self, time, value):
- self.time = time
- self.value = value
+ time = attr.ib(type=int)
+ value = attr.ib()
diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index 6ce2a3d12b..96a8274940 100644
--- a/synapse/util/caches/ttlcache.py
+++ b/synapse/util/caches/ttlcache.py
@@ -15,6 +15,7 @@
import logging
import time
+from typing import Any, Callable, Dict, Generic, Tuple, TypeVar, Union
import attr
from sortedcontainers import SortedList
@@ -23,15 +24,19 @@ from synapse.util.caches import register_cache
logger = logging.getLogger(__name__)
-SENTINEL = object()
+SENTINEL = object() # type: Any
+T = TypeVar("T")
+KT = TypeVar("KT")
+VT = TypeVar("VT")
-class TTLCache:
+
+class TTLCache(Generic[KT, VT]):
"""A key/value cache implementation where each entry has its own TTL"""
- def __init__(self, cache_name, timer=time.time):
+ def __init__(self, cache_name: str, timer: Callable[[], float] = time.time):
# map from key to _CacheEntry
- self._data = {}
+ self._data = {} # type: Dict[KT, _CacheEntry]
# the _CacheEntries, sorted by expiry time
self._expiry_list = SortedList() # type: SortedList[_CacheEntry]
@@ -40,26 +45,27 @@ class TTLCache:
self._metrics = register_cache("ttl", cache_name, self, resizable=False)
- def set(self, key, value, ttl):
+ def set(self, key: KT, value: VT, ttl: float) -> None:
"""Add/update an entry in the cache
Args:
key: key for this entry
value: value for this entry
- ttl (float): TTL for this entry, in seconds
+ ttl: TTL for this entry, in seconds
"""
expiry = self._timer() + ttl
self.expire()
e = self._data.pop(key, SENTINEL)
- if e != SENTINEL:
+ if e is not SENTINEL:
+ assert isinstance(e, _CacheEntry)
self._expiry_list.remove(e)
entry = _CacheEntry(expiry_time=expiry, ttl=ttl, key=key, value=value)
self._data[key] = entry
self._expiry_list.add(entry)
- def get(self, key, default=SENTINEL):
+ def get(self, key: KT, default: T = SENTINEL) -> Union[VT, T]:
"""Get a value from the cache
Args:
@@ -72,23 +78,23 @@ class TTLCache:
"""
self.expire()
e = self._data.get(key, SENTINEL)
- if e == SENTINEL:
+ if e is SENTINEL:
self._metrics.inc_misses()
- if default == SENTINEL:
+ if default is SENTINEL:
raise KeyError(key)
return default
+ assert isinstance(e, _CacheEntry)
self._metrics.inc_hits()
return e.value
- def get_with_expiry(self, key):
+ def get_with_expiry(self, key: KT) -> Tuple[VT, float, float]:
"""Get a value, and its expiry time, from the cache
Args:
key: key to look up
Returns:
- Tuple[Any, float, float]: the value from the cache, the expiry time
- and the TTL
+ A tuple of the value from the cache, the expiry time and the TTL
Raises:
KeyError if the entry is not found
@@ -102,7 +108,7 @@ class TTLCache:
self._metrics.inc_hits()
return e.value, e.expiry_time, e.ttl
- def pop(self, key, default=SENTINEL):
+ def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]: # type: ignore
"""Remove a value from the cache
If key is in the cache, remove it and return its value, else return default.
@@ -118,29 +124,30 @@ class TTLCache:
"""
self.expire()
e = self._data.pop(key, SENTINEL)
- if e == SENTINEL:
+ if e is SENTINEL:
self._metrics.inc_misses()
- if default == SENTINEL:
+ if default is SENTINEL:
raise KeyError(key)
return default
+ assert isinstance(e, _CacheEntry)
self._expiry_list.remove(e)
self._metrics.inc_hits()
return e.value
- def __getitem__(self, key):
+ def __getitem__(self, key: KT) -> VT:
return self.get(key)
- def __delitem__(self, key):
+ def __delitem__(self, key: KT) -> None:
self.pop(key)
- def __contains__(self, key):
+ def __contains__(self, key: KT) -> bool:
return key in self._data
- def __len__(self):
+ def __len__(self) -> int:
self.expire()
return len(self._data)
- def expire(self):
+ def expire(self) -> None:
"""Run the expiry on the cache. Any entries whose expiry times are due will
be removed
"""
@@ -158,7 +165,7 @@ class _CacheEntry:
"""TTLCache entry"""
# expiry_time is the first attribute, so that entries are sorted by expiry.
- expiry_time = attr.ib()
- ttl = attr.ib()
+ expiry_time = attr.ib(type=float)
+ ttl = attr.ib(type=float)
key = attr.ib()
value = attr.ib()
diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index 5f7a6dd1d3..5ca2e71e60 100644
--- a/synapse/util/frozenutils.py
+++ b/synapse/util/frozenutils.py
@@ -36,7 +36,7 @@ def freeze(o):
def unfreeze(o):
if isinstance(o, (dict, frozendict)):
- return dict({k: unfreeze(v) for k, v in o.items()})
+ return {k: unfreeze(v) for k, v in o.items()}
if isinstance(o, (bytes, str)):
return o
|