diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index b9dcca17f1..375cd443f1 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -19,12 +19,15 @@ import logging
from typing import (
Any,
Callable,
+ Dict,
Generic,
+ Hashable,
Iterable,
Mapping,
Optional,
Sequence,
Tuple,
+ Type,
TypeVar,
Union,
cast,
@@ -32,6 +35,7 @@ from typing import (
from weakref import WeakValueDictionary
from twisted.internet import defer
+from twisted.python.failure import Failure
from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.util import unwrapFirstError
@@ -60,7 +64,12 @@ class _CachedFunction(Generic[F]):
class _CacheDescriptorBase:
- def __init__(self, orig: Callable[..., Any], num_args, cache_context=False):
+ def __init__(
+ self,
+ orig: Callable[..., Any],
+ num_args: Optional[int],
+ cache_context: bool = False,
+ ):
self.orig = orig
arg_spec = inspect.getfullargspec(orig)
@@ -172,14 +181,14 @@ class LruCacheDescriptor(_CacheDescriptorBase):
def __init__(
self,
- orig,
+ orig: Callable[..., Any],
max_entries: int = 1000,
cache_context: bool = False,
):
super().__init__(orig, num_args=None, cache_context=cache_context)
self.max_entries = max_entries
- def __get__(self, obj, owner):
+ def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
cache: LruCache[CacheKey, Any] = LruCache(
cache_name=self.orig.__name__,
max_size=self.max_entries,
@@ -189,7 +198,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
sentinel = LruCacheDescriptor._Sentinel.sentinel
@functools.wraps(self.orig)
- def _wrapped(*args, **kwargs):
+ def _wrapped(*args: Any, **kwargs: Any) -> Any:
invalidate_callback = kwargs.pop("on_invalidate", None)
callbacks = (invalidate_callback,) if invalidate_callback else ()
@@ -245,19 +254,19 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
return r1 + r2
Args:
- num_args (int): number of positional arguments (excluding ``self`` and
+ num_args: number of positional arguments (excluding ``self`` and
``cache_context``) to use as cache keys. Defaults to all named
args of the function.
"""
def __init__(
self,
- orig,
- max_entries=1000,
- num_args=None,
- tree=False,
- cache_context=False,
- iterable=False,
+ orig: Callable[..., Any],
+ max_entries: int = 1000,
+ num_args: Optional[int] = None,
+ tree: bool = False,
+ cache_context: bool = False,
+ iterable: bool = False,
prune_unread_entries: bool = True,
):
super().__init__(orig, num_args=num_args, cache_context=cache_context)
@@ -272,7 +281,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
self.iterable = iterable
self.prune_unread_entries = prune_unread_entries
- def __get__(self, obj, owner):
+ def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
cache: DeferredCache[CacheKey, Any] = DeferredCache(
name=self.orig.__name__,
max_entries=self.max_entries,
@@ -284,7 +293,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
get_cache_key = self.cache_key_builder
@functools.wraps(self.orig)
- def _wrapped(*args, **kwargs):
+ def _wrapped(*args: Any, **kwargs: Any) -> Any:
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
@@ -335,13 +344,19 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
of results.
"""
- def __init__(self, orig, cached_method_name, list_name, num_args=None):
+ def __init__(
+ self,
+ orig: Callable[..., Any],
+ cached_method_name: str,
+ list_name: str,
+ num_args: Optional[int] = None,
+ ):
"""
Args:
- orig (function)
- cached_method_name (str): The name of the cached method.
- list_name (str): Name of the argument which is the bulk lookup list
- num_args (int): number of positional arguments (excluding ``self``,
+ orig
+ cached_method_name: The name of the cached method.
+ list_name: Name of the argument which is the bulk lookup list
+ num_args: number of positional arguments (excluding ``self``,
but including list_name) to use as cache keys. Defaults to all
named args of the function.
"""
@@ -360,13 +375,15 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
% (self.list_name, cached_method_name)
)
- def __get__(self, obj, objtype=None):
+ def __get__(
+ self, obj: Optional[Any], objtype: Optional[Type] = None
+ ) -> Callable[..., Any]:
cached_method = getattr(obj, self.cached_method_name)
cache: DeferredCache[CacheKey, Any] = cached_method.cache
num_args = cached_method.num_args
@functools.wraps(self.orig)
- def wrapped(*args, **kwargs):
+ def wrapped(*args: Any, **kwargs: Any) -> Any:
# If we're passed a cache_context then we'll want to call its
# invalidate() whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
@@ -377,7 +394,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
results = {}
- def update_results_dict(res, arg):
+ def update_results_dict(res: Any, arg: Hashable) -> None:
results[arg] = res
# list of deferreds to wait for
@@ -389,13 +406,13 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
# otherwise a tuple is used.
if num_args == 1:
- def arg_to_cache_key(arg):
+ def arg_to_cache_key(arg: Hashable) -> Hashable:
return arg
else:
keylist = list(keyargs)
- def arg_to_cache_key(arg):
+ def arg_to_cache_key(arg: Hashable) -> Hashable:
keylist[self.list_pos] = arg
return tuple(keylist)
@@ -421,7 +438,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
key = arg_to_cache_key(arg)
cache.set(key, deferred, callback=invalidate_callback)
- def complete_all(res):
+ def complete_all(res: Dict[Hashable, Any]) -> None:
# the wrapped function has completed. It returns a
# a dict. We can now resolve the observable deferreds in
# the cache and update our own result map.
@@ -430,7 +447,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
deferreds_map[e].callback(val)
results[e] = val
- def errback(f):
+ def errback(f: Failure) -> Failure:
# the wrapped function has failed. Invalidate any cache
# entries we're supposed to be populating, and fail
# their deferreds.
|