summary refs log tree commit diff
path: root/synapse/logging
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/logging')
-rw-r--r--synapse/logging/_structured.py20
-rw-r--r--synapse/logging/_terse_json.py109
-rw-r--r--synapse/logging/context.py437
-rw-r--r--synapse/logging/opentracing.py44
-rw-r--r--synapse/logging/scopecontextmanager.py13
-rw-r--r--synapse/logging/utils.py30
6 files changed, 425 insertions, 228 deletions
diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
index 3220e985a9..7372450b45 100644
--- a/synapse/logging/_structured.py
+++ b/synapse/logging/_structured.py
@@ -42,7 +42,7 @@ from synapse.logging._terse_json import (
     TerseJSONToConsoleLogObserver,
     TerseJSONToTCPLogObserver,
 )
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import current_context
 
 
 def stdlib_log_level_to_twisted(level: str) -> LogLevel:
@@ -86,7 +86,7 @@ class LogContextObserver(object):
             ].startswith("Timing out client"):
                 return
 
-        context = LoggingContext.current_context()
+        context = current_context()
 
         # Copy the context information to the log event.
         if context is not None:
@@ -185,7 +185,7 @@ DEFAULT_LOGGERS = {"synapse": {"level": "INFO"}}
 
 
 def parse_drain_configs(
-    drains: dict
+    drains: dict,
 ) -> typing.Generator[DrainConfiguration, None, None]:
     """
     Parse the drain configurations.
@@ -261,6 +261,18 @@ def parse_drain_configs(
             )
 
 
+class StoppableLogPublisher(LogPublisher):
+    """
+    A log publisher that can tell its observers to shut down any external
+    communications.
+    """
+
+    def stop(self):
+        for obs in self._observers:
+            if hasattr(obs, "stop"):
+                obs.stop()
+
+
 def setup_structured_logging(
     hs,
     config,
@@ -336,7 +348,7 @@ def setup_structured_logging(
             # We should never get here, but, just in case, throw an error.
             raise ConfigError("%s drain type cannot be configured" % (observer.type,))
 
-    publisher = LogPublisher(*observers)
+    publisher = StoppableLogPublisher(*observers)
     log_filter = LogLevelFilterPredicate()
 
     for namespace, namespace_config in log_config.get(
diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py
index 0ebbde06f2..c0b9384189 100644
--- a/synapse/logging/_terse_json.py
+++ b/synapse/logging/_terse_json.py
@@ -17,25 +17,29 @@
 Log formatters that output terse JSON.
 """
 
+import json
 import sys
+import traceback
 from collections import deque
 from ipaddress import IPv4Address, IPv6Address, ip_address
 from math import floor
-from typing import IO
+from typing import IO, Optional
 
 import attr
-from simplejson import dumps
 from zope.interface import implementer
 
 from twisted.application.internet import ClientService
+from twisted.internet.defer import Deferred
 from twisted.internet.endpoints import (
     HostnameEndpoint,
     TCP4ClientEndpoint,
     TCP6ClientEndpoint,
 )
+from twisted.internet.interfaces import IPushProducer, ITransport
 from twisted.internet.protocol import Factory, Protocol
 from twisted.logger import FileLogObserver, ILogObserver, Logger
-from twisted.python.failure import Failure
+
+_encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":"))
 
 
 def flatten_event(event: dict, metadata: dict, include_time: bool = False):
@@ -141,19 +145,57 @@ def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogOb
 
     def formatEvent(_event: dict) -> str:
         flattened = flatten_event(_event, metadata)
-        return dumps(flattened, ensure_ascii=False, separators=(",", ":")) + "\n"
+        return _encoder.encode(flattened) + "\n"
 
     return FileLogObserver(outFile, formatEvent)
 
 
 @attr.s
