summary refs log tree commit diff
path: root/synapse/rest
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2016-06-09 14:21:23 +0100
committerErik Johnston <erik@matrix.org>2016-06-09 14:21:23 +0100
commitba0406d10da32ebebf4185f01841f236371e0ae8 (patch)
tree70f492094b7fb962a8161bd2304c6846b3ac3f40 /synapse/rest
parentMerge pull request #801 from ruma/readme-history-storage (diff)
parentChange CHANGELOG (diff)
downloadsynapse-ba0406d10da32ebebf4185f01841f236371e0ae8.tar.xz
Merge branch 'release-v0.16.0' of github.com:matrix-org/synapse v0.16.0
Diffstat (limited to 'synapse/rest')
-rw-r--r--synapse/rest/__init__.py4
-rw-r--r--synapse/rest/client/v1/login.py67
-rw-r--r--synapse/rest/client/v1/presence.py20
-rw-r--r--synapse/rest/client/v1/push_rule.py6
-rw-r--r--synapse/rest/client/v1/pusher.py101
-rw-r--r--synapse/rest/client/v1/register.py71
-rw-r--r--synapse/rest/client/v1/room.py56
-rw-r--r--synapse/rest/client/v2_alpha/account.py5
-rw-r--r--synapse/rest/client/v2_alpha/auth.py2
-rw-r--r--synapse/rest/client/v2_alpha/openid.py96
-rw-r--r--synapse/rest/client/v2_alpha/receipts.py2
-rw-r--r--synapse/rest/client/v2_alpha/register.py36
-rw-r--r--synapse/rest/client/v2_alpha/report_event.py59
-rw-r--r--synapse/rest/client/v2_alpha/sync.py87
-rw-r--r--synapse/rest/client/v2_alpha/tokenrefresh.py2
-rw-r--r--synapse/rest/key/v1/server_key_resource.py1
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py4
-rw-r--r--synapse/rest/media/v1/_base.py110
-rw-r--r--synapse/rest/media/v1/base_resource.py459
-rw-r--r--synapse/rest/media/v1/download_resource.py27
-rw-r--r--synapse/rest/media/v1/media_repository.py396
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py451
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py68
-rw-r--r--synapse/rest/media/v1/upload_resource.py54
24 files changed, 1589 insertions, 595 deletions
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 6688fa8fa0..8b223e032b 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -44,6 +44,8 @@ from synapse.rest.client.v2_alpha import (
     tokenrefresh,
     tags,
     account_data,
+    report_event,
+    openid,
 )
 
 from synapse.http.server import JsonResource
@@ -86,3 +88,5 @@ class ClientRestResource(JsonResource):
         tokenrefresh.register_servlets(hs, client_resource)
         tags.register_servlets(hs, client_resource)
         account_data.register_servlets(hs, client_resource)
+        report_event.register_servlets(hs, client_resource)
+        openid.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index fe593d07ce..8df9d10efa 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -43,20 +43,27 @@ class LoginRestServlet(ClientV1RestServlet):
     SAML2_TYPE = "m.login.saml2"
     CAS_TYPE = "m.login.cas"
     TOKEN_TYPE = "m.login.token"
+    JWT_TYPE = "m.login.jwt"
 
     def __init__(self, hs):
         super(LoginRestServlet, self).__init__(hs)
         self.idp_redirect_url = hs.config.saml2_idp_redirect_url
         self.password_enabled = hs.config.password_enabled
         self.saml2_enabled = hs.config.saml2_enabled
+        self.jwt_enabled = hs.config.jwt_enabled
+        self.jwt_secret = hs.config.jwt_secret
+        self.jwt_algorithm = hs.config.jwt_algorithm
         self.cas_enabled = hs.config.cas_enabled
         self.cas_server_url = hs.config.cas_server_url
         self.cas_required_attributes = hs.config.cas_required_attributes
         self.servername = hs.config.server_name
         self.http_client = hs.get_simple_http_client()
+        self.auth_handler = self.hs.get_auth_handler()
 
     def on_GET(self, request):
         flows = []
+        if self.jwt_enabled:
+            flows.append({"type": LoginRestServlet.JWT_TYPE})
         if self.saml2_enabled:
             flows.append({"type": LoginRestServlet.SAML2_TYPE})
         if self.cas_enabled:
@@ -98,6 +105,10 @@ class LoginRestServlet(ClientV1RestServlet):
                     "uri": "%s%s" % (self.idp_redirect_url, relay_state)
                 }
                 defer.returnValue((200, result))
+            elif self.jwt_enabled and (login_submission["type"] ==
+                                       LoginRestServlet.JWT_TYPE):
+                result = yield self.do_jwt_login(login_submission)
+                defer.returnValue(result)
             # TODO Delete this after all CAS clients switch to token login instead
             elif self.cas_enabled and (login_submission["type"] ==
                                        LoginRestServlet.CAS_TYPE):
@@ -133,7 +144,7 @@ class LoginRestServlet(ClientV1RestServlet):
                 user_id, self.hs.hostname
             ).to_string()
 
-        auth_handler = self.handlers.auth_handler
+        auth_handler = self.auth_handler
         user_id, access_token, refresh_token = yield auth_handler.login_with_password(
             user_id=user_id,
             password=login_submission["password"])
@@ -150,7 +161,7 @@ class LoginRestServlet(ClientV1RestServlet):
     @defer.inlineCallbacks
     def do_token_login(self, login_submission):
         token = login_submission['token']
-        auth_handler = self.handlers.auth_handler
+        auth_handler = self.auth_handler
         user_id = (
             yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
         )
@@ -184,7 +195,7 @@ class LoginRestServlet(ClientV1RestServlet):
                     raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
 
         user_id = UserID.create(user, self.hs.hostname).to_string()
-        auth_handler = self.handlers.auth_handler
+        auth_handler = self.auth_handler
         user_exists = yield auth_handler.does_user_exist(user_id)
         if user_exists:
             user_id, access_token, refresh_token = (
@@ -209,6 +220,54 @@ class LoginRestServlet(ClientV1RestServlet):
 
         defer.returnValue((200, result))
 
+    @defer.inlineCallbacks
+    def do_jwt_login(self, login_submission):
+        token = login_submission.get("token", None)
+        if token is None:
+            raise LoginError(
+                401, "Token field for JWT is missing",
+                errcode=Codes.UNAUTHORIZED
+            )
+
+        import jwt
+        from jwt.exceptions import InvalidTokenError
+
+        try:
+            payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm])
+        except jwt.ExpiredSignatureError:
+            raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
+        except InvalidTokenError:
+            raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
+
+        user = payload.get("sub", None)
+        if user is None:
+            raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
+
+        user_id = UserID.create(user, self.hs.hostname).to_string()
+        auth_handler = self.auth_handler
+        user_exists = yield auth_handler.does_user_exist(user_id)
+        if user_exists:
+            user_id, access_token, refresh_token = (
+                yield auth_handler.get_login_tuple_for_user_id(user_id)
+            )
+            result = {
+                "user_id": user_id,  # may have changed
+                "access_token": access_token,
+                "refresh_token": refresh_token,
+                "home_server": self.hs.hostname,
+            }
+        else:
+            user_id, access_token = (
+                yield self.handlers.registration_handler.register(localpart=user)
+            )
+            result = {
+                "user_id": user_id,  # may have changed
+                "access_token": access_token,
+                "home_server": self.hs.hostname,
+            }
+
+        defer.returnValue((200, result))
+
     # TODO Delete this after all CAS clients switch to token login instead
     def parse_cas_response(self, cas_response_body):
         root = ET.fromstring(cas_response_body)
@@ -354,7 +413,7 @@ class CasTicketServlet(ClientV1RestServlet):
                     raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
 
         user_id = UserID.create(user, self.hs.hostname).to_string()
