diff options
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r-- | synapse/storage/_base.py | 141 |
1 files changed, 96 insertions, 45 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 8f812f0fd7..73eea157a4 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -15,6 +15,7 @@ import logging from synapse.api.errors import StoreError +from synapse.util.async import ObservableDeferred from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.lrucache import LruCache @@ -27,6 +28,7 @@ from twisted.internet import defer from collections import namedtuple, OrderedDict import functools +import inspect import sys import time import threading @@ -55,9 +57,12 @@ cache_counter = metrics.register_cache( ) +_CacheSentinel = object() + + class Cache(object): - def __init__(self, name, max_entries=1000, keylen=1, lru=False): + def __init__(self, name, max_entries=1000, keylen=1, lru=True): if lru: self.cache = LruCache(max_size=max_entries) self.max_entries = None @@ -81,45 +86,44 @@ class Cache(object): "Cache objects can only be accessed from the main thread" ) - def get(self, *keyargs): - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) - - if keyargs in self.cache: + def get(self, key, default=_CacheSentinel): + val = self.cache.get(key, _CacheSentinel) + if val is not _CacheSentinel: cache_counter.inc_hits(self.name) - return self.cache[keyargs] + return val cache_counter.inc_misses(self.name) - raise KeyError() - def update(self, sequence, *args): + if default is _CacheSentinel: + raise KeyError() + else: + return default + + def update(self, sequence, key, value): self.check_thread() if self.sequence == sequence: # Only update the cache if the caches sequence number matches the # number that the cache had before the SELECT was started (SYN-369) - self.prefill(*args) - - def prefill(self, *args): # because I can't *keyargs, value - keyargs = args[:-1] - value = args[-1] - - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) + self.prefill(key, value) + def prefill(self, key, value): if self.max_entries is not None: while len(self.cache) >= self.max_entries: self.cache.popitem(last=False) - self.cache[keyargs] = value + self.cache[key] = value - def invalidate(self, *keyargs): + def invalidate(self, key): self.check_thread() - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) + if not isinstance(key, tuple): + raise TypeError( + "The cache key must be a tuple not %r" % (type(key),) + ) + # Increment the sequence number so that any SELECT statements that # raced with the INSERT don't update the cache (SYN-369) self.sequence += 1 - self.cache.pop(keyargs, None) + self.cache.pop(key, None) def invalidate_all(self): self.check_thread() @@ -130,6 +134,9 @@ class Cache(object): class CacheDescriptor(object): """ A method decorator that applies a memoizing cache around the function. + This caches deferreds, rather than the results themselves. Deferreds that + fail are removed from the cache. + The function is presumed to take zero or more arguments, which are used in a tuple as the key for the cache. Hits are served directly from the cache; misses use the function body to generate the value. @@ -141,58 +148,92 @@ class CacheDescriptor(object): which can be used to insert values into the cache specifically, without calling the calculation function. """ - def __init__(self, orig, max_entries=1000, num_args=1, lru=False): + def __init__(self, orig, max_entries=1000, num_args=1, lru=True, + inlineCallbacks=False): self.orig = orig + if inlineCallbacks: + self.function_to_call = defer.inlineCallbacks(orig) + else: + self.function_to_call = orig + self.max_entries = max_entries self.num_args = num_args self.lru = lru - def __get__(self, obj, objtype=None): - cache = Cache( + self.arg_names = inspect.getargspec(orig).args[1:num_args+1] + + if len(self.arg_names) < self.num_args: + raise Exception( + "Not enough explicit positional arguments to key off of for %r." + " (@cached cannot key off of *args or **kwars)" + % (orig.__name__,) + ) + + self.cache = Cache( name=self.orig.__name__, max_entries=self.max_entries, keylen=self.num_args, lru=self.lru, ) + def __get__(self, obj, objtype=None): + @functools.wraps(self.orig) - @defer.inlineCallbacks - def wrapped(*keyargs): + def wrapped(*args, **kwargs): + arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) + cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) try: - cached_result = cache.get(*keyargs[:self.num_args]) + cached_result_d = self.cache.get(cache_key) + + observer = cached_result_d.observe() if DEBUG_CACHES: - actual_result = yield self.orig(obj, *keyargs) - if actual_result != cached_result: - logger.error( - "Stale cache entry %s%r: cached: %r, actual %r", - self.orig.__name__, keyargs, - cached_result, actual_result, - ) - raise ValueError("Stale cache entry") - defer.returnValue(cached_result) + @defer.inlineCallbacks + def check_result(cached_result): + actual_result = yield self.function_to_call(obj, *args, **kwargs) + if actual_result != cached_result: + logger.error( + "Stale cache entry %s%r: cached: %r, actual %r", + self.orig.__name__, cache_key, + cached_result, actual_result, + ) + raise ValueError("Stale cache entry") + defer.returnValue(cached_result) + observer.addCallback(check_result) + + return observer except KeyError: # Get the sequence number of the cache before reading from the # database so that we can tell if the cache is invalidated # while the SELECT is executing (SYN-369) - sequence = cache.sequence + sequence = self.cache.sequence + + ret = defer.maybeDeferred( + self.function_to_call, + obj, *args, **kwargs + ) + + def onErr(f): + self.cache.invalidate(cache_key) + return f - ret = yield self.orig(obj, *keyargs) + ret.addErrback(onErr) - cache.update(sequence, *keyargs[:self.num_args] + (ret,)) + ret = ObservableDeferred(ret, consumeErrors=True) + self.cache.update(sequence, cache_key, ret) - defer.returnValue(ret) + return ret.observe() - wrapped.invalidate = cache.invalidate - wrapped.invalidate_all = cache.invalidate_all - wrapped.prefill = cache.prefill + wrapped.invalidate = self.cache.invalidate + wrapped.invalidate_all = self.cache.invalidate_all + wrapped.prefill = self.cache.prefill obj.__dict__[self.orig.__name__] = wrapped return wrapped -def cached(max_entries=1000, num_args=1, lru=False): +def cached(max_entries=1000, num_args=1, lru=True): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, @@ -201,6 +242,16 @@ def cached(max_entries=1000, num_args=1, lru=False): ) +def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False): + return lambda orig: CacheDescriptor( + orig, + max_entries=max_entries, + num_args=num_args, + lru=lru, + inlineCallbacks=True, + ) + + class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() |