summary refs log tree commit diff
path: root/synapse/logging
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/logging')
-rw-r--r--synapse/logging/_structured.py374
-rw-r--r--synapse/logging/_terse_json.py280
-rw-r--r--synapse/logging/context.py27
-rw-r--r--synapse/logging/opentracing.py419
4 files changed, 919 insertions, 181 deletions
diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
new file mode 100644
index 0000000000..3220e985a9
--- /dev/null
+++ b/synapse/logging/_structured.py
@@ -0,0 +1,374 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os.path
+import sys
+import typing
+import warnings
+from typing import List
+
+import attr
+from constantly import NamedConstant, Names, ValueConstant, Values
+from zope.interface import implementer
+
+from twisted.logger import (
+    FileLogObserver,
+    FilteringLogObserver,
+    ILogObserver,
+    LogBeginner,
+    Logger,
+    LogLevel,
+    LogLevelFilterPredicate,
+    LogPublisher,
+    eventAsText,
+    jsonFileLogObserver,
+)
+
+from synapse.config._base import ConfigError
+from synapse.logging._terse_json import (
+    TerseJSONToConsoleLogObserver,
+    TerseJSONToTCPLogObserver,
+)
+from synapse.logging.context import LoggingContext
+
+
+def stdlib_log_level_to_twisted(level: str) -> LogLevel:
+    """
+    Convert a stdlib log level to Twisted's log level.
+    """
+    lvl = level.lower().replace("warning", "warn")
+    return LogLevel.levelWithName(lvl)
+
+
+@attr.s
+@implementer(ILogObserver)
+class LogContextObserver(object):
+    """
+    An ILogObserver which adds Synapse-specific log context information.
+
+    Attributes:
+        observer (ILogObserver): The target parent observer.
+    """
+
+    observer = attr.ib()
+
+    def __call__(self, event: dict) -> None:
+        """
+        Consume a log event and emit it to the parent observer after filtering
+        and adding log context information.
+
+        Args:
+            event (dict)
+        """
+        # Filter out some useless events that Twisted outputs
+        if "log_text" in event:
+            if event["log_text"].startswith("DNSDatagramProtocol starting on "):
+                return
+
+            if event["log_text"].startswith("(UDP Port "):
+                return
+
+            if event["log_text"].startswith("Timing out client") or event[
+                "log_format"
+            ].startswith("Timing out client"):
+                return
+
+        context = LoggingContext.current_context()
+
+        # Copy the context information to the log event.
+        if context is not None:
+            context.copy_to_twisted_log_entry(event)
+        else:
+            # If there's no logging context, not even the root one, we might be
+            # starting up or it might be from non-Synapse code. Log it as if it
+            # came from the root logger.
+            event["request"] = None
+            event["scope"] = None
+
+        self.observer(event)
+
+
+class PythonStdlibToTwistedLogger(logging.Handler):
+    """
+    Transform a Python stdlib log message into a Twisted one.
+    """
+
+    def __init__(self, observer, *args, **kwargs):
+        """
+        Args:
+            observer (ILogObserver): A Twisted logging observer.
+            *args, **kwargs: Args/kwargs to be passed to logging.Handler.
+        """
+        self.observer = observer
+        super().__init__(*args, **kwargs)
+
+    def emit(self, record: logging.LogRecord) -> None:
+        """
+        Emit a record to Twisted's observer.
+
+        Args:
+            record (logging.LogRecord)
+        """
+
+        self.observer(
+            {
+                "log_time": record.created,
+                "log_text": record.getMessage(),
+                "log_format": "{log_text}",
+                "log_namespace": record.name,
+                "log_level": stdlib_log_level_to_twisted(record.levelname),
+            }
+        )
+
+
+def SynapseFileLogObserver(outFile: typing.IO[str]) -> FileLogObserver:
+    """
+    A log observer that formats events like the traditional log formatter and
+    sends them to `outFile`.
+
+    Args:
+        outFile (file object): The file object to write to.
+    """
+
+    def formatEvent(_event: dict) -> str:
+        event = dict(_event)
+        event["log_level"] = event["log_level"].name.upper()
+        event["log_format"] = "- {log_namespace} - {log_level} - {request} - " + (
+            event.get("log_format", "{log_text}") or "{log_text}"
+        )
+        return eventAsText(event, includeSystem=False) + "\n"
+
+    return FileLogObserver(outFile, formatEvent)
+
+
+class DrainType(Names):
+    CONSOLE = NamedConstant()
+    CONSOLE_JSON = NamedConstant()
+    CONSOLE_JSON_TERSE = NamedConstant()
+    FILE = NamedConstant()
+    FILE_JSON = NamedConstant()
+    NETWORK_JSON_TERSE = NamedConstant()
+
+
+class OutputPipeType(Values):
+    stdout = ValueConstant(sys.__stdout__)
+    stderr = ValueConstant(sys.__stderr__)
+
+
+@attr.s
+class DrainConfiguration(object):
+    name = attr.ib()
+    type = attr.ib()
+    location = attr.ib()
+    options = attr.ib(default=None)
+
+
+@attr.s
+class NetworkJSONTerseOptions(object):
+    maximum_buffer = attr.ib(type=int)
+
+
+DEFAULT_LOGGERS = {"synapse": {"level": "INFO"}}
+
+
+def parse_drain_configs(
+    drains: dict
+) -> typing.Generator[DrainConfiguration, None, None]:
+    """
+    Parse the drain configurations.
+
+    Args:
+        drains (dict): A list of drain configurations.
+
+    Yields:
+        DrainConfiguration instances.
+
+    Raises:
+        ConfigError: If any of the drain configuration items are invalid.
+    """
+    for name, config in drains.items():
+        if "type" not in config:
+            raise ConfigError("Logging drains require a 'type' key.")
+
+        try:
+            logging_type = DrainType.lookupByName(config["type"].upper())
+        except ValueError:
+            raise ConfigError(
+                "%s is not a known logging drain type." % (config["type"],)
+            )
+
+        if logging_type in [
+            DrainType.CONSOLE,
+            DrainType.CONSOLE_JSON,
+            DrainType.CONSOLE_JSON_TERSE,
+        ]:
+            location = config.get("location")
+            if location is None or location not in ["stdout", "stderr"]:
+                raise ConfigError(
+                    (
+                        "The %s drain needs the 'location' key set to "
+                        "either 'stdout' or 'stderr'."
+                    )
+                    % (logging_type,)
+                )
+
+            pipe = OutputPipeType.lookupByName(location).value
+
+            yield DrainConfiguration(name=name, type=logging_type, location=pipe)
+
+        elif logging_type in [DrainType.FILE, DrainType.FILE_JSON]:
+            if "location" not in config:
+                raise ConfigError(
+                    "The %s drain needs the 'location' key set." % (logging_type,)
+                )
+
+            location = config.get("location")
+            if os.path.abspath(location) != location:
+                raise ConfigError(
+                    "File paths need to be absolute, '%s' is a relative path"
+                    % (location,)
+                )
+            yield DrainConfiguration(name=name, type=logging_type, location=location)
+
+        elif logging_type in [DrainType.NETWORK_JSON_TERSE]:
+            host = config.get("host")
+            port = config.get("port")
+            maximum_buffer = config.get("maximum_buffer", 1000)
+            yield DrainConfiguration(
+                name=name,
+                type=logging_type,
+                location=(host, port),
+                options=NetworkJSONTerseOptions(maximum_buffer=maximum_buffer),
+            )
+
+        else:
+            raise ConfigError(
+                "The %s drain type is currently not implemented."
+                % (config["type"].upper(),)
+            )
+
+
+def setup_structured_logging(
+    hs,
+    config,
+    log_config: dict,
+    logBeginner: LogBeginner,
+    redirect_stdlib_logging: bool = True,
+) -> LogPublisher:
+    """
+    Set up Twisted's structured logging system.
+
+    Args:
+        hs: The homeserver to use.
+        config (HomeserverConfig): The configuration of the Synapse homeserver.
+        log_config (dict): The log configuration to use.
+    """
+    if config.no_redirect_stdio:
+        raise ConfigError(
+            "no_redirect_stdio cannot be defined using structured logging."
+        )
+
+    logger = Logger()
+
+    if "drains" not in log_config:
+        raise ConfigError("The logging configuration requires a list of drains.")
+
+    observers = []  # type: List[ILogObserver]
+
+    for observer in parse_drain_configs(log_config["drains"]):
+        # Pipe drains
+        if observer.type == DrainType.CONSOLE:
+            logger.debug(
+                "Starting up the {name} console logger drain", name=observer.name
+            )
+            observers.append(SynapseFileLogObserver(observer.location))
+        elif observer.type == DrainType.CONSOLE_JSON:
+            logger.debug(
+                "Starting up the {name} JSON console logger drain", name=observer.name
+            )
+            observers.append(jsonFileLogObserver(observer.location))
+        elif observer.type == DrainType.CONSOLE_JSON_TERSE:
+            logger.debug(
+                "Starting up the {name} terse JSON console logger drain",
+                name=observer.name,
+            )
+            observers.append(
+                TerseJSONToConsoleLogObserver(observer.location, metadata={})
+            )
+
+        # File drains
+        elif observer.type == DrainType.FILE:
+            logger.debug("Starting up the {name} file logger drain", name=observer.name)
+            log_file = open(observer.location, "at", buffering=1, encoding="utf8")
+            observers.append(SynapseFileLogObserver(log_file))
+        elif observer.type == DrainType.FILE_JSON:
+            logger.debug(
+                "Starting up the {name} JSON file logger drain", name=observer.name
+            )
+            log_file = open(observer.location, "at", buffering=1, encoding="utf8")
+            observers.append(jsonFileLogObserver(log_file))
+
+        elif observer.type == DrainType.NETWORK_JSON_TERSE:
+            metadata = {"server_name": hs.config.server_name}
+            log_observer = TerseJSONToTCPLogObserver(
+                hs=hs,
+                host=observer.location[0],
+                port=observer.location[1],
+                metadata=metadata,
+                maximum_buffer=observer.options.maximum_buffer,
+            )
+            log_observer.start()
+            observers.append(log_observer)
+        else:
+            # We should never get here, but, just in case, throw an error.
+            raise ConfigError("%s drain type cannot be configured" % (observer.type,))
+
+    publisher = LogPublisher(*observers)
+    log_filter = LogLevelFilterPredicate()
+
+    for namespace, namespace_config in log_config.get(
+        "loggers", DEFAULT_LOGGERS
+    ).items():
+        # Set the log level for twisted.logger.Logger namespaces
+        log_filter.setLogLevelForNamespace(
+            namespace,
+            stdlib_log_level_to_twisted(namespace_config.get("level", "INFO")),
+        )
+
+        # Also set the log levels for the stdlib logger namespaces, to prevent
+        # them getting to PythonStdlibToTwistedLogger and having to be formatted
+        if "level" in namespace_config:
+            logging.getLogger(namespace).setLevel(namespace_config.get("level"))
+
+    f = FilteringLogObserver(publisher, [log_filter])
+    lco = LogContextObserver(f)
+
+    if redirect_stdlib_logging:
+        stuff_into_twisted = PythonStdlibToTwistedLogger(lco)
+        stdliblogger = logging.getLogger()
+        stdliblogger.addHandler(stuff_into_twisted)
+
+    # Always redirect standard I/O, otherwise other logging outputs might miss
+    # it.
+    logBeginner.beginLoggingTo([lco], redirectStandardIO=True)
+
+    return publisher
+
+
+def reload_structured_logging(*args, log_config=None) -> None:
+    warnings.warn(
+        "Currently the structured logging system can not be reloaded, doing nothing"
+    )
diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py
new file mode 100644
index 0000000000..0ebbde06f2
--- /dev/null
+++ b/synapse/logging/_terse_json.py
@@ -0,0 +1,280 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Log formatters that output terse JSON.
+"""
+
+import sys
+from collections import deque
+from ipaddress import IPv4Address, IPv6Address, ip_address
+from math import floor
+from typing import IO
+
+import attr
+from simplejson import dumps
+from zope.interface import implementer
+
+from twisted.application.internet import ClientService
+from twisted.internet.endpoints import (
+    HostnameEndpoint,
+    TCP4ClientEndpoint,
+    TCP6ClientEndpoint,
+)
+from twisted.internet.protocol import Factory, Protocol
+from twisted.logger import FileLogObserver, ILogObserver, Logger
+from twisted.python.failure import Failure
+
+
+def flatten_event(event: dict, metadata: dict, include_time: bool = False):
+    """
+    Flatten a Twisted logging event to an dictionary capable of being sent
+    as a log event to a logging aggregation system.
+
+    The format is vastly simplified and is not designed to be a "human readable
+    string" in the sense that traditional logs are. Instead, the structure is
+    optimised for searchability and filtering, with human-understandable log
+    keys.
+
+    Args:
+        event (dict): The Twisted logging event we are flattening.
+        metadata (dict): Additional data to include with each log message. This
+            can be information like the server name. Since the target log
+            consumer does not know who we are other than by host IP, this
+            allows us to forward through static information.
+        include_time (bool): Should we include the `time` key? If False, the
+            event time is stripped from the event.
+    """
+    new_event = {}
+
+    # If it's a failure, make the new event's log_failure be the traceback text.
+    if "log_failure" in event:
+        new_event["log_failure"] = event["log_failure"].getTraceback()
+
+    # If it's a warning, copy over a string representation of the warning.
+    if "warning" in event:
+        new_event["warning"] = str(event["warning"])
+
+    # Stdlib logging events have "log_text" as their human-readable portion,
+    # Twisted ones have "log_format". For now, include the log_format, so that
+    # context only given in the log format (e.g. what is being logged) is
+    # available.
+    if "log_text" in event:
+        new_event["log"] = event["log_text"]
+    else:
+        new_event["log"] = event["log_format"]
+
+    # We want to include the timestamp when forwarding over the network, but
+    # exclude it when we are writing to stdout. This is because the log ingester
+    # (e.g. logstash, fluentd) can add its own timestamp.
+    if include_time:
+        new_event["time"] = round(event["log_time"], 2)
+
+    # Convert the log level to a textual representation.
+    new_event["level"] = event["log_level"].name.upper()
+
+    # Ignore these keys, and do not transfer them over to the new log object.
+    # They are either useless (isError), transferred manually above (log_time,
+    # log_level, etc), or contain Python objects which are not useful for output
+    # (log_logger, log_source).
+    keys_to_delete = [
+        "isError",
+        "log_failure",
+        "log_format",
+        "log_level",
+        "log_logger",
+        "log_source",
+        "log_system",
+        "log_time",
+        "log_text",
+        "observer",
+        "warning",
+    ]
+
+    # If it's from the Twisted legacy logger (twisted.python.log), it adds some
+    # more keys we want to purge.
+    if event.get("log_namespace") == "log_legacy":
+        keys_to_delete.extend(["message", "system", "time"])
+
+    # Rather than modify the dictionary in place, construct a new one with only
+    # the content we want. The original event should be considered 'frozen'.
+    for key in event.keys():
+
+        if key in keys_to_delete:
+            continue
+
+        if isinstance(event[key], (str, int, bool, float)) or event[key] is None:
+            # If it's a plain type, include it as is.
+            new_event[key] = event[key]
+        else:
+            # If it's not one of those basic types, write out a string
+            # representation. This should probably be a warning in development,
+            # so that we are sure we are only outputting useful data.
+            new_event[key] = str(event[key])
+
+    # Add the metadata information to the event (e.g. the server_name).
+    new_event.update(metadata)
+
+    return new_event
+
+
+def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogObserver:
+    """
+    A log observer that formats events to a flattened JSON representation.
+
+    Args:
+        outFile: The file object to write to.
+        metadata: Metadata to be added to each log object.
+    """
+
+    def formatEvent(_event: dict) -> str:
+        flattened = flatten_event(_event, metadata)
+        return dumps(flattened, ensure_ascii=False, separators=(",", ":")) + "\n"
+
+    return FileLogObserver(outFile, formatEvent)
+
+
+@attr.s
+@implementer(ILogObserver)
+class TerseJSONToTCPLogObserver(object):
+    """
+    An IObserver that writes JSON logs to a TCP target.
+
+    Args:
+        hs (HomeServer): The Homeserver that is being logged for.
+        host: The host of the logging target.
+        port: The logging target's port.
+        metadata: Metadata to be added to each log entry.
+    """
+
+    hs = attr.ib()
+    host = attr.ib(type=str)
+    port = attr.ib(type=int)
+    metadata = attr.ib(type=dict)
+    maximum_buffer = attr.ib(type=int)
+    _buffer = attr.ib(default=attr.Factory(deque), type=deque)
+    _writer = attr.ib(default=None)
+    _logger = attr.ib(default=attr.Factory(Logger))
+
+    def start(self) -> None:
+
+        # Connect without DNS lookups if it's a direct IP.
+        try:
+            ip = ip_address(self.host)
+            if isinstance(ip, IPv4Address):
+                endpoint = TCP4ClientEndpoint(
+                    self.hs.get_reactor(), self.host, self.port
+                )
+            elif isinstance(ip, IPv6Address):
+                endpoint = TCP6ClientEndpoint(
+                    self.hs.get_reactor(), self.host, self.port
+                )
+        except ValueError:
+            endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port)
+
+        factory = Factory.forProtocol(Protocol)
+        self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor())
+        self._service.startService()
+
+    def _write_loop(self) -> None:
+        """
+        Implement the write loop.
+        """
+        if self._writer:
+            return
+
+        self._writer = self._service.whenConnected()
+
+        @self._writer.addBoth
+        def writer(r):
+            if isinstance(r, Failure):
+                r.printTraceback(file=sys.__stderr__)
+                self._writer = None
+                self.hs.get_reactor().callLater(1, self._write_loop)
+                return
+
+            try:
+                for event in self._buffer:
+                    r.transport.write(
+                        dumps(event, ensure_ascii=False, separators=(",", ":")).encode(
+                            "utf8"
+                        )
+                    )
+                    r.transport.write(b"\n")
+                self._buffer.clear()
+            except Exception as e:
+                sys.__stderr__.write("Failed writing out logs with %s\n" % (str(e),))
+
+            self._writer = False
+            self.hs.get_reactor().callLater(1, self._write_loop)
+
+    def _handle_pressure(self) -> None:
+        """
+        Handle backpressure by shedding events.
+
+        The buffer will, in this order, until the buffer is below the maximum:
+            - Shed DEBUG events
+            - Shed INFO events
+            - Shed the middle 50% of the events.
+        """
+        if len(self._buffer) <= self.maximum_buffer:
+            return
+
+        # Strip out DEBUGs
+        self._buffer = deque(
+            filter(lambda event: event["level"] != "DEBUG", self._buffer)
+        )
+
+        if len(self._buffer) <= self.maximum_buffer:
+            return
+
+        # Strip out INFOs
+        self._buffer = deque(
+            filter(lambda event: event["level"] != "INFO", self._buffer)
+        )
+
+        if len(self._buffer) <= self.maximum_buffer:
+            return
+
+        # Cut the middle entries out
+        buffer_split = floor(self.maximum_buffer / 2)
+
+        old_buffer = self._buffer
+        self._buffer = deque()
+
+        for i in range(buffer_split):
+            self._buffer.append(old_buffer.popleft())
+
+        end_buffer = []
+        for i in range(buffer_split):
+            end_buffer.append(old_buffer.pop())
+
+        self._buffer.extend(reversed(end_buffer))
+
+    def __call__(self, event: dict) -> None:
+        flattened = flatten_event(event, self.metadata, include_time=True)
+        self._buffer.append(flattened)
+
+        # Handle backpressure, if it exists.
+        try:
+            self._handle_pressure()
+        except Exception:
+            # If handling backpressure fails,clear the buffer and log the
+            # exception.
+            self._buffer.clear()
+            self._logger.failure("Failed clearing backpressure")
+
+        # Try and write immediately.
+        self._write_loop()
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index b456c31f70..370000e377 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -1,4 +1,5 @@
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -25,6 +26,7 @@ See doc/log_contexts.rst for details on how this works.
 import logging
 import threading
 import types
+from typing import Any, List
 
 from twisted.internet import defer, threads
 
@@ -41,13 +43,17 @@ try:
     # exception.
     resource.getrusage(RUSAGE_THREAD)
 
+    is_thread_resource_usage_supported = True
+
     def get_thread_resource_usage():
         return resource.getrusage(RUSAGE_THREAD)
 
 
 except Exception:
     # If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we
-    # won't track resource usage by returning None.
+    # won't track resource usage.
+    is_thread_resource_usage_supported = False
+
     def get_thread_resource_usage():
         return None
 
@@ -194,7 +200,7 @@ class LoggingContext(object):
     class Sentinel(object):
         """Sentinel to represent the root context"""
 
-        __slots__ = []
+        __slots__ = []  # type: List[Any]
 
         def __str__(self):
             return "sentinel"
@@ -202,6 +208,10 @@ class LoggingContext(object):
         def copy_to(self, record):
             pass
 
+        def copy_to_twisted_log_entry(self, record):
+            record["request"] = None
+            record["scope"] = None
+
         def start(self):
             pass
 
@@ -330,6 +340,13 @@ class LoggingContext(object):
         # we also track the current scope:
         record.scope = self.scope
 
+    def copy_to_twisted_log_entry(self, record):
+        """
+        Copy logging fields from this context to a Twisted log record.
+        """
+        record["request"] = self.request
+        record["scope"] = self.scope
+
     def start(self):
         if get_thread_id() != self.main_thread:
             logger.warning("Started logcontext %s on different thread", self)
@@ -347,7 +364,11 @@ class LoggingContext(object):
 
         # When we stop, let's record the cpu used since we started
         if not self.usage_start:
-            logger.warning("Called stop on logcontext %s without calling start", self)
+            # Log a warning on platforms that support thread usage tracking
+            if is_thread_resource_usage_supported:
+                logger.warning(
+                    "Called stop on logcontext %s without calling start", self
+                )
             return
 
         utime_delta, stime_delta = self._get_cputime()
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index d2c209c471..308a27213b 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -43,6 +43,9 @@ OpenTracing to be easily disabled in Synapse and thereby have OpenTracing as
 an optional dependency. This does however limit the number of modifiable spans
 at any point in the code to one. From here out references to `opentracing`
 in the code snippets refer to the Synapses module.
+Most methods provided in the module have a direct correlation to those provided
+by opentracing. Refer to docs there for a more in-depth documentation on some of
+the args and methods.
 
 Tracing
 -------
@@ -68,52 +71,62 @@ set a tag on the current active span.
 Tracing functions
 -----------------
 
-Functions can be easily traced using decorators. There is a decorator for
-'normal' function and for functions which are actually deferreds. The name of
+Functions can be easily traced using decorators. The name of
 the function becomes the operation name for the span.
 
 .. code-block:: python
 
-   from synapse.logging.opentracing import trace, trace_deferred
+   from synapse.logging.opentracing import trace
 
-   # Start a span using 'normal_function' as the operation name
+   # Start a span using 'interesting_function' as the operation name
    @trace
-   def normal_function(*args, **kwargs):
+   def interesting_function(*args, **kwargs):
        # Does all kinds of cool and expected things
        return something_usual_and_useful
 
-   # Start a span using 'deferred_function' as the operation name
-   @trace_deferred
-   @defer.inlineCallbacks
-   def deferred_function(*args, **kwargs):
-       # We start
-       yield we_wait
-       # we finish
-       return something_usual_and_useful
 
-Operation names can be explicitly set for functions by using
-``trace_using_operation_name`` and
-``trace_deferred_using_operation_name``
+Operation names can be explicitly set for a function by passing the
+operation name to ``trace``
 
 .. code-block:: python
 
-   from synapse.logging.opentracing import (
-       trace_using_operation_name,
-       trace_deferred_using_operation_name
-   )
+   from synapse.logging.opentracing import trace
 
-   @trace_using_operation_name("A *much* better operation name")
-   def normal_function(*args, **kwargs):
+   @trace(opname="a_better_operation_name")
+   def interesting_badly_named_function(*args, **kwargs):
        # Does all kinds of cool and expected things
        return something_usual_and_useful
 
-   @trace_deferred_using_operation_name("Another exciting operation name!")
-   @defer.inlineCallbacks
-   def deferred_function(*args, **kwargs):
-       # We start
-       yield we_wait
-       # we finish
-       return something_usual_and_useful
+Setting Tags
+------------
+
+To set a tag on the active span do
+
+.. code-block:: python
+
+   from synapse.logging.opentracing import set_tag
+
+   set_tag(tag_name, tag_value)
+
+There's a convenient decorator to tag all the args of the method. It uses
+inspection in order to use the formal parameter names prefixed with 'ARG_' as
+tag names. It uses kwarg names as tag names without the prefix.
+
+.. code-block:: python
+
+   from synapse.logging.opentracing import tag_args
+
+   @tag_args
+   def set_fates(clotho, lachesis, atropos, father="Zues", mother="Themis"):
+       pass
+
+   set_fates("the story", "the end", "the act")
+   # This will have the following tags
+   #  - ARG_clotho: "the story"
+   #  - ARG_lachesis: "the end"
+   #  - ARG_atropos: "the act"
+   #  - father: "Zues"
+   #  - mother: "Themis"
 
 Contexts and carriers
 ---------------------
@@ -136,6 +149,9 @@ unchartered waters will require the enforcement of the whitelist.
 ``logging/opentracing.py`` has a ``whitelisted_homeserver`` method which takes
 in a destination and compares it to the whitelist.
 
+Most injection methods take a 'destination' arg. The context will only be injected
+if the destination matches the whitelist or the destination is None.
+
 =======
 Gotchas
 =======
@@ -161,16 +177,54 @@ from twisted.internet import defer
 
 from synapse.config import ConfigError
 
+# Helper class
+
+
+class _DummyTagNames(object):
+    """wrapper of opentracings tags. We need to have them if we
+    want to reference them without opentracing around. Clearly they
+    should never actually show up in a trace. `set_tags` overwrites
+    these with the correct ones."""
+
+    INVALID_TAG = "invalid-tag"
+    COMPONENT = INVALID_TAG
+    DATABASE_INSTANCE = INVALID_TAG
+    DATABASE_STATEMENT = INVALID_TAG
+    DATABASE_TYPE = INVALID_TAG
+    DATABASE_USER = INVALID_TAG
+    ERROR = INVALID_TAG
+    HTTP_METHOD = INVALID_TAG
+    HTTP_STATUS_CODE = INVALID_TAG
+    HTTP_URL = INVALID_TAG
+    MESSAGE_BUS_DESTINATION = INVALID_TAG
+    PEER_ADDRESS = INVALID_TAG
+    PEER_HOSTNAME = INVALID_TAG
+    PEER_HOST_IPV4 = INVALID_TAG
+    PEER_HOST_IPV6 = INVALID_TAG
+    PEER_PORT = INVALID_TAG
+    PEER_SERVICE = INVALID_TAG
+    SAMPLING_PRIORITY = INVALID_TAG
+    SERVICE = INVALID_TAG
+    SPAN_KIND = INVALID_TAG
+    SPAN_KIND_CONSUMER = INVALID_TAG
+    SPAN_KIND_PRODUCER = INVALID_TAG
+    SPAN_KIND_RPC_CLIENT = INVALID_TAG
+    SPAN_KIND_RPC_SERVER = INVALID_TAG
+
+
 try:
     import opentracing
+
+    tags = opentracing.tags
 except ImportError:
     opentracing = None
+    tags = _DummyTagNames
 try:
     from jaeger_client import Config as JaegerConfig
     from synapse.logging.scopecontextmanager import LogContextScopeManager
 except ImportError:
-    JaegerConfig = None
-    LogContextScopeManager = None
+    JaegerConfig = None  # type: ignore
+    LogContextScopeManager = None  # type: ignore
 
 
 logger = logging.getLogger(__name__)
@@ -185,8 +239,7 @@ _homeserver_whitelist = None
 
 
 def only_if_tracing(func):
-    """Executes the function only if we're tracing. Otherwise return.
-    Assumes the function wrapped may return None"""
+    """Executes the function only if we're tracing. Otherwise returns None."""
 
     @wraps(func)
     def _only_if_tracing_inner(*args, **kwargs):
