summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/distributor.py28
1 files changed, 26 insertions, 2 deletions
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index da20523b70..22a857a306 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -12,10 +12,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import inspect
 import logging
 
 from twisted.internet import defer
+from twisted.internet.defer import Deferred, fail, succeed
+from twisted.python import failure
 
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
@@ -79,6 +81,28 @@ class Distributor(object):
         run_as_background_process(name, self.signals[name].fire, *args, **kwargs)
 
 
+def maybeAwaitableDeferred(f, *args, **kw):
+    """
+    Invoke a function that may or may not return a Deferred or an Awaitable.
+
+    This is a modified version of twisted.internet.defer.maybeDeferred.
+    """
+    try:
+        result = f(*args, **kw)
+    except Exception:
+        return fail(failure.Failure(captureVars=Deferred.debug))
+
+    if isinstance(result, Deferred):
+        return result
+    # Handle the additional case of an awaitable being returned.
+    elif inspect.isawaitable(result):
+        return defer.ensureDeferred(result)
+    elif isinstance(result, failure.Failure):
+        return fail(result)
+    else:
+        return succeed(result)
+
+
 class Signal(object):
     """A Signal is a dispatch point that stores a list of callables as
     observers of it.
@@ -122,7 +146,7 @@ class Signal(object):
                     ),
                 )
 
-            return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
+            return maybeAwaitableDeferred(observer, *args, **kwargs).addErrback(eb)
 
         deferreds = [run_in_background(do, o) for o in self.observers]