summary refs log tree commit diff
path: root/tests/logging/test_opentracing.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/logging/test_opentracing.py')
-rw-r--r--tests/logging/test_opentracing.py113
1 files changed, 104 insertions, 9 deletions
diff --git a/tests/logging/test_opentracing.py b/tests/logging/test_opentracing.py
index 40148d503c..0917e478a5 100644
--- a/tests/logging/test_opentracing.py
+++ b/tests/logging/test_opentracing.py
@@ -12,6 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import cast
+
 from twisted.internet import defer
 from twisted.test.proto_helpers import MemoryReactorClock
 
@@ -23,6 +25,8 @@ from synapse.logging.context import (
 from synapse.logging.opentracing import (
     start_active_span,
     start_active_span_follows_from,
+    tag_args,
+    trace_with_opname,
 )
 from synapse.util import Clock
 
@@ -36,10 +40,23 @@ try:
 except ImportError:
     jaeger_client = None  # type: ignore
 
+import logging
+
 from tests.unittest import TestCase
 
+logger = logging.getLogger(__name__)
+
 
 class LogContextScopeManagerTestCase(TestCase):
+    """
+    Test logging contexts and active opentracing spans.
+
+    There's casts throughout this from generic opentracing objects (e.g.
+    opentracing.Span) to the ones specific to Jaeger since they have additional
+    properties that these tests depend on. This is safe since the only supported
+    opentracing backend is Jaeger.
+    """
+
     if LogContextScopeManager is None:
         skip = "Requires opentracing"  # type: ignore[unreachable]
     if jaeger_client is None:
@@ -69,7 +86,7 @@ class LogContextScopeManagerTestCase(TestCase):
 
             # start_active_span should start and activate a span.
             scope = start_active_span("span", tracer=self._tracer)
-            span = scope.span
+            span = cast(jaeger_client.Span, scope.span)
             self.assertEqual(self._tracer.active_span, span)
             self.assertIsNotNone(span.start_time)
 
@@ -91,6 +108,7 @@ class LogContextScopeManagerTestCase(TestCase):
         with LoggingContext("root context"):
             with start_active_span("root span", tracer=self._tracer) as root_scope:
                 self.assertEqual(self._tracer.active_span, root_scope.span)
+                root_context = cast(jaeger_client.SpanContext, root_scope.span.context)
 
                 scope1 = start_active_span(
                     "child1",
@@ -99,9 +117,8 @@ class LogContextScopeManagerTestCase(TestCase):
                 self.assertEqual(
                     self._tracer.active_span, scope1.span, "child1 was not activated"
                 )
-                self.assertEqual(
-                    scope1.span.context.parent_id, root_scope.span.context.span_id
-                )
+                context1 = cast(jaeger_client.SpanContext, scope1.span.context)
+                self.assertEqual(context1.parent_id, root_context.span_id)
 
                 scope2 = start_active_span_follows_from(
                     "child2",
@@ -109,17 +126,18 @@ class LogContextScopeManagerTestCase(TestCase):
                     tracer=self._tracer,
                 )
                 self.assertEqual(self._tracer.active_span, scope2.span)
-                self.assertEqual(
-                    scope2.span.context.parent_id, scope1.span.context.span_id
-                )
+                context2 = cast(jaeger_client.SpanContext, scope2.span.context)
+                self.assertEqual(context2.parent_id, context1.span_id)
 
                 with scope1, scope2:
                     pass
 
                 # the root scope should be restored
                 self.assertEqual(self._tracer.active_span, root_scope.span)
-                self.assertIsNotNone(scope2.span.end_time)
-                self.assertIsNotNone(scope1.span.end_time)
+                span2 = cast(jaeger_client.Span, scope2.span)
+                span1 = cast(jaeger_client.Span, scope1.span)
+                self.assertIsNotNone(span2.end_time)
+                self.assertIsNotNone(span1.end_time)
 
             self.assertIsNone(self._tracer.active_span)
 
@@ -182,3 +200,80 @@ class LogContextScopeManagerTestCase(TestCase):
             self._reporter.get_spans(),
             [scopes[1].span, scopes[2].span, scopes[0].span],
         )
+
+    def test_trace_decorator_sync(self) -> None:
+        """
+        Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
+        with sync functions
+        """
+        with LoggingContext("root context"):
+
+            @trace_with_opname("fixture_sync_func", tracer=self._tracer)
+            @tag_args
+            def fixture_sync_func() -> str:
+                return "foo"
+
+            result = fixture_sync_func()
+            self.assertEqual(result, "foo")
+
+        # the span should have been reported
+        self.assertEqual(
+            [span.operation_name for span in self._reporter.get_spans()],
+            ["fixture_sync_func"],
+        )
+
+    def test_trace_decorator_deferred(self) -> None:
+        """
+        Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
+        with functions that return deferreds
+        """
+        reactor = MemoryReactorClock()
+
+        with LoggingContext("root context"):
+
+            @trace_with_opname("fixture_deferred_func", tracer=self._tracer)
+            @tag_args
+            def fixture_deferred_func() -> "defer.Deferred[str]":
+                d1: defer.Deferred[str] = defer.Deferred()
+                d1.callback("foo")
+                return d1
+
+            result_d1 = fixture_deferred_func()
+
+            # let the tasks complete
+            reactor.pump((2,) * 8)
+
+            self.assertEqual(self.successResultOf(result_d1), "foo")
+
+        # the span should have been reported
+        self.assertEqual(
+            [span.operation_name for span in self._reporter.get_spans()],
+            ["fixture_deferred_func"],
+        )
+
+    def test_trace_decorator_async(self) -> None:
+        """
+        Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
+        with async functions
+        """
+        reactor = MemoryReactorClock()
+
+        with LoggingContext("root context"):
+
+            @trace_with_opname("fixture_async_func", tracer=self._tracer)
+            @tag_args
+            async def fixture_async_func() -> str:
+                return "foo"
+
+            d1 = defer.ensureDeferred(fixture_async_func())
+
+            # let the tasks complete
+            reactor.pump((2,) * 8)
+
+            self.assertEqual(self.successResultOf(d1), "foo")
+
+        # the span should have been reported
+        self.assertEqual(
+            [span.operation_name for span in self._reporter.get_spans()],
+            ["fixture_async_func"],
+        )