summary refs log tree commit diff
path: root/synapse/logging
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/logging')
-rw-r--r--synapse/logging/_structured.py374
-rw-r--r--synapse/logging/_terse_json.py278
-rw-r--r--synapse/logging/context.py14
-rw-r--r--synapse/logging/opentracing.py237
4 files changed, 810 insertions, 93 deletions
diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
new file mode 100644
index 0000000000..0367d6dfc4
--- /dev/null
+++ b/synapse/logging/_structured.py
@@ -0,0 +1,374 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os.path
+import sys
+import typing
+import warnings
+
+import attr
+from constantly import NamedConstant, Names, ValueConstant, Values
+from zope.interface import implementer
+
+from twisted.logger import (
+    FileLogObserver,
+    FilteringLogObserver,
+    ILogObserver,
+    LogBeginner,
+    Logger,
+    LogLevel,
+    LogLevelFilterPredicate,
+    LogPublisher,
+    eventAsText,
+    globalLogBeginner,
+    jsonFileLogObserver,
+)
+
+from synapse.config._base import ConfigError
+from synapse.logging._terse_json import (
+    TerseJSONToConsoleLogObserver,
+    TerseJSONToTCPLogObserver,
+)
+from synapse.logging.context import LoggingContext
+
+
+def stdlib_log_level_to_twisted(level: str) -> LogLevel:
+    """
+    Convert a stdlib log level to Twisted's log level.
+    """
+    lvl = level.lower().replace("warning", "warn")
+    return LogLevel.levelWithName(lvl)
+
+
+@attr.s
+@implementer(ILogObserver)
+class LogContextObserver(object):
+    """
+    An ILogObserver which adds Synapse-specific log context information.
+
+    Attributes:
+        observer (ILogObserver): The target parent observer.
+    """
+
+    observer = attr.ib()
+
+    def __call__(self, event: dict) -> None:
+        """
+        Consume a log event and emit it to the parent observer after filtering
+        and adding log context information.
+
+        Args:
+            event (dict)
+        """
+        # Filter out some useless events that Twisted outputs
+        if "log_text" in event:
+            if event["log_text"].startswith("DNSDatagramProtocol starting on "):
+                return
+
+            if event["log_text"].startswith("(UDP Port "):
+                return
+
+            if event["log_text"].startswith("Timing out client") or event[
+                "log_format"
+            ].startswith("Timing out client"):
+                return
+
+        context = LoggingContext.current_context()
+
+        # Copy the context information to the log event.
+        if context is not None:
+            context.copy_to_twisted_log_entry(event)
+        else:
+            # If there's no logging context, not even the root one, we might be
+            # starting up or it might be from non-Synapse code. Log it as if it
+            # came from the root logger.
+            event["request"] = None
+            event["scope"] = None
+
+        self.observer(event)
+
+
+class PythonStdlibToTwistedLogger(logging.Handler):
+    """
+    Transform a Python stdlib log message into a Twisted one.
+    """
+
+    def __init__(self, observer, *args, **kwargs):
+        """
+        Args:
+            observer (ILogObserver): A Twisted logging observer.
+            *args, **kwargs: Args/kwargs to be passed to logging.Handler.
+        """
+        self.observer = observer
+        super().__init__(*args, **kwargs)
+
+    def emit(self, record: logging.LogRecord) -> None:
+        """
+        Emit a record to Twisted's observer.
+
+        Args:
+            record (logging.LogRecord)
+        """
+
+        self.observer(
+            {
+                "log_time": record.created,
+                "log_text": record.getMessage(),
+                "log_format": "{log_text}",
+                "log_namespace": record.name,
+                "log_level": stdlib_log_level_to_twisted(record.levelname),
+            }
+        )
+
+
+def SynapseFileLogObserver(outFile: typing.io.TextIO) -> FileLogObserver:
+    """
+    A log observer that formats events like the traditional log formatter and
+    sends them to `outFile`.
+
+    Args:
+        outFile (file object): The file object to write to.
+    """
+
+    def formatEvent(_event: dict) -> str:
+        event = dict(_event)
+        event["log_level"] = event["log_level"].name.upper()
+        event["log_format"] = "- {log_namespace} - {log_level} - {request} - " + (
+            event.get("log_format", "{log_text}") or "{log_text}"
+        )
+        return eventAsText(event, includeSystem=False) + "\n"
+
+    return FileLogObserver(outFile, formatEvent)
+
+
+class DrainType(Names):
+    CONSOLE = NamedConstant()
+    CONSOLE_JSON = NamedConstant()
+    CONSOLE_JSON_TERSE = NamedConstant()
+    FILE = NamedConstant()
+    FILE_JSON = NamedConstant()
+    NETWORK_JSON_TERSE = NamedConstant()
+
+
+class OutputPipeType(Values):
+    stdout = ValueConstant(sys.__stdout__)
+    stderr = ValueConstant(sys.__stderr__)
+
+
+@attr.s
+class DrainConfiguration(object):
+    name = attr.ib()
+    type = attr.ib()
+    location = attr.ib()
+    options = attr.ib(default=None)
+
+
+@attr.s
+class NetworkJSONTerseOptions(object):
+    maximum_buffer = attr.ib(type=int)
+
+
+DEFAULT_LOGGERS = {"synapse": {"level": "INFO"}}
+
+
+def parse_drain_configs(
+    drains: dict
+) -> typing.Generator[DrainConfiguration, None, None]:
+    """
+    Parse the drain configurations.
+
+    Args:
+        drains (dict): A list of drain configurations.
+
+    Yields:
+        DrainConfiguration instances.
+
+    Raises:
+        ConfigError: If any of the drain configuration items are invalid.
+    """
+    for name, config in drains.items():
+        if "type" not in config:
+            raise ConfigError("Logging drains require a 'type' key.")
+
+        try:
+            logging_type = DrainType.lookupByName(config["type"].upper())
+        except ValueError:
+            raise ConfigError(
+                "%s is not a known logging drain type." % (config["type"],)
+            )
+
+        if logging_type in [
+            DrainType.CONSOLE,
+            DrainType.CONSOLE_JSON,
+            DrainType.CONSOLE_JSON_TERSE,
+        ]:
+            location = config.get("location")
+            if location is None or location not in ["stdout", "stderr"]:
+                raise ConfigError(
+                    (
+                        "The %s drain needs the 'location' key set to "
+                        "either 'stdout' or 'stderr'."
+                    )
+                    % (logging_type,)
+                )
+
+            pipe = OutputPipeType.lookupByName(location).value
+
+            yield DrainConfiguration(name=name, type=logging_type, location=pipe)
+
+        elif logging_type in [DrainType.FILE, DrainType.FILE_JSON]:
+            if "location" not in config:
+                raise ConfigError(
+                    "The %s drain needs the 'location' key set." % (logging_type,)
+                )
+
+            location = config.get("location")
+            if os.path.abspath(location) != location:
+                raise ConfigError(
+                    "File paths need to be absolute, '%s' is a relative path"
+                    % (location,)
+                )
+            yield DrainConfiguration(name=name, type=logging_type, location=location)
+
+        elif logging_type in [DrainType.NETWORK_JSON_TERSE]:
+            host = config.get("host")
+            port = config.get("port")
+            maximum_buffer = config.get("maximum_buffer", 1000)
+            yield DrainConfiguration(
+                name=name,
+                type=logging_type,
+                location=(host, port),
+                options=NetworkJSONTerseOptions(maximum_buffer=maximum_buffer),
+            )
+
+        else:
+            raise ConfigError(
+                "The %s drain type is currently not implemented."
+                % (config["type"].upper(),)
+            )
+
+
+def setup_structured_logging(
+    hs,
+    config,
+    log_config: dict,
+    logBeginner: LogBeginner = globalLogBeginner,
+    redirect_stdlib_logging: bool = True,
+) -> LogPublisher:
+    """
+    Set up Twisted's structured logging system.
+
+    Args:
+        hs: The homeserver to use.
+        config (HomeserverConfig): The configuration of the Synapse homeserver.
+        log_config (dict): The log configuration to use.
+    """
+    if config.no_redirect_stdio:
+        raise ConfigError(
+            "no_redirect_stdio cannot be defined using structured logging."
+        )
+
+    logger = Logger()
+
+    if "drains" not in log_config:
+        raise ConfigError("The logging configuration requires a list of drains.")
+
+    observers = []
+
+    for observer in parse_drain_configs(log_config["drains"]):
+        # Pipe drains
+        if observer.type == DrainType.CONSOLE:
+            logger.debug(
+                "Starting up the {name} console logger drain", name=observer.name
+            )
+            observers.append(SynapseFileLogObserver(observer.location))
+        elif observer.type == DrainType.CONSOLE_JSON:
+            logger.debug(
+                "Starting up the {name} JSON console logger drain", name=observer.name
+            )
+            observers.append(jsonFileLogObserver(observer.location))
+        elif observer.type == DrainType.CONSOLE_JSON_TERSE:
+            logger.debug(
+                "Starting up the {name} terse JSON console logger drain",
+                name=observer.name,
+            )
+            observers.append(
+                TerseJSONToConsoleLogObserver(observer.location, metadata={})
+            )
+
+        # File drains
+        elif observer.type == DrainType.FILE:
+            logger.debug("Starting up the {name} file logger drain", name=observer.name)
+            log_file = open(observer.location, "at", buffering=1, encoding="utf8")
+            observers.append(SynapseFileLogObserver(log_file))
+        elif observer.type == DrainType.FILE_JSON:
+            logger.debug(
+                "Starting up the {name} JSON file logger drain", name=observer.name
+            )
+            log_file = open(observer.location, "at", buffering=1, encoding="utf8")
+            observers.append(jsonFileLogObserver(log_file))
+
+        elif observer.type == DrainType.NETWORK_JSON_TERSE:
+            metadata = {"server_name": hs.config.server_name}
+            log_observer = TerseJSONToTCPLogObserver(
+                hs=hs,
+                host=observer.location[0],
+                port=observer.location[1],
+                metadata=metadata,
+                maximum_buffer=observer.options.maximum_buffer,
+            )
+            log_observer.start()
+            observers.append(log_observer)
+        else:
+            # We should never get here, but, just in case, throw an error.
+            raise ConfigError("%s drain type cannot be configured" % (observer.type,))
+
+    publisher = LogPublisher(*observers)
+    log_filter = LogLevelFilterPredicate()
+
+    for namespace, namespace_config in log_config.get(
+        "loggers", DEFAULT_LOGGERS
+    ).items():
+        # Set the log level for twisted.logger.Logger namespaces
+        log_filter.setLogLevelForNamespace(
+            namespace,
+            stdlib_log_level_to_twisted(namespace_config.get("level", "INFO")),
+        )
+
+        # Also set the log levels for the stdlib logger namespaces, to prevent
+        # them getting to PythonStdlibToTwistedLogger and having to be formatted
+        if "level" in namespace_config:
+            logging.getLogger(namespace).setLevel(namespace_config.get("level"))
+
+    f = FilteringLogObserver(publisher, [log_filter])
+    lco = LogContextObserver(f)
+
+    if redirect_stdlib_logging:
+        stuff_into_twisted = PythonStdlibToTwistedLogger(lco)
+        stdliblogger = logging.getLogger()
+        stdliblogger.addHandler(stuff_into_twisted)
+
+    # Always redirect standard I/O, otherwise other logging outputs might miss
+    # it.
+    logBeginner.beginLoggingTo([lco], redirectStandardIO=True)
+
+    return publisher
+
+
+def reload_structured_logging(*args, log_config=None) -> None:
+    warnings.warn(
+        "Currently the structured logging system can not be reloaded, doing nothing"
+    )
diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py
new file mode 100644
index 0000000000..7f1e8f23fe
--- /dev/null
+++ b/synapse/logging/_terse_json.py
@@ -0,0 +1,278 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Log formatters that output terse JSON.
+"""
+
+import sys
+from collections import deque
+from ipaddress import IPv4Address, IPv6Address, ip_address
+from math import floor
+from typing.io import TextIO
+
+import attr
+from simplejson import dumps
+
+from twisted.application.internet import ClientService
+from twisted.internet.endpoints import (
+    HostnameEndpoint,
+    TCP4ClientEndpoint,
+    TCP6ClientEndpoint,
+)
+from twisted.internet.protocol import Factory, Protocol
+from twisted.logger import FileLogObserver, Logger
+from twisted.python.failure import Failure
+
+
+def flatten_event(event: dict, metadata: dict, include_time: bool = False):
+    """
+    Flatten a Twisted logging event to an dictionary capable of being sent
+    as a log event to a logging aggregation system.
+
+    The format is vastly simplified and is not designed to be a "human readable
+    string" in the sense that traditional logs are. Instead, the structure is
+    optimised for searchability and filtering, with human-understandable log
+    keys.
+
+    Args:
+        event (dict): The Twisted logging event we are flattening.
+        metadata (dict): Additional data to include with each log message. This
+            can be information like the server name. Since the target log
+            consumer does not know who we are other than by host IP, this
+            allows us to forward through static information.
+        include_time (bool): Should we include the `time` key? If False, the
+            event time is stripped from the event.
+    """
+    new_event = {}
+
+    # If it's a failure, make the new event's log_failure be the traceback text.
+    if "log_failure" in event:
+        new_event["log_failure"] = event["log_failure"].getTraceback()
+
+    # If it's a warning, copy over a string representation of the warning.
+    if "warning" in event:
+        new_event["warning"] = str(event["warning"])
+
+    # Stdlib logging events have "log_text" as their human-readable portion,
+    # Twisted ones have "log_format". For now, include the log_format, so that
+    # context only given in the log format (e.g. what is being logged) is
+    # available.
+    if "log_text" in event:
+        new_event["log"] = event["log_text"]
+    else:
+        new_event["log"] = event["log_format"]
+
+    # We want to include the timestamp when forwarding over the network, but
+    # exclude it when we are writing to stdout. This is because the log ingester
+    # (e.g. logstash, fluentd) can add its own timestamp.
+    if include_time:
+        new_event["time"] = round(event["log_time"], 2)
+
+    # Convert the log level to a textual representation.
+    new_event["level"] = event["log_level"].name.upper()
+
+    # Ignore these keys, and do not transfer them over to the new log object.
+    # They are either useless (isError), transferred manually above (log_time,
+    # log_level, etc), or contain Python objects which are not useful for output
+    # (log_logger, log_source).
+    keys_to_delete = [
+        "isError",
+        "log_failure",
+        "log_format",
+        "log_level",
+        "log_logger",
+        "log_source",
+        "log_system",
+        "log_time",
+        "log_text",
+        "observer",
+        "warning",
+    ]
+
+    # If it's from the Twisted legacy logger (twisted.python.log), it adds some
+    # more keys we want to purge.
+    if event.get("log_namespace") == "log_legacy":
+        keys_to_delete.extend(["message", "system", "time"])
+
+    # Rather than modify the dictionary in place, construct a new one with only
+    # the content we want. The original event should be considered 'frozen'.
+    for key in event.keys():
+
+        if key in keys_to_delete:
+            continue
+
+        if isinstance(event[key], (str, int, bool, float)) or event[key] is None:
+            # If it's a plain type, include it as is.
+            new_event[key] = event[key]
+        else:
+            # If it's not one of those basic types, write out a string
+            # representation. This should probably be a warning in development,
+            # so that we are sure we are only outputting useful data.
+            new_event[key] = str(event[key])
+
+    # Add the metadata information to the event (e.g. the server_name).
+    new_event.update(metadata)
+
+    return new_event
+
+
+def TerseJSONToConsoleLogObserver(outFile: TextIO, metadata: dict) -> FileLogObserver:
+    """
+    A log observer that formats events to a flattened JSON representation.
+
+    Args:
+        outFile: The file object to write to.
+        metadata: Metadata to be added to each log object.
+    """
+
+    def formatEvent(_event: dict) -> str:
+        flattened = flatten_event(_event, metadata)
+        return dumps(flattened, ensure_ascii=False, separators=(",", ":")) + "\n"
+
+    return FileLogObserver(outFile, formatEvent)
+
+
+@attr.s
+class TerseJSONToTCPLogObserver(object):
+    """
+    An IObserver that writes JSON logs to a TCP target.
+
+    Args:
+        hs (HomeServer): The Homeserver that is being logged for.
+        host: The host of the logging target.
+        port: The logging target's port.
+        metadata: Metadata to be added to each log entry.
+    """
+
+    hs = attr.ib()
+    host = attr.ib(type=str)
+    port = attr.ib(type=int)
+    metadata = attr.ib(type=dict)
+    maximum_buffer = attr.ib(type=int)
+    _buffer = attr.ib(default=attr.Factory(deque), type=deque)
+    _writer = attr.ib(default=None)
+    _logger = attr.ib(default=attr.Factory(Logger))
+
+    def start(self) -> None:
+
+        # Connect without DNS lookups if it's a direct IP.
+        try:
+            ip = ip_address(self.host)
+            if isinstance(ip, IPv4Address):
+                endpoint = TCP4ClientEndpoint(
+                    self.hs.get_reactor(), self.host, self.port
+                )
+            elif isinstance(ip, IPv6Address):
+                endpoint = TCP6ClientEndpoint(
+                    self.hs.get_reactor(), self.host, self.port
+                )
+        except ValueError:
+            endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port)
+
+        factory = Factory.forProtocol(Protocol)
+        self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor())
+        self._service.startService()
+
+    def _write_loop(self) -> None:
+        """
+        Implement the write loop.
+        """
+        if self._writer:
+            return
+
+        self._writer = self._service.whenConnected()
+
+        @self._writer.addBoth
+        def writer(r):
+            if isinstance(r, Failure):
+                r.printTraceback(file=sys.__stderr__)
+                self._writer = None
+                self.hs.get_reactor().callLater(1, self._write_loop)
+                return
+
+            try:
+                for event in self._buffer:
+                    r.transport.write(
+                        dumps(event, ensure_ascii=False, separators=(",", ":")).encode(
+                            "utf8"
+                        )
+                    )
+                    r.transport.write(b"\n")
+                self._buffer.clear()
+            except Exception as e:
+                sys.__stderr__.write("Failed writing out logs with %s\n" % (str(e),))
+
+            self._writer = False
+            self.hs.get_reactor().callLater(1, self._write_loop)
+
+    def _handle_pressure(self) -> None:
+        """
+        Handle backpressure by shedding events.
+
+        The buffer will, in this order, until the buffer is below the maximum:
+            - Shed DEBUG events
+            - Shed INFO events
+            - Shed the middle 50% of the events.
+        """
+        if len(self._buffer) <= self.maximum_buffer:
+            return
+
+        # Strip out DEBUGs
+        self._buffer = deque(
+            filter(lambda event: event["level"] != "DEBUG", self._buffer)
+        )
+
+        if len(self._buffer) <= self.maximum_buffer:
+            return
+
+        # Strip out INFOs
+        self._buffer = deque(
+            filter(lambda event: event["level"] != "INFO", self._buffer)
+        )
+
+        if len(self._buffer) <= self.maximum_buffer:
+            return
+
+        # Cut the middle entries out
+        buffer_split = floor(self.maximum_buffer / 2)
+
+        old_buffer = self._buffer
+        self._buffer = deque()
+
+        for i in range(buffer_split):
+            self._buffer.append(old_buffer.popleft())
+
+        end_buffer = []
+        for i in range(buffer_split):
+            end_buffer.append(old_buffer.pop())
+
+        self._buffer.extend(reversed(end_buffer))
+
+    def __call__(self, event: dict) -> None:
+        flattened = flatten_event(event, self.metadata, include_time=True)
+        self._buffer.append(flattened)
+
+        # Handle backpressure, if it exists.
+        try:
+            self._handle_pressure()
+        except Exception:
+            # If handling backpressure fails,clear the buffer and log the
+            # exception.
+            self._buffer.clear()
+            self._logger.failure("Failed clearing backpressure")
+
+        # Try and write immediately.
+        self._write_loop()
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index b456c31f70..63379bfb93 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -25,6 +25,7 @@ See doc/log_contexts.rst for details on how this works.
 import logging
 import threading
 import types
+from typing import Any, List
 
 from twisted.internet import defer, threads
 
@@ -194,7 +195,7 @@ class LoggingContext(object):
     class Sentinel(object):
         """Sentinel to represent the root context"""
 
-        __slots__ = []
+        __slots__ = []  # type: List[Any]
 
         def __str__(self):
             return "sentinel"
@@ -202,6 +203,10 @@ class LoggingContext(object):
         def copy_to(self, record):
             pass
 
+        def copy_to_twisted_log_entry(self, record):
+            record["request"] = None
+            record["scope"] = None
+
         def start(self):
             pass
 
@@ -330,6 +335,13 @@ class LoggingContext(object):
         # we also track the current scope:
         record.scope = self.scope
 
+    def copy_to_twisted_log_entry(self, record):
+        """
+        Copy logging fields from this context to a Twisted log record.
+        """
+        record["request"] = self.request
+        record["scope"] = self.scope
+
     def start(self):
         if get_thread_id() != self.main_thread:
             logger.warning("Started logcontext %s on different thread", self)
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index d2c209c471..dd296027a1 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -43,6 +43,9 @@ OpenTracing to be easily disabled in Synapse and thereby have OpenTracing as
 an optional dependency. This does however limit the number of modifiable spans
 at any point in the code to one. From here out references to `opentracing`
 in the code snippets refer to the Synapses module.
+Most methods provided in the module have a direct correlation to those provided
+by opentracing. Refer to docs there for a more in-depth documentation on some of
+the args and methods.
 
 Tracing
 -------
@@ -68,52 +71,62 @@ set a tag on the current active span.
 Tracing functions
 -----------------
 
-Functions can be easily traced using decorators. There is a decorator for
-'normal' function and for functions which are actually deferreds. The name of
+Functions can be easily traced using decorators. The name of
 the function becomes the operation name for the span.
 
 .. code-block:: python
 
-   from synapse.logging.opentracing import trace, trace_deferred
+   from synapse.logging.opentracing import trace
 
-   # Start a span using 'normal_function' as the operation name
+   # Start a span using 'interesting_function' as the operation name
    @trace
-   def normal_function(*args, **kwargs):
+   def interesting_function(*args, **kwargs):
        # Does all kinds of cool and expected things
        return something_usual_and_useful
 
-   # Start a span using 'deferred_function' as the operation name
-   @trace_deferred
-   @defer.inlineCallbacks
-   def deferred_function(*args, **kwargs):
-       # We start
-       yield we_wait
-       # we finish
-       return something_usual_and_useful
 
 Operation names can be explicitly set for functions by using
-``trace_using_operation_name`` and
-``trace_deferred_using_operation_name``
+``trace_using_operation_name``
 
 .. code-block:: python
 
-   from synapse.logging.opentracing import (
-       trace_using_operation_name,
-       trace_deferred_using_operation_name
-   )
+   from synapse.logging.opentracing import trace_using_operation_name
 
    @trace_using_operation_name("A *much* better operation name")
-   def normal_function(*args, **kwargs):
+   def interesting_badly_named_function(*args, **kwargs):
        # Does all kinds of cool and expected things
        return something_usual_and_useful
 
-   @trace_deferred_using_operation_name("Another exciting operation name!")
-   @defer.inlineCallbacks
-   def deferred_function(*args, **kwargs):
-       # We start
-       yield we_wait
-       # we finish
-       return something_usual_and_useful
+Setting Tags
+------------
+
+To set a tag on the active span do
+
+.. code-block:: python
+
+   from synapse.logging.opentracing import set_tag
+
+   set_tag(tag_name, tag_value)
+
+There's a convenient decorator to tag all the args of the method. It uses
+inspection in order to use the formal parameter names prefixed with 'ARG_' as
+tag names. It uses kwarg names as tag names without the prefix.
+
+.. code-block:: python
+
+   from synapse.logging.opentracing import tag_args
+
+   @tag_args
+   def set_fates(clotho, lachesis, atropos, father="Zues", mother="Themis"):
+       pass
+
+   set_fates("the story", "the end", "the act")
+   # This will have the following tags
+   #  - ARG_clotho: "the story"
+   #  - ARG_lachesis: "the end"
+   #  - ARG_atropos: "the act"
+   #  - father: "Zues"
+   #  - mother: "Themis"
 
 Contexts and carriers
 ---------------------
@@ -136,6 +149,9 @@ unchartered waters will require the enforcement of the whitelist.
 ``logging/opentracing.py`` has a ``whitelisted_homeserver`` method which takes
 in a destination and compares it to the whitelist.
 
+Most injection methods take a 'destination' arg. The context will only be injected
+if the destination matches the whitelist or the destination is None.
+
 =======
 Gotchas
 =======
@@ -161,10 +177,48 @@ from twisted.internet import defer
 
 from synapse.config import ConfigError
 
+# Helper class
+
+
+class _DummyTagNames(object):
+    """wrapper of opentracings tags. We need to have them if we
+    want to reference them without opentracing around. Clearly they
+    should never actually show up in a trace. `set_tags` overwrites
+    these with the correct ones."""
+
+    INVALID_TAG = "invalid-tag"
+    COMPONENT = INVALID_TAG
+    DATABASE_INSTANCE = INVALID_TAG
+    DATABASE_STATEMENT = INVALID_TAG
+    DATABASE_TYPE = INVALID_TAG
+    DATABASE_USER = INVALID_TAG
+    ERROR = INVALID_TAG
+    HTTP_METHOD = INVALID_TAG
+    HTTP_STATUS_CODE = INVALID_TAG
+    HTTP_URL = INVALID_TAG
+    MESSAGE_BUS_DESTINATION = INVALID_TAG
+    PEER_ADDRESS = INVALID_TAG
+    PEER_HOSTNAME = INVALID_TAG
+    PEER_HOST_IPV4 = INVALID_TAG
+    PEER_HOST_IPV6 = INVALID_TAG
+    PEER_PORT = INVALID_TAG
+    PEER_SERVICE = INVALID_TAG
+    SAMPLING_PRIORITY = INVALID_TAG
+    SERVICE = INVALID_TAG
+    SPAN_KIND = INVALID_TAG
+    SPAN_KIND_CONSUMER = INVALID_TAG
+    SPAN_KIND_PRODUCER = INVALID_TAG
+    SPAN_KIND_RPC_CLIENT = INVALID_TAG
+    SPAN_KIND_RPC_SERVER = INVALID_TAG
+
+
 try:
     import opentracing
+
+    tags = opentracing.tags
 except ImportError:
     opentracing = None
+    tags = _DummyTagNames
 try:
     from jaeger_client import Config as JaegerConfig
     from synapse.logging.scopecontextmanager import LogContextScopeManager
@@ -239,10 +293,6 @@ def init_tracer(config):
         scope_manager=LogContextScopeManager(config),
     ).initialize_tracer()
 
-    # Set up tags to be opentracing's tags
-    global tags
-    tags = opentracing.tags
-
 
 # Whitelisting
 
@@ -321,8 +371,8 @@ def start_active_span_follows_from(operation_name, contexts):
         return scope
 
 
-def start_active_span_from_context(
-    headers,
+def start_active_span_from_request(
+    request,
     operation_name,
     references=None,
     tags=None,
@@ -331,9 +381,9 @@ def start_active_span_from_context(
     finish_on_close=True,
 ):
     """