@@ -198,6 +251,41 @@ def only_if_tracing(func):
     return _only_if_tracing_inner
 
 
+def ensure_active_span(message, ret=None):
+    """Executes the operation only if opentracing is enabled and there is an active span.
+    If there is no active span it logs message at the error level.
+
+    Args:
+        message (str): Message which fills in "There was no active span when trying to %s"
+            in the error log if there is no active span and opentracing is enabled.
+        ret (object): return value if opentracing is None or there is no active span.
+
+    Returns (object): The result of the func or ret if opentracing is disabled or there
+        was no active span.
+    """
+
+    def ensure_active_span_inner_1(func):
+        @wraps(func)
+        def ensure_active_span_inner_2(*args, **kwargs):
+            if not opentracing:
+                return ret
+
+            if not opentracing.tracer.active_span:
+                logger.error(
+                    "There was no active span when trying to %s."
+                    " Did you forget to start one or did a context slip?",
+                    message,
+                )
+
+                return ret
+
+            return func(*args, **kwargs)
+
+        return ensure_active_span_inner_2
+
+    return ensure_active_span_inner_1
+
+
 @contextlib.contextmanager
 def _noop_context_manager(*args, **kwargs):
     """Does exactly what it says on the tin"""
@@ -239,10 +327,6 @@ def init_tracer(config):
         scope_manager=LogContextScopeManager(config),
     ).initialize_tracer()
 
