diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index d4751769e4..32089b05e5 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -58,6 +58,9 @@ cache_counter = metrics.register_cache(
)
+_CacheSentinel = object()
+
+
class Cache(object):
def __init__(self, name, max_entries=1000, keylen=1, lru=True):
@@ -74,11 +77,6 @@ class Cache(object):
self.thread = None
caches_by_name[name] = self.cache
- class Sentinel(object):
- __slots__ = []
-
- self.sentinel = Sentinel()
-
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
@@ -89,52 +87,38 @@ class Cache(object):
"Cache objects can only be accessed from the main thread"
)
- def get(self, *keyargs):
- try:
- if len(keyargs) != self.keylen:
- raise ValueError("Expected a key to have %d items", self.keylen)
+ def get(self, keyargs, default=_CacheSentinel):
+ val = self.cache.get(keyargs, _CacheSentinel)
+ if val is not _CacheSentinel:
+ cache_counter.inc_hits(self.name)
+ return val
- val = self.cache.get(keyargs, self.sentinel)
- if val is not self.sentinel:
- cache_counter.inc_hits(self.name)
- return val
+ cache_counter.inc_misses(self.name)
- cache_counter.inc_misses(self.name)
+ if default is _CacheSentinel:
raise KeyError()
- except KeyError:
- raise
- except:
- logger.exception("Cache.get failed for %s" % (self.name,))
- raise
-
- def update(self, sequence, *args):
- try:
- self.check_thread()
- if self.sequence == sequence:
- # Only update the cache if the caches sequence number matches the
- # number that the cache had before the SELECT was started (SYN-369)
- self.prefill(*args)
- except:
- logger.exception("Cache.update failed for %s" % (self.name,))
- raise
-
- def prefill(self, *args): # because I can't *keyargs, value
- keyargs = args[:-1]
- value = args[-1]
+ else:
+ return default
- if len(keyargs) != self.keylen:
- raise ValueError("Expected a key to have %d items", self.keylen)
+ def update(self, sequence, keyargs, value):
+ self.check_thread()
+ if self.sequence == sequence:
+ # Only update the cache if the caches sequence number matches the
+ # number that the cache had before the SELECT was started (SYN-369)
+ self.prefill(keyargs, value)
+ def prefill(self, keyargs, value):
if self.max_entries is not None:
while len(self.cache) >= self.max_entries:
self.cache.popitem(last=False)
self.cache[keyargs] = value
- def invalidate(self, *keyargs):
+ def invalidate(self, keyargs):
self.check_thread()
- if len(keyargs) != self.keylen:
- raise ValueError("Expected a key to have %d items", self.keylen)
+ if not isinstance(keyargs, tuple):
+ raise ValueError("keyargs must be a tuple.")
+
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
@@ -185,20 +169,21 @@ class CacheDescriptor(object):
% (orig.__name__,)
)
- def __get__(self, obj, objtype=None):
- cache = Cache(
+ self.cache = Cache(
name=self.orig.__name__,
max_entries=self.max_entries,
keylen=self.num_args,
lru=self.lru,
)
+ def __get__(self, obj, objtype=None):
+
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
- keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
+ keyargs = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
try:
- cached_result_d = cache.get(*keyargs)
+ cached_result_d = self.cache.get(keyargs)
if DEBUG_CACHES:
@@ -219,7 +204,7 @@ class CacheDescriptor(object):
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369)
- sequence = cache.sequence
+ sequence = self.cache.sequence
ret = defer.maybeDeferred(
self.function_to_call,
@@ -227,19 +212,19 @@ class CacheDescriptor(object):
)
def onErr(f):
- cache.invalidate(*keyargs)
+ self.cache.invalidate(keyargs)
return f
ret.addErrback(onErr)
- ret = ObservableDeferred(ret, consumeErrors=False)
- cache.update(sequence, *(keyargs + [ret]))
+ ret = ObservableDeferred(ret, consumeErrors=True)
+ self.cache.update(sequence, keyargs, ret)
return ret.observe()
- wrapped.invalidate = cache.invalidate
- wrapped.invalidate_all = cache.invalidate_all
- wrapped.prefill = cache.prefill
+ wrapped.invalidate = self.cache.invalidate
+ wrapped.invalidate_all = self.cache.invalidate_all
+ wrapped.prefill = self.cache.prefill
obj.__dict__[self.orig.__name__] = wrapped
|