| diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 860b99a4c6..a8f674d13d 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -51,7 +51,7 @@ try:
 
     is_thread_resource_usage_supported = True
 
-    def get_thread_resource_usage():
+    def get_thread_resource_usage() -> "Optional[resource._RUsage]":
         return resource.getrusage(RUSAGE_THREAD)
 
 
@@ -60,7 +60,7 @@ except Exception:
     # won't track resource usage.
     is_thread_resource_usage_supported = False
 
-    def get_thread_resource_usage():
+    def get_thread_resource_usage() -> "Optional[resource._RUsage]":
         return None
 
 
@@ -175,7 +175,54 @@ class ContextResourceUsage(object):
         return res
 
 
-LoggingContextOrSentinel = Union["LoggingContext", "LoggingContext.Sentinel"]
+LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"]
+
+
+class _Sentinel(object):
+    """Sentinel to represent the root context"""
+
+    __slots__ = ["previous_context", "finished", "request", "scope", "tag"]
+
+    def __init__(self) -> None:
+        # Minimal set for compatibility with LoggingContext
+        self.previous_context = None
+        self.finished = False
+        self.request = None
+        self.scope = None
+        self.tag = None
+
+    def __str__(self):
+        return "sentinel"
+
+    def copy_to(self, record):
+        pass
+
+    def copy_to_twisted_log_entry(self, record):
+        record["request"] = None
+        record["scope"] = None
+
+    def start(self, rusage: "Optional[resource._RUsage]"):
+        pass
+
+    def stop(self, rusage: "Optional[resource._RUsage]"):
+        pass
+
+    def add_database_transaction(self, duration_sec):
+        pass
+
+    def add_database_scheduled(self, sched_sec):
+        pass
+
+    def record_event_fetch(self, event_count):
+        pass
+
+    def __nonzero__(self):
+        return False
+
+    __bool__ = __nonzero__  # python3
+
+
+SENTINEL_CONTEXT = _Sentinel()
 
 
 class LoggingContext(object):
@@ -199,76 +246,33 @@ class LoggingContext(object):
         "_resource_usage",
         "usage_start",
         "main_thread",
-        "alive",
+        "finished",
         "request",
         "tag",
         "scope",
     ]
 
-    thread_local = threading.local()
-
-    class Sentinel(object):
-        """Sentinel to represent the root context"""
-
-        __slots__ = ["previous_context", "alive", "request", "scope", "tag"]
-
-        def __init__(self) -> None:
-            # Minimal set for compatibility with LoggingContext
-            self.previous_context = None
-            self.alive = None
-            self.request = None
-            self.scope = None
-            self.tag = None
-
-        def __str__(self):
-            return "sentinel"
-
-        def copy_to(self, record):
-            pass
-
-        def copy_to_twisted_log_entry(self, record):
-            record["request"] = None
-            record["scope"] = None
-
-        def start(self):
-            pass
-
-        def stop(self):
-            pass
-
-        def add_database_transaction(self, duration_sec):
-            pass
-
-        def add_database_scheduled(self, sched_sec):
-            pass
-
-        def record_event_fetch(self, event_count):
-            pass
-
-        def __nonzero__(self):
-            return False
-
-        __bool__ = __nonzero__  # python3
-
-    sentinel = Sentinel()
-
     def __init__(self, name=None, parent_context=None, request=None) -> None:
-        self.previous_context = LoggingContext.current_context()
+        self.previous_context = current_context()
         self.name = name
 
         # track the resources used by this context so far
         self._resource_usage = ContextResourceUsage()
 
-        # If alive has the thread resource usage when the logcontext last
-        # became active.
-        self.usage_start = None
+        # The thread resource usage when the logcontext became active. None
+        # if the context is not currently active.
+        self.usage_start = None  # type: Optional[resource._RUsage]
 
         self.main_thread = get_thread_id()
         self.request = None
         self.tag = ""