-        auth_handler = self.handlers.auth_handler
+        auth_handler = self.auth_handler
         user_exists = yield auth_handler.does_user_exist(user_id)
         if not user_exists:
             user_id, _ = (
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index 27d9ed586b..eafdce865e 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -30,20 +30,24 @@ logger = logging.getLogger(__name__)
 class PresenceStatusRestServlet(ClientV1RestServlet):
     PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status")
 
+    def __init__(self, hs):
+        super(PresenceStatusRestServlet, self).__init__(hs)
+        self.presence_handler = hs.get_presence_handler()
+
     @defer.inlineCallbacks
     def on_GET(self, request, user_id):
         requester = yield self.auth.get_user_by_req(request)
         user = UserID.from_string(user_id)
 
         if requester.user != user:
-            allowed = yield self.handlers.presence_handler.is_visible(
+            allowed = yield self.presence_handler.is_visible(
                 observed_user=user, observer_user=requester.user,
             )
 
             if not allowed:
                 raise AuthError(403, "You are not allowed to see their presence.")
 
-        state = yield self.handlers.presence_handler.get_state(target_user=user)
+        state = yield self.presence_handler.get_state(target_user=user)
 
         defer.returnValue((200, state))
 
@@ -74,7 +78,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
         except:
             raise SynapseError(400, "Unable to parse state")
 
-        yield self.handlers.presence_handler.set_state(user, state)
+        yield self.presence_handler.set_state(user, state)
 
         defer.returnValue((200, {}))
 
@@ -85,6 +89,10 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
 class PresenceListRestServlet(ClientV1RestServlet):
     PATTERNS = client_path_patterns("/presence/list/(?P<user_id>[^/]*)")
 
+    def __init__(self, hs):
+        super(PresenceListRestServlet, self).__init__(hs)
+        self.presence_handler = hs.get_presence_handler()
+
     @defer.inlineCallbacks
     def on_GET(self, request, user_id):
         requester = yield self.auth.get_user_by_req(request)
@@ -96,7 +104,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
         if requester.user != user:
             raise SynapseError(400, "Cannot get another user's presence list")
 
-        presence = yield self.handlers.presence_handler.get_presence_list(
+        presence = yield self.presence_handler.get_presence_list(
             observer_user=user, accepted=True
         )
 
@@ -123,7 +131,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
                 if len(u) == 0:
                     continue
                 invited_user = UserID.from_string(u)
-                yield self.handlers.presence_handler.send_presence_invite(
+                yield self.presence_handler.send_presence_invite(
                     observer_user=user, observed_user=invited_user
                 )
 
@@ -134,7 +142,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
                 if len(u) == 0:
                     continue
                 dropped_user = UserID.from_string(u)
-                yield self.handlers.presence_handler.drop(
+                yield self.presence_handler.drop(
                     observer_user=user, observed_user=dropped_user
                 )
 
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 02d837ee6a..6bb4821ec6 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -128,11 +128,9 @@ class PushRuleRestServlet(ClientV1RestServlet):
         # we build up the full structure and then decide which bits of it
         # to send which means doing unnecessary work sometimes but is
         # is probably not going to make a whole lot of difference
-        rawrules = yield self.store.get_push_rules_for_user(user_id)
+        rules = yield self.store.get_push_rules_for_user(user_id)
 
-        enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id)
-
-        rules = format_push_rules_for_user(requester.user, rawrules, enabled_map)
+        rules = format_push_rules_for_user(requester.user, rules)
 
         path = request.postpath[1:]
 
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 9881f068c3..9a2ed6ed88 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -17,7 +17,11 @@ from twisted.internet import defer
 
 from synapse.api.errors import SynapseError, Codes
 from synapse.push import PusherConfigException
-from synapse.http.servlet import parse_json_object_from_request
+from synapse.http.servlet import (
+    parse_json_object_from_request, parse_string, RestServlet
+)
+from synapse.http.server import finish_request
+from synapse.api.errors import StoreError
 
 from .base import ClientV1RestServlet, client_path_patterns
 
@@ -26,11 +30,48 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-class PusherRestServlet(ClientV1RestServlet):
+class PushersRestServlet(ClientV1RestServlet):
+    PATTERNS = client_path_patterns("/pushers$")
+
+    def __init__(self, hs):
+        super(PushersRestServlet, self).__init__(hs)
+
+    @defer.inlineCallbacks
+    def on_GET(self, request):
+        requester = yield self.auth.get_user_by_req(request)
+        user = requester.user
+
+        pushers = yield self.hs.get_datastore().get_pushers_by_user_id(
+            user.to_string()
+        )
+
+        allowed_keys = [
+            "app_display_name",
+            "app_id",
+            "data",
+            "device_display_name",
+            "kind",
+            "lang",
+            "profile_tag",
+            "pushkey",
+        ]
+
+        for p in pushers:
+            for k, v in p.items():
+                if k not in allowed_keys:
+                    del p[k]
+
+        defer.returnValue((200, {"pushers": pushers}))
+
+    def on_OPTIONS(self, _):
+        return 200, {}
+
+
+class PushersSetRestServlet(ClientV1RestServlet):
     PATTERNS = client_path_patterns("/pushers/set$")
 
     def __init__(self, hs):
-        super(PusherRestServlet, self).__init__(hs)
+        super(PushersSetRestServlet, self).__init__(hs)
         self.notifier = hs.get_notifier()
 
     @defer.inlineCallbacks
@@ -99,5 +140,57 @@ class PusherRestServlet(ClientV1RestServlet):
         return 200, {}
 
 
+class PushersRemoveRestServlet(RestServlet):
+    """
+    To allow pusher to be delete by clicking a link (ie. GET request)
+    """
+    PATTERNS = client_path_patterns("/pushers/remove$")
+    SUCCESS_HTML = "<html><body>You have been unsubscribed</body><html>"
+
+    def __init__(self, hs):
+        super(RestServlet, self).__init__()
+        self.hs = hs
+        self.notifier = hs.get_notifier()
+        self.auth = hs.get_v1auth()
+
+    @defer.inlineCallbacks
+    def on_GET(self, request):
+        requester = yield self.auth.get_user_by_req(request, rights="delete_pusher")
+        user = requester.user
+
+        app_id = parse_string(request, "app_id", required=True)
+        pushkey = parse_string(request, "pushkey", required=True)
+
+        pusher_pool = self.hs.get_pusherpool()
+
+        try:
+            yield pusher_pool.remove_pusher(
+                app_id=app_id,
+                pushkey=pushkey,
+                user_id=user.to_string(),
+            )
+        except StoreError as se:
+            if se.code != 404:
+                # This is fine: they're already unsubscribed
+                raise
+
+        self.notifier.on_new_replication_data()
+
+        request.setResponseCode(200)
+        request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+        request.setHeader(b"Server", self.hs.version_string)
+        request.setHeader(b"Content-Length", b"%d" % (
+            len(PushersRemoveRestServlet.SUCCESS_HTML),
+        ))
+        request.write(PushersRemoveRestServlet.SUCCESS_HTML)
+        finish_request(request)
+        defer.returnValue(None)
+
+    def on_OPTIONS(self, _):
+        return 200, {}
+
+
 def register_servlets(hs, http_server):
-    PusherRestServlet(hs).register(http_server)
+    PushersRestServlet(hs).register(http_server)
+    PushersSetRestServlet(hs).register(http_server)
+    PushersRemoveRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index c6a2ef2ccc..e3f4fbb0bb 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -355,5 +355,76 @@ class RegisterRestServlet(ClientV1RestServlet):
             )
 
 
+class CreateUserRestServlet(ClientV1RestServlet):
+    """Handles user creation via a server-to-server interface
+    """
+
+    PATTERNS = client_path_patterns("/createUser$", releases=())
+
+    def __init__(self, hs):
+        super(CreateUserRestServlet, self).__init__(hs)
+        self.store = hs.get_datastore()
+        self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
+
+    @defer.inlineCallbacks
+    def on_POST(self, request):
+        user_json = parse_json_object_from_request(request)
+
+        if "access_token" not in request.args:
+            raise SynapseError(400, "Expected application service token.")
+
+        app_service = yield self.store.get_app_service_by_token(
+            request.args["access_token"][0]
+        )
+        if not app_service:
+            raise SynapseError(403, "Invalid application service token.")
+
+        logger.debug("creating user: %s", user_json)
+
+        response = yield self._do_create(user_json)
+
+        defer.returnValue((200, response))
+
+    def on_OPTIONS(self, request):
+        return 403, {}
+
+    @defer.inlineCallbacks
+    def _do_create(self, user_json):
+        yield run_on_reactor()
+
+        if "localpart" not in user_json:
+            raise SynapseError(400, "Expected 'localpart' key.")
+
+        if "displayname" not in user_json:
+            raise SynapseError(400, "Expected 'displayname' key.")
+
+        if "duration_seconds" not in user_json:
+            raise SynapseError(400, "Expected 'duration_seconds' key.")
+
+        localpart = user_json["localpart"].encode("utf-8")
+        displayname = user_json["displayname"].encode("utf-8")
+        duration_seconds = 0
+        try:
+            duration_seconds = int(user_json["duration_seconds"])
+        except ValueError:
+            raise SynapseError(400, "Failed to parse 'duration_seconds'")
+        if duration_seconds > self.direct_user_creation_max_duration:
+            duration_seconds = self.direct_user_creation_max_duration
+
+        handler = self.handlers.registration_handler
+        user_id, token = yield handler.get_or_create_user(
+            localpart=localpart,
+            displayname=displayname,
+            duration_seconds=duration_seconds
+        )
+
+        defer.returnValue({
+            "user_id": user_id,
+            "access_token": token,
+            "home_server": self.hs.hostname,
+        })
+
+
 def register_servlets(hs, http_server):
     RegisterRestServlet(hs).register(http_server)
+    CreateUserRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index a1fa7daf79..db52a1fc39 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -232,7 +232,10 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
 
         if RoomID.is_valid(room_identifier):
             room_id = room_identifier
-            remote_room_hosts = None
+            try:
+                remote_room_hosts = request.args["server_name"]
+            except:
+                remote_room_hosts = None
         elif RoomAlias.is_valid(room_identifier):
             handler = self.handlers.room_member_handler
             room_alias = RoomAlias.from_string(room_identifier)
@@ -276,8 +279,9 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request):
-        handler = self.handlers.room_list_handler
-        data = yield handler.get_public_room_list()
+        handler = self.hs.get_room_list_handler()
+        data = yield handler.get_aggregated_public_room_list()
+
         defer.returnValue((200, data))
 
 
@@ -405,6 +409,42 @@ class RoomEventContext(ClientV1RestServlet):
         defer.returnValue((200, results))
 
 
+class RoomForgetRestServlet(ClientV1RestServlet):
+    def register(self, http_server):
+        PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
+        register_txn_path(self, PATTERNS, http_server)
+
+    @defer.inlineCallbacks
+    def on_POST(self, request, room_id, txn_id=None):
+        requester = yield self.auth.get_user_by_req(
+            request,
+            allow_guest=False,
+        )
+
+        yield self.handlers.room_member_handler.forget(
+            user=requester.user,
+            room_id=room_id,
+        )
+
+        defer.returnValue((200, {}))
+
+    @defer.inlineCallbacks
+    def on_PUT(self, request, room_id, txn_id):
+        try:
+            defer.returnValue(
+                self.txns.get_client_transaction(request, txn_id)
+            )
+        except KeyError:
+            pass
+
+        response = yield self.on_POST(
+            request, room_id, txn_id
+        )
+
+        self.txns.store_client_transaction(request, txn_id, response)
+        defer.returnValue(response)
+
+
 # TODO: Needs unit testing
 class RoomMembershipRestServlet(ClientV1RestServlet):
 
@@ -534,7 +574,8 @@ class RoomTypingRestServlet(ClientV1RestServlet):
 
     def __init__(self, hs):
         super(RoomTypingRestServlet, self).__init__(hs)
-        self.presence_handler = hs.get_handlers().presence_handler
+        self.presence_handler = hs.get_presence_handler()
+        self.typing_handler = hs.get_typing_handler()
 
     @defer.inlineCallbacks
     def on_PUT(self, request, room_id, user_id):
@@ -545,19 +586,17 @@ class RoomTypingRestServlet(ClientV1RestServlet):
 
         content = parse_json_object_from_request(request)
 
-        typing_handler = self.handlers.typing_notification_handler
-
         yield self.presence_handler.bump_presence_active_time(requester.user)
 
         if content["typing"]:
-            yield typing_handler.started_typing(
+            yield self.typing_handler.started_typing(
                 target_user=target_user,
                 auth_user=requester.user,
                 room_id=room_id,
                 timeout=content.get("timeout", 30000),
             )
         else:
-            yield typing_handler.stopped_typing(
+            yield self.typing_handler.stopped_typing(
                 target_user=target_user,
                 auth_user=requester.user,
                 room_id=room_id,
@@ -624,6 +663,7 @@ def register_servlets(hs, http_server):
     RoomMemberListRestServlet(hs).register(http_server)
     RoomMessageListRestServlet(hs).register(http_server)
     JoinRoomAliasServlet(hs).register(http_server)
+    RoomForgetRestServlet(hs).register(http_server)
     RoomMembershipRestServlet(hs).register(http_server)
     RoomSendEventRestServlet(hs).register(http_server)
     PublicRoomListRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 7f8a6a4cf7..9a84873a5f 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -35,7 +35,7 @@ class PasswordRestServlet(RestServlet):
         super(PasswordRestServlet, self).__init__()
         self.hs = hs
         self.auth = hs.get_auth()
-        self.auth_handler = hs.get_handlers().auth_handler
+        self.auth_handler = hs.get_auth_handler()
 
     @defer.inlineCallbacks
     def on_POST(self, request):
@@ -52,6 +52,7 @@ class PasswordRestServlet(RestServlet):
             defer.returnValue((401, result))
 
         user_id = None
+        requester = None
 
         if LoginType.PASSWORD in result:
             # if using password, they should also be logged in
@@ -96,7 +97,7 @@ class ThreepidRestServlet(RestServlet):
         self.hs = hs
         self.identity_handler = hs.get_handlers().identity_handler
         self.auth = hs.get_auth()
-        self.auth_handler = hs.get_handlers().auth_handler
+        self.auth_handler = hs.get_auth_handler()
 
     @defer.inlineCallbacks
     def on_GET(self, request):
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 78181b7b18..58d3cad6a1 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -104,7 +104,7 @@ class AuthRestServlet(RestServlet):
         super(AuthRestServlet, self).__init__()
         self.hs = hs
         self.auth = hs.get_auth()
-        self.auth_handler = hs.get_handlers().auth_handler
+        self.auth_handler = hs.get_auth_handler()
         self.registration_handler = hs.get_handlers().registration_handler
 
     @defer.inlineCallbacks
diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py
new file mode 100644
index 0000000000..aa1cae8e1e
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/openid.py
@@ -0,0 +1,96 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from ._base import client_v2_patterns
+
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.api.errors import AuthError
+from synapse.util.stringutils import random_string
+
+from twisted.internet import defer
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class IdTokenServlet(RestServlet):
+    """
+    Get a bearer token that may be passed to a third party to confirm ownership
+    of a matrix user id.
+
+    The format of the response could be made compatible with the format given
+    in http://openid.net/specs/openid-connect-core-1_0.html#TokenResponse
+
+    But instead of returning a signed "id_token" the response contains the
+    name of the issuing matrix homeserver. This means that for now the third
+    party will need to check the validity of the "id_token" against the
+    federation /openid/userinfo endpoint of the homeserver.
+
+    Request:
+
+    POST /user/{user_id}/openid/request_token?access_token=... HTTP/1.1
+
+    {}
+
+    Response:
+
+    HTTP/1.1 200 OK
+    {
+        "access_token": "ABDEFGH",
+        "token_type": "Bearer",
+        "matrix_server_name": "example.com",
+        "expires_in": 3600,
+    }
+    """
+    PATTERNS = client_v2_patterns(
+        "/user/(?P<user_id>[^/]*)/openid/request_token"
+    )
+
+    EXPIRES_MS = 3600 * 1000
+
+    def __init__(self, hs):
+        super(IdTokenServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
+        self.clock = hs.get_clock()
+        self.server_name = hs.config.server_name
+
+    @defer.inlineCallbacks
+    def on_POST(self, request, user_id):
+        requester = yield self.auth.get_user_by_req(request)
+        if user_id != requester.user.to_string():
+            raise AuthError(403, "Cannot request tokens for other users.")
+
+        # Parse the request body to make sure it's JSON, but ignore the contents
+        # for now.
+        parse_json_object_from_request(request)
+
+        token = random_string(24)
+        ts_valid_until_ms = self.clock.time_msec() + self.EXPIRES_MS
+
+        yield self.store.insert_open_id_token(token, ts_valid_until_ms, user_id)
+
+        defer.returnValue((200, {
+            "access_token": token,
+            "token_type": "Bearer",
+            "matrix_server_name": self.server_name,
+            "expires_in": self.EXPIRES_MS / 1000,
+        }))
+
+
+def register_servlets(hs, http_server):
+    IdTokenServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py
index b831d8c95e..891cef99c6 100644
--- a/synapse/rest/client/v2_alpha/receipts.py
+++ b/synapse/rest/client/v2_alpha/receipts.py
@@ -37,7 +37,7 @@ class ReceiptRestServlet(RestServlet):
         self.hs = hs
         self.auth = hs.get_auth()
         self.receipts_handler = hs.get_handlers().receipts_handler
-        self.presence_handler = hs.get_handlers().presence_handler
+        self.presence_handler = hs.get_presence_handler()
 
     @defer.inlineCallbacks
     def on_POST(self, request, room_id, receipt_type, event_id):
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index d32c06c882..2088c316d1 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -48,7 +48,8 @@ class RegisterRestServlet(RestServlet):
         super(RegisterRestServlet, self).__init__()
         self.hs = hs
         self.auth = hs.get_auth()
-        self.auth_handler = hs.get_handlers().auth_handler
+        self.store = hs.get_datastore()
+        self.auth_handler = hs.get_auth_handler()
         self.registration_handler = hs.get_handlers().registration_handler
         self.identity_handler = hs.get_handlers().identity_handler
 
@@ -100,6 +101,11 @@ class RegisterRestServlet(RestServlet):
 
         # == Application Service Registration ==
         if appservice:
+            # Set the desired user according to the AS API (which uses the
+            # 'user' key not 'username'). Since this is a new addition, we'll
+            # fallback to 'username' if they gave one.
+            if isinstance(body.get("user"), basestring):
+                desired_username = body["user"]
             result = yield self._do_appservice_registration(
                 desired_username, request.args["access_token"][0]
             )
@@ -209,6 +215,34 @@ class RegisterRestServlet(RestServlet):
                         threepid['validated_at'],
                     )
 
+                    # And we add an email pusher for them by default, but only
+                    # if email notifications are enabled (so people don't start
+                    # getting mail spam where they weren't before if email
+                    # notifs are set up on a home server)
+                    if (
+                        self.hs.config.email_enable_notifs and
+                        self.hs.config.email_notif_for_new_users
+                    ):
+                        # Pull the ID of the access token back out of the db
+                        # It would really make more sense for this to be passed
+                        # up when the access token is saved, but that's quite an
+                        # invasive change I'd rather do separately.
+                        user_tuple = yield self.store.get_user_by_access_token(
+                            token
+                        )
+
+                        yield self.hs.get_pusherpool().add_pusher(
+                            user_id=user_id,
+                            access_token=user_tuple["token_id"],
+                            kind="email",
+                            app_id="m.email",
+                            app_display_name="Email Notifications",
+                            device_display_name=threepid["address"],
+                            pushkey=threepid["address"],
+                            lang=None,  # We don't know a user's language here
+                            data={},
+                        )
+
             if 'bind_email' in params and params['bind_email']:
                 logger.info("bind_email specified: binding")
 
diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py
new file mode 100644
index 0000000000..8903e12405
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/report_event.py
@@ -0,0 +1,59 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from ._base import client_v2_patterns
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class ReportEventRestServlet(RestServlet):
+    PATTERNS = client_v2_patterns(
+        "/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$"
+    )
+
+    def __init__(self, hs):
+        super(ReportEventRestServlet, self).__init__()
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+        self.store = hs.get_datastore()
+
+    @defer.inlineCallbacks
+    def on_POST(self, request, room_id, event_id):
+        requester = yield self.auth.get_user_by_req(request)
+        user_id = requester.user.to_string()
+
+        body = parse_json_object_from_request(request)
+
+        yield self.store.add_event_report(
+            room_id=room_id,
+            event_id=event_id,
+            user_id=user_id,
+            reason=body.get("reason"),
+            content=body,
+            received_ts=self.clock.time_msec(),
+        )
+
+        defer.returnValue((200, {}))
+
+
+def register_servlets(hs, http_server):
+    ReportEventRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index de4a020ad4..43d8e0bf39 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -79,11 +79,10 @@ class SyncRestServlet(RestServlet):
     def __init__(self, hs):
         super(SyncRestServlet, self).__init__()
         self.auth = hs.get_auth()
-        self.event_stream_handler = hs.get_handlers().event_stream_handler
-        self.sync_handler = hs.get_handlers().sync_handler
+        self.sync_handler = hs.get_sync_handler()
         self.clock = hs.get_clock()
         self.filtering = hs.get_filtering()
-        self.presence_handler = hs.get_handlers().presence_handler
+        self.presence_handler = hs.get_presence_handler()
 
     @defer.inlineCallbacks
     def on_GET(self, request):
@@ -115,6 +114,8 @@ class SyncRestServlet(RestServlet):
             )
         )
 
+        request_key = (user, timeout, since, filter_id, full_state)
+
         if filter_id:
             if filter_id.startswith('{'):
                 try:
@@ -134,6 +135,7 @@ class SyncRestServlet(RestServlet):
             user=user,
             filter_collection=filter,
             is_guest=requester.is_guest,
+            request_key=request_key,
         )
 
         if since is not None:
@@ -196,15 +198,17 @@ class SyncRestServlet(RestServlet):
         """
         Encode the joined rooms in a sync result
 
-        :param list[synapse.handlers.sync.JoinedSyncResult] rooms: list of sync
-            results for rooms this user is joined to
-        :param int time_now: current time - used as a baseline for age
-            calculations
-        :param int token_id: ID of the user's auth token - used for namespacing
-            of transaction IDs
-
-        :return: the joined rooms list, in our response format
-        :rtype: dict[str, dict[str, object]]
+        Args:
+            rooms(list[synapse.handlers.sync.JoinedSyncResult]): list of sync
+                results for rooms this user is joined to
+            time_now(int): current time - used as a baseline for age
+                calculations
+            token_id(int): ID of the user's auth token - used for namespacing
+                of transaction IDs
+
+        Returns:
+            dict[str, dict[str, object]]: the joined rooms list, in our
+                response format
         """
         joined = {}
         for room in rooms:
@@ -218,15 +222,17 @@ class SyncRestServlet(RestServlet):
         """
         Encode the invited rooms in a sync result
 
-        :param list[synapse.handlers.sync.InvitedSyncResult] rooms: list of
-             sync results for rooms this user is joined to
-        :param int time_now: current time - used as a baseline for age
-            calculations
-        :param int token_id: ID of the user's auth token - used for namespacing
+        Args:
+            rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of
+                sync results for rooms this user is joined to
+            time_now(int): current time - used as a baseline for age
+                calculations
+            token_id(int): ID of the user's auth token - used for namespacing
             of transaction IDs
 
-        :return: the invited rooms list, in our response format
-        :rtype: dict[str, dict[str, object]]
+        Returns:
+            dict[str, dict[str, object]]: the invited rooms list, in our
+                response format
         """
         invited = {}
         for room in rooms:
@@ -248,15 +254,17 @@ class SyncRestServlet(RestServlet):
         """
         Encode the archived rooms in a sync result
 
-        :param list[synapse.handlers.sync.ArchivedSyncResult] rooms: list of
-             sync results for rooms this user is joined to
-        :param int time_now: current time - used as a baseline for age
-            calculations
-        :param int token_id: ID of the user's auth token - used for namespacing
-            of transaction IDs
-
-        :return: the invited rooms list, in our response format
-        :rtype: dict[str, dict[str, object]]
+        Args:
+            rooms (list[synapse.handlers.sync.ArchivedSyncResult]): list of
+                sync results for rooms this user is joined to
+            time_now(int): current time - used as a baseline for age
+                calculations
+            token_id(int): ID of the user's auth token - used for namespacing
+                of transaction IDs
+
+        Returns:
+            dict[str, dict[str, object]]: The invited rooms list, in our
+                response format
         """
         joined = {}
         for room in rooms:
@@ -269,17 +277,18 @@ class SyncRestServlet(RestServlet):
     @staticmethod
     def encode_room(room, time_now, token_id, joined=True):
         """
-        :param JoinedSyncResult|ArchivedSyncResult room: sync result for a
-            single room
-        :param int time_now: current time - used as a baseline for age
-            calculations
-        :param int token_id: ID of the user's auth token - used for namespacing
-            of transaction IDs
-        :param joined: True if the user is joined to this room - will mean
-            we handle ephemeral events
-
-        :return: the room, encoded in our response format
-        :rtype: dict[str, object]
+        Args:
+            room (JoinedSyncResult|ArchivedSyncResult): sync result for a
+                single room
+            time_now (int): current time - used as a baseline for age
+                calculations
+            token_id (int): ID of the user's auth token - used for namespacing
+                of transaction IDs
+            joined (bool): True if the user is joined to this room - will mean
+                we handle ephemeral events
+
+        Returns:
+            dict[str, object]: the room, encoded in our response format
         """
         def serialize(event):
             # TODO(mjark): Respect formatting requirements in the filter.
diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py
index a158c2209a..8270e8787f 100644
--- a/synapse/rest/client/v2_alpha/tokenrefresh.py
+++ b/synapse/rest/client/v2_alpha/tokenrefresh.py
@@ -38,7 +38,7 @@ class TokenRefreshRestServlet(RestServlet):
         body = parse_json_object_from_request(request)
         try:
             old_refresh_token = body["refresh_token"]
-            auth_handler = self.hs.get_handlers().auth_handler
+            auth_handler = self.hs.get_auth_handler()
             (user_id, new_refresh_token) = yield self.store.exchange_refresh_token(
                 old_refresh_token, auth_handler.generate_refresh_token)
             new_access_token = yield auth_handler.issue_access_token(user_id)
diff --git a/synapse/rest/key/v1/server_key_resource.py b/synapse/rest/key/v1/server_key_resource.py
index 3db3838b7e..bd4fea5774 100644
--- a/synapse/rest/key/v1/server_key_resource.py
+++ b/synapse/rest/key/v1/server_key_resource.py
@@ -49,7 +49,6 @@ class LocalKey(Resource):
     """
 
     def __init__(self, hs):
-        self.hs = hs
         self.version_string = hs.version_string
         self.response_body = encode_canonical_json(
             self.response_json_object(hs.config)
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 9552016fec..7209d5a37d 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -97,7 +97,7 @@ class RemoteKey(Resource):
         self.async_render_GET(request)
         return NOT_DONE_YET
 
-    @request_handler
+    @request_handler()
     @defer.inlineCallbacks
     def async_render_GET(self, request):
         if len(request.postpath) == 1:
@@ -122,7 +122,7 @@ class RemoteKey(Resource):
         self.async_render_POST(request)
         return NOT_DONE_YET
 
-    @request_handler
+    @request_handler()
     @defer.inlineCallbacks
     def async_render_POST(self, request):
         content = parse_json_object_from_request(request)
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
new file mode 100644
index 0000000000..b9600f2167
--- /dev/null
+++ b/synapse/rest/media/v1/_base.py
@@ -0,0 +1,110 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.http.server import respond_with_json, finish_request
+from synapse.api.errors import (
+    cs_error, Codes, SynapseError
+)
+
+from twisted.internet import defer
+from twisted.protocols.basic import FileSender
+
+from synapse.util.stringutils import is_ascii
+
+import os
+
+import logging
+import urllib
+import urlparse
+
+logger = logging.getLogger(__name__)
+
+
+def parse_media_id(request):
+    try:
+        # This allows users to append e.g. /test.png to the URL. Useful for
+        # clients that parse the URL to see content type.
+        server_name, media_id = request.postpath[:2]
+        file_name = None
+        if len(request.postpath) > 2:
+            try:
+                file_name = urlparse.unquote(request.postpath[-1]).decode("utf-8")
+            except UnicodeDecodeError:
+                pass
+        return server_name, media_id, file_name
+    except:
+        raise SynapseError(
+            404,
+            "Invalid media id token %r" % (request.postpath,),
+            Codes.UNKNOWN,
+        )
+
+
+def respond_404(request):
+    respond_with_json(
+        request, 404,
+        cs_error(
+            "Not found %r" % (request.postpath,),
+            code=Codes.NOT_FOUND,
+        ),
+        send_cors=True
+    )
+
+
+@defer.inlineCallbacks
+def respond_with_file(request, media_type, file_path,
+                      file_size=None, upload_name=None):
+    logger.debug("Responding with %r", file_path)
+
+    if os.path.isfile(file_path):
+        request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
+        if upload_name:
+            if is_ascii(upload_name):
+                request.setHeader(
+                    b"Content-Disposition",
+                    b"inline; filename=%s" % (
+                        urllib.quote(upload_name.encode("utf-8")),
+                    ),
+                )
+            else:
+                request.setHeader(
+                    b"Content-Disposition",
+                    b"inline; filename*=utf-8''%s" % (
+                        urllib.quote(upload_name.encode("utf-8")),
+                    ),
+                )
+
+        # cache for at least a day.
+        # XXX: we might want to turn this off for data we don't want to
+        # recommend caching as it's sensitive or private - or at least
+        # select private. don't bother setting Expires as all our
+        # clients are smart enough to be happy with Cache-Control
+        request.setHeader(
+            b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
+        )
+        if file_size is None:
+            stat = os.stat(file_path)
+            file_size = stat.st_size
+
+        request.setHeader(
+            b"Content-Length", b"%d" % (file_size,)
+        )
+
+        with open(file_path, "rb") as f:
+            yield FileSender().beginFileTransfer(f, request)
+
+        finish_request(request)
+    else:
+        respond_404(request)
diff --git a/synapse/rest/media/v1/base_resource.py b/synapse/rest/media/v1/base_resource.py
deleted file mode 100644
index 58ef91c0b8..0000000000
--- a/synapse/rest/media/v1/base_resource.py
+++ /dev/null
@@ -1,459 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from .thumbnailer import Thumbnailer
-
-from synapse.http.matrixfederationclient import MatrixFederationHttpClient
-from synapse.http.server import respond_with_json, finish_request
-from synapse.util.stringutils import random_string
-from synapse.api.errors import (
-    cs_error, Codes, SynapseError
-)
-
-from twisted.internet import defer, threads
-from twisted.web.resource import Resource
-from twisted.protocols.basic import FileSender
-
-from synapse.util.async import ObservableDeferred
-from synapse.util.stringutils import is_ascii
-from synapse.util.logcontext import preserve_context_over_fn
-
-import os
-
-import cgi
-import logging
-import urllib
-import urlparse
-
-logger = logging.getLogger(__name__)
-
-
-def parse_media_id(request):
-    try:
-        # This allows users to append e.g. /test.png to the URL. Useful for
-        # clients that parse the URL to see content type.
-        server_name, media_id = request.postpath[:2]
-        file_name = None
-        if len(request.postpath) > 2:
-            try:
-                file_name = urlparse.unquote(request.postpath[-1]).decode("utf-8")
-            except UnicodeDecodeError:
-                pass
-        return server_name, media_id, file_name
-    except:
-        raise SynapseError(
-            404,
-            "Invalid media id token %r" % (request.postpath,),
-            Codes.UNKNOWN,
-        )
-
-
-class BaseMediaResource(Resource):
-    isLeaf = True
-
-    def __init__(self, hs, filepaths):
-        Resource.__init__(self)
-        self.auth = hs.get_auth()
-        self.client = MatrixFederationHttpClient(hs)
-        self.clock = hs.get_clock()
-        self.server_name = hs.hostname
-        self.store = hs.get_datastore()
-        self.max_upload_size = hs.config.max_upload_size
-        self.max_image_pixels = hs.config.max_image_pixels
-        self.filepaths = filepaths
-        self.version_string = hs.version_string
-        self.downloads = {}
-        self.dynamic_thumbnails = hs.config.dynamic_thumbnails
-        self.thumbnail_requirements = hs.config.thumbnail_requirements
-
-    def _respond_404(self, request):
-        respond_with_json(
-            request, 404,
-            cs_error(
-                "Not found %r" % (request.postpath,),
-                code=Codes.NOT_FOUND,
-            ),
-            send_cors=True
-        )
-
-    @staticmethod
-    def _makedirs(filepath):
-        dirname = os.path.dirname(filepath)
-        if not os.path.exists(dirname):
-            os.makedirs(dirname)
-
-    def _get_remote_media(self, server_name, media_id):
-        key = (server_name, media_id)
-        download = self.downloads.get(key)
-        if download is None:
-            download = self._get_remote_media_impl(server_name, media_id)
-            download = ObservableDeferred(
-                download,
-                consumeErrors=True
-            )
-            self.downloads[key] = download
-
-            @download.addBoth
-            def callback(media_info):
-                del self.downloads[key]
-                return media_info
-        return download.observe()
-
-    @defer.inlineCallbacks
-    def _get_remote_media_impl(self, server_name, media_id):
-        media_info = yield self.store.get_cached_remote_media(
-            server_name, media_id
-        )
-        if not media_info:
-            media_info = yield self._download_remote_file(
-                server_name, media_id
-            )
-        defer.returnValue(media_info)
-
-    @defer.inlineCallbacks
-    def _download_remote_file(self, server_name, media_id):
-        file_id = random_string(24)
-
-        fname = self.filepaths.remote_media_filepath(
-            server_name, file_id
-        )
-        self._makedirs(fname)
-
-        try:
-            with open(fname, "wb") as f:
-                request_path = "/".join((
-                    "/_matrix/media/v1/download", server_name, media_id,
-                ))
-                length, headers = yield self.client.get_file(
-                    server_name, request_path, output_stream=f,
-                    max_size=self.max_upload_size,
-                )
-            media_type = headers["Content-Type"][0]
-            time_now_ms = self.clock.time_msec()
-
-            content_disposition = headers.get("Content-Disposition", None)
-            if content_disposition:
-                _, params = cgi.parse_header(content_disposition[0],)
-                upload_name = None
-
-                # First check if there is a valid UTF-8 filename
-                upload_name_utf8 = params.get("filename*", None)
-                if upload_name_utf8:
-                    if upload_name_utf8.lower().startswith("utf-8''"):
-                        upload_name = upload_name_utf8[7:]
-
-                # If there isn't check for an ascii name.
-                if not upload_name:
-                    upload_name_ascii = params.get("filename", None)
-                    if upload_name_ascii and is_ascii(upload_name_ascii):
-                        upload_name = upload_name_ascii
-
-                if upload_name:
-                    upload_name = urlparse.unquote(upload_name)
-                    try:
-                        upload_name = upload_name.decode("utf-8")
-                    except UnicodeDecodeError:
-                        upload_name = None
-            else:
-                upload_name = None
-
-            yield self.store.store_cached_remote_media(
-                origin=server_name,
-                media_id=media_id,
-                media_type=media_type,
-                time_now_ms=self.clock.time_msec(),
-                upload_name=upload_name,
-                media_length=length,
-                filesystem_id=file_id,
-            )
-        except:
-            os.remove(fname)
-            raise
-
-        media_info = {
-            "media_type": media_type,
-            "media_length": length,
-            "upload_name": upload_name,
-            "created_ts": time_now_ms,
-            "filesystem_id": file_id,
-        }
-
-        yield self._generate_remote_thumbnails(
-            server_name, media_id, media_info
-        )
-
-        defer.returnValue(media_info)
-
-    @defer.inlineCallbacks
-    def _respond_with_file(self, request, media_type, file_path,
-                           file_size=None, upload_name=None):
-        logger.debug("Responding with %r", file_path)
-
-        if os.path.isfile(file_path):
-            request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
-            if upload_name:
-                if is_ascii(upload_name):
-                    request.setHeader(
-                        b"Content-Disposition",
-                        b"inline; filename=%s" % (
-                            urllib.quote(upload_name.encode("utf-8")),
-                        ),
-                    )
-                else:
-                    request.setHeader(
-                        b"Content-Disposition",
-                        b"inline; filename*=utf-8''%s" % (
-                            urllib.quote(upload_name.encode("utf-8")),
-                        ),
-                    )
-
-            # cache for at least a day.
-            # XXX: we might want to turn this off for data we don't want to
-            # recommend caching as it's sensitive or private - or at least
-            # select private. don't bother setting Expires as all our
-            # clients are smart enough to be happy with Cache-Control
-            request.setHeader(
-                b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
-            )
-            if file_size is None:
-                stat = os.stat(file_path)
-                file_size = stat.st_size
-
-            request.setHeader(
-                b"Content-Length", b"%d" % (file_size,)
-            )
-
-            with open(file_path, "rb") as f:
-                yield FileSender().beginFileTransfer(f, request)
-
-            finish_request(request)
-        else:
-            self._respond_404(request)
-
-    def _get_thumbnail_requirements(self, media_type):
-        return self.thumbnail_requirements.get(media_type, ())
-
-    def _generate_thumbnail(self, input_path, t_path, t_width, t_height,
-                            t_method, t_type):
-        thumbnailer = Thumbnailer(input_path)
-        m_width = thumbnailer.width
-        m_height = thumbnailer.height
-
-        if m_width * m_height >= self.max_image_pixels:
-            logger.info(
-                "Image too large to thumbnail %r x %r > %r",
-                m_width, m_height, self.max_image_pixels
-            )
-            return
-
-        if t_method == "crop":
-            t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
-        elif t_method == "scale":
-            t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
-        else:
-            t_len = None
-
-        return t_len
-
-    @defer.inlineCallbacks
-    def _generate_local_exact_thumbnail(self, media_id, t_width, t_height,
-                                        t_method, t_type):
-        input_path = self.filepaths.local_media_filepath(media_id)
-
-        t_path = self.filepaths.local_media_thumbnail(
-            media_id, t_width, t_height, t_type, t_method
-        )
-        self._makedirs(t_path)
-
-        t_len = yield preserve_context_over_fn(
-            threads.deferToThread,
-            self._generate_thumbnail,
-            input_path, t_path, t_width, t_height, t_method, t_type
-        )
-
-        if t_len:
-            yield self.store.store_local_thumbnail(
-                media_id, t_width, t_height, t_type, t_method, t_len
-            )
-
-            defer.returnValue(t_path)
-
-    @defer.inlineCallbacks
-    def _generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
-                                         t_width, t_height, t_method, t_type):
-        input_path = self.filepaths.remote_media_filepath(server_name, file_id)
-
-        t_path = self.filepaths.remote_media_thumbnail(
-            server_name, file_id, t_width, t_height, t_type, t_method
-        )
-        self._makedirs(t_path)
-
-        t_len = yield preserve_context_over_fn(
-            threads.deferToThread,
-            self._generate_thumbnail,
-            input_path, t_path, t_width, t_height, t_method, t_type
-        )
-
-        if t_len:
-            yield self.store.store_remote_media_thumbnail(
-                server_name, media_id, file_id,
-                t_width, t_height, t_type, t_method, t_len
-            )
-
-            defer.returnValue(t_path)
-
-    @defer.inlineCallbacks
-    def _generate_local_thumbnails(self, media_id, media_info):
-        media_type = media_info["media_type"]
-        requirements = self._get_thumbnail_requirements(media_type)
-        if not requirements:
-            return
-
-        input_path = self.filepaths.local_media_filepath(media_id)
-        thumbnailer = Thumbnailer(input_path)
-        m_width = thumbnailer.width
-        m_height = thumbnailer.height
-
-        if m_width * m_height >= self.max_image_pixels:
-            logger.info(
-                "Image too large to thumbnail %r x %r > %r",
-                m_width, m_height, self.max_image_pixels
-            )
-            return
-
-        local_thumbnails = []
-
-        def generate_thumbnails():
-            scales = set()
-            crops = set()
-            for r_width, r_height, r_method, r_type in requirements:
-                if r_method == "scale":
-                    t_width, t_height = thumbnailer.aspect(r_width, r_height)
-                    scales.add((
-                        min(m_width, t_width), min(m_height, t_height), r_type,
-                    ))
-                elif r_method == "crop":
-                    crops.add((r_width, r_height, r_type))
-
-            for t_width, t_height, t_type in scales:
-                t_method = "scale"
-                t_path = self.filepaths.local_media_thumbnail(
-                    media_id, t_width, t_height, t_type, t_method
-                )
-                self._makedirs(t_path)
-                t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
-
-                local_thumbnails.append((
-                    media_id, t_width, t_height, t_type, t_method, t_len
-                ))
-
-            for t_width, t_height, t_type in crops:
-                if (t_width, t_height, t_type) in scales:
-                    # If the aspect ratio of the cropped thumbnail matches a purely
-                    # scaled one then there is no point in calculating a separate
-                    # thumbnail.
-                    continue
-                t_method = "crop"
-                t_path = self.filepaths.local_media_thumbnail(
-                    media_id, t_width, t_height, t_type, t_method
-                )
-                self._makedirs(t_path)
-                t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
-                local_thumbnails.append((
-                    media_id, t_width, t_height, t_type, t_method, t_len
-                ))
-
-        yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
-
-        for l in local_thumbnails:
-            yield self.store.store_local_thumbnail(*l)
-
-        defer.returnValue({
-            "width": m_width,
-            "height": m_height,
-        })
-
-    @defer.inlineCallbacks
-    def _generate_remote_thumbnails(self, server_name, media_id, media_info):
-        media_type = media_info["media_type"]
-        file_id = media_info["filesystem_id"]
-        requirements = self._get_thumbnail_requirements(media_type)
-        if not requirements:
-            return
-
-        remote_thumbnails = []
-
-        input_path = self.filepaths.remote_media_filepath(server_name, file_id)
-        thumbnailer = Thumbnailer(input_path)
-        m_width = thumbnailer.width
-        m_height = thumbnailer.height
-
-        def generate_thumbnails():
-            if m_width * m_height >= self.max_image_pixels:
-                logger.info(
-                    "Image too large to thumbnail %r x %r > %r",
-                    m_width, m_height, self.max_image_pixels
-                )
-                return
-
-            scales = set()
-            crops = set()
-            for r_width, r_height, r_method, r_type in requirements:
-                if r_method == "scale":
-                    t_width, t_height = thumbnailer.aspect(r_width, r_height)
-                    scales.add((
-                        min(m_width, t_width), min(m_height, t_height), r_type,
-                    ))
-                elif r_method == "crop":
-                    crops.add((r_width, r_height, r_type))
-
-            for t_width, t_height, t_type in scales:
-                t_method = "scale"
-                t_path = self.filepaths.remote_media_thumbnail(
-                    server_name, file_id, t_width, t_height, t_type, t_method
-                )
-                self._makedirs(t_path)
-                t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
-                remote_thumbnails.append([
-                    server_name, media_id, file_id,
-                    t_width, t_height, t_type, t_method, t_len
-                ])
-
-            for t_width, t_height, t_type in crops:
-                if (t_width, t_height, t_type) in scales:
-                    # If the aspect ratio of the cropped thumbnail matches a purely
-                    # scaled one then there is no point in calculating a separate
-                    # thumbnail.
-                    continue
-                t_method = "crop"
-                t_path = self.filepaths.remote_media_thumbnail(
-                    server_name, file_id, t_width, t_height, t_type, t_method
-                )
-                self._makedirs(t_path)
-                t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
-                remote_thumbnails.append([
-                    server_name, media_id, file_id,
-                    t_width, t_height, t_type, t_method, t_len
-                ])
-
-        yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
-
-        for r in remote_thumbnails:
-            yield self.store.store_remote_media_thumbnail(*r)
-
-        defer.returnValue({
-            "width": m_width,
-            "height": m_height,
-        })
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index 1aad6b3551..9f69620772 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -13,7 +13,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .base_resource import BaseMediaResource, parse_media_id
+from ._base import parse_media_id, respond_with_file, respond_404
+from twisted.web.resource import Resource
 from synapse.http.server import request_handler
 
 from twisted.web.server import NOT_DONE_YET
@@ -24,12 +25,24 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-class DownloadResource(BaseMediaResource):
+class DownloadResource(Resource):
+    isLeaf = True
+
+    def __init__(self, hs, media_repo):
+        Resource.__init__(self)
+
+        self.filepaths = media_repo.filepaths
+        self.media_repo = media_repo
+        self.server_name = hs.hostname
+        self.store = hs.get_datastore()
+        self.version_string = hs.version_string
+        self.clock = hs.get_clock()
+
     def render_GET(self, request):
         self._async_render_GET(request)
         return NOT_DONE_YET
 
-    @request_handler
+    @request_handler()
     @defer.inlineCallbacks
     def _async_render_GET(self, request):
         server_name, media_id, name = parse_media_id(request)
@@ -44,7 +57,7 @@ class DownloadResource(BaseMediaResource):
     def _respond_local_file(self, request, media_id, name):
         media_info = yield self.store.get_local_media(media_id)
         if not media_info:
-            self._respond_404(request)
+            respond_404(request)
             return
 
         media_type = media_info["media_type"]
@@ -52,14 +65,14 @@ class DownloadResource(BaseMediaResource):
         upload_name = name if name else media_info["upload_name"]
         file_path = self.filepaths.local_media_filepath(media_id)
 
-        yield self._respond_with_file(
+        yield respond_with_file(
             request, media_type, file_path, media_length,
             upload_name=upload_name,
         )
 
     @defer.inlineCallbacks
     def _respond_remote_file(self, request, server_name, media_id, name):
-        media_info = yield self._get_remote_media(server_name, media_id)
+        media_info = yield self.media_repo.get_remote_media(server_name, media_id)
 
         media_type = media_info["media_type"]
         media_length = media_info["media_length"]
@@ -70,7 +83,7 @@ class DownloadResource(BaseMediaResource):
             server_name, filesystem_id
         )
 
-        yield self._respond_with_file(
+        yield respond_with_file(
             request, media_type, file_path, media_length,
             upload_name=upload_name,
         )
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 7dfb027dd1..d96bf9afe2 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -17,15 +17,400 @@ from .upload_resource import UploadResource
 from .download_resource import DownloadResource
 from .thumbnail_resource import ThumbnailResource
 from .identicon_resource import IdenticonResource
+from .preview_url_resource import PreviewUrlResource
 from .filepath import MediaFilePaths
 
 from twisted.web.resource import Resource
 
+from .thumbnailer import Thumbnailer
+
+from synapse.http.matrixfederationclient import MatrixFederationHttpClient
+from synapse.util.stringutils import random_string
+
+from twisted.internet import defer, threads
+
+from synapse.util.async import ObservableDeferred
+from synapse.util.stringutils import is_ascii
+from synapse.util.logcontext import preserve_context_over_fn
+
+import os
+
+import cgi
 import logging
+import urlparse
 
 logger = logging.getLogger(__name__)
 
 
+class MediaRepository(object):
+    def __init__(self, hs, filepaths):
+        self.auth = hs.get_auth()
+        self.client = MatrixFederationHttpClient(hs)
+        self.clock = hs.get_clock()
+        self.server_name = hs.hostname
+        self.store = hs.get_datastore()
+        self.max_upload_size = hs.config.max_upload_size
+        self.max_image_pixels = hs.config.max_image_pixels
+        self.filepaths = filepaths
+        self.downloads = {}
+        self.dynamic_thumbnails = hs.config.dynamic_thumbnails
+        self.thumbnail_requirements = hs.config.thumbnail_requirements
+
+    @staticmethod
+    def _makedirs(filepath):
+        dirname = os.path.dirname(filepath)
+        if not os.path.exists(dirname):
+            os.makedirs(dirname)
+
+    @defer.inlineCallbacks
+    def create_content(self, media_type, upload_name, content, content_length,
+                       auth_user):
+        media_id = random_string(24)
+
+        fname = self.filepaths.local_media_filepath(media_id)
+        self._makedirs(fname)
+
+        # This shouldn't block for very long because the content will have
+        # already been uploaded at this point.
+        with open(fname, "wb") as f:
+            f.write(content)
+
+        yield self.store.store_local_media(
+            media_id=media_id,
+            media_type=media_type,
+            time_now_ms=self.clock.time_msec(),
+            upload_name=upload_name,
+            media_length=content_length,
+            user_id=auth_user,
+        )
+        media_info = {
+            "media_type": media_type,
+            "media_length": content_length,
+        }
+
+        yield self._generate_local_thumbnails(media_id, media_info)
+
+        defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
+
+    def get_remote_media(self, server_name, media_id):
+        key = (server_name, media_id)
+        download = self.downloads.get(key)
+        if download is None:
+            download = self._get_remote_media_impl(server_name, media_id)
+            download = ObservableDeferred(
+                download,
+                consumeErrors=True
+            )
+            self.downloads[key] = download
+
+            @download.addBoth
+            def callback(media_info):
+                del self.downloads[key]
+                return media_info
+        return download.observe()
+
+    @defer.inlineCallbacks
+    def _get_remote_media_impl(self, server_name, media_id):
+        media_info = yield self.store.get_cached_remote_media(
+            server_name, media_id
+        )
+        if not media_info:
+            media_info = yield self._download_remote_file(
+                server_name, media_id
+            )
+        defer.returnValue(media_info)
+
+    @defer.inlineCallbacks
+    def _download_remote_file(self, server_name, media_id):
+        file_id = random_string(24)
+
+        fname = self.filepaths.remote_media_filepath(
+            server_name, file_id
+        )
+        self._makedirs(fname)
+
+        try:
+            with open(fname, "wb") as f:
+                request_path = "/".join((
+                    "/_matrix/media/v1/download", server_name, media_id,
+                ))
+                length, headers = yield self.client.get_file(
+                    server_name, request_path, output_stream=f,
+                    max_size=self.max_upload_size,
+                )
+            media_type = headers["Content-Type"][0]
+            time_now_ms = self.clock.time_msec()
+
+            content_disposition = headers.get("Content-Disposition", None)
+            if content_disposition:
+                _, params = cgi.parse_header(content_disposition[0],)
+                upload_name = None
+
+                # First check if there is a valid UTF-8 filename
+                upload_name_utf8 = params.get("filename*", None)
+                if upload_name_utf8:
+                    if upload_name_utf8.lower().startswith("utf-8''"):
+                        upload_name = upload_name_utf8[7:]
+
+                # If there isn't check for an ascii name.
+                if not upload_name:
+                    upload_name_ascii = params.get("filename", None)
+                    if upload_name_ascii and is_ascii(upload_name_ascii):
+                        upload_name = upload_name_ascii
+
+                if upload_name:
+                    upload_name = urlparse.unquote(upload_name)
+                    try:
+                        upload_name = upload_name.decode("utf-8")
+                    except UnicodeDecodeError:
+                        upload_name = None
+            else:
+                upload_name = None
+
+            yield self.store.store_cached_remote_media(
+                origin=server_name,
+                media_id=media_id,
+                media_type=media_type,
+                time_now_ms=self.clock.time_msec(),
+                upload_name=upload_name,
+                media_length=length,
+                filesystem_id=file_id,
+            )
+        except:
+            os.remove(fname)
+            raise
+
+        media_info = {
+            "media_type": media_type,
+            "media_length": length,
+            "upload_name": upload_name,
+            "created_ts": time_now_ms,
+            "filesystem_id": file_id,
+        }
+
+        yield self._generate_remote_thumbnails(
+            server_name, media_id, media_info
+        )
+
+        defer.returnValue(media_info)
+
+    def _get_thumbnail_requirements(self, media_type):
+        return self.thumbnail_requirements.get(media_type, ())
+
+    def _generate_thumbnail(self, input_path, t_path, t_width, t_height,
+                            t_method, t_type):
+        thumbnailer = Thumbnailer(input_path)
+        m_width = thumbnailer.width
+        m_height = thumbnailer.height
+
+        if m_width * m_height >= self.max_image_pixels:
+            logger.info(
+                "Image too large to thumbnail %r x %r > %r",
+                m_width, m_height, self.max_image_pixels
+            )
+            return
+
+        if t_method == "crop":
+            t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
+        elif t_method == "scale":
+            t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
+        else:
+            t_len = None
+
+        return t_len
+
+    @defer.inlineCallbacks
+    def generate_local_exact_thumbnail(self, media_id, t_width, t_height,
+                                       t_method, t_type):
+        input_path = self.filepaths.local_media_filepath(media_id)
+
+        t_path = self.filepaths.local_media_thumbnail(
+            media_id, t_width, t_height, t_type, t_method
+        )
+        self._makedirs(t_path)
+
+        t_len = yield preserve_context_over_fn(
+            threads.deferToThread,
+            self._generate_thumbnail,
+            input_path, t_path, t_width, t_height, t_method, t_type
+        )
+
+        if t_len:
+            yield self.store.store_local_thumbnail(
+                media_id, t_width, t_height, t_type, t_method, t_len
+            )
+
+            defer.returnValue(t_path)
+
+    @defer.inlineCallbacks
+    def generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
+                                        t_width, t_height, t_method, t_type):
+        input_path = self.filepaths.remote_media_filepath(server_name, file_id)
+
+        t_path = self.filepaths.remote_media_thumbnail(
+            server_name, file_id, t_width, t_height, t_type, t_method
+        )
+        self._makedirs(t_path)
+
+        t_len = yield preserve_context_over_fn(
+            threads.deferToThread,
+            self._generate_thumbnail,
+            input_path, t_path, t_width, t_height, t_method, t_type
+        )
+
+        if t_len:
+            yield self.store.store_remote_media_thumbnail(
+                server_name, media_id, file_id,
+                t_width, t_height, t_type, t_method, t_len
+            )
+
+            defer.returnValue(t_path)
+
+    @defer.inlineCallbacks
+    def _generate_local_thumbnails(self, media_id, media_info):
+        media_type = media_info["media_type"]
+        requirements = self._get_thumbnail_requirements(media_type)
+        if not requirements:
+            return
+
+        input_path = self.filepaths.local_media_filepath(media_id)
+        thumbnailer = Thumbnailer(input_path)
+        m_width = thumbnailer.width
+        m_height = thumbnailer.height
+
+        if m_width * m_height >= self.max_image_pixels:
+            logger.info(
+                "Image too large to thumbnail %r x %r > %r",
+                m_width, m_height, self.max_image_pixels
+            )
+            return
+
+        local_thumbnails = []
+
+        def generate_thumbnails():
+            scales = set()
+            crops = set()
+            for r_width, r_height, r_method, r_type in requirements:
+                if r_method == "scale":
+                    t_width, t_height = thumbnailer.aspect(r_width, r_height)
+                    scales.add((
+                        min(m_width, t_width), min(m_height, t_height), r_type,
+                    ))
+                elif r_method == "crop":
+                    crops.add((r_width, r_height, r_type))
+
+            for t_width, t_height, t_type in scales:
+                t_method = "scale"
+                t_path = self.filepaths.local_media_thumbnail(
+                    media_id, t_width, t_height, t_type, t_method
+                )
+                self._makedirs(t_path)
+                t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
+
+                local_thumbnails.append((
+                    media_id, t_width, t_height, t_type, t_method, t_len
+                ))
+
+            for t_width, t_height, t_type in crops:
+                if (t_width, t_height, t_type) in scales:
+                    # If the aspect ratio of the cropped thumbnail matches a purely
+                    # scaled one then there is no point in calculating a separate
+                    # thumbnail.
+                    continue
+                t_method = "crop"
+                t_path = self.filepaths.local_media_thumbnail(
+                    media_id, t_width, t_height, t_type, t_method
+                )
+                self._makedirs(t_path)
+                t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
+                local_thumbnails.append((
+                    media_id, t_width, t_height, t_type, t_method, t_len
+                ))
+
+        yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
+
+        for l in local_thumbnails:
+            yield self.store.store_local_thumbnail(*l)
+
+        defer.returnValue({
+            "width": m_width,
+            "height": m_height,
+        })
+
+    @defer.inlineCallbacks
+    def _generate_remote_thumbnails(self, server_name, media_id, media_info):
+        media_type = media_info["media_type"]
+        file_id = media_info["filesystem_id"]
+        requirements = self._get_thumbnail_requirements(media_type)
+        if not requirements:
+            return
+
+        remote_thumbnails = []
+
+        input_path = self.filepaths.remote_media_filepath(server_name, file_id)
+        thumbnailer = Thumbnailer(input_path)
+        m_width = thumbnailer.width
+        m_height = thumbnailer.height
+
+        def generate_thumbnails():
+            if m_width * m_height >= self.max_image_pixels:
+                logger.info(
+                    "Image too large to thumbnail %r x %r > %r",
+                    m_width, m_height, self.max_image_pixels
+                )
+                return
+
+            scales = set()
+            crops = set()
+            for r_width, r_height, r_method, r_type in requirements:
+                if r_method == "scale":
+                    t_width, t_height = thumbnailer.aspect(r_width, r_height)
+                    scales.add((
+                        min(m_width, t_width), min(m_height, t_height), r_type,
+                    ))
+                elif r_method == "crop":
+                    crops.add((r_width, r_height, r_type))
+
+            for t_width, t_height, t_type in scales:
+                t_method = "scale"
+                t_path = self.filepaths.remote_media_thumbnail(
+                    server_name, file_id, t_width, t_height, t_type, t_method
+                )
+                self._makedirs(t_path)
+                t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
+                remote_thumbnails.append([
+                    server_name, media_id, file_id,
+                    t_width, t_height, t_type, t_method, t_len
+                ])
+
+            for t_width, t_height, t_type in crops:
+                if (t_width, t_height, t_type) in scales:
+                    # If the aspect ratio of the cropped thumbnail matches a purely
+                    # scaled one then there is no point in calculating a separate
+                    # thumbnail.
+                    continue
+                t_method = "crop"
+                t_path = self.filepaths.remote_media_thumbnail(
+                    server_name, file_id, t_width, t_height, t_type, t_method
+                )
+                self._makedirs(t_path)
+                t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
+                remote_thumbnails.append([
+                    server_name, media_id, file_id,
+                    t_width, t_height, t_type, t_method, t_len
+                ])
+
+        yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
+
+        for r in remote_thumbnails:
+            yield self.store.store_remote_media_thumbnail(*r)
+
+        defer.returnValue({
+            "width": m_width,
+            "height": m_height,
+        })
+
+
 class MediaRepositoryResource(Resource):
     """File uploading and downloading.
 
@@ -74,7 +459,12 @@ class MediaRepositoryResource(Resource):
     def __init__(self, hs):
         Resource.__init__(self)
         filepaths = MediaFilePaths(hs.config.media_store_path)
-        self.putChild("upload", UploadResource(hs, filepaths))
-        self.putChild("download", DownloadResource(hs, filepaths))
-        self.putChild("thumbnail", ThumbnailResource(hs, filepaths))
+
+        media_repo = MediaRepository(hs, filepaths)
+
+        self.putChild("upload", UploadResource(hs, media_repo))
+        self.putChild("download", DownloadResource(hs, media_repo))
+        self.putChild("thumbnail", ThumbnailResource(hs, media_repo))
         self.putChild("identicon", IdenticonResource())
+        if hs.config.url_preview_enabled:
+            self.putChild("preview_url", PreviewUrlResource(hs, media_repo))
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
new file mode 100644
index 0000000000..37dd1de899
--- /dev/null
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -0,0 +1,451 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.web.server import NOT_DONE_YET
+from twisted.internet import defer
+from twisted.web.resource import Resource
+
+from synapse.api.errors import (
+    SynapseError, Codes,
+)
+from synapse.util.stringutils import random_string
+from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.http.client import SpiderHttpClient
+from synapse.http.server import (
+    request_handler, respond_with_json_bytes
+)
+from synapse.util.async import ObservableDeferred
+from synapse.util.stringutils import is_ascii
+
+import os
+import re
+import fnmatch
+import cgi
+import ujson as json
+import urlparse
+
+import logging
+logger = logging.getLogger(__name__)
+
+
+class PreviewUrlResource(Resource):
+    isLeaf = True
+
+    def __init__(self, hs, media_repo):
+        Resource.__init__(self)
+
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+        self.version_string = hs.version_string
+        self.filepaths = media_repo.filepaths
+        self.max_spider_size = hs.config.max_spider_size
+        self.server_name = hs.hostname
+        self.store = hs.get_datastore()
+        self.client = SpiderHttpClient(hs)
+        self.media_repo = media_repo
+
+        self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
+
+        # simple memory cache mapping urls to OG metadata
+        self.cache = ExpiringCache(
+            cache_name="url_previews",
+            clock=self.clock,
+            # don't spider URLs more often than once an hour
+            expiry_ms=60 * 60 * 1000,
+        )
+        self.cache.start()
+
+        self.downloads = {}
+
+    def render_GET(self, request):
+        self._async_render_GET(request)
+        return NOT_DONE_YET
+
+    @request_handler()
+    @defer.inlineCallbacks
+    def _async_render_GET(self, request):
+
+        # XXX: if get_user_by_req fails, what should we do in an async render?
+        requester = yield self.auth.get_user_by_req(request)
+        url = request.args.get("url")[0]
+        if "ts" in request.args:
+            ts = int(request.args.get("ts")[0])
+        else:
+            ts = self.clock.time_msec()
+
+        url_tuple = urlparse.urlsplit(url)
+        for entry in self.url_preview_url_blacklist:
+            match = True
+            for attrib in entry:
+                pattern = entry[attrib]
+                value = getattr(url_tuple, attrib)
+                logger.debug((
+                    "Matching attrib '%s' with value '%s' against"
+                    " pattern '%s'"
+                ) % (attrib, value, pattern))
+
+                if value is None:
+                    match = False
+                    continue
+
+                if pattern.startswith('^'):
+                    if not re.match(pattern, getattr(url_tuple, attrib)):
+                        match = False
+                        continue
+                else:
+                    if not fnmatch.fnmatch(getattr(url_tuple, attrib), pattern):
+                        match = False
+                        continue
+            if match:
+                logger.warn(
+                    "URL %s blocked by url_blacklist entry %s", url, entry
+                )
+                raise SynapseError(
+                    403, "URL blocked by url pattern blacklist entry",
+                    Codes.UNKNOWN
+                )
+
+        # first check the memory cache - good to handle all the clients on this
+        # HS thundering away to preview the same URL at the same time.
+        og = self.cache.get(url)
+        if og:
+            respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
+            return
+
+        # then check the URL cache in the DB (which will also provide us with
+        # historical previews, if we have any)
+        cache_result = yield self.store.get_url_cache(url, ts)
+        if (
+            cache_result and
+            cache_result["download_ts"] + cache_result["expires"] > ts and
+            cache_result["response_code"] / 100 == 2
+        ):
+            respond_with_json_bytes(
+                request, 200, cache_result["og"].encode('utf-8'),
+                send_cors=True
+            )
+            return
+
+        # Ensure only one download for a given URL is active at a time
+        download = self.downloads.get(url)
+        if download is None:
+            download = self._download_url(url, requester.user)
+            download = ObservableDeferred(
+                download,
+                consumeErrors=True
+            )
+            self.downloads[url] = download
+
+            @download.addBoth
+            def callback(media_info):
+                del self.downloads[url]
+                return media_info
+        media_info = yield download.observe()
+
+        # FIXME: we should probably update our cache now anyway, so that
+        # even if the OG calculation raises, we don't keep hammering on the
+        # remote server.  For now, leave it uncached to aid debugging OG
+        # calculation problems
+
+        logger.debug("got media_info of '%s'" % media_info)
+
+        if self._is_media(media_info['media_type']):
+            dims = yield self.media_repo._generate_local_thumbnails(
+                media_info['filesystem_id'], media_info
+            )
+
+            og = {
+                "og:description": media_info['download_name'],
+                "og:image": "mxc://%s/%s" % (
+                    self.server_name, media_info['filesystem_id']
+                ),
+                "og:image:type": media_info['media_type'],
+                "matrix:image:size": media_info['media_length'],
+            }
+
+            if dims:
+                og["og:image:width"] = dims['width']
+                og["og:image:height"] = dims['height']
+            else:
+                logger.warn("Couldn't get dims for %s" % url)
+
+            # define our OG response for this media
+        elif self._is_html(media_info['media_type']):
+            # TODO: somehow stop a big HTML tree from exploding synapse's RAM
+
+            from lxml import etree
+
+            file = open(media_info['filename'])
+            body = file.read()
+            file.close()
+
+            # clobber the encoding from the content-type, or default to utf-8
+            # XXX: this overrides any <meta/> or XML charset headers in the body
+            # which may pose problems, but so far seems to work okay.
+            match = re.match(r'.*; *charset=(.*?)(;|$)', media_info['media_type'], re.I)
+            encoding = match.group(1) if match else "utf-8"
+
+            try:
+                parser = etree.HTMLParser(recover=True, encoding=encoding)
+                tree = etree.fromstring(body, parser)
+                og = yield self._calc_og(tree, media_info, requester)
+            except UnicodeDecodeError:
+                # blindly try decoding the body as utf-8, which seems to fix
+                # the charset mismatches on https://google.com
+                parser = etree.HTMLParser(recover=True, encoding=encoding)
+                tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser)
+                og = yield self._calc_og(tree, media_info, requester)
+
+        else:
+            logger.warn("Failed to find any OG data in %s", url)
+            og = {}
+
+        logger.debug("Calculated OG for %s as %s" % (url, og))
+
+        # store OG in ephemeral in-memory cache
+        self.cache[url] = og
+
+        # store OG in history-aware DB cache
+        yield self.store.store_url_cache(
+            url,
+            media_info["response_code"],
+            media_info["etag"],
+            media_info["expires"],
+            json.dumps(og),
+            media_info["filesystem_id"],
+            media_info["created_ts"],
+        )
+
+        respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
+
+    @defer.inlineCallbacks
+    def _calc_og(self, tree, media_info, requester):
+        # suck our tree into lxml and define our OG response.
+
+        # if we see any image URLs in the OG response, then spider them
+        # (although the client could choose to do this by asking for previews of those
+        # URLs to avoid DoSing the server)
+
+        # "og:type"         : "video",
+        # "og:url"          : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
+        # "og:site_name"    : "YouTube",
+        # "og:video:type"   : "application/x-shockwave-flash",
+        # "og:description"  : "Fun stuff happening here",
+        # "og:title"        : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
+        # "og:image"        : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
+        # "og:video:url"    : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
+        # "og:video:width"  : "1280"
+        # "og:video:height" : "720",
+        # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
+
+        og = {}
+        for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
+            og[tag.attrib['property']] = tag.attrib['content']
+
+        # TODO: grab article: meta tags too, e.g.:
+
+        # "article:publisher" : "https://www.facebook.com/thethudonline" />
+        # "article:author" content="https://www.facebook.com/thethudonline" />
+        # "article:tag" content="baby" />
+        # "article:section" content="Breaking News" />
+        # "article:published_time" content="2016-03-31T19:58:24+00:00" />
+        # "article:modified_time" content="2016-04-01T18:31:53+00:00" />
+
+        if 'og:title' not in og:
+            # do some basic spidering of the HTML
+            title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
+            og['og:title'] = title[0].text.strip() if title else None
+
+        if 'og:image' not in og:
+            # TODO: extract a favicon failing all else
+            meta_image = tree.xpath(
+                "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
+            )
+            if meta_image:
+                og['og:image'] = self._rebase_url(meta_image[0], media_info['uri'])
+            else:
+                # TODO: consider inlined CSS styles as well as width & height attribs
+                images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
+                images = sorted(images, key=lambda i: (
+                    -1 * int(i.attrib['width']) * int(i.attrib['height'])
+                ))
+                if not images:
+                    images = tree.xpath("//img[@src]")
+                if images:
+                    og['og:image'] = images[0].attrib['src']
+
+        # pre-cache the image for posterity
+        # FIXME: it might be cleaner to use the same flow as the main /preview_url request
+        # itself and benefit from the same caching etc.  But for now we just rely on the
+        # caching on the master request to speed things up.
+        if 'og:image' in og and og['og:image']:
+            image_info = yield self._download_url(
+                self._rebase_url(og['og:image'], media_info['uri']), requester.user
+            )
+
+            if self._is_media(image_info['media_type']):
+                # TODO: make sure we don't choke on white-on-transparent images
+                dims = yield self.media_repo._generate_local_thumbnails(
+                    image_info['filesystem_id'], image_info
+                )
+                if dims:
+                    og["og:image:width"] = dims['width']
+                    og["og:image:height"] = dims['height']
+                else:
+                    logger.warn("Couldn't get dims for %s" % og["og:image"])
+
+                og["og:image"] = "mxc://%s/%s" % (
+                    self.server_name, image_info['filesystem_id']
+                )
+                og["og:image:type"] = image_info['media_type']
+                og["matrix:image:size"] = image_info['media_length']
+            else:
+                del og["og:image"]
+
+        if 'og:description' not in og:
+            meta_description = tree.xpath(
+                "//*/meta"
+                "[translate(@name, 'DESCRIPTION', 'description')='description']"
+                "/@content")
+            if meta_description:
+                og['og:description'] = meta_description[0]
+            else:
+                # grab any text nodes which are inside the <body/> tag...
+                # unless they are within an HTML5 semantic markup tag...
+                # <header/>, <nav/>, <aside/>, <footer/>
+                # ...or if they are within a <script/> or <style/> tag.
+                # This is a very very very coarse approximation to a plain text
+                # render of the page.
+                text_nodes = tree.xpath("//text()[not(ancestor::header | ancestor::nav | "
+                                        "ancestor::aside | ancestor::footer | "
+                                        "ancestor::script | ancestor::style)]" +
+                                        "[ancestor::body]")
+                text = ''
+                for text_node in text_nodes:
+                    if len(text) < 500:
+                        text += text_node + ' '
+                    else:
+                        break
+                text = re.sub(r'[\t ]+', ' ', text)
+                text = re.sub(r'[\t \r\n]*[\r\n]+', '\n', text)
+                text = text.strip()[:500]
+                og['og:description'] = text if text else None
+
+        # TODO: delete the url downloads to stop diskfilling,
+        # as we only ever cared about its OG
+        defer.returnValue(og)
+
+    def _rebase_url(self, url, base):
+        base = list(urlparse.urlparse(base))
+        url = list(urlparse.urlparse(url))
+        if not url[0]:  # fix up schema
+            url[0] = base[0] or "http"
+        if not url[1]:  # fix up hostname
+            url[1] = base[1]
+            if not url[2].startswith('/'):
+                url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2]
+        return urlparse.urlunparse(url)
+
+    @defer.inlineCallbacks
+    def _download_url(self, url, user):
+        # TODO: we should probably honour robots.txt... except in practice
+        # we're most likely being explicitly triggered by a human rather than a
+        # bot, so are we really a robot?
+
+        # XXX: horrible duplication with base_resource's _download_remote_file()
+        file_id = random_string(24)
+
+        fname = self.filepaths.local_media_filepath(file_id)
+        self.media_repo._makedirs(fname)
+
+        try:
+            with open(fname, "wb") as f:
+                logger.debug("Trying to get url '%s'" % url)
+                length, headers, uri, code = yield self.client.get_file(
+                    url, output_stream=f, max_size=self.max_spider_size,
+                )
+                # FIXME: pass through 404s and other error messages nicely
+
+            media_type = headers["Content-Type"][0]
+            time_now_ms = self.clock.time_msec()
+
+            content_disposition = headers.get("Content-Disposition", None)
+            if content_disposition:
+                _, params = cgi.parse_header(content_disposition[0],)
+                download_name = None
+
+                # First check if there is a valid UTF-8 filename
+                download_name_utf8 = params.get("filename*", None)
+                if download_name_utf8:
+                    if download_name_utf8.lower().startswith("utf-8''"):
+                        download_name = download_name_utf8[7:]
+
+                # If there isn't check for an ascii name.
+                if not download_name:
+                    download_name_ascii = params.get("filename", None)
+                    if download_name_ascii and is_ascii(download_name_ascii):
+                        download_name = download_name_ascii
+
+                if download_name:
+                    download_name = urlparse.unquote(download_name)
+                    try:
+                        download_name = download_name.decode("utf-8")
+                    except UnicodeDecodeError:
+                        download_name = None
+            else:
+                download_name = None
+
+            yield self.store.store_local_media(
+                media_id=file_id,
+                media_type=media_type,
+                time_now_ms=self.clock.time_msec(),
+                upload_name=download_name,
+                media_length=length,
+                user_id=user,
+            )
+
+        except Exception as e:
+            os.remove(fname)
+            raise SynapseError(
+                500, ("Failed to download content: %s" % e),
+                Codes.UNKNOWN
+            )
+
+        defer.returnValue({
+            "media_type": media_type,
+            "media_length": length,
+            "download_name": download_name,
+            "created_ts": time_now_ms,
+            "filesystem_id": file_id,
+            "filename": fname,
+            "uri": uri,
+            "response_code": code,
+            # FIXME: we should calculate a proper expiration based on the
+            # Cache-Control and Expire headers.  But for now, assume 1 hour.
+            "expires": 60 * 60 * 1000,
+            "etag": headers["ETag"][0] if "ETag" in headers else None,
+        })
+
+    def _is_media(self, content_type):
+        if content_type.lower().startswith("image/"):
+            return True
+
+    def _is_html(self, content_type):
+        content_type = content_type.lower()
+        if (
+            content_type.startswith("text/html") or
+            content_type.startswith("application/xhtml")
+        ):
+            return True
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index ab52499785..0b9e1de1a7 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -14,7 +14,8 @@
 # limitations under the License.
 
 
