diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 308a27213b..5dddf57008 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -169,7 +169,9 @@ import contextlib
import inspect
import logging
import re
+import types
from functools import wraps
+from typing import TYPE_CHECKING, Dict
from canonicaljson import json
@@ -177,6 +179,9 @@ from twisted.internet import defer
from synapse.config import ConfigError
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
# Helper class
@@ -295,14 +300,11 @@ def _noop_context_manager(*args, **kwargs):
# Setup
-def init_tracer(config):
+def init_tracer(hs: "HomeServer"):
"""Set the whitelists and initialise the JaegerClient tracer
-
- Args:
- config (HomeserverConfig): The config used by the homeserver
"""
global opentracing
- if not config.opentracer_enabled:
+ if not hs.config.opentracer_enabled:
# We don't have a tracer
opentracing = None
return
@@ -313,18 +315,15 @@ def init_tracer(config):
"installed."
)
- # Include the worker name
- name = config.worker_name if config.worker_name else "master"
-
# Pull out the jaeger config if it was given. Otherwise set it to something sensible.
# See https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/config.py
- set_homeserver_whitelist(config.opentracer_whitelist)
+ set_homeserver_whitelist(hs.config.opentracer_whitelist)
JaegerConfig(
- config=config.jaeger_config,
- service_name="{} {}".format(config.server_name, name),
- scope_manager=LogContextScopeManager(config),
+ config=hs.config.jaeger_config,
+ service_name="{} {}".format(hs.config.server_name, hs.get_instance_name()),
+ scope_manager=LogContextScopeManager(hs.config),
).initialize_tracer()
@@ -547,7 +546,7 @@ def inject_active_span_twisted_headers(headers, destination, check_destination=T
return
span = opentracing.tracer.active_span
- carrier = {}
+ carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
for key, value in carrier.items():
@@ -584,7 +583,7 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True):
span = opentracing.tracer.active_span
- carrier = {}
+ carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
for key, value in carrier.items():
@@ -639,7 +638,7 @@ def get_active_span_text_map(destination=None):
if destination and not whitelisted_homeserver(destination):
return {}
- carrier = {}
+ carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
)
@@ -653,7 +652,7 @@ def active_span_context_as_string():
Returns:
The active span context encoded as a string.
"""
- carrier = {}
+ carrier = {} # type: Dict[str, str]
if opentracing:
opentracing.tracer.inject(
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
@@ -777,8 +776,7 @@ def trace_servlet(servlet_name, extract_context=False):
return func
@wraps(func)
- @defer.inlineCallbacks
- def _trace_servlet_inner(request, *args, **kwargs):
+ async def _trace_servlet_inner(request, *args, **kwargs):
request_tags = {
"request_id": request.get_request_id(),
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
@@ -795,8 +793,14 @@ def trace_servlet(servlet_name, extract_context=False):
scope = start_active_span(servlet_name, tags=request_tags)
with scope:
- result = yield defer.maybeDeferred(func, request, *args, **kwargs)
- return result
+ result = func(request, *args, **kwargs)
+
+ if not isinstance(result, (types.CoroutineType, defer.Deferred)):
+ # Some servlets aren't async and just return results
+ # directly, so we handle that here.
+ return result
+
+ return await result
return _trace_servlet_inner
|