summary refs log tree commit diff
path: root/synapse/logging/opentracing.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/logging/opentracing.py')
-rw-r--r--synapse/logging/opentracing.py68
1 files changed, 31 insertions, 37 deletions
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 73bef5e5ca..1676771ef0 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -169,7 +169,6 @@ import contextlib
 import inspect
 import logging
 import re
-import types
 from functools import wraps
 from typing import TYPE_CHECKING, Dict, Optional, Type
 
@@ -182,6 +181,7 @@ from synapse.config import ConfigError
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
+    from synapse.http.site import SynapseRequest
 
 # Helper class
 
@@ -793,48 +793,42 @@ def tag_args(func):
     return _tag_args_inner
 
 
-def trace_servlet(servlet_name, extract_context=False):
-    """Decorator which traces a serlet. It starts a span with some servlet specific
-    tags such as the servlet_name and request information
+@contextlib.contextmanager
+def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
+    """Returns a context manager which traces a request. It starts a span
+    with some servlet specific tags such as the request metrics name and
+    request information.
 
     Args:
-        servlet_name (str): The name to be used for the span's operation_name
-        extract_context (bool): Whether to attempt to extract the opentracing
+        request
+        extract_context: Whether to attempt to extract the opentracing
             context from the request the servlet is handling.
-
     """
 
-    def _trace_servlet_inner_1(func):
-        if not opentracing:
-            return func
-
-        @wraps(func)
-        async def _trace_servlet_inner(request, *args, **kwargs):
-            request_tags = {
-                "request_id": request.get_request_id(),
-                tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
-                tags.HTTP_METHOD: request.get_method(),
-                tags.HTTP_URL: request.get_redacted_uri(),
-                tags.PEER_HOST_IPV6: request.getClientIP(),
-            }
-
-            if extract_context:
-                scope = start_active_span_from_request(
-                    request, servlet_name, tags=request_tags
-                )
-            else:
-                scope = start_active_span(servlet_name, tags=request_tags)
-
-            with scope:
-                result = func(request, *args, **kwargs)
+    if opentracing is None:
+        yield
+        return
 
-                if not isinstance(result, (types.CoroutineType, defer.Deferred)):
-                    # Some servlets aren't async and just return results
-                    # directly, so we handle that here.
-                    return result
+    request_tags = {
+        "request_id": request.get_request_id(),
+        tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
+        tags.HTTP_METHOD: request.get_method(),
+        tags.HTTP_URL: request.get_redacted_uri(),
+        tags.PEER_HOST_IPV6: request.getClientIP(),
+    }
 
-                return await result
+    request_name = request.request_metrics.name
+    if extract_context:
+        scope = start_active_span_from_request(request, request_name, tags=request_tags)
+    else:
+        scope = start_active_span(request_name, tags=request_tags)
 
-        return _trace_servlet_inner
+    with scope:
+        try:
+            yield
+        finally:
+            # We set the operation name again in case its changed (which happens
+            # with JsonResource).
+            scope.span.set_operation_name(request.request_metrics.name)
 
-    return _trace_servlet_inner_1
+            scope.span.set_tag("request_tag", request.request_metrics.start_context.tag)