diff options
Diffstat (limited to 'synapse/logging/opentracing.py')
-rw-r--r-- | synapse/logging/opentracing.py | 27 |
1 files changed, 17 insertions, 10 deletions
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 7246253018..0638cec429 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -169,7 +169,9 @@ import contextlib import inspect import logging import re +import types from functools import wraps +from typing import Dict from canonicaljson import json @@ -223,8 +225,8 @@ try: from jaeger_client import Config as JaegerConfig from synapse.logging.scopecontextmanager import LogContextScopeManager except ImportError: - JaegerConfig = None - LogContextScopeManager = None + JaegerConfig = None # type: ignore + LogContextScopeManager = None # type: ignore logger = logging.getLogger(__name__) @@ -547,7 +549,7 @@ def inject_active_span_twisted_headers(headers, destination, check_destination=T return span = opentracing.tracer.active_span - carrier = {} + carrier = {} # type: Dict[str, str] opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier) for key, value in carrier.items(): @@ -584,7 +586,7 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True): span = opentracing.tracer.active_span - carrier = {} + carrier = {} # type: Dict[str, str] opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier) for key, value in carrier.items(): @@ -639,7 +641,7 @@ def get_active_span_text_map(destination=None): if destination and not whitelisted_homeserver(destination): return {} - carrier = {} + carrier = {} # type: Dict[str, str] opentracing.tracer.inject( opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier ) @@ -653,7 +655,7 @@ def active_span_context_as_string(): Returns: The active span context encoded as a string. """ - carrier = {} + carrier = {} # type: Dict[str, str] if opentracing: opentracing.tracer.inject( opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier @@ -777,8 +779,7 @@ def trace_servlet(servlet_name, extract_context=False): return func @wraps(func) - @defer.inlineCallbacks - def _trace_servlet_inner(request, *args, **kwargs): + async def _trace_servlet_inner(request, *args, **kwargs): request_tags = { "request_id": request.get_request_id(), tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER, @@ -795,8 +796,14 @@ def trace_servlet(servlet_name, extract_context=False): scope = start_active_span(servlet_name, tags=request_tags) with scope: - result = yield defer.maybeDeferred(func, request, *args, **kwargs) - return result + result = func(request, *args, **kwargs) + + if not isinstance(result, (types.CoroutineType, defer.Deferred)): + # Some servlets aren't async and just return results + # directly, so we handle that here. + return result + + return await result return _trace_servlet_inner |