summary refs log tree commit diff
path: root/synapse/util/caches/descriptors.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches/descriptors.py')
-rw-r--r--synapse/util/caches/descriptors.py32
1 files changed, 7 insertions, 25 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 1f43886804..a4172345ef 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -23,7 +23,6 @@ from twisted.internet import defer
 
 from synapse.logging.context import make_deferred_yieldable, preserve_fn
 from synapse.util import unwrapFirstError
-from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.caches.deferred_cache import DeferredCache
 
 logger = logging.getLogger(__name__)
@@ -156,7 +155,7 @@ class CacheDescriptor(_CacheDescriptorBase):
             keylen=self.num_args,
             tree=self.tree,
             iterable=self.iterable,
-        )  # type: DeferredCache[Tuple, Any]
+        )  # type: DeferredCache[CacheKey, Any]
 
         def get_cache_key_gen(args, kwargs):
             """Given some args/kwargs return a generator that resolves into
@@ -208,26 +207,12 @@ class CacheDescriptor(_CacheDescriptorBase):
                 kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
 
             try:
-                cached_result_d = cache.get(cache_key, callback=invalidate_callback)
-
-                if isinstance(cached_result_d, ObservableDeferred):
-                    observer = cached_result_d.observe()
-                else:
-                    observer = defer.succeed(cached_result_d)
-
+                ret = cache.get(cache_key, callback=invalidate_callback)
             except KeyError:
                 ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
+                ret = cache.set(cache_key, ret, callback=invalidate_callback)
 
-                def onErr(f):
-                    cache.invalidate(cache_key)
-                    return f
-
-                ret.addErrback(onErr)
-
-                result_d = cache.set(cache_key, ret, callback=invalidate_callback)
-                observer = result_d.observe()
-
-            return make_deferred_yieldable(observer)
+            return make_deferred_yieldable(ret)
 
         wrapped = cast(_CachedFunction, _wrapped)
 
@@ -286,7 +271,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
 
     def __get__(self, obj, objtype=None):
         cached_method = getattr(obj, self.cached_method_name)
-        cache = cached_method.cache
+        cache = cached_method.cache  # type: DeferredCache[CacheKey, Any]
         num_args = cached_method.num_args
 
         @functools.wraps(self.orig)
@@ -326,14 +311,11 @@ class CacheListDescriptor(_CacheDescriptorBase):
             for arg in list_args:
                 try:
                     res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
-                    if not isinstance(res, ObservableDeferred):
-                        results[arg] = res
-                    elif not res.has_succeeded():
-                        res = res.observe()
+                    if not res.called:
                         res.addCallback(update_results_dict, arg)
                         cached_defers.append(res)
                     else:
-                        results[arg] = res.get_result()
+                        results[arg] = res.result
                 except KeyError:
                     missing.add(arg)