summary refs log tree commit diff
path: root/synapse/logging/opentracing.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/logging/opentracing.py')
-rw-r--r--synapse/logging/opentracing.py114
1 files changed, 80 insertions, 34 deletions
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 256b972aaa..7246253018 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -239,8 +239,7 @@ _homeserver_whitelist = None
 
 
 def only_if_tracing(func):
-    """Executes the function only if we're tracing. Otherwise return.
-    Assumes the function wrapped may return None"""
+    """Executes the function only if we're tracing. Otherwise returns None."""
 
     @wraps(func)
     def _only_if_tracing_inner(*args, **kwargs):
@@ -252,6 +251,41 @@ def only_if_tracing(func):
     return _only_if_tracing_inner
 
 
+def ensure_active_span(message, ret=None):
+    """Executes the operation only if opentracing is enabled and there is an active span.
+    If there is no active span it logs message at the error level.
+
+    Args:
+        message (str): Message which fills in "There was no active span when trying to %s"
+            in the error log if there is no active span and opentracing is enabled.
+        ret (object): return value if opentracing is None or there is no active span.
+
+    Returns (object): The result of the func or ret if opentracing is disabled or there
+        was no active span.
+    """
+
+    def ensure_active_span_inner_1(func):
+        @wraps(func)
+        def ensure_active_span_inner_2(*args, **kwargs):
+            if not opentracing:
+                return ret
+
+            if not opentracing.tracer.active_span:
+                logger.error(
+                    "There was no active span when trying to %s."
+                    " Did you forget to start one or did a context slip?",
+                    message,
+                )
+
+                return ret
+
+            return func(*args, **kwargs)
+
+        return ensure_active_span_inner_2
+
+    return ensure_active_span_inner_1
+
+
 @contextlib.contextmanager
 def _noop_context_manager(*args, **kwargs):
     """Does exactly what it says on the tin"""
@@ -319,7 +353,7 @@ def whitelisted_homeserver(destination):
     Args:
         destination (str)
         """
-    _homeserver_whitelist
+
     if _homeserver_whitelist:
         return _homeserver_whitelist.match(destination)
     return False
@@ -349,26 +383,24 @@ def start_active_span(
     if opentracing is None:
         return _noop_context_manager()
 
-    else:
-        # We need to enter the scope here for the logcontext to become active
-        return opentracing.tracer.start_active_span(
-            operation_name,
-            child_of=child_of,
-            references=references,
-            tags=tags,
-            start_time=start_time,
-            ignore_active_span=ignore_active_span,
-            finish_on_close=finish_on_close,
-        )
+    return opentracing.tracer.start_active_span(
+        operation_name,
+        child_of=child_of,
+        references=references,
+        tags=tags,
+        start_time=start_time,
+        ignore_active_span=ignore_active_span,
+        finish_on_close=finish_on_close,
+    )
 
 
 def start_active_span_follows_from(operation_name, contexts):
     if opentracing is None:
         return _noop_context_manager()
-    else:
-        references = [opentracing.follows_from(context) for context in contexts]
-        scope = start_active_span(operation_name, references=references)
-        return scope
+
+    references = [opentracing.follows_from(context) for context in contexts]
+    scope = start_active_span(operation_name, references=references)
+    return scope
 
 
 def start_active_span_from_request(
@@ -465,19 +497,19 @@ def start_active_span_from_edu(
 # Opentracing setters for tags, logs, etc
 
 
-@only_if_tracing
+@ensure_active_span("set a tag")
 def set_tag(key, value):
     """Sets a tag on the active span"""
     opentracing.tracer.active_span.set_tag(key, value)
 
 
-@only_if_tracing
+@ensure_active_span("log")
 def log_kv(key_values, timestamp=None):
     """Log to the active span"""
     opentracing.tracer.active_span.log_kv(key_values, timestamp)
 
 
-@only_if_tracing
+@ensure_active_span("set the traces operation name")
 def set_operation_name(operation_name):
     """Sets the operation name of the active span"""
     opentracing.tracer.active_span.set_operation_name(operation_name)
@@ -486,13 +518,18 @@ def set_operation_name(operation_name):
 # Injection and extraction
 
 
-@only_if_tracing
+@ensure_active_span("inject the span into a header")
 def inject_active_span_twisted_headers(headers, destination, check_destination=True):
     """
     Injects a span context into twisted headers in-place
 
     Args:
         headers (twisted.web.http_headers.Headers)
