summary refs log tree commit diff
path: root/synapse/storage/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r--synapse/storage/_base.py85
1 files changed, 35 insertions, 50 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index d4751769e4..32089b05e5 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -58,6 +58,9 @@ cache_counter = metrics.register_cache(
 )
 
 
+_CacheSentinel = object()
+
+
 class Cache(object):
 
     def __init__(self, name, max_entries=1000, keylen=1, lru=True):
@@ -74,11 +77,6 @@ class Cache(object):
         self.thread = None
         caches_by_name[name] = self.cache
 
-        class Sentinel(object):
-            __slots__ = []
-
-        self.sentinel = Sentinel()
-
     def check_thread(self):
         expected_thread = self.thread
         if expected_thread is None:
@@ -89,52 +87,38 @@ class Cache(object):
                     "Cache objects can only be accessed from the main thread"
                 )
 
-    def get(self, *keyargs):
-        try:
-            if len(keyargs) != self.keylen:
-                raise ValueError("Expected a key to have %d items", self.keylen)
+    def get(self, keyargs, default=_CacheSentinel):
+        val = self.cache.get(keyargs, _CacheSentinel)
+        if val is not _CacheSentinel:
+            cache_counter.inc_hits(self.name)
+            return val
 
-            val = self.cache.get(keyargs, self.sentinel)
-            if val is not self.sentinel:
-                cache_counter.inc_hits(self.name)
-                return val
+        cache_counter.inc_misses(self.name)
 
-            cache_counter.inc_misses(self.name)
+        if default is _CacheSentinel:
             raise KeyError()
-        except KeyError:
-            raise
-        except:
-            logger.exception("Cache.get failed for %s" % (self.name,))
-            raise
-
-    def update(self, sequence, *args):
-        try:
-            self.check_thread()
-            if self.sequence == sequence:
-                # Only update the cache if the caches sequence number matches the
-                # number that the cache had before the SELECT was started (SYN-369)
-                self.prefill(*args)
-        except:
-            logger.exception("Cache.update failed for %s" % (self.name,))
-            raise
-
-    def prefill(self, *args):  # because I can't  *keyargs, value
-        keyargs = args[:-1]
-        value = args[-1]
+        else:
+            return default
 
-        if len(keyargs) != self.keylen:
-            raise ValueError("Expected a key to have %d items", self.keylen)
+    def update(self, sequence, keyargs, value):
+        self.check_thread()
+        if self.sequence == sequence:
+            # Only update the cache if the caches sequence number matches the
+            # number that the cache had before the SELECT was started (SYN-369)
+            self.prefill(keyargs, value)
 
+    def prefill(self, keyargs, value):
         if self.max_entries is not None:
             while len(self.cache) >= self.max_entries:
                 self.cache.popitem(last=False)
 
         self.cache[keyargs] = value
 
-    def invalidate(self, *keyargs):
+    def invalidate(self, keyargs):
         self.check_thread()
-        if len(keyargs) != self.keylen:
-            raise ValueError("Expected a key to have %d items", self.keylen)
+        if not isinstance(keyargs, tuple):
+            raise ValueError("keyargs must be a tuple.")
+
         # Increment the sequence number so that any SELECT statements that
         # raced with the INSERT don't update the cache (SYN-369)
         self.sequence += 1
@@ -185,20 +169,21 @@ class CacheDescriptor(object):
                 % (orig.__name__,)
             )
 
-    def __get__(self, obj, objtype=None):
-        cache = Cache(
+        self.cache = Cache(
             name=self.orig.__name__,
             max_entries=self.max_entries,
             keylen=self.num_args,
             lru=self.lru,
         )
 
+    def __get__(self, obj, objtype=None):
+
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
             arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
-            keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
+            keyargs = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
             try:
-                cached_result_d = cache.get(*keyargs)
+                cached_result_d = self.cache.get(keyargs)
 
                 if DEBUG_CACHES:
 
@@ -219,7 +204,7 @@ class CacheDescriptor(object):
                 # Get the sequence number of the cache before reading from the
                 # database so that we can tell if the cache is invalidated
                 # while the SELECT is executing (SYN-369)
-                sequence = cache.sequence
+                sequence = self.cache.sequence
 
                 ret = defer.maybeDeferred(
                     self.function_to_call,
@@ -227,19 +212,19 @@ class CacheDescriptor(object):
                 )
 
                 def onErr(f):
-                    cache.invalidate(*keyargs)
+                    self.cache.invalidate(keyargs)
                     return f
 
                 ret.addErrback(onErr)
 
-                ret = ObservableDeferred(ret, consumeErrors=False)
-                cache.update(sequence, *(keyargs + [ret]))
+                ret = ObservableDeferred(ret, consumeErrors=True)
+                self.cache.update(sequence, keyargs, ret)
 
                 return ret.observe()
 
-        wrapped.invalidate = cache.invalidate
-        wrapped.invalidate_all = cache.invalidate_all
-        wrapped.prefill = cache.prefill
+        wrapped.invalidate = self.cache.invalidate
+        wrapped.invalidate_all = self.cache.invalidate_all
+        wrapped.prefill = self.cache.prefill
 
         obj.__dict__[self.orig.__name__] = wrapped