+@implementer(IPushProducer)
+class LogProducer(object):
+    """
+    An IPushProducer that writes logs from its buffer to its transport when it
+    is resumed.
+
+    Args:
+        buffer: Log buffer to read logs from.
+        transport: Transport to write to.
+    """
+
+    transport = attr.ib(type=ITransport)
+    _buffer = attr.ib(type=deque)
+    _paused = attr.ib(default=False, type=bool, init=False)
+
+    def pauseProducing(self):
+        self._paused = True
+
+    def stopProducing(self):
+        self._paused = True
+        self._buffer = deque()
+
+    def resumeProducing(self):
+        self._paused = False
+
+        while self._paused is False and (self._buffer and self.transport.connected):
+            try:
+                event = self._buffer.popleft()
+                self.transport.write(_encoder.encode(event).encode("utf8"))
+                self.transport.write(b"\n")
+            except Exception:
+                # Something has gone wrong writing to the transport -- log it
+                # and break out of the while.
+                traceback.print_exc(file=sys.__stderr__)
+                break
+
+
+@attr.s
 @implementer(ILogObserver)
 class TerseJSONToTCPLogObserver(object):
     """
     An IObserver that writes JSON logs to a TCP target.
 
     Args:
-        hs (HomeServer): The Homeserver that is being logged for.
+        hs (HomeServer): The homeserver that is being logged for.
         host: The host of the logging target.
         port: The logging target's port.
         metadata: Metadata to be added to each log entry.
@@ -165,8 +207,9 @@ class TerseJSONToTCPLogObserver(object):
     metadata = attr.ib(type=dict)
     maximum_buffer = attr.ib(type=int)
     _buffer = attr.ib(default=attr.Factory(deque), type=deque)
-    _writer = attr.ib(default=None)
+    _connection_waiter = attr.ib(default=None, type=Optional[Deferred])
     _logger = attr.ib(default=attr.Factory(Logger))
+    _producer = attr.ib(default=None, type=Optional[LogProducer])
 
     def start(self) -> None:
 
@@ -187,38 +230,44 @@ class TerseJSONToTCPLogObserver(object):
         factory = Factory.forProtocol(Protocol)
         self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor())
         self._service.startService()
+        self._connect()
 
-    def _write_loop(self) -> None:
+    def stop(self):
+        self._service.stopService()
+
+    def _connect(self) -> None:
         """
-        Implement the write loop.
+        Triggers an attempt to connect then write to the remote if not already writing.
         """
-        if self._writer:
+        if self._connection_waiter:
             return
 
-        self._writer = self._service.whenConnected()
+        self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
+
+        @self._connection_waiter.addErrback
+        def fail(r):
+            r.printTraceback(file=sys.__stderr__)
+            self._connection_waiter = None
+            self._connect()
 
-        @self._writer.addBoth
+        @self._connection_waiter.addCallback
         def writer(r):
-            if isinstance(r, Failure):
-                r.printTraceback(file=sys.__stderr__)
-                self._writer = None
-                self.hs.get_reactor().callLater(1, self._write_loop)
+            # We have a connection. If we already have a producer, and its
+            # transport is the same, just trigger a resumeProducing.
+            if self._producer and r.transport is self._producer.transport:
+                self._producer.resumeProducing()
+                self._connection_waiter = None
                 return
 
-            try:
-                for event in self._buffer:
-                    r.transport.write(
-                        dumps(event, ensure_ascii=False, separators=(",", ":")).encode(
-                            "utf8"
-                        )
-                    )
-                    r.transport.write(b"\n")
-                self._buffer.clear()
-            except Exception as e:
-                sys.__stderr__.write("Failed writing out logs with %s\n" % (str(e),))
-
-            self._writer = False
-            self.hs.get_reactor().callLater(1, self._write_loop)
+            # If the producer is still producing, stop it.
+            if self._producer:
+                self._producer.stopProducing()
+
+            # Make a new producer and start it.
+            self._producer = LogProducer(buffer=self._buffer, transport=r.transport)
+            r.transport.registerProducer(self._producer, True)
+            self._producer.resumeProducing()
+            self._connection_waiter = None
 
     def _handle_pressure(self) -> None:
         """
@@ -277,4 +326,4 @@ class TerseJSONToTCPLogObserver(object):
             self._logger.failure("Failed clearing backpressure")
 
         # Try and write immediately.
-        self._write_loop()
+        self._connect()
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 370000e377..8b9c4e38bd 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -23,13 +23,20 @@ them.
 See doc/log_contexts.rst for details on how this works.
 """
 
+import inspect
 import logging
 import threading
 import types
-from typing import Any, List
+import warnings
+from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
+
+from typing_extensions import Literal
 
 from twisted.internet import defer, threads
 
+if TYPE_CHECKING:
+    from synapse.logging.scopecontextmanager import _LogContextScope
+
 logger = logging.getLogger(__name__)
 
 try:
@@ -45,7 +52,7 @@ try:
 
     is_thread_resource_usage_supported = True
 
-    def get_thread_resource_usage():
+    def get_thread_resource_usage() -> "Optional[resource._RUsage]":
         return resource.getrusage(RUSAGE_THREAD)
 
 
@@ -54,7 +61,7 @@ except Exception:
     # won't track resource usage.
     is_thread_resource_usage_supported = False
 
-    def get_thread_resource_usage():
+    def get_thread_resource_usage() -> "Optional[resource._RUsage]":
         return None
 
 
@@ -90,7 +97,7 @@ class ContextResourceUsage(object):
         "evt_db_fetch_count",
     ]
 
-    def __init__(self, copy_from=None):
+    def __init__(self, copy_from: "Optional[ContextResourceUsage]" = None) -> None:
         """Create a new ContextResourceUsage
 
         Args:
@@ -100,27 +107,28 @@ class ContextResourceUsage(object):
         if copy_from is None:
             self.reset()
         else:
