diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 867f315b2a..9d4bc89edb 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -73,8 +73,10 @@ class _CacheDescriptorBase:
num_args: Optional[int],
uncached_args: Optional[Collection[str]] = None,
cache_context: bool = False,
+ name: Optional[str] = None,
):
self.orig = orig
+ self.name = name or orig.__name__
arg_spec = inspect.getfullargspec(orig)
all_args = arg_spec.args
@@ -211,7 +213,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
cache: LruCache[CacheKey, Any] = LruCache(
- cache_name=self.orig.__name__,
+ cache_name=self.name,
max_size=self.max_entries,
)
@@ -241,7 +243,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
wrapped = cast(_CachedFunction, _wrapped)
wrapped.cache = cache
- obj.__dict__[self.orig.__name__] = wrapped
+ obj.__dict__[self.name] = wrapped
return wrapped
@@ -301,12 +303,14 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
cache_context: bool = False,
iterable: bool = False,
prune_unread_entries: bool = True,
+ name: Optional[str] = None,
):
super().__init__(
orig,
num_args=num_args,
uncached_args=uncached_args,
cache_context=cache_context,
+ name=name,
)
if tree and self.num_args < 2:
@@ -321,7 +325,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
cache: DeferredCache[CacheKey, Any] = DeferredCache(
- name=self.orig.__name__,
+ name=self.name,
max_entries=self.max_entries,
tree=self.tree,
iterable=self.iterable,
@@ -372,7 +376,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
wrapped.cache = cache
wrapped.num_args = self.num_args
- obj.__dict__[self.orig.__name__] = wrapped
+ obj.__dict__[self.name] = wrapped
return wrapped
@@ -393,6 +397,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
cached_method_name: str,
list_name: str,
num_args: Optional[int] = None,
+ name: Optional[str] = None,
):
"""
Args:
@@ -403,7 +408,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
but including list_name) to use as cache keys. Defaults to all
named args of the function.
"""
- super().__init__(orig, num_args=num_args, uncached_args=None)
+ super().__init__(orig, num_args=num_args, uncached_args=None, name=name)
self.list_name = list_name
@@ -525,7 +530,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
else:
return defer.succeed(results)
- obj.__dict__[self.orig.__name__] = wrapped
+ obj.__dict__[self.name] = wrapped
return wrapped
@@ -577,6 +582,7 @@ def cached(
cache_context: bool = False,
iterable: bool = False,
prune_unread_entries: bool = True,
+ name: Optional[str] = None,
) -> Callable[[F], _CachedFunction[F]]:
func = lambda orig: DeferredCacheDescriptor(
orig,
@@ -587,13 +593,18 @@ def cached(
cache_context=cache_context,
iterable=iterable,
prune_unread_entries=prune_unread_entries,
+ name=name,
)
return cast(Callable[[F], _CachedFunction[F]], func)
def cachedList(
- *, cached_method_name: str, list_name: str, num_args: Optional[int] = None
+ *,
+ cached_method_name: str,
+ list_name: str,
+ num_args: Optional[int] = None,
+ name: Optional[str] = None,
) -> Callable[[F], _CachedFunction[F]]:
"""Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
@@ -628,6 +639,7 @@ def cachedList(
cached_method_name=cached_method_name,
list_name=list_name,
num_args=num_args,
+ name=name,
)
return cast(Callable[[F], _CachedFunction[F]], func)
|