summary refs log tree commit diff
path: root/synapse/logging/opentracing.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/logging/opentracing.py')
-rw-r--r--synapse/logging/opentracing.py44
1 files changed, 24 insertions, 20 deletions
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 308a27213b..5dddf57008 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -169,7 +169,9 @@ import contextlib
 import inspect
 import logging
 import re
+import types
 from functools import wraps
+from typing import TYPE_CHECKING, Dict
 
 from canonicaljson import json
 
@@ -177,6 +179,9 @@ from twisted.internet import defer
 
 from synapse.config import ConfigError
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 # Helper class
 
 
@@ -295,14 +300,11 @@ def _noop_context_manager(*args, **kwargs):
 # Setup
 
 
-def init_tracer(config):
+def init_tracer(hs: "HomeServer"):
     """Set the whitelists and initialise the JaegerClient tracer
-
-    Args:
-        config (HomeserverConfig): The config used by the homeserver
     """
     global opentracing
-    if not config.opentracer_enabled:
+    if not hs.config.opentracer_enabled:
         # We don't have a tracer
         opentracing = None
         return
@@ -313,18 +315,15 @@ def init_tracer(config):
             "installed."
         )
 
-    # Include the worker name
-    name = config.worker_name if config.worker_name else "master"
-
     # Pull out the jaeger config if it was given. Otherwise set it to something sensible.
     # See https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/config.py
 
-    set_homeserver_whitelist(config.opentracer_whitelist)
+    set_homeserver_whitelist(hs.config.opentracer_whitelist)
 
     JaegerConfig(
-        config=config.jaeger_config,
-        service_name="{} {}".format(config.server_name, name),
-        scope_manager=LogContextScopeManager(config),
+        config=hs.config.jaeger_config,
+        service_name="{} {}".format(hs.config.server_name, hs.get_instance_name()),
+        scope_manager=LogContextScopeManager(hs.config),
     ).initialize_tracer()
 
 
@@ -547,7 +546,7 @@ def inject_active_span_twisted_headers(headers, destination, check_destination=T
         return
 
     span = opentracing.tracer.active_span
-    carrier = {}
+    carrier = {}  # type: Dict[str, str]
     opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
 
     for key, value in carrier.items():
@@ -584,7 +583,7 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True):
 
     span = opentracing.tracer.active_span
 
-    carrier = {}
+    carrier = {}  # type: Dict[str, str]
     opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
 
     for key, value in carrier.items():
@@ -639,7 +638,7 @@ def get_active_span_text_map(destination=None):
     if destination and not whitelisted_homeserver(destination):
         return {}
 
-    carrier = {}
+    carrier = {}  # type: Dict[str, str]
     opentracing.tracer.inject(
         opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
     )
@@ -653,7 +652,7 @@ def active_span_context_as_string():
     Returns:
         The active span context encoded as a string.
     """
-    carrier = {}
+    carrier = {}  # type: Dict[str, str]
     if opentracing:
         opentracing.tracer.inject(
             opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
@@ -777,8 +776,7 @@ def trace_servlet(servlet_name, extract_context=False):
             return func
 
         @wraps(func)
-        @defer.inlineCallbacks
-        def _trace_servlet_inner(request, *args, **kwargs):
+        async def _trace_servlet_inner(request, *args, **kwargs):
             request_tags = {
                 "request_id": request.get_request_id(),
                 tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
@@ -795,8 +793,14 @@ def trace_servlet(servlet_name, extract_context=False):
                 scope = start_active_span(servlet_name, tags=request_tags)
 
             with scope:
-                result = yield defer.maybeDeferred(func, request, *args, **kwargs)
-                return result
+                result = func(request, *args, **kwargs)
+
+                if not isinstance(result, (types.CoroutineType, defer.Deferred)):
+                    # Some servlets aren't async and just return results
+                    # directly, so we handle that here.
+                    return result
+
+                return await result
 
         return _trace_servlet_inner