summary refs log tree commit diff
path: root/synapse/util/caches/descriptors.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches/descriptors.py')
-rw-r--r--synapse/util/caches/descriptors.py235
1 files changed, 188 insertions, 47 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 5d7fffee66..a924140cdf 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -13,10 +13,23 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import enum
 import functools
 import inspect
 import logging
-from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
+from typing import (
+    Any,
+    Callable,
+    Generic,
+    Iterable,
+    Mapping,
+    Optional,
+    Sequence,
+    Tuple,
+    TypeVar,
+    Union,
+    cast,
+)
 from weakref import WeakValueDictionary
 
 from twisted.internet import defer
@@ -24,6 +37,7 @@ from twisted.internet import defer
 from synapse.logging.context import make_deferred_yieldable, preserve_fn
 from synapse.util import unwrapFirstError
 from synapse.util.caches.deferred_cache import DeferredCache
+from synapse.util.caches.lrucache import LruCache
 
 logger = logging.getLogger(__name__)
 
@@ -48,7 +62,7 @@ class _CachedFunction(Generic[F]):
 
 
 class _CacheDescriptorBase:
-    def __init__(self, orig: _CachedFunction, num_args, cache_context=False):
+    def __init__(self, orig: Callable[..., Any], num_args, cache_context=False):
         self.orig = orig
 
         arg_spec = inspect.getfullargspec(orig)
@@ -97,8 +111,107 @@ class _CacheDescriptorBase:
 
         self.add_cache_context = cache_context
 
+        self.cache_key_builder = get_cache_key_builder(
+            self.arg_names, self.arg_defaults
+        )
+
+
+class _LruCachedFunction(Generic[F]):
+    cache = None  # type: LruCache[CacheKey, Any]
+    __call__ = None  # type: F
+
+
+def lru_cache(
+    max_entries: int = 1000, cache_context: bool = False,
+) -> Callable[[F], _LruCachedFunction[F]]:
+    """A method decorator that applies a memoizing cache around the function.
+
+    This is more-or-less a drop-in equivalent to functools.lru_cache, although note
+    that the signature is slightly different.
+
+    The main differences with functools.lru_cache are:
+        (a) the size of the cache can be controlled via the cache_factor mechanism
+        (b) the wrapped function can request a "cache_context" which provides a
+            callback mechanism to indicate that the result is no longer valid
+        (c) prometheus metrics are exposed automatically.
+
+    The function should take zero or more arguments, which are used as the key for the
+    cache. Single-argument functions use that argument as the cache key; otherwise the
+    arguments are built into a tuple.
+
+    Cached functions can be "chained" (i.e. a cached function can call other cached
+    functions and get appropriately invalidated when they called caches are
+    invalidated) by adding a special "cache_context" argument to the function
+    and passing that as a kwarg to all caches called. For example:
+
+        @lru_cache(cache_context=True)
+        def foo(self, key, cache_context):
+            r1 = self.bar1(key, on_invalidate=cache_context.invalidate)
+            r2 = self.bar2(key, on_invalidate=cache_context.invalidate)
+            return r1 + r2
+
+    The wrapped function also has a 'cache' property which offers direct access to the
+    underlying LruCache.
+    """
+
+    def func(orig: F) -> _LruCachedFunction[F]:
+        desc = LruCacheDescriptor(
+            orig, max_entries=max_entries, cache_context=cache_context,
+        )
+        return cast(_LruCachedFunction[F], desc)
+
+    return func
+
+
+class LruCacheDescriptor(_CacheDescriptorBase):
+    """Helper for @lru_cache"""
+
+    class _Sentinel(enum.Enum):
+        sentinel = object()
+
+    def __init__(
+        self, orig, max_entries: int = 1000, cache_context: bool = False,
+    ):
+        super().__init__(orig, num_args=None, cache_context=cache_context)
+        self.max_entries = max_entries
+
+    def __get__(self, obj, owner):
+        cache = LruCache(
+            cache_name=self.orig.__name__, max_size=self.max_entries,
+        )  # type: LruCache[CacheKey, Any]
+
+        get_cache_key = self.cache_key_builder
+        sentinel = LruCacheDescriptor._Sentinel.sentinel
+
+        @functools.wraps(self.orig)
+        def _wrapped(*args, **kwargs):
+            invalidate_callback = kwargs.pop("on_invalidate", None)
+            callbacks = (invalidate_callback,) if invalidate_callback else ()
+
+            cache_key = get_cache_key(args, kwargs)
 
