diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 73bef5e5ca..1676771ef0 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -169,7 +169,6 @@ import contextlib
import inspect
import logging
import re
-import types
from functools import wraps
from typing import TYPE_CHECKING, Dict, Optional, Type
@@ -182,6 +181,7 @@ from synapse.config import ConfigError
if TYPE_CHECKING:
from synapse.server import HomeServer
+ from synapse.http.site import SynapseRequest
# Helper class
@@ -793,48 +793,42 @@ def tag_args(func):
return _tag_args_inner
-def trace_servlet(servlet_name, extract_context=False):
- """Decorator which traces a serlet. It starts a span with some servlet specific
- tags such as the servlet_name and request information
+@contextlib.contextmanager
+def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
+ """Returns a context manager which traces a request. It starts a span
+ with some servlet specific tags such as the request metrics name and
+ request information.
Args:
- servlet_name (str): The name to be used for the span's operation_name
- extract_context (bool): Whether to attempt to extract the opentracing
+ request
+ extract_context: Whether to attempt to extract the opentracing
context from the request the servlet is handling.
-
"""
- def _trace_servlet_inner_1(func):
- if not opentracing:
- return func
-
- @wraps(func)
- async def _trace_servlet_inner(request, *args, **kwargs):
- request_tags = {
- "request_id": request.get_request_id(),
- tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
- tags.HTTP_METHOD: request.get_method(),
- tags.HTTP_URL: request.get_redacted_uri(),
- tags.PEER_HOST_IPV6: request.getClientIP(),
- }
-
- if extract_context:
- scope = start_active_span_from_request(
- request, servlet_name, tags=request_tags
- )
- else:
- scope = start_active_span(servlet_name, tags=request_tags)
-
- with scope:
- result = func(request, *args, **kwargs)
+ if opentracing is None:
+ yield
+ return
- if not isinstance(result, (types.CoroutineType, defer.Deferred)):
- # Some servlets aren't async and just return results
- # directly, so we handle that here.
- return result
+ request_tags = {
+ "request_id": request.get_request_id(),
+ tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
+ tags.HTTP_METHOD: request.get_method(),
+ tags.HTTP_URL: request.get_redacted_uri(),
+ tags.PEER_HOST_IPV6: request.getClientIP(),
+ }
- return await result
+ request_name = request.request_metrics.name
+ if extract_context:
+ scope = start_active_span_from_request(request, request_name, tags=request_tags)
+ else:
+ scope = start_active_span(request_name, tags=request_tags)
- return _trace_servlet_inner
+ with scope:
+ try:
+ yield
+ finally:
+ # We set the operation name again in case its changed (which happens
+ # with JsonResource).
+ scope.span.set_operation_name(request.request_metrics.name)
- return _trace_servlet_inner_1
+ scope.span.set_tag("request_tag", request.request_metrics.start_context.tag)
|