summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/__init__.py17
-rw-r--r--synapse/util/async.py66
-rw-r--r--synapse/util/distributor.py53
-rw-r--r--synapse/util/logcontext.py52
4 files changed, 145 insertions, 43 deletions
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 79109d0b19..c1a16b639a 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.util.logcontext import LoggingContext
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
 
 from twisted.internet import defer, reactor, task
 
@@ -23,6 +23,12 @@ import logging
 logger = logging.getLogger(__name__)
 
 
+def unwrapFirstError(failure):
+    # defer.gatherResults and DeferredLists wrap failures.
+    failure.trap(defer.FirstError)
+    return failure.value.subFailure
+
+
 class Clock(object):
     """A small utility that obtains current time-of-day so that time may be
     mocked during unit-tests.
@@ -50,9 +56,12 @@ class Clock(object):
         current_context = LoggingContext.current_context()
 
         def wrapped_callback():
-            LoggingContext.thread_local.current_context = current_context
-            callback()
-        return reactor.callLater(delay, wrapped_callback)
+            with PreserveLoggingContext():
+                LoggingContext.thread_local.current_context = current_context
+                callback()
+
+        with PreserveLoggingContext():
+            return reactor.callLater(delay, wrapped_callback)
 
     def cancel_call_later(self, timer):
         timer.cancel()
diff --git a/synapse/util/async.py b/synapse/util/async.py
index d8febdb90c..1c2044e5b4 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -16,15 +16,13 @@
 
 from twisted.internet import defer, reactor
 
-from .logcontext import PreserveLoggingContext
+from .logcontext import preserve_context_over_deferred
 
 
-@defer.inlineCallbacks
 def sleep(seconds):
     d = defer.Deferred()
     reactor.callLater(seconds, d.callback, seconds)
-    with PreserveLoggingContext():
-        yield d
+    return preserve_context_over_deferred(d)
 
 
 def run_on_reactor():
@@ -34,20 +32,56 @@ def run_on_reactor():
     return sleep(0)
 
 
-def create_observer(deferred):
-    """Creates a deferred that observes the result or failure of the given
-     deferred *without* affecting the given deferred.
+class ObservableDeferred(object):
+    """Wraps a deferred object so that we can add observer deferreds. These
+    observer deferreds do not affect the callback chain of the original
+    deferred.
+
+    If consumeErrors is true errors will be captured from the origin deferred.
     """
-    d = defer.Deferred()
 
-    def callback(r):
-        d.callback(r)
-        return r
+    __slots__ = ["_deferred", "_observers", "_result"]
+
+    def __init__(self, deferred, consumeErrors=False):
+        object.__setattr__(self, "_deferred", deferred)
+        object.__setattr__(self, "_result", None)
+        object.__setattr__(self, "_observers", [])
+
+        def callback(r):
+            self._result = (True, r)
+            while self._observers:
+                try:
+                    self._observers.pop().callback(r)
+                except:
+                    pass
+            return r
+
+        def errback(f):
+            self._result = (False, f)
+            while self._observers:
+                try:
+                    self._observers.pop().errback(f)
+                except:
+                    pass
+
+            if consumeErrors:
+                return None
+            else:
+                return f
+
+        deferred.addCallbacks(callback, errback)
 
-    def errback(f):
-        d.errback(f)
-        return f
+    def observe(self):
+        if not self._result:
+            d = defer.Deferred()
+            self._observers.append(d)
+            return d
+        else:
+            success, res = self._result
+            return defer.succeed(res) if success else defer.fail(res)
 
-    deferred.addCallbacks(callback, errback)
+    def __getattr__(self, name):
+        return getattr(self._deferred, name)
 
-    return d
+    def __setattr__(self, name, value):
+        setattr(self._deferred, name, value)
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 9d9c350397..064c4a7a1e 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -13,10 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.util.logcontext import PreserveLoggingContext
-
 from twisted.internet import defer
 
+from synapse.util.logcontext import (
+    PreserveLoggingContext, preserve_context_over_deferred,
+)
+
+from synapse.util import unwrapFirstError
+
 import logging
 
 
@@ -93,7 +97,6 @@ class Signal(object):
         Each observer callable may return a Deferred."""
         self.observers.append(observer)
 
-    @defer.inlineCallbacks
     def fire(self, *args, **kwargs):
         """Invokes every callable in the observer list, passing in the args and
         kwargs. Exceptions thrown by observers are logged but ignored. It is
@@ -101,24 +104,28 @@ class Signal(object):
 
         Returns a Deferred that will complete when all the observers have
         completed."""
+
+        def do(observer):
+            def eb(failure):
+                logger.warning(
+                    "%s signal observer %s failed: %r",
+                    self.name, observer, failure,
+                    exc_info=(
+                        failure.type,
+                        failure.value,
+                        failure.getTracebackObject()))
+                if not self.suppress_failures:
+                    return failure
+            return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
+
         with PreserveLoggingContext():
-            deferreds = []
-            for observer in self.observers:
-                d = defer.maybeDeferred(observer, *args, **kwargs)
-
-                def eb(failure):
-                    logger.warning(
-                        "%s signal observer %s failed: %r",
-                        self.name, observer, failure,
-                        exc_info=(
-                            failure.type,
-                            failure.value,
-                            failure.getTracebackObject()))
-                    if not self.suppress_failures:
-                        failure.raiseException()
-                deferreds.append(d.addErrback(eb))
-            results = []
-            for deferred in deferreds:
-                result = yield deferred
-                results.append(result)
-        defer.returnValue(results)
+            deferreds = [
+                do(observer)
+                for observer in self.observers
+            ]
+
+            d = defer.gatherResults(deferreds, consumeErrors=True)
+
+        d.addErrback(unwrapFirstError)
+
+        return preserve_context_over_deferred(d)
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index da7872e95d..a92d518b43 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -12,6 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from twisted.internet import defer
+
 import threading
 import logging
 
@@ -129,3 +131,53 @@ class PreserveLoggingContext(object):
     def __exit__(self, type, value, traceback):
         """Restores the current logging context"""
         LoggingContext.thread_local.current_context = self.current_context
+
+        if self.current_context is not LoggingContext.sentinel:
+            if self.current_context.parent_context is None:
+                logger.warn(
+                    "Restoring dead context: %s",
+                    self.current_context,
+                )
+
+
+def preserve_context_over_fn(fn, *args, **kwargs):
+    """Takes a function and invokes it with the given arguments, but removes
+    and restores the current logging context while doing so.
+
+    If the result is a deferred, call preserve_context_over_deferred before
+    returning it.
+    """
+    with PreserveLoggingContext():
+        res = fn(*args, **kwargs)
+
+    if isinstance(res, defer.Deferred):
+        return preserve_context_over_deferred(res)
+    else:
+        return res
+
+
+def preserve_context_over_deferred(deferred):
+    """Given a deferred wrap it such that any callbacks added later to it will
+    be invoked with the current context.
+    """
+    d = defer.Deferred()
+
+    current_context = LoggingContext.current_context()
+
+    def cb(res):
+        with PreserveLoggingContext():
+            LoggingContext.thread_local.current_context = current_context
+            res = d.callback(res)
+        return res
+
+    def eb(failure):
+        with PreserveLoggingContext():
+            LoggingContext.thread_local.current_context = current_context
+            res = d.errback(failure)
+        return res
+
+    if deferred.called:
+        return deferred
+
+    deferred.addCallbacks(cb, eb)
+    return d