summary refs log tree commit diff
diff options
context:
space:
mode:
authorMark Haines <mjark@negativecurvature.net>2016-04-21 16:21:49 +0100
committerMark Haines <mjark@negativecurvature.net>2016-04-21 16:21:49 +0100
commit712030aeef97c414d641a65b398355ed74dc7baf (patch)
treea355a2b8c9e66264a991dff3d41d3e19cdd90d5e
parentAdd an HTTP API for removing rejected pushers. (diff)
parentpip install new python dependencies in jenkins.sh (diff)
downloadsynapse-712030aeef97c414d641a65b398355ed74dc7baf.tar.xz
Merge branch 'develop' into markjh/split_pusher
-rwxr-xr-xjenkins-postgres.sh2
-rwxr-xr-xjenkins-sqlite.sh2
-rwxr-xr-xjenkins.sh86
-rw-r--r--synapse/federation/transport/client.py3
-rw-r--r--synapse/handlers/auth.py23
-rw-r--r--synapse/http/client.py5
-rw-r--r--synapse/rest/media/v1/_base.py110
-rw-r--r--synapse/rest/media/v1/base_resource.py460
-rw-r--r--synapse/rest/media/v1/download_resource.py24
-rw-r--r--synapse/rest/media/v1/media_repository.py395
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py55
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py51
-rw-r--r--synapse/rest/media/v1/upload_resource.py51
-rw-r--r--synapse/state.py18
-rw-r--r--synapse/storage/state.py19
-rw-r--r--synapse/util/__init__.py3
-rw-r--r--synapse/util/metrics.py23
-rw-r--r--tests/replication/slave/storage/_base.py4
-rw-r--r--tests/replication/slave/storage/test_events.py15
-rw-r--r--tests/replication/slave/storage/test_receipts.py39
-rw-r--r--tests/test_state.py4
21 files changed, 733 insertions, 659 deletions
diff --git a/jenkins-postgres.sh b/jenkins-postgres.sh
index 9ac86d2593..ae6b111591 100755
--- a/jenkins-postgres.sh
+++ b/jenkins-postgres.sh
@@ -25,7 +25,9 @@ rm .coverage* || echo "No coverage files to remove"
 tox --notest -e py27
 
 TOX_BIN=$WORKSPACE/.tox/py27/bin
+python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
 $TOX_BIN/pip install psycopg2
+$TOX_BIN/pip install lxml
 
 : ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
 
diff --git a/jenkins-sqlite.sh b/jenkins-sqlite.sh
index 345d01936c..9398d9db15 100755
--- a/jenkins-sqlite.sh
+++ b/jenkins-sqlite.sh
@@ -24,6 +24,8 @@ rm .coverage* || echo "No coverage files to remove"
 
 tox --notest -e py27
 TOX_BIN=$WORKSPACE/.tox/py27/bin
+python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
+$TOX_BIN/pip install lxml
 
 : ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
 
diff --git a/jenkins.sh b/jenkins.sh
deleted file mode 100755
index b826d510c9..0000000000
--- a/jenkins.sh
+++ /dev/null
@@ -1,86 +0,0 @@
-#!/bin/bash
-
-set -eux
-
-: ${WORKSPACE:="$(pwd)"}
-
-export PYTHONDONTWRITEBYTECODE=yep
-export SYNAPSE_CACHE_FACTOR=1
-
-# Output test results as junit xml
-export TRIAL_FLAGS="--reporter=subunit"
-export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
-# Write coverage reports to a separate file for each process
-export COVERAGE_OPTS="-p"
-export DUMP_COVERAGE_COMMAND="coverage help"
-
-# Output flake8 violations to violations.flake8.log
-# Don't exit with non-0 status code on Jenkins,
-# so that the build steps continue and a later step can decided whether to
-# UNSTABLE or FAILURE this build.
-export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
-
-rm .coverage* || echo "No coverage files to remove"
-
-tox
-
-: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
-
-TOX_BIN=$WORKSPACE/.tox/py27/bin
-
-if [[ ! -e .sytest-base ]]; then
-  git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
-else
-  (cd .sytest-base; git fetch -p)
-fi
-
-rm -rf sytest
-git clone .sytest-base sytest --shared
-cd sytest
-
-git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
-
-: ${PERL5LIB:=$WORKSPACE/perl5/lib/perl5}
-: ${PERL_MB_OPT:=--install_base=$WORKSPACE/perl5}
-: ${PERL_MM_OPT:=INSTALL_BASE=$WORKSPACE/perl5}
-export PERL5LIB PERL_MB_OPT PERL_MM_OPT
-
-./install-deps.pl
-
-: ${PORT_BASE:=8000}
-
-echo >&2 "Running sytest with SQLite3";
-./run-tests.pl --coverage -O tap --synapse-directory $WORKSPACE \
-    --python $TOX_BIN/python --all --port-base $PORT_BASE > results-sqlite3.tap
-
-RUN_POSTGRES=""
-
-for port in $(($PORT_BASE + 1)) $(($PORT_BASE + 2)); do
-    if psql synapse_jenkins_$port <<< ""; then
-        RUN_POSTGRES="$RUN_POSTGRES:$port"
-        cat > localhost-$port/database.yaml << EOF
-name: psycopg2
-args:
-    database: synapse_jenkins_$port
-EOF
-    fi
-done
-
-# Run if both postgresql databases exist
-if test "$RUN_POSTGRES" = ":$(($PORT_BASE + 1)):$(($PORT_BASE + 2))"; then
-    echo >&2 "Running sytest with PostgreSQL";
-    $TOX_BIN/pip install psycopg2
-    ./run-tests.pl --coverage -O tap --synapse-directory $WORKSPACE \
-        --python $TOX_BIN/python --all --port-base $PORT_BASE > results-postgresql.tap
-else
-    echo >&2 "Skipping running sytest with PostgreSQL, $RUN_POSTGRES"
-fi
-
-cd ..
-cp sytest/.coverage.* .
-
-# Combine the coverage reports
-echo "Combining:" .coverage.*
-$TOX_BIN/python -m coverage combine
-# Output coverage to coverage.xml
-$TOX_BIN/coverage xml -o coverage.xml
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 2237e3413c..cd2841c4db 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -179,7 +179,8 @@ class TransportLayerClient(object):
         content = yield self.client.get_json(
             destination=destination,
             path=path,
-            retry_on_dns_fail=True,
+            retry_on_dns_fail=False,
+            timeout=20000,
         )
 
         defer.returnValue(content)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 7a13a8b11c..61fe56032a 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -428,24 +428,31 @@ class AuthHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def _check_password(self, user_id, password):
-        defer.returnValue(
-            not (
-                (yield self._check_ldap_password(user_id, password))
-                or
-                (yield self._check_local_password(user_id, password))
-            ))
+        """
+        Returns:
+            True if the user_id successfully authenticated
+        """
+        valid_ldap = yield self._check_ldap_password(user_id, password)
+        if valid_ldap:
+            defer.returnValue(True)
+
+        valid_local_password = yield self._check_local_password(user_id, password)
+        if valid_local_password:
+            defer.returnValue(True)
+
+        defer.returnValue(False)
 
     @defer.inlineCallbacks
     def _check_local_password(self, user_id, password):
         try:
             user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
-            defer.returnValue(not self.validate_hash(password, password_hash))
+            defer.returnValue(self.validate_hash(password, password_hash))
         except LoginError:
             defer.returnValue(False)
 
     @defer.inlineCallbacks
     def _check_ldap_password(self, user_id, password):
-        if self.ldap_enabled is not True:
+        if not self.ldap_enabled:
             logger.debug("LDAP not configured")
             defer.returnValue(False)
 
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 6c89b20984..902ae7a203 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -462,5 +462,8 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
         self._context = SSL.Context(SSL.SSLv23_METHOD)
         self._context.set_verify(VERIFY_NONE, lambda *_: None)
 
