summary refs log tree commit diff
path: root/tests/rest
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest')
-rw-r--r--tests/rest/client/test_consent.py118
-rw-r--r--tests/rest/client/v1/test_admin.py149
-rw-r--r--tests/rest/client/v1/test_events.py100
-rw-r--r--tests/rest/client/v1/test_register.py89
-rw-r--r--tests/rest/client/v1/utils.py22
-rw-r--r--tests/rest/client/v2_alpha/test_auth.py104
-rw-r--r--tests/rest/client/v2_alpha/test_capabilities.py78
-rw-r--r--tests/rest/client/v2_alpha/test_filter.py95
-rw-r--r--tests/rest/client/v2_alpha/test_register.py138
-rw-r--r--tests/rest/client/v2_alpha/test_sync.py123
-rw-r--r--tests/rest/media/v1/test_media_storage.py146
-rw-r--r--tests/rest/media/v1/test_url_preview.py470
-rw-r--r--tests/rest/test_well_known.py58
13 files changed, 1305 insertions, 385 deletions
diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py
new file mode 100644
index 0000000000..4294bbec2a
--- /dev/null
+++ b/tests/rest/client/test_consent.py
@@ -0,0 +1,118 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+from synapse.api.urls import ConsentURIBuilder
+from synapse.rest.client.v1 import admin, login, room
+from synapse.rest.consent import consent_resource
+
+from tests import unittest
+from tests.server import render
+
+try:
+    from synapse.push.mailer import load_jinja2_templates
+except Exception:
+    load_jinja2_templates = None
+
+
+class ConsentResourceTestCase(unittest.HomeserverTestCase):
+    skip = "No Jinja installed" if not load_jinja2_templates else None
+    servlets = [
+        admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+    user_id = True
+    hijack_auth = False
+
+    def make_homeserver(self, reactor, clock):
+
+        config = self.default_config()
+        config.user_consent_version = "1"
+        config.public_baseurl = ""
+        config.form_secret = "123abc"
+
+        # Make some temporary templates...
+        temp_consent_path = self.mktemp()
+        os.mkdir(temp_consent_path)
+        os.mkdir(os.path.join(temp_consent_path, 'en'))
+        config.user_consent_template_dir = os.path.abspath(temp_consent_path)
+
+        with open(os.path.join(temp_consent_path, "en/1.html"), 'w') as f:
+            f.write("{{version}},{{has_consented}}")
+
+        with open(os.path.join(temp_consent_path, "en/success.html"), 'w') as f:
+            f.write("yay!")
+
+        hs = self.setup_test_homeserver(config=config)
+        return hs
+
+    def test_render_public_consent(self):
+        """You can observe the terms form without specifying a user"""
+        resource = consent_resource.ConsentResource(self.hs)
+        request, channel = self.make_request("GET", "/consent?v=1", shorthand=False)
+        render(request, resource, self.reactor)
+        self.assertEqual(channel.code, 200)
+
+    def test_accept_consent(self):
+        """
+        A user can use the consent form to accept the terms.
+        """
+        uri_builder = ConsentURIBuilder(self.hs.config)
+        resource = consent_resource.ConsentResource(self.hs)
+
+        # Register a user
+        user_id = self.register_user("user", "pass")
+        access_token = self.login("user", "pass")
+
+        # Fetch the consent page, to get the consent version
+        consent_uri = (
+            uri_builder.build_user_consent_uri(user_id).replace("_matrix/", "")
+            + "&u=user"
+        )
+        request, channel = self.make_request(
+            "GET", consent_uri, access_token=access_token, shorthand=False
+        )
+        render(request, resource, self.reactor)
+        self.assertEqual(channel.code, 200)
+
+        # Get the version from the body, and whether we've consented
+        version, consented = channel.result["body"].decode('ascii').split(",")
+        self.assertEqual(consented, "False")
+
+        # POST to the consent page, saying we've agreed
+        request, channel = self.make_request(
+            "POST",
+            consent_uri + "&v=" + version,
+            access_token=access_token,
+            shorthand=False,
+        )
+        render(request, resource, self.reactor)
+        self.assertEqual(channel.code, 200)
+
+        # Fetch the consent page, to get the consent version -- it should have
+        # changed
+        request, channel = self.make_request(
+            "GET", consent_uri, access_token=access_token, shorthand=False
+        )
+        render(request, resource, self.reactor)
+        self.assertEqual(channel.code, 200)
+
+        # Get the version from the body, and check that it's the version we
+        # agreed to, and that we've consented to it.
+        version, consented = channel.result["body"].decode('ascii').split(",")
+        self.assertEqual(consented, "True")
+        self.assertEqual(version, "1")
diff --git a/tests/rest/client/v1/test_admin.py b/tests/rest/client/v1/test_admin.py
index 1a553fa3f9..407bf0ac4c 100644
--- a/tests/rest/client/v1/test_admin.py
+++ b/tests/rest/client/v1/test_admin.py
@@ -19,24 +19,18 @@ import json
 
 from mock import Mock
 
-from synapse.http.server import JsonResource
+from synapse.api.constants import UserTypes
 from synapse.rest.client.v1.admin import register_servlets
-from synapse.util import Clock
 
 from tests import unittest
-from tests.server import (
-    ThreadedMemoryReactorClock,
-    make_request,
-    render,
-    setup_test_homeserver,
-)
 
 
-class UserRegisterTestCase(unittest.TestCase):
-    def setUp(self):
+class UserRegisterTestCase(unittest.HomeserverTestCase):
+
+    servlets = [register_servlets]
+
+    def make_homeserver(self, reactor, clock):
 
-        self.clock = ThreadedMemoryReactorClock()
-        self.hs_clock = Clock(self.clock)
         self.url = "/_matrix/client/r0/admin/register"
 
         self.registration_handler = Mock()
@@ -50,17 +44,14 @@ class UserRegisterTestCase(unittest.TestCase):
 
         self.secrets = Mock()
 
-        self.hs = setup_test_homeserver(
-            self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
-        )
+        self.hs = self.setup_test_homeserver()
 
         self.hs.config.registration_shared_secret = u"shared"
 
         self.hs.get_media_repository = Mock()
         self.hs.get_deactivate_account_handler = Mock()
 
-        self.resource = JsonResource(self.hs)
-        register_servlets(self.hs, self.resource)
+        return self.hs
 
     def test_disabled(self):
         """
@@ -69,8 +60,8 @@ class UserRegisterTestCase(unittest.TestCase):
         """
         self.hs.config.registration_shared_secret = None
 
-        request, channel = make_request("POST", self.url, b'{}')
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request("POST", self.url, b'{}')
+        self.render(request)
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(
@@ -87,8 +78,8 @@ class UserRegisterTestCase(unittest.TestCase):
 
         self.hs.get_secrets = Mock(return_value=secrets)
 
-        request, channel = make_request("GET", self.url)
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request("GET", self.url)
+        self.render(request)
 
         self.assertEqual(channel.json_body, {"nonce": "abcd"})
 
@@ -97,25 +88,25 @@ class UserRegisterTestCase(unittest.TestCase):
         Calling GET on the endpoint will return a randomised nonce, which will
         only last for SALT_TIMEOUT (60s).
         """
-        request, channel = make_request("GET", self.url)
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request("GET", self.url)
+        self.render(request)
         nonce = channel.json_body["nonce"]
 
         # 59 seconds
-        self.clock.advance(59)
+        self.reactor.advance(59)
 
         body = json.dumps({"nonce": nonce})
-        request, channel = make_request("POST", self.url, body.encode('utf8'))
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        self.render(request)
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual('username must be specified', channel.json_body["error"])
 
         # 61 seconds
-        self.clock.advance(2)
+        self.reactor.advance(2)
 
-        request, channel = make_request("POST", self.url, body.encode('utf8'))
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        self.render(request)
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual('unrecognised nonce', channel.json_body["error"])
@@ -124,8 +115,8 @@ class UserRegisterTestCase(unittest.TestCase):
         """
         Only the provided nonce can be used, as it's checked in the MAC.
         """
-        request, channel = make_request("GET", self.url)
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request("GET", self.url)
+        self.render(request)
         nonce = channel.json_body["nonce"]
 
         want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -141,8 +132,8 @@ class UserRegisterTestCase(unittest.TestCase):
                 "mac": want_mac,
             }
         )
-        request, channel = make_request("POST", self.url, body.encode('utf8'))
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        self.render(request)
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("HMAC incorrect", channel.json_body["error"])
@@ -152,12 +143,14 @@ class UserRegisterTestCase(unittest.TestCase):
         When the correct nonce is provided, and the right key is provided, the
         user is registered.
         """
-        request, channel = make_request("GET", self.url)
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request("GET", self.url)
+        self.render(request)
         nonce = channel.json_body["nonce"]
 
         want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
-        want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin")
+        want_mac.update(
+            nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin\x00support"
+        )
         want_mac = want_mac.hexdigest()
 
         body = json.dumps(
@@ -166,11 +159,12 @@ class UserRegisterTestCase(unittest.TestCase):
                 "username": "bob",
                 "password": "abc123",
                 "admin": True,
+                "user_type": UserTypes.SUPPORT,
                 "mac": want_mac,
             }
         )
-        request, channel = make_request("POST", self.url, body.encode('utf8'))
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        self.render(request)
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob:test", channel.json_body["user_id"])
@@ -179,12 +173,14 @@ class UserRegisterTestCase(unittest.TestCase):
         """
         A valid unrecognised nonce.
         """
-        request, channel = make_request("GET", self.url)
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request("GET", self.url)
+        self.render(request)
         nonce = channel.json_body["nonce"]
 
         want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
-        want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin")
+        want_mac.update(
+            nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin"
+        )
         want_mac = want_mac.hexdigest()
 
         body = json.dumps(
@@ -196,15 +192,15 @@ class UserRegisterTestCase(unittest.TestCase):
                 "mac": want_mac,
             }
         )