-            self.ru_utime = copy_from.ru_utime
-            self.ru_stime = copy_from.ru_stime
-            self.db_txn_count = copy_from.db_txn_count
+            # FIXME: mypy can't infer the types set via reset() above, so specify explicitly for now
+            self.ru_utime = copy_from.ru_utime  # type: float
+            self.ru_stime = copy_from.ru_stime  # type: float
+            self.db_txn_count = copy_from.db_txn_count  # type: int
 
-            self.db_txn_duration_sec = copy_from.db_txn_duration_sec
-            self.db_sched_duration_sec = copy_from.db_sched_duration_sec
-            self.evt_db_fetch_count = copy_from.evt_db_fetch_count
+            self.db_txn_duration_sec = copy_from.db_txn_duration_sec  # type: float
+            self.db_sched_duration_sec = copy_from.db_sched_duration_sec  # type: float
+            self.evt_db_fetch_count = copy_from.evt_db_fetch_count  # type: int
 
-    def copy(self):
+    def copy(self) -> "ContextResourceUsage":
         return ContextResourceUsage(copy_from=self)
 
-    def reset(self):
+    def reset(self) -> None:
         self.ru_stime = 0.0
         self.ru_utime = 0.0
         self.db_txn_count = 0
 
-        self.db_txn_duration_sec = 0
-        self.db_sched_duration_sec = 0
+        self.db_txn_duration_sec = 0.0
+        self.db_sched_duration_sec = 0.0
         self.evt_db_fetch_count = 0
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return (
             "<ContextResourceUsage ru_stime='%r', ru_utime='%r', "
             "db_txn_count='%r', db_txn_duration_sec='%r', "
@@ -134,7 +142,7 @@ class ContextResourceUsage(object):
             self.evt_db_fetch_count,
         )
 
-    def __iadd__(self, other):
+    def __iadd__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
         """Add another ContextResourceUsage's stats to this one's.
 
         Args:
@@ -148,7 +156,7 @@ class ContextResourceUsage(object):
         self.evt_db_fetch_count += other.evt_db_fetch_count
         return self
 
-    def __isub__(self, other):
+    def __isub__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
         self.ru_utime -= other.ru_utime
         self.ru_stime -= other.ru_stime
         self.db_txn_count -= other.db_txn_count
@@ -157,17 +165,67 @@ class ContextResourceUsage(object):
         self.evt_db_fetch_count -= other.evt_db_fetch_count
         return self
 
-    def __add__(self, other):
+    def __add__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
         res = ContextResourceUsage(copy_from=self)
         res += other
         return res
 
-    def __sub__(self, other):
+    def __sub__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
         res = ContextResourceUsage(copy_from=self)
         res -= other
         return res
 
 
+LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"]
+
+
+class _Sentinel(object):
+    """Sentinel to represent the root context"""
+
+    __slots__ = ["previous_context", "finished", "request", "scope", "tag"]
+
+    def __init__(self) -> None:
+        # Minimal set for compatibility with LoggingContext
+        self.previous_context = None
+        self.finished = False
+        self.request = None
+        self.scope = None
+        self.tag = None
+
+    def __str__(self):
+        return "sentinel"
+
+    def copy_to(self, record):
+        pass
+
+    def copy_to_twisted_log_entry(self, record):
+        record["request"] = None
+        record["scope"] = None
+
+    def start(self, rusage: "Optional[resource._RUsage]"):
+        pass
+
+    def stop(self, rusage: "Optional[resource._RUsage]"):
+        pass
+
+    def add_database_transaction(self, duration_sec):
+        pass
+
+    def add_database_scheduled(self, sched_sec):
+        pass
+
+    def record_event_fetch(self, event_count):
+        pass
+
+    def __nonzero__(self):
+        return False
+
+    __bool__ = __nonzero__  # python3
+
+
+SENTINEL_CONTEXT = _Sentinel()
+
+
 class LoggingContext(object):
     """Additional context for log formatting. Contexts are scoped within a
     "with" block.
@@ -189,67 +247,32 @@ class LoggingContext(object):
         "_resource_usage",
         "usage_start",
         "main_thread",
-        "alive",
+        "finished",
         "request",
         "tag",
         "scope",
     ]
 
-    thread_local = threading.local()
-
-    class Sentinel(object):
-        """Sentinel to represent the root context"""
-
-        __slots__ = []  # type: List[Any]
-
-        def __str__(self):
-            return "sentinel"
-
-        def copy_to(self, record):
-            pass
-
-        def copy_to_twisted_log_entry(self, record):
-            record["request"] = None
-            record["scope"] = None
-
-        def start(self):
-            pass
-
-        def stop(self):
-            pass
-
-        def add_database_transaction(self, duration_sec):
-            pass
-
-        def add_database_scheduled(self, sched_sec):
-            pass
-
-        def record_event_fetch(self, event_count):
-            pass
-
-        def __nonzero__(self):
-            return False
-
-        __bool__ = __nonzero__  # python3
-
-    sentinel = Sentinel()
-
-    def __init__(self, name=None, parent_context=None, request=None):
-        self.previous_context = LoggingContext.current_context()
+    def __init__(self, name=None, parent_context=None, request=None) -> None:
+        self.previous_context = current_context()
         self.name = name
 
         # track the resources used by this context so far
         self._resource_usage = ContextResourceUsage()
 