-    Extracts a span context from Twisted Headers.
+    Extracts a span context from a Twisted Request.
     args:
-        headers (twisted.web.http_headers.Headers)
+        headers (twisted.web.http.Request)
 
         For the other args see opentracing.tracer
 
@@ -347,7 +397,9 @@ def start_active_span_from_context(
     if opentracing is None:
         return _noop_context_manager()
 
-    header_dict = {k.decode(): v[0].decode() for k, v in headers.getAllRawHeaders()}
+    header_dict = {
+        k.decode(): v[0].decode() for k, v in request.requestHeaders.getAllRawHeaders()
+    }
     context = opentracing.tracer.extract(opentracing.Format.HTTP_HEADERS, header_dict)
 
     return opentracing.tracer.start_active_span(
@@ -435,7 +487,7 @@ def set_operation_name(operation_name):
 
 
 @only_if_tracing
-def inject_active_span_twisted_headers(headers, destination):
+def inject_active_span_twisted_headers(headers, destination, check_destination=True):
     """
     Injects a span context into twisted headers in-place
 
@@ -454,7 +506,7 @@ def inject_active_span_twisted_headers(headers, destination):
         https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py
     """
 
-    if not whitelisted_homeserver(destination):
+    if check_destination and not whitelisted_homeserver(destination):
         return
 
     span = opentracing.tracer.active_span
@@ -466,7 +518,7 @@ def inject_active_span_twisted_headers(headers, destination):
 
 
 @only_if_tracing
-def inject_active_span_byte_dict(headers, destination):
+def inject_active_span_byte_dict(headers, destination, check_destination=True):
     """
     Injects a span context into a dict where the headers are encoded as byte
     strings
@@ -498,7 +550,7 @@ def inject_active_span_byte_dict(headers, destination):
 
 
 @only_if_tracing
-def inject_active_span_text_map(carrier, destination=None):
+def inject_active_span_text_map(carrier, destination, check_destination=True):
     """
     Injects a span context into a dict
 
@@ -519,7 +571,7 @@ def inject_active_span_text_map(carrier, destination=None):
         https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py
     """
 
-    if destination and not whitelisted_homeserver(destination):
+    if check_destination and not whitelisted_homeserver(destination):
         return
 
     opentracing.tracer.inject(
@@ -527,6 +579,29 @@ def inject_active_span_text_map(carrier, destination=None):
     )
 
 
+def get_active_span_text_map(destination=None):
+    """
+    Gets a span context as a dict. This can be used instead of manually
+    injecting a span into an empty carrier.
+
+    Args:
+        destination (str): the name of the remote server.
+
+    Returns:
+        dict: the active span's context if opentracing is enabled, otherwise empty.
+    """
+
+    if not opentracing or (destination and not whitelisted_homeserver(destination)):
+        return {}
+
+    carrier = {}
+    opentracing.tracer.inject(
+        opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
+    )
+
+    return carrier
+
+
 def active_span_context_as_string():
     """
     Returns:
@@ -676,65 +751,43 @@ def tag_args(func):
     return _tag_args_inner
 
 
-def trace_servlet(servlet_name, func):
+def trace_servlet(servlet_name, extract_context=False):
     """Decorator which traces a serlet. It starts a span with some servlet specific
-    tags such as the servlet_name and request information"""
-    if not opentracing:
-        return func
+    tags such as the servlet_name and request information
 
-    @wraps(func)
-    @defer.inlineCallbacks
-    def _trace_servlet_inner(request, *args, **kwargs):
-        with start_active_span(
-            "incoming-client-request",
-            tags={
+    Args:
+        servlet_name (str): The name to be used for the span's operation_name
+        extract_context (bool): Whether to attempt to extract the opentracing
+            context from the request the servlet is handling.
+
+    """
+
+    def _trace_servlet_inner_1(func):
+        if not opentracing:
+            return func
+
+        @wraps(func)
+        @defer.inlineCallbacks
+        def _trace_servlet_inner(request, *args, **kwargs):
+            request_tags = {
                 "request_id": request.get_request_id(),
                 tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
                 tags.HTTP_METHOD: request.get_method(),
                 tags.HTTP_URL: request.get_redacted_uri(),
                 tags.PEER_HOST_IPV6: request.getClientIP(),
-                "servlet_name": servlet_name,
-            },
-        ):
-            result = yield defer.maybeDeferred(func, request, *args, **kwargs)
-            return result
+            }
 
-    return _trace_servlet_inner
-
-
-# Helper class
-
-
-class _DummyTagNames(object):
-    """wrapper of opentracings tags. We need to have them if we
-    want to reference them without opentracing around. Clearly they
-    should never actually show up in a trace. `set_tags` overwrites
-    these with the correct ones."""
+            if extract_context:
+                scope = start_active_span_from_request(
+                    request, servlet_name, tags=request_tags
+                )
+            else:
+                scope = start_active_span(servlet_name, tags=request_tags)
 
-    INVALID_TAG = "invalid-tag"
-    COMPONENT = INVALID_TAG
-    DATABASE_INSTANCE = INVALID_TAG
-    DATABASE_STATEMENT = INVALID_TAG
-    DATABASE_TYPE = INVALID_TAG
-    DATABASE_USER = INVALID_TAG
-    ERROR = INVALID_TAG
-    HTTP_METHOD = INVALID_TAG
-    HTTP_STATUS_CODE = INVALID_TAG
-    HTTP_URL = INVALID_TAG
-    MESSAGE_BUS_DESTINATION = INVALID_TAG
-    PEER_ADDRESS = INVALID_TAG
-    PEER_HOSTNAME = INVALID_TAG
-    PEER_HOST_IPV4 = INVALID_TAG
-    PEER_HOST_IPV6 = INVALID_TAG
-    PEER_PORT = INVALID_TAG
-    PEER_SERVICE = INVALID_TAG
-    SAMPLING_PRIORITY = INVALID_TAG
-    SERVICE = INVALID_TAG
-    SPAN_KIND = INVALID_TAG
-    SPAN_KIND_CONSUMER = INVALID_TAG
-    SPAN_KIND_PRODUCER = INVALID_TAG
-    SPAN_KIND_RPC_CLIENT = INVALID_TAG
-    SPAN_KIND_RPC_SERVER = INVALID_TAG
+            with scope:
+                result = yield defer.maybeDeferred(func, request, *args, **kwargs)
+                return result
 
+        return _trace_servlet_inner
 
-tags = _DummyTagNames
+    return _trace_servlet_inner_1