-        request, channel = make_request("POST", self.url, body.encode('utf8'))
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        self.render(request)
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob:test", channel.json_body["user_id"])
 
         # Now, try and reuse it
-        request, channel = make_request("POST", self.url, body.encode('utf8'))
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        self.render(request)
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual('unrecognised nonce', channel.json_body["error"])
@@ -212,13 +208,13 @@ class UserRegisterTestCase(unittest.TestCase):
     def test_missing_parts(self):
         """
         Synapse will complain if you don't give nonce, username, password, and
-        mac.  Admin is optional.  Additional checks are done for length and
-        type.
+        mac.  Admin and user_types are optional.  Additional checks are done for length
+        and type.
         """
 
         def nonce():
-            request, channel = make_request("GET", self.url)
-            render(request, self.resource, self.clock)
+            request, channel = self.make_request("GET", self.url)
+            self.render(request)
             return channel.json_body["nonce"]
 
         #
@@ -227,8 +223,8 @@ class UserRegisterTestCase(unittest.TestCase):
 
         # Must be present
         body = json.dumps({})
-        request, channel = make_request("POST", self.url, body.encode('utf8'))
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        self.render(request)
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual('nonce must be specified', channel.json_body["error"])
@@ -239,52 +235,52 @@ class UserRegisterTestCase(unittest.TestCase):
 
         # Must be present
         body = json.dumps({"nonce": nonce()})
-        request, channel = make_request("POST", self.url, body.encode('utf8'))
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        self.render(request)
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual('username must be specified', channel.json_body["error"])
 
         # Must be a string
         body = json.dumps({"nonce": nonce(), "username": 1234})
-        request, channel = make_request("POST", self.url, body.encode('utf8'))
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        self.render(request)
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual('Invalid username', channel.json_body["error"])
 
         # Must not have null bytes
         body = json.dumps({"nonce": nonce(), "username": u"abcd\u0000"})
-        request, channel = make_request("POST", self.url, body.encode('utf8'))
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        self.render(request)
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual('Invalid username', channel.json_body["error"])
 
         # Must not have null bytes
         body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
-        request, channel = make_request("POST", self.url, body.encode('utf8'))
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        self.render(request)
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual('Invalid username', channel.json_body["error"])
 
         #
-        # Username checks
+        # Password checks
         #
 
         # Must be present
         body = json.dumps({"nonce": nonce(), "username": "a"})
-        request, channel = make_request("POST", self.url, body.encode('utf8'))
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        self.render(request)
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual('password must be specified', channel.json_body["error"])
 
         # Must be a string
         body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
-        request, channel = make_request("POST", self.url, body.encode('utf8'))
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        self.render(request)
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual('Invalid password', channel.json_body["error"])
@@ -293,16 +289,33 @@ class UserRegisterTestCase(unittest.TestCase):
         body = json.dumps(
             {"nonce": nonce(), "username": "a", "password": u"abcd\u0000"}
         )
-        request, channel = make_request("POST", self.url, body.encode('utf8'))
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        self.render(request)
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual('Invalid password', channel.json_body["error"])
 
         # Super long
         body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
-        request, channel = make_request("POST", self.url, body.encode('utf8'))
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        self.render(request)
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual('Invalid password', channel.json_body["error"])
+
+        #
+        # user_type check
+        #
+
+        # Invalid user_type
+        body = json.dumps({
+            "nonce": nonce(),
+            "username": "a",
+            "password": "1234",
+            "user_type": "invalid"}
+        )
+        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        self.render(request)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual('Invalid user type', channel.json_body["error"])
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index 956f7fc4c4..483bebc832 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -16,64 +16,49 @@
 """ Tests REST events for /events paths."""
 
 from mock import Mock, NonCallableMock
-from six import PY3
 
-from twisted.internet import defer
+from synapse.rest.client.v1 import admin, events, login, room
 
-from ....utils import MockHttpResource, setup_test_homeserver
-from .utils import RestTestCase
+from tests import unittest
 
-PATH_PREFIX = "/_matrix/client/api/v1"
 
-
-class EventStreamPermissionsTestCase(RestTestCase):
+class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
     """ Tests event streaming (GET /events). """
 
-    if PY3:
-        skip = "Skip on Py3 until ported to use not V1 only register."
+    servlets = [
+        events.register_servlets,
+        room.register_servlets,
+        admin.register_servlets,
+        login.register_servlets,
+    ]
 
-    @defer.inlineCallbacks
-    def setUp(self):
-        import synapse.rest.client.v1.events
-        import synapse.rest.client.v1_only.register
-        import synapse.rest.client.v1.room
+    def make_homeserver(self, reactor, clock):
 
-        self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
+        config = self.default_config()
+        config.enable_registration_captcha = False
+        config.enable_registration = True
+        config.auto_join_rooms = []
 
