summary refs log tree commit diff
path: root/synapse/util/caches/descriptors.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches/descriptors.py')
-rw-r--r--synapse/util/caches/descriptors.py94
1 files changed, 60 insertions, 34 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 675db2f448..43f66ec4be 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -19,8 +19,9 @@ import logging
 import threading
 from collections import namedtuple
 
-import six
-from six import itervalues, string_types
+from six import itervalues
+
+from prometheus_client import Gauge
 
 from twisted.internet import defer
 
@@ -30,13 +31,18 @@ from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.caches import get_cache_factor_for
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
-from synapse.util.stringutils import to_ascii
 
 from . import register_cache
 
 logger = logging.getLogger(__name__)
 
 
+cache_pending_metric = Gauge(
+    "synapse_util_caches_cache_pending",
+    "Number of lookups currently pending for this cache",
+    ["name"],
+)
+
 _CacheSentinel = object()
 
 
@@ -82,11 +88,19 @@ class Cache(object):
         self.name = name
         self.keylen = keylen
         self.thread = None
-        self.metrics = register_cache("cache", name, self.cache)
+        self.metrics = register_cache(
+            "cache",
+            name,
+            self.cache,
+            collect_callback=self._metrics_collection_callback,
+        )
 
     def _on_evicted(self, evicted_count):
         self.metrics.inc_evictions(evicted_count)
 
+    def _metrics_collection_callback(self):
+        cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache))
+
     def check_thread(self):
         expected_thread = self.thread
         if expected_thread is None:
@@ -108,7 +122,7 @@ class Cache(object):
             update_metrics (bool): whether to update the cache hit rate metrics
 
         Returns:
-            Either a Deferred or the raw result
+            Either an ObservableDeferred or the raw result
         """
         callbacks = [callback] if callback else []
         val = self._pending_deferred_cache.get(key, _CacheSentinel)
@@ -132,9 +146,14 @@ class Cache(object):
             return default
 
     def set(self, key, value, callback=None):
+        if not isinstance(value, defer.Deferred):
+            raise TypeError("not a Deferred")
+
         callbacks = [callback] if callback else []
         self.check_thread()
-        entry = CacheEntry(deferred=value, callbacks=callbacks)
+        observable = ObservableDeferred(value, consumeErrors=True)
+        observer = defer.maybeDeferred(observable.observe)
+        entry = CacheEntry(deferred=observable, callbacks=callbacks)
 
         existing_entry = self._pending_deferred_cache.pop(key, None)
         if existing_entry:
@@ -142,20 +161,31 @@ class Cache(object):
 
         self._pending_deferred_cache[key] = entry
 
-        def shuffle(result):
+        def compare_and_pop():
+            """Check if our entry is still the one in _pending_deferred_cache, and
+            if so, pop it.
+
+            Returns true if the entries matched.
+            """
             existing_entry = self._pending_deferred_cache.pop(key, None)
             if existing_entry is entry:
+                return True
+
+            # oops, the _pending_deferred_cache has been updated since
+            # we started our query, so we are out of date.
+            #
+            # Better put back whatever we took out. (We do it this way
+            # round, rather than peeking into the _pending_deferred_cache
+            # and then removing on a match, to make the common case faster)
+            if existing_entry is not None:
+                self._pending_deferred_cache[key] = existing_entry
+
+            return False
+
+        def cb(result):
+            if compare_and_pop():
                 self.cache.set(key, result, entry.callbacks)
             else:
-                # oops, the _pending_deferred_cache has been updated since
-                # we started our query, so we are out of date.
-                #
-                # Better put back whatever we took out. (We do it this way
-                # round, rather than peeking into the _pending_deferred_cache
-                # and then removing on a match, to make the common case faster)
-                if existing_entry is not None:
-                    self._pending_deferred_cache[key] = existing_entry
-
                 # we're not going to put this entry into the cache, so need
                 # to make sure that the invalidation callbacks are called.
                 # That was probably done when _pending_deferred_cache was
@@ -163,9 +193,16 @@ class Cache(object):
                 # `invalidate` being previously called, in which case it may
                 # not have been. Either way, let's double-check now.
                 entry.invalidate()
-            return result
 
-        entry.deferred.addCallback(shuffle)
+        def eb(_fail):
+            compare_and_pop()
+            entry.invalidate()
+
+        # once the deferred completes, we can move the entry from the
+        # _pending_deferred_cache to the real cache.
+        #
+        observer.addCallbacks(cb, eb)
+        return observable
 
     def prefill(self, key, value, callback=None):
         callbacks = [callback] if callback else []
@@ -289,7 +326,7 @@ class CacheDescriptor(_CacheDescriptorBase):
         def foo(self, key, cache_context):
             r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate)
             r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
-            defer.returnValue(r1 + r2)
+            return r1 + r2
 
     Args:
         num_args (int): number of positional arguments (excluding ``self`` and
@@ -398,20 +435,10 @@ class CacheDescriptor(_CacheDescriptorBase):
 
                 ret.addErrback(onErr)
 
-                # If our cache_key is a string on py2, try to convert to ascii
-                # to save a bit of space in large caches. Py3 does this
-                # internally automatically.
-                if six.PY2 and isinstance(cache_key, string_types):
-                    cache_key = to_ascii(cache_key)
-
-                result_d = ObservableDeferred(ret, consumeErrors=True)
-                cache.set(cache_key, result_d, callback=invalidate_callback)
+                result_d = cache.set(cache_key, ret, callback=invalidate_callback)
                 observer = result_d.observe()
 
-            if isinstance(observer, defer.Deferred):
-                return make_deferred_yieldable(observer)
-            else:
-                return observer
+            return make_deferred_yieldable(observer)
 
         if self.num_args == 1:
             wrapped.invalidate = lambda key: cache.invalidate(key[0])
@@ -527,7 +554,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
                     missing.add(arg)
 
             if missing:
-                # we need an observable deferred for each entry in the list,
+                # we need a deferred for each entry in the list,
                 # which we put in the cache. Each deferred resolves with the
                 # relevant result for that key.
                 deferreds_map = {}
@@ -535,8 +562,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
                     deferred = defer.Deferred()
                     deferreds_map[arg] = deferred
                     key = arg_to_cache_key(arg)
-                    observable = ObservableDeferred(deferred)
-                    cache.set(key, observable, callback=invalidate_callback)
+                    cache.set(key, deferred, callback=invalidate_callback)
 
                 def complete_all(res):
                     # the wrapped function has completed. It returns a