diff options
Diffstat (limited to 'synapse/util/caches/descriptors.py')
-rw-r--r-- | synapse/util/caches/descriptors.py | 115 |
1 files changed, 91 insertions, 24 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 43f66ec4be..cd48262420 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -13,22 +13,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import functools import inspect import logging import threading -from collections import namedtuple +from typing import Any, Tuple, Union, cast +from weakref import WeakValueDictionary from six import itervalues from prometheus_client import Gauge +from typing_extensions import Protocol from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches import get_cache_factor_for from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry @@ -36,6 +38,20 @@ from . import register_cache logger = logging.getLogger(__name__) +CacheKey = Union[Tuple, Any] + + +class _CachedFunction(Protocol): + invalidate = None # type: Any + invalidate_all = None # type: Any + invalidate_many = None # type: Any + prefill = None # type: Any + cache = None # type: Any + num_args = None # type: Any + + def __name__(self): + ... + cache_pending_metric = Gauge( "synapse_util_caches_cache_pending", @@ -65,7 +81,6 @@ class CacheEntry(object): class Cache(object): __slots__ = ( "cache", - "max_entries", "name", "keylen", "thread", @@ -73,7 +88,29 @@ class Cache(object): "_pending_deferred_cache", ) - def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False): + def __init__( + self, + name: str, + max_entries: int = 1000, + keylen: int = 1, + tree: bool = False, + iterable: bool = False, + apply_cache_factor_from_config: bool = True, + ): + """ + Args: + name: The name of the cache + max_entries: Maximum amount of entries that the cache will hold + keylen: The length of the tuple used as the cache key + tree: Use a TreeCache instead of a dict as the underlying cache type + iterable: If True, count each item in the cached object as an entry, + rather than each cached object + apply_cache_factor_from_config: Whether cache factors specified in the + config file affect `max_entries` + + Returns: + Cache + """ cache_type = TreeCache if tree else dict self._pending_deferred_cache = cache_type() @@ -83,6 +120,7 @@ class Cache(object): cache_type=cache_type, size_callback=(lambda d: len(d)) if iterable else None, evicted_callback=self._on_evicted, + apply_cache_factor_from_config=apply_cache_factor_from_config, ) self.name = name @@ -95,6 +133,10 @@ class Cache(object): collect_callback=self._metrics_collection_callback, ) + @property + def max_entries(self): + return self.cache.max_size + def _on_evicted(self, evicted_count): self.metrics.inc_evictions(evicted_count) @@ -245,7 +287,9 @@ class Cache(object): class _CacheDescriptorBase(object): - def __init__(self, orig, num_args, inlineCallbacks, cache_context=False): + def __init__( + self, orig: _CachedFunction, num_args, inlineCallbacks, cache_context=False + ): self.orig = orig if inlineCallbacks: @@ -253,7 +297,7 @@ class _CacheDescriptorBase(object): else: self.function_to_call = orig - arg_spec = inspect.getargspec(orig) + arg_spec = inspect.getfullargspec(orig) all_args = arg_spec.args if "cache_context" in all_args: @@ -352,13 +396,11 @@ class CacheDescriptor(_CacheDescriptorBase): cache_context=cache_context, ) - max_entries = int(max_entries * get_cache_factor_for(orig.__name__)) - self.max_entries = max_entries self.tree = tree self.iterable = iterable - def __get__(self, obj, objtype=None): + def __get__(self, obj, owner): cache = Cache( name=self.orig.__name__, max_entries=self.max_entries, @@ -404,7 +446,7 @@ class CacheDescriptor(_CacheDescriptorBase): return tuple(get_cache_key_gen(args, kwargs)) @functools.wraps(self.orig) - def wrapped(*args, **kwargs): + def _wrapped(*args, **kwargs): # If we're passed a cache_context then we'll want to call its invalidate() # whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) @@ -414,7 +456,7 @@ class CacheDescriptor(_CacheDescriptorBase): # Add our own `cache_context` to argument list if the wrapped function # has asked for one if self.add_cache_context: - kwargs["cache_context"] = _CacheContext(cache, cache_key) + kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key) try: cached_result_d = cache.get(cache_key, callback=invalidate_callback) @@ -422,7 +464,7 @@ class CacheDescriptor(_CacheDescriptorBase): if isinstance(cached_result_d, ObservableDeferred): observer = cached_result_d.observe() else: - observer = cached_result_d + observer = defer.succeed(cached_result_d) except KeyError: ret = defer.maybeDeferred( @@ -440,6 +482,8 @@ class CacheDescriptor(_CacheDescriptorBase): return make_deferred_yieldable(observer) + wrapped = cast(_CachedFunction, _wrapped) + if self.num_args == 1: wrapped.invalidate = lambda key: cache.invalidate(key[0]) wrapped.prefill = lambda key, val: cache.prefill(key[0], val) @@ -464,9 +508,8 @@ class CacheListDescriptor(_CacheDescriptorBase): Given a list of keys it looks in the cache to find any hits, then passes the list of missing keys to the wrapped function. - Once wrapped, the function returns either a Deferred which resolves to - the list of results, or (if all results were cached), just the list of - results. + Once wrapped, the function returns a Deferred which resolves to the list + of results. """ def __init__( @@ -600,21 +643,45 @@ class CacheListDescriptor(_CacheDescriptorBase): ) return make_deferred_yieldable(d) else: - return results + return defer.succeed(results) obj.__dict__[self.orig.__name__] = wrapped return wrapped -class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): - # We rely on _CacheContext implementing __eq__ and __hash__ sensibly, - # which namedtuple does for us (i.e. two _CacheContext are the same if - # their caches and keys match). This is important in particular to - # dedupe when we add callbacks to lru cache nodes, otherwise the number - # of callbacks would grow. - def invalidate(self): - self.cache.invalidate(self.key) +class _CacheContext: + """Holds cache information from the cached function higher in the calling order. + + Can be used to invalidate the higher level cache entry if something changes + on a lower level. + """ + + _cache_context_objects = ( + WeakValueDictionary() + ) # type: WeakValueDictionary[Tuple[Cache, CacheKey], _CacheContext] + + def __init__(self, cache, cache_key): # type: (Cache, CacheKey) -> None + self._cache = cache + self._cache_key = cache_key + + def invalidate(self): # type: () -> None + """Invalidates the cache entry referred to by the context.""" + self._cache.invalidate(self._cache_key) + + @classmethod + def get_instance(cls, cache, cache_key): # type: (Cache, CacheKey) -> _CacheContext + """Returns an instance constructed with the given arguments. + + A new instance is only created if none already exists. + """ + + # We make sure there are no identical _CacheContext instances. This is + # important in particular to dedupe when we add callbacks to lru cache + # nodes, otherwise the number of callbacks would grow. + return cls._cache_context_objects.setdefault( + (cache, cache_key), cls(cache, cache_key) + ) def cached( |