diff options
Diffstat (limited to 'synapse/util/caches/descriptors.py')
-rw-r--r-- | synapse/util/caches/descriptors.py | 64 |
1 files changed, 55 insertions, 9 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 8514a75a1c..ce736fdf75 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -36,6 +36,8 @@ from typing import ( ) from weakref import WeakValueDictionary +import attr + from twisted.internet import defer from twisted.python.failure import Failure @@ -466,6 +468,35 @@ class _CacheContext: ) +@attr.s(auto_attribs=True, slots=True, frozen=True) +class _CachedFunctionDescriptor: + """Helper for `@cached`, we name it so that we can hook into it with mypy + plugin.""" + + max_entries: int + num_args: Optional[int] + uncached_args: Optional[Collection[str]] + tree: bool + cache_context: bool + iterable: bool + prune_unread_entries: bool + name: Optional[str] + + def __call__(self, orig: F) -> CachedFunction[F]: + d = DeferredCacheDescriptor( + orig, + max_entries=self.max_entries, + num_args=self.num_args, + uncached_args=self.uncached_args, + tree=self.tree, + cache_context=self.cache_context, + iterable=self.iterable, + prune_unread_entries=self.prune_unread_entries, + name=self.name, + ) + return cast(CachedFunction[F], d) + + def cached( *, max_entries: int = 1000, @@ -476,9 +507,8 @@ def cached( iterable: bool = False, prune_unread_entries: bool = True, name: Optional[str] = None, -) -> Callable[[F], CachedFunction[F]]: - func = lambda orig: DeferredCacheDescriptor( - orig, +) -> _CachedFunctionDescriptor: + return _CachedFunctionDescriptor( max_entries=max_entries, num_args=num_args, uncached_args=uncached_args, @@ -489,7 +519,26 @@ def cached( name=name, ) - return cast(Callable[[F], CachedFunction[F]], func) + +@attr.s(auto_attribs=True, slots=True, frozen=True) +class _CachedListFunctionDescriptor: + """Helper for `@cachedList`, we name it so that we can hook into it with mypy + plugin.""" + + cached_method_name: str + list_name: str + num_args: Optional[int] = None + name: Optional[str] = None + + def __call__(self, orig: F) -> CachedFunction[F]: + d = DeferredCacheListDescriptor( + orig, + cached_method_name=self.cached_method_name, + list_name=self.list_name, + num_args=self.num_args, + name=self.name, + ) + return cast(CachedFunction[F], d) def cachedList( @@ -498,7 +547,7 @@ def cachedList( list_name: str, num_args: Optional[int] = None, name: Optional[str] = None, -) -> Callable[[F], CachedFunction[F]]: +) -> _CachedListFunctionDescriptor: """Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`. Used to do batch lookups for an already created cache. One of the arguments @@ -527,16 +576,13 @@ def cachedList( def batch_do_something(self, first_arg, second_args): ... """ - func = lambda orig: DeferredCacheListDescriptor( - orig, + return _CachedListFunctionDescriptor( cached_method_name=cached_method_name, list_name=list_name, num_args=num_args, name=name, ) - return cast(Callable[[F], CachedFunction[F]], func) - def _get_cache_key_builder( param_names: Sequence[str], |