diff --git a/synapse/util/async.py b/synapse/util/async.py
index d8febdb90c..1c2044e5b4 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -16,15 +16,13 @@
from twisted.internet import defer, reactor
-from .logcontext import PreserveLoggingContext
+from .logcontext import preserve_context_over_deferred
-@defer.inlineCallbacks
def sleep(seconds):
d = defer.Deferred()
reactor.callLater(seconds, d.callback, seconds)
- with PreserveLoggingContext():
- yield d
+ return preserve_context_over_deferred(d)
def run_on_reactor():
@@ -34,20 +32,56 @@ def run_on_reactor():
return sleep(0)
-def create_observer(deferred):
- """Creates a deferred that observes the result or failure of the given
- deferred *without* affecting the given deferred.
+class ObservableDeferred(object):
+ """Wraps a deferred object so that we can add observer deferreds. These
+ observer deferreds do not affect the callback chain of the original
+ deferred.
+
+ If consumeErrors is true errors will be captured from the origin deferred.
"""
- d = defer.Deferred()
- def callback(r):
- d.callback(r)
- return r
+ __slots__ = ["_deferred", "_observers", "_result"]
+
+ def __init__(self, deferred, consumeErrors=False):
+ object.__setattr__(self, "_deferred", deferred)
+ object.__setattr__(self, "_result", None)
+ object.__setattr__(self, "_observers", [])
+
+ def callback(r):
+ self._result = (True, r)
+ while self._observers:
+ try:
+ self._observers.pop().callback(r)
+ except:
+ pass
+ return r
+
+ def errback(f):
+ self._result = (False, f)
+ while self._observers:
+ try:
+ self._observers.pop().errback(f)
+ except:
+ pass
+
+ if consumeErrors:
+ return None
+ else:
+ return f
+
+ deferred.addCallbacks(callback, errback)
- def errback(f):
- d.errback(f)
- return f
+ def observe(self):
+ if not self._result:
+ d = defer.Deferred()
+ self._observers.append(d)
+ return d
+ else:
+ success, res = self._result
+ return defer.succeed(res) if success else defer.fail(res)
- deferred.addCallbacks(callback, errback)
+ def __getattr__(self, name):
+ return getattr(self._deferred, name)
- return d
+ def __setattr__(self, name, value):
+ setattr(self._deferred, name, value)
|