diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index e52567afa0..3409ddf0d0 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -34,7 +34,6 @@ import threading
import typing
import warnings
from collections.abc import Coroutine, Generator
-from contextvars import ContextVar
from types import TracebackType
from typing import (
TYPE_CHECKING,
@@ -235,7 +234,14 @@ LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"]
class _Sentinel:
"""Sentinel to represent the root context"""
- __slots__ = ["previous_context", "finished", "request", "scope", "tag"]
+ __slots__ = [
+ "previous_context",
+ "finished",
+ "request",
+ "scope",
+ "tag",
+ "metrics_name",
+ ]
def __init__(self) -> None:
# Minimal set for compatibility with LoggingContext
@@ -244,6 +250,7 @@ class _Sentinel:
self.request = None
self.scope = None
self.tag = None
+ self.metrics_name = None
def __str__(self) -> str:
return "sentinel"
@@ -296,6 +303,7 @@ class LoggingContext:
"request",
"tag",
"scope",
+ "metrics_name",
]
def __init__(
@@ -306,6 +314,8 @@ class LoggingContext:
) -> None:
self.previous_context = current_context()
+ self.metrics_name: Optional[str] = None
+
# track the resources used by this context so far
self._resource_usage = ContextResourceUsage()
@@ -339,6 +349,7 @@ class LoggingContext:
# if we don't have a `name`, but do have a parent context, use its name.
if self.parent_context and name is None:
name = str(self.parent_context)
+ self.metrics_name = self.parent_context.metrics_name
if name is None:
raise ValueError(
"LoggingContext must be given either a name or a parent context"
@@ -821,14 +832,14 @@ def run_in_background(
d: "defer.Deferred[R]"
if isinstance(res, typing.Coroutine):
# Wrap the coroutine in a `Deferred`.
- d = defer.ensureDeferred(measure_coroutine(current.name, res))
+ d = defer.ensureDeferred(measure_coroutine(current.metrics_name, res))
elif isinstance(res, defer.Deferred):
d = res
elif isinstance(res, Awaitable):
# `res` is probably some kind of completed awaitable, such as a `DoneAwaitable`
# or `Future` from `make_awaitable`.
d = defer.ensureDeferred(
- measure_coroutine(current.name, _unwrap_awaitable(res))
+ measure_coroutine(current.metrics_name, _unwrap_awaitable(res))
)
else:
# `res` is a plain value. Wrap it in a `Deferred`.
@@ -1069,6 +1080,10 @@ class _ResourceTracker2(Coroutine[defer.Deferred[Any], Any, _T]):
async def measure_coroutine(
- name: str, co: Coroutine[defer.Deferred[Any], Any, _T]
+ name: Optional[str], co: Coroutine[defer.Deferred[Any], Any, _T]
) -> _T:
+ if not name:
+ return await co
+
+ current_context().metrics_name = name
return await _ResourceTracker2(name, co)
|