summary refs log tree commit diff
path: root/synapse/logging/opentracing.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/logging/opentracing.py')
-rw-r--r--synapse/logging/opentracing.py173
1 files changed, 104 insertions, 69 deletions
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 5dddf57008..7df0aa197d 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -164,28 +164,28 @@ Gotchas
   than one caller? Will all of those calling functions have be in a context
   with an active span?
 """
-
 import contextlib
 import inspect
 import logging
 import re
-import types
 from functools import wraps
-from typing import TYPE_CHECKING, Dict
+from typing import TYPE_CHECKING, Dict, Optional, Type
 
-from canonicaljson import json
+import attr
 
 from twisted.internet import defer
 
 from synapse.config import ConfigError
+from synapse.util import json_decoder, json_encoder
 
 if TYPE_CHECKING:
+    from synapse.http.site import SynapseRequest
     from synapse.server import HomeServer
 
 # Helper class
 
 
-class _DummyTagNames(object):
+class _DummyTagNames:
     """wrapper of opentracings tags. We need to have them if we
     want to reference them without opentracing around. Clearly they
     should never actually show up in a trace. `set_tags` overwrites
@@ -226,12 +226,37 @@ except ImportError:
     tags = _DummyTagNames
 try:
     from jaeger_client import Config as JaegerConfig
+
     from synapse.logging.scopecontextmanager import LogContextScopeManager
 except ImportError:
     JaegerConfig = None  # type: ignore
     LogContextScopeManager = None  # type: ignore
 
 
+try:
+    from rust_python_jaeger_reporter import Reporter
+
+    @attr.s(slots=True, frozen=True)
+    class _WrappedRustReporter:
+        """Wrap the reporter to ensure `report_span` never throws.
+        """
+
+        _reporter = attr.ib(type=Reporter, default=attr.Factory(Reporter))
+
+        def set_process(self, *args, **kwargs):
+            return self._reporter.set_process(*args, **kwargs)
+
+        def report_span(self, span):
+            try:
+                return self._reporter.report_span(span)
+            except Exception:
+                logger.exception("Failed to report span")
+
+    RustReporter = _WrappedRustReporter  # type: Optional[Type[_WrappedRustReporter]]
+except ImportError:
+    RustReporter = None
+
+
 logger = logging.getLogger(__name__)
 
 
@@ -320,11 +345,19 @@ def init_tracer(hs: "HomeServer"):
 
     set_homeserver_whitelist(hs.config.opentracer_whitelist)
 
-    JaegerConfig(
+    config = JaegerConfig(
         config=hs.config.jaeger_config,
         service_name="{} {}".format(hs.config.server_name, hs.get_instance_name()),
         scope_manager=LogContextScopeManager(hs.config),
-    ).initialize_tracer()
+    )
+
+    # If we have the rust jaeger reporter available let's use that.
+    if RustReporter:
+        logger.info("Using rust_python_jaeger_reporter library")
+        tracer = config.create_tracer(RustReporter(), config.sampler)
+        opentracing.set_global_tracer(tracer)
+    else:
+        config.initialize_tracer()
 
 
 # Whitelisting