-        self.alive = True
         self.scope = None  # type: Optional[_LogContextScope]
 
+        # keep track of whether we have hit the __exit__ block for this context
+        # (suggesting that the the thing that created the context thinks it should
+        # be finished, and that re-activating it would suggest an error).
+        self.finished = False
+
         self.parent_context = parent_context
 
         if self.parent_context is not None:
@@ -283,44 +287,15 @@ class LoggingContext(object):
             return str(self.request)
         return "%s@%x" % (self.name, id(self))
 
-    @classmethod
-    def current_context(cls) -> LoggingContextOrSentinel:
-        """Get the current logging context from thread local storage
-
-        Returns:
-            LoggingContext: the current logging context
-        """
-        return getattr(cls.thread_local, "current_context", cls.sentinel)
-
-    @classmethod
-    def set_current_context(
-        cls, context: LoggingContextOrSentinel
-    ) -> LoggingContextOrSentinel:
-        """Set the current logging context in thread local storage
-        Args:
-            context(LoggingContext): The context to activate.
-        Returns:
-            The context that was previously active
-        """
-        current = cls.current_context()
-
-        if current is not context:
-            current.stop()
-            cls.thread_local.current_context = context
-            context.start()
-        return current
-
     def __enter__(self) -> "LoggingContext":
         """Enters this logging context into thread local storage"""
-        old_context = self.set_current_context(self)
+        old_context = set_current_context(self)
         if self.previous_context != old_context:
             logger.warning(
                 "Expected previous context %r, found %r",
                 self.previous_context,
                 old_context,
             )
-        self.alive = True
-
         return self
 
     def __exit__(self, type, value, traceback) -> None:
@@ -329,24 +304,19 @@ class LoggingContext(object):
         Returns:
             None to avoid suppressing any exceptions that were thrown.
         """
-        current = self.set_current_context(self.previous_context)
+        current = set_current_context(self.previous_context)
         if current is not self:
-            if current is self.sentinel:
+            if current is SENTINEL_CONTEXT:
                 logger.warning("Expected logging context %s was lost", self)
             else:
                 logger.warning(
                     "Expected logging context %s but found %s", self, current
                 )
-        self.alive = False
 
-        # if we have a parent, pass our CPU usage stats on
-        if self.parent_context is not None and hasattr(
-            self.parent_context, "_resource_usage"
-        ):
-            self.parent_context._resource_usage += self._resource_usage
-
-            # reset them in case we get entered again
-            self._resource_usage.reset()
+        # the fact that we are here suggests that the caller thinks that everything
+        # is done and dusted for this logcontext, and further activity will not get
+        # recorded against the correct metrics.
+        self.finished = True
 
     def copy_to(self, record) -> None:
         """Copy logging fields from this context to a log record or
@@ -366,35 +336,71 @@ class LoggingContext(object):
         record["request"] = self.request
         record["scope"] = self.scope
 
-    def start(self) -> None:
+    def start(self, rusage: "Optional[resource._RUsage]") -> None:
+        """
+        Record that this logcontext is currently running.
+
+        This should not be called directly: use set_current_context
+
+        Args:
+            rusage: the resources used by the current thread, at the point of
+                switching to this logcontext. May be None if this platform doesn't
+                support getrusuage.
+        """
         if get_thread_id() != self.main_thread:
             logger.warning("Started logcontext %s on different thread", self)
             return
 
+        if self.finished:
+            logger.warning("Re-starting finished log context %s", self)
+
         # If we haven't already started record the thread resource usage so
         # far
-        if not self.usage_start:
-            self.usage_start = get_thread_resource_usage()
+        if self.usage_start:
+            logger.warning("Re-starting already-active log context %s", self)
+        else:
+            self.usage_start = rusage
 
