summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--CHANGES.rst5
-rw-r--r--docs/metrics-howto.rst61
-rw-r--r--synapse/config/tls.py2
-rw-r--r--synapse/federation/transaction_queue.py4
-rw-r--r--synapse/handlers/appservice.py7
-rw-r--r--synapse/handlers/devicemessage.py14
-rw-r--r--synapse/handlers/e2e_keys.py13
-rw-r--r--synapse/http/server.py85
-rw-r--r--synapse/http/site.py10
-rw-r--r--synapse/metrics/metric.py92
-rw-r--r--synapse/replication/tcp/protocol.py19
-rw-r--r--synapse/rest/media/v1/_base.py144
-rw-r--r--synapse/rest/media/v1/download_resource.py75
-rw-r--r--synapse/rest/media/v1/media_repository.py559
-rw-r--r--synapse/rest/media/v1/media_storage.py228
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py26
-rw-r--r--synapse/rest/media/v1/storage_provider.py127
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py140
-rw-r--r--synapse/state.py45
-rw-r--r--synapse/storage/_base.py55
-rw-r--r--synapse/storage/events.py13
-rw-r--r--synapse/storage/media_repository.py23
-rw-r--r--synapse/storage/prepare_database.py2
-rw-r--r--synapse/storage/schema/delta/47/last_access_media.sql16
-rw-r--r--synapse/storage/user_directory.py9
-rw-r--r--synapse/util/logcontext.py48
-rw-r--r--synapse/util/metrics.py73
-rw-r--r--tests/crypto/test_keyring.py14
-rw-r--r--tests/util/test_logcontext.py16
29 files changed, 1404 insertions, 521 deletions
diff --git a/CHANGES.rst b/CHANGES.rst
index 24e4e7a384..a7ed49e105 100644
--- a/CHANGES.rst
+++ b/CHANGES.rst
@@ -2,7 +2,10 @@ Unreleased
 ==========
 
 synctl no longer starts the main synapse when using ``-a`` option with workers.
-A new worker file should be added with ``worker_app: synapse.app.homeserver``
+A new worker file should be added with ``worker_app: synapse.app.homeserver``.
+
+This release also begins the process of renaming a number of the metrics
+reported to prometheus. See `docs/metrics-howto.rst <docs/metrics-howto.rst#block-and-response-metrics-renamed-for-0-27-0>`_.
 
 
 Changes in synapse v0.26.0 (2018-01-05)
diff --git a/docs/metrics-howto.rst b/docs/metrics-howto.rst
index 143cd0f42f..8acc479bc3 100644
--- a/docs/metrics-howto.rst
+++ b/docs/metrics-howto.rst
@@ -16,7 +16,7 @@ How to monitor Synapse metrics using Prometheus
      metrics_port: 9092
 
    Also ensure that ``enable_metrics`` is set to ``True``.
-  
+
    Restart synapse.
 
 3. Add a prometheus target for synapse.
@@ -28,11 +28,58 @@ How to monitor Synapse metrics using Prometheus
       static_configs:
         - targets: ["my.server.here:9092"]
 
-   If your prometheus is older than 1.5.2, you will need to replace 
+   If your prometheus is older than 1.5.2, you will need to replace
    ``static_configs`` in the above with ``target_groups``.
-   
+
    Restart prometheus.
 
+
+Block and response metrics renamed for 0.27.0
+---------------------------------------------
+
+Synapse 0.27.0 begins the process of rationalising the duplicate ``*:count``
+metrics reported for the resource tracking for code blocks and HTTP requests.
+
+At the same time, the corresponding ``*:total`` metrics are being renamed, as
+the ``:total`` suffix no longer makes sense in the absence of a corresponding
+``:count`` metric.
+
+To enable a graceful migration path, this release just adds new names for the
+metrics being renamed. A future release will remove the old ones.
+
+The following table shows the new metrics, and the old metrics which they are
+replacing.
+
+==================================================== ===================================================
+New name                                             Old name
+==================================================== ===================================================
+synapse_util_metrics_block_count                     synapse_util_metrics_block_timer:count
+synapse_util_metrics_block_count                     synapse_util_metrics_block_ru_utime:count
+synapse_util_metrics_block_count                     synapse_util_metrics_block_ru_stime:count
+synapse_util_metrics_block_count                     synapse_util_metrics_block_db_txn_count:count
+synapse_util_metrics_block_count                     synapse_util_metrics_block_db_txn_duration:count
+
+synapse_util_metrics_block_time_seconds              synapse_util_metrics_block_timer:total
+synapse_util_metrics_block_ru_utime_seconds          synapse_util_metrics_block_ru_utime:total
+synapse_util_metrics_block_ru_stime_seconds          synapse_util_metrics_block_ru_stime:total
+synapse_util_metrics_block_db_txn_count              synapse_util_metrics_block_db_txn_count:total
+synapse_util_metrics_block_db_txn_duration_seconds   synapse_util_metrics_block_db_txn_duration:total
+
+synapse_http_server_response_count                   synapse_http_server_requests
+synapse_http_server_response_count                   synapse_http_server_response_time:count
+synapse_http_server_response_count                   synapse_http_server_response_ru_utime:count
+synapse_http_server_response_count                   synapse_http_server_response_ru_stime:count
+synapse_http_server_response_count                   synapse_http_server_response_db_txn_count:count
+synapse_http_server_response_count                   synapse_http_server_response_db_txn_duration:count
+
+synapse_http_server_response_time_seconds            synapse_http_server_response_time:total
+synapse_http_server_response_ru_utime_seconds        synapse_http_server_response_ru_utime:total
+synapse_http_server_response_ru_stime_seconds        synapse_http_server_response_ru_stime:total
+synapse_http_server_response_db_txn_count            synapse_http_server_response_db_txn_count:total
+synapse_http_server_response_db_txn_duration_seconds synapse_http_server_response_db_txn_duration:total
+==================================================== ===================================================
+
+
 Standard Metric Names
 ---------------------
 
@@ -42,7 +89,7 @@ have been changed to seconds, from miliseconds.
 
 ================================== =============================
 New name                           Old name
----------------------------------- -----------------------------
+================================== =============================
 process_cpu_user_seconds_total     process_resource_utime / 1000
 process_cpu_system_seconds_total   process_resource_stime / 1000
 process_open_fds (no 'type' label) process_fds
@@ -52,8 +99,8 @@ The python-specific counts of garbage collector performance have been renamed.
 
 =========================== ======================
 New name                    Old name
---------------------------- ----------------------
-python_gc_time              reactor_gc_time      
+=========================== ======================
+python_gc_time              reactor_gc_time
 python_gc_unreachable_total reactor_gc_unreachable
 python_gc_counts            reactor_gc_counts
 =========================== ======================
@@ -62,7 +109,7 @@ The twisted-specific reactor metrics have been renamed.
 
 ==================================== =====================
 New name                             Old name
------------------------------------- ---------------------
+==================================== =====================
 python_twisted_reactor_pending_calls reactor_pending_calls
 python_twisted_reactor_tick_time     reactor_tick_time
 ==================================== =====================
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 4748f71c2f..29eb012ddb 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -96,7 +96,7 @@ class TlsConfig(Config):
         # certificates returned by this server match one of the fingerprints.
         #
         # Synapse automatically adds the fingerprint of its own certificate
-        # to the list. So if federation traffic is handle directly by synapse
+        # to the list. So if federation traffic is handled directly by synapse
         # then no modification to the list is required.
         #
         # If synapse is run behind a load balancer that handles the TLS then it
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 3e7809b04f..9d39f46583 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -42,6 +42,8 @@ sent_edus_counter = client_metrics.register_counter("sent_edus")
 
 sent_transactions_counter = client_metrics.register_counter("sent_transactions")
 
+events_processed_counter = client_metrics.register_counter("events_processed")
+
 
 class TransactionQueue(object):
     """This class makes sure we only have one transaction in flight at
@@ -205,6 +207,8 @@ class TransactionQueue(object):
 
                     self._send_pdu(event, destinations)
 
+                events_processed_counter.inc_by(len(events))
+
                 yield self.store.update_federation_out_pos(
                     "events", next_token
                 )
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index feca3e4c10..3dd3fa2a27 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -15,6 +15,7 @@
 
 from twisted.internet import defer
 
+import synapse
 from synapse.api.constants import EventTypes
 from synapse.util.metrics import Measure
 from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
@@ -23,6 +24,10 @@ import logging
 
 logger = logging.getLogger(__name__)
 
+metrics = synapse.metrics.get_metrics_for(__name__)
+
+events_processed_counter = metrics.register_counter("events_processed")
+
 
 def log_failure(failure):
     logger.error(
@@ -103,6 +108,8 @@ class ApplicationServicesHandler(object):
                                 service, event
                             )
 
+                    events_processed_counter.inc_by(len(events))
+
                     yield self.store.set_appservice_last_pos(upper_bound)
             finally:
                 self.is_processing = False
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index f7fad15c62..d996aa90bb 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -17,7 +17,8 @@ import logging
 
 from twisted.internet import defer
 
-from synapse.types import get_domain_from_id
+from synapse.api.errors import SynapseError
+from synapse.types import get_domain_from_id, UserID
 from synapse.util.stringutils import random_string
 
 
@@ -33,7 +34,7 @@ class DeviceMessageHandler(object):
         """
         self.store = hs.get_datastore()
         self.notifier = hs.get_notifier()
