diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 5dddf57008..7df0aa197d 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -164,28 +164,28 @@ Gotchas
than one caller? Will all of those calling functions have be in a context
with an active span?
"""
-
import contextlib
import inspect
import logging
import re
-import types
from functools import wraps
-from typing import TYPE_CHECKING, Dict
+from typing import TYPE_CHECKING, Dict, Optional, Type
-from canonicaljson import json
+import attr
from twisted.internet import defer
from synapse.config import ConfigError
+from synapse.util import json_decoder, json_encoder
if TYPE_CHECKING:
+ from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
# Helper class
-class _DummyTagNames(object):
+class _DummyTagNames:
"""wrapper of opentracings tags. We need to have them if we
want to reference them without opentracing around. Clearly they
should never actually show up in a trace. `set_tags` overwrites
@@ -226,12 +226,37 @@ except ImportError:
tags = _DummyTagNames
try:
from jaeger_client import Config as JaegerConfig
+
from synapse.logging.scopecontextmanager import LogContextScopeManager
except ImportError:
JaegerConfig = None # type: ignore
LogContextScopeManager = None # type: ignore
+try:
+ from rust_python_jaeger_reporter import Reporter
+
+ @attr.s(slots=True, frozen=True)
+ class _WrappedRustReporter:
+ """Wrap the reporter to ensure `report_span` never throws.
+ """
+
+ _reporter = attr.ib(type=Reporter, default=attr.Factory(Reporter))
+
+ def set_process(self, *args, **kwargs):
+ return self._reporter.set_process(*args, **kwargs)
+
+ def report_span(self, span):
+ try:
+ return self._reporter.report_span(span)
+ except Exception:
+ logger.exception("Failed to report span")
+
+ RustReporter = _WrappedRustReporter # type: Optional[Type[_WrappedRustReporter]]
+except ImportError:
+ RustReporter = None
+
+
logger = logging.getLogger(__name__)
@@ -320,11 +345,19 @@ def init_tracer(hs: "HomeServer"):
set_homeserver_whitelist(hs.config.opentracer_whitelist)
- JaegerConfig(
+ config = JaegerConfig(
config=hs.config.jaeger_config,
service_name="{} {}".format(hs.config.server_name, hs.get_instance_name()),
scope_manager=LogContextScopeManager(hs.config),
- ).initialize_tracer()
+ )
+
+ # If we have the rust jaeger reporter available let's use that.
+ if RustReporter:
+ logger.info("Using rust_python_jaeger_reporter library")
+ tracer = config.create_tracer(RustReporter(), config.sampler)
+ opentracing.set_global_tracer(tracer)
+ else:
+ config.initialize_tracer()
# Whitelisting
@@ -466,7 +499,9 @@ def start_active_span_from_edu(
if opentracing is None:
return _noop_context_manager()
- carrier = json.loads(edu_content.get("context", "{}")).get("opentracing", {})
+ carrier = json_decoder.decode(edu_content.get("context", "{}")).get(
+ "opentracing", {}
+ )
context = opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
_references = [
opentracing.child_of(span_context_from_string(x))
@@ -657,7 +692,7 @@ def active_span_context_as_string():
opentracing.tracer.inject(
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
)
- return json.dumps(carrier)
+ return json_encoder.encode(carrier)
@only_if_tracing
@@ -666,7 +701,7 @@ def span_context_from_string(carrier):
Returns:
The active span context decoded from a string.
"""
- carrier = json.loads(carrier)
+ carrier = json_decoder.decode(carrier)
return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
@@ -700,37 +735,43 @@ def trace(func=None, opname=None):
_opname = opname if opname else func.__name__
- @wraps(func)
- def _trace_inner(*args, **kwargs):
- if opentracing is None:
- return func(*args, **kwargs)
+ if inspect.iscoroutinefunction(func):
- scope = start_active_span(_opname)
- scope.__enter__()
+ @wraps(func)
+ async def _trace_inner(*args, **kwargs):
+ with start_active_span(_opname):
+ return await func(*args, **kwargs)
- try:
- result = func(*args, **kwargs)
- if isinstance(result, defer.Deferred):
+ else:
+ # The other case here handles both sync functions and those
+ # decorated with inlineDeferred.
+ @wraps(func)
+ def _trace_inner(*args, **kwargs):
+ scope = start_active_span(_opname)
+ scope.__enter__()
- def call_back(result):
- scope.__exit__(None, None, None)
- return result
+ try:
+ result = func(*args, **kwargs)
+ if isinstance(result, defer.Deferred):
- def err_back(result):
- scope.span.set_tag(tags.ERROR, True)
- scope.__exit__(None, None, None)
- return result
+ def call_back(result):
+ scope.__exit__(None, None, None)
+ return result
- result.addCallbacks(call_back, err_back)
+ def err_back(result):
+ scope.__exit__(None, None, None)
+ return result
- else:
- scope.__exit__(None, None, None)
+ result.addCallbacks(call_back, err_back)
- return result
+ else:
+ scope.__exit__(None, None, None)
- except Exception as e:
- scope.__exit__(type(e), None, e.__traceback__)
- raise
+ return result
+
+ except Exception as e:
+ scope.__exit__(type(e), None, e.__traceback__)
+ raise
return _trace_inner
@@ -760,48 +801,42 @@ def tag_args(func):
return _tag_args_inner
-def trace_servlet(servlet_name, extract_context=False):
- """Decorator which traces a serlet. It starts a span with some servlet specific
- tags such as the servlet_name and request information
+@contextlib.contextmanager
+def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
+ """Returns a context manager which traces a request. It starts a span
+ with some servlet specific tags such as the request metrics name and
+ request information.
Args:
- servlet_name (str): The name to be used for the span's operation_name
- extract_context (bool): Whether to attempt to extract the opentracing
+ request
+ extract_context: Whether to attempt to extract the opentracing
context from the request the servlet is handling.
-
"""
- def _trace_servlet_inner_1(func):
- if not opentracing:
- return func
-
- @wraps(func)
- async def _trace_servlet_inner(request, *args, **kwargs):
- request_tags = {
- "request_id": request.get_request_id(),
- tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
- tags.HTTP_METHOD: request.get_method(),
- tags.HTTP_URL: request.get_redacted_uri(),
- tags.PEER_HOST_IPV6: request.getClientIP(),
- }
-
- if extract_context:
- scope = start_active_span_from_request(
- request, servlet_name, tags=request_tags
- )
- else:
- scope = start_active_span(servlet_name, tags=request_tags)
-
- with scope:
- result = func(request, *args, **kwargs)
+ if opentracing is None:
+ yield
+ return
- if not isinstance(result, (types.CoroutineType, defer.Deferred)):
- # Some servlets aren't async and just return results
- # directly, so we handle that here.
- return result
+ request_tags = {
+ "request_id": request.get_request_id(),
+ tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
+ tags.HTTP_METHOD: request.get_method(),
+ tags.HTTP_URL: request.get_redacted_uri(),
+ tags.PEER_HOST_IPV6: request.getClientIP(),
+ }
- return await result
+ request_name = request.request_metrics.name
+ if extract_context:
+ scope = start_active_span_from_request(request, request_name, tags=request_tags)
+ else:
+ scope = start_active_span(request_name, tags=request_tags)
- return _trace_servlet_inner
+ with scope:
+ try:
+ yield
+ finally:
+ # We set the operation name again in case its changed (which happens
+ # with JsonResource).
+ scope.span.set_operation_name(request.request_metrics.name)
- return _trace_servlet_inner_1
+ scope.span.set_tag("request_tag", request.request_metrics.start_context.tag)
|