diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 1f43886804..a4172345ef 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -23,7 +23,6 @@ from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.util import unwrapFirstError
-from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.deferred_cache import DeferredCache
logger = logging.getLogger(__name__)
@@ -156,7 +155,7 @@ class CacheDescriptor(_CacheDescriptorBase):
keylen=self.num_args,
tree=self.tree,
iterable=self.iterable,
- ) # type: DeferredCache[Tuple, Any]
+ ) # type: DeferredCache[CacheKey, Any]
def get_cache_key_gen(args, kwargs):
"""Given some args/kwargs return a generator that resolves into
@@ -208,26 +207,12 @@ class CacheDescriptor(_CacheDescriptorBase):
kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
try:
- cached_result_d = cache.get(cache_key, callback=invalidate_callback)
-
- if isinstance(cached_result_d, ObservableDeferred):
- observer = cached_result_d.observe()
- else:
- observer = defer.succeed(cached_result_d)
-
+ ret = cache.get(cache_key, callback=invalidate_callback)
except KeyError:
ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
+ ret = cache.set(cache_key, ret, callback=invalidate_callback)
- def onErr(f):
- cache.invalidate(cache_key)
- return f
-
- ret.addErrback(onErr)
-
- result_d = cache.set(cache_key, ret, callback=invalidate_callback)
- observer = result_d.observe()
-
- return make_deferred_yieldable(observer)
+ return make_deferred_yieldable(ret)
wrapped = cast(_CachedFunction, _wrapped)
@@ -286,7 +271,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
def __get__(self, obj, objtype=None):
cached_method = getattr(obj, self.cached_method_name)
- cache = cached_method.cache
+ cache = cached_method.cache # type: DeferredCache[CacheKey, Any]
num_args = cached_method.num_args
@functools.wraps(self.orig)
@@ -326,14 +311,11 @@ class CacheListDescriptor(_CacheDescriptorBase):
for arg in list_args:
try:
res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
- if not isinstance(res, ObservableDeferred):
- results[arg] = res
- elif not res.has_succeeded():
- res = res.observe()
+ if not res.called:
res.addCallback(update_results_dict, arg)
cached_defers.append(res)
else:
- results[arg] = res.get_result()
+ results[arg] = res.result
except KeyError:
missing.add(arg)
|