-    def getContext(self, hostname, port):
+    def getContext(self, hostname=None, port=None):
         return self._context
+
+    def creatorForNetloc(self, hostname, port):
+        return self
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
new file mode 100644
index 0000000000..b9600f2167
--- /dev/null
+++ b/synapse/rest/media/v1/_base.py
@@ -0,0 +1,110 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.http.server import respond_with_json, finish_request
+from synapse.api.errors import (
+    cs_error, Codes, SynapseError
+)
+
+from twisted.internet import defer
+from twisted.protocols.basic import FileSender
+
+from synapse.util.stringutils import is_ascii
+
+import os
+
+import logging
+import urllib
+import urlparse
+
+logger = logging.getLogger(__name__)
+
+
+def parse_media_id(request):
+    try:
+        # This allows users to append e.g. /test.png to the URL. Useful for
+        # clients that parse the URL to see content type.
+        server_name, media_id = request.postpath[:2]
+        file_name = None
+        if len(request.postpath) > 2:
+            try:
+                file_name = urlparse.unquote(request.postpath[-1]).decode("utf-8")
+            except UnicodeDecodeError:
+                pass
+        return server_name, media_id, file_name
+    except:
+        raise SynapseError(
+            404,
+            "Invalid media id token %r" % (request.postpath,),
+            Codes.UNKNOWN,
+        )
+
+
+def respond_404(request):
+    respond_with_json(
+        request, 404,
+        cs_error(
+            "Not found %r" % (request.postpath,),
+            code=Codes.NOT_FOUND,
+        ),
+        send_cors=True
+    )
+
+
+@defer.inlineCallbacks
+def respond_with_file(request, media_type, file_path,
+                      file_size=None, upload_name=None):
+    logger.debug("Responding with %r", file_path)
+
+    if os.path.isfile(file_path):
+        request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
+        if upload_name:
+            if is_ascii(upload_name):
+                request.setHeader(
+                    b"Content-Disposition",
+                    b"inline; filename=%s" % (
+                        urllib.quote(upload_name.encode("utf-8")),
+                    ),
+                )
+            else:
+                request.setHeader(
+                    b"Content-Disposition",
+                    b"inline; filename*=utf-8''%s" % (
+                        urllib.quote(upload_name.encode("utf-8")),
+                    ),
+                )
+
+        # cache for at least a day.
+        # XXX: we might want to turn this off for data we don't want to
+        # recommend caching as it's sensitive or private - or at least
+        # select private. don't bother setting Expires as all our
+        # clients are smart enough to be happy with Cache-Control
+        request.setHeader(
+            b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
+        )
+        if file_size is None:
+            stat = os.stat(file_path)
+            file_size = stat.st_size
+
+        request.setHeader(
+            b"Content-Length", b"%d" % (file_size,)
+        )
+
+        with open(file_path, "rb") as f:
+            yield FileSender().beginFileTransfer(f, request)
+
+        finish_request(request)
+    else:
+        respond_404(request)
diff --git a/synapse/rest/media/v1/base_resource.py b/synapse/rest/media/v1/base_resource.py
deleted file mode 100644
index 2b1938dc8e..0000000000
--- a/synapse/rest/media/v1/base_resource.py
+++ /dev/null
@@ -1,460 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from .thumbnailer import Thumbnailer
-
-from synapse.http.matrixfederationclient import MatrixFederationHttpClient
-from synapse.http.server import respond_with_json, finish_request
-from synapse.util.stringutils import random_string
-from synapse.api.errors import (
-    cs_error, Codes, SynapseError
-)
-
-from twisted.internet import defer, threads
-from twisted.web.resource import Resource
-from twisted.protocols.basic import FileSender
-
-from synapse.util.async import ObservableDeferred
-from synapse.util.stringutils import is_ascii
-from synapse.util.logcontext import preserve_context_over_fn
-
-import os
-
-import cgi
-import logging
-import urllib
-import urlparse
-
-logger = logging.getLogger(__name__)
-
-
-def parse_media_id(request):
-    try:
-        # This allows users to append e.g. /test.png to the URL. Useful for
-        # clients that parse the URL to see content type.
-        server_name, media_id = request.postpath[:2]
-        file_name = None
-        if len(request.postpath) > 2:
-            try:
-                file_name = urlparse.unquote(request.postpath[-1]).decode("utf-8")
-            except UnicodeDecodeError:
-                pass
-        return server_name, media_id, file_name
-    except:
-        raise SynapseError(
-            404,
-            "Invalid media id token %r" % (request.postpath,),
-            Codes.UNKNOWN,
-        )
-
-
-class BaseMediaResource(Resource):
-    isLeaf = True
-
-    def __init__(self, hs, filepaths):
-        Resource.__init__(self)
-        self.auth = hs.get_auth()
-        self.client = MatrixFederationHttpClient(hs)
-        self.clock = hs.get_clock()
-        self.server_name = hs.hostname
-        self.store = hs.get_datastore()
-        self.max_upload_size = hs.config.max_upload_size
-        self.max_image_pixels = hs.config.max_image_pixels
-        self.max_spider_size = hs.config.max_spider_size
-        self.filepaths = filepaths
-        self.version_string = hs.version_string
-        self.downloads = {}
-        self.dynamic_thumbnails = hs.config.dynamic_thumbnails
-        self.thumbnail_requirements = hs.config.thumbnail_requirements
-
-    def _respond_404(self, request):
-        respond_with_json(
-            request, 404,
-            cs_error(
-                "Not found %r" % (request.postpath,),
-                code=Codes.NOT_FOUND,
-            ),
-            send_cors=True
-        )
-
-    @staticmethod
-    def _makedirs(filepath):
-        dirname = os.path.dirname(filepath)
-        if not os.path.exists(dirname):
-            os.makedirs(dirname)
-
-    def _get_remote_media(self, server_name, media_id):
-        key = (server_name, media_id)
-        download = self.downloads.get(key)
-        if download is None:
-            download = self._get_remote_media_impl(server_name, media_id)
-            download = ObservableDeferred(
-                download,
-                consumeErrors=True
-            )
-            self.downloads[key] = download
-
-            @download.addBoth
-            def callback(media_info):
-                del self.downloads[key]
-                return media_info
-        return download.observe()
-
-    @defer.inlineCallbacks
-    def _get_remote_media_impl(self, server_name, media_id):
-        media_info = yield self.store.get_cached_remote_media(
-            server_name, media_id
-        )
-        if not media_info:
-            media_info = yield self._download_remote_file(
-                server_name, media_id
-            )
-        defer.returnValue(media_info)
-
-    @defer.inlineCallbacks
-    def _download_remote_file(self, server_name, media_id):
-        file_id = random_string(24)
-
-        fname = self.filepaths.remote_media_filepath(
-            server_name, file_id
-        )
-        self._makedirs(fname)
-
-        try:
-            with open(fname, "wb") as f:
-                request_path = "/".join((
-                    "/_matrix/media/v1/download", server_name, media_id,
-                ))
-                length, headers = yield self.client.get_file(
-                    server_name, request_path, output_stream=f,
-                    max_size=self.max_upload_size,
-                )
-            media_type = headers["Content-Type"][0]
-            time_now_ms = self.clock.time_msec()
-
-            content_disposition = headers.get("Content-Disposition", None)
-            if content_disposition:
-                _, params = cgi.parse_header(content_disposition[0],)
-                upload_name = None
-
-                # First check if there is a valid UTF-8 filename
-                upload_name_utf8 = params.get("filename*", None)
-                if upload_name_utf8:
-                    if upload_name_utf8.lower().startswith("utf-8''"):
-                        upload_name = upload_name_utf8[7:]
-
-                # If there isn't check for an ascii name.
-                if not upload_name:
-                    upload_name_ascii = params.get("filename", None)
-                    if upload_name_ascii and is_ascii(upload_name_ascii):
-                        upload_name = upload_name_ascii
-
-                if upload_name:
-                    upload_name = urlparse.unquote(upload_name)
-                    try:
-                        upload_name = upload_name.decode("utf-8")
-                    except UnicodeDecodeError:
-                        upload_name = None
-            else:
-                upload_name = None
-
-            yield self.store.store_cached_remote_media(
-                origin=server_name,
-                media_id=media_id,
-                media_type=media_type,
-                time_now_ms=self.clock.time_msec(),
-                upload_name=upload_name,
-                media_length=length,
-                filesystem_id=file_id,
-            )
-        except:
-            os.remove(fname)
-            raise
-
-        media_info = {
-            "media_type": media_type,
-            "media_length": length,
-            "upload_name": upload_name,
-            "created_ts": time_now_ms,
-            "filesystem_id": file_id,
-        }
-
-        yield self._generate_remote_thumbnails(
-            server_name, media_id, media_info
-        )
-
-        defer.returnValue(media_info)
-
-    @defer.inlineCallbacks
-    def _respond_with_file(self, request, media_type, file_path,
-                           file_size=None, upload_name=None):
-        logger.debug("Responding with %r", file_path)
-
-        if os.path.isfile(file_path):
-            request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
-            if upload_name:
-                if is_ascii(upload_name):
-                    request.setHeader(
-                        b"Content-Disposition",
-                        b"inline; filename=%s" % (
-                            urllib.quote(upload_name.encode("utf-8")),
-                        ),
-                    )
-                else:
-                    request.setHeader(
-                        b"Content-Disposition",
-                        b"inline; filename*=utf-8''%s" % (
-                            urllib.quote(upload_name.encode("utf-8")),
-                        ),
-                    )
-
-            # cache for at least a day.
-            # XXX: we might want to turn this off for data we don't want to
-            # recommend caching as it's sensitive or private - or at least
-            # select private. don't bother setting Expires as all our
-            # clients are smart enough to be happy with Cache-Control
-            request.setHeader(
-                b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
-            )
-            if file_size is None:
-                stat = os.stat(file_path)
-                file_size = stat.st_size
-
-            request.setHeader(
-                b"Content-Length", b"%d" % (file_size,)
-            )
-
-            with open(file_path, "rb") as f:
-                yield FileSender().beginFileTransfer(f, request)
-
-            finish_request(request)
-        else:
-            self._respond_404(request)
-
-    def _get_thumbnail_requirements(self, media_type):
-        return self.thumbnail_requirements.get(media_type, ())
-
-    def _generate_thumbnail(self, input_path, t_path, t_width, t_height,
-                            t_method, t_type):
-        thumbnailer = Thumbnailer(input_path)
-        m_width = thumbnailer.width
-        m_height = thumbnailer.height
-
-        if m_width * m_height >= self.max_image_pixels:
-            logger.info(
-                "Image too large to thumbnail %r x %r > %r",
-                m_width, m_height, self.max_image_pixels
-            )
-            return
-
-        if t_method == "crop":
-            t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
-        elif t_method == "scale":
-            t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
-        else:
-            t_len = None
-
-        return t_len
-
-    @defer.inlineCallbacks
-    def _generate_local_exact_thumbnail(self, media_id, t_width, t_height,
-                                        t_method, t_type):
-        input_path = self.filepaths.local_media_filepath(media_id)
-
-        t_path = self.filepaths.local_media_thumbnail(
-            media_id, t_width, t_height, t_type, t_method
-        )
-        self._makedirs(t_path)
-
-        t_len = yield preserve_context_over_fn(
-            threads.deferToThread,
-            self._generate_thumbnail,
-            input_path, t_path, t_width, t_height, t_method, t_type
-        )
-
-        if t_len:
-            yield self.store.store_local_thumbnail(
-                media_id, t_width, t_height, t_type, t_method, t_len
-            )
-
-            defer.returnValue(t_path)
-
-    @defer.inlineCallbacks
-    def _generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
-                                         t_width, t_height, t_method, t_type):
-        input_path = self.filepaths.remote_media_filepath(server_name, file_id)
-
-        t_path = self.filepaths.remote_media_thumbnail(
-            server_name, file_id, t_width, t_height, t_type, t_method
-        )
-        self._makedirs(t_path)
-
-        t_len = yield preserve_context_over_fn(
-            threads.deferToThread,
-            self._generate_thumbnail,
-            input_path, t_path, t_width, t_height, t_method, t_type
-        )
-
-        if t_len:
-            yield self.store.store_remote_media_thumbnail(
-                server_name, media_id, file_id,
-                t_width, t_height, t_type, t_method, t_len
-            )
-
-            defer.returnValue(t_path)
-
-    @defer.inlineCallbacks
-    def _generate_local_thumbnails(self, media_id, media_info):
-        media_type = media_info["media_type"]
-        requirements = self._get_thumbnail_requirements(media_type)
-        if not requirements:
-            return
-
-        input_path = self.filepaths.local_media_filepath(media_id)
-        thumbnailer = Thumbnailer(input_path)
-        m_width = thumbnailer.width
-        m_height = thumbnailer.height
-
-        if m_width * m_height >= self.max_image_pixels:
-            logger.info(
-                "Image too large to thumbnail %r x %r > %r",
-                m_width, m_height, self.max_image_pixels
-            )
-            return
-
-        local_thumbnails = []
-
-        def generate_thumbnails():
-            scales = set()
-            crops = set()
-            for r_width, r_height, r_method, r_type in requirements:
-                if r_method == "scale":
-                    t_width, t_height = thumbnailer.aspect(r_width, r_height)
-                    scales.add((
-                        min(m_width, t_width), min(m_height, t_height), r_type,
-                    ))
-                elif r_method == "crop":
-                    crops.add((r_width, r_height, r_type))
-
-            for t_width, t_height, t_type in scales:
-                t_method = "scale"
-                t_path = self.filepaths.local_media_thumbnail(
-                    media_id, t_width, t_height, t_type, t_method
-                )
-                self._makedirs(t_path)
-                t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
-
-                local_thumbnails.append((
-                    media_id, t_width, t_height, t_type, t_method, t_len
-                ))
-
-            for t_width, t_height, t_type in crops:
-                if (t_width, t_height, t_type) in scales:
-                    # If the aspect ratio of the cropped thumbnail matches a purely
-                    # scaled one then there is no point in calculating a separate
-                    # thumbnail.
-                    continue
-                t_method = "crop"
-                t_path = self.filepaths.local_media_thumbnail(
-                    media_id, t_width, t_height, t_type, t_method
-                )
-                self._makedirs(t_path)
-                t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
-                local_thumbnails.append((
-                    media_id, t_width, t_height, t_type, t_method, t_len
-                ))
-
-        yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
-
-        for l in local_thumbnails:
-            yield self.store.store_local_thumbnail(*l)
-
-        defer.returnValue({
-            "width": m_width,
-            "height": m_height,
-        })
-
-    @defer.inlineCallbacks
-    def _generate_remote_thumbnails(self, server_name, media_id, media_info):
-        media_type = media_info["media_type"]
-        file_id = media_info["filesystem_id"]
-        requirements = self._get_thumbnail_requirements(media_type)
-        if not requirements:
-            return
-
-        remote_thumbnails = []
-
-        input_path = self.filepaths.remote_media_filepath(server_name, file_id)
-        thumbnailer = Thumbnailer(input_path)
-        m_width = thumbnailer.width
-        m_height = thumbnailer.height
-
-        def generate_thumbnails():
-            if m_width * m_height >= self.max_image_pixels:
-                logger.info(
-                    "Image too large to thumbnail %r x %r > %r",
-                    m_width, m_height, self.max_image_pixels
-                )
-                return
-
-            scales = set()
-            crops = set()
-            for r_width, r_height, r_method, r_type in requirements:
-                if r_method == "scale":
-                    t_width, t_height = thumbnailer.aspect(r_width, r_height)
-                    scales.add((
-                        min(m_width, t_width), min(m_height, t_height), r_type,
-                    ))
-                elif r_method == "crop":
-                    crops.add((r_width, r_height, r_type))
-
-            for t_width, t_height, t_type in scales:
-                t_method = "scale"
-                t_path = self.filepaths.remote_media_thumbnail(
-                    server_name, file_id, t_width, t_height, t_type, t_method
-                )
-                self._makedirs(t_path)
-                t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
-                remote_thumbnails.append([
-                    server_name, media_id, file_id,
-                    t_width, t_height, t_type, t_method, t_len
-                ])
-
-            for t_width, t_height, t_type in crops:
-                if (t_width, t_height, t_type) in scales:
-                    # If the aspect ratio of the cropped thumbnail matches a purely
-                    # scaled one then there is no point in calculating a separate
-                    # thumbnail.
-                    continue
-                t_method = "crop"
-                t_path = self.filepaths.remote_media_thumbnail(
-                    server_name, file_id, t_width, t_height, t_type, t_method
-                )
-                self._makedirs(t_path)
-                t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
-                remote_thumbnails.append([
-                    server_name, media_id, file_id,
-                    t_width, t_height, t_type, t_method, t_len
-                ])
-
-        yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
-
-        for r in remote_thumbnails:
-            yield self.store.store_remote_media_thumbnail(*r)
-
-        defer.returnValue({
-            "width": m_width,
-            "height": m_height,
-        })
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index 1aad6b3551..510884262c 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -13,7 +13,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .base_resource import BaseMediaResource, parse_media_id
+from ._base import parse_media_id, respond_with_file, respond_404
+from twisted.web.resource import Resource
 from synapse.http.server import request_handler
 
 from twisted.web.server import NOT_DONE_YET
