summary refs log tree commit diff
path: root/synapse/util/caches/descriptors.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches/descriptors.py')
-rw-r--r--synapse/util/caches/descriptors.py22
1 files changed, 20 insertions, 2 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 43f66ec4be..5ac2530a6a 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -18,10 +18,12 @@ import inspect
 import logging
 import threading
 from collections import namedtuple
+from typing import Any, cast
 
 from six import itervalues
 
 from prometheus_client import Gauge
+from typing_extensions import Protocol
 
 from twisted.internet import defer
 
@@ -37,6 +39,18 @@ from . import register_cache
 logger = logging.getLogger(__name__)
 
 
+class _CachedFunction(Protocol):
+    invalidate = None  # type: Any
+    invalidate_all = None  # type: Any
+    invalidate_many = None  # type: Any
+    prefill = None  # type: Any
+    cache = None  # type: Any
+    num_args = None  # type: Any
+
+    def __name__(self):
+        ...
+
+
 cache_pending_metric = Gauge(
     "synapse_util_caches_cache_pending",
     "Number of lookups currently pending for this cache",
@@ -245,7 +259,9 @@ class Cache(object):
 
 
 class _CacheDescriptorBase(object):
-    def __init__(self, orig, num_args, inlineCallbacks, cache_context=False):
+    def __init__(
+        self, orig: _CachedFunction, num_args, inlineCallbacks, cache_context=False
+    ):
         self.orig = orig
 
         if inlineCallbacks:
@@ -404,7 +420,7 @@ class CacheDescriptor(_CacheDescriptorBase):
                 return tuple(get_cache_key_gen(args, kwargs))
 
         @functools.wraps(self.orig)
-        def wrapped(*args, **kwargs):
+        def _wrapped(*args, **kwargs):
             # If we're passed a cache_context then we'll want to call its invalidate()
             # whenever we are invalidated
             invalidate_callback = kwargs.pop("on_invalidate", None)
@@ -440,6 +456,8 @@ class CacheDescriptor(_CacheDescriptorBase):
 
             return make_deferred_yieldable(observer)
 
+        wrapped = cast(_CachedFunction, _wrapped)
+
         if self.num_args == 1:
             wrapped.invalidate = lambda key: cache.invalidate(key[0])
             wrapped.prefill = lambda key, val: cache.prefill(key[0], val)