diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 5997603b3c..e76ee2779a 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -57,6 +57,9 @@ cache_counter = metrics.register_cache(
)
+_CacheSentinel = object()
+
+
class Cache(object):
def __init__(self, name, max_entries=1000, keylen=1, lru=True):
@@ -83,45 +86,42 @@ class Cache(object):
"Cache objects can only be accessed from the main thread"
)
- def get(self, *keyargs):
- if len(keyargs) != self.keylen:
- raise ValueError("Expected a key to have %d items", self.keylen)
-
- if keyargs in self.cache:
+ def get(self, key, default=_CacheSentinel):
+ val = self.cache.get(key, _CacheSentinel)
+ if val is not _CacheSentinel:
cache_counter.inc_hits(self.name)
- return self.cache[keyargs]
+ return val
cache_counter.inc_misses(self.name)
- raise KeyError()
- def update(self, sequence, *args):
+ if default is _CacheSentinel:
+ raise KeyError()
+ else:
+ return default
+
+ def update(self, sequence, key, value):
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
- self.prefill(*args)
-
- def prefill(self, *args): # because I can't *keyargs, value
- keyargs = args[:-1]
- value = args[-1]
-
- if len(keyargs) != self.keylen:
- raise ValueError("Expected a key to have %d items", self.keylen)
+ self.prefill(key, value)
+ def prefill(self, key, value):
if self.max_entries is not None:
while len(self.cache) >= self.max_entries:
self.cache.popitem(last=False)
- self.cache[keyargs] = value
+ self.cache[key] = value
- def invalidate(self, *keyargs):
+ def invalidate(self, key):
self.check_thread()
- if len(keyargs) != self.keylen:
- raise ValueError("Expected a key to have %d items", self.keylen)
+ if not isinstance(key, tuple):
+ raise ValueError("keyargs must be a tuple.")
+
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
- self.cache.pop(keyargs, None)
+ self.cache.pop(key, None)
def invalidate_all(self):
self.check_thread()
@@ -168,20 +168,21 @@ class CacheDescriptor(object):
% (orig.__name__,)
)
- def __get__(self, obj, objtype=None):
- cache = Cache(
+ self.cache = Cache(
name=self.orig.__name__,
max_entries=self.max_entries,
keylen=self.num_args,
lru=self.lru,
)
+ def __get__(self, obj, objtype=None):
+
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
- keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
+ cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
try:
- cached_result_d = cache.get(*keyargs)
+ cached_result_d = self.cache.get(cache_key)
observer = cached_result_d.observe()
if DEBUG_CACHES:
@@ -191,7 +192,7 @@ class CacheDescriptor(object):
if actual_result != cached_result:
logger.error(
"Stale cache entry %s%r: cached: %r, actual %r",
- self.orig.__name__, keyargs,
+ self.orig.__name__, cache_key,
cached_result, actual_result,
)
raise ValueError("Stale cache entry")
@@ -203,7 +204,7 @@ class CacheDescriptor(object):
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369)
- sequence = cache.sequence
+ sequence = self.cache.sequence
ret = defer.maybeDeferred(
self.function_to_call,
@@ -211,19 +212,19 @@ class CacheDescriptor(object):
)
def onErr(f):
- cache.invalidate(*keyargs)
+ self.cache.invalidate(cache_key)
return f
ret.addErrback(onErr)
- ret = ObservableDeferred(ret, consumeErrors=False)
- cache.update(sequence, *(keyargs + [ret]))
+ ret = ObservableDeferred(ret, consumeErrors=True)
+ self.cache.update(sequence, cache_key, ret)
return ret.observe()
- wrapped.invalidate = cache.invalidate
- wrapped.invalidate_all = cache.invalidate_all
- wrapped.prefill = cache.prefill
+ wrapped.invalidate = self.cache.invalidate
+ wrapped.invalidate_all = self.cache.invalidate_all
+ wrapped.prefill = self.cache.prefill
obj.__dict__[self.orig.__name__] = wrapped
|