+        destination (str): address of entity receiving the span context. If check_destination
+            is true the context will only be injected if the destination matches the
+            opentracing whitelist
+        check_destination (bool): If false, destination will be ignored and the context
+            will always be injected.
         span (opentracing.Span)
 
     Returns:
@@ -517,7 +554,7 @@ def inject_active_span_twisted_headers(headers, destination, check_destination=T
         headers.addRawHeaders(key, value)
 
 
-@only_if_tracing
+@ensure_active_span("inject the span into a byte dict")
 def inject_active_span_byte_dict(headers, destination, check_destination=True):
     """
     Injects a span context into a dict where the headers are encoded as byte
@@ -525,6 +562,11 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True):
 
     Args:
         headers (dict)
+        destination (str): address of entity receiving the span context. If check_destination
+            is true the context will only be injected if the destination matches the
+            opentracing whitelist
+        check_destination (bool): If false, destination will be ignored and the context
+            will always be injected.
         span (opentracing.Span)
 
     Returns:
@@ -537,7 +579,7 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True):
         here:
         https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py
     """
-    if not whitelisted_homeserver(destination):
+    if check_destination and not whitelisted_homeserver(destination):
         return
 
     span = opentracing.tracer.active_span
@@ -549,16 +591,18 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True):
         headers[key.encode()] = [value.encode()]
 
 
-@only_if_tracing
+@ensure_active_span("inject the span into a text map")
 def inject_active_span_text_map(carrier, destination, check_destination=True):
     """
     Injects a span context into a dict
 
     Args:
         carrier (dict)
-        destination (str): the name of the remote server. The span context
-        will only be injected if the destination matches the homeserver_whitelist
-        or destination is None.
+        destination (str): address of entity receiving the span context. If check_destination
+            is true the context will only be injected if the destination matches the
+            opentracing whitelist
+        check_destination (bool): If false, destination will be ignored and the context
+            will always be injected.
 
     Returns:
         In-place modification of carrier
@@ -579,6 +623,7 @@ def inject_active_span_text_map(carrier, destination, check_destination=True):
     )
 
 
+@ensure_active_span("get the active span context as a dict", ret={})
 def get_active_span_text_map(destination=None):
     """
     Gets a span context as a dict. This can be used instead of manually
@@ -591,7 +636,7 @@ def get_active_span_text_map(destination=None):
         dict: the active span's context if opentracing is enabled, otherwise empty.
     """
 
-    if not opentracing or (destination and not whitelisted_homeserver(destination)):
+    if destination and not whitelisted_homeserver(destination):
         return {}
 
     carrier = {}
@@ -602,6 +647,7 @@ def get_active_span_text_map(destination=None):
     return carrier
 
 
+@ensure_active_span("get the span context as a string.", ret={})
 def active_span_context_as_string():
     """
     Returns:
@@ -656,15 +702,15 @@ def trace(func=None, opname=None):
         _opname = opname if opname else func.__name__
 
         @wraps(func)
-        def _trace_inner(self, *args, **kwargs):
+        def _trace_inner(*args, **kwargs):
             if opentracing is None:
-                return func(self, *args, **kwargs)
+                return func(*args, **kwargs)
 
             scope = start_active_span(_opname)
             scope.__enter__()
 
             try:
-                result = func(self, *args, **kwargs)
+                result = func(*args, **kwargs)
                 if isinstance(result, defer.Deferred):
 
                     def call_back(result):
@@ -704,13 +750,13 @@ def tag_args(func):
         return func
 
     @wraps(func)
-    def _tag_args_inner(self, *args, **kwargs):
+    def _tag_args_inner(*args, **kwargs):
         argspec = inspect.getargspec(func)
         for i, arg in enumerate(argspec.args[1:]):
             set_tag("ARG_" + arg, args[i])
         set_tag("args", args[len(argspec.args) :])
         set_tag("kwargs", kwargs)
-        return func(self, *args, **kwargs)
+        return func(*args, **kwargs)
 
     return _tag_args_inner