summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/auth.py4
-rw-r--r--synapse/logging/opentracing.py57
-rw-r--r--synapse/storage/persist_events.py46
3 files changed, 98 insertions, 9 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 26a3b38918..cf4333a923 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -207,7 +207,7 @@ class Auth:
 
                 request.requester = user_id
                 if user_id in self._force_tracing_for_users:
-                    opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
+                    opentracing.force_tracing()
                 opentracing.set_tag("authenticated_entity", user_id)
                 opentracing.set_tag("user_id", user_id)
                 opentracing.set_tag("appservice_id", app_service.id)
@@ -260,7 +260,7 @@ class Auth:
 
             request.requester = requester
             if user_info.token_owner in self._force_tracing_for_users:
-                opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
+                opentracing.force_tracing()
             opentracing.set_tag("authenticated_entity", user_info.token_owner)
             opentracing.set_tag("user_id", user_info.user_id)
             if device_id:
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 5b4725e035..4f18792c99 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -168,7 +168,7 @@ import inspect
 import logging
 import re
 from functools import wraps
-from typing import TYPE_CHECKING, Dict, List, Optional, Pattern, Type
+from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Pattern, Type
 
 import attr
 
@@ -278,6 +278,10 @@ class SynapseTags:
     DB_TXN_ID = "db.txn_id"
 
 
+class SynapseBaggage:
+    FORCE_TRACING = "synapse-force-tracing"
+
+
 # Block everything by default
 # A regex which matches the server_names to expose traces for.
 # None means 'block everything'.
@@ -285,6 +289,8 @@ _homeserver_whitelist = None  # type: Optional[Pattern[str]]
 
 # Util methods
 
+Sentinel = object()
+
 
 def only_if_tracing(func):
     """Executes the function only if we're tracing. Otherwise returns None."""
@@ -447,12 +453,28 @@ def start_active_span(
     )
 
 
-def start_active_span_follows_from(operation_name, contexts):
+def start_active_span_follows_from(
+    operation_name: str, contexts: Collection, inherit_force_tracing=False
+):
+    """Starts an active opentracing span, with additional references to previous spans
+
+    Args:
+        operation_name: name of the operation represented by the new span
+        contexts: the previous spans to inherit from
+        inherit_force_tracing: if set, and any of the previous contexts have had tracing
+           forced, the new span will also have tracing forced.
+    """
     if opentracing is None:
         return noop_context_manager()
 
     references = [opentracing.follows_from(context) for context in contexts]
     scope = start_active_span(operation_name, references=references)
+
+    if inherit_force_tracing and any(
+        is_context_forced_tracing(ctx) for ctx in contexts
+    ):
+        force_tracing(scope.span)
+
     return scope
 
 
@@ -551,6 +573,10 @@ def start_active_span_from_edu(
 
 
 # Opentracing setters for tags, logs, etc
+@only_if_tracing
+def active_span():
+    """Get the currently active span, if any"""
+    return opentracing.tracer.active_span
 
 
 @ensure_active_span("set a tag")
@@ -571,6 +597,33 @@ def set_operation_name(operation_name):
     opentracing.tracer.active_span.set_operation_name(operation_name)
 
 
+@only_if_tracing
+def force_tracing(span=Sentinel) -> None:
+    """Force sampling for the active/given span and its children.
+
+    Args:
+        span: span to force tracing for. By default, the active span.
+    """
+    if span is Sentinel:
+        span = opentracing.tracer.active_span
+    if span is None:
+        logger.error("No active span in force_tracing")
+        return
+
+    span.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
+
+    # also set a bit of baggage, so that we have a way of figuring out if
+    # it is enabled later
+    span.set_baggage_item(SynapseBaggage.FORCE_TRACING, "1")
+
+
+def is_context_forced_tracing(span_context) -> bool:
+    """Check if sampling has been force for the given span context."""
+    if span_context is None:
+        return False
+    return span_context.baggage.get(SynapseBaggage.FORCE_TRACING) is not None
+
+
 # Injection and extraction
 
 
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index c11f6c5845..dc38942bb1 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -18,6 +18,7 @@ import itertools
 import logging
 from collections import deque
 from typing import (
+    Any,
     Awaitable,
     Callable,
     Collection,
@@ -40,6 +41,7 @@ from twisted.internet import defer
 from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
+from synapse.logging import opentracing
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.databases import Databases
@@ -103,12 +105,18 @@ times_pruned_extremities = Counter(
 )
 
 
-@attr.s(auto_attribs=True, frozen=True, slots=True)
+@attr.s(auto_attribs=True, slots=True)
 class _EventPersistQueueItem:
     events_and_contexts: List[Tuple[EventBase, EventContext]]
     backfilled: bool
     deferred: ObservableDeferred
 
+    parent_opentracing_span_contexts: List = []
+    """A list of opentracing spans waiting for this batch"""
+
+    opentracing_span_context: Any = None
+    """The opentracing span under which the persistence actually happened"""
+
 
 _PersistResult = TypeVar("_PersistResult")
 
@@ -171,9 +179,27 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
             )
             queue.append(end_item)
 
+        # add our events to the queue item
         end_item.events_and_contexts.extend(events_and_contexts)
+
+        # also add our active opentracing span to the item so that we get a link back
+        span = opentracing.active_span()
+        if span:
+            end_item.parent_opentracing_span_contexts.append(span.context)
+
+        # start a processor for the queue, if there isn't one already
         self._handle_queue(room_id)
-        return await make_deferred_yieldable(end_item.deferred.observe())
+
+        # wait for the queue item to complete
+        res = await make_deferred_yieldable(end_item.deferred.observe())
+
+        # add another opentracing span which links to the persist trace.
+        with opentracing.start_active_span_follows_from(
+            "persist_event_batch_complete", (end_item.opentracing_span_context,)
+        ):
+            pass
+
+        return res
 
     def _handle_queue(self, room_id):
         """Attempts to handle the queue for a room if not already being handled.
@@ -200,9 +226,17 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
                 queue = self._get_drainining_queue(room_id)
                 for item in queue:
                     try:
-                        ret = await self._per_item_callback(
-                            item.events_and_contexts, item.backfilled
-                        )
+                        with opentracing.start_active_span_follows_from(
+                            "persist_event_batch",
+                            item.parent_opentracing_span_contexts,
+                            inherit_force_tracing=True,
+                        ) as scope:
+                            if scope:
+                                item.opentracing_span_context = scope.span.context
+
+                            ret = await self._per_item_callback(
+                                item.events_and_contexts, item.backfilled
+                            )
                     except Exception:
                         with PreserveLoggingContext():
                             item.deferred.errback()
@@ -252,6 +286,7 @@ class EventsPersistenceStorage:
         self._event_persist_queue = _EventPeristenceQueue(self._persist_event_batch)
         self._state_resolution_handler = hs.get_state_resolution_handler()
 
+    @opentracing.trace
     async def persist_events(
         self,
         events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
@@ -307,6 +342,7 @@ class EventsPersistenceStorage:
             self.main_store.get_room_max_token(),
         )
 
+    @opentracing.trace
     async def persist_event(
         self, event: EventBase, context: EventContext, backfilled: bool = False
     ) -> Tuple[EventBase, PersistedEventPosition, RoomStreamToken]: