summary refs log tree commit diff
path: root/synapse/util/caches/descriptors.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches/descriptors.py')
-rw-r--r--synapse/util/caches/descriptors.py97
1 files changed, 74 insertions, 23 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index d082c26b1f..b3b2d6092d 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -17,7 +17,7 @@ import logging
 from synapse.util.async import ObservableDeferred
 from synapse.util import unwrapFirstError
 from synapse.util.caches.lrucache import LruCache
-from synapse.util.caches.treecache import TreeCache
+from synapse.util.caches.treecache import TreeCache, popped_to_iterator
 from synapse.util.logcontext import (
     PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
 )
@@ -42,11 +42,23 @@ _CacheSentinel = object()
 CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
 
 
-def deferred_size(deferred):
-    if deferred.called:
-        return len(deferred.result)
-    else:
-        return 1
+class CacheEntry(object):
+    __slots__ = [
+        "deferred", "sequence", "callbacks", "invalidated"
+    ]
+
+    def __init__(self, deferred, sequence, callbacks):
+        self.deferred = deferred
+        self.sequence = sequence
+        self.callbacks = set(callbacks)
+        self.invalidated = False
+
+    def invalidate(self):
+        if not self.invalidated:
+            self.invalidated = True
+            for callback in self.callbacks:
+                callback()
+            self.callbacks.clear()
 
 
 class Cache(object):
@@ -58,13 +70,16 @@ class Cache(object):
         "sequence",
         "thread",
         "metrics",
+        "_pending_deferred_cache",
     )
 
     def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False):
         cache_type = TreeCache if tree else dict
+        self._pending_deferred_cache = cache_type()
+
         self.cache = LruCache(
             max_size=max_entries, keylen=keylen, cache_type=cache_type,
-            size_callback=deferred_size if iterable else None,
+            size_callback=(lambda d: len(d.result)) if iterable else None,
         )
 
         self.name = name
@@ -84,7 +99,15 @@ class Cache(object):
                 )
 
     def get(self, key, default=_CacheSentinel, callback=None):
-        val = self.cache.get(key, _CacheSentinel, callback=callback)
+        callbacks = [callback] if callback else []
+        val = self._pending_deferred_cache.get(key, _CacheSentinel)
+        if val is not _CacheSentinel:
+            if val.sequence == self.sequence:
+                val.callbacks.update(callbacks)
+                self.metrics.inc_hits()
+                return val.deferred
+
+        val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
         if val is not _CacheSentinel:
             self.metrics.inc_hits()
             return val
@@ -96,15 +119,39 @@ class Cache(object):
         else:
             return default
 
-    def update(self, sequence, key, value, callback=None):
+    def set(self, key, value, callback=None):
+        callbacks = [callback] if callback else []
         self.check_thread()
-        if self.sequence == sequence:
-            # Only update the cache if the caches sequence number matches the
-            # number that the cache had before the SELECT was started (SYN-369)
-            self.prefill(key, value, callback=callback)
+        entry = CacheEntry(
+            deferred=value,
+            sequence=self.sequence,
+            callbacks=callbacks,
+        )
+
+        entry.callbacks.update(callbacks)
+
+        existing_entry = self._pending_deferred_cache.pop(key, None)
+        if existing_entry:
+            existing_entry.invalidate()
+
+        self._pending_deferred_cache[key] = entry
+
+        def shuffle(result):
+            if self.sequence == entry.sequence:
+                existing_entry = self._pending_deferred_cache.pop(key, None)
+                if existing_entry is entry:
+                    self.cache.set(key, entry.deferred, entry.callbacks)
+                else:
+                    entry.invalidate()
+            else:
+                entry.invalidate()
+            return result
+
+        entry.deferred.addCallback(shuffle)
 
     def prefill(self, key, value, callback=None):
-        self.cache.set(key, value, callback=callback)
+        callbacks = [callback] if callback else []
+        self.cache.set(key, value, callbacks=callbacks)
 
     def invalidate(self, key):
         self.check_thread()
@@ -116,6 +163,10 @@ class Cache(object):
         # Increment the sequence number so that any SELECT statements that
         # raced with the INSERT don't update the cache (SYN-369)
         self.sequence += 1
+        entry = self._pending_deferred_cache.pop(key, None)
+        if entry:
+            entry.invalidate()
+
         self.cache.pop(key, None)
 
     def invalidate_many(self, key):
@@ -127,6 +178,12 @@ class Cache(object):
         self.sequence += 1
         self.cache.del_multi(key)
 
+        val = self._pending_deferred_cache.pop(key, None)
+        if val is not None:
+            entry_dict, _ = val
+            for entry in popped_to_iterator(entry_dict):
+                entry.invalidate()
+
     def invalidate_all(self):
         self.check_thread()
         self.sequence += 1
@@ -254,11 +311,6 @@ class CacheDescriptor(object):
 
                 return preserve_context_over_deferred(observer)
             except KeyError:
-                # Get the sequence number of the cache before reading from the
-                # database so that we can tell if the cache is invalidated
-                # while the SELECT is executing (SYN-369)
-                sequence = cache.sequence
-
                 ret = defer.maybeDeferred(
                     preserve_context_over_fn,
                     self.function_to_call,
@@ -272,7 +324,7 @@ class CacheDescriptor(object):
                 ret.addErrback(onErr)
 
                 ret = ObservableDeferred(ret, consumeErrors=True)
-                cache.update(sequence, cache_key, ret, callback=invalidate_callback)
+                cache.set(cache_key, ret, callback=invalidate_callback)
 
                 return preserve_context_over_deferred(ret.observe())
 
@@ -370,7 +422,6 @@ class CacheListDescriptor(object):
                     missing.append(arg)
 
             if missing:
-                sequence = cache.sequence
                 args_to_call = dict(arg_dict)
                 args_to_call[self.list_name] = missing
 
@@ -393,8 +444,8 @@ class CacheListDescriptor(object):
 
                     key = list(keyargs)
                     key[self.list_pos] = arg
-                    cache.update(
-                        sequence, tuple(key), observer,
+                    cache.set(
+                        tuple(key), observer,
                         callback=invalidate_callback
                     )