-        hs = yield setup_test_homeserver(
-            self.addCleanup,
-            http_client=None,
-            federation_client=Mock(),
-            ratelimiter=NonCallableMock(spec_set=["send_message"]),
+        hs = self.setup_test_homeserver(
+            config=config, ratelimiter=NonCallableMock(spec_set=["send_message"])
         )
         self.ratelimiter = hs.get_ratelimiter()
         self.ratelimiter.send_message.return_value = (True, 0)
-        hs.config.enable_registration_captcha = False
-        hs.config.enable_registration = True
-        hs.config.auto_join_rooms = []
 
         hs.get_handlers().federation_handler = Mock()
 
-        synapse.rest.client.v1_only.register.register_servlets(hs, self.mock_resource)
-        synapse.rest.client.v1.events.register_servlets(hs, self.mock_resource)
-        synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
+        return hs
+
+    def prepare(self, hs, reactor, clock):
 
         # register an account
-        self.user_id = "sid1"
-        response = yield self.register(self.user_id)
-        self.token = response["access_token"]
-        self.user_id = response["user_id"]
+        self.user_id = self.register_user("sid1", "pass")
+        self.token = self.login(self.user_id, "pass")
 
         # register a 2nd account
-        self.other_user = "other1"
-        response = yield self.register(self.other_user)
-        self.other_token = response["access_token"]
-        self.other_user = response["user_id"]
+        self.other_user = self.register_user("other2", "pass")
+        self.other_token = self.login(self.other_user, "pass")
 
-    def tearDown(self):
-        pass
-
-    @defer.inlineCallbacks
     def test_stream_basic_permissions(self):
         # invalid token, expect 401
         # note: this is in violation of the original v1 spec, which expected
@@ -81,34 +66,37 @@ class EventStreamPermissionsTestCase(RestTestCase):
         # implementation is now part of the r0 implementation, the newer
         # behaviour is used instead to be consistent with the r0 spec.
         # see issue #2602
-        (code, response) = yield self.mock_resource.trigger_get(
-            "/events?access_token=%s" % ("invalid" + self.token,)
+        request, channel = self.make_request(
+            "GET", "/events?access_token=%s" % ("invalid" + self.token,)
         )
-        self.assertEquals(401, code, msg=str(response))
+        self.render(request)
+        self.assertEquals(channel.code, 401, msg=channel.result)
 
         # valid token, expect content
-        (code, response) = yield self.mock_resource.trigger_get(
-            "/events?access_token=%s&timeout=0" % (self.token,)
+        request, channel = self.make_request(
+            "GET", "/events?access_token=%s&timeout=0" % (self.token,)
         )
-        self.assertEquals(200, code, msg=str(response))
-        self.assertTrue("chunk" in response)
-        self.assertTrue("start" in response)
-        self.assertTrue("end" in response)
+        self.render(request)
+        self.assertEquals(channel.code, 200, msg=channel.result)
+        self.assertTrue("chunk" in channel.json_body)
+        self.assertTrue("start" in channel.json_body)
+        self.assertTrue("end" in channel.json_body)
 
-    @defer.inlineCallbacks
     def test_stream_room_permissions(self):
-        room_id = yield self.create_room_as(self.other_user, tok=self.other_token)
-        yield self.send(room_id, tok=self.other_token)
+        room_id = self.helper.create_room_as(self.other_user, tok=self.other_token)
+        self.helper.send(room_id, tok=self.other_token)
 
         # invited to room (expect no content for room)
-        yield self.invite(
+        self.helper.invite(
             room_id, src=self.other_user, targ=self.user_id, tok=self.other_token
         )
 
-        (code, response) = yield self.mock_resource.trigger_get(
-            "/events?access_token=%s&timeout=0" % (self.token,)
+        # valid token, expect content
+        request, channel = self.make_request(
+            "GET", "/events?access_token=%s&timeout=0" % (self.token,)
         )
-        self.assertEquals(200, code, msg=str(response))
+        self.render(request)
+        self.assertEquals(channel.code, 200, msg=channel.result)
 
         # We may get a presence event for ourselves down
         self.assertEquals(
@@ -116,7 +104,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
             len(
                 [
                     c
-                    for c in response["chunk"]
+                    for c in channel.json_body["chunk"]
                     if not (
                         c.get("type") == "m.presence"
                         and c["content"].get("user_id") == self.user_id
@@ -126,7 +114,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
         )
 
         # joined room (expect all content for room)
-        yield self.join(room=room_id, user=self.user_id, tok=self.token)
+        self.helper.join(room=room_id, user=self.user_id, tok=self.token)
 
         # left to room (expect no content for room)
 
diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py
deleted file mode 100644
index 6b7ff813d5..0000000000
--- a/tests/rest/client/v1/test_register.py
+++ /dev/null
@@ -1,89 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import json
-
-from mock import Mock
-from six import PY3
-
-from twisted.test.proto_helpers import MemoryReactorClock
-
-from synapse.http.server import JsonResource
-from synapse.rest.client.v1_only.register import register_servlets
-from synapse.util import Clock
-
-from tests import unittest
-from tests.server import make_request, render, setup_test_homeserver
-
-
-class CreateUserServletTestCase(unittest.TestCase):
-    """
-    Tests for CreateUserRestServlet.
-    """
-
-    if PY3:
-        skip = "Not ported to Python 3."
-
-    def setUp(self):
-        self.registration_handler = Mock()
-
-        self.appservice = Mock(sender="@as:test")
-        self.datastore = Mock(
-            get_app_service_by_token=Mock(return_value=self.appservice)
-        )
-
-        handlers = Mock(registration_handler=self.registration_handler)
-        self.clock = MemoryReactorClock()
-        self.hs_clock = Clock(self.clock)
-
-        self.hs = self.hs = setup_test_homeserver(
-            self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
-        )
-        self.hs.get_datastore = Mock(return_value=self.datastore)
-        self.hs.get_handlers = Mock(return_value=handlers)
-
-    def test_POST_createuser_with_valid_user(self):
-
-        res = JsonResource(self.hs)
-        register_servlets(self.hs, res)
-
-        request_data = json.dumps(
-            {
-                "localpart": "someone",
-                "displayname": "someone interesting",
-                "duration_seconds": 200,
-            }
-        )
-
-        url = b'/_matrix/client/api/v1/createUser?access_token=i_am_an_app_service'
-
-        user_id = "@someone:interesting"
-        token = "my token"
-
-        self.registration_handler.get_or_create_user = Mock(
-            return_value=(user_id, token)
-        )
-
-        request, channel = make_request(b"POST", url, request_data)
-        render(request, res, self.clock)
-
-        self.assertEquals(channel.result["code"], b"200")
-
-        det_data = {
-            "user_id": user_id,
-            "access_token": token,
-            "home_server": self.hs.hostname,
-        }
-        self.assertDictContainsSubset(det_data, json.loads(channel.result["body"]))
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 530dc8ba6d..9c401bf300 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -169,7 +169,7 @@ class RestHelper(object):
             path = path + "?access_token=%s" % tok
 
         request, channel = make_request(
-            "POST", path, json.dumps(content).encode('utf8')
+            self.hs.get_reactor(), "POST", path, json.dumps(content).encode('utf8')
         )
         render(request, self.resource, self.hs.get_reactor())
 
@@ -217,7 +217,9 @@ class RestHelper(object):
 
         data = {"membership": membership}
 
-        request, channel = make_request("PUT", path, json.dumps(data).encode('utf8'))
+        request, channel = make_request(
+            self.hs.get_reactor(), "PUT", path, json.dumps(data).encode('utf8')
+        )
 
         render(request, self.resource, self.hs.get_reactor())
 
@@ -228,18 +230,6 @@ class RestHelper(object):
 
         self.auth_user_id = temp_id
 
-    @defer.inlineCallbacks
-    def register(self, user_id):
-        (code, response) = yield self.mock_resource.trigger(
-            "POST",
-            "/_matrix/client/r0/register",
-            json.dumps(
-                {"user": user_id, "password": "test", "type": "m.login.password"}
-            ),
-        )
-        self.assertEquals(200, code)
-        defer.returnValue(response)
-
     def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
         if txn_id is None:
             txn_id = "m%s" % (str(time.time()))
@@ -251,7 +241,9 @@ class RestHelper(object):
         if tok:
             path = path + "?access_token=%s" % tok
 
-        request, channel = make_request("PUT", path, json.dumps(content).encode('utf8'))
+        request, channel = make_request(
+            self.hs.get_reactor(), "PUT", path, json.dumps(content).encode('utf8')
+        )
         render(request, self.resource, self.hs.get_reactor())
 
         assert int(channel.result["code"]) == expect_code, (
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
new file mode 100644
index 0000000000..7fa120a10f
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -0,0 +1,104 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet.defer import succeed
+
+from synapse.api.constants import LoginType
+from synapse.rest.client.v1 import admin
+from synapse.rest.client.v2_alpha import auth, register
+
+from tests import unittest
+
+
+class FallbackAuthTests(unittest.HomeserverTestCase):
+
+    servlets = [
+        auth.register_servlets,
+        admin.register_servlets,
+        register.register_servlets,
+    ]
+    hijack_auth = False
+
+    def make_homeserver(self, reactor, clock):
+
+        config = self.default_config()
+
+        config.enable_registration_captcha = True
+        config.recaptcha_public_key = "brokencake"
+        config.registrations_require_3pid = []
+
+        hs = self.setup_test_homeserver(config=config)
+        return hs
+
+    def prepare(self, reactor, clock, hs):
+        auth_handler = hs.get_auth_handler()
+
+        self.recaptcha_attempts = []
+
+        def _recaptcha(authdict, clientip):
+            self.recaptcha_attempts.append((authdict, clientip))
+            return succeed(True)
+
+        auth_handler.checkers[LoginType.RECAPTCHA] = _recaptcha
+
+    @unittest.INFO
+    def test_fallback_captcha(self):
+
+        request, channel = self.make_request(
+            "POST",
+            "register",
+            {"username": "user", "type": "m.login.password", "password": "bar"},
+        )
+        self.render(request)
+
+        # Returns a 401 as per the spec
+        self.assertEqual(request.code, 401)
+        # Grab the session
+        session = channel.json_body["session"]
+        # Assert our configured public key is being given
+        self.assertEqual(
+            channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake"
+        )
+
+        request, channel = self.make_request(
+            "GET", "auth/m.login.recaptcha/fallback/web?session=" + session
+        )
+        self.render(request)
+        self.assertEqual(request.code, 200)
+
+        request, channel = self.make_request(
+            "POST",
+            "auth/m.login.recaptcha/fallback/web?session="
+            + session
+            + "&g-recaptcha-response=a",
+        )
+        self.render(request)
+        self.assertEqual(request.code, 200)
+
+        # The recaptcha handler is called with the response given
+        self.assertEqual(len(self.recaptcha_attempts), 1)
+        self.assertEqual(self.recaptcha_attempts[0][0]["response"], "a")
+
+        # Now we have fufilled the recaptcha fallback step, we can then send a
+        # request to the register API with the session in the authdict.
+        request, channel = self.make_request(
+            "POST", "register", {"auth": {"session": session}}
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # We're given a registered user.
+        self.assertEqual(channel.json_body["user_id"], "@user:test")
diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py
new file mode 100644
index 0000000000..d3d43970fb
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_capabilities.py
@@ -0,0 +1,78 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.api.constants import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS
+from synapse.rest.client.v1 import admin, login
+from synapse.rest.client.v2_alpha import capabilities
+
+from tests import unittest
+
+
+class CapabilitiesTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        admin.register_servlets,
+        capabilities.register_servlets,
+        login.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor, clock):
+        self.url = b"/_matrix/client/r0/capabilities"
+        hs = self.setup_test_homeserver()
+        self.store = hs.get_datastore()
+        return hs
+
+    def test_check_auth_required(self):
+        request, channel = self.make_request("GET", self.url)
+        self.render(request)
+
+        self.assertEqual(channel.code, 401)
+
+    def test_get_room_version_capabilities(self):
+        self.register_user("user", "pass")
+        access_token = self.login("user", "pass")
+
+        request, channel = self.make_request("GET", self.url, access_token=access_token)
+        self.render(request)
+        capabilities = channel.json_body['capabilities']
+
+        self.assertEqual(channel.code, 200)
+        for room_version in capabilities['m.room_versions']['available'].keys():
+            self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, "" + room_version)
+        self.assertEqual(
+            DEFAULT_ROOM_VERSION, capabilities['m.room_versions']['default']
+        )
+
+    def test_get_change_password_capabilities(self):
+        localpart = "user"
+        password = "pass"
+        user = self.register_user(localpart, password)
+        access_token = self.login(user, password)
+
+        request, channel = self.make_request("GET", self.url, access_token=access_token)
+        self.render(request)
+        capabilities = channel.json_body['capabilities']
+
+        self.assertEqual(channel.code, 200)
+
+        # Test case where password is handled outside of Synapse
+        self.assertTrue(capabilities['m.change_password']['enabled'])
+        self.get_success(self.store.user_set_password_hash(user, None))
+        request, channel = self.make_request("GET", self.url, access_token=access_token)
+        self.render(request)
+        capabilities = channel.json_body['capabilities']
+
+        self.assertEqual(channel.code, 200)
+        self.assertFalse(capabilities['m.change_password']['enabled'])
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index 6a886ee3b8..f42a8efbf4 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -13,84 +13,47 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import synapse.types
 from synapse.api.errors import Codes
-from synapse.http.server import JsonResource
 from synapse.rest.client.v2_alpha import filter
-from synapse.types import UserID
-from synapse.util import Clock
 
 from tests import unittest
-from tests.server import (
-    ThreadedMemoryReactorClock as MemoryReactorClock,
-    make_request,
-    render,
-    setup_test_homeserver,
-)
 
 PATH_PREFIX = "/_matrix/client/v2_alpha"
 
 
-class FilterTestCase(unittest.TestCase):
+class FilterTestCase(unittest.HomeserverTestCase):
 
-    USER_ID = "@apple:test"
+    user_id = "@apple:test"
+    hijack_auth = True
     EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}}
     EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}'
-    TO_REGISTER = [filter]
+    servlets = [filter.register_servlets]
 
-    def setUp(self):
-        self.clock = MemoryReactorClock()
-        self.hs_clock = Clock(self.clock)
-
-        self.hs = setup_test_homeserver(
-            self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
-        )
-
-        self.auth = self.hs.get_auth()
-
-        def get_user_by_access_token(token=None, allow_guest=False):
-            return {
-                "user": UserID.from_string(self.USER_ID),
-                "token_id": 1,
-                "is_guest": False,
-            }
-
-        def get_user_by_req(request, allow_guest=False, rights="access"):
-            return synapse.types.create_requester(
-                UserID.from_string(self.USER_ID), 1, False, None
-            )
-
-        self.auth.get_user_by_access_token = get_user_by_access_token
-        self.auth.get_user_by_req = get_user_by_req
-
-        self.store = self.hs.get_datastore()
-        self.filtering = self.hs.get_filtering()
-        self.resource = JsonResource(self.hs)
-
-        for r in self.TO_REGISTER:
-            r.register_servlets(self.hs, self.resource)
+    def prepare(self, reactor, clock, hs):
+        self.filtering = hs.get_filtering()
+        self.store = hs.get_datastore()
 
     def test_add_filter(self):
-        request, channel = make_request(
+        request, channel = self.make_request(
             "POST",
-            "/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
+            "/_matrix/client/r0/user/%s/filter" % (self.user_id),
             self.EXAMPLE_FILTER_JSON,
         )
-        render(request, self.resource, self.clock)
+        self.render(request)
 
         self.assertEqual(channel.result["code"], b"200")
         self.assertEqual(channel.json_body, {"filter_id": "0"})
         filter = self.store.get_user_filter(user_localpart="apple", filter_id=0)
-        self.clock.advance(0)
+        self.pump()
         self.assertEquals(filter.result, self.EXAMPLE_FILTER)
 
     def test_add_filter_for_other_user(self):
-        request, channel = make_request(
+        request, channel = self.make_request(
             "POST",
             "/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"),
             self.EXAMPLE_FILTER_JSON,
         )
-        render(request, self.resource, self.clock)
+        self.render(request)
 
         self.assertEqual(channel.result["code"], b"403")
         self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
@@ -98,12 +61,12 @@ class FilterTestCase(unittest.TestCase):
     def test_add_filter_non_local_user(self):
         _is_mine = self.hs.is_mine
         self.hs.is_mine = lambda target_user: False
-        request, channel = make_request(
+        request, channel = self.make_request(
             "POST",
-            "/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
+            "/_matrix/client/r0/user/%s/filter" % (self.user_id),
             self.EXAMPLE_FILTER_JSON,
         )
-        render(request, self.resource, self.clock)
+        self.render(request)
 
         self.hs.is_mine = _is_mine
         self.assertEqual(channel.result["code"], b"403")
@@ -113,21 +76,21 @@ class FilterTestCase(unittest.TestCase):
         filter_id = self.filtering.add_user_filter(
             user_localpart="apple", user_filter=self.EXAMPLE_FILTER
         )
-        self.clock.advance(1)
+        self.reactor.advance(1)
         filter_id = filter_id.result
-        request, channel = make_request(
-            "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id)
+        request, channel = self.make_request(
+            "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id)
         )
-        render(request, self.resource, self.clock)
+        self.render(request)
 
         self.assertEqual(channel.result["code"], b"200")
         self.assertEquals(channel.json_body, self.EXAMPLE_FILTER)
 
     def test_get_filter_non_existant(self):
-        request, channel = make_request(
-            "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID)
+        request, channel = self.make_request(
+            "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id)
         )
-        render(request, self.resource, self.clock)
+        self.render(request)
 
         self.assertEqual(channel.result["code"], b"400")
         self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND)
@@ -135,18 +98,18 @@ class FilterTestCase(unittest.TestCase):
     # Currently invalid params do not have an appropriate errcode
     # in errors.py
     def test_get_filter_invalid_id(self):
-        request, channel = make_request(
-            "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID)
+        request, channel = self.make_request(
+            "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id)
         )
-        render(request, self.resource, self.clock)
+        self.render(request)
 
         self.assertEqual(channel.result["code"], b"400")
 
     # No ID also returns an invalid_id error
     def test_get_filter_no_id(self):
-        request, channel = make_request(
-            "GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID)
+        request, channel = self.make_request(
+            "GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id)
         )
-        render(request, self.resource, self.clock)
+        self.render(request)
 
         self.assertEqual(channel.result["code"], b"400")
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 1c128e81f5..906b348d3e 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -1,83 +1,51 @@
 import json
 
-from mock import Mock
-
-from twisted.python import failure
-from twisted.test.proto_helpers import MemoryReactorClock
-
-from synapse.api.errors import InteractiveAuthIncompleteError
-from synapse.http.server import JsonResource
+from synapse.api.constants import LoginType
+from synapse.appservice import ApplicationService
 from synapse.rest.client.v2_alpha.register import register_servlets
-from synapse.util import Clock
 
 from tests import unittest
-from tests.server import make_request, render, setup_test_homeserver
 
 
-class RegisterRestServletTestCase(unittest.TestCase):
-    def setUp(self):
+class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
-        self.clock = MemoryReactorClock()
-        self.hs_clock = Clock(self.clock)
-        self.url = b"/_matrix/client/r0/register"
+    servlets = [register_servlets]
 
-        self.appservice = None
-        self.auth = Mock(
-            get_appservice_by_req=Mock(side_effect=lambda x: self.appservice)
-        )
+    def make_homeserver(self, reactor, clock):
 
-        self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None))
-        self.auth_handler = Mock(
-            check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
-            get_session_data=Mock(return_value=None),
-        )
-        self.registration_handler = Mock()
-        self.identity_handler = Mock()
-        self.login_handler = Mock()
-        self.device_handler = Mock()
-        self.device_handler.check_device_registered = Mock(return_value="FAKE")
-
-        self.datastore = Mock(return_value=Mock())
-        self.datastore.get_current_state_deltas = Mock(return_value=[])
-
-        # do the dance to hook it up to the hs global
-        self.handlers = Mock(
-            registration_handler=self.registration_handler,
-            identity_handler=self.identity_handler,
-            login_handler=self.login_handler,
-        )
-        self.hs = setup_test_homeserver(
-            self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
-        )
-        self.hs.get_auth = Mock(return_value=self.auth)
-        self.hs.get_handlers = Mock(return_value=self.handlers)
-        self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
-        self.hs.get_device_handler = Mock(return_value=self.device_handler)
-        self.hs.get_datastore = Mock(return_value=self.datastore)
+        self.url = b"/_matrix/client/r0/register"
+
+        self.hs = self.setup_test_homeserver()
         self.hs.config.enable_registration = True
         self.hs.config.registrations_require_3pid = []
         self.hs.config.auto_join_rooms = []
+        self.hs.config.enable_registration_captcha = False
 
-        self.resource = JsonResource(self.hs)
-        register_servlets(self.hs, self.resource)
+        return self.hs
 
     def test_POST_appservice_registration_valid(self):
-        user_id = "@kermit:muppet"
-        token = "kermits_access_token"
-        self.appservice = {"id": "1234"}
-        self.registration_handler.appservice_register = Mock(return_value=user_id)
-        self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
-        request_data = json.dumps({"username": "kermit"})
+        user_id = "@as_user_kermit:test"
+        as_token = "i_am_an_app_service"
+
+        appservice = ApplicationService(
+            as_token, self.hs.config.hostname,
+            id="1234",
+            namespaces={
+                "users": [{"regex": r"@as_user.*", "exclusive": True}],
+            },
+        )
+
+        self.hs.get_datastore().services_cache.append(appservice)
+        request_data = json.dumps({"username": "as_user_kermit"})
 
-        request, channel = make_request(
+        request, channel = self.make_request(
             b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
         )
-        render(request, self.resource, self.clock)
+        self.render(request)
 
         self.assertEquals(channel.result["code"], b"200", channel.result)
         det_data = {
             "user_id": user_id,
-            "access_token": token,
             "home_server": self.hs.hostname,
         }
         self.assertDictContainsSubset(det_data, channel.json_body)
@@ -85,81 +53,69 @@ class RegisterRestServletTestCase(unittest.TestCase):
     def test_POST_appservice_registration_invalid(self):
         self.appservice = None  # no application service exists
         request_data = json.dumps({"username": "kermit"})
-        request, channel = make_request(
+        request, channel = self.make_request(
             b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
         )
-        render(request, self.resource, self.clock)
+        self.render(request)
 
         self.assertEquals(channel.result["code"], b"401", channel.result)
 
     def test_POST_bad_password(self):
         request_data = json.dumps({"username": "kermit", "password": 666})
-        request, channel = make_request(b"POST", self.url, request_data)
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request(b"POST", self.url, request_data)
+        self.render(request)
 
         self.assertEquals(channel.result["code"], b"400", channel.result)
         self.assertEquals(channel.json_body["error"], "Invalid password")
 
     def test_POST_bad_username(self):
         request_data = json.dumps({"username": 777, "password": "monkey"})
-        request, channel = make_request(b"POST", self.url, request_data)
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request(b"POST", self.url, request_data)
+        self.render(request)
 
         self.assertEquals(channel.result["code"], b"400", channel.result)
         self.assertEquals(channel.json_body["error"], "Invalid username")
 
     def test_POST_user_valid(self):
-        user_id = "@kermit:muppet"
-        token = "kermits_access_token"
+        user_id = "@kermit:test"
         device_id = "frogfone"
-        request_data = json.dumps(
-            {"username": "kermit", "password": "monkey", "device_id": device_id}
-        )
-        self.registration_handler.check_username = Mock(return_value=True)
-        self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
-        self.registration_handler.register = Mock(return_value=(user_id, None))
-        self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
-        self.device_handler.check_device_registered = Mock(return_value=device_id)
-
-        request, channel = make_request(b"POST", self.url, request_data)
-        render(request, self.resource, self.clock)
+        params = {
+            "username": "kermit",
+            "password": "monkey",
+            "device_id": device_id,
+            "auth": {"type": LoginType.DUMMY},
+        }
+        request_data = json.dumps(params)
+        request, channel = self.make_request(b"POST", self.url, request_data)
+        self.render(request)
 
         det_data = {
             "user_id": user_id,
-            "access_token": token,
             "home_server": self.hs.hostname,
             "device_id": device_id,
         }
         self.assertEquals(channel.result["code"], b"200", channel.result)
         self.assertDictContainsSubset(det_data, channel.json_body)
-        self.auth_handler.get_login_tuple_for_user_id(
-            user_id, device_id=device_id, initial_device_display_name=None
-        )
 
     def test_POST_disabled_registration(self):
         self.hs.config.enable_registration = False
         request_data = json.dumps({"username": "kermit", "password": "monkey"})
-        self.registration_handler.check_username = Mock(return_value=True)
         self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
-        self.registration_handler.register = Mock(return_value=("@user:id", "t"))
 
-        request, channel = make_request(b"POST", self.url, request_data)
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request(b"POST", self.url, request_data)
+        self.render(request)
 
         self.assertEquals(channel.result["code"], b"403", channel.result)
         self.assertEquals(channel.json_body["error"], "Registration has been disabled")
 
     def test_POST_guest_registration(self):
-        user_id = "a@b"
         self.hs.config.macaroon_secret_key = "test"
         self.hs.config.allow_guest_access = True
-        self.registration_handler.register = Mock(return_value=(user_id, None))
 
-        request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+        self.render(request)
 
         det_data = {
-            "user_id": user_id,
             "home_server": self.hs.hostname,
             "device_id": "guest_device",
         }
@@ -169,8 +125,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
     def test_POST_disabled_guest_registration(self):
         self.hs.config.allow_guest_access = False
 
-        request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
-        render(request, self.resource, self.clock)
+        request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+        self.render(request)
 
         self.assertEquals(channel.result["code"], b"403", channel.result)
         self.assertEquals(channel.json_body["error"], "Guest access is disabled")
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 4c30c5f258..99b716f00a 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -15,9 +15,11 @@
 
 from mock import Mock
 
+from synapse.rest.client.v1 import admin, login, room
 from synapse.rest.client.v2_alpha import sync
 
 from tests import unittest
+from tests.server import TimedOutException
 
 
 class FilterTestCase(unittest.HomeserverTestCase):
@@ -65,3 +67,124 @@ class FilterTestCase(unittest.HomeserverTestCase):
                 ["next_batch", "rooms", "account_data", "to_device", "device_lists"]
             ).issubset(set(channel.json_body.keys()))
         )
+
+
+class SyncTypingTests(unittest.HomeserverTestCase):
+
+    servlets = [
+        admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+        sync.register_servlets,
+    ]
+    user_id = True
+    hijack_auth = False
+
+    def test_sync_backwards_typing(self):
+        """
+        If the typing serial goes backwards and the typing handler is then reset
+        (such as when the master restarts and sets the typing serial to 0), we
+        do not incorrectly return typing information that had a serial greater
+        than the now-reset serial.
+        """
+        typing_url = "/rooms/%s/typing/%s?access_token=%s"
+        sync_url = "/sync?timeout=3000000&access_token=%s&since=%s"
+
+        # Register the user who gets notified
+        user_id = self.register_user("user", "pass")
+        access_token = self.login("user", "pass")
+
+        # Register the user who sends the message
+        other_user_id = self.register_user("otheruser", "pass")
+        other_access_token = self.login("otheruser", "pass")
+
+        # Create a room
+        room = self.helper.create_room_as(user_id, tok=access_token)
+
+        # Invite the other person
+        self.helper.invite(room=room, src=user_id, tok=access_token, targ=other_user_id)
+
+        # The other user joins
+        self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+
+        # The other user sends some messages
+        self.helper.send(room, body="Hi!", tok=other_access_token)
+        self.helper.send(room, body="There!", tok=other_access_token)
+
+        # Start typing.
+        request, channel = self.make_request(
+            "PUT",
+            typing_url % (room, other_user_id, other_access_token),
+            b'{"typing": true, "timeout": 30000}',
+        )
+        self.render(request)
+        self.assertEquals(200, channel.code)
+
+        request, channel = self.make_request(
+            "GET", "/sync?access_token=%s" % (access_token,)
+        )
+        self.render(request)
+        self.assertEquals(200, channel.code)
+        next_batch = channel.json_body["next_batch"]
+
+        # Stop typing.
+        request, channel = self.make_request(
+            "PUT",
+            typing_url % (room, other_user_id, other_access_token),
+            b'{"typing": false}',
+        )
+        self.render(request)
+        self.assertEquals(200, channel.code)
+
+        # Start typing.
+        request, channel = self.make_request(
+            "PUT",
+            typing_url % (room, other_user_id, other_access_token),
+            b'{"typing": true, "timeout": 30000}',
+        )
+        self.render(request)
+        self.assertEquals(200, channel.code)
+
+        # Should return immediately
+        request, channel = self.make_request(
+            "GET", sync_url % (access_token, next_batch)
+        )
+        self.render(request)
+        self.assertEquals(200, channel.code)
+        next_batch = channel.json_body["next_batch"]
+
+        # Reset typing serial back to 0, as if the master had.
+        typing = self.hs.get_typing_handler()
+        typing._latest_room_serial = 0
+
+        # Since it checks the state token, we need some state to update to
+        # invalidate the stream token.
+        self.helper.send(room, body="There!", tok=other_access_token)
+
+        request, channel = self.make_request(
+            "GET", sync_url % (access_token, next_batch)
+        )
+        self.render(request)
+        self.assertEquals(200, channel.code)
+        next_batch = channel.json_body["next_batch"]
+
+        # This should time out! But it does not, because our stream token is
+        # ahead, and therefore it's saying the typing (that we've actually
+        # already seen) is new, since it's got a token above our new, now-reset
+        # stream token.
+        request, channel = self.make_request(
+            "GET", sync_url % (access_token, next_batch)
+        )
+        self.render(request)
+        self.assertEquals(200, channel.code)
+        next_batch = channel.json_body["next_batch"]
+
+        # Clear the typing information, so that it doesn't think everything is
+        # in the future.
+        typing._reset()
+
+        # Now it SHOULD fail as it never completes!
+        request, channel = self.make_request(
+            "GET", sync_url % (access_token, next_batch)
+        )
+        self.assertRaises(TimedOutException, self.render, request)
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index a86901c2d8..ad5e9a612f 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -17,15 +17,21 @@
 import os
 import shutil
 import tempfile
+from binascii import unhexlify
 
 from mock import Mock
+from six.moves.urllib import parse
 
 from twisted.internet import defer, reactor
+from twisted.internet.defer import Deferred
 
+from synapse.config.repository import MediaStorageProviderConfig
 from synapse.rest.media.v1._base import FileInfo
 from synapse.rest.media.v1.filepath import MediaFilePaths
 from synapse.rest.media.v1.media_storage import MediaStorage
 from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
+from synapse.util.logcontext import make_deferred_yieldable
+from synapse.util.module_loader import load_module
 
 from tests import unittest
 
@@ -83,3 +89,143 @@ class MediaStorageTests(unittest.TestCase):
             body = f.read()
 
         self.assertEqual(test_body, body)
+
+
+class MediaRepoTests(unittest.HomeserverTestCase):
+
+    hijack_auth = True
+    user_id = "@test:user"
+
+    def make_homeserver(self, reactor, clock):
+
+        self.fetches = []
+
+        def get_file(destination, path, output_stream, args=None, max_size=None):
+            """
+            Returns tuple[int,dict,str,int] of file length, response headers,
+            absolute URI, and response code.
+            """
+
+            def write_to(r):
+                data, response = r
+                output_stream.write(data)
+                return response
+
+            d = Deferred()
+            d.addCallback(write_to)
+            self.fetches.append((d, destination, path, args))
+            return make_deferred_yieldable(d)
+
+        client = Mock()
+        client.get_file = get_file
+
+        self.storage_path = self.mktemp()
+        os.mkdir(self.storage_path)
+
+        config = self.default_config()
+        config.media_store_path = self.storage_path
+        config.thumbnail_requirements = {}
+        config.max_image_pixels = 2000000
+
+        provider_config = {
+            "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
+            "store_local": True,
+            "store_synchronous": False,
+            "store_remote": True,
+            "config": {"directory": self.storage_path},
+        }
+
+        loaded = list(load_module(provider_config)) + [
+            MediaStorageProviderConfig(False, False, False)
+        ]
+
+        config.media_storage_providers = [loaded]
+
+        hs = self.setup_test_homeserver(config=config, http_client=client)
+
+        return hs
+
+    def prepare(self, reactor, clock, hs):
+
+        self.media_repo = hs.get_media_repository_resource()
+        self.download_resource = self.media_repo.children[b'download']
+
+        # smol png
+        self.end_content = unhexlify(
+            b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+            b"0000001f15c4890000000a49444154789c63000100000500010d"
+            b"0a2db40000000049454e44ae426082"
+        )
+
+    def _req(self, content_disposition):
+
+        request, channel = self.make_request(
+            "GET", "example.com/12345", shorthand=False
+        )
+        request.render(self.download_resource)
+        self.pump()
+
+        # We've made one fetch, to example.com, using the media URL, and asking
+        # the other server not to do a remote fetch
+        self.assertEqual(len(self.fetches), 1)
+        self.assertEqual(self.fetches[0][1], "example.com")
+        self.assertEqual(
+            self.fetches[0][2], "/_matrix/media/v1/download/example.com/12345"
+        )
+        self.assertEqual(self.fetches[0][3], {"allow_remote": "false"})
+
+        headers = {
+            b"Content-Length": [b"%d" % (len(self.end_content))],
+            b"Content-Type": [b'image/png'],
+        }
+        if content_disposition:
+            headers[b"Content-Disposition"] = [content_disposition]
+
+        self.fetches[0][0].callback(
+            (self.end_content, (len(self.end_content), headers))
+        )
+
+        self.pump()
+        self.assertEqual(channel.code, 200)
+
+        return channel
+
+    def test_disposition_filename_ascii(self):
+        """
+        If the filename is filename=<ascii> then Synapse will decode it as an
+        ASCII string, and use filename= in the response.
+        """
+        channel = self._req(b"inline; filename=out.png")
+
+        headers = channel.headers
+        self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"])
+        self.assertEqual(
+            headers.getRawHeaders(b"Content-Disposition"), [b"inline; filename=out.png"]
+        )
+
+    def test_disposition_filenamestar_utf8escaped(self):
+        """
+        If the filename is filename=*utf8''<utf8 escaped> then Synapse will
+        correctly decode it as the UTF-8 string, and use filename* in the
+        response.
+        """
+        filename = parse.quote(u"\u2603".encode('utf8')).encode('ascii')
+        channel = self._req(b"inline; filename*=utf-8''" + filename + b".png")
+
+        headers = channel.headers
+        self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"])
+        self.assertEqual(
+            headers.getRawHeaders(b"Content-Disposition"),
+            [b"inline; filename*=utf-8''" + filename + b".png"],
+        )
+
+    def test_disposition_none(self):
+        """
+        If there is no filename, one isn't passed on in the Content-Disposition
+        of the request.
+        """
+        channel = self._req(None)
+
+        headers = channel.headers
+        self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"])
+        self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
new file mode 100644
index 0000000000..650ce95a6f
--- /dev/null
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -0,0 +1,470 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+import attr
+from netaddr import IPSet
+
+from twisted.internet._resolver import HostResolution
+from twisted.internet.address import IPv4Address, IPv6Address
+from twisted.internet.error import DNSLookupError
+from twisted.python.failure import Failure
+from twisted.test.proto_helpers import AccumulatingProtocol
+from twisted.web._newclient import ResponseDone
+
+from synapse.config.repository import MediaStorageProviderConfig
+from synapse.util.module_loader import load_module
+
+from tests import unittest
+from tests.server import FakeTransport
+
+
+@attr.s
+class FakeResponse(object):
+    version = attr.ib()
+    code = attr.ib()
+    phrase = attr.ib()
+    headers = attr.ib()
+    body = attr.ib()
+    absoluteURI = attr.ib()
+
+    @property
+    def request(self):
+        @attr.s
+        class FakeTransport(object):
+            absoluteURI = self.absoluteURI
+
+        return FakeTransport()
+
+    def deliverBody(self, protocol):
+        protocol.dataReceived(self.body)
+        protocol.connectionLost(Failure(ResponseDone()))
+
+
+class URLPreviewTests(unittest.HomeserverTestCase):
+
+    hijack_auth = True
+    user_id = "@test:user"
+    end_content = (
+        b'<html><head>'
+        b'<meta property="og:title" content="~matrix~" />'
+        b'<meta property="og:description" content="hi" />'
+        b'</head></html>'
+    )
+
+    def make_homeserver(self, reactor, clock):
+
+        self.storage_path = self.mktemp()
+        os.mkdir(self.storage_path)
+
+        config = self.default_config()
+        config.url_preview_enabled = True
+        config.max_spider_size = 9999999
+        config.url_preview_ip_range_blacklist = IPSet(
+            (
+                "192.168.1.1",
+                "1.0.0.0/8",
+                "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff",
+                "2001:800::/21",
+            )
+        )
+        config.url_preview_ip_range_whitelist = IPSet(("1.1.1.1",))
+        config.url_preview_url_blacklist = []
+        config.media_store_path = self.storage_path
+
+        provider_config = {
+            "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
+            "store_local": True,
+            "store_synchronous": False,
+            "store_remote": True,
+            "config": {"directory": self.storage_path},
+        }
+
+        loaded = list(load_module(provider_config)) + [
+            MediaStorageProviderConfig(False, False, False)
+        ]
+
+        config.media_storage_providers = [loaded]
+
+        hs = self.setup_test_homeserver(config=config)
+
+        return hs
+
+    def prepare(self, reactor, clock, hs):
+
+        self.media_repo = hs.get_media_repository_resource()
+        self.preview_url = self.media_repo.children[b'preview_url']
+
+        self.lookups = {}
+
+        class Resolver(object):
+            def resolveHostName(
+                _self,
+                resolutionReceiver,
+                hostName,
+                portNumber=0,
+                addressTypes=None,
+                transportSemantics='TCP',
+            ):
+
+                resolution = HostResolution(hostName)
+                resolutionReceiver.resolutionBegan(resolution)
+                if hostName not in self.lookups:
+                    raise DNSLookupError("OH NO")
+
+                for i in self.lookups[hostName]:
+                    resolutionReceiver.addressResolved(i[0]('TCP', i[1], portNumber))
+                resolutionReceiver.resolutionComplete()
+                return resolutionReceiver
+
+        self.reactor.nameResolver = Resolver()
+
+    def test_cache_returns_correct_type(self):
+        self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+
+        request, channel = self.make_request(
+            "GET", "url_preview?url=http://matrix.org", shorthand=False
+        )
+        request.render(self.preview_url)
+        self.pump()
+
+        client = self.reactor.tcpClients[0][2].buildProtocol(None)
+        server = AccumulatingProtocol()
+        server.makeConnection(FakeTransport(client, self.reactor))
+        client.makeConnection(FakeTransport(server, self.reactor))
+        client.dataReceived(
+            b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
+            % (len(self.end_content),)
+            + self.end_content
+        )
+
+        self.pump()
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(
+            channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
+        )
+
+        # Check the cache returns the correct response
+        request, channel = self.make_request(
+            "GET", "url_preview?url=http://matrix.org", shorthand=False
+        )
+        request.render(self.preview_url)
+        self.pump()
+
+        # Check the cache response has the same content
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(
+            channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
+        )
+
+        # Clear the in-memory cache
+        self.assertIn("http://matrix.org", self.preview_url._cache)
+        self.preview_url._cache.pop("http://matrix.org")
+        self.assertNotIn("http://matrix.org", self.preview_url._cache)
+
+        # Check the database cache returns the correct response
+        request, channel = self.make_request(
+            "GET", "url_preview?url=http://matrix.org", shorthand=False
+        )
+        request.render(self.preview_url)
+        self.pump()
+
+        # Check the cache response has the same content
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(
+            channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
+        )
+
+    def test_non_ascii_preview_httpequiv(self):
+        self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+
+        end_content = (
+            b'<html><head>'
+            b'<meta http-equiv="Content-Type" content="text/html; charset=windows-1251"/>'
+            b'<meta property="og:title" content="\xe4\xea\xe0" />'
+            b'<meta property="og:description" content="hi" />'
+            b'</head></html>'
+        )
+
+        request, channel = self.make_request(
+            "GET", "url_preview?url=http://matrix.org", shorthand=False
+        )
+        request.render(self.preview_url)
+        self.pump()
+
+        client = self.reactor.tcpClients[0][2].buildProtocol(None)
+        server = AccumulatingProtocol()
+        server.makeConnection(FakeTransport(client, self.reactor))
+        client.makeConnection(FakeTransport(server, self.reactor))
+        client.dataReceived(
+            (
+                b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+                b"Content-Type: text/html; charset=\"utf8\"\r\n\r\n"
+            )
+            % (len(end_content),)
+            + end_content
+        )
+
+        self.pump()
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430")
+
+    def test_non_ascii_preview_content_type(self):
+        self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+
+        end_content = (
+            b'<html><head>'
+            b'<meta property="og:title" content="\xe4\xea\xe0" />'
+            b'<meta property="og:description" content="hi" />'
+            b'</head></html>'
+        )
+
+        request, channel = self.make_request(
+            "GET", "url_preview?url=http://matrix.org", shorthand=False
+        )
+        request.render(self.preview_url)
+        self.pump()
+
+        client = self.reactor.tcpClients[0][2].buildProtocol(None)
+        server = AccumulatingProtocol()
+        server.makeConnection(FakeTransport(client, self.reactor))
+        client.makeConnection(FakeTransport(server, self.reactor))
+        client.dataReceived(
+            (
+                b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+                b"Content-Type: text/html; charset=\"windows-1251\"\r\n\r\n"
+            )
+            % (len(end_content),)
+            + end_content
+        )
+
+        self.pump()
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430")
+
+    def test_ipaddr(self):
+        """
+        IP addresses can be previewed directly.
+        """
+        self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")]
+
+        request, channel = self.make_request(
+            "GET", "url_preview?url=http://example.com", shorthand=False
+        )
+        request.render(self.preview_url)
+        self.pump()
+
+        client = self.reactor.tcpClients[0][2].buildProtocol(None)
+        server = AccumulatingProtocol()
+        server.makeConnection(FakeTransport(client, self.reactor))
+        client.makeConnection(FakeTransport(server, self.reactor))
+        client.dataReceived(
+            b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
+            % (len(self.end_content),)
+            + self.end_content
+        )
+
+        self.pump()
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(
+            channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
+        )
+
+    def test_blacklisted_ip_specific(self):
+        """
+        Blacklisted IP addresses, found via DNS, are not spidered.
+        """
+        self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")]
+
+        request, channel = self.make_request(
+            "GET", "url_preview?url=http://example.com", shorthand=False
+        )
+        request.render(self.preview_url)
+        self.pump()
+
+        # No requests made.
+        self.assertEqual(len(self.reactor.tcpClients), 0)
+        self.assertEqual(channel.code, 403)
+        self.assertEqual(
+            channel.json_body,
+            {
+                'errcode': 'M_UNKNOWN',
+                'error': 'IP address blocked by IP blacklist entry',
+            },
+        )
+
+    def test_blacklisted_ip_range(self):
+        """
+        Blacklisted IP ranges, IPs found over DNS, are not spidered.
+        """
+        self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")]
+
+        request, channel = self.make_request(
+            "GET", "url_preview?url=http://example.com", shorthand=False
+        )
+        request.render(self.preview_url)
+        self.pump()
+
+        self.assertEqual(channel.code, 403)
+        self.assertEqual(
+            channel.json_body,
+            {
+                'errcode': 'M_UNKNOWN',
+                'error': 'IP address blocked by IP blacklist entry',
+            },
+        )
+
+    def test_blacklisted_ip_specific_direct(self):
+        """
+        Blacklisted IP addresses, accessed directly, are not spidered.
+        """
+        request, channel = self.make_request(
+            "GET", "url_preview?url=http://192.168.1.1", shorthand=False
+        )
+        request.render(self.preview_url)
+        self.pump()
+
+        # No requests made.
+        self.assertEqual(len(self.reactor.tcpClients), 0)
+        self.assertEqual(channel.code, 403)
+        self.assertEqual(
+            channel.json_body,
+            {
+                'errcode': 'M_UNKNOWN',
+                'error': 'IP address blocked by IP blacklist entry',
+            },
+        )
+
+    def test_blacklisted_ip_range_direct(self):
+        """
+        Blacklisted IP ranges, accessed directly, are not spidered.
+        """
+        request, channel = self.make_request(
+            "GET", "url_preview?url=http://1.1.1.2", shorthand=False
+        )
+        request.render(self.preview_url)
+        self.pump()
+
+        self.assertEqual(channel.code, 403)
+        self.assertEqual(
+            channel.json_body,
+            {
+                'errcode': 'M_UNKNOWN',
+                'error': 'IP address blocked by IP blacklist entry',
+            },
+        )
+
+    def test_blacklisted_ip_range_whitelisted_ip(self):
+        """
+        Blacklisted but then subsequently whitelisted IP addresses can be
+        spidered.
+        """
+        self.lookups["example.com"] = [(IPv4Address, "1.1.1.1")]
+
+        request, channel = self.make_request(
+            "GET", "url_preview?url=http://example.com", shorthand=False
+        )
+        request.render(self.preview_url)
+        self.pump()
+
+        client = self.reactor.tcpClients[0][2].buildProtocol(None)
+
+        server = AccumulatingProtocol()
+        server.makeConnection(FakeTransport(client, self.reactor))
+        client.makeConnection(FakeTransport(server, self.reactor))
+
+        client.dataReceived(
+            b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
+            % (len(self.end_content),)
+            + self.end_content
+        )
+
+        self.pump()
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(
+            channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
+        )
+
+    def test_blacklisted_ip_with_external_ip(self):
+        """
+        If a hostname resolves a blacklisted IP, even if there's a
+        non-blacklisted one, it will be rejected.
+        """
+        # Hardcode the URL resolving to the IP we want.
+        self.lookups[u"example.com"] = [
+            (IPv4Address, "1.1.1.2"),
+            (IPv4Address, "8.8.8.8"),
+        ]
+
+        request, channel = self.make_request(
+            "GET", "url_preview?url=http://example.com", shorthand=False
+        )
+        request.render(self.preview_url)
+        self.pump()
+        self.assertEqual(channel.code, 403)
+        self.assertEqual(
+            channel.json_body,
+            {
+                'errcode': 'M_UNKNOWN',
+                'error': 'IP address blocked by IP blacklist entry',
+            },
+        )
+
+    def test_blacklisted_ipv6_specific(self):
+        """
+        Blacklisted IP addresses, found via DNS, are not spidered.
+        """
+        self.lookups["example.com"] = [
+            (IPv6Address, "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")
+        ]
+
+        request, channel = self.make_request(
+            "GET", "url_preview?url=http://example.com", shorthand=False
+        )
+        request.render(self.preview_url)
+        self.pump()
+
+        # No requests made.
+        self.assertEqual(len(self.reactor.tcpClients), 0)
+        self.assertEqual(channel.code, 403)
+        self.assertEqual(
+            channel.json_body,
+            {
+                'errcode': 'M_UNKNOWN',
+                'error': 'IP address blocked by IP blacklist entry',
+            },
+        )
+
+    def test_blacklisted_ipv6_range(self):
+        """
+        Blacklisted IP ranges, IPs found over DNS, are not spidered.
+        """
+        self.lookups["example.com"] = [(IPv6Address, "2001:800::1")]
+
+        request, channel = self.make_request(
+            "GET", "url_preview?url=http://example.com", shorthand=False
+        )
+        request.render(self.preview_url)
+        self.pump()
+
+        self.assertEqual(channel.code, 403)
+        self.assertEqual(
+            channel.json_body,
+            {
+                'errcode': 'M_UNKNOWN',
+                'error': 'IP address blocked by IP blacklist entry',
+            },
+        )
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
new file mode 100644
index 0000000000..8d8f03e005
--- /dev/null
+++ b/tests/rest/test_well_known.py
@@ -0,0 +1,58 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.rest.well_known import WellKnownResource
+
+from tests import unittest
+
+
+class WellKnownTests(unittest.HomeserverTestCase):
+    def setUp(self):
+        super(WellKnownTests, self).setUp()
+
+        # replace the JsonResource with a WellKnownResource
+        self.resource = WellKnownResource(self.hs)
+
+    def test_well_known(self):
+        self.hs.config.public_baseurl = "https://tesths"
+        self.hs.config.default_identity_server = "https://testis"
+
+        request, channel = self.make_request(
+            "GET",
+            "/.well-known/matrix/client",
+            shorthand=False,
+        )
+        self.render(request)
+
+        self.assertEqual(request.code, 200)
+        self.assertEqual(
+            channel.json_body, {
+                "m.homeserver": {"base_url": "https://tesths"},
+                "m.identity_server": {"base_url": "https://testis"},
+            }
+        )
+
+    def test_well_known_no_public_baseurl(self):
+        self.hs.config.public_baseurl = None
+
+        request, channel = self.make_request(
+            "GET",
+            "/.well-known/matrix/client",
+            shorthand=False,
+        )
+        self.render(request)
+
+        self.assertEqual(request.code, 404)