-from .base_resource import BaseMediaResource, parse_media_id
+from ._base import parse_media_id, respond_404, respond_with_file
+from twisted.web.resource import Resource
 from synapse.http.servlet import parse_string, parse_integer
 from synapse.http.server import request_handler
 
@@ -26,14 +27,25 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-class ThumbnailResource(BaseMediaResource):
+class ThumbnailResource(Resource):
     isLeaf = True
 
+    def __init__(self, hs, media_repo):
+        Resource.__init__(self)
+
+        self.store = hs.get_datastore()
+        self.filepaths = media_repo.filepaths
+        self.media_repo = media_repo
+        self.dynamic_thumbnails = hs.config.dynamic_thumbnails
+        self.server_name = hs.hostname
+        self.version_string = hs.version_string
+        self.clock = hs.get_clock()
+
     def render_GET(self, request):
         self._async_render_GET(request)
         return NOT_DONE_YET
 
-    @request_handler
+    @request_handler()
     @defer.inlineCallbacks
     def _async_render_GET(self, request):
         server_name, media_id, _ = parse_media_id(request)
@@ -69,9 +81,14 @@ class ThumbnailResource(BaseMediaResource):
         media_info = yield self.store.get_local_media(media_id)
 
         if not media_info:
-            self._respond_404(request)
+            respond_404(request)
             return
 
