summary refs log tree commit diff
path: root/synapse/storage/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r--synapse/storage/_base.py141
1 files changed, 96 insertions, 45 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 8f812f0fd7..73eea157a4 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -15,6 +15,7 @@
 import logging
 
 from synapse.api.errors import StoreError
+from synapse.util.async import ObservableDeferred
 from synapse.util.logutils import log_function
 from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
 from synapse.util.lrucache import LruCache
@@ -27,6 +28,7 @@ from twisted.internet import defer
 from collections import namedtuple, OrderedDict
 
 import functools
+import inspect
 import sys
 import time
 import threading
@@ -55,9 +57,12 @@ cache_counter = metrics.register_cache(
 )
 
 
+_CacheSentinel = object()
+
+
 class Cache(object):
 
-    def __init__(self, name, max_entries=1000, keylen=1, lru=False):
+    def __init__(self, name, max_entries=1000, keylen=1, lru=True):
         if lru:
             self.cache = LruCache(max_size=max_entries)
             self.max_entries = None
@@ -81,45 +86,44 @@ class Cache(object):
                     "Cache objects can only be accessed from the main thread"
                 )
 
-    def get(self, *keyargs):
-        if len(keyargs) != self.keylen:
-            raise ValueError("Expected a key to have %d items", self.keylen)
-
-        if keyargs in self.cache:
+    def get(self, key, default=_CacheSentinel):
+        val = self.cache.get(key, _CacheSentinel)
+        if val is not _CacheSentinel:
             cache_counter.inc_hits(self.name)
-            return self.cache[keyargs]
+            return val
 
         cache_counter.inc_misses(self.name)
-        raise KeyError()
 
-    def update(self, sequence, *args):
+        if default is _CacheSentinel:
+            raise KeyError()
+        else:
+            return default
+
+    def update(self, sequence, key, value):
         self.check_thread()
         if self.sequence == sequence:
             # Only update the cache if the caches sequence number matches the
             # number that the cache had before the SELECT was started (SYN-369)
-            self.prefill(*args)
-
-    def prefill(self, *args):  # because I can't  *keyargs, value
-        keyargs = args[:-1]
-        value = args[-1]
-
-        if len(keyargs) != self.keylen:
-            raise ValueError("Expected a key to have %d items", self.keylen)
+            self.prefill(key, value)
 
+    def prefill(self, key, value):
         if self.max_entries is not None:
             while len(self.cache) >= self.max_entries:
                 self.cache.popitem(last=False)
 
-        self.cache[keyargs] = value
+        self.cache[key] = value
 
-    def invalidate(self, *keyargs):
+    def invalidate(self, key):
         self.check_thread()
-        if len(keyargs) != self.keylen:
-            raise ValueError("Expected a key to have %d items", self.keylen)
+        if not isinstance(key, tuple):
+            raise TypeError(
+                "The cache key must be a tuple not %r" % (type(key),)
+            )
+
         # Increment the sequence number so that any SELECT statements that
         # raced with the INSERT don't update the cache (SYN-369)
         self.sequence += 1
-        self.cache.pop(keyargs, None)
+        self.cache.pop(key, None)
 
     def invalidate_all(self):
         self.check_thread()
@@ -130,6 +134,9 @@ class Cache(object):
 class CacheDescriptor(object):
     """ A method decorator that applies a memoizing cache around the function.
 
+    This caches deferreds, rather than the results themselves. Deferreds that
+    fail are removed from the cache.
+
     The function is presumed to take zero or more arguments, which are used in
     a tuple as the key for the cache. Hits are served directly from the cache;
     misses use the function body to generate the value.
@@ -141,58 +148,92 @@ class CacheDescriptor(object):
     which can be used to insert values into the cache specifically, without
     calling the calculation function.
     """
-    def __init__(self, orig, max_entries=1000, num_args=1, lru=False):
+    def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
+                 inlineCallbacks=False):
         self.orig = orig
 