-        self.is_mine_id = hs.is_mine_id
+        self.is_mine = hs.is_mine
         self.federation = hs.get_federation_sender()
 
         hs.get_replication_layer().register_edu_handler(
@@ -52,6 +53,12 @@ class DeviceMessageHandler(object):
         message_type = content["type"]
         message_id = content["message_id"]
         for user_id, by_device in content["messages"].items():
+            # we use UserID.from_string to catch invalid user ids
+            if not self.is_mine(UserID.from_string(user_id)):
+                logger.warning("Request for keys for non-local user %s",
+                               user_id)
+                raise SynapseError(400, "Not a user here")
+
             messages_by_device = {
                 device_id: {
                     "content": message_content,
@@ -77,7 +84,8 @@ class DeviceMessageHandler(object):
         local_messages = {}
         remote_messages = {}
         for user_id, by_device in messages.items():
-            if self.is_mine_id(user_id):
+            # we use UserID.from_string to catch invalid user ids
+            if self.is_mine(UserID.from_string(user_id)):
                 messages_by_device = {
                     device_id: {
                         "content": message_content,
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 668a90e495..5af8abf66b 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -20,7 +20,7 @@ from canonicaljson import encode_canonical_json
 from twisted.internet import defer
 
 from synapse.api.errors import SynapseError, CodeMessageException
-from synapse.types import get_domain_from_id
+from synapse.types import get_domain_from_id, UserID
 from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
 from synapse.util.retryutils import NotRetryingDestination
 
@@ -32,7 +32,7 @@ class E2eKeysHandler(object):
         self.store = hs.get_datastore()
         self.federation = hs.get_replication_layer()
         self.device_handler = hs.get_device_handler()
-        self.is_mine_id = hs.is_mine_id
+        self.is_mine = hs.is_mine
         self.clock = hs.get_clock()
 
         # doesn't really work as part of the generic query API, because the
@@ -70,7 +70,8 @@ class E2eKeysHandler(object):
         remote_queries = {}
 
         for user_id, device_ids in device_keys_query.items():
-            if self.is_mine_id(user_id):
+            # we use UserID.from_string to catch invalid user ids
+            if self.is_mine(UserID.from_string(user_id)):
                 local_query[user_id] = device_ids
             else:
                 remote_queries[user_id] = device_ids
@@ -170,7 +171,8 @@ class E2eKeysHandler(object):
 
         result_dict = {}
         for user_id, device_ids in query.items():
-            if not self.is_mine_id(user_id):
+            # we use UserID.from_string to catch invalid user ids
+            if not self.is_mine(UserID.from_string(user_id)):
                 logger.warning("Request for keys for non-local user %s",
                                user_id)
                 raise SynapseError(400, "Not a user here")
@@ -213,7 +215,8 @@ class E2eKeysHandler(object):
         remote_queries = {}
 
         for user_id, device_keys in query.get("one_time_keys", {}).items():
-            if self.is_mine_id(user_id):
+            # we use UserID.from_string to catch invalid user ids
+            if self.is_mine(UserID.from_string(user_id)):
                 for device_id, algorithm in device_keys.items():
                     local_query.append((user_id, device_id, algorithm))
             else:
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 6e8f4c9c5f..165c684d0d 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -42,36 +42,70 @@ logger = logging.getLogger(__name__)
 
 metrics = synapse.metrics.get_metrics_for(__name__)
 
-incoming_requests_counter = metrics.register_counter(
-    "requests",
+# total number of responses served, split by method/servlet/tag
+response_count = metrics.register_counter(
+    "response_count",
     labels=["method", "servlet", "tag"],
+    alternative_names=(
+        # the following are all deprecated aliases for the same metric
+        metrics.name_prefix + x for x in (
+            "_requests",
+            "_response_time:count",
+            "_response_ru_utime:count",
+            "_response_ru_stime:count",
+            "_response_db_txn_count:count",
+            "_response_db_txn_duration:count",
+        )
+    )
 )
+
 outgoing_responses_counter = metrics.register_counter(
     "responses",
     labels=["method", "code"],
 )
 
-response_timer = metrics.register_distribution(
-    "response_time",
-    labels=["method", "servlet", "tag"]
+response_timer = metrics.register_counter(
+    "response_time_seconds",
+    labels=["method", "servlet", "tag"],
+    alternative_names=(
+        metrics.name_prefix + "_response_time:total",
+    ),
 )
 
-response_ru_utime = metrics.register_distribution(
-    "response_ru_utime", labels=["method", "servlet", "tag"]
+response_ru_utime = metrics.register_counter(
+    "response_ru_utime_seconds", labels=["method", "servlet", "tag"],
+    alternative_names=(
+        metrics.name_prefix + "_response_ru_utime:total",
+    ),
 )
 
-response_ru_stime = metrics.register_distribution(
-    "response_ru_stime", labels=["method", "servlet", "tag"]
+response_ru_stime = metrics.register_counter(
+    "response_ru_stime_seconds", labels=["method", "servlet", "tag"],
+    alternative_names=(
+        metrics.name_prefix + "_response_ru_stime:total",
+    ),
 )
 
-response_db_txn_count = metrics.register_distribution(
-    "response_db_txn_count", labels=["method", "servlet", "tag"]
+response_db_txn_count = metrics.register_counter(
+    "response_db_txn_count", labels=["method", "servlet", "tag"],
+    alternative_names=(
+        metrics.name_prefix + "_response_db_txn_count:total",
+    ),
 )
 
-response_db_txn_duration = metrics.register_distribution(
-    "response_db_txn_duration", labels=["method", "servlet", "tag"]
+# seconds spent waiting for db txns, excluding scheduling time, when processing
+# this request
+response_db_txn_duration = metrics.register_counter(
+    "response_db_txn_duration_seconds", labels=["method", "servlet", "tag"],
+    alternative_names=(
+        metrics.name_prefix + "_response_db_txn_duration:total",
+    ),
 )
 
+# seconds spent waiting for a db connection, when processing this request
+response_db_sched_duration = metrics.register_counter(
+    "response_db_sched_duration_seconds", labels=["method", "servlet", "tag"]
+)
 
 _next_request_id = 0
 
@@ -288,15 +322,6 @@ class JsonResource(HttpServer, resource.Resource):
 
     def _send_response(self, request, code, response_json_object,
                        response_code_message=None):
-        # could alternatively use request.notifyFinish() and flip a flag when
-        # the Deferred fires, but since the flag is RIGHT THERE it seems like
-        # a waste.
-        if request._disconnected:
-            logger.warn(
-                "Not sending response to request %s, already disconnected.",
-                request)
-            return
-
         outgoing_responses_counter.inc(request.method, str(code))
 
         # TODO: Only enable CORS for the requests that need it.
@@ -330,7 +355,7 @@ class RequestMetrics(object):
                 )
                 return
 
-        incoming_requests_counter.inc(request.method, self.name, tag)
+        response_count.inc(request.method, self.name, tag)
 
         response_timer.inc_by(
             clock.time_msec() - self.start, request.method,
@@ -349,7 +374,10 @@ class RequestMetrics(object):
             context.db_txn_count, request.method, self.name, tag
         )
         response_db_txn_duration.inc_by(
-            context.db_txn_duration, request.method, self.name, tag
+            context.db_txn_duration_ms / 1000., request.method, self.name, tag
+        )
+        response_db_sched_duration.inc_by(
+            context.db_sched_duration_ms / 1000., request.method, self.name, tag
         )
 
 
@@ -372,6 +400,15 @@ class RootRedirect(resource.Resource):
 def respond_with_json(request, code, json_object, send_cors=False,
                       response_code_message=None, pretty_print=False,
                       version_string="", canonical_json=True):
+    # could alternatively use request.notifyFinish() and flip a flag when
+    # the Deferred fires, but since the flag is RIGHT THERE it seems like
+    # a waste.
+    if request._disconnected:
+        logger.warn(
+            "Not sending response to request %s, already disconnected.",
+            request)
+        return
+
     if pretty_print:
         json_bytes = encode_pretty_printed_json(json_object) + "\n"
     else:
diff --git a/synapse/http/site.py b/synapse/http/site.py
index cd1492b1c3..e422c8dfae 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -66,14 +66,15 @@ class SynapseRequest(Request):
             context = LoggingContext.current_context()
             ru_utime, ru_stime = context.get_resource_usage()
             db_txn_count = context.db_txn_count
-            db_txn_duration = context.db_txn_duration
+            db_txn_duration_ms = context.db_txn_duration_ms
+            db_sched_duration_ms = context.db_sched_duration_ms
         except Exception:
             ru_utime, ru_stime = (0, 0)
-            db_txn_count, db_txn_duration = (0, 0)
+            db_txn_count, db_txn_duration_ms = (0, 0)
 
         self.site.access_logger.info(
             "%s - %s - {%s}"
-            " Processed request: %dms (%dms, %dms) (%dms/%d)"
+            " Processed request: %dms (%dms, %dms) (%dms/%dms/%d)"
             " %sB %s \"%s %s %s\" \"%s\"",
             self.getClientIP(),
             self.site.site_tag,
@@ -81,7 +82,8 @@ class SynapseRequest(Request):
             int(time.time() * 1000) - self.start_time,
             int(ru_utime * 1000),
             int(ru_stime * 1000),
-            int(db_txn_duration * 1000),
+            db_sched_duration_ms,
+            db_txn_duration_ms,
             int(db_txn_count),
             self.sentLength,
             self.code,
diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py
index e87b2b80a7..f480aae614 100644
--- a/synapse/metrics/metric.py
+++ b/synapse/metrics/metric.py
@@ -17,16 +17,33 @@
 from itertools import chain
 
 
-# TODO(paul): I can't believe Python doesn't have one of these
-def map_concat(func, items):
-    # flatten a list-of-lists
-    return list(chain.from_iterable(map(func, items)))
+def flatten(items):
+    """Flatten a list of lists
+
+    Args:
+        items: iterable[iterable[X]]
+
+    Returns:
+        list[X]: flattened list
+    """
+    return list(chain.from_iterable(items))
 
 
 class BaseMetric(object):
+    """Base class for metrics which report a single value per label set
+    """
 
-    def __init__(self, name, labels=[]):
-        self.name = name
+    def __init__(self, name, labels=[], alternative_names=[]):
+        """
+        Args:
+            name (str): principal name for this metric
+            labels (list(str)): names of the labels which will be reported
+                for this metric
+            alternative_names (iterable(str)): list of alternative names for
+                 this metric. This can be useful to provide a migration path
+                when renaming metrics.
+        """
+        self._names = [name] + list(alternative_names)
         self.labels = labels  # OK not to clone as we never write it
 
     def dimension(self):
@@ -36,7 +53,7 @@ class BaseMetric(object):
         return not len(self.labels)
 
     def _render_labelvalue(self, value):
-        # TODO: some kind of value escape
+        # TODO: escape backslashes, quotes and newlines
         return '"%s"' % (value)
 
     def _render_key(self, values):
@@ -47,19 +64,60 @@ class BaseMetric(object):
                       for k, v in zip(self.labels, values)])
         )
 
+    def _render_for_labels(self, label_values, value):
+        """Render this metric for a single set of labels
+
+        Args:
+            label_values (list[str]): values for each of the labels
+            value: value of the metric at with these labels
+
+        Returns:
+            iterable[str]: rendered metric
+        """
+        rendered_labels = self._render_key(label_values)
+        return (
+            "%s%s %.12g" % (name, rendered_labels, value)
+            for name in self._names
+        )
+
+    def render(self):
+        """Render this metric
+
+        Each metric is rendered as:
+
+            name{label1="val1",label2="val2"} value
+
+        https://prometheus.io/docs/instrumenting/exposition_formats/#text-format-details
+
+        Returns:
+            iterable[str]: rendered metrics
+        """
+        raise NotImplementedError()
+
 
 class CounterMetric(BaseMetric):
     """The simplest kind of metric; one that stores a monotonically-increasing
-    integer that counts events."""
+    value that counts events or running totals.
+
+    Example use cases for Counters:
+    - Number of requests processed
+    - Number of items that were inserted into a queue
+    - Total amount of data that a system has processed
+    Counters can only go up (and be reset when the process restarts).
+    """
 
     def __init__(self, *args, **kwargs):
         super(CounterMetric, self).__init__(*args, **kwargs)
 
+        # dict[list[str]]: value for each set of label values. the keys are the
+        # label values, in the same order as the labels in self.labels.
+        #
+        # (if the metric is a scalar, the (single) key is the empty list).
         self.counts = {}
 
         # Scalar metrics are never empty
         if self.is_scalar():
-            self.counts[()] = 0
+            self.counts[()] = 0.
 
     def inc_by(self, incr, *values):
         if len(values) != self.dimension():
@@ -77,11 +135,11 @@ class CounterMetric(BaseMetric):
     def inc(self, *values):
         self.inc_by(1, *values)
 
-    def render_item(self, k):
-        return ["%s%s %d" % (self.name, self._render_key(k), self.counts[k])]
-
     def render(self):
-        return map_concat(self.render_item, sorted(self.counts.keys()))
+        return flatten(
+            self._render_for_labels(k, self.counts[k])
+            for k in sorted(self.counts.keys())
+        )
 
 
 class CallbackMetric(BaseMetric):
@@ -98,10 +156,12 @@ class CallbackMetric(BaseMetric):
         value = self.callback()
 
         if self.is_scalar():
-            return ["%s %.12g" % (self.name, value)]
+            return list(self._render_for_labels([], value))
 
-        return ["%s%s %.12g" % (self.name, self._render_key(k), value[k])
-                for k in sorted(value.keys())]
+        return flatten(
+            self._render_for_labels(k, value[k])
+            for k in sorted(value.keys())
+        )
 
 
 class DistributionMetric(object):
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index d59503b905..0a9a290af4 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -517,25 +517,28 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
             self.send_error("Wrong remote")
 
     def on_RDATA(self, cmd):
+        stream_name = cmd.stream_name
+        inbound_rdata_count.inc(stream_name)
+
         try:
-            row = STREAMS_MAP[cmd.stream_name].ROW_TYPE(*cmd.row)
+            row = STREAMS_MAP[stream_name].ROW_TYPE(*cmd.row)
         except Exception:
             logger.exception(
                 "[%s] Failed to parse RDATA: %r %r",
-                self.id(), cmd.stream_name, cmd.row
+                self.id(), stream_name, cmd.row
             )
             raise
 
         if cmd.token is None:
             # I.e. this is part of a batch of updates for this stream. Batch
             # until we get an update for the stream with a non None token
-            self.pending_batches.setdefault(cmd.stream_name, []).append(row)
+            self.pending_batches.setdefault(stream_name, []).append(row)
         else:
             # Check if this is the last of a batch of updates
-            rows = self.pending_batches.pop(cmd.stream_name, [])
+            rows = self.pending_batches.pop(stream_name, [])
             rows.append(row)
 
-            self.handler.on_rdata(cmd.stream_name, cmd.token, rows)
+            self.handler.on_rdata(stream_name, cmd.token, rows)
 
     def on_POSITION(self, cmd):
         self.handler.on_position(cmd.stream_name, cmd.token)
@@ -644,3 +647,9 @@ metrics.register_callback(
     },
     labels=["command", "name", "conn_id"],
 )
+
+# number of updates received for each RDATA stream
+inbound_rdata_count = metrics.register_counter(
+    "inbound_rdata_count",
+    labels=["stream_name"],
+)
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 95fa95fce3..e7ac01da01 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -70,38 +70,11 @@ def respond_with_file(request, media_type, file_path,
     logger.debug("Responding with %r", file_path)
 
     if os.path.isfile(file_path):
-        request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
-        if upload_name:
-            if is_ascii(upload_name):
-                request.setHeader(
-                    b"Content-Disposition",
-                    b"inline; filename=%s" % (
-                        urllib.quote(upload_name.encode("utf-8")),
-                    ),
-                )
-            else:
-                request.setHeader(
-                    b"Content-Disposition",
-                    b"inline; filename*=utf-8''%s" % (
-                        urllib.quote(upload_name.encode("utf-8")),
-                    ),
-                )
-
-        # cache for at least a day.
-        # XXX: we might want to turn this off for data we don't want to
-        # recommend caching as it's sensitive or private - or at least
-        # select private. don't bother setting Expires as all our
-        # clients are smart enough to be happy with Cache-Control
-        request.setHeader(
-            b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
-        )
         if file_size is None:
             stat = os.stat(file_path)
             file_size = stat.st_size
 
-        request.setHeader(
-            b"Content-Length", b"%d" % (file_size,)
-        )
+        add_file_headers(request, media_type, file_size, upload_name)
 
         with open(file_path, "rb") as f:
             yield logcontext.make_deferred_yieldable(
@@ -111,3 +84,118 @@ def respond_with_file(request, media_type, file_path,
         finish_request(request)
     else:
         respond_404(request)
+
+
+def add_file_headers(request, media_type, file_size, upload_name):
+    """Adds the correct response headers in preparation for responding with the
+    media.
+
+    Args:
+        request (twisted.web.http.Request)
+        media_type (str): The media/content type.
+        file_size (int): Size in bytes of the media, if known.
+        upload_name (str): The name of the requested file, if any.
+    """
+    request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
+    if upload_name:
+        if is_ascii(upload_name):
+            request.setHeader(
+                b"Content-Disposition",
+                b"inline; filename=%s" % (
+                    urllib.quote(upload_name.encode("utf-8")),
+                ),
+            )
+        else:
+            request.setHeader(
+                b"Content-Disposition",
+                b"inline; filename*=utf-8''%s" % (
+                    urllib.quote(upload_name.encode("utf-8")),
+                ),
+            )
+
+    # cache for at least a day.
+    # XXX: we might want to turn this off for data we don't want to
+    # recommend caching as it's sensitive or private - or at least
+    # select private. don't bother setting Expires as all our
+    # clients are smart enough to be happy with Cache-Control
+    request.setHeader(
+        b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
+    )
+
+    request.setHeader(
+        b"Content-Length", b"%d" % (file_size,)
+    )
+
+
+@defer.inlineCallbacks
+def respond_with_responder(request, responder, media_type, file_size, upload_name=None):
+    """Responds to the request with given responder. If responder is None then
+    returns 404.
+
+    Args:
+        request (twisted.web.http.Request)
+        responder (Responder|None)
+        media_type (str): The media/content type.
+        file_size (int|None): Size in bytes of the media. If not known it should be None
+        upload_name (str|None): The name of the requested file, if any.
+    """
+    if not responder:
+        respond_404(request)
+        return
+
+    add_file_headers(request, media_type, file_size, upload_name)
+    with responder:
+        yield responder.write_to_consumer(request)
+    finish_request(request)
+
+
+class Responder(object):
+    """Represents a response that can be streamed to the requester.
+
+    Responder is a context manager which *must* be used, so that any resources
+    held can be cleaned up.
+    """
+    def write_to_consumer(self, consumer):
+        """Stream response into consumer
+
+        Args:
+            consumer (IConsumer)
+
+        Returns:
+            Deferred: Resolves once the response has finished being written
+        """
+        pass
+
+    def __enter__(self):
+        pass
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        pass
+
+
+class FileInfo(object):
+    """Details about a requested/uploaded file.
+
+    Attributes:
+        server_name (str): The server name where the media originated from,
+            or None if local.
+        file_id (str): The local ID of the file. For local files this is the
+            same as the media_id
+        url_cache (bool): If the file is for the url preview cache
+        thumbnail (bool): Whether the file is a thumbnail or not.
+        thumbnail_width (int)
+        thumbnail_height (int)
+        thumbnail_method (str)
+        thumbnail_type (str): Content type of thumbnail, e.g. image/png
+    """
+    def __init__(self, server_name, file_id, url_cache=False,
+                 thumbnail=False, thumbnail_width=None, thumbnail_height=None,
+                 thumbnail_method=None, thumbnail_type=None):
+        self.server_name = server_name
+        self.file_id = file_id
+        self.url_cache = url_cache
+        self.thumbnail = thumbnail
+        self.thumbnail_width = thumbnail_width
+        self.thumbnail_height = thumbnail_height
+        self.thumbnail_method = thumbnail_method
+        self.thumbnail_type = thumbnail_type
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index 6879249c8a..fe7e17596f 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import synapse.http.servlet
 
-from ._base import parse_media_id, respond_with_file, respond_404
+from ._base import parse_media_id, respond_404
 from twisted.web.resource import Resource
 from synapse.http.server import request_handler, set_cors_headers
 
@@ -32,12 +32,12 @@ class DownloadResource(Resource):
     def __init__(self, hs, media_repo):
         Resource.__init__(self)
 
-        self.filepaths = media_repo.filepaths
         self.media_repo = media_repo
         self.server_name = hs.hostname
-        self.store = hs.get_datastore()
-        self.version_string = hs.version_string
+
+        # Both of these are expected by @request_handler()
         self.clock = hs.get_clock()
+        self.version_string = hs.version_string
 
     def render_GET(self, request):
         self._async_render_GET(request)
@@ -57,59 +57,16 @@ class DownloadResource(Resource):
         )
         server_name, media_id, name = parse_media_id(request)
         if server_name == self.server_name:
-            yield self._respond_local_file(request, media_id, name)
+            yield self.media_repo.get_local_media(request, media_id, name)
         else:
-            yield self._respond_remote_file(
-                request, server_name, media_id, name
-            )
-
-    @defer.inlineCallbacks
-    def _respond_local_file(self, request, media_id, name):
-        media_info = yield self.store.get_local_media(media_id)
-        if not media_info or media_info["quarantined_by"]:
-            respond_404(request)
-            return
-
-        media_type = media_info["media_type"]
-        media_length = media_info["media_length"]
-        upload_name = name if name else media_info["upload_name"]
-        if media_info["url_cache"]:
-            # TODO: Check the file still exists, if it doesn't we can redownload
-            # it from the url `media_info["url_cache"]`
-            file_path = self.filepaths.url_cache_filepath(media_id)
-        else:
-            file_path = self.filepaths.local_media_filepath(media_id)
-
-        yield respond_with_file(
-            request, media_type, file_path, media_length,
-            upload_name=upload_name,
-        )
-
-    @defer.inlineCallbacks
-    def _respond_remote_file(self, request, server_name, media_id, name):
-        # don't forward requests for remote media if allow_remote is false
-        allow_remote = synapse.http.servlet.parse_boolean(
-            request, "allow_remote", default=True)
-        if not allow_remote:
-            logger.info(
-                "Rejecting request for remote media %s/%s due to allow_remote",
-                server_name, media_id,
-            )
-            respond_404(request)
-            return
-
-        media_info = yield self.media_repo.get_remote_media(server_name, media_id)
-
-        media_type = media_info["media_type"]
-        media_length = media_info["media_length"]
-        filesystem_id = media_info["filesystem_id"]
-        upload_name = name if name else media_info["upload_name"]
-
-        file_path = self.filepaths.remote_media_filepath(
-            server_name, filesystem_id
-        )
-
-        yield respond_with_file(
-            request, media_type, file_path, media_length,
-            upload_name=upload_name,
-        )
+            allow_remote = synapse.http.servlet.parse_boolean(
+                request, "allow_remote", default=True)
+            if not allow_remote:
+                logger.info(
+                    "Rejecting request for remote media %s/%s due to allow_remote",
+                    server_name, media_id,
+                )
+                respond_404(request)
+                return
+
+            yield self.media_repo.get_remote_media(request, server_name, media_id, name)
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index eed9056a2f..b2c76440b7 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -18,6 +19,7 @@ import twisted.internet.error
 import twisted.web.http
 from twisted.web.resource import Resource
 
+from ._base import respond_404, FileInfo, respond_with_responder
 from .upload_resource import UploadResource
 from .download_resource import DownloadResource
 from .thumbnail_resource import ThumbnailResource
@@ -25,6 +27,10 @@ from .identicon_resource import IdenticonResource
 from .preview_url_resource import PreviewUrlResource
 from .filepath import MediaFilePaths
 from .thumbnailer import Thumbnailer
+from .storage_provider import (
+    StorageProviderWrapper, FileStorageProviderBackend,
+)
+from .media_storage import MediaStorage
 
 from synapse.http.matrixfederationclient import MatrixFederationHttpClient
 from synapse.util.stringutils import random_string
@@ -33,7 +39,7 @@ from synapse.api.errors import SynapseError, HttpResponseException, \
 
 from synapse.util.async import Linearizer
 from synapse.util.stringutils import is_ascii
-from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
+from synapse.util.logcontext import make_deferred_yieldable
 from synapse.util.retryutils import NotRetryingDestination
 
 import os
@@ -47,7 +53,7 @@ import urlparse
 logger = logging.getLogger(__name__)
 
 
-UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000
+UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
 
 
 class MediaRepository(object):
@@ -63,96 +69,64 @@ class MediaRepository(object):
         self.primary_base_path = hs.config.media_store_path
         self.filepaths = MediaFilePaths(self.primary_base_path)
 
-        self.backup_base_path = hs.config.backup_media_store_path
-
-        self.synchronous_backup_media_store = hs.config.synchronous_backup_media_store
-
         self.dynamic_thumbnails = hs.config.dynamic_thumbnails
         self.thumbnail_requirements = hs.config.thumbnail_requirements
 
         self.remote_media_linearizer = Linearizer(name="media_remote")
 
         self.recently_accessed_remotes = set()
+        self.recently_accessed_locals = set()
 
-        self.clock.looping_call(
-            self._update_recently_accessed_remotes,
-            UPDATE_RECENTLY_ACCESSED_REMOTES_TS
-        )
+        # List of StorageProviders where we should search for media and
+        # potentially upload to.
+        storage_providers = []
 
-    @defer.inlineCallbacks
-    def _update_recently_accessed_remotes(self):
-        media = self.recently_accessed_remotes
-        self.recently_accessed_remotes = set()
+        # TODO: Move this into config and allow other storage providers to be
+        # defined.
+        if hs.config.backup_media_store_path:
+            backend = FileStorageProviderBackend(
+                self.primary_base_path, hs.config.backup_media_store_path,
+            )
+            provider = StorageProviderWrapper(
+                backend,
+                store=True,
+                store_synchronous=hs.config.synchronous_backup_media_store,
+                store_remote=True,
+            )
+            storage_providers.append(provider)
 
-        yield self.store.update_cached_last_access_time(
-            media, self.clock.time_msec()
+        self.media_storage = MediaStorage(
+            self.primary_base_path, self.filepaths, storage_providers,
         )
 
-    @staticmethod
-    def _makedirs(filepath):
-        dirname = os.path.dirname(filepath)
-        if not os.path.exists(dirname):
-            os.makedirs(dirname)
-
-    @staticmethod
-    def _write_file_synchronously(source, fname):
-        """Write `source` to the path `fname` synchronously. Should be called
-        from a thread.
-
-        Args:
-            source: A file like object to be written
-            fname (str): Path to write to
-        """
-        MediaRepository._makedirs(fname)
-        source.seek(0)  # Ensure we read from the start of the file
-        with open(fname, "wb") as f:
-            shutil.copyfileobj(source, f)
+        self.clock.looping_call(
+            self._update_recently_accessed,
+            UPDATE_RECENTLY_ACCESSED_TS,
+        )
 
     @defer.inlineCallbacks
-    def write_to_file_and_backup(self, source, path):
-        """Write `source` to the on disk media store, and also the backup store
-        if configured.
-
-        Args:
-            source: A file like object that should be written
-            path (str): Relative path to write file to
-
-        Returns:
-            Deferred[str]: the file path written to in the primary media store
-        """
-        fname = os.path.join(self.primary_base_path, path)
-
-        # Write to the main repository
-        yield make_deferred_yieldable(threads.deferToThread(
-            self._write_file_synchronously, source, fname,
-        ))
+    def _update_recently_accessed(self):
+        remote_media = self.recently_accessed_remotes
+        self.recently_accessed_remotes = set()
 
-        # Write to backup repository
-        yield self.copy_to_backup(path)
+        local_media = self.recently_accessed_locals
+        self.recently_accessed_locals = set()
 
-        defer.returnValue(fname)
+        yield self.store.update_cached_last_access_time(
+            local_media, remote_media, self.clock.time_msec()
+        )
 
-    @defer.inlineCallbacks
-    def copy_to_backup(self, path):
-        """Copy a file from the primary to backup media store, if configured.
+    def mark_recently_accessed(self, server_name, media_id):
+        """Mark the given media as recently accessed.
 
         Args:
-            path(str): Relative path to write file to
+            server_name (str|None): Origin server of media, or None if local
+            media_id (str): The media ID of the content
         """
-        if self.backup_base_path:
-            primary_fname = os.path.join(self.primary_base_path, path)
-            backup_fname = os.path.join(self.backup_base_path, path)
-
-            # We can either wait for successful writing to the backup repository
-            # or write in the background and immediately return
-            if self.synchronous_backup_media_store:
-                yield make_deferred_yieldable(threads.deferToThread(
-                    shutil.copyfile, primary_fname, backup_fname,
-                ))
-            else:
-                preserve_fn(threads.deferToThread)(
-                    shutil.copyfile, primary_fname, backup_fname,
-                )
+        if server_name:
+            self.recently_accessed_remotes.add((server_name, media_id))
+        else:
+            self.recently_accessed_locals.add(media_id)
 
     @defer.inlineCallbacks
     def create_content(self, media_type, upload_name, content, content_length,
@@ -171,10 +145,13 @@ class MediaRepository(object):
         """
         media_id = random_string(24)
 
-        fname = yield self.write_to_file_and_backup(
-            content, self.filepaths.local_media_filepath_rel(media_id)
+        file_info = FileInfo(
+            server_name=None,
+            file_id=media_id,
         )
 
+        fname = yield self.media_storage.store_file(content, file_info)
+
         logger.info("Stored local media in file %r", fname)
 
         yield self.store.store_local_media(
@@ -185,134 +162,263 @@ class MediaRepository(object):
             media_length=content_length,
             user_id=auth_user,
         )
-        media_info = {
-            "media_type": media_type,
-            "media_length": content_length,
-        }
 
-        yield self._generate_thumbnails(None, media_id, media_info)
+        yield self._generate_thumbnails(
+            None, media_id, media_id, media_type,
+        )
 
         defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
 
     @defer.inlineCallbacks
-    def get_remote_media(self, server_name, media_id):
+    def get_local_media(self, request, media_id, name):
+        """Responds to reqests for local media, if exists, or returns 404.
+
+        Args:
+            request(twisted.web.http.Request)
+            media_id (str): The media ID of the content. (This is the same as
+                the file_id for local content.)
+            name (str|None): Optional name that, if specified, will be used as
+                the filename in the Content-Disposition header of the response.
+
+        Returns:
+            Deferred: Resolves once a response has successfully been written
+                to request
+        """
+        media_info = yield self.store.get_local_media(media_id)
+        if not media_info or media_info["quarantined_by"]:
+            respond_404(request)
+            return
+
+        self.mark_recently_accessed(None, media_id)
+
+        media_type = media_info["media_type"]
+        media_length = media_info["media_length"]
+        upload_name = name if name else media_info["upload_name"]
+        url_cache = media_info["url_cache"]
+
+        file_info = FileInfo(
+            None, media_id,
+            url_cache=url_cache,
+        )
+
+        responder = yield self.media_storage.fetch_media(file_info)
+        yield respond_with_responder(
+            request, responder, media_type, media_length, upload_name,
+        )
+
+    @defer.inlineCallbacks
+    def get_remote_media(self, request, server_name, media_id, name):
+        """Respond to requests for remote media.
+
+        Args:
+            request(twisted.web.http.Request)
+            server_name (str): Remote server_name where the media originated.
+            media_id (str): The media ID of the content (as defined by the
+                remote server).
+            name (str|None): Optional name that, if specified, will be used as
+                the filename in the Content-Disposition header of the response.
+
+        Returns:
+            Deferred: Resolves once a response has successfully been written
+                to request
+        """
+        self.mark_recently_accessed(server_name, media_id)
+
+        # We linearize here to ensure that we don't try and download remote
+        # media multiple times concurrently
+        key = (server_name, media_id)
+        with (yield self.remote_media_linearizer.queue(key)):
+            responder, media_info = yield self._get_remote_media_impl(
+                server_name, media_id,
+            )
+
+        # We deliberately stream the file outside the lock
+        if responder:
+            media_type = media_info["media_type"]
+            media_length = media_info["media_length"]
+            upload_name = name if name else media_info["upload_name"]
+            yield respond_with_responder(
+                request, responder, media_type, media_length, upload_name,
+            )
+        else:
+            respond_404(request)
+
+    @defer.inlineCallbacks
+    def get_remote_media_info(self, server_name, media_id):
+        """Gets the media info associated with the remote file, downloading
+        if necessary.
+
+        Args:
+            server_name (str): Remote server_name where the media originated.
+            media_id (str): The media ID of the content (as defined by the
+                remote server).
+
+        Returns:
+            Deferred[dict]: The media_info of the file
+        """
+        # We linearize here to ensure that we don't try and download remote
+        # media multiple times concurrently
         key = (server_name, media_id)
         with (yield self.remote_media_linearizer.queue(key)):
-            media_info = yield self._get_remote_media_impl(server_name, media_id)
+            responder, media_info = yield self._get_remote_media_impl(
+                server_name, media_id,
+            )
+
+        # Ensure we actually use the responder so that it releases resources
+        if responder:
+            with responder:
+                pass
+
         defer.returnValue(media_info)
 
     @defer.inlineCallbacks
     def _get_remote_media_impl(self, server_name, media_id):
+        """Looks for media in local cache, if not there then attempt to
+        download from remote server.
+
+        Args:
+            server_name (str): Remote server_name where the media originated.
+            media_id (str): The media ID of the content (as defined by the
+                remote server).
+
+        Returns:
+            Deferred[(Responder, media_info)]
+        """
         media_info = yield self.store.get_cached_remote_media(
             server_name, media_id
         )
-        if not media_info:
-            media_info = yield self._download_remote_file(
-                server_name, media_id
-            )
-        elif media_info["quarantined_by"]:
-            raise NotFoundError()
+
+        # file_id is the ID we use to track the file locally. If we've already
+        # seen the file then reuse the existing ID, otherwise genereate a new
+        # one.
+        if media_info:
+            file_id = media_info["filesystem_id"]
         else:
-            self.recently_accessed_remotes.add((server_name, media_id))
-            yield self.store.update_cached_last_access_time(
-                [(server_name, media_id)], self.clock.time_msec()
-            )
-        defer.returnValue(media_info)
+            file_id = random_string(24)
+
+        file_info = FileInfo(server_name, file_id)
+
+        # If we have an entry in the DB, try and look for it
+        if media_info:
+            if media_info["quarantined_by"]:
+                logger.info("Media is quarantined")
+                raise NotFoundError()
+
+            responder = yield self.media_storage.fetch_media(file_info)
+            if responder:
+                defer.returnValue((responder, media_info))
+
+        # Failed to find the file anywhere, lets download it.
+
+        media_info = yield self._download_remote_file(
+            server_name, media_id, file_id
+        )
+
+        responder = yield self.media_storage.fetch_media(file_info)
+        defer.returnValue((responder, media_info))
 
     @defer.inlineCallbacks
-    def _download_remote_file(self, server_name, media_id):
-        file_id = random_string(24)
+    def _download_remote_file(self, server_name, media_id, file_id):
+        """Attempt to download the remote file from the given server name,
+        using the given file_id as the local id.
+
+        Args:
+            server_name (str): Originating server
+            media_id (str): The media ID of the content (as defined by the
+                remote server). This is different than the file_id, which is
+                locally generated.
+            file_id (str): Local file ID
+
+        Returns:
+            Deferred[MediaInfo]
+        """
 
-        fpath = self.filepaths.remote_media_filepath_rel(
-            server_name, file_id
+        file_info = FileInfo(
+            server_name=server_name,
+            file_id=file_id,
         )
-        fname = os.path.join(self.primary_base_path, fpath)
-        self._makedirs(fname)
 
-        try:
-            with open(fname, "wb") as f:
-                request_path = "/".join((
-                    "/_matrix/media/v1/download", server_name, media_id,
-                ))
+        with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+            request_path = "/".join((
+                "/_matrix/media/v1/download", server_name, media_id,
+            ))
+            try:
+                length, headers = yield self.client.get_file(
+                    server_name, request_path, output_stream=f,
+                    max_size=self.max_upload_size, args={
+                        # tell the remote server to 404 if it doesn't
+                        # recognise the server_name, to make sure we don't
+                        # end up with a routing loop.
+                        "allow_remote": "false",
+                    }
+                )
+            except twisted.internet.error.DNSLookupError as e:
+                logger.warn("HTTP error fetching remote media %s/%s: %r",
+                            server_name, media_id, e)
+                raise NotFoundError()
+
+            except HttpResponseException as e:
+                logger.warn("HTTP error fetching remote media %s/%s: %s",
+                            server_name, media_id, e.response)
+                if e.code == twisted.web.http.NOT_FOUND:
+                    raise SynapseError.from_http_response_exception(e)
+                raise SynapseError(502, "Failed to fetch remote media")
+
+            except SynapseError:
+                logger.exception("Failed to fetch remote media %s/%s",
+                                 server_name, media_id)
+                raise
+            except NotRetryingDestination:
+                logger.warn("Not retrying destination %r", server_name)
+                raise SynapseError(502, "Failed to fetch remote media")
+            except Exception:
+                logger.exception("Failed to fetch remote media %s/%s",
+                                 server_name, media_id)
+                raise SynapseError(502, "Failed to fetch remote media")
+
+            yield finish()
+
+        media_type = headers["Content-Type"][0]
+
+        time_now_ms = self.clock.time_msec()
+
+        content_disposition = headers.get("Content-Disposition", None)
+        if content_disposition:
+            _, params = cgi.parse_header(content_disposition[0],)
+            upload_name = None
+
+            # First check if there is a valid UTF-8 filename
+            upload_name_utf8 = params.get("filename*", None)
+            if upload_name_utf8:
+                if upload_name_utf8.lower().startswith("utf-8''"):
+                    upload_name = upload_name_utf8[7:]
+
+            # If there isn't check for an ascii name.
+            if not upload_name:
+                upload_name_ascii = params.get("filename", None)
+                if upload_name_ascii and is_ascii(upload_name_ascii):
+                    upload_name = upload_name_ascii
+
+            if upload_name:
+                upload_name = urlparse.unquote(upload_name)
                 try:
-                    length, headers = yield self.client.get_file(
-                        server_name, request_path, output_stream=f,
-                        max_size=self.max_upload_size, args={
-                            # tell the remote server to 404 if it doesn't
-                            # recognise the server_name, to make sure we don't
-                            # end up with a routing loop.
-                            "allow_remote": "false",
-                        }
-                    )
-                except twisted.internet.error.DNSLookupError as e:
-                    logger.warn("HTTP error fetching remote media %s/%s: %r",
-                                server_name, media_id, e)
-                    raise NotFoundError()
-
-                except HttpResponseException as e:
-                    logger.warn("HTTP error fetching remote media %s/%s: %s",
-                                server_name, media_id, e.response)
-                    if e.code == twisted.web.http.NOT_FOUND:
-                        raise SynapseError.from_http_response_exception(e)
-                    raise SynapseError(502, "Failed to fetch remote media")
-
-                except SynapseError:
-                    logger.exception("Failed to fetch remote media %s/%s",
-                                     server_name, media_id)
-                    raise
-                except NotRetryingDestination:
-                    logger.warn("Not retrying destination %r", server_name)
-                    raise SynapseError(502, "Failed to fetch remote media")
-                except Exception:
-                    logger.exception("Failed to fetch remote media %s/%s",
-                                     server_name, media_id)
-                    raise SynapseError(502, "Failed to fetch remote media")
-
-            yield self.copy_to_backup(fpath)
-
-            media_type = headers["Content-Type"][0]
-            time_now_ms = self.clock.time_msec()
-
-            content_disposition = headers.get("Content-Disposition", None)
-            if content_disposition:
-                _, params = cgi.parse_header(content_disposition[0],)
-                upload_name = None
-
-                # First check if there is a valid UTF-8 filename
-                upload_name_utf8 = params.get("filename*", None)
-                if upload_name_utf8:
-                    if upload_name_utf8.lower().startswith("utf-8''"):
-                        upload_name = upload_name_utf8[7:]
-
-                # If there isn't check for an ascii name.
-                if not upload_name:
-                    upload_name_ascii = params.get("filename", None)
-                    if upload_name_ascii and is_ascii(upload_name_ascii):
-                        upload_name = upload_name_ascii
-
-                if upload_name:
-                    upload_name = urlparse.unquote(upload_name)
-                    try:
-                        upload_name = upload_name.decode("utf-8")
-                    except UnicodeDecodeError:
-                        upload_name = None
-            else:
-                upload_name = None
-
-            logger.info("Stored remote media in file %r", fname)
-
-            yield self.store.store_cached_remote_media(
-                origin=server_name,
-                media_id=media_id,
-                media_type=media_type,
-                time_now_ms=self.clock.time_msec(),
-                upload_name=upload_name,
-                media_length=length,
-                filesystem_id=file_id,
-            )
-        except Exception:
-            os.remove(fname)
-            raise
+                    upload_name = upload_name.decode("utf-8")
+                except UnicodeDecodeError:
+                    upload_name = None
+        else:
+            upload_name = None
+
+        logger.info("Stored remote media in file %r", fname)
+
+        yield self.store.store_cached_remote_media(
+            origin=server_name,
+            media_id=media_id,
+            media_type=media_type,
+            time_now_ms=self.clock.time_msec(),
+            upload_name=upload_name,
+            media_length=length,
+            filesystem_id=file_id,
+        )
 
         media_info = {
             "media_type": media_type,
@@ -323,7 +429,7 @@ class MediaRepository(object):
         }
 
         yield self._generate_thumbnails(
-            server_name, media_id, media_info
+            server_name, media_id, file_id, media_type,
         )
 
         defer.returnValue(media_info)
@@ -368,11 +474,18 @@ class MediaRepository(object):
 
         if t_byte_source:
             try:
-                output_path = yield self.write_to_file_and_backup(
-                    t_byte_source,
-                    self.filepaths.local_media_thumbnail_rel(
-                        media_id, t_width, t_height, t_type, t_method
-                    )
+                file_info = FileInfo(
+                    server_name=None,
+                    file_id=media_id,
+                    thumbnail=True,
+                    thumbnail_width=t_width,
+                    thumbnail_height=t_height,
+                    thumbnail_method=t_method,
+                    thumbnail_type=t_type,
+                )
+
+                output_path = yield self.media_storage.store_file(
+                    t_byte_source, file_info,
                 )
             finally:
                 t_byte_source.close()
@@ -400,11 +513,18 @@ class MediaRepository(object):
 
         if t_byte_source:
             try:
-                output_path = yield self.write_to_file_and_backup(
-                    t_byte_source,
-                    self.filepaths.remote_media_thumbnail_rel(
-                        server_name, file_id, t_width, t_height, t_type, t_method
-                    )
+                file_info = FileInfo(
+                    server_name=server_name,
+                    file_id=media_id,
+                    thumbnail=True,
+                    thumbnail_width=t_width,
+                    thumbnail_height=t_height,
+                    thumbnail_method=t_method,
+                    thumbnail_type=t_type,
+                )
+
+                output_path = yield self.media_storage.store_file(
+                    t_byte_source, file_info,
                 )
             finally:
                 t_byte_source.close()
@@ -421,21 +541,22 @@ class MediaRepository(object):
             defer.returnValue(output_path)
 
     @defer.inlineCallbacks
-    def _generate_thumbnails(self, server_name, media_id, media_info, url_cache=False):
+    def _generate_thumbnails(self, server_name, media_id, file_id, media_type,
+                             url_cache=False):
         """Generate and store thumbnails for an image.
 
         Args:
-            server_name(str|None): The server name if remote media, else None if local
-            media_id(str)
-            media_info(dict)
-            url_cache(bool): If we are thumbnailing images downloaded for the URL cache,
+            server_name (str|None): The server name if remote media, else None if local
+            media_id (str): The media ID of the content. (This is the same as
+                the file_id for local content)
+            file_id (str): Local file ID
+            media_type (str): The content type of the file
+            url_cache (bool): If we are thumbnailing images downloaded for the URL cache,
                 used exclusively by the url previewer
 
         Returns:
             Deferred[dict]: Dict with "width" and "height" keys of original image
         """
-        media_type = media_info["media_type"]
-        file_id = media_info.get("filesystem_id")
         requirements = self._get_thumbnail_requirements(media_type)
         if not requirements:
             return
@@ -472,20 +593,6 @@ class MediaRepository(object):
 
         # Now we generate the thumbnails for each dimension, store it
         for (t_width, t_height, t_type), t_method in thumbnails.iteritems():
-            # Work out the correct file name for thumbnail
-            if server_name:
-                file_path = self.filepaths.remote_media_thumbnail_rel(
-                    server_name, file_id, t_width, t_height, t_type, t_method
-                )
-            elif url_cache:
-                file_path = self.filepaths.url_cache_thumbnail_rel(
-                    media_id, t_width, t_height, t_type, t_method
-                )
-            else:
-                file_path = self.filepaths.local_media_thumbnail_rel(
-                    media_id, t_width, t_height, t_type, t_method
-                )
-
             # Generate the thumbnail
             if t_method == "crop":
                 t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
@@ -505,9 +612,19 @@ class MediaRepository(object):
                 continue
 
             try:
-                # Write to disk
-                output_path = yield self.write_to_file_and_backup(
-                    t_byte_source, file_path,
+                file_info = FileInfo(
+                    server_name=server_name,
+                    file_id=file_id,
+                    thumbnail=True,
+                    thumbnail_width=t_width,
+                    thumbnail_height=t_height,
+                    thumbnail_method=t_method,
+                    thumbnail_type=t_type,
+                    url_cache=url_cache,
+                )
+
+                output_path = yield self.media_storage.store_file(
+                    t_byte_source, file_info,
                 )
             finally:
                 t_byte_source.close()
@@ -620,7 +737,11 @@ class MediaRepositoryResource(Resource):
 
         self.putChild("upload", UploadResource(hs, media_repo))
         self.putChild("download", DownloadResource(hs, media_repo))
-        self.putChild("thumbnail", ThumbnailResource(hs, media_repo))
+        self.putChild("thumbnail", ThumbnailResource(
+            hs, media_repo, media_repo.media_storage,
+        ))
         self.putChild("identicon", IdenticonResource())
         if hs.config.url_preview_enabled:
-            self.putChild("preview_url", PreviewUrlResource(hs, media_repo))
+            self.putChild("preview_url", PreviewUrlResource(
+                hs, media_repo, media_repo.media_storage,
+            ))
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
new file mode 100644
index 0000000000..001e84578e
--- /dev/null
+++ b/synapse/rest/media/v1/media_storage.py
@@ -0,0 +1,228 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vecotr Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer, threads
+from twisted.protocols.basic import FileSender
+
+from ._base import Responder
+
+from synapse.util.logcontext import make_deferred_yieldable
+
+import contextlib
+import os
+import logging
+import shutil
+import sys
+
+logger = logging.getLogger(__name__)
+
+
+class MediaStorage(object):
+    """Responsible for storing/fetching files from local sources.
+
+    Args:
+        local_media_directory (str): Base path where we store media on disk
+        filepaths (MediaFilePaths)
+        storage_providers ([StorageProvider]): List of StorageProvider that are
+            used to fetch and store files.
+    """
+
+    def __init__(self, local_media_directory, filepaths, storage_providers):
+        self.local_media_directory = local_media_directory
+        self.filepaths = filepaths
+        self.storage_providers = storage_providers
+
+    @defer.inlineCallbacks
+    def store_file(self, source, file_info):
+        """Write `source` to the on disk media store, and also any other
+        configured storage providers
+
+        Args:
+            source: A file like object that should be written
+            file_info (FileInfo): Info about the file to store
+
+        Returns:
+            Deferred[str]: the file path written to in the primary media store
+        """
+        path = self._file_info_to_path(file_info)
+        fname = os.path.join(self.local_media_directory, path)
+
+        dirname = os.path.dirname(fname)
+        if not os.path.exists(dirname):
+            os.makedirs(dirname)
+
+        # Write to the main repository
+        yield make_deferred_yieldable(threads.deferToThread(
+            _write_file_synchronously, source, fname,
+        ))
+
+        defer.returnValue(fname)
+
+    @contextlib.contextmanager
+    def store_into_file(self, file_info):
+        """Context manager used to get a file like object to write into, as
+        described by file_info.
+
+        Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
+        like object that can be written to, fname is the absolute path of file
+        on disk, and finish_cb is a function that returns a Deferred.
+
+        fname can be used to read the contents from after upload, e.g. to
+        generate thumbnails.
+
+        finish_cb must be called and waited on after the file has been
+        successfully been written to. Should not be called if there was an
+        error.
+
+        Args:
+            file_info (FileInfo): Info about the file to store
+
+        Example:
+
+            with media_storage.store_into_file(info) as (f, fname, finish_cb):
+                # .. write into f ...
+                yield finish_cb()
+        """
+
+        path = self._file_info_to_path(file_info)
+        fname = os.path.join(self.local_media_directory, path)
+
+        dirname = os.path.dirname(fname)
+        if not os.path.exists(dirname):
+            os.makedirs(dirname)
+
+        finished_called = [False]
+
+        @defer.inlineCallbacks
+        def finish():
+            for provider in self.storage_providers:
+                yield provider.store_file(path, file_info)
+
+            finished_called[0] = True
+
+        try:
+            with open(fname, "wb") as f:
+                yield f, fname, finish
+        except Exception:
+            t, v, tb = sys.exc_info()
+            try:
+                os.remove(fname)
+            except Exception:
+                pass
+            raise t, v, tb
+
+        if not finished_called:
+            raise Exception("Finished callback not called")
+
+    @defer.inlineCallbacks
+    def fetch_media(self, file_info):
+        """Attempts to fetch media described by file_info from the local cache
+        and configured storage providers.
+
+        Args:
+            file_info (FileInfo)
+
+        Returns:
+            Deferred[Responder|None]: Returns a Responder if the file was found,
+                otherwise None.
+        """
+
+        path = self._file_info_to_path(file_info)
+        local_path = os.path.join(self.local_media_directory, path)
+        if os.path.exists(local_path):
+            defer.returnValue(FileResponder(open(local_path, "rb")))
+
+        for provider in self.storage_providers:
+            res = yield provider.fetch(path, file_info)
+            if res:
+                defer.returnValue(res)
+
+        defer.returnValue(None)
+
+    def _file_info_to_path(self, file_info):
+        """Converts file_info into a relative path.
+
+        The path is suitable for storing files under a directory, e.g. used to
+        store files on local FS under the base media repository directory.
+
+        Args:
+            file_info (FileInfo)
+
+        Returns:
+            str
+        """
+        if file_info.url_cache:
+            return self.filepaths.url_cache_filepath_rel(file_info.file_id)
+
+        if file_info.server_name:
+            if file_info.thumbnail:
+                return self.filepaths.remote_media_thumbnail_rel(
+                    server_name=file_info.server_name,
+                    file_id=file_info.file_id,
+                    width=file_info.thumbnail_width,
+                    height=file_info.thumbnail_height,
+                    content_type=file_info.thumbnail_type,
+                    method=file_info.thumbnail_method
+                )
+            return self.filepaths.remote_media_filepath_rel(
+                file_info.server_name, file_info.file_id,
+            )
+
+        if file_info.thumbnail:
+            return self.filepaths.local_media_thumbnail_rel(
+                media_id=file_info.file_id,
+                width=file_info.thumbnail_width,
+                height=file_info.thumbnail_height,
+                content_type=file_info.thumbnail_type,
+                method=file_info.thumbnail_method
+            )
+        return self.filepaths.local_media_filepath_rel(
+            file_info.file_id,
+        )
+
+
+def _write_file_synchronously(source, fname):
+    """Write `source` to the path `fname` synchronously. Should be called
+    from a thread.
+
+    Args:
+        source: A file like object to be written
+        fname (str): Path to write to
+    """
+    dirname = os.path.dirname(fname)
+    if not os.path.exists(dirname):
+        os.makedirs(dirname)
+
+    source.seek(0)  # Ensure we read from the start of the file
+    with open(fname, "wb") as f:
+        shutil.copyfileobj(source, f)
+
+
+class FileResponder(Responder):
+    """Wraps an open file that can be sent to a request.
+
+    Args:
+        open_file (file): A file like object to be streamed ot the client,
+            is closed when finished streaming.
+    """
+    def __init__(self, open_file):
+        self.open_file = open_file
+
+    @defer.inlineCallbacks
+    def write_to_consumer(self, consumer):
+        yield FileSender().beginFileTransfer(self.open_file, consumer)
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.open_file.close()
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 40d2e664eb..981f01e417 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -17,6 +17,8 @@ from twisted.web.server import NOT_DONE_YET
 from twisted.internet import defer
 from twisted.web.resource import Resource
 
+from ._base import FileInfo
+
 from synapse.api.errors import (
     SynapseError, Codes,
 )
@@ -49,7 +51,7 @@ logger = logging.getLogger(__name__)
 class PreviewUrlResource(Resource):
     isLeaf = True
 
-    def __init__(self, hs, media_repo):
+    def __init__(self, hs, media_repo, media_storage):
         Resource.__init__(self)
 
         self.auth = hs.get_auth()
@@ -62,6 +64,7 @@ class PreviewUrlResource(Resource):
         self.client = SpiderHttpClient(hs)
         self.media_repo = media_repo
         self.primary_base_path = media_repo.primary_base_path
+        self.media_storage = media_storage
 
         self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
 
@@ -182,8 +185,10 @@ class PreviewUrlResource(Resource):
         logger.debug("got media_info of '%s'" % media_info)
 
         if _is_media(media_info['media_type']):
+            file_id = media_info['filesystem_id']
             dims = yield self.media_repo._generate_thumbnails(
-                None, media_info['filesystem_id'], media_info, url_cache=True,
+                None, file_id, file_id, media_info["media_type"],
+                url_cache=True,
             )
 
             og = {
@@ -228,8 +233,10 @@ class PreviewUrlResource(Resource):
 
                 if _is_media(image_info['media_type']):
                     # TODO: make sure we don't choke on white-on-transparent images
+                    file_id = image_info['filesystem_id']
                     dims = yield self.media_repo._generate_thumbnails(
-                        None, image_info['filesystem_id'], image_info, url_cache=True,
+                        None, file_id, file_id, image_info["media_type"],
+                        url_cache=True,
                     )
                     if dims:
                         og["og:image:width"] = dims['width']
@@ -273,19 +280,21 @@ class PreviewUrlResource(Resource):
 
         file_id = datetime.date.today().isoformat() + '_' + random_string(16)
 
-        fpath = self.filepaths.url_cache_filepath_rel(file_id)
-        fname = os.path.join(self.primary_base_path, fpath)
-        self.media_repo._makedirs(fname)
+        file_info = FileInfo(
+            server_name=None,
+            file_id=file_id,
+            url_cache=True,
+        )
 
         try:
-            with open(fname, "wb") as f:
+            with self.media_storage.store_into_file(file_info) as (f, fname, finish):
                 logger.debug("Trying to get url '%s'" % url)
                 length, headers, uri, code = yield self.client.get_file(
                     url, output_stream=f, max_size=self.max_spider_size,
                 )
                 # FIXME: pass through 404s and other error messages nicely
 
-            yield self.media_repo.copy_to_backup(fpath)
+                yield finish()
 
             media_type = headers["Content-Type"][0]
             time_now_ms = self.clock.time_msec()
@@ -327,7 +336,6 @@ class PreviewUrlResource(Resource):
             )
 
         except Exception as e:
-            os.remove(fname)
             raise SynapseError(
                 500, ("Failed to download content: %s" % e),
                 Codes.UNKNOWN
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
new file mode 100644
index 0000000000..2ad602e101
--- /dev/null
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -0,0 +1,127 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer, threads
+
+from .media_storage import FileResponder
+
+from synapse.util.logcontext import preserve_fn
+
+import logging
+import os
+import shutil
+
+
+logger = logging.getLogger(__name__)
+
+
+class StorageProvider(object):
+    """A storage provider is a service that can store uploaded media and
+    retrieve them.
+    """
+    def store_file(self, path, file_info):
+        """Store the file described by file_info. The actual contents can be
+        retrieved by reading the file in file_info.upload_path.
+
+        Args:
+            path (str): Relative path of file in local cache
+            file_info (FileInfo)
+
+        Returns:
+            Deferred
+        """
+        pass
+
+    def fetch(self, path, file_info):
+        """Attempt to fetch the file described by file_info and stream it
+        into writer.
+
+        Args:
+            path (str): Relative path of file in local cache
+            file_info (FileInfo)
+
+        Returns:
+            Deferred(Responder): Returns a Responder if the provider has the file,
+                otherwise returns None.
+        """
+        pass
+
+
+class StorageProviderWrapper(StorageProvider):
+    """Wraps a storage provider and provides various config options
+
+    Args:
+        backend (StorageProvider)
+        store (bool): Whether to store new files or not.
+        store_synchronous (bool): Whether to wait for file to be successfully
+            uploaded, or todo the upload in the backgroud.
+        store_remote (bool): Whether remote media should be uploaded
+    """
+    def __init__(self, backend, store, store_synchronous, store_remote):
+        self.backend = backend
+        self.store = store
+        self.store_synchronous = store_synchronous
+        self.store_remote = store_remote
+
+    def store_file(self, path, file_info):
+        if not self.store:
+            return defer.succeed(None)
+
+        if file_info.server_name and not self.store_remote:
+            return defer.succeed(None)
+
+        if self.store_synchronous:
+            return self.backend.store_file(path, file_info)
+        else:
+            # TODO: Handle errors.
+            preserve_fn(self.backend.store_file)(path, file_info)
+            return defer.succeed(None)
+
+    def fetch(self, path, file_info):
+        return self.backend.fetch(path, file_info)
+
+
+class FileStorageProviderBackend(StorageProvider):
+    """A storage provider that stores files in a directory on a filesystem.
+
+    Args:
+        cache_directory (str): Base path of the local media repository
+        base_directory (str): Base path to store new files
+    """
+
+    def __init__(self, cache_directory, base_directory):
+        self.cache_directory = cache_directory
+        self.base_directory = base_directory
+
+    def store_file(self, path, file_info):
+        """See StorageProvider.store_file"""
+
+        primary_fname = os.path.join(self.cache_directory, path)
+        backup_fname = os.path.join(self.base_directory, path)
+
+        dirname = os.path.dirname(backup_fname)
+        if not os.path.exists(dirname):
+            os.makedirs(dirname)
+
+        return threads.deferToThread(
+            shutil.copyfile, primary_fname, backup_fname,
+        )
+
+    def fetch(self, path, file_info):
+        """See StorageProvider.fetch"""
+
+        backup_fname = os.path.join(self.base_directory, path)
+        if os.path.isfile(backup_fname):
+            return FileResponder(open(backup_fname, "rb"))
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 53e48aba22..c09f2dec41 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -14,7 +14,10 @@
 # limitations under the License.
 
 
-from ._base import parse_media_id, respond_404, respond_with_file
+from ._base import (
+    parse_media_id, respond_404, respond_with_file, FileInfo,
+    respond_with_responder,
+)
 from twisted.web.resource import Resource
 from synapse.http.servlet import parse_string, parse_integer
 from synapse.http.server import request_handler, set_cors_headers
@@ -30,12 +33,12 @@ logger = logging.getLogger(__name__)
 class ThumbnailResource(Resource):
     isLeaf = True
 
-    def __init__(self, hs, media_repo):
+    def __init__(self, hs, media_repo, media_storage):
         Resource.__init__(self)
 
         self.store = hs.get_datastore()
-        self.filepaths = media_repo.filepaths
         self.media_repo = media_repo
+        self.media_storage = media_storage
         self.dynamic_thumbnails = hs.config.dynamic_thumbnails
         self.server_name = hs.hostname
         self.version_string = hs.version_string
@@ -64,6 +67,7 @@ class ThumbnailResource(Resource):
                 yield self._respond_local_thumbnail(
                     request, media_id, width, height, method, m_type
                 )
+            self.media_repo.mark_recently_accessed(server_name, media_id)
         else:
             if self.dynamic_thumbnails:
                 yield self._select_or_generate_remote_thumbnail(
@@ -75,13 +79,18 @@ class ThumbnailResource(Resource):
                     request, server_name, media_id,
                     width, height, method, m_type
                 )
+            self.media_repo.mark_recently_accessed(None, media_id)
 
     @defer.inlineCallbacks
     def _respond_local_thumbnail(self, request, media_id, width, height,
                                  method, m_type):
         media_info = yield self.store.get_local_media(media_id)
 
-        if not media_info or media_info["quarantined_by"]:
+        if not media_info:
+            respond_404(request)
+            return
+        if media_info["quarantined_by"]:
+            logger.info("Media is quarantined")
             respond_404(request)
             return
 
@@ -91,24 +100,24 @@ class ThumbnailResource(Resource):
             thumbnail_info = self._select_thumbnail(
                 width, height, method, m_type, thumbnail_infos
             )
-            t_width = thumbnail_info["thumbnail_width"]
-            t_height = thumbnail_info["thumbnail_height"]
-            t_type = thumbnail_info["thumbnail_type"]
-            t_method = thumbnail_info["thumbnail_method"]
-
-            if media_info["url_cache"]:
-                # TODO: Check the file still exists, if it doesn't we can redownload
-                # it from the url `media_info["url_cache"]`
-                file_path = self.filepaths.url_cache_thumbnail(
-                    media_id, t_width, t_height, t_type, t_method,
-                )
-            else:
-                file_path = self.filepaths.local_media_thumbnail(
-                    media_id, t_width, t_height, t_type, t_method,
-                )
-            yield respond_with_file(request, t_type, file_path)
 
+            file_info = FileInfo(
+                server_name=None, file_id=media_id,
+                url_cache=media_info["url_cache"],
+                thumbnail=True,
+                thumbnail_width=thumbnail_info["thumbnail_width"],
+                thumbnail_height=thumbnail_info["thumbnail_height"],
+                thumbnail_type=thumbnail_info["thumbnail_type"],
+                thumbnail_method=thumbnail_info["thumbnail_method"],
+            )
+
+            t_type = file_info.thumbnail_type
+            t_length = thumbnail_info["thumbnail_length"]
+
+            responder = yield self.media_storage.fetch_media(file_info)
+            yield respond_with_responder(request, responder, t_type, t_length)
         else:
+            logger.info("Couldn't find any generated thumbnails")
             respond_404(request)
 
     @defer.inlineCallbacks
@@ -117,7 +126,11 @@ class ThumbnailResource(Resource):
                                             desired_type):
         media_info = yield self.store.get_local_media(media_id)
 
-        if not media_info or media_info["quarantined_by"]:
+        if not media_info:
+            respond_404(request)
+            return
+        if media_info["quarantined_by"]:
+            logger.info("Media is quarantined")
             respond_404(request)
             return
 
@@ -129,22 +142,25 @@ class ThumbnailResource(Resource):
             t_type = info["thumbnail_type"] == desired_type
 
             if t_w and t_h and t_method and t_type:
-                if media_info["url_cache"]:
-                    # TODO: Check the file still exists, if it doesn't we can redownload
-                    # it from the url `media_info["url_cache"]`
-                    file_path = self.filepaths.url_cache_thumbnail(
-                        media_id, desired_width, desired_height, desired_type,
-                        desired_method,
-                    )
-                else:
-                    file_path = self.filepaths.local_media_thumbnail(
-                        media_id, desired_width, desired_height, desired_type,
-                        desired_method,
-                    )
-                yield respond_with_file(request, desired_type, file_path)
-                return
-
-        logger.debug("We don't have a local thumbnail of that size. Generating")
+                file_info = FileInfo(
+                    server_name=None, file_id=media_id,
+                    url_cache=media_info["url_cache"],
+                    thumbnail=True,
+                    thumbnail_width=info["thumbnail_width"],
+                    thumbnail_height=info["thumbnail_height"],
+                    thumbnail_type=info["thumbnail_type"],
+                    thumbnail_method=info["thumbnail_method"],
+                )
+
+                t_type = file_info.thumbnail_type
+                t_length = info["thumbnail_length"]
+
+                responder = yield self.media_storage.fetch_media(file_info)
+                if responder:
+                    yield respond_with_responder(request, responder, t_type, t_length)
+                    return
+
+        logger.debug("We don't have a thumbnail of that size. Generating")
 
         # Okay, so we generate one.
         file_path = yield self.media_repo.generate_local_exact_thumbnail(
@@ -154,13 +170,14 @@ class ThumbnailResource(Resource):
         if file_path:
             yield respond_with_file(request, desired_type, file_path)
         else:
+            logger.warn("Failed to generate thumbnail")
             respond_404(request)
 
     @defer.inlineCallbacks
     def _select_or_generate_remote_thumbnail(self, request, server_name, media_id,
                                              desired_width, desired_height,
                                              desired_method, desired_type):
-        media_info = yield self.media_repo.get_remote_media(server_name, media_id)
+        media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
 
         thumbnail_infos = yield self.store.get_remote_media_thumbnails(
             server_name, media_id,
@@ -175,14 +192,24 @@ class ThumbnailResource(Resource):
             t_type = info["thumbnail_type"] == desired_type
 
             if t_w and t_h and t_method and t_type:
-                file_path = self.filepaths.remote_media_thumbnail(
-                    server_name, file_id, desired_width, desired_height,
-                    desired_type, desired_method,
+                file_info = FileInfo(
+                    server_name=server_name, file_id=media_info["filesystem_id"],
+                    thumbnail=True,
+                    thumbnail_width=info["thumbnail_width"],
+                    thumbnail_height=info["thumbnail_height"],
+                    thumbnail_type=info["thumbnail_type"],
+                    thumbnail_method=info["thumbnail_method"],
                 )
-                yield respond_with_file(request, desired_type, file_path)
-                return
 
-        logger.debug("We don't have a local thumbnail of that size. Generating")
+                t_type = file_info.thumbnail_type
+                t_length = info["thumbnail_length"]
+
+                responder = yield self.media_storage.fetch_media(file_info)
+                if responder:
+                    yield respond_with_responder(request, responder, t_type, t_length)
+                    return
+
+        logger.debug("We don't have a thumbnail of that size. Generating")
 
         # Okay, so we generate one.
         file_path = yield self.media_repo.generate_remote_exact_thumbnail(
@@ -193,6 +220,7 @@ class ThumbnailResource(Resource):
         if file_path:
             yield respond_with_file(request, desired_type, file_path)
         else:
+            logger.warn("Failed to generate thumbnail")
             respond_404(request)
 
     @defer.inlineCallbacks
@@ -201,7 +229,7 @@ class ThumbnailResource(Resource):
         # TODO: Don't download the whole remote file
         # We should proxy the thumbnail from the remote server instead of
         # downloading the remote file and generating our own thumbnails.
-        yield self.media_repo.get_remote_media(server_name, media_id)
+        media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
 
         thumbnail_infos = yield self.store.get_remote_media_thumbnails(
             server_name, media_id,
@@ -211,18 +239,22 @@ class ThumbnailResource(Resource):
             thumbnail_info = self._select_thumbnail(
                 width, height, method, m_type, thumbnail_infos
             )
-            t_width = thumbnail_info["thumbnail_width"]
-            t_height = thumbnail_info["thumbnail_height"]
-            t_type = thumbnail_info["thumbnail_type"]
-            t_method = thumbnail_info["thumbnail_method"]
-            file_id = thumbnail_info["filesystem_id"]
+            file_info = FileInfo(
+                server_name=server_name, file_id=media_info["filesystem_id"],
+                thumbnail=True,
+                thumbnail_width=thumbnail_info["thumbnail_width"],
+                thumbnail_height=thumbnail_info["thumbnail_height"],
+                thumbnail_type=thumbnail_info["thumbnail_type"],
+                thumbnail_method=thumbnail_info["thumbnail_method"],
+            )
+
+            t_type = file_info.thumbnail_type
             t_length = thumbnail_info["thumbnail_length"]
 
-            file_path = self.filepaths.remote_media_thumbnail(
-                server_name, file_id, t_width, t_height, t_type, t_method,
-            )
-            yield respond_with_file(request, t_type, file_path, t_length)
+            responder = yield self.media_storage.fetch_media(file_info)
+            yield respond_with_responder(request, responder, t_type, t_length)
         else:
+            logger.info("Failed to find any generated thumbnails")
             respond_404(request)
 
     def _select_thumbnail(self, desired_width, desired_height, desired_method,
diff --git a/synapse/state.py b/synapse/state.py
index 9e624b4937..1f9abf9d3d 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -341,7 +341,7 @@ class StateHandler(object):
             if conflicted_state:
                 logger.info("Resolving conflicted state for %r", room_id)
                 with Measure(self.clock, "state._resolve_events"):
-                    new_state = yield resolve_events(
+                    new_state = yield resolve_events_with_factory(
                         state_groups_ids.values(),
                         state_map_factory=lambda ev_ids: self.store.get_events(
                             ev_ids, get_prev_content=False, check_redacted=False,
@@ -404,7 +404,7 @@ class StateHandler(object):
         }
 
         with Measure(self.clock, "state._resolve_events"):
-            new_state = resolve_events(state_set_ids, state_map)
+            new_state = resolve_events_with_state_map(state_set_ids, state_map)
 
         new_state = {
             key: state_map[ev_id] for key, ev_id in new_state.items()
@@ -420,19 +420,17 @@ def _ordered_events(events):
     return sorted(events, key=key_func)
 
 
-def resolve_events(state_sets, state_map_factory):
+def resolve_events_with_state_map(state_sets, state_map):
     """
     Args:
         state_sets(list): List of dicts of (type, state_key) -> event_id,
             which are the different state groups to resolve.
-        state_map_factory(dict|callable): If callable, then will be called
-            with a list of event_ids that are needed, and should return with
-            a Deferred of dict of event_id to event. Otherwise, should be
-            a dict from event_id to event of all events in state_sets.
+        state_map(dict): a dict from event_id to event, for all events in
+            state_sets.
 
     Returns
-        dict[(str, str), synapse.events.FrozenEvent] is a map from
-        (type, state_key) to event.
+        dict[(str, str), synapse.events.FrozenEvent]:
+            a map from (type, state_key) to event.
     """
     if len(state_sets) == 1:
         return state_sets[0]
@@ -441,13 +439,6 @@ def resolve_events(state_sets, state_map_factory):
         state_sets,
     )
 
-    if callable(state_map_factory):
-        return _resolve_with_state_fac(
-            unconflicted_state, conflicted_state, state_map_factory
-        )
-
-    state_map = state_map_factory
-
     auth_events = _create_auth_events_from_maps(
         unconflicted_state, conflicted_state, state_map
     )
@@ -491,8 +482,26 @@ def _seperate(state_sets):
 
 
 @defer.inlineCallbacks
-def _resolve_with_state_fac(unconflicted_state, conflicted_state,
-                            state_map_factory):
+def resolve_events_with_factory(state_sets, state_map_factory):
+    """
+    Args:
+        state_sets(list): List of dicts of (type, state_key) -> event_id,
+            which are the different state groups to resolve.
+        state_map_factory(func): will be called
+            with a list of event_ids that are needed, and should return with
+            a Deferred of dict of event_id to event.
+
+    Returns
+        Deferred[dict[(str, str), synapse.events.FrozenEvent]]:
+            a map from (type, state_key) to event.
+    """
+    if len(state_sets) == 1:
+        defer.returnValue(state_sets[0])
+
+    unconflicted_state, conflicted_state = _seperate(
+        state_sets,
+    )
+
     needed_events = set(
         event_id
         for event_ids in conflicted_state.itervalues()
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index b971f0cb18..68125006eb 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -291,33 +291,33 @@ class SQLBaseStore(object):
 
     @defer.inlineCallbacks
     def runInteraction(self, desc, func, *args, **kwargs):
-        """Wraps the .runInteraction() method on the underlying db_pool."""
-        current_context = LoggingContext.current_context()
+        """Starts a transaction on the database and runs a given function
 
-        start_time = time.time() * 1000
+        Arguments:
+            desc (str): description of the transaction, for logging and metrics
+            func (func): callback function, which will be called with a
+                database transaction (twisted.enterprise.adbapi.Transaction) as
+                its first argument, followed by `args` and `kwargs`.
+
+            args (list): positional args to pass to `func`
+            kwargs (dict): named args to pass to `func`
+
+        Returns:
+            Deferred: The result of func
+        """
+        current_context = LoggingContext.current_context()
 
         after_callbacks = []
         final_callbacks = []
 
         def inner_func(conn, *args, **kwargs):
-            with LoggingContext("runInteraction") as context:
-                sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
-
-                if self.database_engine.is_connection_closed(conn):
-                    logger.debug("Reconnecting closed database connection")
-                    conn.reconnect()
-
-                current_context.copy_to(context)
-                return self._new_transaction(
-                    conn, desc, after_callbacks, final_callbacks, current_context,
-                    func, *args, **kwargs
-                )
+            return self._new_transaction(
+                conn, desc, after_callbacks, final_callbacks, current_context,
+                func, *args, **kwargs
+            )
 
         try:
-            with PreserveLoggingContext():
-                result = yield self._db_pool.runWithConnection(
-                    inner_func, *args, **kwargs
-                )
+            result = yield self.runWithConnection(inner_func, *args, **kwargs)
 
             for after_callback, after_args, after_kwargs in after_callbacks:
                 after_callback(*after_args, **after_kwargs)
@@ -329,14 +329,27 @@ class SQLBaseStore(object):
 
     @defer.inlineCallbacks
     def runWithConnection(self, func, *args, **kwargs):
-        """Wraps the .runInteraction() method on the underlying db_pool."""
+        """Wraps the .runWithConnection() method on the underlying db_pool.
+
+        Arguments:
+            func (func): callback function, which will be called with a
+                database connection (twisted.enterprise.adbapi.Connection) as
+                its first argument, followed by `args` and `kwargs`.
+            args (list): positional args to pass to `func`
+            kwargs (dict): named args to pass to `func`
+
+        Returns:
+            Deferred: The result of func
+        """
         current_context = LoggingContext.current_context()
 
         start_time = time.time() * 1000
 
         def inner_func(conn, *args, **kwargs):
             with LoggingContext("runWithConnection") as context:
-                sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
+                sched_duration_ms = time.time() * 1000 - start_time
+                sql_scheduling_timer.inc_by(sched_duration_ms)
+                current_context.add_database_scheduled(sched_duration_ms)
 
                 if self.database_engine.is_connection_closed(conn):
                     logger.debug("Reconnecting closed database connection")
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index ba0da83642..7a9cd3ec90 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -27,7 +27,7 @@ from synapse.util.logutils import log_function
 from synapse.util.metrics import Measure
 from synapse.api.constants import EventTypes
 from synapse.api.errors import SynapseError
-from synapse.state import resolve_events
+from synapse.state import resolve_events_with_factory
 from synapse.util.caches.descriptors import cached
 from synapse.types import get_domain_from_id
 
@@ -146,6 +146,9 @@ class _EventPeristenceQueue(object):
             try:
                 queue = self._get_drainining_queue(room_id)
                 for item in queue:
+                    # handle_queue_loop runs in the sentinel logcontext, so
+                    # there is no need to preserve_fn when running the
+                    # callbacks on the deferred.
                     try:
                         ret = yield per_item_callback(item)
                         item.deferred.callback(ret)
@@ -157,7 +160,11 @@ class _EventPeristenceQueue(object):
                     self._event_persist_queues[room_id] = queue
                 self._currently_persisting_rooms.discard(room_id)
 
-        preserve_fn(handle_queue_loop)()
+        # set handle_queue_loop off on the background. We don't want to
+        # attribute work done in it to the current request, so we drop the
+        # logcontext altogether.
+        with PreserveLoggingContext():
+            handle_queue_loop()
 
     def _get_drainining_queue(self, room_id):
         queue = self._event_persist_queues.setdefault(room_id, deque())
@@ -556,7 +563,7 @@ class EventsStore(SQLBaseStore):
                         to_return.update(evs)
                     defer.returnValue(to_return)
 
-                current_state = yield resolve_events(
+                current_state = yield resolve_events_with_factory(
                     state_sets,
                     state_map_factory=get_events,
                 )
diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py
index 6ebc372498..e6cdbb0545 100644
--- a/synapse/storage/media_repository.py
+++ b/synapse/storage/media_repository.py
@@ -173,7 +173,14 @@ class MediaRepositoryStore(BackgroundUpdateStore):
             desc="store_cached_remote_media",
         )
 
-    def update_cached_last_access_time(self, origin_id_tuples, time_ts):
+    def update_cached_last_access_time(self, local_media, remote_media, time_ms):
+        """Updates the last access time of the given media
+
+        Args:
+            local_media (iterable[str]): Set of media_ids
+            remote_media (iterable[(str, str)]): Set of (server_name, media_id)
+            time_ms: Current time in milliseconds
+        """
         def update_cache_txn(txn):
             sql = (
                 "UPDATE remote_media_cache SET last_access_ts = ?"
@@ -181,8 +188,18 @@ class MediaRepositoryStore(BackgroundUpdateStore):
             )
 
             txn.executemany(sql, (
-                (time_ts, media_origin, media_id)
-                for media_origin, media_id in origin_id_tuples
+                (time_ms, media_origin, media_id)
+                for media_origin, media_id in remote_media
+            ))
+
+            sql = (
+                "UPDATE local_media_repository SET last_access_ts = ?"
+                " WHERE media_id = ?"
+            )
+
+            txn.executemany(sql, (
+                (time_ms, media_id)
+                for media_id in local_media
             ))
 
         return self.runInteraction("update_cached_last_access_time", update_cache_txn)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index d1691bbac2..c845a0cec5 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 46
+SCHEMA_VERSION = 47
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
 
diff --git a/synapse/storage/schema/delta/47/last_access_media.sql b/synapse/storage/schema/delta/47/last_access_media.sql
new file mode 100644
index 0000000000..f505fb22b5
--- /dev/null
+++ b/synapse/storage/schema/delta/47/last_access_media.sql
@@ -0,0 +1,16 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE local_media_repository ADD COLUMN last_access_ts BIGINT;
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index c9bff408ef..f150ef0103 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -641,8 +641,13 @@ class UserDirectoryStore(SQLBaseStore):
         """
 
         if self.hs.config.user_directory_search_all_users:
-            join_clause = ""
-            where_clause = "?<>''"  # naughty hack to keep the same number of binds
+            # dummy to keep the number of binds & aliases the same
+            join_clause = """
+                LEFT JOIN (
+                    SELECT NULL as user_id WHERE NULL = ?
+                ) AS s USING (user_id)"
+            """
+            where_clause = ""
         else:
             join_clause = """
                 LEFT JOIN users_in_public_rooms AS p USING (user_id)
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 48c9f6802d..94fa7cac98 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -52,13 +52,17 @@ except Exception:
 class LoggingContext(object):
     """Additional context for log formatting. Contexts are scoped within a
     "with" block.
+
     Args:
         name (str): Name for the context for debugging.
     """
 
     __slots__ = [
-        "previous_context", "name", "usage_start", "usage_end", "main_thread",
-        "__dict__", "tag", "alive",
+        "previous_context", "name", "ru_stime", "ru_utime",
+        "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
+        "usage_start", "usage_end",
+        "main_thread", "alive",
+        "request", "tag",
     ]
 
     thread_local = threading.local()
@@ -83,6 +87,9 @@ class LoggingContext(object):
         def add_database_transaction(self, duration_ms):
             pass
 
+        def add_database_scheduled(self, sched_ms):
+            pass
+
         def __nonzero__(self):
             return False
 
@@ -94,9 +101,17 @@ class LoggingContext(object):
         self.ru_stime = 0.
         self.ru_utime = 0.
         self.db_txn_count = 0
-        self.db_txn_duration = 0.
+
+        # ms spent waiting for db txns, excluding scheduling time
+        self.db_txn_duration_ms = 0
+
+        # ms spent waiting for db txns to be scheduled
+        self.db_sched_duration_ms = 0
+
         self.usage_start = None
+        self.usage_end = None
         self.main_thread = threading.current_thread()
+        self.request = None
         self.tag = ""
         self.alive = True
 
@@ -105,7 +120,11 @@ class LoggingContext(object):
 
     @classmethod
     def current_context(cls):
-        """Get the current logging context from thread local storage"""
+        """Get the current logging context from thread local storage
+
+        Returns:
+            LoggingContext: the current logging context
+        """
         return getattr(cls.thread_local, "current_context", cls.sentinel)
 
     @classmethod
@@ -155,11 +174,13 @@ class LoggingContext(object):
         self.alive = False
 
     def copy_to(self, record):
-        """Copy fields from this context to the record"""
-        for key, value in self.__dict__.items():
-            setattr(record, key, value)
+        """Copy logging fields from this context to a log record or
+        another LoggingContext
+        """
 
-        record.ru_utime, record.ru_stime = self.get_resource_usage()
+        # 'request' is the only field we currently use in the logger, so that's
+        # all we need to copy
+        record.request = self.request
 
     def start(self):
         if threading.current_thread() is not self.main_thread:
@@ -194,7 +215,16 @@ class LoggingContext(object):
 
     def add_database_transaction(self, duration_ms):
         self.db_txn_count += 1
-        self.db_txn_duration += duration_ms / 1000.
+        self.db_txn_duration_ms += duration_ms
+
+    def add_database_scheduled(self, sched_ms):
+        """Record a use of the database pool
+
+        Args:
+            sched_ms (int): number of milliseconds it took us to get a
+                connection
+        """
+        self.db_sched_duration_ms += sched_ms
 
 
 class LoggingContextFilter(logging.Filter):
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 4ea930d3e8..059bb7fedf 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -27,25 +27,62 @@ logger = logging.getLogger(__name__)
 
 metrics = synapse.metrics.get_metrics_for(__name__)
 
-block_timer = metrics.register_distribution(
-    "block_timer",
-    labels=["block_name"]
+# total number of times we have hit this block
+response_count = metrics.register_counter(
+    "block_count",
+    labels=["block_name"],
+    alternative_names=(
+        # the following are all deprecated aliases for the same metric
+        metrics.name_prefix + x for x in (
+            "_block_timer:count",
+            "_block_ru_utime:count",
+            "_block_ru_stime:count",
+            "_block_db_txn_count:count",
+            "_block_db_txn_duration:count",
+        )
+    )
+)
+
+block_timer = metrics.register_counter(
+    "block_time_seconds",
+    labels=["block_name"],
+    alternative_names=(
+        metrics.name_prefix + "_block_timer:total",
+    ),
 )
 
-block_ru_utime = metrics.register_distribution(
-    "block_ru_utime", labels=["block_name"]
+block_ru_utime = metrics.register_counter(
+    "block_ru_utime_seconds", labels=["block_name"],
+    alternative_names=(
+        metrics.name_prefix + "_block_ru_utime:total",
+    ),
 )
 
-block_ru_stime = metrics.register_distribution(
-    "block_ru_stime", labels=["block_name"]
+block_ru_stime = metrics.register_counter(
+    "block_ru_stime_seconds", labels=["block_name"],
+    alternative_names=(
+        metrics.name_prefix + "_block_ru_stime:total",
+    ),
 )
 
-block_db_txn_count = metrics.register_distribution(
-    "block_db_txn_count", labels=["block_name"]
+block_db_txn_count = metrics.register_counter(
+    "block_db_txn_count", labels=["block_name"],
+    alternative_names=(
+        metrics.name_prefix + "_block_db_txn_count:total",
+    ),
 )
 
-block_db_txn_duration = metrics.register_distribution(
-    "block_db_txn_duration", labels=["block_name"]
+# seconds spent waiting for db txns, excluding scheduling time, in this block
+block_db_txn_duration = metrics.register_counter(
+    "block_db_txn_duration_seconds", labels=["block_name"],
+    alternative_names=(
+        metrics.name_prefix + "_block_db_txn_count:total",
+    ),
+)
+
+# seconds spent waiting for a db connection, in this block
+block_db_sched_duration = metrics.register_counter(
+    "block_db_sched_duration_seconds", labels=["block_name"],
 )
 
 
@@ -64,7 +101,9 @@ def measure_func(name):
 class Measure(object):
     __slots__ = [
         "clock", "name", "start_context", "start", "new_context", "ru_utime",
-        "ru_stime", "db_txn_count", "db_txn_duration", "created_context"
+        "ru_stime",
+        "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
+        "created_context",
     ]
 
     def __init__(self, clock, name):
@@ -84,7 +123,8 @@ class Measure(object):
 
         self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
         self.db_txn_count = self.start_context.db_txn_count
-        self.db_txn_duration = self.start_context.db_txn_duration
+        self.db_txn_duration_ms = self.start_context.db_txn_duration_ms
+        self.db_sched_duration_ms = self.start_context.db_sched_duration_ms
 
     def __exit__(self, exc_type, exc_val, exc_tb):
         if isinstance(exc_type, Exception) or not self.start_context:
@@ -114,7 +154,12 @@ class Measure(object):
             context.db_txn_count - self.db_txn_count, self.name
         )
         block_db_txn_duration.inc_by(
-            context.db_txn_duration - self.db_txn_duration, self.name
+            (context.db_txn_duration_ms - self.db_txn_duration_ms) / 1000.,
+            self.name
+        )
+        block_db_sched_duration.inc_by(
+            (context.db_sched_duration_ms - self.db_sched_duration_ms) / 1000.,
+            self.name
         )
 
         if self.created_context:
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 570312da84..c899fecf5d 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -68,7 +68,7 @@ class KeyringTestCase(unittest.TestCase):
 
     def check_context(self, _, expected):
         self.assertEquals(
-            getattr(LoggingContext.current_context(), "test_key", None),
+            getattr(LoggingContext.current_context(), "request", None),
             expected
         )
 
@@ -82,7 +82,7 @@ class KeyringTestCase(unittest.TestCase):
         lookup_2_deferred = defer.Deferred()
 
         with LoggingContext("one") as context_one:
-            context_one.test_key = "one"
+            context_one.request = "one"
 
             wait_1_deferred = kr.wait_for_previous_lookups(
                 ["server1"],
@@ -96,7 +96,7 @@ class KeyringTestCase(unittest.TestCase):
             wait_1_deferred.addBoth(self.check_context, "one")
 
         with LoggingContext("two") as context_two:
-            context_two.test_key = "two"
+            context_two.request = "two"
 
             # set off another wait. It should block because the first lookup
             # hasn't yet completed.
@@ -137,7 +137,7 @@ class KeyringTestCase(unittest.TestCase):
         @defer.inlineCallbacks
         def get_perspectives(**kwargs):
             self.assertEquals(
-                LoggingContext.current_context().test_key, "11",
+                LoggingContext.current_context().request, "11",
             )
             with logcontext.PreserveLoggingContext():
                 yield persp_deferred
@@ -145,7 +145,7 @@ class KeyringTestCase(unittest.TestCase):
         self.http_client.post_json.side_effect = get_perspectives
 
         with LoggingContext("11") as context_11:
-            context_11.test_key = "11"
+            context_11.request = "11"
 
             # start off a first set of lookups
             res_deferreds = kr.verify_json_objects_for_server(
@@ -173,7 +173,7 @@ class KeyringTestCase(unittest.TestCase):
             self.assertIs(LoggingContext.current_context(), context_11)
 
             context_12 = LoggingContext("12")
-            context_12.test_key = "12"
+            context_12.request = "12"
             with logcontext.PreserveLoggingContext(context_12):
                 # a second request for a server with outstanding requests
                 # should block rather than start a second call
@@ -211,7 +211,7 @@ class KeyringTestCase(unittest.TestCase):
         sentinel_context = LoggingContext.current_context()
 
         with LoggingContext("one") as context_one:
-            context_one.test_key = "one"
+            context_one.request = "one"
 
             defer = kr.verify_json_for_server("server9", {})
             try:
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index e2f7765f49..4850722bc5 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -12,12 +12,12 @@ class LoggingContextTestCase(unittest.TestCase):
 
     def _check_test_key(self, value):
         self.assertEquals(
-            LoggingContext.current_context().test_key, value
+            LoggingContext.current_context().request, value
         )
 
     def test_with_context(self):
         with LoggingContext() as context_one:
-            context_one.test_key = "test"
+            context_one.request = "test"
             self._check_test_key("test")
 
     @defer.inlineCallbacks
@@ -25,14 +25,14 @@ class LoggingContextTestCase(unittest.TestCase):
         @defer.inlineCallbacks
         def competing_callback():
             with LoggingContext() as competing_context:
-                competing_context.test_key = "competing"
+                competing_context.request = "competing"
                 yield sleep(0)
                 self._check_test_key("competing")
 
         reactor.callLater(0, competing_callback)
 
         with LoggingContext() as context_one:
-            context_one.test_key = "one"
+            context_one.request = "one"
             yield sleep(0)
             self._check_test_key("one")
 
@@ -43,14 +43,14 @@ class LoggingContextTestCase(unittest.TestCase):
 
         @defer.inlineCallbacks
         def cb():
-            context_one.test_key = "one"
+            context_one.request = "one"
             yield function()
             self._check_test_key("one")
 
             callback_completed[0] = True
 
         with LoggingContext() as context_one:
-            context_one.test_key = "one"
+            context_one.request = "one"
 
             # fire off function, but don't wait on it.
             logcontext.preserve_fn(cb)()
@@ -107,7 +107,7 @@ class LoggingContextTestCase(unittest.TestCase):
         sentinel_context = LoggingContext.current_context()
 
         with LoggingContext() as context_one:
-            context_one.test_key = "one"
+            context_one.request = "one"
 
             d1 = logcontext.make_deferred_yieldable(blocking_function())
             # make sure that the context was reset by make_deferred_yieldable
@@ -124,7 +124,7 @@ class LoggingContextTestCase(unittest.TestCase):
         argument isn't actually a deferred"""
 
         with LoggingContext() as context_one:
-            context_one.test_key = "one"
+            context_one.request = "one"
 
             d1 = logcontext.make_deferred_yieldable("bum")
             self._check_test_key("one")