diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 8514a75a1c..ce736fdf75 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -36,6 +36,8 @@ from typing import (
)
from weakref import WeakValueDictionary
+import attr
+
from twisted.internet import defer
from twisted.python.failure import Failure
@@ -466,6 +468,35 @@ class _CacheContext:
)
+@attr.s(auto_attribs=True, slots=True, frozen=True)
+class _CachedFunctionDescriptor:
+ """Helper for `@cached`, we name it so that we can hook into it with mypy
+ plugin."""
+
+ max_entries: int
+ num_args: Optional[int]
+ uncached_args: Optional[Collection[str]]
+ tree: bool
+ cache_context: bool
+ iterable: bool
+ prune_unread_entries: bool
+ name: Optional[str]
+
+ def __call__(self, orig: F) -> CachedFunction[F]:
+ d = DeferredCacheDescriptor(
+ orig,
+ max_entries=self.max_entries,
+ num_args=self.num_args,
+ uncached_args=self.uncached_args,
+ tree=self.tree,
+ cache_context=self.cache_context,
+ iterable=self.iterable,
+ prune_unread_entries=self.prune_unread_entries,
+ name=self.name,
+ )
+ return cast(CachedFunction[F], d)
+
+
def cached(
*,
max_entries: int = 1000,
@@ -476,9 +507,8 @@ def cached(
iterable: bool = False,
prune_unread_entries: bool = True,
name: Optional[str] = None,
-) -> Callable[[F], CachedFunction[F]]:
- func = lambda orig: DeferredCacheDescriptor(
- orig,
+) -> _CachedFunctionDescriptor:
+ return _CachedFunctionDescriptor(
max_entries=max_entries,
num_args=num_args,
uncached_args=uncached_args,
@@ -489,7 +519,26 @@ def cached(
name=name,
)
- return cast(Callable[[F], CachedFunction[F]], func)
+
+@attr.s(auto_attribs=True, slots=True, frozen=True)
+class _CachedListFunctionDescriptor:
+ """Helper for `@cachedList`, we name it so that we can hook into it with mypy
+ plugin."""
+
+ cached_method_name: str
+ list_name: str
+ num_args: Optional[int] = None
+ name: Optional[str] = None
+
+ def __call__(self, orig: F) -> CachedFunction[F]:
+ d = DeferredCacheListDescriptor(
+ orig,
+ cached_method_name=self.cached_method_name,
+ list_name=self.list_name,
+ num_args=self.num_args,
+ name=self.name,
+ )
+ return cast(CachedFunction[F], d)
def cachedList(
@@ -498,7 +547,7 @@ def cachedList(
list_name: str,
num_args: Optional[int] = None,
name: Optional[str] = None,
-) -> Callable[[F], CachedFunction[F]]:
+) -> _CachedListFunctionDescriptor:
"""Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
Used to do batch lookups for an already created cache. One of the arguments
@@ -527,16 +576,13 @@ def cachedList(
def batch_do_something(self, first_arg, second_args):
...
"""
- func = lambda orig: DeferredCacheListDescriptor(
- orig,
+ return _CachedListFunctionDescriptor(
cached_method_name=cached_method_name,
list_name=list_name,
num_args=num_args,
name=name,
)
- return cast(Callable[[F], CachedFunction[F]], func)
-
def _get_cache_key_builder(
param_names: Sequence[str],
|