+        if inlineCallbacks:
+            self.function_to_call = defer.inlineCallbacks(orig)
+        else:
+            self.function_to_call = orig
+
         self.max_entries = max_entries
         self.num_args = num_args
         self.lru = lru
 
-    def __get__(self, obj, objtype=None):
-        cache = Cache(
+        self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
+
+        if len(self.arg_names) < self.num_args:
+            raise Exception(
+                "Not enough explicit positional arguments to key off of for %r."
+                " (@cached cannot key off of *args or **kwars)"
+                % (orig.__name__,)
+            )
+
+        self.cache = Cache(
             name=self.orig.__name__,
             max_entries=self.max_entries,
             keylen=self.num_args,
             lru=self.lru,
         )
 
+    def __get__(self, obj, objtype=None):
+
         @functools.wraps(self.orig)
-        @defer.inlineCallbacks
-        def wrapped(*keyargs):
+        def wrapped(*args, **kwargs):
+            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
+            cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
             try:
-                cached_result = cache.get(*keyargs[:self.num_args])
+                cached_result_d = self.cache.get(cache_key)
+
+                observer = cached_result_d.observe()
                 if DEBUG_CACHES:
-                    actual_result = yield self.orig(obj, *keyargs)
-                    if actual_result != cached_result:
-                        logger.error(
-                            "Stale cache entry %s%r: cached: %r, actual %r",
-                            self.orig.__name__, keyargs,
-                            cached_result, actual_result,
-                        )
-                        raise ValueError("Stale cache entry")
-                defer.returnValue(cached_result)
+                    @defer.inlineCallbacks
+                    def check_result(cached_result):
+                        actual_result = yield self.function_to_call(obj, *args, **kwargs)
+                        if actual_result != cached_result:
+                            logger.error(
+                                "Stale cache entry %s%r: cached: %r, actual %r",
+                                self.orig.__name__, cache_key,
+                                cached_result, actual_result,
+                            )
+                            raise ValueError("Stale cache entry")
+                        defer.returnValue(cached_result)
+                    observer.addCallback(check_result)
+
+                return observer
             except KeyError:
                 # Get the sequence number of the cache before reading from the
                 # database so that we can tell if the cache is invalidated
                 # while the SELECT is executing (SYN-369)
-                sequence = cache.sequence
+                sequence = self.cache.sequence
+
+                ret = defer.maybeDeferred(
+                    self.function_to_call,
+                    obj, *args, **kwargs
+                )
+
+                def onErr(f):
+                    self.cache.invalidate(cache_key)
+                    return f
 
-                ret = yield self.orig(obj, *keyargs)
+                ret.addErrback(onErr)
 
-                cache.update(sequence, *keyargs[:self.num_args] + (ret,))
+                ret = ObservableDeferred(ret, consumeErrors=True)
+                self.cache.update(sequence, cache_key, ret)
 
-                defer.returnValue(ret)
+                return ret.observe()
 
-        wrapped.invalidate = cache.invalidate
-        wrapped.invalidate_all = cache.invalidate_all
-        wrapped.prefill = cache.prefill
+        wrapped.invalidate = self.cache.invalidate
+        wrapped.invalidate_all = self.cache.invalidate_all
+        wrapped.prefill = self.cache.prefill
 
         obj.__dict__[self.orig.__name__] = wrapped
 
         return wrapped
 
 
-def cached(max_entries=1000, num_args=1, lru=False):
+def cached(max_entries=1000, num_args=1, lru=True):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
@@ -201,6 +242,16 @@ def cached(max_entries=1000, num_args=1, lru=False):
     )
 
 
+def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
+    return lambda orig: CacheDescriptor(
+        orig,
+        max_entries=max_entries,
+        num_args=num_args,
+        lru=lru,
+        inlineCallbacks=True,
+    )
+
+
 class LoggingTransaction(object):
     """An object that almost-transparently proxies for the 'txn' object
     passed to the constructor. Adds logging and metrics to the .execute()