-    def stop(self) -> None:
-        if get_thread_id() != self.main_thread:
-            logger.warning("Stopped logcontext %s on different thread", self)
-            return
+    def stop(self, rusage: "Optional[resource._RUsage]") -> None:
+        """
+        Record that this logcontext is no longer running.
+
+        This should not be called directly: use set_current_context
+
+        Args:
+            rusage: the resources used by the current thread, at the point of
+                switching away from this logcontext. May be None if this platform
+                doesn't support getrusuage.
+        """
+
+        try:
+            if get_thread_id() != self.main_thread:
+                logger.warning("Stopped logcontext %s on different thread", self)
+                return
+
+            if not rusage:
+                return
 
-        # When we stop, let's record the cpu used since we started
-        if not self.usage_start:
-            # Log a warning on platforms that support thread usage tracking
-            if is_thread_resource_usage_supported:
+            # Record the cpu used since we started
+            if not self.usage_start:
                 logger.warning(
-                    "Called stop on logcontext %s without calling start", self
+                    "Called stop on logcontext %s without recording a start rusage",
+                    self,
                 )
-            return
+                return
+
+            utime_delta, stime_delta = self._get_cputime(rusage)
+            self._resource_usage.ru_utime += utime_delta
+            self._resource_usage.ru_stime += stime_delta
 
-        utime_delta, stime_delta = self._get_cputime()
-        self._resource_usage.ru_utime += utime_delta
-        self._resource_usage.ru_stime += stime_delta
+            # if we have a parent, pass our CPU usage stats on
+            if self.parent_context:
+                self.parent_context._resource_usage += self._resource_usage
 
-        self.usage_start = None
+                # reset them in case we get entered again
+                self._resource_usage.reset()
+        finally:
+            self.usage_start = None
 
     def get_resource_usage(self) -> ContextResourceUsage:
         """Get resources used by this logcontext so far.
@@ -409,25 +415,25 @@ class LoggingContext(object):
         # If we are on the correct thread and we're currently running then we
         # can include resource usage so far.
         is_main_thread = get_thread_id() == self.main_thread
-        if self.alive and self.usage_start and is_main_thread:
-            utime_delta, stime_delta = self._get_cputime()
+        if self.usage_start and is_main_thread:
+            rusage = get_thread_resource_usage()
+            assert rusage is not None
+            utime_delta, stime_delta = self._get_cputime(rusage)
             res.ru_utime += utime_delta
             res.ru_stime += stime_delta
 
         return res
 
-    def _get_cputime(self) -> Tuple[float, float]:
-        """Get the cpu usage time so far
+    def _get_cputime(self, current: "resource._RUsage") -> Tuple[float, float]:
+        """Get the cpu usage time between start() and the given rusage
+
+        Args:
+            rusage: the current resource usage
 
         Returns: Tuple[float, float]: seconds in user mode, seconds in system mode
         """
         assert self.usage_start is not None
 
-        current = get_thread_resource_usage()
-
-        # Indicate to mypy that we know that self.usage_start is None.
-        assert self.usage_start is not None
-
         utime_delta = current.ru_utime - self.usage_start.ru_utime
         stime_delta = current.ru_stime - self.usage_start.ru_stime
 
