summary refs log tree commit diff
path: root/tests/test_distributor.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/test_distributor.py27
1 files changed, 19 insertions, 8 deletions
diff --git a/tests/test_distributor.py b/tests/test_distributor.py
index 39c5b8dff2..6a0095d850 100644
--- a/tests/test_distributor.py
+++ b/tests/test_distributor.py
@@ -13,12 +13,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from tests import unittest
+from . import unittest
 from twisted.internet import defer
 
 from mock import Mock, patch
 
 from synapse.util.distributor import Distributor
+from synapse.util.async import run_on_reactor
 
 
 class DistributorTestCase(unittest.TestCase):
@@ -26,6 +27,7 @@ class DistributorTestCase(unittest.TestCase):
     def setUp(self):
         self.dist = Distributor()
 
+    @defer.inlineCallbacks
     def test_signal_dispatch(self):
         self.dist.declare("alert")
 
@@ -33,10 +35,11 @@ class DistributorTestCase(unittest.TestCase):
         self.dist.observe("alert", observer)
 
         d = self.dist.fire("alert", 1, 2, 3)
-
+        yield d
         self.assertTrue(d.called)
         observer.assert_called_with(1, 2, 3)
 
+    @defer.inlineCallbacks
     def test_signal_dispatch_deferred(self):
         self.dist.declare("whine")
 
@@ -50,8 +53,10 @@ class DistributorTestCase(unittest.TestCase):
         self.assertFalse(d_outer.called)
 
         d_inner.callback(None)
+        yield d_outer
         self.assertTrue(d_outer.called)
 
+    @defer.inlineCallbacks
     def test_signal_catch(self):
         self.dist.declare("alarm")
 
@@ -65,6 +70,7 @@ class DistributorTestCase(unittest.TestCase):
                 spec=["warning"]
         ) as mock_logger:
             d = self.dist.fire("alarm", "Go")
+            yield d
             self.assertTrue(d.called)
 
             observers[0].assert_called_once("Go")
@@ -81,23 +87,28 @@ class DistributorTestCase(unittest.TestCase):
 
         self.dist.declare("whail")
 
-        observer = Mock()
-        observer.return_value = defer.fail(
-            Exception("Oopsie")
-        )
+        class MyException(Exception):
+            pass
+
+        @defer.inlineCallbacks
+        def observer():
+            yield run_on_reactor()
+            raise MyException("Oopsie")
 
         self.dist.observe("whail", observer)
 
         d = self.dist.fire("whail")
 
-        yield self.assertFailure(d, Exception)
+        yield self.assertFailure(d, MyException)
+        self.dist.suppress_failures = True
 
+    @defer.inlineCallbacks
     def test_signal_prereg(self):
         observer = Mock()
         self.dist.observe("flare", observer)
 
         self.dist.declare("flare")
-        self.dist.fire("flare", 4, 5)
+        yield self.dist.fire("flare", 4, 5)
 
         observer.assert_called_with(4, 5)