diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 5d93ab07f1..6364290615 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -173,6 +173,7 @@ from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Pattern, Typ
import attr
from twisted.internet import defer
+from twisted.web.http import Request
from twisted.web.http_headers import Headers
from synapse.config import ConfigError
@@ -490,48 +491,6 @@ def start_active_span_follows_from(
return scope
-def start_active_span_from_request(
- request,
- operation_name,
- references=None,
- tags=None,
- start_time=None,
- ignore_active_span=False,
- finish_on_close=True,
-):
- """
- Extracts a span context from a Twisted Request.
- args:
- headers (twisted.web.http.Request)
-
- For the other args see opentracing.tracer
-
- returns:
- span_context (opentracing.span.SpanContext)
- """
- # Twisted encodes the values as lists whereas opentracing doesn't.
- # So, we take the first item in the list.
- # Also, twisted uses byte arrays while opentracing expects strings.
-
- if opentracing is None:
- return noop_context_manager() # type: ignore[unreachable]
-
- header_dict = {
- k.decode(): v[0].decode() for k, v in request.requestHeaders.getAllRawHeaders()
- }
- context = opentracing.tracer.extract(opentracing.Format.HTTP_HEADERS, header_dict)
-
- return opentracing.tracer.start_active_span(
- operation_name,
- child_of=context,
- references=references,
- tags=tags,
- start_time=start_time,
- ignore_active_span=ignore_active_span,
- finish_on_close=finish_on_close,
- )
-
-
def start_active_span_from_edu(
edu_content,
operation_name,
@@ -743,6 +702,20 @@ def active_span_context_as_string():
return json_encoder.encode(carrier)
+def span_context_from_request(request: Request) -> "Optional[opentracing.SpanContext]":
+ """Extract an opentracing context from the headers on an HTTP request
+
+ This is useful when we have received an HTTP request from another part of our
+ system, and want to link our spans to those of the remote system.
+ """
+ if not opentracing:
+ return None
+ header_dict = {
+ k.decode(): v[0].decode() for k, v in request.requestHeaders.getAllRawHeaders()
+ }
+ return opentracing.tracer.extract(opentracing.Format.HTTP_HEADERS, header_dict)
+
+
@only_if_tracing
def span_context_from_string(carrier):
"""
@@ -882,10 +855,13 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
}
request_name = request.request_metrics.name
- if extract_context:
- scope = start_active_span_from_request(request, request_name)
- else:
- scope = start_active_span(request_name)
+ context = span_context_from_request(request) if extract_context else None
+
+ # we configure the scope not to finish the span immediately on exit, and instead
+ # pass the span into the SynapseRequest, which will finish it once we've finished
+ # sending the response to the client.
+ scope = start_active_span(request_name, child_of=context, finish_on_close=False)
+ request.set_opentracing_span(scope.span)
with scope:
inject_response_headers(request.responseHeaders)
|