summary refs log tree commit diff
path: root/synapse/util/caches/descriptors.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2017-01-17 11:18:13 +0000
committerErik Johnston <erik@matrix.org>2017-01-17 11:18:13 +0000
commitf85b6ca494ae587731d99196020cc74d7eca012a (patch)
tree453615672d4125641e192bb92cf9a7abdd68d345 /synapse/util/caches/descriptors.py
parentAdd ExpiringCache tests (diff)
downloadsynapse-f85b6ca494ae587731d99196020cc74d7eca012a.tar.xz
Speed up cache size calculation
Instead of calculating the size of the cache repeatedly, which can take
a long time now that it can use a callback, instead cache the size and
update that on insertion and deletion.

This requires changing the cache descriptors to have two caches, one for
pending deferreds and the other for the actual values. There's no reason
to evict from the pending deferreds as they won't take up any more
memory.
Diffstat (limited to 'synapse/util/caches/descriptors.py')
-rw-r--r--synapse/util/caches/descriptors.py97
1 files changed, 74 insertions, 23 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index d082c26b1f..b3b2d6092d 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -17,7 +17,7 @@ import logging
 from synapse.util.async import ObservableDeferred
 from synapse.util import unwrapFirstError
 from synapse.util.caches.lrucache import LruCache
-from synapse.util.caches.treecache import TreeCache
+from synapse.util.caches.treecache import TreeCache, popped_to_iterator
 from synapse.util.logcontext import (
     PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
 )
@@ -42,11 +42,23 @@ _CacheSentinel = object()
 CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
 
 
-def deferred_size(deferred):
-    if deferred.called:
-        return len(deferred.result)
-    else:
-        return 1
+class CacheEntry(object):
+    __slots__ = [
+        "deferred", "sequence", "callbacks", "invalidated"
+    ]
+
+    def __init__(self, deferred, sequence, callbacks):
+        self.deferred = deferred
+        self.sequence = sequence
+        self.callbacks = set(callbacks)
+        self.invalidated = False
+
+    def invalidate(self):
+        if not self.invalidated:
+            self.invalidated = True
+            for callback in self.callbacks:
+                callback()
+            self.callbacks.clear()
 
 
 class Cache(object):
@@ -58,13 +70,16 @@ class Cache(object):
         "sequence",
         "thread",
         "metrics",
+        "_pending_deferred_cache",
     )
 
     def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False):
         cache_type = TreeCache if tree else dict
+        self._pending_deferred_cache = cache_type()
+
         self.cache = LruCache(
             max_size=max_entries, keylen=keylen, cache_type=cache_type,
-            size_callback=deferred_size if iterable else None,
+            size_callback=(lambda d: len(d.result)) if iterable else None,
         )
 
         self.name = name
@@ -84,7 +99,15 @@ class Cache(object):
                 )
 
     def get(self, key, default=_CacheSentinel, callback=None):
-        val = self.cache.get(key, _CacheSentinel, callback=callback)
+        callbacks = [callback] if callback else []
+        val = self._pending_deferred_cache.get(key, _CacheSentinel)
+        if val is not _CacheSentinel:
+            if val.sequence == self.sequence:
+                val.callbacks.update(callbacks)
+                self.metrics.inc_hits()
+                return val.deferred
+
+        val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
         if val is not _CacheSentinel:
             self.metrics.inc_hits()
             return val
@@ -96,15 +119,39 @@ class Cache(object):
         else:
             return default
 
-    def update(self, sequence, key, value, callback=None):
+    def set(self, key, value, callback=None):
+        callbacks = [callback] if callback else []
         self.check_thread()
-        if self.sequence == sequence:
-            # Only update the cache if the caches sequence number matches the
-            # number that the cache had before the SELECT was started (SYN-369)
-            self.prefill(key, value, callback=callback)
+        entry = CacheEntry(
+            deferred=value,
+            sequence=self.sequence,
+            callbacks=callbacks,
+        )
+
+        entry.callbacks.update(callbacks)
+
+        existing_entry = self._pending_deferred_cache.pop(key, None)
+        if existing_entry:
+            existing_entry.invalidate()
+
+        self._pending_deferred_cache[key] = entry
+
+        def shuffle(result):
+            if self.sequence == entry.sequence:
+                existing_entry = self._pending_deferred_cache.pop(key, None)
+                if existing_entry is entry:
+                    self.cache.set(key, entry.deferred, entry.callbacks)
+                else:
+                    entry.invalidate()
+            else:
+                entry.invalidate()
+            return result
+
+        entry.deferred.addCallback(shuffle)
 
     def prefill(self, key, value, callback=None):
-        self.cache.set(key, value, callback=callback)
+        callbacks = [callback] if callback else []
+        self.cache.set(key, value, callbacks=callbacks)
 
     def invalidate(self, key):
         self.check_thread()
@@ -116,6 +163,10 @@ class Cache(object):
         # Increment the sequence number so that any SELECT statements that
         # raced with the INSERT don't update the cache (SYN-369)
         self.sequence += 1
+        entry = self._pending_deferred_cache.pop(key, None)
+        if entry:
+            entry.invalidate()
+
         self.cache.pop(key, None)
 
     def invalidate_many(self, key):
@@ -127,6 +178,12 @@ class Cache(object):
         self.sequence += 1
         self.cache.del_multi(key)
 
+        val = self._pending_deferred_cache.pop(key, None)
+        if val is not None:
+            entry_dict, _ = val
+            for entry in popped_to_iterator(entry_dict):
+                entry.invalidate()
+
     def invalidate_all(self):
         self.check_thread()
         self.sequence += 1
@@ -254,11 +311,6 @@ class CacheDescriptor(object):
 
                 return preserve_context_over_deferred(observer)
             except KeyError:
-                # Get the sequence number of the cache before reading from the
-                # database so that we can tell if the cache is invalidated
-                # while the SELECT is executing (SYN-369)
-                sequence = cache.sequence
-
                 ret = defer.maybeDeferred(
                     preserve_context_over_fn,
                     self.function_to_call,
@@ -272,7 +324,7 @@ class CacheDescriptor(object):
                 ret.addErrback(onErr)
 
                 ret = ObservableDeferred(ret, consumeErrors=True)
-                cache.update(sequence, cache_key, ret, callback=invalidate_callback)
+                cache.set(cache_key, ret, callback=invalidate_callback)
 
                 return preserve_context_over_deferred(ret.observe())
 
@@ -370,7 +422,6 @@ class CacheListDescriptor(object):
                     missing.append(arg)
 
             if missing:
-                sequence = cache.sequence
                 args_to_call = dict(arg_dict)
                 args_to_call[self.list_name] = missing
 
@@ -393,8 +444,8 @@ class CacheListDescriptor(object):
 
                     key = list(keyargs)
                     key[self.list_pos] = arg
-                    cache.update(
-                        sequence, tuple(key), observer,
+                    cache.set(
+                        tuple(key), observer,
                         callback=invalidate_callback
                     )