@@ -24,7 +25,18 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-class DownloadResource(BaseMediaResource):
+class DownloadResource(Resource):
+    isLeaf = True
+
+    def __init__(self, hs, media_repo):
+        Resource.__init__(self)
+
+        self.filepaths = media_repo.filepaths
+        self.media_repo = media_repo
+        self.server_name = hs.hostname
+        self.store = hs.get_datastore()
+        self.version_string = hs.version_string
+
     def render_GET(self, request):
         self._async_render_GET(request)
         return NOT_DONE_YET
@@ -44,7 +56,7 @@ class DownloadResource(BaseMediaResource):
     def _respond_local_file(self, request, media_id, name):
         media_info = yield self.store.get_local_media(media_id)
         if not media_info:
-            self._respond_404(request)
+            respond_404(request)
             return
 
         media_type = media_info["media_type"]
@@ -52,14 +64,14 @@ class DownloadResource(BaseMediaResource):
         upload_name = name if name else media_info["upload_name"]
         file_path = self.filepaths.local_media_filepath(media_id)
 
-        yield self._respond_with_file(
+        yield respond_with_file(
             request, media_type, file_path, media_length,
             upload_name=upload_name,
         )
 
     @defer.inlineCallbacks
     def _respond_remote_file(self, request, server_name, media_id, name):
-        media_info = yield self._get_remote_media(server_name, media_id)
+        media_info = yield self.media_repo.get_remote_media(server_name, media_id)
 
         media_type = media_info["media_type"]
         media_length = media_info["media_length"]
@@ -70,7 +82,7 @@ class DownloadResource(BaseMediaResource):
             server_name, filesystem_id
         )
 