@@ -466,7 +499,9 @@ def start_active_span_from_edu(
     if opentracing is None:
         return _noop_context_manager()
 
-    carrier = json.loads(edu_content.get("context", "{}")).get("opentracing", {})
+    carrier = json_decoder.decode(edu_content.get("context", "{}")).get(
+        "opentracing", {}
+    )
     context = opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
     _references = [
         opentracing.child_of(span_context_from_string(x))
@@ -657,7 +692,7 @@ def active_span_context_as_string():
         opentracing.tracer.inject(
             opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
         )
-    return json.dumps(carrier)
+    return json_encoder.encode(carrier)
 
 
 @only_if_tracing
@@ -666,7 +701,7 @@ def span_context_from_string(carrier):
     Returns:
         The active span context decoded from a string.
     """
-    carrier = json.loads(carrier)
+    carrier = json_decoder.decode(carrier)
     return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
 
 
@@ -700,37 +735,43 @@ def trace(func=None, opname=None):
 
         _opname = opname if opname else func.__name__
 
-        @wraps(func)
-        def _trace_inner(*args, **kwargs):
-            if opentracing is None:
-                return func(*args, **kwargs)
+        if inspect.iscoroutinefunction(func):
 
-            scope = start_active_span(_opname)
-            scope.__enter__()
+            @wraps(func)
+            async def _trace_inner(*args, **kwargs):
+                with start_active_span(_opname):
+                    return await func(*args, **kwargs)
 
-            try:
-                result = func(*args, **kwargs)
-                if isinstance(result, defer.Deferred):
+        else:
+            # The other case here handles both sync functions and those
+            # decorated with inlineDeferred.
+            @wraps(func)
+            def _trace_inner(*args, **kwargs):
+                scope = start_active_span(_opname)
+                scope.__enter__()
 
-                    def call_back(result):
-                        scope.__exit__(None, None, None)
-                        return result
+                try:
+                    result = func(*args, **kwargs)
+                    if isinstance(result, defer.Deferred):
 
-                    def err_back(result):
-                        scope.span.set_tag(tags.ERROR, True)
-                        scope.__exit__(None, None, None)
-                        return result
+                        def call_back(result):
+                            scope.__exit__(None, None, None)
+                            return result
 
-                    result.addCallbacks(call_back, err_back)
+                        def err_back(result):
+                            scope.__exit__(None, None, None)
+                            return result
 
-                else:
-                    scope.__exit__(None, None, None)
+                        result.addCallbacks(call_back, err_back)
 
-                return result
+                    else:
+                        scope.__exit__(None, None, None)
 
-            except Exception as e:
-                scope.__exit__(type(e), None, e.__traceback__)
-                raise
+                    return result
+
+                except Exception as e:
+                    scope.__exit__(type(e), None, e.__traceback__)
+                    raise
 
         return _trace_inner
 
@@ -760,48 +801,42 @@ def tag_args(func):
     return _tag_args_inner
 
 
-def trace_servlet(servlet_name, extract_context=False):
-    """Decorator which traces a serlet. It starts a span with some servlet specific
-    tags such as the servlet_name and request information
+@contextlib.contextmanager
+def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
+    """Returns a context manager which traces a request. It starts a span
+    with some servlet specific tags such as the request metrics name and
+    request information.
 
     Args:
-        servlet_name (str): The name to be used for the span's operation_name
-        extract_context (bool): Whether to attempt to extract the opentracing
+        request
+        extract_context: Whether to attempt to extract the opentracing
             context from the request the servlet is handling.
-
     """
 
-    def _trace_servlet_inner_1(func):
-        if not opentracing:
-            return func
-
-        @wraps(func)
-        async def _trace_servlet_inner(request, *args, **kwargs):
-            request_tags = {
-                "request_id": request.get_request_id(),
-                tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
-                tags.HTTP_METHOD: request.get_method(),
-                tags.HTTP_URL: request.get_redacted_uri(),
-                tags.PEER_HOST_IPV6: request.getClientIP(),
-            }
-
-            if extract_context:
-                scope = start_active_span_from_request(
-                    request, servlet_name, tags=request_tags
-                )
-            else:
-                scope = start_active_span(servlet_name, tags=request_tags)
-
-            with scope:
-                result = func(request, *args, **kwargs)
+    if opentracing is None:
+        yield
+        return
 
-                if not isinstance(result, (types.CoroutineType, defer.Deferred)):
-                    # Some servlets aren't async and just return results
-                    # directly, so we handle that here.
-                    return result
+    request_tags = {
+        "request_id": request.get_request_id(),
+        tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
+        tags.HTTP_METHOD: request.get_method(),
+        tags.HTTP_URL: request.get_redacted_uri(),
+        tags.PEER_HOST_IPV6: request.getClientIP(),
+    }
 
-                return await result
+    request_name = request.request_metrics.name
+    if extract_context:
+        scope = start_active_span_from_request(request, request_name, tags=request_tags)
+    else:
+        scope = start_active_span(request_name, tags=request_tags)
 
-        return _trace_servlet_inner
+    with scope:
+        try:
+            yield
+        finally:
+            # We set the operation name again in case its changed (which happens
+            # with JsonResource).
+            scope.span.set_operation_name(request.request_metrics.name)
 
-    return _trace_servlet_inner_1
+            scope.span.set_tag("request_tag", request.request_metrics.start_context.tag)