summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/async_helpers.py8
-rw-r--r--synapse/util/distributor.py7
2 files changed, 7 insertions, 8 deletions
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 382f0cf3f0..9a873c8e8e 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -15,10 +15,12 @@
 # limitations under the License.
 
 import collections
+import inspect
 import logging
 from contextlib import contextmanager
 from typing import (
     Any,
+    Awaitable,
     Callable,
     Dict,
     Hashable,
@@ -542,11 +544,11 @@ class DoneAwaitable:
         raise StopIteration(self.value)
 
 
-def maybe_awaitable(value):
+def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]:
     """Convert a value to an awaitable if not already an awaitable.
     """
-
-    if hasattr(value, "__await__"):
+    if inspect.isawaitable(value):
+        assert isinstance(value, Awaitable)
         return value
 
     return DoneAwaitable(value)
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index f73e95393c..a6ee9edaec 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -12,13 +12,13 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import inspect
 import logging
 
 from twisted.internet import defer
 
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.async_helpers import maybe_awaitable
 
 logger = logging.getLogger(__name__)
 
@@ -105,10 +105,7 @@ class Signal:
 
         async def do(observer):
             try:
-                result = observer(*args, **kwargs)
-                if inspect.isawaitable(result):
-                    result = await result
-                return result
+                return await maybe_awaitable(observer(*args, **kwargs))
             except Exception as e:
                 logger.warning(
                     "%s signal observer %s failed: %r", self.name, observer, e,