-    # Set up tags to be opentracing's tags
-    global tags
-    tags = opentracing.tags
-
 
 # Whitelisting
 
@@ -269,7 +353,7 @@ def whitelisted_homeserver(destination):
     Args:
         destination (str)
         """
-    _homeserver_whitelist
+
     if _homeserver_whitelist:
         return _homeserver_whitelist.match(destination)
     return False
@@ -299,30 +383,28 @@ def start_active_span(
     if opentracing is None:
         return _noop_context_manager()
 
-    else:
-        # We need to enter the scope here for the logcontext to become active
-        return opentracing.tracer.start_active_span(
-            operation_name,
-            child_of=child_of,
-            references=references,
-            tags=tags,
-            start_time=start_time,
-            ignore_active_span=ignore_active_span,
-            finish_on_close=finish_on_close,
-        )
+    return opentracing.tracer.start_active_span(
+        operation_name,
+        child_of=child_of,
+        references=references,
+        tags=tags,
+        start_time=start_time,
+        ignore_active_span=ignore_active_span,
+        finish_on_close=finish_on_close,
+    )
 
 
 def start_active_span_follows_from(operation_name, contexts):
     if opentracing is None:
         return _noop_context_manager()
-    else:
-        references = [opentracing.follows_from(context) for context in contexts]
-        scope = start_active_span(operation_name, references=references)
-        return scope
+
+    references = [opentracing.follows_from(context) for context in contexts]
+    scope = start_active_span(operation_name, references=references)
+    return scope
 
 
-def start_active_span_from_context(
-    headers,
+def start_active_span_from_request(
+    request,
     operation_name,
     references=None,
     tags=None,
@@ -331,9 +413,9 @@ def start_active_span_from_context(
     finish_on_close=True,
 ):
     """
