diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index faeef75506..6c162e9f34 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -57,7 +57,7 @@ class DeferredCache(Generic[KT, VT]):
"""Wraps an LruCache, adding support for Deferred results.
It expects that each entry added with set() will be a Deferred; likewise get()
- may return an ObservableDeferred.
+ will return a Deferred.
"""
__slots__ = (
@@ -130,16 +130,22 @@ class DeferredCache(Generic[KT, VT]):
key: KT,
callback: Optional[Callable[[], None]] = None,
update_metrics: bool = True,
- ) -> Union[ObservableDeferred, VT]:
+ ) -> defer.Deferred:
"""Looks the key up in the caches.
+ For symmetry with set(), this method does *not* follow the synapse logcontext
+ rules: the logcontext will not be cleared on return, and the Deferred will run
+ its callbacks in the sentinel context. In other words: wrap the result with
+ make_deferred_yieldable() before `await`ing it.
+
Args:
- key(tuple)
- callback(fn): Gets called when the entry in the cache is invalidated
+ key:
+ callback: Gets called when the entry in the cache is invalidated
update_metrics (bool): whether to update the cache hit rate metrics
Returns:
- Either an ObservableDeferred or the result itself
+ A Deferred which completes with the result. Note that this may later fail
+ if there is an ongoing set() operation which later completes with a failure.
Raises:
KeyError if the key is not found in the cache
@@ -152,7 +158,7 @@ class DeferredCache(Generic[KT, VT]):
m = self.cache.metrics
assert m # we always have a name, so should always have metrics
m.inc_hits()
- return val.deferred
+ return val.deferred.observe()
val2 = self.cache.get(
key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
@@ -160,7 +166,7 @@ class DeferredCache(Generic[KT, VT]):
if val2 is _Sentinel.sentinel:
raise KeyError()
else:
- return val2
+ return defer.succeed(val2)
def get_immediate(
self, key: KT, default: T, update_metrics: bool = True
@@ -173,7 +179,36 @@ class DeferredCache(Generic[KT, VT]):
key: KT,
value: defer.Deferred,
callback: Optional[Callable[[], None]] = None,
- ) -> ObservableDeferred:
+ ) -> defer.Deferred:
+ """Adds a new entry to the cache (or updates an existing one).
+
+ The given `value` *must* be a Deferred.
+
+ First any existing entry for the same key is invalidated. Then a new entry
+ is added to the cache for the given key.
+
+ Until the `value` completes, calls to `get()` for the key will also result in an
+ incomplete Deferred, which will ultimately complete with the same result as
+ `value`.
+
+ If `value` completes successfully, subsequent calls to `get()` will then return
+ a completed deferred with the same result. If it *fails*, the cache is
+ invalidated and subequent calls to `get()` will raise a KeyError.
+
+ If another call to `set()` happens before `value` completes, then (a) any
+ invalidation callbacks registered in the interim will be called, (b) any
+ `get()`s in the interim will continue to complete with the result from the
+ *original* `value`, (c) any future calls to `get()` will complete with the
+ result from the *new* `value`.
+
+ It is expected that `value` does *not* follow the synapse logcontext rules - ie,
+ if it is incomplete, it runs its callbacks in the sentinel context.
+
+ Args:
+ key: Key to be set
+ value: a deferred which will complete with a result to add to the cache
+ callback: An optional callback to be called when the entry is invalidated
+ """
if not isinstance(value, defer.Deferred):
raise TypeError("not a Deferred")
@@ -187,6 +222,8 @@ class DeferredCache(Generic[KT, VT]):
if existing_entry:
existing_entry.invalidate()
+ # XXX: why don't we invalidate the entry in `self.cache` yet?
+
self._pending_deferred_cache[key] = entry
def compare_and_pop():
@@ -230,7 +267,9 @@ class DeferredCache(Generic[KT, VT]):
# _pending_deferred_cache to the real cache.
#
observer.addCallbacks(cb, eb)
- return observable
+
+ # we return a new Deferred which will be called before any subsequent observers.
+ return observable.observe()
def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None):
callbacks = [callback] if callback else []
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 1f43886804..a4172345ef 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -23,7 +23,6 @@ from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.util import unwrapFirstError
-from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.deferred_cache import DeferredCache
logger = logging.getLogger(__name__)
@@ -156,7 +155,7 @@ class CacheDescriptor(_CacheDescriptorBase):
keylen=self.num_args,
tree=self.tree,
iterable=self.iterable,
- ) # type: DeferredCache[Tuple, Any]
+ ) # type: DeferredCache[CacheKey, Any]
def get_cache_key_gen(args, kwargs):
"""Given some args/kwargs return a generator that resolves into
@@ -208,26 +207,12 @@ class CacheDescriptor(_CacheDescriptorBase):
kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
try:
- cached_result_d = cache.get(cache_key, callback=invalidate_callback)
-
- if isinstance(cached_result_d, ObservableDeferred):
- observer = cached_result_d.observe()
- else:
- observer = defer.succeed(cached_result_d)
-
+ ret = cache.get(cache_key, callback=invalidate_callback)
except KeyError:
ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
+ ret = cache.set(cache_key, ret, callback=invalidate_callback)
- def onErr(f):
- cache.invalidate(cache_key)
- return f
-
- ret.addErrback(onErr)
-
- result_d = cache.set(cache_key, ret, callback=invalidate_callback)
- observer = result_d.observe()
-
- return make_deferred_yieldable(observer)
+ return make_deferred_yieldable(ret)
wrapped = cast(_CachedFunction, _wrapped)
@@ -286,7 +271,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
def __get__(self, obj, objtype=None):
cached_method = getattr(obj, self.cached_method_name)
- cache = cached_method.cache
+ cache = cached_method.cache # type: DeferredCache[CacheKey, Any]
num_args = cached_method.num_args
@functools.wraps(self.orig)
@@ -326,14 +311,11 @@ class CacheListDescriptor(_CacheDescriptorBase):
for arg in list_args:
try:
res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
- if not isinstance(res, ObservableDeferred):
- results[arg] = res
- elif not res.has_succeeded():
- res = res.observe()
+ if not res.called:
res.addCallback(update_results_dict, arg)
cached_defers.append(res)
else:
- results[arg] = res.get_result()
+ results[arg] = res.result
except KeyError:
missing.add(arg)
|