summary refs log tree commit diff
path: root/synapse/logging/opentracing.py
diff options
context:
space:
mode:
authorEric Eastwood <erice@element.io>2023-06-06 17:39:22 -0500
committerGitHub <noreply@github.com>2023-06-06 17:39:22 -0500
commit8bfded81f3378ab6333f174e182f2aae6ef01f49 (patch)
treea044fc90236de0cd7b9a216987c88509bf56be30 /synapse/logging/opentracing.py
parentUpdate error to more plainly explain we can only authorize our own events (#1... (diff)
downloadsynapse-8bfded81f3378ab6333f174e182f2aae6ef01f49.tar.xz
Trace functions which return `Awaitable` (#15650)
Diffstat (limited to 'synapse/logging/opentracing.py')
-rw-r--r--synapse/logging/opentracing.py37
1 files changed, 26 insertions, 11 deletions
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index c70eee649c..75217e3f45 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -171,6 +171,7 @@ from functools import wraps
 from typing import (
     TYPE_CHECKING,
     Any,
+    Awaitable,
     Callable,
     Collection,
     ContextManager,
@@ -903,6 +904,7 @@ def _custom_sync_async_decorator(
     """
 
     if inspect.iscoroutinefunction(func):
+        # For this branch, we handle async functions like `async def func() -> RInner`.
         # In this branch, R = Awaitable[RInner], for some other type RInner
         @wraps(func)
         async def _wrapper(
@@ -914,15 +916,16 @@ def _custom_sync_async_decorator(
                 return await func(*args, **kwargs)  # type: ignore[misc]
 
     else:
-        # The other case here handles both sync functions and those
-        # decorated with inlineDeferred.
+        # The other case here handles sync functions including those decorated with
+        # `@defer.inlineCallbacks` or that return a `Deferred` or other `Awaitable`.
         @wraps(func)
-        def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
+        def _wrapper(*args: P.args, **kwargs: P.kwargs) -> Any:
             scope = wrapping_logic(func, *args, **kwargs)
             scope.__enter__()
 
             try:
                 result = func(*args, **kwargs)
+
                 if isinstance(result, defer.Deferred):
 
                     def call_back(result: R) -> R:
@@ -930,20 +933,32 @@ def _custom_sync_async_decorator(
                         return result
 
                     def err_back(result: R) -> R:
+                        # TODO: Pass the error details into `scope.__exit__(...)` for
+                        #       consistency with the other paths.
                         scope.__exit__(None, None, None)
                         return result
 
                     result.addCallbacks(call_back, err_back)
 
+                elif inspect.isawaitable(result):
+
+                    async def wrap_awaitable() -> Any:
+                        try:
+                            assert isinstance(result, Awaitable)
+                            awaited_result = await result
+                            scope.__exit__(None, None, None)
+                            return awaited_result
+                        except Exception as e:
+                            scope.__exit__(type(e), None, e.__traceback__)
+                            raise
+
+                    # The original method returned an awaitable, eg. a coroutine, so we
+                    # create another awaitable wrapping it that calls
+                    # `scope.__exit__(...)`.
+                    return wrap_awaitable()
                 else:
-                    if inspect.isawaitable(result):
-                        logger.error(
-                            "@trace may not have wrapped %s correctly! "
-                            "The function is not async but returned a %s.",
-                            func.__qualname__,
-                            type(result).__name__,
-                        )
-
+                    # Just a simple sync function so we can just exit the scope and
+                    # return the result without any fuss.
                     scope.__exit__(None, None, None)
 
                 return result