diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 7e69cf55fb..43f66ec4be 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -19,8 +19,7 @@ import logging
import threading
from collections import namedtuple
-import six
-from six import itervalues, string_types
+from six import itervalues
from prometheus_client import Gauge
@@ -32,7 +31,6 @@ from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
-from synapse.util.stringutils import to_ascii
from . import register_cache
@@ -124,7 +122,7 @@ class Cache(object):
update_metrics (bool): whether to update the cache hit rate metrics
Returns:
- Either a Deferred or the raw result
+ Either an ObservableDeferred or the raw result
"""
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel)
@@ -148,9 +146,14 @@ class Cache(object):
return default
def set(self, key, value, callback=None):
+ if not isinstance(value, defer.Deferred):
+ raise TypeError("not a Deferred")
+
callbacks = [callback] if callback else []
self.check_thread()
- entry = CacheEntry(deferred=value, callbacks=callbacks)
+ observable = ObservableDeferred(value, consumeErrors=True)
+ observer = defer.maybeDeferred(observable.observe)
+ entry = CacheEntry(deferred=observable, callbacks=callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
@@ -158,20 +161,31 @@ class Cache(object):
self._pending_deferred_cache[key] = entry
- def shuffle(result):
+ def compare_and_pop():
+ """Check if our entry is still the one in _pending_deferred_cache, and
+ if so, pop it.
+
+ Returns true if the entries matched.
+ """
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
+ return True
+
+ # oops, the _pending_deferred_cache has been updated since
+ # we started our query, so we are out of date.
+ #
+ # Better put back whatever we took out. (We do it this way
+ # round, rather than peeking into the _pending_deferred_cache
+ # and then removing on a match, to make the common case faster)
+ if existing_entry is not None:
+ self._pending_deferred_cache[key] = existing_entry
+
+ return False
+
+ def cb(result):
+ if compare_and_pop():
self.cache.set(key, result, entry.callbacks)
else:
- # oops, the _pending_deferred_cache has been updated since
- # we started our query, so we are out of date.
- #
- # Better put back whatever we took out. (We do it this way
- # round, rather than peeking into the _pending_deferred_cache
- # and then removing on a match, to make the common case faster)
- if existing_entry is not None:
- self._pending_deferred_cache[key] = existing_entry
-
# we're not going to put this entry into the cache, so need
# to make sure that the invalidation callbacks are called.
# That was probably done when _pending_deferred_cache was
@@ -179,9 +193,16 @@ class Cache(object):
# `invalidate` being previously called, in which case it may
# not have been. Either way, let's double-check now.
entry.invalidate()
- return result
- entry.deferred.addCallback(shuffle)
+ def eb(_fail):
+ compare_and_pop()
+ entry.invalidate()
+
+ # once the deferred completes, we can move the entry from the
+ # _pending_deferred_cache to the real cache.
+ #
+ observer.addCallbacks(cb, eb)
+ return observable
def prefill(self, key, value, callback=None):
callbacks = [callback] if callback else []
@@ -414,20 +435,10 @@ class CacheDescriptor(_CacheDescriptorBase):
ret.addErrback(onErr)
- # If our cache_key is a string on py2, try to convert to ascii
- # to save a bit of space in large caches. Py3 does this
- # internally automatically.
- if six.PY2 and isinstance(cache_key, string_types):
- cache_key = to_ascii(cache_key)
-
- result_d = ObservableDeferred(ret, consumeErrors=True)
- cache.set(cache_key, result_d, callback=invalidate_callback)
+ result_d = cache.set(cache_key, ret, callback=invalidate_callback)
observer = result_d.observe()
- if isinstance(observer, defer.Deferred):
- return make_deferred_yieldable(observer)
- else:
- return observer
+ return make_deferred_yieldable(observer)
if self.num_args == 1:
wrapped.invalidate = lambda key: cache.invalidate(key[0])
@@ -543,7 +554,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
missing.add(arg)
if missing:
- # we need an observable deferred for each entry in the list,
+ # we need a deferred for each entry in the list,
# which we put in the cache. Each deferred resolves with the
# relevant result for that key.
deferreds_map = {}
@@ -551,8 +562,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
deferred = defer.Deferred()
deferreds_map[arg] = deferred
key = arg_to_cache_key(arg)
- observable = ObservableDeferred(deferred)
- cache.set(key, observable, callback=invalidate_callback)
+ cache.set(key, deferred, callback=invalidate_callback)
def complete_all(res):
# the wrapped function has completed. It returns a
|