diff options
Diffstat (limited to '')
-rw-r--r-- | synapse/util/caches/descriptors.py | 235 |
1 files changed, 188 insertions, 47 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 5d7fffee66..a924140cdf 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -13,10 +13,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import enum import functools import inspect import logging -from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast +from typing import ( + Any, + Callable, + Generic, + Iterable, + Mapping, + Optional, + Sequence, + Tuple, + TypeVar, + Union, + cast, +) from weakref import WeakValueDictionary from twisted.internet import defer @@ -24,6 +37,7 @@ from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.util import unwrapFirstError from synapse.util.caches.deferred_cache import DeferredCache +from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -48,7 +62,7 @@ class _CachedFunction(Generic[F]): class _CacheDescriptorBase: - def __init__(self, orig: _CachedFunction, num_args, cache_context=False): + def __init__(self, orig: Callable[..., Any], num_args, cache_context=False): self.orig = orig arg_spec = inspect.getfullargspec(orig) @@ -97,8 +111,107 @@ class _CacheDescriptorBase: self.add_cache_context = cache_context + self.cache_key_builder = get_cache_key_builder( + self.arg_names, self.arg_defaults + ) + + +class _LruCachedFunction(Generic[F]): + cache = None # type: LruCache[CacheKey, Any] + __call__ = None # type: F + + +def lru_cache( + max_entries: int = 1000, cache_context: bool = False, +) -> Callable[[F], _LruCachedFunction[F]]: + """A method decorator that applies a memoizing cache around the function. + + This is more-or-less a drop-in equivalent to functools.lru_cache, although note + that the signature is slightly different. + + The main differences with functools.lru_cache are: + (a) the size of the cache can be controlled via the cache_factor mechanism + (b) the wrapped function can request a "cache_context" which provides a + callback mechanism to indicate that the result is no longer valid + (c) prometheus metrics are exposed automatically. + + The function should take zero or more arguments, which are used as the key for the + cache. Single-argument functions use that argument as the cache key; otherwise the + arguments are built into a tuple. + + Cached functions can be "chained" (i.e. a cached function can call other cached + functions and get appropriately invalidated when they called caches are + invalidated) by adding a special "cache_context" argument to the function + and passing that as a kwarg to all caches called. For example: + + @lru_cache(cache_context=True) + def foo(self, key, cache_context): + r1 = self.bar1(key, on_invalidate=cache_context.invalidate) + r2 = self.bar2(key, on_invalidate=cache_context.invalidate) + return r1 + r2 + + The wrapped function also has a 'cache' property which offers direct access to the + underlying LruCache. + """ + + def func(orig: F) -> _LruCachedFunction[F]: + desc = LruCacheDescriptor( + orig, max_entries=max_entries, cache_context=cache_context, + ) + return cast(_LruCachedFunction[F], desc) + + return func + + +class LruCacheDescriptor(_CacheDescriptorBase): + """Helper for @lru_cache""" + + class _Sentinel(enum.Enum): + sentinel = object() + + def __init__( + self, orig, max_entries: int = 1000, cache_context: bool = False, + ): + super().__init__(orig, num_args=None, cache_context=cache_context) + self.max_entries = max_entries + + def __get__(self, obj, owner): + cache = LruCache( + cache_name=self.orig.__name__, max_size=self.max_entries, + ) # type: LruCache[CacheKey, Any] + + get_cache_key = self.cache_key_builder + sentinel = LruCacheDescriptor._Sentinel.sentinel + + @functools.wraps(self.orig) + def _wrapped(*args, **kwargs): + invalidate_callback = kwargs.pop("on_invalidate", None) + callbacks = (invalidate_callback,) if invalidate_callback else () + + cache_key = get_cache_key(args, kwargs) -class CacheDescriptor(_CacheDescriptorBase): + ret = cache.get(cache_key, default=sentinel, callbacks=callbacks) + if ret != sentinel: + return ret + + # Add our own `cache_context` to argument list if the wrapped function + # has asked for one + if self.add_cache_context: + kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key) + + ret2 = self.orig(obj, *args, **kwargs) + cache.set(cache_key, ret2, callbacks=callbacks) + + return ret2 + + wrapped = cast(_CachedFunction, _wrapped) + wrapped.cache = cache + obj.__dict__[self.orig.__name__] = wrapped + + return wrapped + + +class DeferredCacheDescriptor(_CacheDescriptorBase): """ A method decorator that applies a memoizing cache around the function. This caches deferreds, rather than the results themselves. Deferreds that @@ -141,7 +254,6 @@ class CacheDescriptor(_CacheDescriptorBase): cache_context=False, iterable=False, ): - super().__init__(orig, num_args=num_args, cache_context=cache_context) self.max_entries = max_entries @@ -157,41 +269,7 @@ class CacheDescriptor(_CacheDescriptorBase): iterable=self.iterable, ) # type: DeferredCache[CacheKey, Any] - def get_cache_key_gen(args, kwargs): - """Given some args/kwargs return a generator that resolves into - the cache_key. - - We loop through each arg name, looking up if its in the `kwargs`, - otherwise using the next argument in `args`. If there are no more - args then we try looking the arg name up in the defaults - """ - pos = 0 - for nm in self.arg_names: - if nm in kwargs: - yield kwargs[nm] - elif pos < len(args): - yield args[pos] - pos += 1 - else: - yield self.arg_defaults[nm] - - # By default our cache key is a tuple, but if there is only one item - # then don't bother wrapping in a tuple. This is to save memory. - if self.num_args == 1: - nm = self.arg_names[0] - - def get_cache_key(args, kwargs): - if nm in kwargs: - return kwargs[nm] - elif len(args): - return args[0] - else: - return self.arg_defaults[nm] - - else: - - def get_cache_key(args, kwargs): - return tuple(get_cache_key_gen(args, kwargs)) + get_cache_key = self.cache_key_builder @functools.wraps(self.orig) def _wrapped(*args, **kwargs): @@ -223,7 +301,6 @@ class CacheDescriptor(_CacheDescriptorBase): wrapped.prefill = lambda key, val: cache.prefill(key[0], val) else: wrapped.invalidate = cache.invalidate - wrapped.invalidate_all = cache.invalidate_all wrapped.invalidate_many = cache.invalidate_many wrapped.prefill = cache.prefill @@ -236,7 +313,7 @@ class CacheDescriptor(_CacheDescriptorBase): return wrapped -class CacheListDescriptor(_CacheDescriptorBase): +class DeferredCacheListDescriptor(_CacheDescriptorBase): """Wraps an existing cache to support bulk fetching of keys. Given a list of keys it looks in the cache to find any hits, then passes @@ -382,11 +459,13 @@ class _CacheContext: on a lower level. """ + Cache = Union[DeferredCache, LruCache] + _cache_context_objects = ( WeakValueDictionary() - ) # type: WeakValueDictionary[Tuple[DeferredCache, CacheKey], _CacheContext] + ) # type: WeakValueDictionary[Tuple[_CacheContext.Cache, CacheKey], _CacheContext] - def __init__(self, cache, cache_key): # type: (DeferredCache, CacheKey) -> None + def __init__(self, cache: "_CacheContext.Cache", cache_key: CacheKey) -> None: self._cache = cache self._cache_key = cache_key @@ -396,8 +475,8 @@ class _CacheContext: @classmethod def get_instance( - cls, cache, cache_key - ): # type: (DeferredCache, CacheKey) -> _CacheContext + cls, cache: "_CacheContext.Cache", cache_key: CacheKey + ) -> "_CacheContext": """Returns an instance constructed with the given arguments. A new instance is only created if none already exists. @@ -418,7 +497,7 @@ def cached( cache_context: bool = False, iterable: bool = False, ) -> Callable[[F], _CachedFunction[F]]: - func = lambda orig: CacheDescriptor( + func = lambda orig: DeferredCacheDescriptor( orig, max_entries=max_entries, num_args=num_args, @@ -460,7 +539,7 @@ def cachedList( def batch_do_something(self, first_arg, second_args): ... """ - func = lambda orig: CacheListDescriptor( + func = lambda orig: DeferredCacheListDescriptor( orig, cached_method_name=cached_method_name, list_name=list_name, @@ -468,3 +547,65 @@ def cachedList( ) return cast(Callable[[F], _CachedFunction[F]], func) + + +def get_cache_key_builder( + param_names: Sequence[str], param_defaults: Mapping[str, Any] +) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]: + """Construct a function which will build cache keys suitable for a cached function + + Args: + param_names: list of formal parameter names for the cached function + param_defaults: a mapping from parameter name to default value for that param + + Returns: + A function which will take an (args, kwargs) pair and return a cache key + """ + + # By default our cache key is a tuple, but if there is only one item + # then don't bother wrapping in a tuple. This is to save memory. + + if len(param_names) == 1: + nm = param_names[0] + + def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey: + if nm in kwargs: + return kwargs[nm] + elif len(args): + return args[0] + else: + return param_defaults[nm] + + else: + + def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey: + return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs)) + + return get_cache_key + + +def _get_cache_key_gen( + param_names: Iterable[str], + param_defaults: Mapping[str, Any], + args: Sequence[Any], + kwargs: Mapping[str, Any], +) -> Iterable[Any]: + """Given some args/kwargs return a generator that resolves into + the cache_key. + + This is essentially the same operation as `inspect.getcallargs`, but optimised so + that we don't need to inspect the target function for each call. + """ + + # We loop through each arg name, looking up if its in the `kwargs`, + # otherwise using the next argument in `args`. If there are no more + # args then we try looking the arg name up in the defaults. + pos = 0 + for nm in param_names: + if nm in kwargs: + yield kwargs[nm] + elif pos < len(args): + yield args[pos] + pos += 1 + else: + yield param_defaults[nm] |