-        yield self._respond_with_file(
+        yield respond_with_file(
             request, media_type, file_path, media_length,
             upload_name=upload_name,
         )
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 77fb0313c5..d96bf9afe2 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -22,11 +22,395 @@ from .filepath import MediaFilePaths
 
 from twisted.web.resource import Resource
 
+from .thumbnailer import Thumbnailer
+
+from synapse.http.matrixfederationclient import MatrixFederationHttpClient
+from synapse.util.stringutils import random_string
+
+from twisted.internet import defer, threads
+
+from synapse.util.async import ObservableDeferred
+from synapse.util.stringutils import is_ascii
+from synapse.util.logcontext import preserve_context_over_fn
+
+import os
+
+import cgi
 import logging
+import urlparse
 
 logger = logging.getLogger(__name__)
 
 
+class MediaRepository(object):
+    def __init__(self, hs, filepaths):
+        self.auth = hs.get_auth()
+        self.client = MatrixFederationHttpClient(hs)
+        self.clock = hs.get_clock()
+        self.server_name = hs.hostname
+        self.store = hs.get_datastore()
+        self.max_upload_size = hs.config.max_upload_size
+        self.max_image_pixels = hs.config.max_image_pixels
+        self.filepaths = filepaths
+        self.downloads = {}
+        self.dynamic_thumbnails = hs.config.dynamic_thumbnails
+        self.thumbnail_requirements = hs.config.thumbnail_requirements
+
+    @staticmethod
+    def _makedirs(filepath):
+        dirname = os.path.dirname(filepath)
+        if not os.path.exists(dirname):
+            os.makedirs(dirname)
+
+    @defer.inlineCallbacks
+    def create_content(self, media_type, upload_name, content, content_length,
+                       auth_user):
+        media_id = random_string(24)
+
+        fname = self.filepaths.local_media_filepath(media_id)
+        self._makedirs(fname)
+
+        # This shouldn't block for very long because the content will have
+        # already been uploaded at this point.
+        with open(fname, "wb") as f:
+            f.write(content)
+
+        yield self.store.store_local_media(
+            media_id=media_id,
+            media_type=media_type,
+            time_now_ms=self.clock.time_msec(),
+            upload_name=upload_name,
+            media_length=content_length,
+            user_id=auth_user,
+        )
+        media_info = {
+            "media_type": media_type,
+            "media_length": content_length,
+        }
+
+        yield self._generate_local_thumbnails(media_id, media_info)
+
+        defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
+
+    def get_remote_media(self, server_name, media_id):
+        key = (server_name, media_id)
+        download = self.downloads.get(key)
+        if download is None:
+            download = self._get_remote_media_impl(server_name, media_id)
+            download = ObservableDeferred(
+                download,
+                consumeErrors=True
+            )
+            self.downloads[key] = download
+
+            @download.addBoth
+            def callback(media_info):
+                del self.downloads[key]
+                return media_info
+        return download.observe()
+
+    @defer.inlineCallbacks
+    def _get_remote_media_impl(self, server_name, media_id):
+        media_info = yield self.store.get_cached_remote_media(
+            server_name, media_id
+        )
+        if not media_info:
+            media_info = yield self._download_remote_file(
+                server_name, media_id
+            )
+        defer.returnValue(media_info)
+
+    @defer.inlineCallbacks
+    def _download_remote_file(self, server_name, media_id):
+        file_id = random_string(24)
+
+        fname = self.filepaths.remote_media_filepath(
+            server_name, file_id
+        )
+        self._makedirs(fname)
+
+        try:
+            with open(fname, "wb") as f:
+                request_path = "/".join((
+                    "/_matrix/media/v1/download", server_name, media_id,
+                ))
+                length, headers = yield self.client.get_file(
+                    server_name, request_path, output_stream=f,
+                    max_size=self.max_upload_size,
+                )
+            media_type = headers["Content-Type"][0]
+            time_now_ms = self.clock.time_msec()
+
+            content_disposition = headers.get("Content-Disposition", None)
+            if content_disposition:
+                _, params = cgi.parse_header(content_disposition[0],)
+                upload_name = None
+
+                # First check if there is a valid UTF-8 filename
+                upload_name_utf8 = params.get("filename*", None)
+                if upload_name_utf8:
+                    if upload_name_utf8.lower().startswith("utf-8''"):
+                        upload_name = upload_name_utf8[7:]
+
+                # If there isn't check for an ascii name.
+                if not upload_name:
+                    upload_name_ascii = params.get("filename", None)
+                    if upload_name_ascii and is_ascii(upload_name_ascii):
+                        upload_name = upload_name_ascii
+
+                if upload_name:
+                    upload_name = urlparse.unquote(upload_name)
+                    try:
+                        upload_name = upload_name.decode("utf-8")
+                    except UnicodeDecodeError:
+                        upload_name = None
+            else:
+                upload_name = None
+
+            yield self.store.store_cached_remote_media(
+                origin=server_name,
+                media_id=media_id,
+                media_type=media_type,
+                time_now_ms=self.clock.time_msec(),
+                upload_name=upload_name,
+                media_length=length,
+                filesystem_id=file_id,
+            )
+        except:
+            os.remove(fname)
+            raise
+
+        media_info = {
+            "media_type": media_type,
+            "media_length": length,
+            "upload_name": upload_name,
+            "created_ts": time_now_ms,
+            "filesystem_id": file_id,
+        }
+
+        yield self._generate_remote_thumbnails(
+            server_name, media_id, media_info
+        )
+
+        defer.returnValue(media_info)
+
+    def _get_thumbnail_requirements(self, media_type):
+        return self.thumbnail_requirements.get(media_type, ())
+
+    def _generate_thumbnail(self, input_path, t_path, t_width, t_height,
+                            t_method, t_type):
+        thumbnailer = Thumbnailer(input_path)
+        m_width = thumbnailer.width
+        m_height = thumbnailer.height
+
+        if m_width * m_height >= self.max_image_pixels:
+            logger.info(
+                "Image too large to thumbnail %r x %r > %r",
+                m_width, m_height, self.max_image_pixels
+            )
+            return
+
+        if t_method == "crop":
+            t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
+        elif t_method == "scale":
+            t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
+        else:
+            t_len = None
+
+        return t_len
+
+    @defer.inlineCallbacks
+    def generate_local_exact_thumbnail(self, media_id, t_width, t_height,
+                                       t_method, t_type):
+        input_path = self.filepaths.local_media_filepath(media_id)
+
+        t_path = self.filepaths.local_media_thumbnail(
+            media_id, t_width, t_height, t_type, t_method
+        )
+        self._makedirs(t_path)
+
+        t_len = yield preserve_context_over_fn(
+            threads.deferToThread,
+            self._generate_thumbnail,
+            input_path, t_path, t_width, t_height, t_method, t_type
+        )
+
+        if t_len:
+            yield self.store.store_local_thumbnail(
+                media_id, t_width, t_height, t_type, t_method, t_len
+            )
+
+            defer.returnValue(t_path)
+
+    @defer.inlineCallbacks
+    def generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
+                                        t_width, t_height, t_method, t_type):
+        input_path = self.filepaths.remote_media_filepath(server_name, file_id)
+
+        t_path = self.filepaths.remote_media_thumbnail(
+            server_name, file_id, t_width, t_height, t_type, t_method
+        )
+        self._makedirs(t_path)
+
+        t_len = yield preserve_context_over_fn(
+            threads.deferToThread,
+            self._generate_thumbnail,
+            input_path, t_path, t_width, t_height, t_method, t_type
+        )
+
+        if t_len:
+            yield self.store.store_remote_media_thumbnail(
+                server_name, media_id, file_id,
+                t_width, t_height, t_type, t_method, t_len
+            )
+
+            defer.returnValue(t_path)
+
+    @defer.inlineCallbacks
+    def _generate_local_thumbnails(self, media_id, media_info):
+        media_type = media_info["media_type"]
+        requirements = self._get_thumbnail_requirements(media_type)
+        if not requirements:
+            return
+
+        input_path = self.filepaths.local_media_filepath(media_id)
+        thumbnailer = Thumbnailer(input_path)
+        m_width = thumbnailer.width
+        m_height = thumbnailer.height
+
+        if m_width * m_height >= self.max_image_pixels:
+            logger.info(
+                "Image too large to thumbnail %r x %r > %r",
+                m_width, m_height, self.max_image_pixels
+            )
+            return
+
+        local_thumbnails = []
+
+        def generate_thumbnails():
+            scales = set()
+            crops = set()
+            for r_width, r_height, r_method, r_type in requirements:
+                if r_method == "scale":
+                    t_width, t_height = thumbnailer.aspect(r_width, r_height)
+                    scales.add((
+                        min(m_width, t_width), min(m_height, t_height), r_type,
+                    ))
+                elif r_method == "crop":
+                    crops.add((r_width, r_height, r_type))
+
+            for t_width, t_height, t_type in scales:
+                t_method = "scale"
+                t_path = self.filepaths.local_media_thumbnail(
+                    media_id, t_width, t_height, t_type, t_method
+                )
+                self._makedirs(t_path)
+                t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
+
+                local_thumbnails.append((
+                    media_id, t_width, t_height, t_type, t_method, t_len
+                ))
+
+            for t_width, t_height, t_type in crops:
+                if (t_width, t_height, t_type) in scales:
+                    # If the aspect ratio of the cropped thumbnail matches a purely
+                    # scaled one then there is no point in calculating a separate
+                    # thumbnail.
+                    continue
+                t_method = "crop"
+                t_path = self.filepaths.local_media_thumbnail(
+                    media_id, t_width, t_height, t_type, t_method
+                )
+                self._makedirs(t_path)
+                t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
+                local_thumbnails.append((
+                    media_id, t_width, t_height, t_type, t_method, t_len
+                ))
+
+        yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
+
+        for l in local_thumbnails:
+            yield self.store.store_local_thumbnail(*l)
+
+        defer.returnValue({
+            "width": m_width,
+            "height": m_height,
+        })
+
+    @defer.inlineCallbacks
+    def _generate_remote_thumbnails(self, server_name, media_id, media_info):
+        media_type = media_info["media_type"]
+        file_id = media_info["filesystem_id"]
+        requirements = self._get_thumbnail_requirements(media_type)
+        if not requirements:
+            return
+
+        remote_thumbnails = []
+
+        input_path = self.filepaths.remote_media_filepath(server_name, file_id)
+        thumbnailer = Thumbnailer(input_path)
+        m_width = thumbnailer.width
+        m_height = thumbnailer.height
+
+        def generate_thumbnails():
+            if m_width * m_height >= self.max_image_pixels:
+                logger.info(
+                    "Image too large to thumbnail %r x %r > %r",
+                    m_width, m_height, self.max_image_pixels
+                )
+                return
+
+            scales = set()
+            crops = set()
+            for r_width, r_height, r_method, r_type in requirements:
+                if r_method == "scale":
+                    t_width, t_height = thumbnailer.aspect(r_width, r_height)
+                    scales.add((
+                        min(m_width, t_width), min(m_height, t_height), r_type,
+                    ))
+                elif r_method == "crop":
+                    crops.add((r_width, r_height, r_type))
+
+            for t_width, t_height, t_type in scales:
+                t_method = "scale"
+                t_path = self.filepaths.remote_media_thumbnail(
+                    server_name, file_id, t_width, t_height, t_type, t_method
+                )
+                self._makedirs(t_path)
+                t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
+                remote_thumbnails.append([
+                    server_name, media_id, file_id,
+                    t_width, t_height, t_type, t_method, t_len
+                ])
+
+            for t_width, t_height, t_type in crops:
+                if (t_width, t_height, t_type) in scales:
+                    # If the aspect ratio of the cropped thumbnail matches a purely
+                    # scaled one then there is no point in calculating a separate
+                    # thumbnail.
+                    continue
+                t_method = "crop"
+                t_path = self.filepaths.remote_media_thumbnail(
+                    server_name, file_id, t_width, t_height, t_type, t_method
+                )
+                self._makedirs(t_path)
+                t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
+                remote_thumbnails.append([
+                    server_name, media_id, file_id,
+                    t_width, t_height, t_type, t_method, t_len
+                ])
+
+        yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
+
+        for r in remote_thumbnails:
+            yield self.store.store_remote_media_thumbnail(*r)
+
+        defer.returnValue({
+            "width": m_width,
+            "height": m_height,
+        })
+
+
 class MediaRepositoryResource(Resource):
     """File uploading and downloading.
 
@@ -75,9 +459,12 @@ class MediaRepositoryResource(Resource):
     def __init__(self, hs):
         Resource.__init__(self)
         filepaths = MediaFilePaths(hs.config.media_store_path)
-        self.putChild("upload", UploadResource(hs, filepaths))
-        self.putChild("download", DownloadResource(hs, filepaths))
-        self.putChild("thumbnail", ThumbnailResource(hs, filepaths))
+
+        media_repo = MediaRepository(hs, filepaths)
+
+        self.putChild("upload", UploadResource(hs, media_repo))
+        self.putChild("download", DownloadResource(hs, media_repo))
+        self.putChild("thumbnail", ThumbnailResource(hs, media_repo))
         self.putChild("identicon", IdenticonResource())
         if hs.config.url_preview_enabled:
-            self.putChild("preview_url", PreviewUrlResource(hs, filepaths))
+            self.putChild("preview_url", PreviewUrlResource(hs, media_repo))
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index c27ba72735..69327ac493 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -13,10 +13,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .base_resource import BaseMediaResource
-
 from twisted.web.server import NOT_DONE_YET
 from twisted.internet import defer
+from twisted.web.resource import Resource
 
 from synapse.api.errors import (
     SynapseError, Codes,
@@ -41,12 +40,22 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-class PreviewUrlResource(BaseMediaResource):
+class PreviewUrlResource(Resource):
     isLeaf = True
 
-    def __init__(self, hs, filepaths):
-        BaseMediaResource.__init__(self, hs, filepaths)
+    def __init__(self, hs, media_repo):
+        Resource.__init__(self)
+
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+        self.version_string = hs.version_string
+        self.filepaths = media_repo.filepaths
+        self.max_spider_size = hs.config.max_spider_size
+        self.server_name = hs.hostname
+        self.store = hs.get_datastore()
         self.client = SpiderHttpClient(hs)
+        self.media_repo = media_repo
+
         if hasattr(hs.config, "url_preview_url_blacklist"):
             self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
 
@@ -156,7 +165,7 @@ class PreviewUrlResource(BaseMediaResource):
         logger.debug("got media_info of '%s'" % media_info)
 
         if self._is_media(media_info['media_type']):
-            dims = yield self._generate_local_thumbnails(
+            dims = yield self.media_repo._generate_local_thumbnails(
                 media_info['filesystem_id'], media_info
             )
 
@@ -179,23 +188,27 @@ class PreviewUrlResource(BaseMediaResource):
         elif self._is_html(media_info['media_type']):
             # TODO: somehow stop a big HTML tree from exploding synapse's RAM
 
-            from lxml import html
+            from lxml import etree
+
+            file = open(media_info['filename'])
+            body = file.read()
+            file.close()
+
+            # clobber the encoding from the content-type, or default to utf-8
+            # XXX: this overrides any <meta/> or XML charset headers in the body
+            # which may pose problems, but so far seems to work okay.
+            match = re.match(r'.*; *charset=(.*?)(;|$)', media_info['media_type'], re.I)
+            encoding = match.group(1) if match else "utf-8"
 
             try:
-                tree = html.parse(media_info['filename'])
+                parser = etree.HTMLParser(recover=True, encoding=encoding)
+                tree = etree.fromstring(body, parser)
                 og = yield self._calc_og(tree, media_info, requester)
             except UnicodeDecodeError:
-                # XXX: evil evil bodge
-                # Empirically, sites like google.com mix Latin-1 and utf-8
-                # encodings in the same page.  The rogue Latin-1 characters
-                # cause lxml to choke with a UnicodeDecodeError, so if we
-                # see this we go and do a manual decode of the HTML before
-                # handing it to lxml as utf-8 encoding, counter-intuitively,
-                # which seems to make it happier...
-                file = open(media_info['filename'])
-                body = file.read()
-                file.close()
-                tree = html.fromstring(body.decode('utf-8', 'ignore'))
+                # blindly try decoding the body as utf-8, which seems to fix
+                # the charset mismatches on https://google.com
+                parser = etree.HTMLParser(recover=True, encoding=encoding)
+                tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser)
                 og = yield self._calc_og(tree, media_info, requester)
 
         else:
@@ -287,7 +300,7 @@ class PreviewUrlResource(BaseMediaResource):
 
             if self._is_media(image_info['media_type']):
                 # TODO: make sure we don't choke on white-on-transparent images
-                dims = yield self._generate_local_thumbnails(
+                dims = yield self.media_repo._generate_local_thumbnails(
                     image_info['filesystem_id'], image_info
                 )
                 if dims:
@@ -358,7 +371,7 @@ class PreviewUrlResource(BaseMediaResource):
         file_id = random_string(24)
 
         fname = self.filepaths.local_media_filepath(file_id)
-        self._makedirs(fname)
+        self.media_repo._makedirs(fname)
 
         try:
             with open(fname, "wb") as f:
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 40ef22459c..234dd4261c 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -14,7 +14,8 @@
 # limitations under the License.
 
 
-from .base_resource import BaseMediaResource, parse_media_id
+from ._base import parse_media_id, respond_404, respond_with_file
+from twisted.web.resource import Resource
 from synapse.http.servlet import parse_string, parse_integer
 from synapse.http.server import request_handler
 
@@ -26,9 +27,19 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-class ThumbnailResource(BaseMediaResource):
+class ThumbnailResource(Resource):
     isLeaf = True
 
+    def __init__(self, hs, media_repo):
+        Resource.__init__(self)
+
+        self.store = hs.get_datastore()
+        self.filepaths = media_repo.filepaths
+        self.media_repo = media_repo
+        self.dynamic_thumbnails = hs.config.dynamic_thumbnails
+        self.server_name = hs.hostname
+        self.version_string = hs.version_string
+
     def render_GET(self, request):
         self._async_render_GET(request)
         return NOT_DONE_YET
@@ -69,12 +80,12 @@ class ThumbnailResource(BaseMediaResource):
         media_info = yield self.store.get_local_media(media_id)
 
         if not media_info:
-            self._respond_404(request)
+            respond_404(request)
             return
 
         # if media_info["media_type"] == "image/svg+xml":
         #     file_path = self.filepaths.local_media_filepath(media_id)
-        #     yield self._respond_with_file(request, media_info["media_type"], file_path)
+        #     yield respond_with_file(request, media_info["media_type"], file_path)
         #     return
 
         thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
@@ -91,7 +102,7 @@ class ThumbnailResource(BaseMediaResource):
             file_path = self.filepaths.local_media_thumbnail(
                 media_id, t_width, t_height, t_type, t_method,
             )
-            yield self._respond_with_file(request, t_type, file_path)
+            yield respond_with_file(request, t_type, file_path)
 
         else:
             yield self._respond_default_thumbnail(
@@ -105,12 +116,12 @@ class ThumbnailResource(BaseMediaResource):
         media_info = yield self.store.get_local_media(media_id)
 
         if not media_info:
-            self._respond_404(request)
+            respond_404(request)
             return
 
         # if media_info["media_type"] == "image/svg+xml":
         #     file_path = self.filepaths.local_media_filepath(media_id)
-        #     yield self._respond_with_file(request, media_info["media_type"], file_path)
+        #     yield respond_with_file(request, media_info["media_type"], file_path)
         #     return
 
         thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
@@ -124,18 +135,18 @@ class ThumbnailResource(BaseMediaResource):
                 file_path = self.filepaths.local_media_thumbnail(
                     media_id, desired_width, desired_height, desired_type, desired_method,
                 )
-                yield self._respond_with_file(request, desired_type, file_path)
+                yield respond_with_file(request, desired_type, file_path)
                 return
 
         logger.debug("We don't have a local thumbnail of that size. Generating")
 
         # Okay, so we generate one.
-        file_path = yield self._generate_local_exact_thumbnail(
+        file_path = yield self.media_repo.generate_local_exact_thumbnail(
             media_id, desired_width, desired_height, desired_method, desired_type
         )
 
         if file_path:
-            yield self._respond_with_file(request, desired_type, file_path)
+            yield respond_with_file(request, desired_type, file_path)
         else:
             yield self._respond_default_thumbnail(
                 request, media_info, desired_width, desired_height,
@@ -146,11 +157,11 @@ class ThumbnailResource(BaseMediaResource):
     def _select_or_generate_remote_thumbnail(self, request, server_name, media_id,
                                              desired_width, desired_height,
                                              desired_method, desired_type):
-        media_info = yield self._get_remote_media(server_name, media_id)
+        media_info = yield self.media_repo.get_remote_media(server_name, media_id)
 
         # if media_info["media_type"] == "image/svg+xml":
         #     file_path = self.filepaths.remote_media_filepath(server_name, media_id)
-        #     yield self._respond_with_file(request, media_info["media_type"], file_path)
+        #     yield respond_with_file(request, media_info["media_type"], file_path)
         #     return
 
         thumbnail_infos = yield self.store.get_remote_media_thumbnails(
@@ -170,19 +181,19 @@ class ThumbnailResource(BaseMediaResource):
                     server_name, file_id, desired_width, desired_height,
                     desired_type, desired_method,
                 )
-                yield self._respond_with_file(request, desired_type, file_path)
+                yield respond_with_file(request, desired_type, file_path)
                 return
 
         logger.debug("We don't have a local thumbnail of that size. Generating")
 
         # Okay, so we generate one.
-        file_path = yield self._generate_remote_exact_thumbnail(
+        file_path = yield self.media_repo.generate_remote_exact_thumbnail(
             server_name, file_id, media_id, desired_width,
             desired_height, desired_method, desired_type
         )
 
         if file_path:
-            yield self._respond_with_file(request, desired_type, file_path)
+            yield respond_with_file(request, desired_type, file_path)
         else:
             yield self._respond_default_thumbnail(
                 request, media_info, desired_width, desired_height,
@@ -194,11 +205,11 @@ class ThumbnailResource(BaseMediaResource):
                                   height, method, m_type):
         # TODO: Don't download the whole remote file
         # We should proxy the thumbnail from the remote server instead.
-        media_info = yield self._get_remote_media(server_name, media_id)
+        media_info = yield self.media_repo.get_remote_media(server_name, media_id)
 
         # if media_info["media_type"] == "image/svg+xml":
         #     file_path = self.filepaths.remote_media_filepath(server_name, media_id)
-        #     yield self._respond_with_file(request, media_info["media_type"], file_path)
+        #     yield respond_with_file(request, media_info["media_type"], file_path)
         #     return
 
         thumbnail_infos = yield self.store.get_remote_media_thumbnails(
@@ -219,7 +230,7 @@ class ThumbnailResource(BaseMediaResource):
             file_path = self.filepaths.remote_media_thumbnail(
                 server_name, file_id, t_width, t_height, t_type, t_method,
             )
-            yield self._respond_with_file(request, t_type, file_path, t_length)
+            yield respond_with_file(request, t_type, file_path, t_length)
         else:
             yield self._respond_default_thumbnail(
                 request, media_info, width, height, method, m_type,
@@ -245,7 +256,7 @@ class ThumbnailResource(BaseMediaResource):
                 "_default", "_default",
             )
         if not thumbnail_infos:
-            self._respond_404(request)
+            respond_404(request)
             return
 
         thumbnail_info = self._select_thumbnail(
@@ -261,7 +272,7 @@ class ThumbnailResource(BaseMediaResource):
         file_path = self.filepaths.default_thumbnail(
             top_level_type, sub_type, t_width, t_height, t_type, t_method,
         )
-        yield self.respond_with_file(request, t_type, file_path, t_length)
+        yield respond_with_file(request, t_type, file_path, t_length)
 
     def _select_thumbnail(self, desired_width, desired_height, desired_method,
                           desired_type, thumbnail_infos):
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 9c7ad4ae85..299e1f6e56 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -15,20 +15,33 @@
 
 from synapse.http.server import respond_with_json, request_handler
 
-from synapse.util.stringutils import random_string
 from synapse.api.errors import SynapseError
 
 from twisted.web.server import NOT_DONE_YET
 from twisted.internet import defer
 
-from .base_resource import BaseMediaResource
+from twisted.web.resource import Resource
 
 import logging
 
 logger = logging.getLogger(__name__)
 
 
-class UploadResource(BaseMediaResource):
+class UploadResource(Resource):
+    isLeaf = True
+
+    def __init__(self, hs, media_repo):
+        Resource.__init__(self)
+
+        self.media_repo = media_repo
+        self.filepaths = media_repo.filepaths
+        self.store = hs.get_datastore()
+        self.clock = hs.get_clock()
+        self.server_name = hs.hostname
+        self.auth = hs.get_auth()
+        self.max_upload_size = hs.config.max_upload_size
+        self.version_string = hs.version_string
+
     def render_POST(self, request):
         self._async_render_POST(request)
         return NOT_DONE_YET
@@ -37,36 +50,6 @@ class UploadResource(BaseMediaResource):
         respond_with_json(request, 200, {}, send_cors=True)
         return NOT_DONE_YET
 
-    @defer.inlineCallbacks
-    def create_content(self, media_type, upload_name, content, content_length,
-                       auth_user):
-        media_id = random_string(24)
-
-        fname = self.filepaths.local_media_filepath(media_id)
-        self._makedirs(fname)
-
-        # This shouldn't block for very long because the content will have
-        # already been uploaded at this point.
-        with open(fname, "wb") as f:
-            f.write(content)
-
-        yield self.store.store_local_media(
-            media_id=media_id,
-            media_type=media_type,
-            time_now_ms=self.clock.time_msec(),
-            upload_name=upload_name,
-            media_length=content_length,
-            user_id=auth_user,
-        )
-        media_info = {
-            "media_type": media_type,
-            "media_length": content_length,
-        }
-
-        yield self._generate_local_thumbnails(media_id, media_info)
-
-        defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
-
     @request_handler
     @defer.inlineCallbacks
     def _async_render_POST(self, request):
@@ -108,7 +91,7 @@ class UploadResource(BaseMediaResource):
         #     disposition = headers.getRawHeaders("Content-Disposition")[0]
         # TODO(markjh): parse content-dispostion
 
-        content_uri = yield self.create_content(
+        content_uri = yield self.media_repo.create_content(
             media_type, upload_name, request.content.read(),
             content_length, requester.user
         )
diff --git a/synapse/state.py b/synapse/state.py
index 58211f5feb..d0f76dc4f5 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -214,7 +214,7 @@ class StateHandler(object):
 
         if self._state_cache is not None:
             cache = self._state_cache.get(group_names, None)
-            if cache and cache.state_group:
+            if cache:
                 cache.ts = self.clock.time_msec()
 
                 event_dict = yield self.store.get_events(cache.state.values())
@@ -230,22 +230,34 @@ class StateHandler(object):
                     (cache.state_group, state, prev_states)
                 )
 
+        logger.info("Resolving state for %s with %d groups", room_id, len(state_groups))
+
         new_state, prev_states = self._resolve_events(
             state_groups.values(), event_type, state_key
         )
 
+        state_group = None
+        new_state_event_ids = frozenset(e.event_id for e in new_state.values())
+        for sg, events in state_groups.items():
+            if new_state_event_ids == frozenset(e.event_id for e in events):
+                state_group = sg
+                break
+
         if self._state_cache is not None:
             cache = _StateCacheEntry(
                 state={key: event.event_id for key, event in new_state.items()},
-                state_group=None,
+                state_group=state_group,
                 ts=self.clock.time_msec()
             )
 
             self._state_cache[group_names] = cache
 
-        defer.returnValue((None, new_state, prev_states))
+        defer.returnValue((state_group, new_state, prev_states))
 
     def resolve_events(self, state_sets, event):
+        logger.info(
+            "Resolving state for %s with %d groups", event.room_id, len(state_sets)
+        )
         if event.is_state():
             return self._resolve_events(
                 state_sets, event.type, event.state_key
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index c5d2a3a6df..5b743db67a 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -174,6 +174,12 @@ class StateStore(SQLBaseStore):
             return [r[0] for r in results]
         return self.runInteraction("get_current_state_for_key", f)
 
+    @cached(num_args=2, lru=True, max_entries=1000)
+    def _get_state_group_from_group(self, group, types):
+        raise NotImplementedError()
+
+    @cachedList(cached_method_name="_get_state_group_from_group",
+                list_name="groups", num_args=2, inlineCallbacks=True)
     def _get_state_groups_from_groups(self, groups, types):
         """Returns dictionary state_group -> (dict of (type, state_key) -> event id)
         """
@@ -201,18 +207,23 @@ class StateStore(SQLBaseStore):
             txn.execute(sql, args)
             rows = self.cursor_to_dict(txn)
 
-            results = {}
+            results = {group: {} for group in groups}
             for row in rows:
                 key = (row["type"], row["state_key"])
-                results.setdefault(row["state_group"], {})[key] = row["event_id"]
+                results[row["state_group"]][key] = row["event_id"]
             return results
 
+        results = {}
+
         chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)]
         for chunk in chunks:
-            return self.runInteraction(
+            res = yield self.runInteraction(
                 "_get_state_groups_from_groups",
                 f, chunk
             )
+            results.update(res)
+
+        defer.returnValue(results)
 
     @defer.inlineCallbacks
     def get_state_for_events(self, event_ids, types):
@@ -359,6 +370,8 @@ class StateStore(SQLBaseStore):
         a `state_key` of None matches all state_keys. If `types` is None then
         all events are returned.
         """
