summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/util/distributor.py17
-rw-r--r--tests/test_distributor.py24
2 files changed, 34 insertions, 7 deletions
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 9605d7d1b9..9cffaec8f2 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -32,7 +32,9 @@ class Distributor(object):
       model will do for today.
     """
 
-    def __init__(self):
+    def __init__(self, suppress_failures=True):
+        self.suppress_failures = suppress_failures
+
         self.signals = {}
         self.pre_registration = {}
 
@@ -40,7 +42,9 @@ class Distributor(object):
         if name in self.signals:
             raise KeyError("%r already has a signal named %s" % (self, name))
 
-        self.signals[name] = Signal(name)
+        self.signals[name] = Signal(name,
+            suppress_failures=self.suppress_failures,
+        )
 
         if name in self.pre_registration:
             signal = self.signals[name]
@@ -74,8 +78,9 @@ class Signal(object):
     method into all of the observers.
     """
 
-    def __init__(self, name):
+    def __init__(self, name, suppress_failures):
         self.name = name
+        self.suppress_failures = suppress_failures
         self.observers = []
 
     def observe(self, observer):
@@ -104,6 +109,10 @@ class Signal(object):
                         failure.type,
                         failure.value,
                         failure.getTracebackObject()))
+                if not self.suppress_failures:
+                    raise failure
             deferreds.append(d.addErrback(eb))
 
-        return defer.DeferredList(deferreds)
+        return defer.DeferredList(
+            deferreds, fireOnOneErrback=not self.suppress_failures
+        )
diff --git a/tests/test_distributor.py b/tests/test_distributor.py
index 21c91f335b..2869fdfd76 100644
--- a/tests/test_distributor.py
+++ b/tests/test_distributor.py
@@ -13,9 +13,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import unittest
-
 from twisted.internet import defer
+from twisted.trial import unittest
 
 from mock import Mock, patch
 
@@ -75,6 +74,24 @@ class DistributorTestCase(unittest.TestCase):
             self.assertIsInstance(mock_logger.warning.call_args[0][0],
                     str)
 
+    @defer.inlineCallbacks
+    def test_signal_catch_no_suppress(self):
+        # Gut-wrenching
+        self.dist.suppress_failures = False
+
+        self.dist.declare("whail")
+
+        observer = Mock()
+        observer.return_value = defer.fail(
+            Exception("Oopsie")
+        )
+
+        self.dist.observe("whail", observer)
+
+        d = self.dist.fire("whail")
+
+        yield self.assertFailure(d, Exception)
+
     def test_signal_prereg(self):
         observer = Mock()
         self.dist.observe("flare", observer)
@@ -85,5 +102,6 @@ class DistributorTestCase(unittest.TestCase):
         observer.assert_called_with(4, 5)
 
     def test_signal_undeclared(self):
-        with self.assertRaises(KeyError):
+        def code():
             self.dist.fire("notification")
+        self.assertRaises(KeyError, code)