@@ -492,7 +498,7 @@ class LoggingContextFilter(logging.Filter):
         Returns:
             True to include the record in the log output.
         """
-        context = LoggingContext.current_context()
+        context = current_context()
         for key, value in self.defaults.items():
             setattr(record, key, value)
 
@@ -512,27 +518,24 @@ class PreserveLoggingContext(object):
 
     __slots__ = ["current_context", "new_context", "has_parent"]
 
-    def __init__(self, new_context: Optional[LoggingContextOrSentinel] = None) -> None:
-        if new_context is None:
-            self.new_context = LoggingContext.sentinel  # type: LoggingContextOrSentinel
-        else:
-            self.new_context = new_context
+    def __init__(
+        self, new_context: LoggingContextOrSentinel = SENTINEL_CONTEXT
+    ) -> None:
+        self.new_context = new_context
 
     def __enter__(self) -> None:
         """Captures the current logging context"""
-        self.current_context = LoggingContext.set_current_context(self.new_context)
+        self.current_context = set_current_context(self.new_context)
 
         if self.current_context:
             self.has_parent = self.current_context.previous_context is not None
-            if not self.current_context.alive:
-                logger.debug("Entering dead context: %s", self.current_context)
 
     def __exit__(self, type, value, traceback) -> None:
         """Restores the current logging context"""
-        context = LoggingContext.set_current_context(self.current_context)
+        context = set_current_context(self.current_context)
 
         if context != self.new_context:
-            if context is LoggingContext.sentinel:
+            if not context:
                 logger.warning("Expected logging context %s was lost", self.new_context)
             else:
                 logger.warning(
@@ -541,9 +544,37 @@ class PreserveLoggingContext(object):
                     context,
                 )
 
-        if self.current_context is not LoggingContext.sentinel:
-            if not self.current_context.alive:
-                logger.debug("Restoring dead context: %s", self.current_context)
+
+_thread_local = threading.local()
+_thread_local.current_context = SENTINEL_CONTEXT
+
+
+def current_context() -> LoggingContextOrSentinel:
+    """Get the current logging context from thread local storage"""
+    return getattr(_thread_local, "current_context", SENTINEL_CONTEXT)
+
+
+def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSentinel:
+    """Set the current logging context in thread local storage
+    Args:
+        context(LoggingContext): The context to activate.
+    Returns:
+        The context that was previously active
+    """
+    # everything blows up if we allow current_context to be set to None, so sanity-check
+    # that now.
+    if context is None:
+        raise TypeError("'context' argument may not be None")
+
+    current = current_context()
+
+    if current is not context:
+        rusage = get_thread_resource_usage()
+        current.stop(rusage)
+        _thread_local.current_context = context
+        context.start(rusage)
+
+    return current
 
 
 def nested_logging_context(
@@ -572,7 +603,7 @@ def nested_logging_context(
     if parent_context is not None:
         context = parent_context  # type: LoggingContextOrSentinel
     else:
-        context = LoggingContext.current_context()
+        context = current_context()
     return LoggingContext(
         parent_context=context, request=str(context.request) + "-" + suffix
     )
@@ -604,7 +635,7 @@ def run_in_background(f, *args, **kwargs):
     CRITICAL error about an unhandled error will be logged without much
     indication about where it came from.
     """
-    current = LoggingContext.current_context()
+    current = current_context()
     try:
         res = f(*args, **kwargs)
     except:  # noqa: E722
@@ -625,7 +656,7 @@ def run_in_background(f, *args, **kwargs):
 
     # The function may have reset the context before returning, so
     # we need to restore it now.
-    ctx = LoggingContext.set_current_context(current)
+    ctx = set_current_context(current)
 
     # The original context will be restored when the deferred
     # completes, but there is nothing waiting for it, so it will
@@ -674,7 +705,7 @@ def make_deferred_yieldable(deferred):
 
     # ok, we can't be sure that a yield won't block, so let's reset the
     # logcontext, and add a callback to the deferred to restore it.
-    prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
+    prev_context = set_current_context(SENTINEL_CONTEXT)
     deferred.addBoth(_set_context_cb, prev_context)
     return deferred
 
@@ -684,7 +715,7 @@ ResultT = TypeVar("ResultT")
 
 def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
     """A callback function which just sets the logging context"""
-    LoggingContext.set_current_context(context)
+    set_current_context(context)
     return result
 
 
@@ -752,7 +783,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
         Deferred: A Deferred which fires a callback with the result of `f`, or an
             errback if `f` throws an exception.
     """
-    logcontext = LoggingContext.current_context()
+    logcontext = current_context()
 
     def g():
         with LoggingContext(parent_context=logcontext):
 |