summary refs log tree commit diff
path: root/synapse/util/caches/descriptors.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches/descriptors.py')
-rw-r--r--synapse/util/caches/descriptors.py64
1 files changed, 55 insertions, 9 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 8514a75a1c..ce736fdf75 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -36,6 +36,8 @@ from typing import (
 )
 from weakref import WeakValueDictionary
 
+import attr
+
 from twisted.internet import defer
 from twisted.python.failure import Failure
 
@@ -466,6 +468,35 @@ class _CacheContext:
         )
 
 
+@attr.s(auto_attribs=True, slots=True, frozen=True)
+class _CachedFunctionDescriptor:
+    """Helper for `@cached`, we name it so that we can hook into it with mypy
+    plugin."""
+
+    max_entries: int
+    num_args: Optional[int]
+    uncached_args: Optional[Collection[str]]
+    tree: bool
+    cache_context: bool
+    iterable: bool
+    prune_unread_entries: bool
+    name: Optional[str]
+
+    def __call__(self, orig: F) -> CachedFunction[F]:
+        d = DeferredCacheDescriptor(
+            orig,
+            max_entries=self.max_entries,
+            num_args=self.num_args,
+            uncached_args=self.uncached_args,
+            tree=self.tree,
+            cache_context=self.cache_context,
+            iterable=self.iterable,
+            prune_unread_entries=self.prune_unread_entries,
+            name=self.name,
+        )
+        return cast(CachedFunction[F], d)
+
+
 def cached(
     *,
     max_entries: int = 1000,
@@ -476,9 +507,8 @@ def cached(
     iterable: bool = False,
     prune_unread_entries: bool = True,
     name: Optional[str] = None,
-) -> Callable[[F], CachedFunction[F]]:
-    func = lambda orig: DeferredCacheDescriptor(
-        orig,
+) -> _CachedFunctionDescriptor:
+    return _CachedFunctionDescriptor(
         max_entries=max_entries,
         num_args=num_args,
         uncached_args=uncached_args,
@@ -489,7 +519,26 @@ def cached(
         name=name,
     )
 
-    return cast(Callable[[F], CachedFunction[F]], func)
+
+@attr.s(auto_attribs=True, slots=True, frozen=True)
+class _CachedListFunctionDescriptor:
+    """Helper for `@cachedList`, we name it so that we can hook into it with mypy
+    plugin."""
+
+    cached_method_name: str
+    list_name: str
+    num_args: Optional[int] = None
+    name: Optional[str] = None
+
+    def __call__(self, orig: F) -> CachedFunction[F]:
+        d = DeferredCacheListDescriptor(
+            orig,
+            cached_method_name=self.cached_method_name,
+            list_name=self.list_name,
+            num_args=self.num_args,
+            name=self.name,
+        )
+        return cast(CachedFunction[F], d)
 
 
 def cachedList(
@@ -498,7 +547,7 @@ def cachedList(
     list_name: str,
     num_args: Optional[int] = None,
     name: Optional[str] = None,
-) -> Callable[[F], CachedFunction[F]]:
+) -> _CachedListFunctionDescriptor:
     """Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
 
     Used to do batch lookups for an already created cache. One of the arguments
@@ -527,16 +576,13 @@ def cachedList(
             def batch_do_something(self, first_arg, second_args):
                 ...
     """
-    func = lambda orig: DeferredCacheListDescriptor(
-        orig,
+    return _CachedListFunctionDescriptor(
         cached_method_name=cached_method_name,
         list_name=list_name,
         num_args=num_args,
         name=name,
     )
 
-    return cast(Callable[[F], CachedFunction[F]], func)
-
 
 def _get_cache_key_builder(
     param_names: Sequence[str],