diff options
Diffstat (limited to 'synapse/util')
-rw-r--r-- | synapse/util/caches/descriptors.py | 74 |
1 files changed, 58 insertions, 16 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 1cdead02f1..c3c5c16db9 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -20,6 +20,7 @@ from typing import ( Any, Awaitable, Callable, + Collection, Dict, Generic, Hashable, @@ -69,6 +70,7 @@ class _CacheDescriptorBase: self, orig: Callable[..., Any], num_args: Optional[int], + uncached_args: Optional[Collection[str]] = None, cache_context: bool = False, ): self.orig = orig @@ -76,6 +78,13 @@ class _CacheDescriptorBase: arg_spec = inspect.getfullargspec(orig) all_args = arg_spec.args + # There's no reason that keyword-only arguments couldn't be supported, + # but right now they're buggy so do not allow them. + if arg_spec.kwonlyargs: + raise ValueError( + "_CacheDescriptorBase does not support keyword-only arguments." + ) + if "cache_context" in all_args: if not cache_context: raise ValueError( @@ -88,6 +97,9 @@ class _CacheDescriptorBase: " named `cache_context`" ) + if num_args is not None and uncached_args is not None: + raise ValueError("Cannot provide both num_args and uncached_args") + if num_args is None: num_args = len(all_args) - 1 if cache_context: @@ -105,6 +117,12 @@ class _CacheDescriptorBase: # list of the names of the args used as the cache key self.arg_names = all_args[1 : num_args + 1] + # If there are args to not cache on, filter them out (and fix the size of num_args). + if uncached_args is not None: + include_arg_in_cache_key = [n not in uncached_args for n in self.arg_names] + else: + include_arg_in_cache_key = [True] * len(self.arg_names) + # self.arg_defaults is a map of arg name to its default value for each # argument that has a default value if arg_spec.defaults: @@ -119,8 +137,8 @@ class _CacheDescriptorBase: self.add_cache_context = cache_context - self.cache_key_builder = get_cache_key_builder( - self.arg_names, self.arg_defaults + self.cache_key_builder = _get_cache_key_builder( + self.arg_names, include_arg_in_cache_key, self.arg_defaults ) @@ -130,8 +148,7 @@ class _LruCachedFunction(Generic[F]): def lru_cache( - max_entries: int = 1000, - cache_context: bool = False, + *, max_entries: int = 1000, cache_context: bool = False ) -> Callable[[F], _LruCachedFunction[F]]: """A method decorator that applies a memoizing cache around the function. @@ -186,7 +203,9 @@ class LruCacheDescriptor(_CacheDescriptorBase): max_entries: int = 1000, cache_context: bool = False, ): - super().__init__(orig, num_args=None, cache_context=cache_context) + super().__init__( + orig, num_args=None, uncached_args=None, cache_context=cache_context + ) self.max_entries = max_entries def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: @@ -260,6 +279,9 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): num_args: number of positional arguments (excluding ``self`` and ``cache_context``) to use as cache keys. Defaults to all named args of the function. + uncached_args: a list of argument names to not use as the cache key. + (``self`` and ``cache_context`` are always ignored.) Cannot be used + with num_args. tree: cache_context: iterable: @@ -273,12 +295,18 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): orig: Callable[..., Any], max_entries: int = 1000, num_args: Optional[int] = None, + uncached_args: Optional[Collection[str]] = None, tree: bool = False, cache_context: bool = False, iterable: bool = False, prune_unread_entries: bool = True, ): - super().__init__(orig, num_args=num_args, cache_context=cache_context) + super().__init__( + orig, + num_args=num_args, + uncached_args=uncached_args, + cache_context=cache_context, + ) if tree and self.num_args < 2: raise RuntimeError( @@ -369,7 +397,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): but including list_name) to use as cache keys. Defaults to all named args of the function. """ - super().__init__(orig, num_args=num_args) + super().__init__(orig, num_args=num_args, uncached_args=None) self.list_name = list_name @@ -530,8 +558,10 @@ class _CacheContext: def cached( + *, max_entries: int = 1000, num_args: Optional[int] = None, + uncached_args: Optional[Collection[str]] = None, tree: bool = False, cache_context: bool = False, iterable: bool = False, @@ -541,6 +571,7 @@ def cached( orig, max_entries=max_entries, num_args=num_args, + uncached_args=uncached_args, tree=tree, cache_context=cache_context, iterable=iterable, @@ -551,7 +582,7 @@ def cached( def cachedList( - cached_method_name: str, list_name: str, num_args: Optional[int] = None + *, cached_method_name: str, list_name: str, num_args: Optional[int] = None ) -> Callable[[F], _CachedFunction[F]]: """Creates a descriptor that wraps a function in a `CacheListDescriptor`. @@ -590,13 +621,16 @@ def cachedList( return cast(Callable[[F], _CachedFunction[F]], func) -def get_cache_key_builder( - param_names: Sequence[str], param_defaults: Mapping[str, Any] +def _get_cache_key_builder( + param_names: Sequence[str], + include_params: Sequence[bool], + param_defaults: Mapping[str, Any], ) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]: """Construct a function which will build cache keys suitable for a cached function Args: param_names: list of formal parameter names for the cached function + include_params: list of bools of whether to include the parameter name in the cache key param_defaults: a mapping from parameter name to default value for that param Returns: @@ -608,6 +642,7 @@ def get_cache_key_builder( if len(param_names) == 1: nm = param_names[0] + assert include_params[0] is True def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey: if nm in kwargs: @@ -620,13 +655,18 @@ def get_cache_key_builder( else: def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey: - return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs)) + return tuple( + _get_cache_key_gen( + param_names, include_params, param_defaults, args, kwargs + ) + ) return get_cache_key def _get_cache_key_gen( param_names: Iterable[str], + include_params: Iterable[bool], param_defaults: Mapping[str, Any], args: Sequence[Any], kwargs: Mapping[str, Any], @@ -637,16 +677,18 @@ def _get_cache_key_gen( This is essentially the same operation as `inspect.getcallargs`, but optimised so that we don't need to inspect the target function for each call. """ - # We loop through each arg name, looking up if its in the `kwargs`, # otherwise using the next argument in `args`. If there are no more # args then we try looking the arg name up in the defaults. pos = 0 - for nm in param_names: + for nm, inc in zip(param_names, include_params): if nm in kwargs: - yield kwargs[nm] + if inc: + yield kwargs[nm] elif pos < len(args): - yield args[pos] + if inc: + yield args[pos] pos += 1 else: - yield param_defaults[nm] + if inc: + yield param_defaults[nm] |