summary refs log tree commit diff
path: root/synapse/storage/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r--synapse/storage/_base.py61
1 files changed, 31 insertions, 30 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 0872a438f1..e07cf3b58a 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -57,6 +57,9 @@ cache_counter = metrics.register_cache(
 )
 
 
+_CacheSentinel = object()
+
+
 class Cache(object):
 
     def __init__(self, name, max_entries=1000, keylen=1, lru=True):
@@ -83,41 +86,38 @@ class Cache(object):
                     "Cache objects can only be accessed from the main thread"
                 )
 
-    def get(self, *keyargs):
-        if len(keyargs) != self.keylen:
-            raise ValueError("Expected a key to have %d items", self.keylen)
-
-        if keyargs in self.cache:
+    def get(self, keyargs, default=_CacheSentinel):
+        val = self.cache.get(keyargs, _CacheSentinel)
+        if val is not _CacheSentinel:
             cache_counter.inc_hits(self.name)
-            return self.cache[keyargs]
+            return val
 
         cache_counter.inc_misses(self.name)
-        raise KeyError()
 
-    def update(self, sequence, *args):
+        if default is _CacheSentinel:
+            raise KeyError()
+        else:
+            return default
+
+    def update(self, sequence, keyargs, value):
         self.check_thread()
         if self.sequence == sequence:
             # Only update the cache if the caches sequence number matches the
             # number that the cache had before the SELECT was started (SYN-369)
-            self.prefill(*args)
-
-    def prefill(self, *args):  # because I can't  *keyargs, value
-        keyargs = args[:-1]
-        value = args[-1]
-
-        if len(keyargs) != self.keylen:
-            raise ValueError("Expected a key to have %d items", self.keylen)
+            self.prefill(keyargs, value)
 
+    def prefill(self, keyargs, value):
         if self.max_entries is not None:
             while len(self.cache) >= self.max_entries:
                 self.cache.popitem(last=False)
 
         self.cache[keyargs] = value
 
-    def invalidate(self, *keyargs):
+    def invalidate(self, keyargs):
         self.check_thread()
-        if len(keyargs) != self.keylen:
-            raise ValueError("Expected a key to have %d items", self.keylen)
+        if not isinstance(keyargs, tuple):
+            raise ValueError("keyargs must be a tuple.")
+
         # Increment the sequence number so that any SELECT statements that
         # raced with the INSERT don't update the cache (SYN-369)
         self.sequence += 1
@@ -168,20 +168,21 @@ class CacheDescriptor(object):
                 % (orig.__name__,)
             )
 
-    def __get__(self, obj, objtype=None):
-        cache = Cache(
+        self.cache = Cache(
             name=self.orig.__name__,
             max_entries=self.max_entries,
             keylen=self.num_args,
             lru=self.lru,
         )
 
+    def __get__(self, obj, objtype=None):
+
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
             arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
-            keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
+            keyargs = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
             try:
-                cached_result_d = cache.get(*keyargs)
+                cached_result_d = self.cache.get(keyargs)
 
                 if DEBUG_CACHES:
 
@@ -202,7 +203,7 @@ class CacheDescriptor(object):
                 # Get the sequence number of the cache before reading from the
                 # database so that we can tell if the cache is invalidated
                 # while the SELECT is executing (SYN-369)
-                sequence = cache.sequence
+                sequence = self.cache.sequence
 
                 ret = defer.maybeDeferred(
                     self.function_to_call,
@@ -210,19 +211,19 @@ class CacheDescriptor(object):
                 )
 
                 def onErr(f):
-                    cache.invalidate(*keyargs)
+                    self.cache.invalidate(keyargs)
                     return f
 
                 ret.addErrback(onErr)
 
-                ret = ObservableDeferred(ret, consumeErrors=False)
-                cache.update(sequence, *(keyargs + [ret]))
+                ret = ObservableDeferred(ret, consumeErrors=True)
+                self.cache.update(sequence, keyargs, ret)
 
                 return ret.observe()
 
-        wrapped.invalidate = cache.invalidate
-        wrapped.invalidate_all = cache.invalidate_all
-        wrapped.prefill = cache.prefill
+        wrapped.invalidate = self.cache.invalidate
+        wrapped.invalidate_all = self.cache.invalidate_all
+        wrapped.prefill = self.cache.prefill
 
         obj.__dict__[self.orig.__name__] = wrapped