+        if types:
+            types = frozenset(types)
         results = {}
         missing_groups = []
         if types is not None:
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index b462495eb8..2b3f0bef3c 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from synapse.api.errors import SynapseError
 from synapse.util.logcontext import PreserveLoggingContext
 
 from twisted.internet import defer, reactor, task
@@ -80,7 +81,7 @@ class Clock(object):
 
         def timed_out_fn():
             try:
-                ret_deferred.errback(RuntimeError("Timed out"))
+                ret_deferred.errback(SynapseError(504, "Timed out"))
             except:
                 pass
 
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index c51b641125..e1f374807e 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -50,7 +50,7 @@ block_db_txn_duration = metrics.register_distribution(
 class Measure(object):
     __slots__ = [
         "clock", "name", "start_context", "start", "new_context", "ru_utime",
-        "ru_stime", "db_txn_count", "db_txn_duration"
+        "ru_stime", "db_txn_count", "db_txn_duration", "created_context"
     ]
 
     def __init__(self, clock, name):
@@ -58,14 +58,20 @@ class Measure(object):
         self.name = name
         self.start_context = None
         self.start = None
+        self.created_context = False
 
     def __enter__(self):
         self.start = self.clock.time_msec()
         self.start_context = LoggingContext.current_context()
-        if self.start_context:
-            self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
-            self.db_txn_count = self.start_context.db_txn_count
-            self.db_txn_duration = self.start_context.db_txn_duration
+        if not self.start_context:
+            logger.warn("Entered Measure without log context: %s", self.name)
+            self.start_context = LoggingContext("Measure")
+            self.start_context.__enter__()
+            self.created_context = True
+
+        self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
+        self.db_txn_count = self.start_context.db_txn_count
+        self.db_txn_duration = self.start_context.db_txn_duration
 
     def __exit__(self, exc_type, exc_val, exc_tb):
         if exc_type is not None or not self.start_context:
@@ -91,7 +97,12 @@ class Measure(object):
 
         block_ru_utime.inc_by(ru_utime - self.ru_utime, self.name)
         block_ru_stime.inc_by(ru_stime - self.ru_stime, self.name)
-        block_db_txn_count.inc_by(context.db_txn_count - self.db_txn_count, self.name)
+        block_db_txn_count.inc_by(
+            context.db_txn_count - self.db_txn_count, self.name
+        )
         block_db_txn_duration.inc_by(
             context.db_txn_duration - self.db_txn_duration, self.name
         )
+
+        if self.created_context:
+            self.start_context.__exit__(exc_type, exc_val, exc_tb)
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 983caafe8a..1f13cd0bc0 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -15,8 +15,6 @@
 from twisted.internet import defer
 from tests import unittest
 
-from synapse.replication.slave.storage.events import SlavedEventStore
-
 from mock import Mock, NonCallableMock
 from tests.utils import setup_test_homeserver
 from synapse.replication.resource import ReplicationResource
@@ -38,7 +36,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
         self.replication = ReplicationResource(self.hs)
 
         self.master_store = self.hs.get_datastore()
-        self.slaved_store = SlavedEventStore(self.hs.get_db_conn(), self.hs)
+        self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
         self.event_id = 0
 
     @defer.inlineCallbacks
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index baa4a26eb5..41a626cf70 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -16,6 +16,7 @@ from ._base import BaseSlavedStoreTestCase
 
 from synapse.events import FrozenEvent, _EventInternalMetadata
 from synapse.events.snapshot import EventContext
+from synapse.replication.slave.storage.events import SlavedEventStore
 from synapse.storage.roommember import RoomsForUser
 
 from twisted.internet import defer
@@ -43,6 +44,8 @@ def patch__eq__(cls):
 
 class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
 
+    STORE_TYPE = SlavedEventStore
+
     def setUp(self):
         # Patch up the equality operator for events so that we can check
         # whether lists of events match using assertEquals
@@ -251,6 +254,18 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
         yield self.check("get_event", [msg.event_id], redacted)
 
+    @defer.inlineCallbacks
+    def test_invites(self):
+        yield self.check("get_invited_rooms_for_user", [USER_ID_2], [])
+        event = yield self.persist(
+            type="m.room.member", key=USER_ID_2, membership="invite"
+        )
+        yield self.replicate()
+        yield self.check("get_invited_rooms_for_user", [USER_ID_2], [RoomsForUser(
+            ROOM_ID, USER_ID, "invite", event.event_id,
+            event.internal_metadata.stream_ordering
+        )])
+
     event_id = 0
 
     @defer.inlineCallbacks
diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py
new file mode 100644
index 0000000000..6624fe4eea
--- /dev/null
+++ b/tests/replication/slave/storage/test_receipts.py
@@ -0,0 +1,39 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import BaseSlavedStoreTestCase
+
+from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
+
+from twisted.internet import defer
+
+USER_ID = "@feeling:blue"
+ROOM_ID = "!room:blue"
+EVENT_ID = "$event:blue"
+
+
+class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
+
+    STORE_TYPE = SlavedReceiptsStore
+
+    @defer.inlineCallbacks
+    def test_receipt(self):
+        yield self.check("get_receipts_for_user", [USER_ID, "m.read"], {})
+        yield self.master_store.insert_receipt(
+            ROOM_ID, "m.read", USER_ID, [EVENT_ID], {}
+        )
+        yield self.replicate()
+        yield self.check("get_receipts_for_user", [USER_ID, "m.read"], {
+            ROOM_ID: EVENT_ID
+        })
diff --git a/tests/test_state.py b/tests/test_state.py
index a1ea7ef672..1a11bbcee0 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -140,13 +140,13 @@ class StateTestCase(unittest.TestCase):
                 "add_event_hashes",
             ]
         )
-        hs = Mock(spec=[
+        hs = Mock(spec_set=[
             "get_datastore", "get_auth", "get_state_handler", "get_clock",
         ])
         hs.get_datastore.return_value = self.store
         hs.get_state_handler.return_value = None
-        hs.get_auth.return_value = Auth(hs)
         hs.get_clock.return_value = MockClock()
+        hs.get_auth.return_value = Auth(hs)
 
         self.state = StateHandler(hs)
         self.event_id = 0