summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2020-03-24 14:45:33 +0000
committerGitHub <noreply@github.com>2020-03-24 14:45:33 +0000
commit39230d217104f3cd7aba9065dc478f935ce1e614 (patch)
tree3585baa74ba9dfeba89efc367a5ec79f02b847d3 /synapse
parentFix CAS redirect url (#6634) (diff)
downloadsynapse-39230d217104f3cd7aba9065dc478f935ce1e614.tar.xz
Clean up some LoggingContext stuff (#7120)
* Pull Sentinel out of LoggingContext

... and drop a few unnecessary references to it

* Factor out LoggingContext.current_context

move `current_context` and `set_context` out to top-level functions.

Mostly this means that I can more easily trace what's actually referring to
LoggingContext, but I think it's generally neater.

* move copy-to-parent into `stop`

this really just makes `start` and `stop` more symetric. It also means that it
behaves correctly if you manually `set_log_context` rather than using the
context manager.

* Replace `LoggingContext.alive` with `finished`

Turn `alive` into `finished` and make it a bit better defined.
Diffstat (limited to 'synapse')
-rw-r--r--synapse/crypto/keyring.py4
-rw-r--r--synapse/federation/federation_base.py4
-rw-r--r--synapse/handlers/sync.py4
-rw-r--r--synapse/http/request_metrics.py6
-rw-r--r--synapse/logging/_structured.py4
-rw-r--r--synapse/logging/context.py234
-rw-r--r--synapse/logging/scopecontextmanager.py13
-rw-r--r--synapse/storage/data_stores/main/events_worker.py4
-rw-r--r--synapse/storage/database.py11
-rw-r--r--synapse/util/metrics.py4
-rw-r--r--synapse/util/patch_inline_callbacks.py36
11 files changed, 161 insertions, 163 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 983f0ead8c..a9f4025bfe 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -43,8 +43,8 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.logging.context import (
-    LoggingContext,
     PreserveLoggingContext,
+    current_context,
     make_deferred_yieldable,
     preserve_fn,
     run_in_background,
@@ -236,7 +236,7 @@ class Keyring(object):
         """
 
         try:
-            ctx = LoggingContext.current_context()
+            ctx = current_context()
 
             # map from server name to a set of outstanding request ids
             server_to_request_ids = {}
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index b0b0eba41e..4b115aac04 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -32,8 +32,8 @@ from synapse.events import EventBase, make_event_from_dict
 from synapse.events.utils import prune_event
 from synapse.http.servlet import assert_params_in_dict
 from synapse.logging.context import (
-    LoggingContext,
     PreserveLoggingContext,
+    current_context,
     make_deferred_yieldable,
 )
 from synapse.types import JsonDict, get_domain_from_id
@@ -78,7 +78,7 @@ class FederationBase(object):
         """
         deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus)
 
-        ctx = LoggingContext.current_context()
+        ctx = current_context()
 
         def callback(_, pdu: EventBase):
             with PreserveLoggingContext(ctx):
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 669dbc8a48..5746fdea14 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -26,7 +26,7 @@ from prometheus_client import Counter
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.filtering import FilterCollection
 from synapse.events import EventBase
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import current_context
 from synapse.push.clientformat import format_push_rules_for_user
 from synapse.storage.roommember import MemberSummary
 from synapse.storage.state import StateFilter
@@ -301,7 +301,7 @@ class SyncHandler(object):
         else:
             sync_type = "incremental_sync"
 
-        context = LoggingContext.current_context()
+        context = current_context()
         if context:
             context.tag = sync_type
 
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index 58f9cc61c8..b58ae3d9db 100644
--- a/synapse/http/request_metrics.py
+++ b/synapse/http/request_metrics.py
@@ -19,7 +19,7 @@ import threading
 
 from prometheus_client.core import Counter, Histogram
 
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import current_context
 from synapse.metrics import LaterGauge
 
 logger = logging.getLogger(__name__)
@@ -148,7 +148,7 @@ LaterGauge(
 class RequestMetrics(object):
     def start(self, time_sec, name, method):
         self.start = time_sec
-        self.start_context = LoggingContext.current_context()
+        self.start_context = current_context()
         self.name = name
         self.method = method
 
@@ -163,7 +163,7 @@ class RequestMetrics(object):
         with _in_flight_requests_lock:
             _in_flight_requests.discard(self)
 
-        context = LoggingContext.current_context()
+        context = current_context()
 
         tag = ""
         if context:
diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
index ffa7b20ca8..7372450b45 100644
--- a/synapse/logging/_structured.py
+++ b/synapse/logging/_structured.py
@@ -42,7 +42,7 @@ from synapse.logging._terse_json import (
     TerseJSONToConsoleLogObserver,
     TerseJSONToTCPLogObserver,
 )
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import current_context
 
 
 def stdlib_log_level_to_twisted(level: str) -> LogLevel:
@@ -86,7 +86,7 @@ class LogContextObserver(object):
             ].startswith("Timing out client"):
                 return
 
-        context = LoggingContext.current_context()
+        context = current_context()
 
         # Copy the context information to the log event.
         if context is not None:
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 860b99a4c6..a8eafb1c7c 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -175,7 +175,54 @@ class ContextResourceUsage(object):
         return res
 
 
-LoggingContextOrSentinel = Union["LoggingContext", "LoggingContext.Sentinel"]
+LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"]
+
+
+class _Sentinel(object):
+    """Sentinel to represent the root context"""
+
+    __slots__ = ["previous_context", "finished", "request", "scope", "tag"]
+
+    def __init__(self) -> None:
+        # Minimal set for compatibility with LoggingContext
+        self.previous_context = None
+        self.finished = False
+        self.request = None
+        self.scope = None
+        self.tag = None
+
+    def __str__(self):
+        return "sentinel"
+
+    def copy_to(self, record):
+        pass
+
+    def copy_to_twisted_log_entry(self, record):
+        record["request"] = None
+        record["scope"] = None
+
+    def start(self):
+        pass
+
+    def stop(self):
+        pass
+
+    def add_database_transaction(self, duration_sec):
+        pass
+
+    def add_database_scheduled(self, sched_sec):
+        pass
+
+    def record_event_fetch(self, event_count):
+        pass
+
+    def __nonzero__(self):
+        return False
+
+    __bool__ = __nonzero__  # python3
+
+
+SENTINEL_CONTEXT = _Sentinel()
 
 
 class LoggingContext(object):
@@ -199,76 +246,33 @@ class LoggingContext(object):
         "_resource_usage",
         "usage_start",
         "main_thread",
-        "alive",
+        "finished",
         "request",
         "tag",
         "scope",
     ]
 
-    thread_local = threading.local()
-
-    class Sentinel(object):
-        """Sentinel to represent the root context"""
-
-        __slots__ = ["previous_context", "alive", "request", "scope", "tag"]
-
-        def __init__(self) -> None:
-            # Minimal set for compatibility with LoggingContext
-            self.previous_context = None
-            self.alive = None
-            self.request = None
-            self.scope = None
-            self.tag = None
-
-        def __str__(self):
-            return "sentinel"
-
-        def copy_to(self, record):
-            pass
-
-        def copy_to_twisted_log_entry(self, record):
-            record["request"] = None
-            record["scope"] = None
-
-        def start(self):
-            pass
-
-        def stop(self):
-            pass
-
-        def add_database_transaction(self, duration_sec):
-            pass
-
-        def add_database_scheduled(self, sched_sec):
-            pass
-
-        def record_event_fetch(self, event_count):
-            pass
-
-        def __nonzero__(self):
-            return False
-
-        __bool__ = __nonzero__  # python3
-
-    sentinel = Sentinel()
-
     def __init__(self, name=None, parent_context=None, request=None) -> None:
-        self.previous_context = LoggingContext.current_context()
+        self.previous_context = current_context()
         self.name = name
 
         # track the resources used by this context so far
         self._resource_usage = ContextResourceUsage()
 
-        # If alive has the thread resource usage when the logcontext last
-        # became active.
+        # The thread resource usage when the logcontext became active. None
+        # if the context is not currently active.
         self.usage_start = None
 
         self.main_thread = get_thread_id()
         self.request = None
         self.tag = ""
-        self.alive = True
         self.scope = None  # type: Optional[_LogContextScope]
 
+        # keep track of whether we have hit the __exit__ block for this context
+        # (suggesting that the the thing that created the context thinks it should
+        # be finished, and that re-activating it would suggest an error).
+        self.finished = False
+
         self.parent_context = parent_context
 
         if self.parent_context is not None:
@@ -283,44 +287,15 @@ class LoggingContext(object):
             return str(self.request)
         return "%s@%x" % (self.name, id(self))
 
-    @classmethod
-    def current_context(cls) -> LoggingContextOrSentinel:
-        """Get the current logging context from thread local storage
-
-        Returns:
-            LoggingContext: the current logging context
-        """
-        return getattr(cls.thread_local, "current_context", cls.sentinel)
-
-    @classmethod
-    def set_current_context(
-        cls, context: LoggingContextOrSentinel
-    ) -> LoggingContextOrSentinel:
-        """Set the current logging context in thread local storage
-        Args:
-            context(LoggingContext): The context to activate.
-        Returns:
-            The context that was previously active
-        """
-        current = cls.current_context()
-
-        if current is not context:
-            current.stop()
-            cls.thread_local.current_context = context
-            context.start()
-        return current
-
     def __enter__(self) -> "LoggingContext":
         """Enters this logging context into thread local storage"""
-        old_context = self.set_current_context(self)
+        old_context = set_current_context(self)
         if self.previous_context != old_context:
             logger.warning(
                 "Expected previous context %r, found %r",
                 self.previous_context,
                 old_context,
             )
-        self.alive = True
-
         return self
 
     def __exit__(self, type, value, traceback) -> None:
@@ -329,24 +304,19 @@ class LoggingContext(object):
         Returns:
             None to avoid suppressing any exceptions that were thrown.
         """
-        current = self.set_current_context(self.previous_context)
+        current = set_current_context(self.previous_context)
         if current is not self:
-            if current is self.sentinel:
+            if current is SENTINEL_CONTEXT:
                 logger.warning("Expected logging context %s was lost", self)
             else:
                 logger.warning(
                     "Expected logging context %s but found %s", self, current
                 )
-        self.alive = False
-
-        # if we have a parent, pass our CPU usage stats on
-        if self.parent_context is not None and hasattr(
-            self.parent_context, "_resource_usage"
-        ):
-            self.parent_context._resource_usage += self._resource_usage
 
-            # reset them in case we get entered again
-            self._resource_usage.reset()
+        # the fact that we are here suggests that the caller thinks that everything
+        # is done and dusted for this logcontext, and further activity will not get
+        # recorded against the correct metrics.
+        self.finished = True
 
     def copy_to(self, record) -> None:
         """Copy logging fields from this context to a log record or
@@ -371,9 +341,14 @@ class LoggingContext(object):
             logger.warning("Started logcontext %s on different thread", self)
             return
 
+        if self.finished:
+            logger.warning("Re-starting finished log context %s", self)
+
         # If we haven't already started record the thread resource usage so
         # far
-        if not self.usage_start:
+        if self.usage_start:
+            logger.warning("Re-starting already-active log context %s", self)
+        else:
             self.usage_start = get_thread_resource_usage()
 
     def stop(self) -> None:
@@ -396,6 +371,15 @@ class LoggingContext(object):
 
         self.usage_start = None
 
+        # if we have a parent, pass our CPU usage stats on
+        if self.parent_context is not None and hasattr(
+            self.parent_context, "_resource_usage"
+        ):
+            self.parent_context._resource_usage += self._resource_usage
+
+            # reset them in case we get entered again
+            self._resource_usage.reset()
+
     def get_resource_usage(self) -> ContextResourceUsage:
         """Get resources used by this logcontext so far.
 
@@ -409,7 +393,7 @@ class LoggingContext(object):
         # If we are on the correct thread and we're currently running then we
         # can include resource usage so far.
         is_main_thread = get_thread_id() == self.main_thread
-        if self.alive and self.usage_start and is_main_thread:
+        if self.usage_start and is_main_thread:
             utime_delta, stime_delta = self._get_cputime()
             res.ru_utime += utime_delta
             res.ru_stime += stime_delta
@@ -492,7 +476,7 @@ class LoggingContextFilter(logging.Filter):
         Returns:
             True to include the record in the log output.
         """
-        context = LoggingContext.current_context()
+        context = current_context()
         for key, value in self.defaults.items():
             setattr(record, key, value)
 
@@ -512,27 +496,24 @@ class PreserveLoggingContext(object):
 
     __slots__ = ["current_context", "new_context", "has_parent"]
 
-    def __init__(self, new_context: Optional[LoggingContextOrSentinel] = None) -> None:
-        if new_context is None:
-            self.new_context = LoggingContext.sentinel  # type: LoggingContextOrSentinel
-        else:
-            self.new_context = new_context
+    def __init__(
+        self, new_context: LoggingContextOrSentinel = SENTINEL_CONTEXT
+    ) -> None:
+        self.new_context = new_context
 
     def __enter__(self) -> None:
         """Captures the current logging context"""
-        self.current_context = LoggingContext.set_current_context(self.new_context)
+        self.current_context = set_current_context(self.new_context)
 
         if self.current_context:
             self.has_parent = self.current_context.previous_context is not None
-            if not self.current_context.alive:
-                logger.debug("Entering dead context: %s", self.current_context)
 
     def __exit__(self, type, value, traceback) -> None:
         """Restores the current logging context"""
-        context = LoggingContext.set_current_context(self.current_context)
+        context = set_current_context(self.current_context)
 
         if context != self.new_context:
-            if context is LoggingContext.sentinel:
+            if not context:
                 logger.warning("Expected logging context %s was lost", self.new_context)
             else:
                 logger.warning(
@@ -541,9 +522,30 @@ class PreserveLoggingContext(object):
                     context,
                 )
 
-        if self.current_context is not LoggingContext.sentinel:
-            if not self.current_context.alive:
-                logger.debug("Restoring dead context: %s", self.current_context)
+
+_thread_local = threading.local()
+_thread_local.current_context = SENTINEL_CONTEXT
+
+
+def current_context() -> LoggingContextOrSentinel:
+    """Get the current logging context from thread local storage"""
+    return getattr(_thread_local, "current_context", SENTINEL_CONTEXT)
+
+
+def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSentinel:
+    """Set the current logging context in thread local storage
+    Args:
+        context(LoggingContext): The context to activate.
+    Returns:
+        The context that was previously active
+    """
+    current = current_context()
+
+    if current is not context:
+        current.stop()
+        _thread_local.current_context = context
+        context.start()
+    return current
 
 
 def nested_logging_context(
@@ -572,7 +574,7 @@ def nested_logging_context(
     if parent_context is not None:
         context = parent_context  # type: LoggingContextOrSentinel
     else:
-        context = LoggingContext.current_context()
+        context = current_context()
     return LoggingContext(
         parent_context=context, request=str(context.request) + "-" + suffix
     )
@@ -604,7 +606,7 @@ def run_in_background(f, *args, **kwargs):
     CRITICAL error about an unhandled error will be logged without much
     indication about where it came from.
     """
-    current = LoggingContext.current_context()
+    current = current_context()
     try:
         res = f(*args, **kwargs)
     except:  # noqa: E722
@@ -625,7 +627,7 @@ def run_in_background(f, *args, **kwargs):
 
     # The function may have reset the context before returning, so
     # we need to restore it now.
-    ctx = LoggingContext.set_current_context(current)
+    ctx = set_current_context(current)
 
     # The original context will be restored when the deferred
     # completes, but there is nothing waiting for it, so it will
@@ -674,7 +676,7 @@ def make_deferred_yieldable(deferred):
 
     # ok, we can't be sure that a yield won't block, so let's reset the
     # logcontext, and add a callback to the deferred to restore it.
-    prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
+    prev_context = set_current_context(SENTINEL_CONTEXT)
     deferred.addBoth(_set_context_cb, prev_context)
     return deferred
 
@@ -684,7 +686,7 @@ ResultT = TypeVar("ResultT")
 
 def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
     """A callback function which just sets the logging context"""
-    LoggingContext.set_current_context(context)
+    set_current_context(context)
     return result
 
 
@@ -752,7 +754,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
         Deferred: A Deferred which fires a callback with the result of `f`, or an
             errback if `f` throws an exception.
     """
-    logcontext = LoggingContext.current_context()
+    logcontext = current_context()
 
     def g():
         with LoggingContext(parent_context=logcontext):
diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py
index 4eed4f2338..dc3ab00cbb 100644
--- a/synapse/logging/scopecontextmanager.py
+++ b/synapse/logging/scopecontextmanager.py
@@ -19,7 +19,7 @@ from opentracing import Scope, ScopeManager
 
 import twisted
 
-from synapse.logging.context import LoggingContext, nested_logging_context
+from synapse.logging.context import current_context, nested_logging_context
 
 logger = logging.getLogger(__name__)
 
@@ -49,11 +49,8 @@ class LogContextScopeManager(ScopeManager):
             (Scope) : the Scope that is active, or None if not
             available.
         """
-        ctx = LoggingContext.current_context()
-        if ctx is LoggingContext.sentinel:
-            return None
-        else:
-            return ctx.scope
+        ctx = current_context()
+        return ctx.scope
 
     def activate(self, span, finish_on_close):
         """
@@ -70,9 +67,9 @@ class LogContextScopeManager(ScopeManager):
         """
 
         enter_logcontext = False
-        ctx = LoggingContext.current_context()
+        ctx = current_context()
 
-        if ctx is LoggingContext.sentinel:
+        if not ctx:
             # We don't want this scope to affect.
             logger.error("Tried to activate scope outside of loggingcontext")
             return Scope(None, span)
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index ca237c6f12..3013f49d32 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -35,7 +35,7 @@ from synapse.api.room_versions import (
 )
 from synapse.events import make_event_from_dict
 from synapse.events.utils import prune_event
-from synapse.logging.context import LoggingContext, PreserveLoggingContext
+from synapse.logging.context import PreserveLoggingContext, current_context
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
 from synapse.storage.database import Database
@@ -409,7 +409,7 @@ class EventsWorkerStore(SQLBaseStore):
         missing_events_ids = [e for e in event_ids if e not in event_entry_map]
 
         if missing_events_ids:
-            log_ctx = LoggingContext.current_context()
+            log_ctx = current_context()
             log_ctx.record_event_fetch(len(missing_events_ids))
 
             # Note that _get_events_from_db is also responsible for turning db rows
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index e61595336c..715c0346dd 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -32,6 +32,7 @@ from synapse.config.database import DatabaseConnectionConfig
 from synapse.logging.context import (
     LoggingContext,
     LoggingContextOrSentinel,
+    current_context,
     make_deferred_yieldable,
 )
 from synapse.metrics.background_process_metrics import run_as_background_process
@@ -483,7 +484,7 @@ class Database(object):
             end = monotonic_time()
             duration = end - start
 
-            LoggingContext.current_context().add_database_transaction(duration)
+            current_context().add_database_transaction(duration)
 
             transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
 
@@ -510,7 +511,7 @@ class Database(object):
         after_callbacks = []  # type: List[_CallbackListEntry]
         exception_callbacks = []  # type: List[_CallbackListEntry]
 
-        if LoggingContext.current_context() == LoggingContext.sentinel:
+        if not current_context():
             logger.warning("Starting db txn '%s' from sentinel context", desc)
 
         try:
@@ -547,10 +548,8 @@ class Database(object):
         Returns:
             Deferred: The result of func
         """
-        parent_context = (
-            LoggingContext.current_context()
-        )  # type: Optional[LoggingContextOrSentinel]
-        if parent_context == LoggingContext.sentinel:
+        parent_context = current_context()  # type: Optional[LoggingContextOrSentinel]
+        if not parent_context:
             logger.warning(
                 "Starting db connection from sentinel context: metrics will be lost"
             )
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 7b18455469..ec61e14423 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -21,7 +21,7 @@ from prometheus_client import Counter
 
 from twisted.internet import defer
 
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import LoggingContext, current_context
 from synapse.metrics import InFlightGauge
 
 logger = logging.getLogger(__name__)
@@ -106,7 +106,7 @@ class Measure(object):
             raise RuntimeError("Measure() objects cannot be re-used")
 
         self.start = self.clock.time()
-        parent_context = LoggingContext.current_context()
+        parent_context = current_context()
         self._logging_context = LoggingContext(
             "Measure[%s]" % (self.name,), parent_context
         )
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index 3925927f9f..fdff195771 100644
--- a/synapse/util/patch_inline_callbacks.py
+++ b/synapse/util/patch_inline_callbacks.py
@@ -32,7 +32,7 @@ def do_patch():
     Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
     """
 
-    from synapse.logging.context import LoggingContext
+    from synapse.logging.context import current_context
 
     global _already_patched
 
@@ -43,35 +43,35 @@ def do_patch():
     def new_inline_callbacks(f):
         @functools.wraps(f)
         def wrapped(*args, **kwargs):
-            start_context = LoggingContext.current_context()
+            start_context = current_context()
             changes = []  # type: List[str]
             orig = orig_inline_callbacks(_check_yield_points(f, changes))
 
             try:
                 res = orig(*args, **kwargs)
             except Exception:
-                if LoggingContext.current_context() != start_context:
+                if current_context() != start_context:
                     for err in changes:
                         print(err, file=sys.stderr)
 
                     err = "%s changed context from %s to %s on exception" % (
                         f,
                         start_context,
-                        LoggingContext.current_context(),
+                        current_context(),
                     )
                     print(err, file=sys.stderr)
                     raise Exception(err)
                 raise
 
             if not isinstance(res, Deferred) or res.called:
-                if LoggingContext.current_context() != start_context:
+                if current_context() != start_context:
                     for err in changes:
                         print(err, file=sys.stderr)
 
                     err = "Completed %s changed context from %s to %s" % (
                         f,
                         start_context,
-                        LoggingContext.current_context(),
+                        current_context(),
                     )
                     # print the error to stderr because otherwise all we
                     # see in travis-ci is the 500 error
@@ -79,23 +79,23 @@ def do_patch():
                     raise Exception(err)
                 return res
 
-            if LoggingContext.current_context() != LoggingContext.sentinel:
+            if current_context():
                 err = (
                     "%s returned incomplete deferred in non-sentinel context "
                     "%s (start was %s)"
-                ) % (f, LoggingContext.current_context(), start_context)
+                ) % (f, current_context(), start_context)
                 print(err, file=sys.stderr)
                 raise Exception(err)
 
             def check_ctx(r):
-                if LoggingContext.current_context() != start_context:
+                if current_context() != start_context:
                     for err in changes:
                         print(err, file=sys.stderr)
                     err = "%s completion of %s changed context from %s to %s" % (
                         "Failure" if isinstance(r, Failure) else "Success",
                         f,
                         start_context,
-                        LoggingContext.current_context(),
+                        current_context(),
                     )
                     print(err, file=sys.stderr)
                     raise Exception(err)
@@ -127,7 +127,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
         function
     """
 
-    from synapse.logging.context import LoggingContext
+    from synapse.logging.context import current_context
 
     @functools.wraps(f)
     def check_yield_points_inner(*args, **kwargs):
@@ -136,7 +136,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
         last_yield_line_no = gen.gi_frame.f_lineno
         result = None  # type: Any
         while True:
-            expected_context = LoggingContext.current_context()
+            expected_context = current_context()
 
             try:
                 isFailure = isinstance(result, Failure)
@@ -145,7 +145,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
                 else:
                     d = gen.send(result)
             except (StopIteration, defer._DefGen_Return) as e:
-                if LoggingContext.current_context() != expected_context:
+                if current_context() != expected_context:
                     # This happens when the context is lost sometime *after* the
                     # final yield and returning. E.g. we forgot to yield on a
                     # function that returns a deferred.
@@ -159,7 +159,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
                         % (
                             f.__qualname__,
                             expected_context,
-                            LoggingContext.current_context(),
+                            current_context(),
                             f.__code__.co_filename,
                             last_yield_line_no,
                         )
@@ -173,13 +173,13 @@ def _check_yield_points(f: Callable, changes: List[str]):
                 # This happens if we yield on a deferred that doesn't follow
                 # the log context rules without wrapping in a `make_deferred_yieldable`.
                 # We raise here as this should never happen.
-                if LoggingContext.current_context() is not LoggingContext.sentinel:
+                if current_context():
                     err = (
                         "%s yielded with context %s rather than sentinel,"
                         " yielded on line %d in %s"
                         % (
                             frame.f_code.co_name,
-                            LoggingContext.current_context(),
+                            current_context(),
                             frame.f_lineno,
                             frame.f_code.co_filename,
                         )
@@ -191,7 +191,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
             except Exception as e:
                 result = Failure(e)
 
-            if LoggingContext.current_context() != expected_context:
+            if current_context() != expected_context:
 
                 # This happens because the context is lost sometime *after* the
                 # previous yield and *after* the current yield. E.g. the
@@ -206,7 +206,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
                     % (
                         frame.f_code.co_name,
                         expected_context,
-                        LoggingContext.current_context(),
+                        current_context(),
                         last_yield_line_no,
                         frame.f_lineno,
                         frame.f_code.co_filename,