-        # If alive has the thread resource usage when the logcontext last
-        # became active.
-        self.usage_start = None
+        # The thread resource usage when the logcontext became active. None
+        # if the context is not currently active.
+        self.usage_start = None  # type: Optional[resource._RUsage]
 
         self.main_thread = get_thread_id()
         self.request = None
         self.tag = ""
-        self.alive = True
-        self.scope = None
+        self.scope = None  # type: Optional[_LogContextScope]
+
+        # keep track of whether we have hit the __exit__ block for this context
+        # (suggesting that the the thing that created the context thinks it should
+        # be finished, and that re-activating it would suggest an error).
+        self.finished = False
 
         self.parent_context = parent_context
 
@@ -260,76 +283,83 @@ class LoggingContext(object):
             # the request param overrides the request from the parent context
             self.request = request
 
-    def __str__(self):
+    def __str__(self) -> str:
         if self.request:
             return str(self.request)
         return "%s@%x" % (self.name, id(self))
 
     @classmethod
-    def current_context(cls):
+    def current_context(cls) -> LoggingContextOrSentinel:
         """Get the current logging context from thread local storage
 
+        This exists for backwards compatibility. ``current_context()`` should be
+        called directly.
+
         Returns:
             LoggingContext: the current logging context
         """
-        return getattr(cls.thread_local, "current_context", cls.sentinel)
+        warnings.warn(
+            "synapse.logging.context.LoggingContext.current_context() is deprecated "
+            "in favor of synapse.logging.context.current_context().",
+            DeprecationWarning,
+            stacklevel=2,
+        )
+        return current_context()
 
     @classmethod
-    def set_current_context(cls, context):
+    def set_current_context(
+        cls, context: LoggingContextOrSentinel
+    ) -> LoggingContextOrSentinel:
         """Set the current logging context in thread local storage
+
+        This exists for backwards compatibility. ``set_current_context()`` should be
+        called directly.
+
         Args:
             context(LoggingContext): The context to activate.
         Returns:
             The context that was previously active
         """
-        current = cls.current_context()
-
-        if current is not context:
-            current.stop()
-            cls.thread_local.current_context = context
-            context.start()
-        return current
+        warnings.warn(
+            "synapse.logging.context.LoggingContext.set_current_context() is deprecated "
+            "in favor of synapse.logging.context.set_current_context().",
+            DeprecationWarning,
+            stacklevel=2,
+        )
+        return set_current_context(context)
 
-    def __enter__(self):
+    def __enter__(self) -> "LoggingContext":
         """Enters this logging context into thread local storage"""
-        old_context = self.set_current_context(self)
+        old_context = set_current_context(self)
         if self.previous_context != old_context:
