summary refs log tree commit diff
path: root/synapse/util/distributor.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/distributor.py')
-rw-r--r--synapse/util/distributor.py53
1 files changed, 30 insertions, 23 deletions
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 9d9c350397..064c4a7a1e 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -13,10 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.util.logcontext import PreserveLoggingContext
-
 from twisted.internet import defer
 
+from synapse.util.logcontext import (
+    PreserveLoggingContext, preserve_context_over_deferred,
+)
+
+from synapse.util import unwrapFirstError
+
 import logging
 
 
@@ -93,7 +97,6 @@ class Signal(object):
         Each observer callable may return a Deferred."""
         self.observers.append(observer)
 
-    @defer.inlineCallbacks
     def fire(self, *args, **kwargs):
         """Invokes every callable in the observer list, passing in the args and
         kwargs. Exceptions thrown by observers are logged but ignored. It is
@@ -101,24 +104,28 @@ class Signal(object):
 
         Returns a Deferred that will complete when all the observers have
         completed."""
+
+        def do(observer):
+            def eb(failure):
+                logger.warning(
+                    "%s signal observer %s failed: %r",
+                    self.name, observer, failure,
+                    exc_info=(
+                        failure.type,
+                        failure.value,
+                        failure.getTracebackObject()))
+                if not self.suppress_failures:
+                    return failure
+            return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
+
         with PreserveLoggingContext():
-            deferreds = []
-            for observer in self.observers:
-                d = defer.maybeDeferred(observer, *args, **kwargs)
-
-                def eb(failure):
-                    logger.warning(
-                        "%s signal observer %s failed: %r",
-                        self.name, observer, failure,
-                        exc_info=(
-                            failure.type,
-                            failure.value,
-                            failure.getTracebackObject()))
-                    if not self.suppress_failures:
-                        failure.raiseException()
-                deferreds.append(d.addErrback(eb))
-            results = []
-            for deferred in deferreds:
-                result = yield deferred
-                results.append(result)
-        defer.returnValue(results)
+            deferreds = [
+                do(observer)
+                for observer in self.observers
+            ]
+
+            d = defer.gatherResults(deferreds, consumeErrors=True)
+
+        d.addErrback(unwrapFirstError)
+
+        return preserve_context_over_deferred(d)