summary refs log tree commit diff
path: root/synapse/util/async.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/async.py')
-rw-r--r--synapse/util/async.py66
1 files changed, 50 insertions, 16 deletions
diff --git a/synapse/util/async.py b/synapse/util/async.py
index d8febdb90c..1c2044e5b4 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -16,15 +16,13 @@
 
 from twisted.internet import defer, reactor
 
-from .logcontext import PreserveLoggingContext
+from .logcontext import preserve_context_over_deferred
 
 
-@defer.inlineCallbacks
 def sleep(seconds):
     d = defer.Deferred()
     reactor.callLater(seconds, d.callback, seconds)
-    with PreserveLoggingContext():
-        yield d
+    return preserve_context_over_deferred(d)
 
 
 def run_on_reactor():
@@ -34,20 +32,56 @@ def run_on_reactor():
     return sleep(0)
 
 
-def create_observer(deferred):
-    """Creates a deferred that observes the result or failure of the given
-     deferred *without* affecting the given deferred.
+class ObservableDeferred(object):
+    """Wraps a deferred object so that we can add observer deferreds. These
+    observer deferreds do not affect the callback chain of the original
+    deferred.
+
+    If consumeErrors is true errors will be captured from the origin deferred.
     """
-    d = defer.Deferred()
 
-    def callback(r):
-        d.callback(r)
-        return r
+    __slots__ = ["_deferred", "_observers", "_result"]
+
+    def __init__(self, deferred, consumeErrors=False):
+        object.__setattr__(self, "_deferred", deferred)
+        object.__setattr__(self, "_result", None)
+        object.__setattr__(self, "_observers", [])
+
+        def callback(r):
+            self._result = (True, r)
+            while self._observers:
+                try:
+                    self._observers.pop().callback(r)
+                except:
+                    pass
+            return r
+
+        def errback(f):
+            self._result = (False, f)
+            while self._observers:
+                try:
+                    self._observers.pop().errback(f)
+                except:
+                    pass
+
+            if consumeErrors:
+                return None
+            else:
+                return f
+
+        deferred.addCallbacks(callback, errback)
 
-    def errback(f):
-        d.errback(f)
-        return f
+    def observe(self):
+        if not self._result:
+            d = defer.Deferred()
+            self._observers.append(d)
+            return d
+        else:
+            success, res = self._result
+            return defer.succeed(res) if success else defer.fail(res)
 
-    deferred.addCallbacks(callback, errback)
+    def __getattr__(self, name):
+        return getattr(self._deferred, name)
 
-    return d
+    def __setattr__(self, name, value):
+        setattr(self._deferred, name, value)