-            logger.warn(
+            logger.warning(
                 "Expected previous context %r, found %r",
                 self.previous_context,
                 old_context,
             )
-        self.alive = True
-
         return self
 
-    def __exit__(self, type, value, traceback):
+    def __exit__(self, type, value, traceback) -> None:
         """Restore the logging context in thread local storage to the state it
         was before this context was entered.
         Returns:
             None to avoid suppressing any exceptions that were thrown.
         """
-        current = self.set_current_context(self.previous_context)
+        current = set_current_context(self.previous_context)
         if current is not self:
-            if current is self.sentinel:
+            if current is SENTINEL_CONTEXT:
                 logger.warning("Expected logging context %s was lost", self)
             else:
                 logger.warning(
                     "Expected logging context %s but found %s", self, current
                 )
-        self.previous_context = None
-        self.alive = False
-
-        # if we have a parent, pass our CPU usage stats on
-        if self.parent_context is not None and hasattr(
-            self.parent_context, "_resource_usage"
-        ):
-            self.parent_context._resource_usage += self._resource_usage
 
-            # reset them in case we get entered again
-            self._resource_usage.reset()
+        # the fact that we are here suggests that the caller thinks that everything
+        # is done and dusted for this logcontext, and further activity will not get
+        # recorded against the correct metrics.
+        self.finished = True
 
-    def copy_to(self, record):
+    def copy_to(self, record) -> None:
         """Copy logging fields from this context to a log record or
         another LoggingContext
         """
@@ -340,44 +370,72 @@ class LoggingContext(object):
         # we also track the current scope:
         record.scope = self.scope
 
-    def copy_to_twisted_log_entry(self, record):
+    def copy_to_twisted_log_entry(self, record) -> None:
         """
         Copy logging fields from this context to a Twisted log record.
         """
         record["request"] = self.request
         record["scope"] = self.scope
 
-    def start(self):
+    def start(self, rusage: "Optional[resource._RUsage]") -> None:
+        """
+        Record that this logcontext is currently running.
+
+        This should not be called directly: use set_current_context
+
+        Args:
+            rusage: the resources used by the current thread, at the point of
+                switching to this logcontext. May be None if this platform doesn't
+                support getrusuage.
+        """
         if get_thread_id() != self.main_thread:
             logger.warning("Started logcontext %s on different thread", self)
             return
 
+        if self.finished:
+            logger.warning("Re-starting finished log context %s", self)
+
         # If we haven't already started record the thread resource usage so
         # far
-        if not self.usage_start:
-            self.usage_start = get_thread_resource_usage()
+        if self.usage_start:
+            logger.warning("Re-starting already-active log context %s", self)
+        else:
+            self.usage_start = rusage
 
-    def stop(self):
-        if get_thread_id() != self.main_thread:
-            logger.warning("Stopped logcontext %s on different thread", self)
-            return
+    def stop(self, rusage: "Optional[resource._RUsage]") -> None:
+        """
+        Record that this logcontext is no longer running.
 
-        # When we stop, let's record the cpu used since we started
-        if not self.usage_start:
-            # Log a warning on platforms that support thread usage tracking
-            if is_thread_resource_usage_supported:
+        This should not be called directly: use set_current_context
+
+        Args:
+            rusage: the resources used by the current thread, at the point of
+                switching away from this logcontext. May be None if this platform
+                doesn't support getrusuage.
+        """
+
+        try:
+            if get_thread_id() != self.main_thread:
+                logger.warning("Stopped logcontext %s on different thread", self)
+                return
+
+            if not rusage:
+                return
+
+            # Record the cpu used since we started
+            if not self.usage_start:
                 logger.warning(
-                    "Called stop on logcontext %s without calling start", self
+                    "Called stop on logcontext %s without recording a start rusage",
+                    self,
                 )
-            return
+                return
 
-        utime_delta, stime_delta = self._get_cputime()
-        self._resource_usage.ru_utime += utime_delta
-        self._resource_usage.ru_stime += stime_delta
-
-        self.usage_start = None
+            utime_delta, stime_delta = self._get_cputime(rusage)
+            self.add_cputime(utime_delta, stime_delta)
+        finally:
+            self.usage_start = None
 
-    def get_resource_usage(self):
+    def get_resource_usage(self) -> ContextResourceUsage:
         """Get resources used by this logcontext so far.
 
         Returns:
@@ -390,19 +448,24 @@ class LoggingContext(object):
         # If we are on the correct thread and we're currently running then we
         # can include resource usage so far.
         is_main_thread = get_thread_id() == self.main_thread
-        if self.alive and self.usage_start and is_main_thread:
-            utime_delta, stime_delta = self._get_cputime()
+        if self.usage_start and is_main_thread:
+            rusage = get_thread_resource_usage()
+            assert rusage is not None
+            utime_delta, stime_delta = self._get_cputime(rusage)
             res.ru_utime += utime_delta
             res.ru_stime += stime_delta
 
         return res
 
-    def _get_cputime(self):
-        """Get the cpu usage time so far
+    def _get_cputime(self, current: "resource._RUsage") -> Tuple[float, float]:
+        """Get the cpu usage time between start() and the given rusage
+
+        Args:
+            rusage: the current resource usage
 
         Returns: Tuple[float, float]: seconds in user mode, seconds in system mode
         """
-        current = get_thread_resource_usage()
+        assert self.usage_start is not None
 
         utime_delta = current.ru_utime - self.usage_start.ru_utime
         stime_delta = current.ru_stime - self.usage_start.ru_stime
@@ -426,30 +489,52 @@ class LoggingContext(object):
 
         return utime_delta, stime_delta
 
-    def add_database_transaction(self, duration_sec):
+    def add_cputime(self, utime_delta: float, stime_delta: float) -> None:
+        """Update the CPU time usage of this context (and any parents, recursively).
+
+        Args:
+            utime_delta: additional user time, in seconds, spent in this context.
+            stime_delta: additional system time, in seconds, spent in this context.
+        """
+        self._resource_usage.ru_utime += utime_delta
+        self._resource_usage.ru_stime += stime_delta
+        if self.parent_context:
+            self.parent_context.add_cputime(utime_delta, stime_delta)
+
+    def add_database_transaction(self, duration_sec: float) -> None:
+        """Record the use of a database transaction and the length of time it took.
+
+        Args:
+            duration_sec: The number of seconds the database transaction took.
+        """
         if duration_sec < 0:
             raise ValueError("DB txn time can only be non-negative")
         self._resource_usage.db_txn_count += 1
         self._resource_usage.db_txn_duration_sec += duration_sec
+        if self.parent_context:
+            self.parent_context.add_database_transaction(duration_sec)
 
-    def add_database_scheduled(self, sched_sec):
+    def add_database_scheduled(self, sched_sec: float) -> None:
         """Record a use of the database pool
 
         Args:
-            sched_sec (float): number of seconds it took us to get a
-                connection
+            sched_sec: number of seconds it took us to get a connection
         """
         if sched_sec < 0:
             raise ValueError("DB scheduling time can only be non-negative")
         self._resource_usage.db_sched_duration_sec += sched_sec
+        if self.parent_context:
+            self.parent_context.add_database_scheduled(sched_sec)
 
-    def record_event_fetch(self, event_count):
+    def record_event_fetch(self, event_count: int) -> None:
         """Record a number of events being fetched from the db
 
         Args:
-            event_count (int): number of events being fetched
+            event_count: number of events being fetched
         """
         self._resource_usage.evt_db_fetch_count += event_count
+        if self.parent_context:
+            self.parent_context.record_event_fetch(event_count)
 
 
 class LoggingContextFilter(logging.Filter):
@@ -460,15 +545,15 @@ class LoggingContextFilter(logging.Filter):
             missing fields
     """
 
-    def __init__(self, **defaults):
+    def __init__(self, **defaults) -> None:
         self.defaults = defaults
 
-    def filter(self, record):
+    def filter(self, record) -> Literal[True]:
         """Add each fields from the logging contexts to the record.
         Returns:
             True to include the record in the log output.
         """
-        context = LoggingContext.current_context()
+        context = current_context()
         for key, value in self.defaults.items():
             setattr(record, key, value)
 
@@ -488,26 +573,24 @@ class PreserveLoggingContext(object):
 
     __slots__ = ["current_context", "new_context", "has_parent"]
 
-    def __init__(self, new_context=None):
-        if new_context is None:
-            new_context = LoggingContext.sentinel
+    def __init__(
+        self, new_context: LoggingContextOrSentinel = SENTINEL_CONTEXT
+    ) -> None:
         self.new_context = new_context
 
-    def __enter__(self):
+    def __enter__(self) -> None:
         """Captures the current logging context"""
-        self.current_context = LoggingContext.set_current_context(self.new_context)
+        self.current_context = set_current_context(self.new_context)
 
         if self.current_context:
             self.has_parent = self.current_context.previous_context is not None
-            if not self.current_context.alive:
-                logger.debug("Entering dead context: %s", self.current_context)
 
-    def __exit__(self, type, value, traceback):
+    def __exit__(self, type, value, traceback) -> None:
         """Restores the current logging context"""
-        context = LoggingContext.set_current_context(self.current_context)
+        context = set_current_context(self.current_context)
 
         if context != self.new_context:
-            if context is LoggingContext.sentinel:
+            if not context:
                 logger.warning("Expected logging context %s was lost", self.new_context)
             else:
                 logger.warning(
@@ -516,12 +599,42 @@ class PreserveLoggingContext(object):
                     context,
                 )
 
-        if self.current_context is not LoggingContext.sentinel:
-            if not self.current_context.alive:
-                logger.debug("Restoring dead context: %s", self.current_context)
 
+_thread_local = threading.local()
+_thread_local.current_context = SENTINEL_CONTEXT
+
+
+def current_context() -> LoggingContextOrSentinel:
+    """Get the current logging context from thread local storage"""
+    return getattr(_thread_local, "current_context", SENTINEL_CONTEXT)
 
-def nested_logging_context(suffix, parent_context=None):
+
+def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSentinel:
+    """Set the current logging context in thread local storage
+    Args:
+        context(LoggingContext): The context to activate.
+    Returns:
+        The context that was previously active
+    """
+    # everything blows up if we allow current_context to be set to None, so sanity-check
+    # that now.
+    if context is None:
+        raise TypeError("'context' argument may not be None")
+
+    current = current_context()
+
+    if current is not context:
+        rusage = get_thread_resource_usage()
+        current.stop(rusage)
+        _thread_local.current_context = context
+        context.start(rusage)
+
+    return current
+
+
+def nested_logging_context(
+    suffix: str, parent_context: Optional[LoggingContext] = None
+) -> LoggingContext:
     """Creates a new logging context as a child of another.
 
     The nested logging context will have a 'request' made up of the parent context's
@@ -542,10 +655,12 @@ def nested_logging_context(suffix, parent_context=None):
     Returns:
         LoggingContext: new logging context.
     """
-    if parent_context is None:
-        parent_context = LoggingContext.current_context()
+    if parent_context is not None:
+        context = parent_context  # type: LoggingContextOrSentinel
+    else:
+        context = current_context()
     return LoggingContext(
-        parent_context=parent_context, request=parent_context.request + "-" + suffix
+        parent_context=context, request=str(context.request) + "-" + suffix
     )
 
 
@@ -567,12 +682,15 @@ def run_in_background(f, *args, **kwargs):
     yield or await on (for instance because you want to pass it to
     deferred.gatherResults()).
 
+    If f returns a Coroutine object, it will be wrapped into a Deferred (which will have
+    the side effect of executing the coroutine).
+
     Note that if you completely discard the result, you should make sure that
     `f` doesn't raise any deferred exceptions, otherwise a scary-looking
     CRITICAL error about an unhandled error will be logged without much
     indication about where it came from.
     """
-    current = LoggingContext.current_context()
+    current = current_context()
     try:
         res = f(*args, **kwargs)
     except:  # noqa: E722
@@ -593,7 +711,7 @@ def run_in_background(f, *args, **kwargs):
 
     # The function may have reset the context before returning, so
     # we need to restore it now.
-    ctx = LoggingContext.set_current_context(current)
+    ctx = set_current_context(current)
 
     # The original context will be restored when the deferred
     # completes, but there is nothing waiting for it, so it will
@@ -612,7 +730,8 @@ def run_in_background(f, *args, **kwargs):
 
 
 def make_deferred_yieldable(deferred):
-    """Given a deferred, make it follow the Synapse logcontext rules:
+    """Given a deferred (or coroutine), make it follow the Synapse logcontext
+    rules:
 
     If the deferred has completed (or is not actually a Deferred), essentially
     does nothing (just returns another completed deferred with the
@@ -624,6 +743,13 @@ def make_deferred_yieldable(deferred):
 
     (This is more-or-less the opposite operation to run_in_background.)
     """
+    if inspect.isawaitable(deferred):
+        # If we're given a coroutine we convert it to a deferred so that we
+        # run it and find out if it immediately finishes, it it does then we
+        # don't need to fiddle with log contexts at all and can return
+        # immediately.
+        deferred = defer.ensureDeferred(deferred)
+
     if not isinstance(deferred, defer.Deferred):
         return deferred
 
@@ -634,14 +760,17 @@ def make_deferred_yieldable(deferred):
 
     # ok, we can't be sure that a yield won't block, so let's reset the
     # logcontext, and add a callback to the deferred to restore it.
-    prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
+    prev_context = set_current_context(SENTINEL_CONTEXT)
     deferred.addBoth(_set_context_cb, prev_context)
     return deferred
 
 
-def _set_context_cb(result, context):
+ResultT = TypeVar("ResultT")
+
+
+def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
     """A callback function which just sets the logging context"""
-    LoggingContext.set_current_context(context)
+    set_current_context(context)
     return result
 
 
@@ -709,7 +838,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
         Deferred: A Deferred which fires a callback with the result of `f`, or an
             errback if `f` throws an exception.
     """
-    logcontext = LoggingContext.current_context()
+    logcontext = current_context()
 
     def g():
         with LoggingContext(parent_context=logcontext):
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 308a27213b..5dddf57008 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -169,7 +169,9 @@ import contextlib
 import inspect
 import logging
 import re
+import types
 from functools import wraps
+from typing import TYPE_CHECKING, Dict
 
 from canonicaljson import json
 
@@ -177,6 +179,9 @@ from twisted.internet import defer
 
 from synapse.config import ConfigError
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 # Helper class
 
 
@@ -295,14 +300,11 @@ def _noop_context_manager(*args, **kwargs):
 # Setup
 
 
-def init_tracer(config):
+def init_tracer(hs: "HomeServer"):
     """Set the whitelists and initialise the JaegerClient tracer
-
-    Args:
-        config (HomeserverConfig): The config used by the homeserver
     """
     global opentracing
-    if not config.opentracer_enabled:
+    if not hs.config.opentracer_enabled:
         # We don't have a tracer
         opentracing = None
         return
@@ -313,18 +315,15 @@ def init_tracer(config):
             "installed."
         )
 
-    # Include the worker name
-    name = config.worker_name if config.worker_name else "master"
-
     # Pull out the jaeger config if it was given. Otherwise set it to something sensible.
     # See https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/config.py
 
-    set_homeserver_whitelist(config.opentracer_whitelist)
+    set_homeserver_whitelist(hs.config.opentracer_whitelist)
 
     JaegerConfig(
-        config=config.jaeger_config,
-        service_name="{} {}".format(config.server_name, name),
-        scope_manager=LogContextScopeManager(config),
+        config=hs.config.jaeger_config,
+        service_name="{} {}".format(hs.config.server_name, hs.get_instance_name()),
+        scope_manager=LogContextScopeManager(hs.config),
     ).initialize_tracer()
 
 
@@ -547,7 +546,7 @@ def inject_active_span_twisted_headers(headers, destination, check_destination=T
         return
 
     span = opentracing.tracer.active_span
-    carrier = {}
+    carrier = {}  # type: Dict[str, str]
     opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
 
     for key, value in carrier.items():
@@ -584,7 +583,7 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True):
 
     span = opentracing.tracer.active_span
 
-    carrier = {}
+    carrier = {}  # type: Dict[str, str]
     opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
 
     for key, value in carrier.items():
@@ -639,7 +638,7 @@ def get_active_span_text_map(destination=None):
     if destination and not whitelisted_homeserver(destination):
         return {}
 
-    carrier = {}
+    carrier = {}  # type: Dict[str, str]
     opentracing.tracer.inject(
         opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
     )
@@ -653,7 +652,7 @@ def active_span_context_as_string():
     Returns:
         The active span context encoded as a string.
     """
-    carrier = {}
+    carrier = {}  # type: Dict[str, str]
     if opentracing:
         opentracing.tracer.inject(
             opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
@@ -777,8 +776,7 @@ def trace_servlet(servlet_name, extract_context=False):
             return func
 
         @wraps(func)
-        @defer.inlineCallbacks
-        def _trace_servlet_inner(request, *args, **kwargs):
+        async def _trace_servlet_inner(request, *args, **kwargs):
             request_tags = {
                 "request_id": request.get_request_id(),
                 tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
@@ -795,8 +793,14 @@ def trace_servlet(servlet_name, extract_context=False):
                 scope = start_active_span(servlet_name, tags=request_tags)
 
             with scope:
-                result = yield defer.maybeDeferred(func, request, *args, **kwargs)
-                return result
+                result = func(request, *args, **kwargs)
+
+                if not isinstance(result, (types.CoroutineType, defer.Deferred)):
+                    # Some servlets aren't async and just return results
+                    # directly, so we handle that here.
+                    return result
+
+                return await result
 
         return _trace_servlet_inner
 
diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py
index 4eed4f2338..dc3ab00cbb 100644
--- a/synapse/logging/scopecontextmanager.py
+++ b/synapse/logging/scopecontextmanager.py
@@ -19,7 +19,7 @@ from opentracing import Scope, ScopeManager
 
 import twisted
 
-from synapse.logging.context import LoggingContext, nested_logging_context
+from synapse.logging.context import current_context, nested_logging_context
 
 logger = logging.getLogger(__name__)
 
@@ -49,11 +49,8 @@ class LogContextScopeManager(ScopeManager):
             (Scope) : the Scope that is active, or None if not
             available.
         """
-        ctx = LoggingContext.current_context()
-        if ctx is LoggingContext.sentinel:
-            return None
-        else:
-            return ctx.scope
+        ctx = current_context()
+        return ctx.scope
 
     def activate(self, span, finish_on_close):
         """
@@ -70,9 +67,9 @@ class LogContextScopeManager(ScopeManager):
         """
 
         enter_logcontext = False
-        ctx = LoggingContext.current_context()
+        ctx = current_context()
 
-        if ctx is LoggingContext.sentinel:
+        if not ctx:
             # We don't want this scope to affect.
             logger.error("Tried to activate scope outside of loggingcontext")
             return Scope(None, span)
diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py
index 7df0fa6087..99049bb5d8 100644
--- a/synapse/logging/utils.py
+++ b/synapse/logging/utils.py
@@ -20,8 +20,6 @@ import time
 from functools import wraps
 from inspect import getcallargs
 
-from six import PY3
-
 _TIME_FUNC_ID = 0
 
 
@@ -30,12 +28,8 @@ def _log_debug_as_f(f, msg, msg_args):
     logger = logging.getLogger(name)
 
     if logger.isEnabledFor(logging.DEBUG):
-        if PY3:
-            lineno = f.__code__.co_firstlineno
-            pathname = f.__code__.co_filename
-        else:
-            lineno = f.func_code.co_firstlineno
-            pathname = f.func_code.co_filename
+        lineno = f.__code__.co_firstlineno
+        pathname = f.__code__.co_filename
 
         record = logging.LogRecord(
             name=name,
@@ -119,7 +113,11 @@ def trace_function(f):
         logger = logging.getLogger(name)
         level = logging.DEBUG
 
-        s = inspect.currentframe().f_back
+        frame = inspect.currentframe()
+        if frame is None:
+            raise Exception("Can't get current frame!")
+
+        s = frame.f_back
 
         to_print = [
             "\t%s:%s %s. Args: args=%s, kwargs=%s"
@@ -144,7 +142,7 @@ def trace_function(f):
             pathname=pathname,
             lineno=lineno,
             msg=msg,
-            args=None,
+            args=(),
             exc_info=None,
         )
 
@@ -157,7 +155,12 @@ def trace_function(f):
 
 
 def get_previous_frames():
-    s = inspect.currentframe().f_back.f_back
+
+    frame = inspect.currentframe()
+    if frame is None:
+        raise Exception("Can't get current frame!")
+
+    s = frame.f_back.f_back
     to_return = []
     while s:
         if s.f_globals["__name__"].startswith("synapse"):
@@ -174,7 +177,10 @@ def get_previous_frames():
 
 
 def get_previous_frame(ignore=[]):
-    s = inspect.currentframe().f_back.f_back
+    frame = inspect.currentframe()
+    if frame is None:
+        raise Exception("Can't get current frame!")
+    s = frame.f_back.f_back
 
     while s:
         if s.f_globals["__name__"].startswith("synapse"):