summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/async.py9
-rw-r--r--synapse/util/caches/descriptors.py35
2 files changed, 34 insertions, 10 deletions
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 0d6f48e2d8..40be7fe7e3 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -102,6 +102,15 @@ class ObservableDeferred(object):
     def observers(self):
         return self._observers
 
+    def has_called(self):
+        return self._result is not None
+
+    def has_succeeded(self):
+        return self._result is not None and self._result[0] is True
+
+    def get_result(self):
+        return self._result[1]
+
     def __getattr__(self, name):
         return getattr(self._deferred, name)
 
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 5d25c9e762..f31dfb22b7 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -33,6 +33,7 @@ import functools
 import inspect
 import threading
 
+
 logger = logging.getLogger(__name__)
 
 
@@ -302,16 +303,21 @@ class CacheListDescriptor(object):
 
             # cached is a dict arg -> deferred, where deferred results in a
             # 2-tuple (`arg`, `result`)
-            cached = {}
+            results = {}
+            cached_defers = {}
             missing = []
             for arg in list_args:
                 key = list(keyargs)
                 key[self.list_pos] = arg
 
                 try:
-                    res = cache.get(tuple(key)).observe()
-                    res.addCallback(lambda r, arg: (arg, r), arg)
-                    cached[arg] = res
+                    res = cache.get(tuple(key))
+                    if not res.has_succeeded():
+                        res = res.observe()
+                        res.addCallback(lambda r, arg: (arg, r), arg)
+                        cached_defers[arg] = res
+                    else:
+                        results[arg] = res.get_result()
                 except KeyError:
                     missing.append(arg)
 
@@ -349,12 +355,21 @@ class CacheListDescriptor(object):
                     res = observer.observe()
                     res.addCallback(lambda r, arg: (arg, r), arg)
 
-                    cached[arg] = res
-
-            return preserve_context_over_deferred(defer.gatherResults(
-                cached.values(),
-                consumeErrors=True,
-            ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)))
+                    cached_defers[arg] = res
+
+            if cached_defers:
+                def update_results_dict(res):
+                    results.update(res)
+                    return results
+
+                return preserve_context_over_deferred(defer.gatherResults(
+                    cached_defers.values(),
+                    consumeErrors=True,
+                ).addCallback(update_results_dict).addErrback(
+                    unwrapFirstError
+                ))
+            else:
+                return results
 
         obj.__dict__[self.orig.__name__] = wrapped