diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 9605d7d1b9..9cffaec8f2 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -32,7 +32,9 @@ class Distributor(object):
model will do for today.
"""
- def __init__(self):
+ def __init__(self, suppress_failures=True):
+ self.suppress_failures = suppress_failures
+
self.signals = {}
self.pre_registration = {}
@@ -40,7 +42,9 @@ class Distributor(object):
if name in self.signals:
raise KeyError("%r already has a signal named %s" % (self, name))
- self.signals[name] = Signal(name)
+ self.signals[name] = Signal(name,
+ suppress_failures=self.suppress_failures,
+ )
if name in self.pre_registration:
signal = self.signals[name]
@@ -74,8 +78,9 @@ class Signal(object):
method into all of the observers.
"""
- def __init__(self, name):
+ def __init__(self, name, suppress_failures):
self.name = name
+ self.suppress_failures = suppress_failures
self.observers = []
def observe(self, observer):
@@ -104,6 +109,10 @@ class Signal(object):
failure.type,
failure.value,
failure.getTracebackObject()))
+ if not self.suppress_failures:
+ raise failure
deferreds.append(d.addErrback(eb))
- return defer.DeferredList(deferreds)
+ return defer.DeferredList(
+ deferreds, fireOnOneErrback=not self.suppress_failures
+ )
diff --git a/tests/test_distributor.py b/tests/test_distributor.py
index 21c91f335b..2869fdfd76 100644
--- a/tests/test_distributor.py
+++ b/tests/test_distributor.py
@@ -13,9 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import unittest
-
from twisted.internet import defer
+from twisted.trial import unittest
from mock import Mock, patch
@@ -75,6 +74,24 @@ class DistributorTestCase(unittest.TestCase):
self.assertIsInstance(mock_logger.warning.call_args[0][0],
str)
+ @defer.inlineCallbacks
+ def test_signal_catch_no_suppress(self):
+ # Gut-wrenching
+ self.dist.suppress_failures = False
+
+ self.dist.declare("whail")
+
+ observer = Mock()
+ observer.return_value = defer.fail(
+ Exception("Oopsie")
+ )
+
+ self.dist.observe("whail", observer)
+
+ d = self.dist.fire("whail")
+
+ yield self.assertFailure(d, Exception)
+
def test_signal_prereg(self):
observer = Mock()
self.dist.observe("flare", observer)
@@ -85,5 +102,6 @@ class DistributorTestCase(unittest.TestCase):
observer.assert_called_with(4, 5)
def test_signal_undeclared(self):
- with self.assertRaises(KeyError):
+ def code():
self.dist.fire("notification")
+ self.assertRaises(KeyError, code)
|