diff options
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r-- | synapse/storage/_base.py | 85 |
1 files changed, 35 insertions, 50 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index d4751769e4..32089b05e5 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -58,6 +58,9 @@ cache_counter = metrics.register_cache( ) +_CacheSentinel = object() + + class Cache(object): def __init__(self, name, max_entries=1000, keylen=1, lru=True): @@ -74,11 +77,6 @@ class Cache(object): self.thread = None caches_by_name[name] = self.cache - class Sentinel(object): - __slots__ = [] - - self.sentinel = Sentinel() - def check_thread(self): expected_thread = self.thread if expected_thread is None: @@ -89,52 +87,38 @@ class Cache(object): "Cache objects can only be accessed from the main thread" ) - def get(self, *keyargs): - try: - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) + def get(self, keyargs, default=_CacheSentinel): + val = self.cache.get(keyargs, _CacheSentinel) + if val is not _CacheSentinel: + cache_counter.inc_hits(self.name) + return val - val = self.cache.get(keyargs, self.sentinel) - if val is not self.sentinel: - cache_counter.inc_hits(self.name) - return val + cache_counter.inc_misses(self.name) - cache_counter.inc_misses(self.name) + if default is _CacheSentinel: raise KeyError() - except KeyError: - raise - except: - logger.exception("Cache.get failed for %s" % (self.name,)) - raise - - def update(self, sequence, *args): - try: - self.check_thread() - if self.sequence == sequence: - # Only update the cache if the caches sequence number matches the - # number that the cache had before the SELECT was started (SYN-369) - self.prefill(*args) - except: - logger.exception("Cache.update failed for %s" % (self.name,)) - raise - - def prefill(self, *args): # because I can't *keyargs, value - keyargs = args[:-1] - value = args[-1] + else: + return default - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) + def update(self, sequence, keyargs, value): + self.check_thread() + if self.sequence == sequence: + # Only update the cache if the caches sequence number matches the + # number that the cache had before the SELECT was started (SYN-369) + self.prefill(keyargs, value) + def prefill(self, keyargs, value): if self.max_entries is not None: while len(self.cache) >= self.max_entries: self.cache.popitem(last=False) self.cache[keyargs] = value - def invalidate(self, *keyargs): + def invalidate(self, keyargs): self.check_thread() - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) + if not isinstance(keyargs, tuple): + raise ValueError("keyargs must be a tuple.") + # Increment the sequence number so that any SELECT statements that # raced with the INSERT don't update the cache (SYN-369) self.sequence += 1 @@ -185,20 +169,21 @@ class CacheDescriptor(object): % (orig.__name__,) ) - def __get__(self, obj, objtype=None): - cache = Cache( + self.cache = Cache( name=self.orig.__name__, max_entries=self.max_entries, keylen=self.num_args, lru=self.lru, ) + def __get__(self, obj, objtype=None): + @functools.wraps(self.orig) def wrapped(*args, **kwargs): arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) - keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] + keyargs = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) try: - cached_result_d = cache.get(*keyargs) + cached_result_d = self.cache.get(keyargs) if DEBUG_CACHES: @@ -219,7 +204,7 @@ class CacheDescriptor(object): # Get the sequence number of the cache before reading from the # database so that we can tell if the cache is invalidated # while the SELECT is executing (SYN-369) - sequence = cache.sequence + sequence = self.cache.sequence ret = defer.maybeDeferred( self.function_to_call, @@ -227,19 +212,19 @@ class CacheDescriptor(object): ) def onErr(f): - cache.invalidate(*keyargs) + self.cache.invalidate(keyargs) return f ret.addErrback(onErr) - ret = ObservableDeferred(ret, consumeErrors=False) - cache.update(sequence, *(keyargs + [ret])) + ret = ObservableDeferred(ret, consumeErrors=True) + self.cache.update(sequence, keyargs, ret) return ret.observe() - wrapped.invalidate = cache.invalidate - wrapped.invalidate_all = cache.invalidate_all - wrapped.prefill = cache.prefill + wrapped.invalidate = self.cache.invalidate + wrapped.invalidate_all = self.cache.invalidate_all + wrapped.prefill = self.cache.prefill obj.__dict__[self.orig.__name__] = wrapped |