diff options
Diffstat (limited to 'synapse/util')
-rw-r--r-- | synapse/util/__init__.py | 49 | ||||
-rw-r--r-- | synapse/util/async.py | 66 | ||||
-rw-r--r-- | synapse/util/distributor.py | 53 | ||||
-rw-r--r-- | synapse/util/logcontext.py | 52 |
4 files changed, 175 insertions, 45 deletions
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 79109d0b19..260714ccc2 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.logcontext import LoggingContext +from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from twisted.internet import defer, reactor, task @@ -23,6 +23,40 @@ import logging logger = logging.getLogger(__name__) +def unwrapFirstError(failure): + # defer.gatherResults and DeferredLists wrap failures. + failure.trap(defer.FirstError) + return failure.value.subFailure + + +def unwrap_deferred(d): + """Given a deferred that we know has completed, return its value or raise + the failure as an exception + """ + if not d.called: + raise RuntimeError("deferred has not finished") + + res = [] + + def f(r): + res.append(r) + return r + d.addCallback(f) + + if res: + return res[0] + + def f(r): + res.append(r) + return r + d.addErrback(f) + + if res: + res[0].raiseException() + else: + raise RuntimeError("deferred did not call callbacks") + + class Clock(object): """A small utility that obtains current time-of-day so that time may be mocked during unit-tests. @@ -46,13 +80,16 @@ class Clock(object): def stop_looping_call(self, loop): loop.stop() - def call_later(self, delay, callback): + def call_later(self, delay, callback, *args, **kwargs): current_context = LoggingContext.current_context() - def wrapped_callback(): - LoggingContext.thread_local.current_context = current_context - callback() - return reactor.callLater(delay, wrapped_callback) + def wrapped_callback(*args, **kwargs): + with PreserveLoggingContext(): + LoggingContext.thread_local.current_context = current_context + callback(*args, **kwargs) + + with PreserveLoggingContext(): + return reactor.callLater(delay, wrapped_callback, *args, **kwargs) def cancel_call_later(self, timer): timer.cancel() diff --git a/synapse/util/async.py b/synapse/util/async.py index d8febdb90c..1c2044e5b4 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -16,15 +16,13 @@ from twisted.internet import defer, reactor -from .logcontext import PreserveLoggingContext +from .logcontext import preserve_context_over_deferred -@defer.inlineCallbacks def sleep(seconds): d = defer.Deferred() reactor.callLater(seconds, d.callback, seconds) - with PreserveLoggingContext(): - yield d + return preserve_context_over_deferred(d) def run_on_reactor(): @@ -34,20 +32,56 @@ def run_on_reactor(): return sleep(0) -def create_observer(deferred): - """Creates a deferred that observes the result or failure of the given - deferred *without* affecting the given deferred. +class ObservableDeferred(object): + """Wraps a deferred object so that we can add observer deferreds. These + observer deferreds do not affect the callback chain of the original + deferred. + + If consumeErrors is true errors will be captured from the origin deferred. """ - d = defer.Deferred() - def callback(r): - d.callback(r) - return r + __slots__ = ["_deferred", "_observers", "_result"] + + def __init__(self, deferred, consumeErrors=False): + object.__setattr__(self, "_deferred", deferred) + object.__setattr__(self, "_result", None) + object.__setattr__(self, "_observers", []) + + def callback(r): + self._result = (True, r) + while self._observers: + try: + self._observers.pop().callback(r) + except: + pass + return r + + def errback(f): + self._result = (False, f) + while self._observers: + try: + self._observers.pop().errback(f) + except: + pass + + if consumeErrors: + return None + else: + return f + + deferred.addCallbacks(callback, errback) - def errback(f): - d.errback(f) - return f + def observe(self): + if not self._result: + d = defer.Deferred() + self._observers.append(d) + return d + else: + success, res = self._result + return defer.succeed(res) if success else defer.fail(res) - deferred.addCallbacks(callback, errback) + def __getattr__(self, name): + return getattr(self._deferred, name) - return d + def __setattr__(self, name, value): + setattr(self._deferred, name, value) diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index 9d9c350397..064c4a7a1e 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -13,10 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.logcontext import PreserveLoggingContext - from twisted.internet import defer +from synapse.util.logcontext import ( + PreserveLoggingContext, preserve_context_over_deferred, +) + +from synapse.util import unwrapFirstError + import logging @@ -93,7 +97,6 @@ class Signal(object): Each observer callable may return a Deferred.""" self.observers.append(observer) - @defer.inlineCallbacks def fire(self, *args, **kwargs): """Invokes every callable in the observer list, passing in the args and kwargs. Exceptions thrown by observers are logged but ignored. It is @@ -101,24 +104,28 @@ class Signal(object): Returns a Deferred that will complete when all the observers have completed.""" + + def do(observer): + def eb(failure): + logger.warning( + "%s signal observer %s failed: %r", + self.name, observer, failure, + exc_info=( + failure.type, + failure.value, + failure.getTracebackObject())) + if not self.suppress_failures: + return failure + return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb) + with PreserveLoggingContext(): - deferreds = [] - for observer in self.observers: - d = defer.maybeDeferred(observer, *args, **kwargs) - - def eb(failure): - logger.warning( - "%s signal observer %s failed: %r", - self.name, observer, failure, - exc_info=( - failure.type, - failure.value, - failure.getTracebackObject())) - if not self.suppress_failures: - failure.raiseException() - deferreds.append(d.addErrback(eb)) - results = [] - for deferred in deferreds: - result = yield deferred - results.append(result) - defer.returnValue(results) + deferreds = [ + do(observer) + for observer in self.observers + ] + + d = defer.gatherResults(deferreds, consumeErrors=True) + + d.addErrback(unwrapFirstError) + + return preserve_context_over_deferred(d) diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index da7872e95d..a92d518b43 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + import threading import logging @@ -129,3 +131,53 @@ class PreserveLoggingContext(object): def __exit__(self, type, value, traceback): """Restores the current logging context""" LoggingContext.thread_local.current_context = self.current_context + + if self.current_context is not LoggingContext.sentinel: + if self.current_context.parent_context is None: + logger.warn( + "Restoring dead context: %s", + self.current_context, + ) + + +def preserve_context_over_fn(fn, *args, **kwargs): + """Takes a function and invokes it with the given arguments, but removes + and restores the current logging context while doing so. + + If the result is a deferred, call preserve_context_over_deferred before + returning it. + """ + with PreserveLoggingContext(): + res = fn(*args, **kwargs) + + if isinstance(res, defer.Deferred): + return preserve_context_over_deferred(res) + else: + return res + + +def preserve_context_over_deferred(deferred): + """Given a deferred wrap it such that any callbacks added later to it will + be invoked with the current context. + """ + d = defer.Deferred() + + current_context = LoggingContext.current_context() + + def cb(res): + with PreserveLoggingContext(): + LoggingContext.thread_local.current_context = current_context + res = d.callback(res) + return res + + def eb(failure): + with PreserveLoggingContext(): + LoggingContext.thread_local.current_context = current_context + res = d.errback(failure) + return res + + if deferred.called: + return deferred + + deferred.addCallbacks(cb, eb) + return d |