diff options
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r-- | synapse/storage/_base.py | 61 |
1 files changed, 31 insertions, 30 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 0872a438f1..e07cf3b58a 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -57,6 +57,9 @@ cache_counter = metrics.register_cache( ) +_CacheSentinel = object() + + class Cache(object): def __init__(self, name, max_entries=1000, keylen=1, lru=True): @@ -83,41 +86,38 @@ class Cache(object): "Cache objects can only be accessed from the main thread" ) - def get(self, *keyargs): - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) - - if keyargs in self.cache: + def get(self, keyargs, default=_CacheSentinel): + val = self.cache.get(keyargs, _CacheSentinel) + if val is not _CacheSentinel: cache_counter.inc_hits(self.name) - return self.cache[keyargs] + return val cache_counter.inc_misses(self.name) - raise KeyError() - def update(self, sequence, *args): + if default is _CacheSentinel: + raise KeyError() + else: + return default + + def update(self, sequence, keyargs, value): self.check_thread() if self.sequence == sequence: # Only update the cache if the caches sequence number matches the # number that the cache had before the SELECT was started (SYN-369) - self.prefill(*args) - - def prefill(self, *args): # because I can't *keyargs, value - keyargs = args[:-1] - value = args[-1] - - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) + self.prefill(keyargs, value) + def prefill(self, keyargs, value): if self.max_entries is not None: while len(self.cache) >= self.max_entries: self.cache.popitem(last=False) self.cache[keyargs] = value - def invalidate(self, *keyargs): + def invalidate(self, keyargs): self.check_thread() - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) + if not isinstance(keyargs, tuple): + raise ValueError("keyargs must be a tuple.") + # Increment the sequence number so that any SELECT statements that # raced with the INSERT don't update the cache (SYN-369) self.sequence += 1 @@ -168,20 +168,21 @@ class CacheDescriptor(object): % (orig.__name__,) ) - def __get__(self, obj, objtype=None): - cache = Cache( + self.cache = Cache( name=self.orig.__name__, max_entries=self.max_entries, keylen=self.num_args, lru=self.lru, ) + def __get__(self, obj, objtype=None): + @functools.wraps(self.orig) def wrapped(*args, **kwargs): arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) - keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] + keyargs = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) try: - cached_result_d = cache.get(*keyargs) + cached_result_d = self.cache.get(keyargs) if DEBUG_CACHES: @@ -202,7 +203,7 @@ class CacheDescriptor(object): # Get the sequence number of the cache before reading from the # database so that we can tell if the cache is invalidated # while the SELECT is executing (SYN-369) - sequence = cache.sequence + sequence = self.cache.sequence ret = defer.maybeDeferred( self.function_to_call, @@ -210,19 +211,19 @@ class CacheDescriptor(object): ) def onErr(f): - cache.invalidate(*keyargs) + self.cache.invalidate(keyargs) return f ret.addErrback(onErr) - ret = ObservableDeferred(ret, consumeErrors=False) - cache.update(sequence, *(keyargs + [ret])) + ret = ObservableDeferred(ret, consumeErrors=True) + self.cache.update(sequence, keyargs, ret) return ret.observe() - wrapped.invalidate = cache.invalidate - wrapped.invalidate_all = cache.invalidate_all - wrapped.prefill = cache.prefill + wrapped.invalidate = self.cache.invalidate + wrapped.invalidate_all = self.cache.invalidate_all + wrapped.prefill = self.cache.prefill obj.__dict__[self.orig.__name__] = wrapped |