+        # if media_info["media_type"] == "image/svg+xml":
+        #     file_path = self.filepaths.local_media_filepath(media_id)
+        #     yield respond_with_file(request, media_info["media_type"], file_path)
+        #     return
+
         thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
 
         if thumbnail_infos:
@@ -86,7 +103,7 @@ class ThumbnailResource(BaseMediaResource):
             file_path = self.filepaths.local_media_thumbnail(
                 media_id, t_width, t_height, t_type, t_method,
             )
-            yield self._respond_with_file(request, t_type, file_path)
+            yield respond_with_file(request, t_type, file_path)
 
         else:
             yield self._respond_default_thumbnail(
@@ -100,9 +117,14 @@ class ThumbnailResource(BaseMediaResource):
         media_info = yield self.store.get_local_media(media_id)
 
         if not media_info:
-            self._respond_404(request)
+            respond_404(request)
             return
 
+        # if media_info["media_type"] == "image/svg+xml":
+        #     file_path = self.filepaths.local_media_filepath(media_id)
+        #     yield respond_with_file(request, media_info["media_type"], file_path)
+        #     return
+
         thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
         for info in thumbnail_infos:
             t_w = info["thumbnail_width"] == desired_width
@@ -114,18 +136,18 @@ class ThumbnailResource(BaseMediaResource):
                 file_path = self.filepaths.local_media_thumbnail(
                     media_id, desired_width, desired_height, desired_type, desired_method,
                 )
-                yield self._respond_with_file(request, desired_type, file_path)
+                yield respond_with_file(request, desired_type, file_path)
                 return
 
         logger.debug("We don't have a local thumbnail of that size. Generating")
 
         # Okay, so we generate one.
-        file_path = yield self._generate_local_exact_thumbnail(
+        file_path = yield self.media_repo.generate_local_exact_thumbnail(
             media_id, desired_width, desired_height, desired_method, desired_type
         )
 
         if file_path:
-            yield self._respond_with_file(request, desired_type, file_path)
+            yield respond_with_file(request, desired_type, file_path)
         else:
             yield self._respond_default_thumbnail(
                 request, media_info, desired_width, desired_height,
@@ -136,7 +158,12 @@ class ThumbnailResource(BaseMediaResource):
     def _select_or_generate_remote_thumbnail(self, request, server_name, media_id,
                                              desired_width, desired_height,
                                              desired_method, desired_type):
-        media_info = yield self._get_remote_media(server_name, media_id)
+        media_info = yield self.media_repo.get_remote_media(server_name, media_id)
+
+        # if media_info["media_type"] == "image/svg+xml":
+        #     file_path = self.filepaths.remote_media_filepath(server_name, media_id)
+        #     yield respond_with_file(request, media_info["media_type"], file_path)
+        #     return
 
         thumbnail_infos = yield self.store.get_remote_media_thumbnails(
             server_name, media_id,
@@ -155,19 +182,19 @@ class ThumbnailResource(BaseMediaResource):
                     server_name, file_id, desired_width, desired_height,
                     desired_type, desired_method,
                 )
-                yield self._respond_with_file(request, desired_type, file_path)
+                yield respond_with_file(request, desired_type, file_path)
                 return
 
         logger.debug("We don't have a local thumbnail of that size. Generating")
 
         # Okay, so we generate one.
-        file_path = yield self._generate_remote_exact_thumbnail(
+        file_path = yield self.media_repo.generate_remote_exact_thumbnail(
             server_name, file_id, media_id, desired_width,
             desired_height, desired_method, desired_type
         )
 
         if file_path:
-            yield self._respond_with_file(request, desired_type, file_path)
+            yield respond_with_file(request, desired_type, file_path)
         else:
             yield self._respond_default_thumbnail(
                 request, media_info, desired_width, desired_height,
@@ -179,7 +206,12 @@ class ThumbnailResource(BaseMediaResource):
                                   height, method, m_type):
         # TODO: Don't download the whole remote file
         # We should proxy the thumbnail from the remote server instead.
-        media_info = yield self._get_remote_media(server_name, media_id)
+        media_info = yield self.media_repo.get_remote_media(server_name, media_id)
+
+        # if media_info["media_type"] == "image/svg+xml":
+        #     file_path = self.filepaths.remote_media_filepath(server_name, media_id)
+        #     yield respond_with_file(request, media_info["media_type"], file_path)
+        #     return
 
         thumbnail_infos = yield self.store.get_remote_media_thumbnails(
             server_name, media_id,
@@ -199,7 +231,7 @@ class ThumbnailResource(BaseMediaResource):
             file_path = self.filepaths.remote_media_thumbnail(
                 server_name, file_id, t_width, t_height, t_type, t_method,
             )
-            yield self._respond_with_file(request, t_type, file_path, t_length)
+            yield respond_with_file(request, t_type, file_path, t_length)
         else:
             yield self._respond_default_thumbnail(
                 request, media_info, width, height, method, m_type,
@@ -208,6 +240,8 @@ class ThumbnailResource(BaseMediaResource):
     @defer.inlineCallbacks
     def _respond_default_thumbnail(self, request, media_info, width, height,
                                    method, m_type):
+        # XXX: how is this meant to work? store.get_default_thumbnails
+        # appears to always return [] so won't this always 404?
         media_type = media_info["media_type"]
         top_level_type = media_type.split("/")[0]
         sub_type = media_type.split("/")[-1].split(";")[0]
@@ -223,7 +257,7 @@ class ThumbnailResource(BaseMediaResource):
                 "_default", "_default",
             )
         if not thumbnail_infos:
-            self._respond_404(request)
+            respond_404(request)
             return
 
         thumbnail_info = self._select_thumbnail(
@@ -239,7 +273,7 @@ class ThumbnailResource(BaseMediaResource):
         file_path = self.filepaths.default_thumbnail(
             top_level_type, sub_type, t_width, t_height, t_type, t_method,
         )
-        yield self.respond_with_file(request, t_type, file_path, t_length)
+        yield respond_with_file(request, t_type, file_path, t_length)
 
     def _select_thumbnail(self, desired_width, desired_height, desired_method,
                           desired_type, thumbnail_infos):
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 9c7ad4ae85..b716d1d892 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -15,20 +15,34 @@
 
 from synapse.http.server import respond_with_json, request_handler
 
-from synapse.util.stringutils import random_string
 from synapse.api.errors import SynapseError
 
 from twisted.web.server import NOT_DONE_YET
 from twisted.internet import defer
 
-from .base_resource import BaseMediaResource
+from twisted.web.resource import Resource
 
 import logging
 
 logger = logging.getLogger(__name__)
 
 
-class UploadResource(BaseMediaResource):
+class UploadResource(Resource):
+    isLeaf = True
+
+    def __init__(self, hs, media_repo):
+        Resource.__init__(self)
+
+        self.media_repo = media_repo
+        self.filepaths = media_repo.filepaths
+        self.store = hs.get_datastore()
+        self.clock = hs.get_clock()
+        self.server_name = hs.hostname
+        self.auth = hs.get_auth()
+        self.max_upload_size = hs.config.max_upload_size
+        self.version_string = hs.version_string
+        self.clock = hs.get_clock()
+
     def render_POST(self, request):
         self._async_render_POST(request)
         return NOT_DONE_YET
@@ -37,37 +51,7 @@ class UploadResource(BaseMediaResource):
         respond_with_json(request, 200, {}, send_cors=True)
         return NOT_DONE_YET
 
-    @defer.inlineCallbacks
-    def create_content(self, media_type, upload_name, content, content_length,
-                       auth_user):
-        media_id = random_string(24)
-
-        fname = self.filepaths.local_media_filepath(media_id)
-        self._makedirs(fname)
-
-        # This shouldn't block for very long because the content will have
-        # already been uploaded at this point.
-        with open(fname, "wb") as f:
-            f.write(content)
-
-        yield self.store.store_local_media(
-            media_id=media_id,
-            media_type=media_type,
-            time_now_ms=self.clock.time_msec(),
-            upload_name=upload_name,
-            media_length=content_length,
-            user_id=auth_user,
-        )
-        media_info = {
-            "media_type": media_type,
-            "media_length": content_length,
-        }
-
-        yield self._generate_local_thumbnails(media_id, media_info)
-
-        defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
-
-    @request_handler
+    @request_handler()
     @defer.inlineCallbacks
     def _async_render_POST(self, request):
         requester = yield self.auth.get_user_by_req(request)
@@ -108,7 +92,7 @@ class UploadResource(BaseMediaResource):
         #     disposition = headers.getRawHeaders("Content-Disposition")[0]
         # TODO(markjh): parse content-dispostion
 
-        content_uri = yield self.create_content(
+        content_uri = yield self.media_repo.create_content(
             media_type, upload_name, request.content.read(),
             content_length, requester.user
         )