diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 79109d0b19..c1a16b639a 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.logcontext import LoggingContext
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from twisted.internet import defer, reactor, task
@@ -23,6 +23,12 @@ import logging
logger = logging.getLogger(__name__)
+def unwrapFirstError(failure):
+ # defer.gatherResults and DeferredLists wrap failures.
+ failure.trap(defer.FirstError)
+ return failure.value.subFailure
+
+
class Clock(object):
"""A small utility that obtains current time-of-day so that time may be
mocked during unit-tests.
@@ -50,9 +56,12 @@ class Clock(object):
current_context = LoggingContext.current_context()
def wrapped_callback():
- LoggingContext.thread_local.current_context = current_context
- callback()
- return reactor.callLater(delay, wrapped_callback)
+ with PreserveLoggingContext():
+ LoggingContext.thread_local.current_context = current_context
+ callback()
+
+ with PreserveLoggingContext():
+ return reactor.callLater(delay, wrapped_callback)
def cancel_call_later(self, timer):
timer.cancel()
diff --git a/synapse/util/async.py b/synapse/util/async.py
index d8febdb90c..1c2044e5b4 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -16,15 +16,13 @@
from twisted.internet import defer, reactor
-from .logcontext import PreserveLoggingContext
+from .logcontext import preserve_context_over_deferred
-@defer.inlineCallbacks
def sleep(seconds):
d = defer.Deferred()
reactor.callLater(seconds, d.callback, seconds)
- with PreserveLoggingContext():
- yield d
+ return preserve_context_over_deferred(d)
def run_on_reactor():
@@ -34,20 +32,56 @@ def run_on_reactor():
return sleep(0)
-def create_observer(deferred):
- """Creates a deferred that observes the result or failure of the given
- deferred *without* affecting the given deferred.
+class ObservableDeferred(object):
+ """Wraps a deferred object so that we can add observer deferreds. These
+ observer deferreds do not affect the callback chain of the original
+ deferred.
+
+ If consumeErrors is true errors will be captured from the origin deferred.
"""
- d = defer.Deferred()
- def callback(r):
- d.callback(r)
- return r
+ __slots__ = ["_deferred", "_observers", "_result"]
+
+ def __init__(self, deferred, consumeErrors=False):
+ object.__setattr__(self, "_deferred", deferred)
+ object.__setattr__(self, "_result", None)
+ object.__setattr__(self, "_observers", [])
+
+ def callback(r):
+ self._result = (True, r)
+ while self._observers:
+ try:
+ self._observers.pop().callback(r)
+ except:
+ pass
+ return r
+
+ def errback(f):
+ self._result = (False, f)
+ while self._observers:
+ try:
+ self._observers.pop().errback(f)
+ except:
+ pass
+
+ if consumeErrors:
+ return None
+ else:
+ return f
+
+ deferred.addCallbacks(callback, errback)
- def errback(f):
- d.errback(f)
- return f
+ def observe(self):
+ if not self._result:
+ d = defer.Deferred()
+ self._observers.append(d)
+ return d
+ else:
+ success, res = self._result
+ return defer.succeed(res) if success else defer.fail(res)
- deferred.addCallbacks(callback, errback)
+ def __getattr__(self, name):
+ return getattr(self._deferred, name)
- return d
+ def __setattr__(self, name, value):
+ setattr(self._deferred, name, value)
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 9d9c350397..064c4a7a1e 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -13,10 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.logcontext import PreserveLoggingContext
-
from twisted.internet import defer
+from synapse.util.logcontext import (
+ PreserveLoggingContext, preserve_context_over_deferred,
+)
+
+from synapse.util import unwrapFirstError
+
import logging
@@ -93,7 +97,6 @@ class Signal(object):
Each observer callable may return a Deferred."""
self.observers.append(observer)
- @defer.inlineCallbacks
def fire(self, *args, **kwargs):
"""Invokes every callable in the observer list, passing in the args and
kwargs. Exceptions thrown by observers are logged but ignored. It is
@@ -101,24 +104,28 @@ class Signal(object):
Returns a Deferred that will complete when all the observers have
completed."""
+
+ def do(observer):
+ def eb(failure):
+ logger.warning(
+ "%s signal observer %s failed: %r",
+ self.name, observer, failure,
+ exc_info=(
+ failure.type,
+ failure.value,
+ failure.getTracebackObject()))
+ if not self.suppress_failures:
+ return failure
+ return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
+
with PreserveLoggingContext():
- deferreds = []
- for observer in self.observers:
- d = defer.maybeDeferred(observer, *args, **kwargs)
-
- def eb(failure):
- logger.warning(
- "%s signal observer %s failed: %r",
- self.name, observer, failure,
- exc_info=(
- failure.type,
- failure.value,
- failure.getTracebackObject()))
- if not self.suppress_failures:
- failure.raiseException()
- deferreds.append(d.addErrback(eb))
- results = []
- for deferred in deferreds:
- result = yield deferred
- results.append(result)
- defer.returnValue(results)
+ deferreds = [
+ do(observer)
+ for observer in self.observers
+ ]
+
+ d = defer.gatherResults(deferreds, consumeErrors=True)
+
+ d.addErrback(unwrapFirstError)
+
+ return preserve_context_over_deferred(d)
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index da7872e95d..a92d518b43 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer
+
import threading
import logging
@@ -129,3 +131,53 @@ class PreserveLoggingContext(object):
def __exit__(self, type, value, traceback):
"""Restores the current logging context"""
LoggingContext.thread_local.current_context = self.current_context
+
+ if self.current_context is not LoggingContext.sentinel:
+ if self.current_context.parent_context is None:
+ logger.warn(
+ "Restoring dead context: %s",
+ self.current_context,
+ )
+
+
+def preserve_context_over_fn(fn, *args, **kwargs):
+ """Takes a function and invokes it with the given arguments, but removes
+ and restores the current logging context while doing so.
+
+ If the result is a deferred, call preserve_context_over_deferred before
+ returning it.
+ """
+ with PreserveLoggingContext():
+ res = fn(*args, **kwargs)
+
+ if isinstance(res, defer.Deferred):
+ return preserve_context_over_deferred(res)
+ else:
+ return res
+
+
+def preserve_context_over_deferred(deferred):
+ """Given a deferred wrap it such that any callbacks added later to it will
+ be invoked with the current context.
+ """
+ d = defer.Deferred()
+
+ current_context = LoggingContext.current_context()
+
+ def cb(res):
+ with PreserveLoggingContext():
+ LoggingContext.thread_local.current_context = current_context
+ res = d.callback(res)
+ return res
+
+ def eb(failure):
+ with PreserveLoggingContext():
+ LoggingContext.thread_local.current_context = current_context
+ res = d.errback(failure)
+ return res
+
+ if deferred.called:
+ return deferred
+
+ deferred.addCallbacks(cb, eb)
+ return d
|