summary refs log tree commit diff
path: root/synapse/storage/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r--synapse/storage/_base.py67
1 files changed, 34 insertions, 33 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 5997603b3c..e76ee2779a 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -57,6 +57,9 @@ cache_counter = metrics.register_cache(
 )
 
 
+_CacheSentinel = object()
+
+
 class Cache(object):
 
     def __init__(self, name, max_entries=1000, keylen=1, lru=True):
@@ -83,45 +86,42 @@ class Cache(object):
                     "Cache objects can only be accessed from the main thread"
                 )
 
-    def get(self, *keyargs):
-        if len(keyargs) != self.keylen:
-            raise ValueError("Expected a key to have %d items", self.keylen)
-
-        if keyargs in self.cache:
+    def get(self, key, default=_CacheSentinel):
+        val = self.cache.get(key, _CacheSentinel)
+        if val is not _CacheSentinel:
             cache_counter.inc_hits(self.name)
-            return self.cache[keyargs]
+            return val
 
         cache_counter.inc_misses(self.name)
-        raise KeyError()
 
-    def update(self, sequence, *args):
+        if default is _CacheSentinel:
+            raise KeyError()
+        else:
+            return default
+
+    def update(self, sequence, key, value):
         self.check_thread()
         if self.sequence == sequence:
             # Only update the cache if the caches sequence number matches the
             # number that the cache had before the SELECT was started (SYN-369)
-            self.prefill(*args)
-
-    def prefill(self, *args):  # because I can't  *keyargs, value
-        keyargs = args[:-1]
-        value = args[-1]
-
-        if len(keyargs) != self.keylen:
-            raise ValueError("Expected a key to have %d items", self.keylen)
+            self.prefill(key, value)
 
+    def prefill(self, key, value):
         if self.max_entries is not None:
             while len(self.cache) >= self.max_entries:
                 self.cache.popitem(last=False)
 
-        self.cache[keyargs] = value
+        self.cache[key] = value
 
-    def invalidate(self, *keyargs):
+    def invalidate(self, key):
         self.check_thread()
-        if len(keyargs) != self.keylen:
-            raise ValueError("Expected a key to have %d items", self.keylen)
+        if not isinstance(key, tuple):
+            raise ValueError("keyargs must be a tuple.")
+
         # Increment the sequence number so that any SELECT statements that
         # raced with the INSERT don't update the cache (SYN-369)
         self.sequence += 1
-        self.cache.pop(keyargs, None)
+        self.cache.pop(key, None)
 
     def invalidate_all(self):
         self.check_thread()
@@ -168,20 +168,21 @@ class CacheDescriptor(object):
                 % (orig.__name__,)
             )
 
-    def __get__(self, obj, objtype=None):
-        cache = Cache(
+        self.cache = Cache(
             name=self.orig.__name__,
             max_entries=self.max_entries,
             keylen=self.num_args,
             lru=self.lru,
         )
 
+    def __get__(self, obj, objtype=None):
+
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
             arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
-            keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
+            cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
             try:
-                cached_result_d = cache.get(*keyargs)
+                cached_result_d = self.cache.get(cache_key)
 
                 observer = cached_result_d.observe()
                 if DEBUG_CACHES:
@@ -191,7 +192,7 @@ class CacheDescriptor(object):
                         if actual_result != cached_result:
                             logger.error(
                                 "Stale cache entry %s%r: cached: %r, actual %r",
-                                self.orig.__name__, keyargs,
+                                self.orig.__name__, cache_key,
                                 cached_result, actual_result,
                             )
                             raise ValueError("Stale cache entry")
@@ -203,7 +204,7 @@ class CacheDescriptor(object):
                 # Get the sequence number of the cache before reading from the
                 # database so that we can tell if the cache is invalidated
                 # while the SELECT is executing (SYN-369)
-                sequence = cache.sequence
+                sequence = self.cache.sequence
 
                 ret = defer.maybeDeferred(
                     self.function_to_call,
@@ -211,19 +212,19 @@ class CacheDescriptor(object):
                 )
 
                 def onErr(f):
-                    cache.invalidate(*keyargs)
+                    self.cache.invalidate(cache_key)
                     return f
 
                 ret.addErrback(onErr)
 
-                ret = ObservableDeferred(ret, consumeErrors=False)
-                cache.update(sequence, *(keyargs + [ret]))
+                ret = ObservableDeferred(ret, consumeErrors=True)
+                self.cache.update(sequence, cache_key, ret)
 
                 return ret.observe()
 
-        wrapped.invalidate = cache.invalidate
-        wrapped.invalidate_all = cache.invalidate_all
-        wrapped.prefill = cache.prefill
+        wrapped.invalidate = self.cache.invalidate
+        wrapped.invalidate_all = self.cache.invalidate_all
+        wrapped.prefill = self.cache.prefill
 
         obj.__dict__[self.orig.__name__] = wrapped