summary refs log tree commit diff
path: root/synapse/util/distributor.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/distributor.py')
-rw-r--r--synapse/util/distributor.py17
1 files changed, 10 insertions, 7 deletions
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 064c4a7a1e..8875813de4 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2014, 2015 OpenMarket Ltd
+# Copyright 2014-2016 OpenMarket Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,9 +15,7 @@
 
 from twisted.internet import defer
 
-from synapse.util.logcontext import (
-    PreserveLoggingContext, preserve_context_over_deferred,
-)
+from synapse.util.logcontext import PreserveLoggingContext
 
 from synapse.util import unwrapFirstError
 
@@ -97,6 +95,7 @@ class Signal(object):
         Each observer callable may return a Deferred."""
         self.observers.append(observer)
 
+    @defer.inlineCallbacks
     def fire(self, *args, **kwargs):
         """Invokes every callable in the observer list, passing in the args and
         kwargs. Exceptions thrown by observers are logged but ignored. It is
@@ -116,6 +115,7 @@ class Signal(object):
                         failure.getTracebackObject()))
                 if not self.suppress_failures:
                     return failure
+
             return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
 
         with PreserveLoggingContext():
@@ -124,8 +124,11 @@ class Signal(object):
                 for observer in self.observers
             ]
 
-            d = defer.gatherResults(deferreds, consumeErrors=True)
+            res = yield defer.gatherResults(
+                deferreds, consumeErrors=True
+            ).addErrback(unwrapFirstError)
 
-        d.addErrback(unwrapFirstError)
+        defer.returnValue(res)
 
-        return preserve_context_over_deferred(d)
+    def __repr__(self):
+        return "<Signal name=%r>" % (self.name,)