summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/15650.misc1
-rw-r--r--synapse/logging/opentracing.py37
-rw-r--r--tests/logging/test_opentracing.py43
3 files changed, 59 insertions, 22 deletions
diff --git a/changelog.d/15650.misc b/changelog.d/15650.misc
new file mode 100644
index 0000000000..9bbad113e1
--- /dev/null
+++ b/changelog.d/15650.misc
@@ -0,0 +1 @@
+Add support for tracing functions which return `Awaitable`s.
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index c70eee649c..75217e3f45 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -171,6 +171,7 @@ from functools import wraps
 from typing import (
     TYPE_CHECKING,
     Any,
+    Awaitable,
     Callable,
     Collection,
     ContextManager,
@@ -903,6 +904,7 @@ def _custom_sync_async_decorator(
     """
 
     if inspect.iscoroutinefunction(func):
+        # For this branch, we handle async functions like `async def func() -> RInner`.
         # In this branch, R = Awaitable[RInner], for some other type RInner
         @wraps(func)
         async def _wrapper(
@@ -914,15 +916,16 @@ def _custom_sync_async_decorator(
                 return await func(*args, **kwargs)  # type: ignore[misc]
 
     else:
-        # The other case here handles both sync functions and those
-        # decorated with inlineDeferred.
+        # The other case here handles sync functions including those decorated with
+        # `@defer.inlineCallbacks` or that return a `Deferred` or other `Awaitable`.
         @wraps(func)
-        def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
+        def _wrapper(*args: P.args, **kwargs: P.kwargs) -> Any:
             scope = wrapping_logic(func, *args, **kwargs)
             scope.__enter__()
 
             try:
                 result = func(*args, **kwargs)
+
                 if isinstance(result, defer.Deferred):
 
                     def call_back(result: R) -> R:
@@ -930,20 +933,32 @@ def _custom_sync_async_decorator(
                         return result
 
                     def err_back(result: R) -> R:
+                        # TODO: Pass the error details into `scope.__exit__(...)` for
+                        #       consistency with the other paths.
                         scope.__exit__(None, None, None)
                         return result
 
                     result.addCallbacks(call_back, err_back)
 
+                elif inspect.isawaitable(result):
+
+                    async def wrap_awaitable() -> Any:
+                        try:
+                            assert isinstance(result, Awaitable)
+                            awaited_result = await result
+                            scope.__exit__(None, None, None)
+                            return awaited_result
+                        except Exception as e:
+                            scope.__exit__(type(e), None, e.__traceback__)
+                            raise
+
+                    # The original method returned an awaitable, eg. a coroutine, so we
+                    # create another awaitable wrapping it that calls
+                    # `scope.__exit__(...)`.
+                    return wrap_awaitable()
                 else:
-                    if inspect.isawaitable(result):
-                        logger.error(
-                            "@trace may not have wrapped %s correctly! "
-                            "The function is not async but returned a %s.",
-                            func.__qualname__,
-                            type(result).__name__,
-                        )
-
+                    # Just a simple sync function so we can just exit the scope and
+                    # return the result without any fuss.
                     scope.__exit__(None, None, None)
 
                 return result
diff --git a/tests/logging/test_opentracing.py b/tests/logging/test_opentracing.py
index e28ba84cc2..1bc7d64ad9 100644
--- a/tests/logging/test_opentracing.py
+++ b/tests/logging/test_opentracing.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import cast
+from typing import Awaitable, cast
 
 from twisted.internet import defer
 from twisted.test.proto_helpers import MemoryReactorClock
@@ -227,8 +227,6 @@ class LogContextScopeManagerTestCase(TestCase):
         Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
         with functions that return deferreds
         """
-        reactor = MemoryReactorClock()
-
         with LoggingContext("root context"):
 
             @trace_with_opname("fixture_deferred_func", tracer=self._tracer)
@@ -240,9 +238,6 @@ class LogContextScopeManagerTestCase(TestCase):
 
             result_d1 = fixture_deferred_func()
 
-            # let the tasks complete
-            reactor.pump((2,) * 8)
-
             self.assertEqual(self.successResultOf(result_d1), "foo")
 
         # the span should have been reported
@@ -256,8 +251,6 @@ class LogContextScopeManagerTestCase(TestCase):
         Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
         with async functions
         """
-        reactor = MemoryReactorClock()
-
         with LoggingContext("root context"):
 
             @trace_with_opname("fixture_async_func", tracer=self._tracer)
@@ -267,9 +260,6 @@ class LogContextScopeManagerTestCase(TestCase):
 
             d1 = defer.ensureDeferred(fixture_async_func())
 
-            # let the tasks complete
-            reactor.pump((2,) * 8)
-
             self.assertEqual(self.successResultOf(d1), "foo")
 
         # the span should have been reported
@@ -277,3 +267,34 @@ class LogContextScopeManagerTestCase(TestCase):
             [span.operation_name for span in self._reporter.get_spans()],
             ["fixture_async_func"],
         )
+
+    def test_trace_decorator_awaitable_return(self) -> None:
+        """
+        Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
+        with functions that return an awaitable (e.g. a coroutine)
+        """
+        with LoggingContext("root context"):
+            # Something we can return without `await` to get a coroutine
+            async def fixture_async_func() -> str:
+                return "foo"
+
+            # The actual kind of function we want to test that returns an awaitable
+            @trace_with_opname("fixture_awaitable_return_func", tracer=self._tracer)
+            @tag_args
+            def fixture_awaitable_return_func() -> Awaitable[str]:
+                return fixture_async_func()
+
+            # Something we can run with `defer.ensureDeferred(runner())` and pump the
+            # whole async tasks through to completion.
+            async def runner() -> str:
+                return await fixture_awaitable_return_func()
+
+            d1 = defer.ensureDeferred(runner())
+
+            self.assertEqual(self.successResultOf(d1), "foo")
+
+        # the span should have been reported
+        self.assertEqual(
+            [span.operation_name for span in self._reporter.get_spans()],
+            ["fixture_awaitable_return_func"],
+        )