-    Extracts a span context from Twisted Headers.
+    Extracts a span context from a Twisted Request.
     args:
-        headers (twisted.web.http_headers.Headers)
+        headers (twisted.web.http.Request)
 
         For the other args see opentracing.tracer
 
@@ -347,7 +429,9 @@ def start_active_span_from_context(
     if opentracing is None:
         return _noop_context_manager()
 
-    header_dict = {k.decode(): v[0].decode() for k, v in headers.getAllRawHeaders()}
+    header_dict = {
+        k.decode(): v[0].decode() for k, v in request.requestHeaders.getAllRawHeaders()
+    }
     context = opentracing.tracer.extract(opentracing.Format.HTTP_HEADERS, header_dict)
 
     return opentracing.tracer.start_active_span(
@@ -413,19 +497,19 @@ def start_active_span_from_edu(
 # Opentracing setters for tags, logs, etc
 
 
-@only_if_tracing
+@ensure_active_span("set a tag")
 def set_tag(key, value):
     """Sets a tag on the active span"""
     opentracing.tracer.active_span.set_tag(key, value)
 
 
-@only_if_tracing
+@ensure_active_span("log")
 def log_kv(key_values, timestamp=None):
     """Log to the active span"""
     opentracing.tracer.active_span.log_kv(key_values, timestamp)
 
 
-@only_if_tracing
+@ensure_active_span("set the traces operation name")
 def set_operation_name(operation_name):
     """Sets the operation name of the active span"""
     opentracing.tracer.active_span.set_operation_name(operation_name)
@@ -434,13 +518,18 @@ def set_operation_name(operation_name):
 # Injection and extraction
 
 
-@only_if_tracing
-def inject_active_span_twisted_headers(headers, destination):
+@ensure_active_span("inject the span into a header")
+def inject_active_span_twisted_headers(headers, destination, check_destination=True):
     """
     Injects a span context into twisted headers in-place
 
     Args:
         headers (twisted.web.http_headers.Headers)
+        destination (str): address of entity receiving the span context. If check_destination
+            is true the context will only be injected if the destination matches the
+            opentracing whitelist
+        check_destination (bool): If false, destination will be ignored and the context
+            will always be injected.
         span (opentracing.Span)
 
     Returns:
@@ -454,7 +543,7 @@ def inject_active_span_twisted_headers(headers, destination):
         https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py
     """
 
-    if not whitelisted_homeserver(destination):
+    if check_destination and not whitelisted_homeserver(destination):
         return
 
     span = opentracing.tracer.active_span
@@ -465,14 +554,19 @@ def inject_active_span_twisted_headers(headers, destination):
         headers.addRawHeaders(key, value)
 
 
-@only_if_tracing
-def inject_active_span_byte_dict(headers, destination):
+@ensure_active_span("inject the span into a byte dict")
+def inject_active_span_byte_dict(headers, destination, check_destination=True):
     """
     Injects a span context into a dict where the headers are encoded as byte
     strings
 
     Args:
         headers (dict)
+        destination (str): address of entity receiving the span context. If check_destination
+            is true the context will only be injected if the destination matches the
+            opentracing whitelist
+        check_destination (bool): If false, destination will be ignored and the context
+            will always be injected.
         span (opentracing.Span)
 
     Returns:
@@ -485,7 +579,7 @@ def inject_active_span_byte_dict(headers, destination):
         here:
         https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py
     """
-    if not whitelisted_homeserver(destination):
+    if check_destination and not whitelisted_homeserver(destination):
         return
 
     span = opentracing.tracer.active_span
@@ -497,16 +591,18 @@ def inject_active_span_byte_dict(headers, destination):
         headers[key.encode()] = [value.encode()]
 
 
-@only_if_tracing
-def inject_active_span_text_map(carrier, destination=None):
+@ensure_active_span("inject the span into a text map")
+def inject_active_span_text_map(carrier, destination, check_destination=True):
     """
     Injects a span context into a dict
 
     Args:
         carrier (dict)
-        destination (str): the name of the remote server. The span context
-        will only be injected if the destination matches the homeserver_whitelist
-        or destination is None.
+        destination (str): address of entity receiving the span context. If check_destination
+            is true the context will only be injected if the destination matches the
+            opentracing whitelist
+        check_destination (bool): If false, destination will be ignored and the context
+            will always be injected.
 
     Returns:
         In-place modification of carrier
@@ -519,7 +615,7 @@ def inject_active_span_text_map(carrier, destination=None):
         https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py
     """
 
-    if destination and not whitelisted_homeserver(destination):
+    if check_destination and not whitelisted_homeserver(destination):
         return
 
     opentracing.tracer.inject(
@@ -527,6 +623,31 @@ def inject_active_span_text_map(carrier, destination=None):
     )
 
 
+@ensure_active_span("get the active span context as a dict", ret={})
+def get_active_span_text_map(destination=None):
+    """
+    Gets a span context as a dict. This can be used instead of manually
+    injecting a span into an empty carrier.
+
+    Args:
+        destination (str): the name of the remote server.
+
+    Returns:
+        dict: the active span's context if opentracing is enabled, otherwise empty.
+    """
+
+    if destination and not whitelisted_homeserver(destination):
+        return {}
+
+    carrier = {}
+    opentracing.tracer.inject(
+        opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
+    )
+
+    return carrier
+
+
+@ensure_active_span("get the span context as a string.", ret={})
 def active_span_context_as_string():
     """
     Returns:
@@ -566,70 +687,30 @@ def extract_text_map(carrier):
 # Tracing decorators
 
 
-def trace(func):
+def trace(func=None, opname=None):
     """
     Decorator to trace a function.
-    Sets the operation name to that of the function's.
+    Sets the operation name to that of the function's or that given
+    as operation_name. See the module's doc string for usage
+    examples.
     """
-    if opentracing is None:
-        return func
-
-    @wraps(func)
-    def _trace_inner(self, *args, **kwargs):
-        if opentracing is None:
-            return func(self, *args, **kwargs)
 
-        scope = start_active_span(func.__name__)
-        scope.__enter__()
-
-        try:
-            result = func(self, *args, **kwargs)
-            if isinstance(result, defer.Deferred):
-
-                def call_back(result):
-                    scope.__exit__(None, None, None)
-                    return result
-
-                def err_back(result):
-                    scope.span.set_tag(tags.ERROR, True)
-                    scope.__exit__(None, None, None)
-                    return result
-
-                result.addCallbacks(call_back, err_back)
-
-            else:
-                scope.__exit__(None, None, None)
-
-            return result
-
-        except Exception as e:
-            scope.__exit__(type(e), None, e.__traceback__)
-            raise
-
-    return _trace_inner
-
-
-def trace_using_operation_name(operation_name):
-    """Decorator to trace a function. Explicitely sets the operation_name."""
-
-    def trace(func):
-        """
-        Decorator to trace a function.
-        Sets the operation name to that of the function's.
-        """
+    def decorator(func):
         if opentracing is None:
             return func
 
+        _opname = opname if opname else func.__name__
+
         @wraps(func)
-        def _trace_inner(self, *args, **kwargs):
+        def _trace_inner(*args, **kwargs):
             if opentracing is None:
-                return func(self, *args, **kwargs)
+                return func(*args, **kwargs)
 
-            scope = start_active_span(operation_name)
+            scope = start_active_span(_opname)
             scope.__enter__()
 
             try:
-                result = func(self, *args, **kwargs)
+                result = func(*args, **kwargs)
                 if isinstance(result, defer.Deferred):
 
                     def call_back(result):
@@ -642,6 +723,7 @@ def trace_using_operation_name(operation_name):
                         return result
 
                     result.addCallbacks(call_back, err_back)
+
                 else:
                     scope.__exit__(None, None, None)
 
@@ -653,7 +735,10 @@ def trace_using_operation_name(operation_name):
 
         return _trace_inner
 
-    return trace
+    if func:
+        return decorator(func)
+    else:
+        return decorator
 
 
 def tag_args(func):
@@ -665,76 +750,54 @@ def tag_args(func):
         return func
 
     @wraps(func)
-    def _tag_args_inner(self, *args, **kwargs):
+    def _tag_args_inner(*args, **kwargs):
         argspec = inspect.getargspec(func)
         for i, arg in enumerate(argspec.args[1:]):
             set_tag("ARG_" + arg, args[i])
         set_tag("args", args[len(argspec.args) :])
         set_tag("kwargs", kwargs)
-        return func(self, *args, **kwargs)
+        return func(*args, **kwargs)
 
     return _tag_args_inner
 
 
-def trace_servlet(servlet_name, func):
+def trace_servlet(servlet_name, extract_context=False):
     """Decorator which traces a serlet. It starts a span with some servlet specific
-    tags such as the servlet_name and request information"""
-    if not opentracing:
-        return func
+    tags such as the servlet_name and request information
 
-    @wraps(func)
-    @defer.inlineCallbacks
-    def _trace_servlet_inner(request, *args, **kwargs):
-        with start_active_span(
-            "incoming-client-request",
-            tags={
+    Args:
+        servlet_name (str): The name to be used for the span's operation_name
+        extract_context (bool): Whether to attempt to extract the opentracing
+            context from the request the servlet is handling.
+
+    """
+
+    def _trace_servlet_inner_1(func):
+        if not opentracing:
+            return func
+
+        @wraps(func)
+        @defer.inlineCallbacks
+        def _trace_servlet_inner(request, *args, **kwargs):
+            request_tags = {
                 "request_id": request.get_request_id(),
                 tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
                 tags.HTTP_METHOD: request.get_method(),
                 tags.HTTP_URL: request.get_redacted_uri(),
                 tags.PEER_HOST_IPV6: request.getClientIP(),
-                "servlet_name": servlet_name,
-            },
-        ):
-            result = yield defer.maybeDeferred(func, request, *args, **kwargs)
-            return result
-
-    return _trace_servlet_inner
-
-
-# Helper class
+            }
 
+            if extract_context:
+                scope = start_active_span_from_request(
+                    request, servlet_name, tags=request_tags
+                )
+            else:
+                scope = start_active_span(servlet_name, tags=request_tags)
 
-class _DummyTagNames(object):
-    """wrapper of opentracings tags. We need to have them if we
-    want to reference them without opentracing around. Clearly they
-    should never actually show up in a trace. `set_tags` overwrites
-    these with the correct ones."""
-
-    INVALID_TAG = "invalid-tag"
-    COMPONENT = INVALID_TAG
-    DATABASE_INSTANCE = INVALID_TAG
-    DATABASE_STATEMENT = INVALID_TAG
-    DATABASE_TYPE = INVALID_TAG
-    DATABASE_USER = INVALID_TAG
-    ERROR = INVALID_TAG
-    HTTP_METHOD = INVALID_TAG
-    HTTP_STATUS_CODE = INVALID_TAG
-    HTTP_URL = INVALID_TAG
-    MESSAGE_BUS_DESTINATION = INVALID_TAG
-    PEER_ADDRESS = INVALID_TAG
-    PEER_HOSTNAME = INVALID_TAG
-    PEER_HOST_IPV4 = INVALID_TAG
-    PEER_HOST_IPV6 = INVALID_TAG
-    PEER_PORT = INVALID_TAG
-    PEER_SERVICE = INVALID_TAG
-    SAMPLING_PRIORITY = INVALID_TAG
-    SERVICE = INVALID_TAG
-    SPAN_KIND = INVALID_TAG
-    SPAN_KIND_CONSUMER = INVALID_TAG
-    SPAN_KIND_PRODUCER = INVALID_TAG
-    SPAN_KIND_RPC_CLIENT = INVALID_TAG
-    SPAN_KIND_RPC_SERVER = INVALID_TAG
+            with scope:
+                result = yield defer.maybeDeferred(func, request, *args, **kwargs)
+                return result
 
+        return _trace_servlet_inner
 
-tags = _DummyTagNames
+    return _trace_servlet_inner_1