-class CacheDescriptor(_CacheDescriptorBase):
+            ret = cache.get(cache_key, default=sentinel, callbacks=callbacks)
+            if ret != sentinel:
+                return ret
+
+            # Add our own `cache_context` to argument list if the wrapped function
+            # has asked for one
+            if self.add_cache_context:
+                kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
+
+            ret2 = self.orig(obj, *args, **kwargs)
+            cache.set(cache_key, ret2, callbacks=callbacks)
+
+            return ret2
+
+        wrapped = cast(_CachedFunction, _wrapped)
+        wrapped.cache = cache
+        obj.__dict__[self.orig.__name__] = wrapped
+
+        return wrapped
+
+
+class DeferredCacheDescriptor(_CacheDescriptorBase):
     """ A method decorator that applies a memoizing cache around the function.
 
     This caches deferreds, rather than the results themselves. Deferreds that
@@ -141,7 +254,6 @@ class CacheDescriptor(_CacheDescriptorBase):
         cache_context=False,
         iterable=False,
     ):
-
         super().__init__(orig, num_args=num_args, cache_context=cache_context)
 
         self.max_entries = max_entries
@@ -157,41 +269,7 @@ class CacheDescriptor(_CacheDescriptorBase):
             iterable=self.iterable,
         )  # type: DeferredCache[CacheKey, Any]
 
-        def get_cache_key_gen(args, kwargs):
-            """Given some args/kwargs return a generator that resolves into
-            the cache_key.
-
-            We loop through each arg name, looking up if its in the `kwargs`,
-            otherwise using the next argument in `args`. If there are no more
-            args then we try looking the arg name up in the defaults
-            """
-            pos = 0
-            for nm in self.arg_names:
-                if nm in kwargs:
-                    yield kwargs[nm]
-                elif pos < len(args):
-                    yield args[pos]
-                    pos += 1
-                else:
-                    yield self.arg_defaults[nm]
-
-        # By default our cache key is a tuple, but if there is only one item
-        # then don't bother wrapping in a tuple.  This is to save memory.
-        if self.num_args == 1:
-            nm = self.arg_names[0]
-
-            def get_cache_key(args, kwargs):
-                if nm in kwargs:
-                    return kwargs[nm]
-                elif len(args):
-                    return args[0]
-                else:
-                    return self.arg_defaults[nm]
-
-        else:
-
-            def get_cache_key(args, kwargs):
-                return tuple(get_cache_key_gen(args, kwargs))
+        get_cache_key = self.cache_key_builder
 
         @functools.wraps(self.orig)
         def _wrapped(*args, **kwargs):
@@ -223,7 +301,6 @@ class CacheDescriptor(_CacheDescriptorBase):
             wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
         else:
             wrapped.invalidate = cache.invalidate
-            wrapped.invalidate_all = cache.invalidate_all
             wrapped.invalidate_many = cache.invalidate_many
             wrapped.prefill = cache.prefill
 
@@ -236,7 +313,7 @@ class CacheDescriptor(_CacheDescriptorBase):
         return wrapped
 
 
-class CacheListDescriptor(_CacheDescriptorBase):
+class DeferredCacheListDescriptor(_CacheDescriptorBase):
     """Wraps an existing cache to support bulk fetching of keys.
 
     Given a list of keys it looks in the cache to find any hits, then passes
@@ -382,11 +459,13 @@ class _CacheContext:
     on a lower level.
     """
 
+    Cache = Union[DeferredCache, LruCache]
+
     _cache_context_objects = (
         WeakValueDictionary()
-    )  # type: WeakValueDictionary[Tuple[DeferredCache, CacheKey], _CacheContext]
+    )  # type: WeakValueDictionary[Tuple[_CacheContext.Cache, CacheKey], _CacheContext]
 
-    def __init__(self, cache, cache_key):  # type: (DeferredCache, CacheKey) -> None
+    def __init__(self, cache: "_CacheContext.Cache", cache_key: CacheKey) -> None:
         self._cache = cache
         self._cache_key = cache_key
 
@@ -396,8 +475,8 @@ class _CacheContext:
 
     @classmethod
     def get_instance(
-        cls, cache, cache_key
-    ):  # type: (DeferredCache, CacheKey) -> _CacheContext
+        cls, cache: "_CacheContext.Cache", cache_key: CacheKey
+    ) -> "_CacheContext":
         """Returns an instance constructed with the given arguments.
 
         A new instance is only created if none already exists.
@@ -418,7 +497,7 @@ def cached(
     cache_context: bool = False,
     iterable: bool = False,
 ) -> Callable[[F], _CachedFunction[F]]:
-    func = lambda orig: CacheDescriptor(
+    func = lambda orig: DeferredCacheDescriptor(
         orig,
         max_entries=max_entries,
         num_args=num_args,
@@ -460,7 +539,7 @@ def cachedList(
             def batch_do_something(self, first_arg, second_args):
                 ...
     """
-    func = lambda orig: CacheListDescriptor(
+    func = lambda orig: DeferredCacheListDescriptor(
         orig,
         cached_method_name=cached_method_name,
         list_name=list_name,
@@ -468,3 +547,65 @@ def cachedList(
     )
 
     return cast(Callable[[F], _CachedFunction[F]], func)
+
+
+def get_cache_key_builder(
+    param_names: Sequence[str], param_defaults: Mapping[str, Any]
+) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]:
+    """Construct a function which will build cache keys suitable for a cached function
+
+    Args:
+        param_names: list of formal parameter names for the cached function
+        param_defaults: a mapping from parameter name to default value for that param
+
+    Returns:
+        A function which will take an (args, kwargs) pair and return a cache key
+    """
+
+    # By default our cache key is a tuple, but if there is only one item
+    # then don't bother wrapping in a tuple.  This is to save memory.
+
+    if len(param_names) == 1:
+        nm = param_names[0]
+
+        def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
+            if nm in kwargs:
+                return kwargs[nm]
+            elif len(args):
+                return args[0]
+            else:
+                return param_defaults[nm]
+
+    else:
+
+        def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
+            return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs))
+
+    return get_cache_key
+
+
+def _get_cache_key_gen(
+    param_names: Iterable[str],
+    param_defaults: Mapping[str, Any],
+    args: Sequence[Any],
+    kwargs: Mapping[str, Any],
+) -> Iterable[Any]:
+    """Given some args/kwargs return a generator that resolves into
+    the cache_key.
+
+    This is essentially the same operation as `inspect.getcallargs`, but optimised so
+    that we don't need to inspect the target function for each call.
+    """
+
+    # We loop through each arg name, looking up if its in the `kwargs`,
+    # otherwise using the next argument in `args`. If there are no more
+    # args then we try looking the arg name up in the defaults.
+    pos = 0
+    for nm in param_names:
+        if nm in kwargs:
+            yield kwargs[nm]
+        elif pos < len(args):
+            yield args[pos]
+            pos += 1
+        else:
+            yield param_defaults[nm]