summary refs log tree commit diff
path: root/synapse/util/caches/descriptors.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches/descriptors.py')
-rw-r--r--synapse/util/caches/descriptors.py110
1 files changed, 88 insertions, 22 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 8dba61d49f..a9ea97fd46 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -17,7 +17,7 @@ import logging
 from synapse.util.async import ObservableDeferred
 from synapse.util import unwrapFirstError
 from synapse.util.caches.lrucache import LruCache
-from synapse.util.caches.treecache import TreeCache
+from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
 from synapse.util.logcontext import (
     PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
 )
@@ -42,6 +42,25 @@ _CacheSentinel = object()
 CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
 
 
+class CacheEntry(object):
+    __slots__ = [
+        "deferred", "sequence", "callbacks", "invalidated"
+    ]
+
+    def __init__(self, deferred, sequence, callbacks):
+        self.deferred = deferred
+        self.sequence = sequence
+        self.callbacks = set(callbacks)
+        self.invalidated = False
+
+    def invalidate(self):
+        if not self.invalidated:
+            self.invalidated = True
+            for callback in self.callbacks:
+                callback()
+            self.callbacks.clear()
+
+
 class Cache(object):
     __slots__ = (
         "cache",
@@ -51,12 +70,16 @@ class Cache(object):
         "sequence",
         "thread",
         "metrics",
+        "_pending_deferred_cache",
     )
 
-    def __init__(self, name, max_entries=1000, keylen=1, tree=False):
+    def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False):
         cache_type = TreeCache if tree else dict
+        self._pending_deferred_cache = cache_type()
+
         self.cache = LruCache(
-            max_size=max_entries, keylen=keylen, cache_type=cache_type
+            max_size=max_entries, keylen=keylen, cache_type=cache_type,
+            size_callback=(lambda d: len(d.result)) if iterable else None,
         )
 
         self.name = name
@@ -76,7 +99,15 @@ class Cache(object):
                 )
 
     def get(self, key, default=_CacheSentinel, callback=None):
-        val = self.cache.get(key, _CacheSentinel, callback=callback)
+        callbacks = [callback] if callback else []
+        val = self._pending_deferred_cache.get(key, _CacheSentinel)
+        if val is not _CacheSentinel:
+            if val.sequence == self.sequence:
+                val.callbacks.update(callbacks)
+                self.metrics.inc_hits()
+                return val.deferred
+
+        val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
         if val is not _CacheSentinel:
             self.metrics.inc_hits()
             return val
@@ -88,15 +119,39 @@ class Cache(object):
         else:
             return default
 
-    def update(self, sequence, key, value, callback=None):
+    def set(self, key, value, callback=None):
+        callbacks = [callback] if callback else []
         self.check_thread()
-        if self.sequence == sequence:
-            # Only update the cache if the caches sequence number matches the
-            # number that the cache had before the SELECT was started (SYN-369)
-            self.prefill(key, value, callback=callback)
+        entry = CacheEntry(
+            deferred=value,
+            sequence=self.sequence,
+            callbacks=callbacks,
+        )
+
+        entry.callbacks.update(callbacks)
+
+        existing_entry = self._pending_deferred_cache.pop(key, None)
+        if existing_entry:
+            existing_entry.invalidate()
+
+        self._pending_deferred_cache[key] = entry
+
+        def shuffle(result):
+            if self.sequence == entry.sequence:
+                existing_entry = self._pending_deferred_cache.pop(key, None)
+                if existing_entry is entry:
+                    self.cache.set(key, entry.deferred, entry.callbacks)
+                else:
+                    entry.invalidate()
+            else:
+                entry.invalidate()
+            return result
+
+        entry.deferred.addCallback(shuffle)
 
     def prefill(self, key, value, callback=None):
-        self.cache.set(key, value, callback=callback)
+        callbacks = [callback] if callback else []
+        self.cache.set(key, value, callbacks=callbacks)
 
     def invalidate(self, key):
         self.check_thread()
@@ -108,6 +163,10 @@ class Cache(object):
         # Increment the sequence number so that any SELECT statements that
         # raced with the INSERT don't update the cache (SYN-369)
         self.sequence += 1
+        entry = self._pending_deferred_cache.pop(key, None)
+        if entry:
+            entry.invalidate()
+
         self.cache.pop(key, None)
 
     def invalidate_many(self, key):
@@ -119,6 +178,12 @@ class Cache(object):
         self.sequence += 1
         self.cache.del_multi(key)
 
+        val = self._pending_deferred_cache.pop(key, None)
+        if val is not None:
+            entry_dict, _ = val
+            for entry in iterate_tree_cache_entry(entry_dict):
+                entry.invalidate()
+
     def invalidate_all(self):
         self.check_thread()
         self.sequence += 1
@@ -155,7 +220,7 @@ class CacheDescriptor(object):
 
     """
     def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
-                 inlineCallbacks=False, cache_context=False):
+                 inlineCallbacks=False, cache_context=False, iterable=False):
         max_entries = int(max_entries * CACHE_SIZE_FACTOR)
 
         self.orig = orig
@@ -169,6 +234,8 @@ class CacheDescriptor(object):
         self.num_args = num_args
         self.tree = tree
 
+        self.iterable = iterable
+
         all_args = inspect.getargspec(orig)
         self.arg_names = all_args.args[1:num_args + 1]
 
@@ -203,6 +270,7 @@ class CacheDescriptor(object):
             max_entries=self.max_entries,
             keylen=self.num_args,
             tree=self.tree,
+            iterable=self.iterable,
         )
 
         @functools.wraps(self.orig)
@@ -243,11 +311,6 @@ class CacheDescriptor(object):
 
                 return preserve_context_over_deferred(observer)
             except KeyError:
-                # Get the sequence number of the cache before reading from the
-                # database so that we can tell if the cache is invalidated
-                # while the SELECT is executing (SYN-369)
-                sequence = cache.sequence
-
                 ret = defer.maybeDeferred(
                     preserve_context_over_fn,
                     self.function_to_call,
@@ -261,7 +324,7 @@ class CacheDescriptor(object):
                 ret.addErrback(onErr)
 
                 ret = ObservableDeferred(ret, consumeErrors=True)
-                cache.update(sequence, cache_key, ret, callback=invalidate_callback)
+                cache.set(cache_key, ret, callback=invalidate_callback)
 
                 return preserve_context_over_deferred(ret.observe())
 
@@ -359,7 +422,6 @@ class CacheListDescriptor(object):
                     missing.append(arg)
 
             if missing:
-                sequence = cache.sequence
                 args_to_call = dict(arg_dict)
                 args_to_call[self.list_name] = missing
 
@@ -382,8 +444,8 @@ class CacheListDescriptor(object):
 
                     key = list(keyargs)
                     key[self.list_pos] = arg
-                    cache.update(
-                        sequence, tuple(key), observer,
+                    cache.set(
+                        tuple(key), observer,
                         callback=invalidate_callback
                     )
 
@@ -421,17 +483,20 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
         self.cache.invalidate(self.key)
 
 
-def cached(max_entries=1000, num_args=1, tree=False, cache_context=False):
+def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
+           iterable=False):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
         num_args=num_args,
         tree=tree,
         cache_context=cache_context,
+        iterable=iterable,
     )
 
 
-def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False):
+def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False,
+                          iterable=False):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
@@ -439,6 +504,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex
         tree=tree,
         inlineCallbacks=True,
         cache_context=cache_context,
+        iterable=iterable,
     )