diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index df4d61e4b6..15debd6c46 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -17,7 +17,7 @@ import logging
import typing
from enum import Enum, auto
from sys import intern
-from typing import Callable, Dict, Optional, Sized
+from typing import Any, Callable, Dict, List, Optional, Sized
import attr
from prometheus_client.core import Gauge
@@ -58,20 +58,20 @@ class EvictionReason(Enum):
time = auto()
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class CacheMetric:
- _cache = attr.ib()
- _cache_type = attr.ib(type=str)
- _cache_name = attr.ib(type=str)
- _collect_callback = attr.ib(type=Optional[Callable])
+ _cache: Sized
+ _cache_type: str
+ _cache_name: str
+ _collect_callback: Optional[Callable]
- hits = attr.ib(default=0)
- misses = attr.ib(default=0)
+ hits: int = 0
+ misses: int = 0
eviction_size_by_reason: typing.Counter[EvictionReason] = attr.ib(
factory=collections.Counter
)
- memory_usage = attr.ib(default=None)
+ memory_usage: Optional[int] = None
def inc_hits(self) -> None:
self.hits += 1
@@ -89,13 +89,14 @@ class CacheMetric:
self.memory_usage += memory
def dec_memory_usage(self, memory: int) -> None:
+ assert self.memory_usage is not None
self.memory_usage -= memory
def clear_memory_usage(self) -> None:
if self.memory_usage is not None:
self.memory_usage = 0
- def describe(self):
+ def describe(self) -> List[str]:
return []
def collect(self) -> None:
@@ -118,8 +119,9 @@ class CacheMetric:
self.eviction_size_by_reason[reason]
)
cache_total.labels(self._cache_name).set(self.hits + self.misses)
- if getattr(self._cache, "max_size", None):
- cache_max_size.labels(self._cache_name).set(self._cache.max_size)
+ max_size = getattr(self._cache, "max_size", None)
+ if max_size:
+ cache_max_size.labels(self._cache_name).set(max_size)
if TRACK_MEMORY_USAGE:
# self.memory_usage can be None if nothing has been inserted
@@ -193,7 +195,7 @@ KNOWN_KEYS = {
}
-def intern_string(string):
+def intern_string(string: Optional[str]) -> Optional[str]:
"""Takes a (potentially) unicode string and interns it if it's ascii"""
if string is None:
return None
@@ -204,7 +206,7 @@ def intern_string(string):
return string
-def intern_dict(dictionary):
+def intern_dict(dictionary: Dict[str, Any]) -> Dict[str, Any]:
"""Takes a dictionary and interns well known keys and their values"""
return {
KNOWN_KEYS.get(key, key): _intern_known_values(key, value)
@@ -212,7 +214,7 @@ def intern_dict(dictionary):
}
-def _intern_known_values(key, value):
+def _intern_known_values(key: str, value: Any) -> Any:
intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key")
if key in intern_keys:
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index da502aec11..3c4cc093af 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -289,7 +289,7 @@ class DeferredCache(Generic[KT, VT]):
callbacks = [callback] if callback else []
self.cache.set(key, value, callbacks=callbacks)
- def invalidate(self, key) -> None:
+ def invalidate(self, key: KT) -> None:
"""Delete a key, or tree of entries
If the cache is backed by a regular dict, then "key" must be of
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index b9dcca17f1..375cd443f1 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -19,12 +19,15 @@ import logging
from typing import (
Any,
Callable,
+ Dict,
Generic,
+ Hashable,
Iterable,
Mapping,
Optional,
Sequence,
Tuple,
+ Type,
TypeVar,
Union,
cast,
@@ -32,6 +35,7 @@ from typing import (
from weakref import WeakValueDictionary
from twisted.internet import defer
+from twisted.python.failure import Failure
from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.util import unwrapFirstError
@@ -60,7 +64,12 @@ class _CachedFunction(Generic[F]):
class _CacheDescriptorBase:
- def __init__(self, orig: Callable[..., Any], num_args, cache_context=False):
+ def __init__(
+ self,
+ orig: Callable[..., Any],
+ num_args: Optional[int],
+ cache_context: bool = False,
+ ):
self.orig = orig
arg_spec = inspect.getfullargspec(orig)
@@ -172,14 +181,14 @@ class LruCacheDescriptor(_CacheDescriptorBase):
def __init__(
self,
- orig,
+ orig: Callable[..., Any],
max_entries: int = 1000,
cache_context: bool = False,
):
super().__init__(orig, num_args=None, cache_context=cache_context)
self.max_entries = max_entries
- def __get__(self, obj, owner):
+ def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
cache: LruCache[CacheKey, Any] = LruCache(
cache_name=self.orig.__name__,
max_size=self.max_entries,
@@ -189,7 +198,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
sentinel = LruCacheDescriptor._Sentinel.sentinel
@functools.wraps(self.orig)
- def _wrapped(*args, **kwargs):
+ def _wrapped(*args: Any, **kwargs: Any) -> Any:
invalidate_callback = kwargs.pop("on_invalidate", None)
callbacks = (invalidate_callback,) if invalidate_callback else ()
@@ -245,19 +254,19 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
return r1 + r2
Args:
- num_args (int): number of positional arguments (excluding ``self`` and
+ num_args: number of positional arguments (excluding ``self`` and
``cache_context``) to use as cache keys. Defaults to all named
args of the function.
"""
def __init__(
self,
- orig,
- max_entries=1000,
- num_args=None,
- tree=False,
- cache_context=False,
- iterable=False,
+ orig: Callable[..., Any],
+ max_entries: int = 1000,
+ num_args: Optional[int] = None,
+ tree: bool = False,
+ cache_context: bool = False,
+ iterable: bool = False,
prune_unread_entries: bool = True,
):
super().__init__(orig, num_args=num_args, cache_context=cache_context)
@@ -272,7 +281,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
self.iterable = iterable
self.prune_unread_entries = prune_unread_entries
- def __get__(self, obj, owner):
+ def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
cache: DeferredCache[CacheKey, Any] = DeferredCache(
name=self.orig.__name__,
max_entries=self.max_entries,
@@ -284,7 +293,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
get_cache_key = self.cache_key_builder
@functools.wraps(self.orig)
- def _wrapped(*args, **kwargs):
+ def _wrapped(*args: Any, **kwargs: Any) -> Any:
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
@@ -335,13 +344,19 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
of results.
"""
- def __init__(self, orig, cached_method_name, list_name, num_args=None):
+ def __init__(
+ self,
+ orig: Callable[..., Any],
+ cached_method_name: str,
+ list_name: str,
+ num_args: Optional[int] = None,
+ ):
"""
Args:
- orig (function)
- cached_method_name (str): The name of the cached method.
- list_name (str): Name of the argument which is the bulk lookup list
- num_args (int): number of positional arguments (excluding ``self``,
+ orig
+ cached_method_name: The name of the cached method.
+ list_name: Name of the argument which is the bulk lookup list
+ num_args: number of positional arguments (excluding ``self``,
but including list_name) to use as cache keys. Defaults to all
named args of the function.
"""
@@ -360,13 +375,15 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
% (self.list_name, cached_method_name)
)
- def __get__(self, obj, objtype=None):
+ def __get__(
+ self, obj: Optional[Any], objtype: Optional[Type] = None
+ ) -> Callable[..., Any]:
cached_method = getattr(obj, self.cached_method_name)
cache: DeferredCache[CacheKey, Any] = cached_method.cache
num_args = cached_method.num_args
@functools.wraps(self.orig)
- def wrapped(*args, **kwargs):
+ def wrapped(*args: Any, **kwargs: Any) -> Any:
# If we're passed a cache_context then we'll want to call its
# invalidate() whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
@@ -377,7 +394,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
results = {}
- def update_results_dict(res, arg):
+ def update_results_dict(res: Any, arg: Hashable) -> None:
results[arg] = res
# list of deferreds to wait for
@@ -389,13 +406,13 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
# otherwise a tuple is used.
if num_args == 1:
- def arg_to_cache_key(arg):
+ def arg_to_cache_key(arg: Hashable) -> Hashable:
return arg
else:
keylist = list(keyargs)
- def arg_to_cache_key(arg):
+ def arg_to_cache_key(arg: Hashable) -> Hashable:
keylist[self.list_pos] = arg
return tuple(keylist)
@@ -421,7 +438,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
key = arg_to_cache_key(arg)
cache.set(key, deferred, callback=invalidate_callback)
- def complete_all(res):
+ def complete_all(res: Dict[Hashable, Any]) -> None:
# the wrapped function has completed. It returns a
# a dict. We can now resolve the observable deferreds in
# the cache and update our own result map.
@@ -430,7 +447,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
deferreds_map[e].callback(val)
results[e] = val
- def errback(f):
+ def errback(f: Failure) -> Failure:
# the wrapped function has failed. Invalidate any cache
# entries we're supposed to be populating, and fail
# their deferreds.
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index c3f72aa06d..67ee4c693b 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -19,6 +19,8 @@ from typing import Any, Generic, Optional, TypeVar, Union, overload
import attr
from typing_extensions import Literal
+from twisted.internet import defer
+
from synapse.config import cache as cache_config
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util import Clock
@@ -81,7 +83,7 @@ class ExpiringCache(Generic[KT, VT]):
# Don't bother starting the loop if things never expire
return
- def f():
+ def f() -> "defer.Deferred[None]":
return run_as_background_process(
"prune_cache_%s" % self._cache_name, self._prune_cache
)
@@ -157,7 +159,7 @@ class ExpiringCache(Generic[KT, VT]):
self[key] = value
return value
- def _prune_cache(self) -> None:
+ async def _prune_cache(self) -> None:
if not self._expiry_ms:
# zero expiry time means don't expire. This should never get called
# since we have this check in start too.
@@ -210,7 +212,7 @@ class ExpiringCache(Generic[KT, VT]):
return False
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class _CacheEntry:
- time = attr.ib(type=int)
- value = attr.ib()
+ time: int
+ value: Any
|