diff options
Diffstat (limited to 'synapse/util/caches')
-rw-r--r-- | synapse/util/caches/descriptors.py | 42 |
1 files changed, 28 insertions, 14 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 49d9fddcf0..825810eb16 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -18,11 +18,10 @@ import functools import inspect import logging import threading -from typing import Any, Tuple, Union, cast +from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast from weakref import WeakValueDictionary from prometheus_client import Gauge -from typing_extensions import Protocol from twisted.internet import defer @@ -38,8 +37,10 @@ logger = logging.getLogger(__name__) CacheKey = Union[Tuple, Any] +F = TypeVar("F", bound=Callable[..., Any]) -class _CachedFunction(Protocol): + +class _CachedFunction(Generic[F]): invalidate = None # type: Any invalidate_all = None # type: Any invalidate_many = None # type: Any @@ -47,8 +48,11 @@ class _CachedFunction(Protocol): cache = None # type: Any num_args = None # type: Any - def __name__(self): - ... + __name__ = None # type: str + + # Note: This function signature is actually fiddled with by the synapse mypy + # plugin to a) make it a bound method, and b) remove any `cache_context` arg. + __call__ = None # type: F cache_pending_metric = Gauge( @@ -123,7 +127,7 @@ class Cache(object): self.name = name self.keylen = keylen - self.thread = None + self.thread = None # type: Optional[threading.Thread] self.metrics = register_cache( "cache", name, @@ -662,9 +666,13 @@ class _CacheContext: def cached( - max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False -): - return lambda orig: CacheDescriptor( + max_entries: int = 1000, + num_args: Optional[int] = None, + tree: bool = False, + cache_context: bool = False, + iterable: bool = False, +) -> Callable[[F], _CachedFunction[F]]: + func = lambda orig: CacheDescriptor( orig, max_entries=max_entries, num_args=num_args, @@ -673,8 +681,12 @@ def cached( iterable=iterable, ) + return cast(Callable[[F], _CachedFunction[F]], func) -def cachedList(cached_method_name, list_name, num_args=None): + +def cachedList( + cached_method_name: str, list_name: str, num_args: Optional[int] = None +) -> Callable[[F], _CachedFunction[F]]: """Creates a descriptor that wraps a function in a `CacheListDescriptor`. Used to do batch lookups for an already created cache. A single argument @@ -684,11 +696,11 @@ def cachedList(cached_method_name, list_name, num_args=None): cache. Args: - cached_method_name (str): The name of the single-item lookup method. + cached_method_name: The name of the single-item lookup method. This is only used to find the cache to use. - list_name (str): The name of the argument that is the list to use to + list_name: The name of the argument that is the list to use to do batch lookups in the cache. - num_args (int): Number of arguments to use as the key in the cache + num_args: Number of arguments to use as the key in the cache (including list_name). Defaults to all named parameters. Example: @@ -702,9 +714,11 @@ def cachedList(cached_method_name, list_name, num_args=None): def batch_do_something(self, first_arg, second_args): ... """ - return lambda orig: CacheListDescriptor( + func = lambda orig: CacheListDescriptor( orig, cached_method_name=cached_method_name, list_name=list_name, num_args=num_args, ) + + return cast(Callable[[F], _CachedFunction[F]], func) |