summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/__init__.py2
-rw-r--r--tests/api/test_auth.py12
-rw-r--r--tests/config/test_server.py8
-rw-r--r--tests/config/test_tls.py4
-rw-r--r--tests/crypto/test_event_signing.py46
-rw-r--r--tests/crypto/test_keyring.py334
-rw-r--r--tests/events/test_utils.py94
-rw-r--r--tests/federation/test_complexity.py90
-rw-r--r--tests/federation/test_federation_sender.py48
-rw-r--r--tests/handlers/test_auth.py10
-rw-r--r--tests/handlers/test_directory.py8
-rw-r--r--tests/handlers/test_e2e_room_keys.py18
-rw-r--r--tests/handlers/test_register.py31
-rw-r--r--tests/handlers/test_stats.py304
-rw-r--r--tests/handlers/test_typing.py16
-rw-r--r--tests/handlers/test_user_directory.py8
-rw-r--r--tests/http/__init__.py124
-rw-r--r--tests/http/ca.crt19
-rw-r--r--tests/http/ca.key27
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py358
-rw-r--r--tests/http/federation/test_srv_resolver.py2
-rw-r--r--tests/http/server.key27
-rw-r--r--tests/http/server.pem81
-rw-r--r--tests/http/test_endpoint.py12
-rw-r--r--tests/http/test_fedclient.py19
-rw-r--r--tests/push/test_email.py99
-rw-r--r--tests/push/test_http.py6
-rw-r--r--tests/rest/admin/test_admin.py101
-rw-r--r--tests/rest/client/test_consent.py16
-rw-r--r--tests/rest/client/test_identity.py2
-rw-r--r--tests/rest/client/third_party_rules.py79
-rw-r--r--tests/rest/client/v1/test_profile.py63
-rw-r--r--tests/rest/client/v1/test_rooms.py48
-rw-r--r--tests/rest/client/v1/utils.py23
-rw-r--r--tests/rest/client/v2_alpha/test_account.py281
-rw-r--r--tests/rest/client/v2_alpha/test_capabilities.py19
-rw-r--r--tests/rest/client/v2_alpha/test_register.py112
-rw-r--r--tests/rest/client/v2_alpha/test_relations.py10
-rw-r--r--tests/rest/media/v1/test_base.py14
-rw-r--r--tests/rest/media/v1/test_media_storage.py6
-rw-r--r--tests/rest/media/v1/test_url_preview.py56
-rw-r--r--tests/server.py16
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py4
-rw-r--r--tests/state/test_v2.py20
-rw-r--r--tests/storage/test_appservice.py4
-rw-r--r--tests/storage/test_cleanup_extrems.py165
-rw-r--r--tests/storage/test_client_ips.py20
-rw-r--r--tests/storage/test_devices.py69
-rw-r--r--tests/storage/test_end_to_end_keys.py8
-rw-r--r--tests/storage/test_event_federation.py14
-rw-r--r--tests/storage/test_event_metrics.py84
-rw-r--r--tests/storage/test_keys.py70
-rw-r--r--tests/storage/test_monthly_active_users.py22
-rw-r--r--tests/storage/test_redaction.py4
-rw-r--r--tests/storage/test_registration.py2
-rw-r--r--tests/storage/test_room.py4
-rw-r--r--tests/storage/test_state.py16
-rw-r--r--tests/test_preview.py136
-rw-r--r--tests/test_server.py14
-rw-r--r--tests/test_state.py6
-rw-r--r--tests/test_types.py4
-rw-r--r--tests/test_utils/logging_setup.py2
-rw-r--r--tests/test_visibility.py2
-rw-r--r--tests/unittest.py85
-rw-r--r--tests/util/caches/test_descriptors.py60
-rw-r--r--tests/util/caches/test_ttlcache.py46
-rw-r--r--tests/utils.py24
67 files changed, 2593 insertions, 945 deletions
diff --git a/tests/__init__.py b/tests/__init__.py
index d3181f9403..f7fc502f01 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -21,4 +21,4 @@ import tests.patch_inline_callbacks
 # attempt to do the patch before we load any synapse code
 tests.patch_inline_callbacks.do_patch()
 
-util.DEFAULT_TIMEOUT_DURATION = 10
+util.DEFAULT_TIMEOUT_DURATION = 20
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index d0d36f96fa..d4e75b5b2e 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -172,7 +172,7 @@ class AuthTestCase(unittest.TestCase):
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         requester = yield self.auth.get_user_by_req(request)
         self.assertEquals(
-            requester.user.to_string(), masquerading_user_id.decode('utf8')
+            requester.user.to_string(), masquerading_user_id.decode("utf8")
         )
 
     def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
@@ -264,7 +264,7 @@ class AuthTestCase(unittest.TestCase):
 
         # check the token works
         request = Mock(args={})
-        request.args[b"access_token"] = [token.encode('ascii')]
+        request.args[b"access_token"] = [token.encode("ascii")]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         requester = yield self.auth.get_user_by_req(request, allow_guest=True)
         self.assertEqual(UserID.from_string(USER_ID), requester.user)
@@ -277,7 +277,7 @@ class AuthTestCase(unittest.TestCase):
 
         # the token should *not* work now
         request = Mock(args={})
-        request.args[b"access_token"] = [guest_tok.encode('ascii')]
+        request.args[b"access_token"] = [guest_tok.encode("ascii")]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
 
         with self.assertRaises(AuthError) as cm:
@@ -321,11 +321,11 @@ class AuthTestCase(unittest.TestCase):
         self.hs.config.limit_usage_by_mau = True
         self.hs.config.max_mau_value = 1
         self.store.get_monthly_active_count = lambda: defer.succeed(2)
-        threepid = {'medium': 'email', 'address': 'reserved@server.com'}
-        unknown_threepid = {'medium': 'email', 'address': 'unreserved@server.com'}
+        threepid = {"medium": "email", "address": "reserved@server.com"}
+        unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
         self.hs.config.mau_limits_reserved_threepids = [threepid]
 
-        yield self.store.register(user_id='user1', token="123", password_hash=None)
+        yield self.store.register(user_id="user1", token="123", password_hash=None)
         with self.assertRaises(ResourceLimitError):
             yield self.auth.check_auth_blocking()
 
diff --git a/tests/config/test_server.py b/tests/config/test_server.py
index de64965a60..1ca5ea54ca 100644
--- a/tests/config/test_server.py
+++ b/tests/config/test_server.py
@@ -20,10 +20,10 @@ from tests import unittest
 
 class ServerConfigTestCase(unittest.TestCase):
     def test_is_threepid_reserved(self):
-        user1 = {'medium': 'email', 'address': 'user1@example.com'}
-        user2 = {'medium': 'email', 'address': 'user2@example.com'}
-        user3 = {'medium': 'email', 'address': 'user3@example.com'}
-        user1_msisdn = {'medium': 'msisdn', 'address': '447700000000'}
+        user1 = {"medium": "email", "address": "user1@example.com"}
+        user2 = {"medium": "email", "address": "user2@example.com"}
+        user3 = {"medium": "email", "address": "user3@example.com"}
+        user1_msisdn = {"medium": "msisdn", "address": "447700000000"}
         config = [user1, user2]
 
         self.assertTrue(is_threepid_reserved(config, user1))
diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py
index 40ca428778..a5d88d644a 100644
--- a/tests/config/test_tls.py
+++ b/tests/config/test_tls.py
@@ -32,7 +32,7 @@ class TLSConfigTests(TestCase):
         """
         config_dir = self.mktemp()
         os.mkdir(config_dir)
-        with open(os.path.join(config_dir, "cert.pem"), 'w') as f:
+        with open(os.path.join(config_dir, "cert.pem"), "w") as f:
             f.write(
                 """-----BEGIN CERTIFICATE-----
 MIID6DCCAtACAws9CjANBgkqhkiG9w0BAQUFADCBtzELMAkGA1UEBhMCVFIxDzAN
@@ -65,7 +65,7 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
         }
 
         t = TestConfig()
-        t.read_config(config)
+        t.read_config(config, config_dir_path="", data_dir_path="")
         t.read_certificate_from_disk(require_cert_and_key=False)
 
         warnings = self.flushWarnings()
diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py
index 71aa731439..126e176004 100644
--- a/tests/crypto/test_event_signing.py
+++ b/tests/crypto/test_event_signing.py
@@ -41,25 +41,25 @@ class EventSigningTestCase(unittest.TestCase):
 
     def test_sign_minimal(self):
         event_dict = {
-            'event_id': "$0:domain",
-            'origin': "domain",
-            'origin_server_ts': 1000000,
-            'signatures': {},
-            'type': "X",
-            'unsigned': {'age_ts': 1000000},
+            "event_id": "$0:domain",
+            "origin": "domain",
+            "origin_server_ts": 1000000,
+            "signatures": {},
+            "type": "X",
+            "unsigned": {"age_ts": 1000000},
         }
 
         add_hashes_and_signatures(event_dict, HOSTNAME, self.signing_key)
 
         event = FrozenEvent(event_dict)
 
-        self.assertTrue(hasattr(event, 'hashes'))
-        self.assertIn('sha256', event.hashes)
+        self.assertTrue(hasattr(event, "hashes"))
+        self.assertIn("sha256", event.hashes)
         self.assertEquals(
-            event.hashes['sha256'], "6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI"
+            event.hashes["sha256"], "6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI"
         )
 
-        self.assertTrue(hasattr(event, 'signatures'))
+        self.assertTrue(hasattr(event, "signatures"))
         self.assertIn(HOSTNAME, event.signatures)
         self.assertIn(KEY_NAME, event.signatures["domain"])
         self.assertEquals(
@@ -70,28 +70,28 @@ class EventSigningTestCase(unittest.TestCase):
 
     def test_sign_message(self):
         event_dict = {
-            'content': {'body': "Here is the message content"},
-            'event_id': "$0:domain",
-            'origin': "domain",
-            'origin_server_ts': 1000000,
-            'type': "m.room.message",
-            'room_id': "!r:domain",
-            'sender': "@u:domain",
-            'signatures': {},
-            'unsigned': {'age_ts': 1000000},
+            "content": {"body": "Here is the message content"},
+            "event_id": "$0:domain",
+            "origin": "domain",
+            "origin_server_ts": 1000000,
+            "type": "m.room.message",
+            "room_id": "!r:domain",
+            "sender": "@u:domain",
+            "signatures": {},
+            "unsigned": {"age_ts": 1000000},
         }
 
         add_hashes_and_signatures(event_dict, HOSTNAME, self.signing_key)
 
         event = FrozenEvent(event_dict)
 
-        self.assertTrue(hasattr(event, 'hashes'))
-        self.assertIn('sha256', event.hashes)
+        self.assertTrue(hasattr(event, "hashes"))
+        self.assertIn("sha256", event.hashes)
         self.assertEquals(
-            event.hashes['sha256'], "onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g"
+            event.hashes["sha256"], "onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g"
         )
 
-        self.assertTrue(hasattr(event, 'signatures'))
+        self.assertTrue(hasattr(event, "signatures"))
         self.assertIn(HOSTNAME, event.signatures)
         self.assertIn(KEY_NAME, event.signatures["domain"])
         self.assertEquals(
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 3c79d4afe7..5a355f00cc 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -19,12 +19,18 @@ from mock import Mock
 import canonicaljson
 import signedjson.key
 import signedjson.sign
+from signedjson.key import encode_verify_key_base64, get_verify_key
 
 from twisted.internet import defer
 
 from synapse.api.errors import SynapseError
 from synapse.crypto import keyring
-from synapse.crypto.keyring import KeyLookupError
+from synapse.crypto.keyring import (
+    PerspectivesKeyFetcher,
+    ServerKeyFetcher,
+    StoreKeyFetcher,
+)
+from synapse.storage.keys import FetchKeyResult
 from synapse.util import logcontext
 from synapse.util.logcontext import LoggingContext
 
@@ -38,7 +44,7 @@ class MockPerspectiveServer(object):
 
     def get_verify_keys(self):
         vk = signedjson.key.get_verify_key(self.key)
-        return {"%s:%s" % (vk.alg, vk.version): vk}
+        return {"%s:%s" % (vk.alg, vk.version): encode_verify_key_base64(vk)}
 
     def get_signed_key(self, server_name, verify_key):
         key_id = "%s:%s" % (verify_key.alg, verify_key.version)
@@ -46,25 +52,31 @@ class MockPerspectiveServer(object):
             "server_name": server_name,
             "old_verify_keys": {},
             "valid_until_ts": time.time() * 1000 + 3600,
-            "verify_keys": {
-                key_id: {"key": signedjson.key.encode_verify_key_base64(verify_key)}
-            },
+            "verify_keys": {key_id: {"key": encode_verify_key_base64(verify_key)}},
         }
-        return self.get_signed_response(res)
+        self.sign_response(res)
+        return res
 
-    def get_signed_response(self, res):
+    def sign_response(self, res):
         signedjson.sign.sign_json(res, self.server_name, self.key)
-        return res
 
 
 class KeyringTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
         self.mock_perspective_server = MockPerspectiveServer()
         self.http_client = Mock()
-        hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
-        keys = self.mock_perspective_server.get_verify_keys()
-        hs.config.perspectives = {self.mock_perspective_server.server_name: keys}
-        return hs
+
+        config = self.default_config()
+        config["trusted_key_servers"] = [
+            {
+                "server_name": self.mock_perspective_server.server_name,
+                "verify_keys": self.mock_perspective_server.get_verify_keys(),
+            }
+        ]
+
+        return self.setup_test_homeserver(
+            handlers=None, http_client=self.http_client, config=config
+        )
 
     def check_context(self, _, expected):
         self.assertEquals(
@@ -80,7 +92,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         # we run the lookup in a logcontext so that the patched inlineCallbacks can check
         # it is doing the right thing with logcontexts.
         wait_1_deferred = run_in_context(
-            kr.wait_for_previous_lookups, ["server1"], {"server1": lookup_1_deferred}
+            kr.wait_for_previous_lookups, {"server1": lookup_1_deferred}
         )
 
         # there were no previous lookups, so the deferred should be ready
@@ -89,7 +101,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         # set off another wait. It should block because the first lookup
         # hasn't yet completed.
         wait_2_deferred = run_in_context(
-            kr.wait_for_previous_lookups, ["server1"], {"server1": lookup_2_deferred}
+            kr.wait_for_previous_lookups, {"server1": lookup_2_deferred}
         )
 
         self.assertFalse(wait_2_deferred.called)
@@ -132,7 +144,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
                 context_11.request = "11"
 
                 res_deferreds = kr.verify_json_objects_for_server(
-                    [("server10", json1), ("server11", {})]
+                    [("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")]
                 )
 
                 # the unsigned json should be rejected pretty quickly
@@ -169,7 +181,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
                 self.http_client.post_json.return_value = defer.Deferred()
 
                 res_deferreds_2 = kr.verify_json_objects_for_server(
-                    [("server10", json1)]
+                    [("server10", json1, 0, "test")]
                 )
                 res_deferreds_2[0].addBoth(self.check_context, None)
                 yield logcontext.make_deferred_yieldable(res_deferreds_2[0])
@@ -192,31 +204,169 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         kr = keyring.Keyring(self.hs)
 
         key1 = signedjson.key.generate_signing_key(1)
-        r = self.hs.datastore.store_server_verify_key(
-            "server9", "", time.time() * 1000, signedjson.key.get_verify_key(key1)
+        r = self.hs.datastore.store_server_verify_keys(
+            "server9",
+            time.time() * 1000,
+            [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
+        )
+        self.get_success(r)
+
+        json1 = {}
+        signedjson.sign.sign_json(json1, "server9", key1)
+
+        # should fail immediately on an unsigned object
+        d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
+        self.failureResultOf(d, SynapseError)
+
+        # should suceed on a signed object
+        d = _verify_json_for_server(kr, "server9", json1, 500, "test signed")
+        # self.assertFalse(d.called)
+        self.get_success(d)
+
+    def test_verify_json_for_server_with_null_valid_until_ms(self):
+        """Tests that we correctly handle key requests for keys we've stored
+        with a null `ts_valid_until_ms`
+        """
+        mock_fetcher = keyring.KeyFetcher()
+        mock_fetcher.get_keys = Mock(return_value=defer.succeed({}))
+
+        kr = keyring.Keyring(
+            self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher)
+        )
+
+        key1 = signedjson.key.generate_signing_key(1)
+        r = self.hs.datastore.store_server_verify_keys(
+            "server9",
+            time.time() * 1000,
+            [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), None))],
         )
         self.get_success(r)
+
         json1 = {}
         signedjson.sign.sign_json(json1, "server9", key1)
 
         # should fail immediately on an unsigned object
-        d = _verify_json_for_server(kr, "server9", {})
+        d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
         self.failureResultOf(d, SynapseError)
 
-        d = _verify_json_for_server(kr, "server9", json1)
-        self.assertFalse(d.called)
+        # should fail on a signed object with a non-zero minimum_valid_until_ms,
+        # as it tries to refetch the keys and fails.
+        d = _verify_json_for_server(
+            kr, "server9", json1, 500, "test signed non-zero min"
+        )
+        self.get_failure(d, SynapseError)
+
+        # We expect the keyring tried to refetch the key once.
+        mock_fetcher.get_keys.assert_called_once_with(
+            {"server9": {get_key_id(key1): 500}}
+        )
+
+        # should succeed on a signed object with a 0 minimum_valid_until_ms
+        d = _verify_json_for_server(
+            kr, "server9", json1, 0, "test signed with zero min"
+        )
         self.get_success(d)
 
+    def test_verify_json_dedupes_key_requests(self):
+        """Two requests for the same key should be deduped."""
+        key1 = signedjson.key.generate_signing_key(1)
+
+        def get_keys(keys_to_fetch):
+            # there should only be one request object (with the max validity)
+            self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
+
+            return defer.succeed(
+                {
+                    "server1": {
+                        get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
+                    }
+                }
+            )
+
+        mock_fetcher = keyring.KeyFetcher()
+        mock_fetcher.get_keys = Mock(side_effect=get_keys)
+        kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
+
+        json1 = {}
+        signedjson.sign.sign_json(json1, "server1", key1)
+
+        # the first request should succeed; the second should fail because the key
+        # has expired
+        results = kr.verify_json_objects_for_server(
+            [("server1", json1, 500, "test1"), ("server1", json1, 1500, "test2")]
+        )
+        self.assertEqual(len(results), 2)
+        self.get_success(results[0])
+        e = self.get_failure(results[1], SynapseError).value
+        self.assertEqual(e.errcode, "M_UNAUTHORIZED")
+        self.assertEqual(e.code, 401)
+
+        # there should have been a single call to the fetcher
+        mock_fetcher.get_keys.assert_called_once()
+
+    def test_verify_json_falls_back_to_other_fetchers(self):
+        """If the first fetcher cannot provide a recent enough key, we fall back"""
+        key1 = signedjson.key.generate_signing_key(1)
+
+        def get_keys1(keys_to_fetch):
+            self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
+            return defer.succeed(
+                {
+                    "server1": {
+                        get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)
+                    }
+                }
+            )
+
+        def get_keys2(keys_to_fetch):
+            self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
+            return defer.succeed(
+                {
+                    "server1": {
+                        get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
+                    }
+                }
+            )
+
+        mock_fetcher1 = keyring.KeyFetcher()
+        mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
+        mock_fetcher2 = keyring.KeyFetcher()
+        mock_fetcher2.get_keys = Mock(side_effect=get_keys2)
+        kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2))
+
+        json1 = {}
+        signedjson.sign.sign_json(json1, "server1", key1)
+
+        results = kr.verify_json_objects_for_server(
+            [("server1", json1, 1200, "test1"), ("server1", json1, 1500, "test2")]
+        )
+        self.assertEqual(len(results), 2)
+        self.get_success(results[0])
+        e = self.get_failure(results[1], SynapseError).value
+        self.assertEqual(e.errcode, "M_UNAUTHORIZED")
+        self.assertEqual(e.code, 401)
+
+        # there should have been a single call to each fetcher
+        mock_fetcher1.get_keys.assert_called_once()
+        mock_fetcher2.get_keys.assert_called_once()
+
+
+class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
+    def make_homeserver(self, reactor, clock):
+        self.http_client = Mock()
+        hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
+        return hs
+
     def test_get_keys_from_server(self):
         # arbitrarily advance the clock a bit
         self.reactor.advance(100)
 
         SERVER_NAME = "server2"
-        kr = keyring.Keyring(self.hs)
+        fetcher = ServerKeyFetcher(self.hs)
         testkey = signedjson.key.generate_signing_key("ver1")
         testverifykey = signedjson.key.get_verify_key(testkey)
         testverifykey_id = "ed25519:ver1"
-        VALID_UNTIL_TS = 1000
+        VALID_UNTIL_TS = 200 * 1000
 
         # valid response
         response = {
@@ -238,12 +388,13 @@ class KeyringTestCase(unittest.HomeserverTestCase):
 
         self.http_client.get_json.side_effect = get_json
 
-        server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
-        keys = self.get_success(kr.get_keys_from_server(server_name_and_key_ids))
+        keys_to_fetch = {SERVER_NAME: {"key1": 0}}
+        keys = self.get_success(fetcher.get_keys(keys_to_fetch))
         k = keys[SERVER_NAME][testverifykey_id]
-        self.assertEqual(k, testverifykey)
-        self.assertEqual(k.alg, "ed25519")
-        self.assertEqual(k.version, "ver1")
+        self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
+        self.assertEqual(k.verify_key, testverifykey)
+        self.assertEqual(k.verify_key.alg, "ed25519")
+        self.assertEqual(k.verify_key.version, "ver1")
 
         # check that the perspectives store is correctly updated
         lookup_triplet = (SERVER_NAME, testverifykey_id, None)
@@ -263,18 +414,37 @@ class KeyringTestCase(unittest.HomeserverTestCase):
             bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
         )
 
-        # change the server name: it should cause a rejection
+        # change the server name: the result should be ignored
         response["server_name"] = "OTHER_SERVER"
-        self.get_failure(
-            kr.get_keys_from_server(server_name_and_key_ids), KeyLookupError
+
+        keys = self.get_success(fetcher.get_keys(keys_to_fetch))
+        self.assertEqual(keys, {})
+
+
+class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
+    def make_homeserver(self, reactor, clock):
+        self.mock_perspective_server = MockPerspectiveServer()
+        self.http_client = Mock()
+
+        config = self.default_config()
+        config["trusted_key_servers"] = [
+            {
+                "server_name": self.mock_perspective_server.server_name,
+                "verify_keys": self.mock_perspective_server.get_verify_keys(),
+            }
+        ]
+
+        return self.setup_test_homeserver(
+            handlers=None, http_client=self.http_client, config=config
         )
 
     def test_get_keys_from_perspectives(self):
         # arbitrarily advance the clock a bit
         self.reactor.advance(100)
 
+        fetcher = PerspectivesKeyFetcher(self.hs)
+
         SERVER_NAME = "server2"
-        kr = keyring.Keyring(self.hs)
         testkey = signedjson.key.generate_signing_key("ver1")
         testverifykey = signedjson.key.get_verify_key(testkey)
         testverifykey_id = "ed25519:ver1"
@@ -292,9 +462,10 @@ class KeyringTestCase(unittest.HomeserverTestCase):
             },
         }
 
-        persp_resp = {
-            "server_keys": [self.mock_perspective_server.get_signed_response(response)]
-        }
+        # the response must be signed by both the origin server and the perspectives
+        # server.
+        signedjson.sign.sign_json(response, SERVER_NAME, testkey)
+        self.mock_perspective_server.sign_response(response)
 
         def post_json(destination, path, data, **kwargs):
             self.assertEqual(destination, self.mock_perspective_server.server_name)
@@ -303,17 +474,18 @@ class KeyringTestCase(unittest.HomeserverTestCase):
             # check that the request is for the expected key
             q = data["server_keys"]
             self.assertEqual(list(q[SERVER_NAME].keys()), ["key1"])
-            return persp_resp
+            return {"server_keys": [response]}
 
         self.http_client.post_json.side_effect = post_json
 
-        server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
-        keys = self.get_success(kr.get_keys_from_perspectives(server_name_and_key_ids))
+        keys_to_fetch = {SERVER_NAME: {"key1": 0}}
+        keys = self.get_success(fetcher.get_keys(keys_to_fetch))
         self.assertIn(SERVER_NAME, keys)
         k = keys[SERVER_NAME][testverifykey_id]
-        self.assertEqual(k, testverifykey)
-        self.assertEqual(k.alg, "ed25519")
-        self.assertEqual(k.version, "ver1")
+        self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
+        self.assertEqual(k.verify_key, testverifykey)
+        self.assertEqual(k.verify_key.alg, "ed25519")
+        self.assertEqual(k.verify_key.version, "ver1")
 
         # check that the perspectives store is correctly updated
         lookup_triplet = (SERVER_NAME, testverifykey_id, None)
@@ -329,26 +501,96 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
 
         self.assertEqual(
-            bytes(res["key_json"]),
-            canonicaljson.encode_canonical_json(persp_resp["server_keys"][0]),
+            bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
         )
 
+    def test_invalid_perspectives_responses(self):
+        """Check that invalid responses from the perspectives server are rejected"""
+        # arbitrarily advance the clock a bit
+        self.reactor.advance(100)
+
+        SERVER_NAME = "server2"
+        testkey = signedjson.key.generate_signing_key("ver1")
+        testverifykey = signedjson.key.get_verify_key(testkey)
+        testverifykey_id = "ed25519:ver1"
+        VALID_UNTIL_TS = 200 * 1000
+
+        def build_response():
+            # valid response
+            response = {
+                "server_name": SERVER_NAME,
+                "old_verify_keys": {},
+                "valid_until_ts": VALID_UNTIL_TS,
+                "verify_keys": {
+                    testverifykey_id: {
+                        "key": signedjson.key.encode_verify_key_base64(testverifykey)
+                    }
+                },
+            }
+
+            # the response must be signed by both the origin server and the perspectives
+            # server.
+            signedjson.sign.sign_json(response, SERVER_NAME, testkey)
+            self.mock_perspective_server.sign_response(response)
+            return response
+
+        def get_key_from_perspectives(response):
+            fetcher = PerspectivesKeyFetcher(self.hs)
+            keys_to_fetch = {SERVER_NAME: {"key1": 0}}
+
+            def post_json(destination, path, data, **kwargs):
+                self.assertEqual(destination, self.mock_perspective_server.server_name)
+                self.assertEqual(path, "/_matrix/key/v2/query")
+                return {"server_keys": [response]}
+
+            self.http_client.post_json.side_effect = post_json
+
+            return self.get_success(fetcher.get_keys(keys_to_fetch))
+
+        # start with a valid response so we can check we are testing the right thing
+        response = build_response()
+        keys = get_key_from_perspectives(response)
+        k = keys[SERVER_NAME][testverifykey_id]
+        self.assertEqual(k.verify_key, testverifykey)
+
+        # remove the perspectives server's signature
+        response = build_response()
+        del response["signatures"][self.mock_perspective_server.server_name]
+        self.http_client.post_json.return_value = {"server_keys": [response]}
+        keys = get_key_from_perspectives(response)
+        self.assertEqual(keys, {}, "Expected empty dict with missing persp server sig")
+
+        # remove the origin server's signature
+        response = build_response()
+        del response["signatures"][SERVER_NAME]
+        self.http_client.post_json.return_value = {"server_keys": [response]}
+        keys = get_key_from_perspectives(response)
+        self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig")
+
+
+def get_key_id(key):
+    """Get the matrix ID tag for a given SigningKey or VerifyKey"""
+    return "%s:%s" % (key.alg, key.version)
+
 
 @defer.inlineCallbacks
 def run_in_context(f, *args, **kwargs):
-    with LoggingContext("testctx"):
+    with LoggingContext("testctx") as ctx:
+        # we set the "request" prop to make it easier to follow what's going on in the
+        # logs.
+        ctx.request = "testctx"
         rv = yield f(*args, **kwargs)
     defer.returnValue(rv)
 
 
-def _verify_json_for_server(keyring, server_name, json_object):
+def _verify_json_for_server(kr, *args):
     """thin wrapper around verify_json_for_server which makes sure it is wrapped
     with the patched defer.inlineCallbacks.
     """
 
     @defer.inlineCallbacks
     def v():
-        rv1 = yield keyring.verify_json_for_server(server_name, json_object)
+        rv1 = yield kr.verify_json_for_server(*args)
         defer.returnValue(rv1)
 
     return run_in_context(v)
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index d0cc492deb..9e3d4d0f47 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -37,88 +37,88 @@ class PruneEventTestCase(unittest.TestCase):
 
     def test_minimal(self):
         self.run_test(
-            {'type': 'A', 'event_id': '$test:domain'},
+            {"type": "A", "event_id": "$test:domain"},
             {
-                'type': 'A',
-                'event_id': '$test:domain',
-                'content': {},
-                'signatures': {},
-                'unsigned': {},
+                "type": "A",
+                "event_id": "$test:domain",
+                "content": {},
+                "signatures": {},
+                "unsigned": {},
             },
         )
 
     def test_basic_keys(self):
         self.run_test(
             {
-                'type': 'A',
-                'room_id': '!1:domain',
-                'sender': '@2:domain',
-                'event_id': '$3:domain',
-                'origin': 'domain',
+                "type": "A",
+                "room_id": "!1:domain",
+                "sender": "@2:domain",
+                "event_id": "$3:domain",
+                "origin": "domain",
             },
             {
-                'type': 'A',
-                'room_id': '!1:domain',
-                'sender': '@2:domain',
-                'event_id': '$3:domain',
-                'origin': 'domain',
-                'content': {},
-                'signatures': {},
-                'unsigned': {},
+                "type": "A",
+                "room_id": "!1:domain",
+                "sender": "@2:domain",
+                "event_id": "$3:domain",
+                "origin": "domain",
+                "content": {},
+                "signatures": {},
+                "unsigned": {},
             },
         )
 
     def test_unsigned_age_ts(self):
         self.run_test(
-            {'type': 'B', 'event_id': '$test:domain', 'unsigned': {'age_ts': 20}},
+            {"type": "B", "event_id": "$test:domain", "unsigned": {"age_ts": 20}},
             {
-                'type': 'B',
-                'event_id': '$test:domain',
-                'content': {},
-                'signatures': {},
-                'unsigned': {'age_ts': 20},
+                "type": "B",
+                "event_id": "$test:domain",
+                "content": {},
+                "signatures": {},
+                "unsigned": {"age_ts": 20},
             },
         )
 
         self.run_test(
             {
-                'type': 'B',
-                'event_id': '$test:domain',
-                'unsigned': {'other_key': 'here'},
+                "type": "B",
+                "event_id": "$test:domain",
+                "unsigned": {"other_key": "here"},
             },
             {
-                'type': 'B',
-                'event_id': '$test:domain',
-                'content': {},
-                'signatures': {},
-                'unsigned': {},
+                "type": "B",
+                "event_id": "$test:domain",
+                "content": {},
+                "signatures": {},
+                "unsigned": {},
             },
         )
 
     def test_content(self):
         self.run_test(
-            {'type': 'C', 'event_id': '$test:domain', 'content': {'things': 'here'}},
+            {"type": "C", "event_id": "$test:domain", "content": {"things": "here"}},
             {
-                'type': 'C',
-                'event_id': '$test:domain',
-                'content': {},
-                'signatures': {},
-                'unsigned': {},
+                "type": "C",
+                "event_id": "$test:domain",
+                "content": {},
+                "signatures": {},
+                "unsigned": {},
             },
         )
 
         self.run_test(
             {
-                'type': 'm.room.create',
-                'event_id': '$test:domain',
-                'content': {'creator': '@2:domain', 'other_field': 'here'},
+                "type": "m.room.create",
+                "event_id": "$test:domain",
+                "content": {"creator": "@2:domain", "other_field": "here"},
             },
             {
-                'type': 'm.room.create',
-                'event_id': '$test:domain',
-                'content': {'creator': '@2:domain'},
-                'signatures': {},
-                'unsigned': {},
+                "type": "m.room.create",
+                "event_id": "$test:domain",
+                "content": {"creator": "@2:domain"},
+                "signatures": {},
+                "unsigned": {},
             },
         )
 
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
new file mode 100644
index 0000000000..a5b03005d7
--- /dev/null
+++ b/tests/federation/test_complexity.py
@@ -0,0 +1,90 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 Matrix.org Foundation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.config.ratelimiting import FederationRateLimitConfig
+from synapse.federation.transport import server
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.util.ratelimitutils import FederationRateLimiter
+
+from tests import unittest
+
+
+class RoomComplexityTests(unittest.HomeserverTestCase):
+
+    servlets = [
+        admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def default_config(self, name="test"):
+        config = super(RoomComplexityTests, self).default_config(name=name)
+        config["limit_large_remote_room_joins"] = True
+        config["limit_large_remote_room_complexity"] = 0.05
+        return config
+
+    def prepare(self, reactor, clock, homeserver):
+        class Authenticator(object):
+            def authenticate_request(self, request, content):
+                return defer.succeed("otherserver.nottld")
+
+        ratelimiter = FederationRateLimiter(
+            clock,
+            FederationRateLimitConfig(
+                window_size=1,
+                sleep_limit=1,
+                sleep_msec=1,
+                reject_limit=1000,
+                concurrent_requests=1000,
+            ),
+        )
+        server.register_servlets(
+            homeserver, self.resource, Authenticator(), ratelimiter
+        )
+
+    def test_complexity_simple(self):
+
+        u1 = self.register_user("u1", "pass")
+        u1_token = self.login("u1", "pass")
+
+        room_1 = self.helper.create_room_as(u1, tok=u1_token)
+        self.helper.send_state(
+            room_1, event_type="m.room.topic", body={"topic": "foo"}, tok=u1_token
+        )
+
+        # Get the room complexity
+        request, channel = self.make_request(
+            "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
+        )
+        self.render(request)
+        self.assertEquals(200, channel.code)
+        complexity = channel.json_body["v1"]
+        self.assertTrue(complexity > 0, complexity)
+
+        # Artificially raise the complexity
+        store = self.hs.get_datastore()
+        store.get_current_state_event_counts = lambda x: defer.succeed(500 * 1.23)
+
+        # Get the room complexity again -- make sure it's our artificial value
+        request, channel = self.make_request(
+            "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
+        )
+        self.render(request)
+        self.assertEquals(200, channel.code)
+        complexity = channel.json_body["v1"]
+        self.assertEqual(complexity, 1.23)
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 7bb106b5f7..cce8d8c6de 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -51,16 +51,16 @@ class FederationSenderTestCases(HomeserverTestCase):
         json_cb = mock_send_transaction.call_args[0][1]
         data = json_cb()
         self.assertEqual(
-            data['edus'],
+            data["edus"],
             [
                 {
-                    'edu_type': 'm.receipt',
-                    'content': {
-                        'room_id': {
-                            'm.read': {
-                                'user_id': {
-                                    'event_ids': ['event_id'],
-                                    'data': {'ts': 1234},
+                    "edu_type": "m.receipt",
+                    "content": {
+                        "room_id": {
+                            "m.read": {
+                                "user_id": {
+                                    "event_ids": ["event_id"],
+                                    "data": {"ts": 1234},
                                 }
                             }
                         }
@@ -93,16 +93,16 @@ class FederationSenderTestCases(HomeserverTestCase):
         json_cb = mock_send_transaction.call_args[0][1]
         data = json_cb()
         self.assertEqual(
-            data['edus'],
+            data["edus"],
             [
                 {
-                    'edu_type': 'm.receipt',
-                    'content': {
-                        'room_id': {
-                            'm.read': {
-                                'user_id': {
-                                    'event_ids': ['event_id'],
-                                    'data': {'ts': 1234},
+                    "edu_type": "m.receipt",
+                    "content": {
+                        "room_id": {
+                            "m.read": {
+                                "user_id": {
+                                    "event_ids": ["event_id"],
+                                    "data": {"ts": 1234},
                                 }
                             }
                         }
@@ -128,16 +128,16 @@ class FederationSenderTestCases(HomeserverTestCase):
         json_cb = mock_send_transaction.call_args[0][1]
         data = json_cb()
         self.assertEqual(
-            data['edus'],
+            data["edus"],
             [
                 {
-                    'edu_type': 'm.receipt',
-                    'content': {
-                        'room_id': {
-                            'm.read': {
-                                'user_id': {
-                                    'event_ids': ['other_id'],
-                                    'data': {'ts': 1234},
+                    "edu_type": "m.receipt",
+                    "content": {
+                        "room_id": {
+                            "m.read": {
+                                "user_id": {
+                                    "event_ids": ["other_id"],
+                                    "data": {"ts": 1234},
                                 }
                             }
                         }
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 1e39fe0ec2..b204a0700d 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -117,7 +117,7 @@ class AuthTestCase(unittest.TestCase):
     def test_mau_limits_disabled(self):
         self.hs.config.limit_usage_by_mau = False
         # Ensure does not throw exception
-        yield self.auth_handler.get_access_token_for_user_id('user_a')
+        yield self.auth_handler.get_access_token_for_user_id("user_a")
 
         yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
             self._get_macaroon().serialize()
@@ -131,7 +131,7 @@ class AuthTestCase(unittest.TestCase):
         )
 
         with self.assertRaises(ResourceLimitError):
-            yield self.auth_handler.get_access_token_for_user_id('user_a')
+            yield self.auth_handler.get_access_token_for_user_id("user_a")
 
         self.hs.get_datastore().get_monthly_active_count = Mock(
             return_value=defer.succeed(self.large_number_of_users)
@@ -150,7 +150,7 @@ class AuthTestCase(unittest.TestCase):
             return_value=defer.succeed(self.hs.config.max_mau_value)
         )
         with self.assertRaises(ResourceLimitError):
-            yield self.auth_handler.get_access_token_for_user_id('user_a')
+            yield self.auth_handler.get_access_token_for_user_id("user_a")
 
         self.hs.get_datastore().get_monthly_active_count = Mock(
             return_value=defer.succeed(self.hs.config.max_mau_value)
@@ -166,7 +166,7 @@ class AuthTestCase(unittest.TestCase):
         self.hs.get_datastore().get_monthly_active_count = Mock(
             return_value=defer.succeed(self.hs.config.max_mau_value)
         )
-        yield self.auth_handler.get_access_token_for_user_id('user_a')
+        yield self.auth_handler.get_access_token_for_user_id("user_a")
         self.hs.get_datastore().user_last_seen_monthly_active = Mock(
             return_value=defer.succeed(self.hs.get_clock().time_msec())
         )
@@ -185,7 +185,7 @@ class AuthTestCase(unittest.TestCase):
             return_value=defer.succeed(self.small_number_of_users)
         )
         # Ensure does not raise exception
-        yield self.auth_handler.get_access_token_for_user_id('user_a')
+        yield self.auth_handler.get_access_token_for_user_id("user_a")
 
         self.hs.get_datastore().get_monthly_active_count = Mock(
             return_value=defer.succeed(self.small_number_of_users)
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 917548bb31..91c7a17070 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -132,7 +132,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
         request, channel = self.make_request(
             "PUT",
             b"directory/room/%23test%3Atest",
-            ('{"room_id":"%s"}' % (room_id,)).encode('ascii'),
+            ('{"room_id":"%s"}' % (room_id,)).encode("ascii"),
         )
         self.render(request)
         self.assertEquals(403, channel.code, channel.result)
@@ -143,7 +143,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
         request, channel = self.make_request(
             "PUT",
             b"directory/room/%23unofficial_test%3Atest",
-            ('{"room_id":"%s"}' % (room_id,)).encode('ascii'),
+            ('{"room_id":"%s"}' % (room_id,)).encode("ascii"),
         )
         self.render(request)
         self.assertEquals(200, channel.code, channel.result)
@@ -158,7 +158,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
         room_id = self.helper.create_room_as(self.user_id)
 
         request, channel = self.make_request(
-            "PUT", b"directory/list/room/%s" % (room_id.encode('ascii'),), b'{}'
+            "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}"
         )
         self.render(request)
         self.assertEquals(200, channel.code, channel.result)
@@ -190,7 +190,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
         # Room list disabled so we shouldn't be allowed to publish rooms
         room_id = self.helper.create_room_as(self.user_id)
         request, channel = self.make_request(
-            "PUT", b"directory/list/room/%s" % (room_id.encode('ascii'),), b'{}'
+            "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}"
         )
         self.render(request)
         self.assertEquals(403, channel.code, channel.result)
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 2e72a1dd23..c4503c1611 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -395,37 +395,37 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         yield self.handler.upload_room_keys(self.local_user, version, room_keys)
 
         new_room_keys = copy.deepcopy(room_keys)
-        new_room_key = new_room_keys['rooms']['!abc:matrix.org']['sessions']['c0ff33']
+        new_room_key = new_room_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]
 
         # test that increasing the message_index doesn't replace the existing session
-        new_room_key['first_message_index'] = 2
-        new_room_key['session_data'] = 'new'
+        new_room_key["first_message_index"] = 2
+        new_room_key["session_data"] = "new"
         yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
 
         res = yield self.handler.get_room_keys(self.local_user, version)
         self.assertEqual(
-            res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'],
+            res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
             "SSBBTSBBIEZJU0gK",
         )
 
         # test that marking the session as verified however /does/ replace it
-        new_room_key['is_verified'] = True
+        new_room_key["is_verified"] = True
         yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
 
         res = yield self.handler.get_room_keys(self.local_user, version)
         self.assertEqual(
-            res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], "new"
+            res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
         )
 
         # test that a session with a higher forwarded_count doesn't replace one
         # with a lower forwarding count
-        new_room_key['forwarded_count'] = 2
-        new_room_key['session_data'] = 'other'
+        new_room_key["forwarded_count"] = 2
+        new_room_key["session_data"] = "other"
         yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
 
         res = yield self.handler.get_room_keys(self.local_user, version)
         self.assertEqual(
-            res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], "new"
+            res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
         )
 
         # TODO: check edge cases as well as the common variations here
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 5ffba2ca7a..4edce7af43 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -52,7 +52,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         self.mock_distributor.declare("registered_user")
         self.mock_captcha_client = Mock()
         self.macaroon_generator = Mock(
-            generate_access_token=Mock(return_value='secret')
+            generate_access_token=Mock(return_value="secret")
         )
         self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
         self.handler = self.hs.get_registration_handler()
@@ -71,7 +71,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         )
         self.assertEquals(result_user_id, user_id)
         self.assertTrue(result_token is not None)
-        self.assertEquals(result_token, 'secret')
+        self.assertEquals(result_token, "secret")
 
     def test_if_user_exists(self):
         store = self.hs.get_datastore()
@@ -96,7 +96,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         self.hs.config.limit_usage_by_mau = False
         # Ensure does not throw exception
         self.get_success(
-            self.handler.get_or_create_user(self.requester, 'a', "display_name")
+            self.handler.get_or_create_user(self.requester, "a", "display_name")
         )
 
     def test_get_or_create_user_mau_not_blocked(self):
@@ -105,7 +105,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
             return_value=defer.succeed(self.hs.config.max_mau_value - 1)
         )
         # Ensure does not throw exception
-        self.get_success(self.handler.get_or_create_user(self.requester, 'c', "User"))
+        self.get_success(self.handler.get_or_create_user(self.requester, "c", "User"))
 
     def test_get_or_create_user_mau_blocked(self):
         self.hs.config.limit_usage_by_mau = True
@@ -113,7 +113,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
             return_value=defer.succeed(self.lots_of_users)
         )
         self.get_failure(
-            self.handler.get_or_create_user(self.requester, 'b', "display_name"),
+            self.handler.get_or_create_user(self.requester, "b", "display_name"),
             ResourceLimitError,
         )
 
@@ -121,7 +121,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
             return_value=defer.succeed(self.hs.config.max_mau_value)
         )
         self.get_failure(
-            self.handler.get_or_create_user(self.requester, 'b', "display_name"),
+            self.handler.get_or_create_user(self.requester, "b", "display_name"),
             ResourceLimitError,
         )
 
@@ -144,13 +144,13 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
     def test_auto_create_auto_join_rooms(self):
         room_alias_str = "#room:test"
         self.hs.config.auto_join_rooms = [room_alias_str]
-        res = self.get_success(self.handler.register(localpart='jeff'))
+        res = self.get_success(self.handler.register(localpart="jeff"))
         rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
         directory_handler = self.hs.get_handlers().directory_handler
         room_alias = RoomAlias.from_string(room_alias_str)
         room_id = self.get_success(directory_handler.get_association(room_alias))
 
-        self.assertTrue(room_id['room_id'] in rooms)
+        self.assertTrue(room_id["room_id"] in rooms)
         self.assertEqual(len(rooms), 1)
 
     def test_auto_create_auto_join_rooms_with_no_rooms(self):
@@ -173,7 +173,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         self.hs.config.autocreate_auto_join_rooms = False
         room_alias_str = "#room:test"
         self.hs.config.auto_join_rooms = [room_alias_str]
-        res = self.get_success(self.handler.register(localpart='jeff'))
+        res = self.get_success(self.handler.register(localpart="jeff"))
         rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
         self.assertEqual(len(rooms), 0)
 
@@ -182,7 +182,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         self.hs.config.auto_join_rooms = [room_alias_str]
 
         self.store.is_support_user = Mock(return_value=True)
-        res = self.get_success(self.handler.register(localpart='support'))
+        res = self.get_success(self.handler.register(localpart="support"))
         rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
         self.assertEqual(len(rooms), 0)
         directory_handler = self.hs.get_handlers().directory_handler
@@ -211,7 +211,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
 
         # When:-
         #   * the user is registered and post consent actions are called
-        res = self.get_success(self.handler.register(localpart='jeff'))
+        res = self.get_success(self.handler.register(localpart="jeff"))
         self.get_success(self.handler.post_consent_actions(res[0]))
 
         # Then:-
@@ -221,17 +221,14 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
 
     def test_register_support_user(self):
         res = self.get_success(
-            self.handler.register(localpart='user', user_type=UserTypes.SUPPORT)
+            self.handler.register(localpart="user", user_type=UserTypes.SUPPORT)
         )
         self.assertTrue(self.store.is_support_user(res[0]))
 
     def test_register_not_support_user(self):
-        res = self.get_success(self.handler.register(localpart='user'))
+        res = self.get_success(self.handler.register(localpart="user"))
         self.assertFalse(self.store.is_support_user(res[0]))
 
     def test_invalid_user_id_length(self):
         invalid_user_id = "x" * 256
-        self.get_failure(
-            self.handler.register(localpart=invalid_user_id),
-            SynapseError
-        )
+        self.get_failure(self.handler.register(localpart=invalid_user_id), SynapseError)
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
new file mode 100644
index 0000000000..a8b858eb4f
--- /dev/null
+++ b/tests/handlers/test_stats.py
@@ -0,0 +1,304 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from mock import Mock
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests import unittest
+
+
+class StatsRoomTests(unittest.HomeserverTestCase):
+
+    servlets = [
+        admin.register_servlets_for_client_rest_resource,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+
+        self.store = hs.get_datastore()
+        self.handler = self.hs.get_stats_handler()
+
+    def _add_background_updates(self):
+        """
+        Add the background updates we need to run.
+        """
+        # Ugh, have to reset this flag
+        self.store._all_done = False
+
+        self.get_success(
+            self.store._simple_insert(
+                "background_updates",
+                {"update_name": "populate_stats_createtables", "progress_json": "{}"},
+            )
+        )
+        self.get_success(
+            self.store._simple_insert(
+                "background_updates",
+                {
+                    "update_name": "populate_stats_process_rooms",
+                    "progress_json": "{}",
+                    "depends_on": "populate_stats_createtables",
+                },
+            )
+        )
+        self.get_success(
+            self.store._simple_insert(
+                "background_updates",
+                {
+                    "update_name": "populate_stats_cleanup",
+                    "progress_json": "{}",
+                    "depends_on": "populate_stats_process_rooms",
+                },
+            )
+        )
+
+    def test_initial_room(self):
+        """
+        The background updates will build the table from scratch.
+        """
+        r = self.get_success(self.store.get_all_room_state())
+        self.assertEqual(len(r), 0)
+
+        # Disable stats
+        self.hs.config.stats_enabled = False
+        self.handler.stats_enabled = False
+
+        u1 = self.register_user("u1", "pass")
+        u1_token = self.login("u1", "pass")
+
+        room_1 = self.helper.create_room_as(u1, tok=u1_token)
+        self.helper.send_state(
+            room_1, event_type="m.room.topic", body={"topic": "foo"}, tok=u1_token
+        )
+
+        # Stats disabled, shouldn't have done anything
+        r = self.get_success(self.store.get_all_room_state())
+        self.assertEqual(len(r), 0)
+
+        # Enable stats
+        self.hs.config.stats_enabled = True
+        self.handler.stats_enabled = True
+
+        # Do the initial population of the user directory via the background update
+        self._add_background_updates()
+
+        while not self.get_success(self.store.has_completed_background_updates()):
+            self.get_success(self.store.do_next_background_update(100), by=0.1)
+
+        r = self.get_success(self.store.get_all_room_state())
+
+        self.assertEqual(len(r), 1)
+        self.assertEqual(r[0]["topic"], "foo")
+
+    def test_initial_earliest_token(self):
+        """
+        Ingestion via notify_new_event will ignore tokens that the background
+        update have already processed.
+        """
+        self.reactor.advance(86401)
+
+        self.hs.config.stats_enabled = False
+        self.handler.stats_enabled = False
+
+        u1 = self.register_user("u1", "pass")
+        u1_token = self.login("u1", "pass")
+
+        u2 = self.register_user("u2", "pass")
+        u2_token = self.login("u2", "pass")
+
+        u3 = self.register_user("u3", "pass")
+        u3_token = self.login("u3", "pass")
+
+        room_1 = self.helper.create_room_as(u1, tok=u1_token)
+        self.helper.send_state(
+            room_1, event_type="m.room.topic", body={"topic": "foo"}, tok=u1_token
+        )
+
+        # Begin the ingestion by creating the temp tables. This will also store
+        # the position that the deltas should begin at, once they take over.
+        self.hs.config.stats_enabled = True
+        self.handler.stats_enabled = True
+        self.store._all_done = False
+        self.get_success(self.store.update_stats_stream_pos(None))
+
+        self.get_success(
+            self.store._simple_insert(
+                "background_updates",
+                {"update_name": "populate_stats_createtables", "progress_json": "{}"},
+            )
+        )
+
+        while not self.get_success(self.store.has_completed_background_updates()):
+            self.get_success(self.store.do_next_background_update(100), by=0.1)
+
+        # Now, before the table is actually ingested, add some more events.
+        self.helper.invite(room=room_1, src=u1, targ=u2, tok=u1_token)
+        self.helper.join(room=room_1, user=u2, tok=u2_token)
+
+        # Now do the initial ingestion.
+        self.get_success(
+            self.store._simple_insert(
+                "background_updates",
+                {"update_name": "populate_stats_process_rooms", "progress_json": "{}"},
+            )
+        )
+        self.get_success(
+            self.store._simple_insert(
+                "background_updates",
+                {
+                    "update_name": "populate_stats_cleanup",
+                    "progress_json": "{}",
+                    "depends_on": "populate_stats_process_rooms",
+                },
+            )
+        )
+
+        self.store._all_done = False
+        while not self.get_success(self.store.has_completed_background_updates()):
+            self.get_success(self.store.do_next_background_update(100), by=0.1)
+
+        self.reactor.advance(86401)
+
+        # Now add some more events, triggering ingestion. Because of the stream
+        # position being set to before the events sent in the middle, a simpler
+        # implementation would reprocess those events, and say there were four
+        # users, not three.
+        self.helper.invite(room=room_1, src=u1, targ=u3, tok=u1_token)
+        self.helper.join(room=room_1, user=u3, tok=u3_token)
+
+        # Get the deltas! There should be two -- day 1, and day 2.
+        r = self.get_success(self.store.get_deltas_for_room(room_1, 0))
+
+        # The oldest has 2 joined members
+        self.assertEqual(r[-1]["joined_members"], 2)
+
+        # The newest has 3
+        self.assertEqual(r[0]["joined_members"], 3)
+
+    def test_incorrect_state_transition(self):
+        """
+        If the state transition is not one of (JOIN, INVITE, LEAVE, BAN) to
+        (JOIN, INVITE, LEAVE, BAN), an error is raised.
+        """
+        events = {
+            "a1": {"membership": Membership.LEAVE},
+            "a2": {"membership": "not a real thing"},
+        }
+
+        def get_event(event_id, allow_none=True):
+            m = Mock()
+            m.content = events[event_id]
+            d = defer.Deferred()
+            self.reactor.callLater(0.0, d.callback, m)
+            return d
+
+        def get_received_ts(event_id):
+            return defer.succeed(1)
+
+        self.store.get_received_ts = get_received_ts
+        self.store.get_event = get_event
+
+        deltas = [
+            {
+                "type": EventTypes.Member,
+                "state_key": "some_user",
+                "room_id": "room",
+                "event_id": "a1",
+                "prev_event_id": "a2",
+                "stream_id": 60,
+            }
+        ]
+
+        f = self.get_failure(self.handler._handle_deltas(deltas), ValueError)
+        self.assertEqual(
+            f.value.args[0], "'not a real thing' is not a valid prev_membership"
+        )
+
+        # And the other way...
+        deltas = [
+            {
+                "type": EventTypes.Member,
+                "state_key": "some_user",
+                "room_id": "room",
+                "event_id": "a2",
+                "prev_event_id": "a1",
+                "stream_id": 100,
+            }
+        ]
+
+        f = self.get_failure(self.handler._handle_deltas(deltas), ValueError)
+        self.assertEqual(
+            f.value.args[0], "'not a real thing' is not a valid membership"
+        )
+
+    def test_redacted_prev_event(self):
+        """
+        If the prev_event does not exist, then it is assumed to be a LEAVE.
+        """
+        u1 = self.register_user("u1", "pass")
+        u1_token = self.login("u1", "pass")
+
+        room_1 = self.helper.create_room_as(u1, tok=u1_token)
+
+        # Do the initial population of the user directory via the background update
+        self._add_background_updates()
+
+        while not self.get_success(self.store.has_completed_background_updates()):
+            self.get_success(self.store.do_next_background_update(100), by=0.1)
+
+        events = {"a1": None, "a2": {"membership": Membership.JOIN}}
+
+        def get_event(event_id, allow_none=True):
+            if events.get(event_id):
+                m = Mock()
+                m.content = events[event_id]
+            else:
+                m = None
+            d = defer.Deferred()
+            self.reactor.callLater(0.0, d.callback, m)
+            return d
+
+        def get_received_ts(event_id):
+            return defer.succeed(1)
+
+        self.store.get_received_ts = get_received_ts
+        self.store.get_event = get_event
+
+        deltas = [
+            {
+                "type": EventTypes.Member,
+                "state_key": "some_user:test",
+                "room_id": room_1,
+                "event_id": "a2",
+                "prev_event_id": "a1",
+                "stream_id": 100,
+            }
+        ]
+
+        # Handle our fake deltas, which has a user going from LEAVE -> JOIN.
+        self.get_success(self.handler._handle_deltas(deltas))
+
+        # One delta, with two joined members -- the room creator, and our fake
+        # user.
+        r = self.get_success(self.store.get_deltas_for_room(room_1, 0))
+        self.assertEqual(len(r), 1)
+        self.assertEqual(r[0]["joined_members"], 2)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index cb8b4d2913..5d5e324df2 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -47,7 +47,7 @@ def _expect_edu_transaction(edu_type, content, origin="test"):
 
 
 def _make_edu_transaction_json(edu_type, content):
-    return json.dumps(_expect_edu_transaction(edu_type, content)).encode('utf8')
+    return json.dumps(_expect_edu_transaction(edu_type, content)).encode("utf8")
 
 
 class TypingNotificationsTestCase(unittest.HomeserverTestCase):
@@ -151,7 +151,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])])
+        self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
 
         self.assertEquals(self.event_source.get_current_key(), 1)
         events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
@@ -209,12 +209,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
                     "typing": True,
                 },
             ),
-            federation_auth_origin=b'farm',
+            federation_auth_origin=b"farm",
         )
         self.render(request)
         self.assertEqual(channel.code, 200)
 
-        self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])])
+        self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
 
         self.assertEquals(self.event_source.get_current_key(), 1)
         events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
@@ -247,7 +247,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])])
+        self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
 
         put_json = self.hs.get_http_client().put_json
         put_json.assert_called_once_with(
@@ -285,7 +285,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])])
+        self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
         self.on_new_event.reset_mock()
 
         self.assertEquals(self.event_source.get_current_key(), 1)
@@ -303,7 +303,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.reactor.pump([16])
 
-        self.on_new_event.assert_has_calls([call('typing_key', 2, rooms=[ROOM_ID])])
+        self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])])
 
         self.assertEquals(self.event_source.get_current_key(), 2)
         events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1)
@@ -320,7 +320,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        self.on_new_event.assert_has_calls([call('typing_key', 3, rooms=[ROOM_ID])])
+        self.on_new_event.assert_has_calls([call("typing_key", 3, rooms=[ROOM_ID])])
         self.on_new_event.reset_mock()
 
         self.assertEquals(self.event_source.get_current_key(), 3)
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 9021e647fe..b135486c48 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -60,15 +60,15 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         )
         profile = self.get_success(self.store.get_user_in_directory(support_user_id))
         self.assertTrue(profile is None)
-        display_name = 'display_name'
+        display_name = "display_name"
 
-        profile_info = ProfileInfo(avatar_url='avatar_url', display_name=display_name)
-        regular_user_id = '@regular:test'
+        profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name)
+        regular_user_id = "@regular:test"
         self.get_success(
             self.handler.handle_local_profile_change(regular_user_id, profile_info)
         )
         profile = self.get_success(self.store.get_user_in_directory(regular_user_id))
-        self.assertTrue(profile['display_name'] == display_name)
+        self.assertTrue(profile["display_name"] == display_name)
 
     def test_handle_user_deactivated_support_user(self):
         s_user_id = "@support:test"
diff --git a/tests/http/__init__.py b/tests/http/__init__.py
index 851fc0eb33..2d5dba6464 100644
--- a/tests/http/__init__.py
+++ b/tests/http/__init__.py
@@ -13,28 +13,122 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import os.path
+import subprocess
+
+from zope.interface import implementer
 
 from OpenSSL import SSL
+from OpenSSL.SSL import Connection
+from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
+
+
+def get_test_ca_cert_file():
+    """Get the path to the test CA cert
+
+    The keypair is generated with:
+
+        openssl genrsa -out ca.key 2048
+        openssl req -new -x509 -key ca.key -days 3650 -out ca.crt \
+            -subj '/CN=synapse test CA'
+    """
+    return os.path.join(os.path.dirname(__file__), "ca.crt")
+
+
+def get_test_key_file():
+    """get the path to the test key
+
+    The key file is made with:
+
+        openssl genrsa -out server.key 2048
+    """
+    return os.path.join(os.path.dirname(__file__), "server.key")
+
+
+cert_file_count = 0
+
+CONFIG_TEMPLATE = b"""\
+[default]
+basicConstraints = CA:FALSE
+keyUsage=nonRepudiation, digitalSignature, keyEncipherment
+subjectAltName = %(sanentries)s
+"""
+
+
+def create_test_cert_file(sanlist):
+    """build an x509 certificate file
+
+    Args:
+        sanlist: list[bytes]: a list of subjectAltName values for the cert
+
+    Returns:
+        str: the path to the file
+    """
+    global cert_file_count
+    csr_filename = "server.csr"
+    cnf_filename = "server.%i.cnf" % (cert_file_count,)
+    cert_filename = "server.%i.crt" % (cert_file_count,)
+    cert_file_count += 1
+
+    # first build a CSR
+    subprocess.check_call(
+        [
+            "openssl",
+            "req",
+            "-new",
+            "-key",
+            get_test_key_file(),
+            "-subj",
+            "/",
+            "-out",
+            csr_filename,
+        ]
+    )
 
+    # now a config file describing the right SAN entries
+    sanentries = b",".join(sanlist)
+    with open(cnf_filename, "wb") as f:
+        f.write(CONFIG_TEMPLATE % {b"sanentries": sanentries})
 
-def get_test_cert_file():
-    """get the path to the test cert"""
+    # finally the cert
+    ca_key_filename = os.path.join(os.path.dirname(__file__), "ca.key")
+    ca_cert_filename = get_test_ca_cert_file()
+    subprocess.check_call(
+        [
+            "openssl",
+            "x509",
+            "-req",
+            "-in",
+            csr_filename,
+            "-CA",
+            ca_cert_filename,
+            "-CAkey",
+            ca_key_filename,
+            "-set_serial",
+            "1",
+            "-extfile",
+            cnf_filename,
+            "-out",
+            cert_filename,
+        ]
+    )
 
-    # the cert file itself is made with:
-    #
-    # openssl req -x509 -newkey rsa:4096 -keyout server.pem  -out server.pem -days 36500 \
-    #     -nodes -subj '/CN=testserv'
-    return os.path.join(os.path.dirname(__file__), 'server.pem')
+    return cert_filename
 
 
-class ServerTLSContext(object):
-    """A TLS Context which presents our test cert."""
+@implementer(IOpenSSLServerConnectionCreator)
+class TestServerTLSConnectionFactory(object):
+    """An SSL connection creator which returns connections which present a certificate
+    signed by our test CA."""
 
-    def __init__(self):
-        self.filename = get_test_cert_file()
+    def __init__(self, sanlist):
+        """
+        Args:
+            sanlist: list[bytes]: a list of subjectAltName values for the cert
+        """
+        self._cert_file = create_test_cert_file(sanlist)
 
-    def getContext(self):
+    def serverConnectionForTLS(self, tlsProtocol):
         ctx = SSL.Context(SSL.TLSv1_METHOD)
-        ctx.use_certificate_file(self.filename)
-        ctx.use_privatekey_file(self.filename)
-        return ctx
+        ctx.use_certificate_file(self._cert_file)
+        ctx.use_privatekey_file(get_test_key_file())
+        return Connection(ctx, None)
diff --git a/tests/http/ca.crt b/tests/http/ca.crt
new file mode 100644
index 0000000000..730f81e99c
--- /dev/null
+++ b/tests/http/ca.crt
@@ -0,0 +1,19 @@
+-----BEGIN CERTIFICATE-----
+MIIDCjCCAfKgAwIBAgIJAPwHIHgH/jtjMA0GCSqGSIb3DQEBCwUAMBoxGDAWBgNV
+BAMMD3N5bmFwc2UgdGVzdCBDQTAeFw0xOTA2MTAxMTI2NDdaFw0yOTA2MDcxMTI2
+NDdaMBoxGDAWBgNVBAMMD3N5bmFwc2UgdGVzdCBDQTCCASIwDQYJKoZIhvcNAQEB
+BQADggEPADCCAQoCggEBAOZOXCKuylf9jHzJXpU2nS+XEKrnGPgs2SAhQKrzBxg3
+/d8KT2Zsfsj1i3G7oGu7B0ZKO6qG5AxOPCmSMf9/aiSHFilfSh+r8rCpJyWMev2c
+/w/xmhoFHgn+H90NnqlXvWb5y1YZCE3gWaituQSaa93GPKacRqXCgIrzjPUuhfeT
+uwFQt4iyUhMNBYEy3aw4IuIHdyBqi4noUhR2ZeuflLJ6PswdJ8mEiAvxCbBGPerq
+idhWcZwlo0fKu4u1uu5B8TnTsMg2fJgL6c5olBG90Urt22gA6anfP5W/U1ZdVhmB
+T3Rv5SJMkGyMGE6sEUetLFyb2GJpgGD7ePkUCZr+IMMCAwEAAaNTMFEwHQYDVR0O
+BBYEFLg7nTCYsvQXWTyS6upLc0YTlIwRMB8GA1UdIwQYMBaAFLg7nTCYsvQXWTyS
+6upLc0YTlIwRMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEBADqx
+GX4Ul5OGQlcG+xTt4u3vMCeqGo8mh1AnJ7zQbyRmwjJiNxJVX+/EcqFSTsmkBNoe
+xdYITI7Z6dyoiKw99yCZDE7gALcyACEU7r0XY7VY/hebAaX6uLaw1sZKKAIC04lD
+KgCu82tG85n60Qyud5SiZZF0q1XVq7lbvOYVdzVZ7k8Vssy5p9XnaLJLMggYeOiX
+psHIQjvYGnTTEBZZHzWOrc0WGThd69wxTOOkAbCsoTPEwZL8BGUsdtLWtvhp452O
+npvaUBzKg39R5X3KTdhB68XptiQfzbQkd3FtrwNuYPUywlsg55Bxkv85n57+xDO3
+D9YkgUqEp0RGUXQgCsQ=
+-----END CERTIFICATE-----
diff --git a/tests/http/ca.key b/tests/http/ca.key
new file mode 100644
index 0000000000..5c99cae186
--- /dev/null
+++ b/tests/http/ca.key
@@ -0,0 +1,27 @@
+-----BEGIN RSA PRIVATE KEY-----
+MIIEpgIBAAKCAQEA5k5cIq7KV/2MfMlelTadL5cQqucY+CzZICFAqvMHGDf93wpP
+Zmx+yPWLcbuga7sHRko7qobkDE48KZIx/39qJIcWKV9KH6vysKknJYx6/Zz/D/Ga
+GgUeCf4f3Q2eqVe9ZvnLVhkITeBZqK25BJpr3cY8ppxGpcKAivOM9S6F95O7AVC3
+iLJSEw0FgTLdrDgi4gd3IGqLiehSFHZl65+Usno+zB0nyYSIC/EJsEY96uqJ2FZx
+nCWjR8q7i7W67kHxOdOwyDZ8mAvpzmiUEb3RSu3baADpqd8/lb9TVl1WGYFPdG/l
+IkyQbIwYTqwRR60sXJvYYmmAYPt4+RQJmv4gwwIDAQABAoIBAQCFuFG+wYYy+MCt
+Y65LLN6vVyMSWAQjdMbM5QHLQDiKU1hQPIhFjBFBVXCVpL9MTde3dDqYlKGsk3BT
+ItNs6eoTM2wmsXE0Wn4bHNvh7WMsBhACjeFP4lDCtI6DpvjMkmkidT8eyoIL1Yu5
+aMTYa2Dd79AfXPWYIQrJowfhBBY83KuW5fmYnKKDVLqkT9nf2dgmmQz85RgtNiZC
+zFkIsNmPqH1zRbcw0wORfOBrLFvsMc4Tt8EY5Wz3NnH8Zfgf8Q3MgARH1yspz3Vp
+B+EYHbsK17xZ+P59KPiX3yefvyYWEUjFF7ymVsVnDxLugYl4pXwWUpm19GxeDvFk
+cgBUD5OBAoGBAP7lBdCp6lx6fYtxdxUm3n4MMQmYcac4qZdeBIrvpFMnvOBBuixl
+eavcfFmFdwgAr8HyVYiu9ynac504IYvmtYlcpUmiRBbmMHbvLQEYHl7FYFKNz9ej
+2ue4oJE3RsPdLsD3xIlc+xN8oT1j0knyorwsHdj0Sv77eZzZS9XZZfJzAoGBAOdO
+CibYmoNqK/mqDHkp6PgsnbQGD5/CvPF/BLUWV1QpHxLzUQQeoBOQW5FatHe1H5zi
+mbq3emBefVmsCLrRIJ4GQu4vsTMfjcpGLwviWmaK6pHbGPt8IYeEQ2MNyv59EtA2
+pQy4dX7/Oe6NLAR1UEQjXmCuXf+rxnxF3VJd1nRxAoGBANb9eusl9fusgSnVOTjJ
+AQ7V36KVRv9hZoG6liBNwo80zDVmms4JhRd1MBkd3mkMkzIF4SkZUnWlwLBSANGM
+dX/3eZ5i1AVwgF5Am/f5TNxopDbdT/o1RVT/P8dcFT7s1xuBn+6wU0F7dFBgWqVu
+lt4aY85zNrJcj5XBHhqwdDGLAoGBAIksPNUAy9F3m5C6ih8o/aKAQx5KIeXrBUZq
+v43tK+kbYfRJHBjHWMOBbuxq0G/VmGPf9q9GtGqGXuxZG+w+rYtJx1OeMQZShjIZ
+ITl5CYeahrXtK4mo+fF2PMh3m5UE861LWuKKWhPwpJiWXC5grDNcjlHj1pcTdeip
+PjHkuJPhAoGBAIh35DptqqdicOd3dr/+/m2YQywY8aSpMrR0bC06aAkscD7oq4tt
+s/jwl0UlHIrEm/aMN7OnGIbpfkVdExfGKYaa5NRlgOwQpShwLufIo/c8fErd2zb8
+K3ptlwBxMrayMXpS3DP78r83Z0B8/FSK2guelzdRJ3ftipZ9io1Gss1C
+-----END RSA PRIVATE KEY-----
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index ed0ca079d9..417fda3ab2 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -17,16 +17,19 @@ import logging
 from mock import Mock
 
 import treq
+from service_identity import VerificationError
 from zope.interface import implementer
 
 from twisted.internet import defer
 from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOptions
 from twisted.internet.protocol import Factory
 from twisted.protocols.tls import TLSMemoryBIOFactory
+from twisted.web._newclient import ResponseNeverReceived
 from twisted.web.http import HTTPChannel
 from twisted.web.http_headers import Headers
 from twisted.web.iweb import IPolicyForHTTPS
 
+from synapse.config.homeserver import HomeServerConfig
 from synapse.crypto.context_factory import ClientTLSOptionsFactory
 from synapse.http.federation.matrix_federation_agent import (
     MatrixFederationAgent,
@@ -36,13 +39,31 @@ from synapse.http.federation.srv_resolver import Server
 from synapse.util.caches.ttlcache import TTLCache
 from synapse.util.logcontext import LoggingContext
 
-from tests.http import ServerTLSContext
+from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
 from tests.server import FakeTransport, ThreadedMemoryReactorClock
 from tests.unittest import TestCase
 from tests.utils import default_config
 
 logger = logging.getLogger(__name__)
 
+test_server_connection_factory = None
+
+
+def get_connection_factory():
+    # this needs to happen once, but not until we are ready to run the first test
+    global test_server_connection_factory
+    if test_server_connection_factory is None:
+        test_server_connection_factory = TestServerTLSConnectionFactory(
+            sanlist=[
+                b"DNS:testserv",
+                b"DNS:target-server",
+                b"DNS:xn--bcher-kva.com",
+                b"IP:1.2.3.4",
+                b"IP:::1",
+            ]
+        )
+    return test_server_connection_factory
+
 
 class MatrixFederationAgentTests(TestCase):
     def setUp(self):
@@ -52,11 +73,16 @@ class MatrixFederationAgentTests(TestCase):
 
         self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
 
+        config_dict = default_config("test", parse=False)
+        config_dict["federation_custom_ca_list"] = [get_test_ca_cert_file()]
+        # config_dict["trusted_key_servers"] = []
+
+        self._config = config = HomeServerConfig()
+        config.parse_config_dict(config_dict, "", "")
+
         self.agent = MatrixFederationAgent(
             reactor=self.reactor,
-            tls_client_options_factory=ClientTLSOptionsFactory(
-                default_config("test", parse=True)
-            ),
+            tls_client_options_factory=ClientTLSOptionsFactory(config),
             _well_known_tls_policy=TrustingTLSPolicyForHTTPS(),
             _srv_resolver=self.mock_resolver,
             _well_known_cache=self.well_known_cache,
@@ -70,7 +96,7 @@ class MatrixFederationAgentTests(TestCase):
         """
 
         # build the test server
-        server_tls_protocol = _build_test_server()
+        server_tls_protocol = _build_test_server(get_connection_factory())
 
         # now, tell the client protocol factory to build the client protocol (it will be a
         # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
@@ -109,7 +135,7 @@ class MatrixFederationAgentTests(TestCase):
         Sends a simple GET request via the agent, and checks its logcontext management
         """
         with LoggingContext("one") as context:
-            fetch_d = self.agent.request(b'GET', uri)
+            fetch_d = self.agent.request(b"GET", uri)
 
             # Nothing happened yet
             self.assertNoResult(fetch_d)
@@ -153,9 +179,9 @@ class MatrixFederationAgentTests(TestCase):
         """Check that an incoming request looks like a valid .well-known request, and
         send back the response.
         """
-        self.assertEqual(request.method, b'GET')
-        self.assertEqual(request.path, b'/.well-known/matrix/server')
-        self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv'])
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"/.well-known/matrix/server")
+        self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"])
         # send back a response
         for k, v in headers.items():
             request.setHeader(k, v)
@@ -178,7 +204,7 @@ class MatrixFederationAgentTests(TestCase):
         clients = self.reactor.tcpClients
         self.assertEqual(len(clients), 1)
         (host, port, client_factory, _timeout, _bindAddress) = clients[0]
-        self.assertEqual(host, '1.2.3.4')
+        self.assertEqual(host, "1.2.3.4")
         self.assertEqual(port, 8448)
 
         # make a test server, and wire up the client
@@ -186,20 +212,20 @@ class MatrixFederationAgentTests(TestCase):
 
         self.assertEqual(len(http_server.requests), 1)
         request = http_server.requests[0]
-        self.assertEqual(request.method, b'GET')
-        self.assertEqual(request.path, b'/foo/bar')
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"/foo/bar")
         self.assertEqual(
-            request.requestHeaders.getRawHeaders(b'host'), [b'testserv:8448']
+            request.requestHeaders.getRawHeaders(b"host"), [b"testserv:8448"]
         )
         content = request.content.read()
-        self.assertEqual(content, b'')
+        self.assertEqual(content, b"")
 
         # Deferred is still without a result
         self.assertNoResult(test_d)
 
         # send the headers
-        request.responseHeaders.setRawHeaders(b'Content-Type', [b'application/json'])
-        request.write('')
+        request.responseHeaders.setRawHeaders(b"Content-Type", [b"application/json"])
+        request.write("")
 
         self.reactor.pump((0.1,))
 
@@ -209,7 +235,7 @@ class MatrixFederationAgentTests(TestCase):
         self.assertEqual(response.code, 200)
 
         # Send the body
-        request.write('{ "a": 1 }'.encode('ascii'))
+        request.write('{ "a": 1 }'.encode("ascii"))
         request.finish()
 
         self.reactor.pump((0.1,))
@@ -234,7 +260,7 @@ class MatrixFederationAgentTests(TestCase):
         clients = self.reactor.tcpClients
         self.assertEqual(len(clients), 1)
         (host, port, client_factory, _timeout, _bindAddress) = clients[0]
-        self.assertEqual(host, '1.2.3.4')
+        self.assertEqual(host, "1.2.3.4")
         self.assertEqual(port, 8448)
 
         # make a test server, and wire up the client
@@ -242,9 +268,9 @@ class MatrixFederationAgentTests(TestCase):
 
         self.assertEqual(len(http_server.requests), 1)
         request = http_server.requests[0]
-        self.assertEqual(request.method, b'GET')
-        self.assertEqual(request.path, b'/foo/bar')
-        self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'1.2.3.4'])
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"/foo/bar")
+        self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"1.2.3.4"])
 
         # finish the request
         request.finish()
@@ -269,7 +295,7 @@ class MatrixFederationAgentTests(TestCase):
         clients = self.reactor.tcpClients
         self.assertEqual(len(clients), 1)
         (host, port, client_factory, _timeout, _bindAddress) = clients[0]
-        self.assertEqual(host, '::1')
+        self.assertEqual(host, "::1")
         self.assertEqual(port, 8448)
 
         # make a test server, and wire up the client
@@ -277,9 +303,9 @@ class MatrixFederationAgentTests(TestCase):
 
         self.assertEqual(len(http_server.requests), 1)
         request = http_server.requests[0]
-        self.assertEqual(request.method, b'GET')
-        self.assertEqual(request.path, b'/foo/bar')
-        self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'[::1]'])
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"/foo/bar")
+        self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"[::1]"])
 
         # finish the request
         request.finish()
@@ -304,7 +330,7 @@ class MatrixFederationAgentTests(TestCase):
         clients = self.reactor.tcpClients
         self.assertEqual(len(clients), 1)
         (host, port, client_factory, _timeout, _bindAddress) = clients[0]
-        self.assertEqual(host, '::1')
+        self.assertEqual(host, "::1")
         self.assertEqual(port, 80)
 
         # make a test server, and wire up the client
@@ -312,15 +338,97 @@ class MatrixFederationAgentTests(TestCase):
 
         self.assertEqual(len(http_server.requests), 1)
         request = http_server.requests[0]
-        self.assertEqual(request.method, b'GET')
-        self.assertEqual(request.path, b'/foo/bar')
-        self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'[::1]:80'])
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"/foo/bar")
+        self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"[::1]:80"])
 
         # finish the request
         request.finish()
         self.reactor.pump((0.1,))
         self.successResultOf(test_d)
 
+    def test_get_hostname_bad_cert(self):
+        """
+        Test the behaviour when the certificate on the server doesn't match the hostname
+        """
+        self.mock_resolver.resolve_service.side_effect = lambda _: []
+        self.reactor.lookups["testserv1"] = "1.2.3.4"
+
+        test_d = self._make_get_request(b"matrix://testserv1/foo/bar")
+
+        # Nothing happened yet
+        self.assertNoResult(test_d)
+
+        # No SRV record lookup yet
+        self.mock_resolver.resolve_service.assert_not_called()
+
+        # there should be an attempt to connect on port 443 for the .well-known
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+        self.assertEqual(host, "1.2.3.4")
+        self.assertEqual(port, 443)
+
+        # fonx the connection
+        client_factory.clientConnectionFailed(None, Exception("nope"))
+
+        # attemptdelay on the hostnameendpoint is 0.3, so takes that long before the
+        # .well-known request fails.
+        self.reactor.pump((0.4,))
+
+        # now there should be a SRV lookup
+        self.mock_resolver.resolve_service.assert_called_once_with(
+            b"_matrix._tcp.testserv1"
+        )
+
+        # we should fall back to a direct connection
+        self.assertEqual(len(clients), 2)
+        (host, port, client_factory, _timeout, _bindAddress) = clients[1]
+        self.assertEqual(host, "1.2.3.4")
+        self.assertEqual(port, 8448)
+
+        # make a test server, and wire up the client
+        http_server = self._make_connection(client_factory, expected_sni=b"testserv1")
+
+        # there should be no requests
+        self.assertEqual(len(http_server.requests), 0)
+
+        # ... and the request should have failed
+        e = self.failureResultOf(test_d, ResponseNeverReceived)
+        failure_reason = e.value.reasons[0]
+        self.assertIsInstance(failure_reason.value, VerificationError)
+
+    def test_get_ip_address_bad_cert(self):
+        """
+        Test the behaviour when the server name contains an explicit IP, but
+        the server cert doesn't cover it
+        """
+        # there will be a getaddrinfo on the IP
+        self.reactor.lookups["1.2.3.5"] = "1.2.3.5"
+
+        test_d = self._make_get_request(b"matrix://1.2.3.5/foo/bar")
+
+        # Nothing happened yet
+        self.assertNoResult(test_d)
+
+        # Make sure treq is trying to connect
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+        self.assertEqual(host, "1.2.3.5")
+        self.assertEqual(port, 8448)
+
+        # make a test server, and wire up the client
+        http_server = self._make_connection(client_factory, expected_sni=None)
+
+        # there should be no requests
+        self.assertEqual(len(http_server.requests), 0)
+
+        # ... and the request should have failed
+        e = self.failureResultOf(test_d, ResponseNeverReceived)
+        failure_reason = e.value.reasons[0]
+        self.assertIsInstance(failure_reason.value, VerificationError)
+
     def test_get_no_srv_no_well_known(self):
         """
         Test the behaviour when the server name has no port, no SRV, and no well-known
@@ -341,7 +449,7 @@ class MatrixFederationAgentTests(TestCase):
         clients = self.reactor.tcpClients
         self.assertEqual(len(clients), 1)
         (host, port, client_factory, _timeout, _bindAddress) = clients[0]
-        self.assertEqual(host, '1.2.3.4')
+        self.assertEqual(host, "1.2.3.4")
         self.assertEqual(port, 443)
 
         # fonx the connection
@@ -359,17 +467,17 @@ class MatrixFederationAgentTests(TestCase):
         # we should fall back to a direct connection
         self.assertEqual(len(clients), 2)
         (host, port, client_factory, _timeout, _bindAddress) = clients[1]
-        self.assertEqual(host, '1.2.3.4')
+        self.assertEqual(host, "1.2.3.4")
         self.assertEqual(port, 8448)
 
         # make a test server, and wire up the client
-        http_server = self._make_connection(client_factory, expected_sni=b'testserv')
+        http_server = self._make_connection(client_factory, expected_sni=b"testserv")
 
         self.assertEqual(len(http_server.requests), 1)
         request = http_server.requests[0]
-        self.assertEqual(request.method, b'GET')
-        self.assertEqual(request.path, b'/foo/bar')
-        self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv'])
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"/foo/bar")
+        self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"])
 
         # finish the request
         request.finish()
@@ -393,7 +501,7 @@ class MatrixFederationAgentTests(TestCase):
         clients = self.reactor.tcpClients
         self.assertEqual(len(clients), 1)
         (host, port, client_factory, _timeout, _bindAddress) = clients[0]
-        self.assertEqual(host, '1.2.3.4')
+        self.assertEqual(host, "1.2.3.4")
         self.assertEqual(port, 443)
 
         self._handle_well_known_connection(
@@ -410,20 +518,20 @@ class MatrixFederationAgentTests(TestCase):
         # now we should get a connection to the target server
         self.assertEqual(len(clients), 2)
         (host, port, client_factory, _timeout, _bindAddress) = clients[1]
-        self.assertEqual(host, '1::f')
+        self.assertEqual(host, "1::f")
         self.assertEqual(port, 8448)
 
         # make a test server, and wire up the client
         http_server = self._make_connection(
-            client_factory, expected_sni=b'target-server'
+            client_factory, expected_sni=b"target-server"
         )
 
         self.assertEqual(len(http_server.requests), 1)
         request = http_server.requests[0]
-        self.assertEqual(request.method, b'GET')
-        self.assertEqual(request.path, b'/foo/bar')
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"/foo/bar")
         self.assertEqual(
-            request.requestHeaders.getRawHeaders(b'host'), [b'target-server']
+            request.requestHeaders.getRawHeaders(b"host"), [b"target-server"]
         )
 
         # finish the request
@@ -455,7 +563,7 @@ class MatrixFederationAgentTests(TestCase):
         clients = self.reactor.tcpClients
         self.assertEqual(len(clients), 1)
         (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
-        self.assertEqual(host, '1.2.3.4')
+        self.assertEqual(host, "1.2.3.4")
         self.assertEqual(port, 443)
 
         redirect_server = self._make_connection(
@@ -465,7 +573,7 @@ class MatrixFederationAgentTests(TestCase):
         # send a 302 redirect
         self.assertEqual(len(redirect_server.requests), 1)
         request = redirect_server.requests[0]
-        request.redirect(b'https://testserv/even_better_known')
+        request.redirect(b"https://testserv/even_better_known")
         request.finish()
 
         self.reactor.pump((0.1,))
@@ -474,7 +582,7 @@ class MatrixFederationAgentTests(TestCase):
         clients = self.reactor.tcpClients
         self.assertEqual(len(clients), 1)
         (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
-        self.assertEqual(host, '1.2.3.4')
+        self.assertEqual(host, "1.2.3.4")
         self.assertEqual(port, 443)
 
         well_known_server = self._make_connection(
@@ -483,8 +591,8 @@ class MatrixFederationAgentTests(TestCase):
 
         self.assertEqual(len(well_known_server.requests), 1, "No request after 302")
         request = well_known_server.requests[0]
-        self.assertEqual(request.method, b'GET')
-        self.assertEqual(request.path, b'/even_better_known')
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"/even_better_known")
         request.write(b'{ "m.server": "target-server" }')
         request.finish()
 
@@ -498,20 +606,20 @@ class MatrixFederationAgentTests(TestCase):
         # now we should get a connection to the target server
         self.assertEqual(len(clients), 1)
         (host, port, client_factory, _timeout, _bindAddress) = clients[0]
-        self.assertEqual(host, '1::f')
+        self.assertEqual(host, "1::f")
         self.assertEqual(port, 8448)
 
         # make a test server, and wire up the client
         http_server = self._make_connection(
-            client_factory, expected_sni=b'target-server'
+            client_factory, expected_sni=b"target-server"
         )
 
         self.assertEqual(len(http_server.requests), 1)
         request = http_server.requests[0]
-        self.assertEqual(request.method, b'GET')
-        self.assertEqual(request.path, b'/foo/bar')
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"/foo/bar")
         self.assertEqual(
-            request.requestHeaders.getRawHeaders(b'host'), [b'target-server']
+            request.requestHeaders.getRawHeaders(b"host"), [b"target-server"]
         )
 
         # finish the request
@@ -546,11 +654,11 @@ class MatrixFederationAgentTests(TestCase):
         clients = self.reactor.tcpClients
         self.assertEqual(len(clients), 1)
         (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
-        self.assertEqual(host, '1.2.3.4')
+        self.assertEqual(host, "1.2.3.4")
         self.assertEqual(port, 443)
 
         self._handle_well_known_connection(
-            client_factory, expected_sni=b"testserv", content=b'NOT JSON'
+            client_factory, expected_sni=b"testserv", content=b"NOT JSON"
         )
 
         # now there should be a SRV lookup
@@ -561,23 +669,64 @@ class MatrixFederationAgentTests(TestCase):
         # we should fall back to a direct connection
         self.assertEqual(len(clients), 1)
         (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
-        self.assertEqual(host, '1.2.3.4')
+        self.assertEqual(host, "1.2.3.4")
         self.assertEqual(port, 8448)
 
         # make a test server, and wire up the client
-        http_server = self._make_connection(client_factory, expected_sni=b'testserv')
+        http_server = self._make_connection(client_factory, expected_sni=b"testserv")
 
         self.assertEqual(len(http_server.requests), 1)
         request = http_server.requests[0]
-        self.assertEqual(request.method, b'GET')
-        self.assertEqual(request.path, b'/foo/bar')
-        self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv'])
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"/foo/bar")
+        self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"])
 
         # finish the request
         request.finish()
         self.reactor.pump((0.1,))
         self.successResultOf(test_d)
 
+    def test_get_well_known_unsigned_cert(self):
+        """Test the behaviour when the .well-known server presents a cert
+        not signed by a CA
+        """
+
+        # we use the same test server as the other tests, but use an agent
+        # with _well_known_tls_policy left to the default, which will not
+        # trust it (since the presented cert is signed by a test CA)
+
+        self.mock_resolver.resolve_service.side_effect = lambda _: []
+        self.reactor.lookups["testserv"] = "1.2.3.4"
+
+        agent = MatrixFederationAgent(
+            reactor=self.reactor,
+            tls_client_options_factory=ClientTLSOptionsFactory(self._config),
+            _srv_resolver=self.mock_resolver,
+            _well_known_cache=self.well_known_cache,
+        )
+
+        test_d = agent.request(b"GET", b"matrix://testserv/foo/bar")
+
+        # Nothing happened yet
+        self.assertNoResult(test_d)
+
+        # there should be an attempt to connect on port 443 for the .well-known
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+        self.assertEqual(host, "1.2.3.4")
+        self.assertEqual(port, 443)
+
+        http_proto = self._make_connection(client_factory, expected_sni=b"testserv")
+
+        # there should be no requests
+        self.assertEqual(len(http_proto.requests), 0)
+
+        # and there should be a SRV lookup instead
+        self.mock_resolver.resolve_service.assert_called_once_with(
+            b"_matrix._tcp.testserv"
+        )
+
     def test_get_hostname_srv(self):
         """
         Test the behaviour when there is a single SRV record
@@ -601,17 +750,17 @@ class MatrixFederationAgentTests(TestCase):
         clients = self.reactor.tcpClients
         self.assertEqual(len(clients), 1)
         (host, port, client_factory, _timeout, _bindAddress) = clients[0]
-        self.assertEqual(host, '1.2.3.4')
+        self.assertEqual(host, "1.2.3.4")
         self.assertEqual(port, 8443)
 
         # make a test server, and wire up the client
-        http_server = self._make_connection(client_factory, expected_sni=b'testserv')
+        http_server = self._make_connection(client_factory, expected_sni=b"testserv")
 
         self.assertEqual(len(http_server.requests), 1)
         request = http_server.requests[0]
-        self.assertEqual(request.method, b'GET')
-        self.assertEqual(request.path, b'/foo/bar')
-        self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv'])
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"/foo/bar")
+        self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"])
 
         # finish the request
         request.finish()
@@ -634,7 +783,7 @@ class MatrixFederationAgentTests(TestCase):
         clients = self.reactor.tcpClients
         self.assertEqual(len(clients), 1)
         (host, port, client_factory, _timeout, _bindAddress) = clients[0]
-        self.assertEqual(host, '1.2.3.4')
+        self.assertEqual(host, "1.2.3.4")
         self.assertEqual(port, 443)
 
         self.mock_resolver.resolve_service.side_effect = lambda _: [
@@ -655,20 +804,20 @@ class MatrixFederationAgentTests(TestCase):
         # now we should get a connection to the target of the SRV record
         self.assertEqual(len(clients), 2)
         (host, port, client_factory, _timeout, _bindAddress) = clients[1]
-        self.assertEqual(host, '5.6.7.8')
+        self.assertEqual(host, "5.6.7.8")
         self.assertEqual(port, 8443)
 
         # make a test server, and wire up the client
         http_server = self._make_connection(
-            client_factory, expected_sni=b'target-server'
+            client_factory, expected_sni=b"target-server"
         )
 
         self.assertEqual(len(http_server.requests), 1)
         request = http_server.requests[0]
-        self.assertEqual(request.method, b'GET')
-        self.assertEqual(request.path, b'/foo/bar')
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"/foo/bar")
         self.assertEqual(
-            request.requestHeaders.getRawHeaders(b'host'), [b'target-server']
+            request.requestHeaders.getRawHeaders(b"host"), [b"target-server"]
         )
 
         # finish the request
@@ -697,7 +846,7 @@ class MatrixFederationAgentTests(TestCase):
         clients = self.reactor.tcpClients
         self.assertEqual(len(clients), 1)
         (host, port, client_factory, _timeout, _bindAddress) = clients[0]
-        self.assertEqual(host, '1.2.3.4')
+        self.assertEqual(host, "1.2.3.4")
         self.assertEqual(port, 443)
 
         # fonx the connection
@@ -716,20 +865,20 @@ class MatrixFederationAgentTests(TestCase):
         clients = self.reactor.tcpClients
         self.assertEqual(len(clients), 2)
         (host, port, client_factory, _timeout, _bindAddress) = clients[1]
-        self.assertEqual(host, '1.2.3.4')
+        self.assertEqual(host, "1.2.3.4")
         self.assertEqual(port, 8448)
 
         # make a test server, and wire up the client
         http_server = self._make_connection(
-            client_factory, expected_sni=b'xn--bcher-kva.com'
+            client_factory, expected_sni=b"xn--bcher-kva.com"
         )
 
         self.assertEqual(len(http_server.requests), 1)
         request = http_server.requests[0]
-        self.assertEqual(request.method, b'GET')
-        self.assertEqual(request.path, b'/foo/bar')
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"/foo/bar")
         self.assertEqual(
-            request.requestHeaders.getRawHeaders(b'host'), [b'xn--bcher-kva.com']
+            request.requestHeaders.getRawHeaders(b"host"), [b"xn--bcher-kva.com"]
         )
 
         # finish the request
@@ -758,20 +907,20 @@ class MatrixFederationAgentTests(TestCase):
         clients = self.reactor.tcpClients
         self.assertEqual(len(clients), 1)
         (host, port, client_factory, _timeout, _bindAddress) = clients[0]
-        self.assertEqual(host, '1.2.3.4')
+        self.assertEqual(host, "1.2.3.4")
         self.assertEqual(port, 8443)
 
         # make a test server, and wire up the client
         http_server = self._make_connection(
-            client_factory, expected_sni=b'xn--bcher-kva.com'
+            client_factory, expected_sni=b"xn--bcher-kva.com"
         )
 
         self.assertEqual(len(http_server.requests), 1)
         request = http_server.requests[0]
-        self.assertEqual(request.method, b'GET')
-        self.assertEqual(request.path, b'/foo/bar')
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"/foo/bar")
         self.assertEqual(
-            request.requestHeaders.getRawHeaders(b'host'), [b'xn--bcher-kva.com']
+            request.requestHeaders.getRawHeaders(b"host"), [b"xn--bcher-kva.com"]
         )
 
         # finish the request
@@ -792,42 +941,42 @@ class MatrixFederationAgentTests(TestCase):
     def test_well_known_cache(self):
         self.reactor.lookups["testserv"] = "1.2.3.4"
 
-        fetch_d = self.do_get_well_known(b'testserv')
+        fetch_d = self.do_get_well_known(b"testserv")
 
         # there should be an attempt to connect on port 443 for the .well-known
         clients = self.reactor.tcpClients
         self.assertEqual(len(clients), 1)
         (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
-        self.assertEqual(host, '1.2.3.4')
+        self.assertEqual(host, "1.2.3.4")
         self.assertEqual(port, 443)
 
         well_known_server = self._handle_well_known_connection(
             client_factory,
             expected_sni=b"testserv",
-            response_headers={b'Cache-Control': b'max-age=10'},
+            response_headers={b"Cache-Control": b"max-age=10"},
             content=b'{ "m.server": "target-server" }',
         )
 
         r = self.successResultOf(fetch_d)
-        self.assertEqual(r, b'target-server')
+        self.assertEqual(r, b"target-server")
 
         # close the tcp connection
         well_known_server.loseConnection()
 
         # repeat the request: it should hit the cache
-        fetch_d = self.do_get_well_known(b'testserv')
+        fetch_d = self.do_get_well_known(b"testserv")
         r = self.successResultOf(fetch_d)
-        self.assertEqual(r, b'target-server')
+        self.assertEqual(r, b"target-server")
 
         # expire the cache
         self.reactor.pump((10.0,))
 
         # now it should connect again
-        fetch_d = self.do_get_well_known(b'testserv')
+        fetch_d = self.do_get_well_known(b"testserv")
 
         self.assertEqual(len(clients), 1)
         (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
-        self.assertEqual(host, '1.2.3.4')
+        self.assertEqual(host, "1.2.3.4")
         self.assertEqual(port, 443)
 
         self._handle_well_known_connection(
@@ -837,7 +986,7 @@ class MatrixFederationAgentTests(TestCase):
         )
 
         r = self.successResultOf(fetch_d)
-        self.assertEqual(r, b'other-server')
+        self.assertEqual(r, b"other-server")
 
 
 class TestCachePeriodFromHeaders(TestCase):
@@ -845,27 +994,27 @@ class TestCachePeriodFromHeaders(TestCase):
         # uppercase
         self.assertEqual(
             _cache_period_from_headers(
-                Headers({b'Cache-Control': [b'foo, Max-Age = 100, bar']})
+                Headers({b"Cache-Control": [b"foo, Max-Age = 100, bar"]})
             ),
             100,
         )
 
         # missing value
         self.assertIsNone(
-            _cache_period_from_headers(Headers({b'Cache-Control': [b'max-age=, bar']}))
+            _cache_period_from_headers(Headers({b"Cache-Control": [b"max-age=, bar"]}))
         )
 
         # hackernews: bogus due to semicolon
         self.assertIsNone(
             _cache_period_from_headers(
-                Headers({b'Cache-Control': [b'private; max-age=0']})
+                Headers({b"Cache-Control": [b"private; max-age=0"]})
             )
         )
 
         # github
         self.assertEqual(
             _cache_period_from_headers(
-                Headers({b'Cache-Control': [b'max-age=0, private, must-revalidate']})
+                Headers({b"Cache-Control": [b"max-age=0, private, must-revalidate"]})
             ),
             0,
         )
@@ -873,7 +1022,7 @@ class TestCachePeriodFromHeaders(TestCase):
         # google
         self.assertEqual(
             _cache_period_from_headers(
-                Headers({b'cache-control': [b'private, max-age=0']})
+                Headers({b"cache-control": [b"private, max-age=0"]})
             ),
             0,
         )
@@ -881,7 +1030,7 @@ class TestCachePeriodFromHeaders(TestCase):
     def test_expires(self):
         self.assertEqual(
             _cache_period_from_headers(
-                Headers({b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT']}),
+                Headers({b"Expires": [b"Wed, 30 Jan 2019 07:35:33 GMT"]}),
                 time_now=lambda: 1548833700,
             ),
             33,
@@ -892,8 +1041,8 @@ class TestCachePeriodFromHeaders(TestCase):
             _cache_period_from_headers(
                 Headers(
                     {
-                        b'cache-control': [b'max-age=10'],
-                        b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT'],
+                        b"cache-control": [b"max-age=10"],
+                        b"Expires": [b"Wed, 30 Jan 2019 07:35:33 GMT"],
                     }
                 ),
                 time_now=lambda: 1548833700,
@@ -902,7 +1051,7 @@ class TestCachePeriodFromHeaders(TestCase):
         )
 
         # invalid expires means immediate expiry
-        self.assertEqual(_cache_period_from_headers(Headers({b'Expires': [b'0']})), 0)
+        self.assertEqual(_cache_period_from_headers(Headers({b"Expires": [b"0"]})), 0)
 
 
 def _check_logcontext(context):
@@ -911,11 +1060,17 @@ def _check_logcontext(context):
         raise AssertionError("Expected logcontext %s but was %s" % (context, current))
 
 
-def _build_test_server():
+def _build_test_server(connection_creator):
     """Construct a test server
 
     This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol
 
+    Args:
+        connection_creator (IOpenSSLServerConnectionCreator): thing to build
+            SSL connections
+        sanlist (list[bytes]): list of the SAN entries for the cert returned
+            by the server
+
     Returns:
         TLSMemoryBIOProtocol
     """
@@ -924,7 +1079,7 @@ def _build_test_server():
     server_factory.log = _log_request
 
     server_tls_factory = TLSMemoryBIOFactory(
-        ServerTLSContext(), isClient=False, wrappedFactory=server_factory
+        connection_creator, isClient=False, wrappedFactory=server_factory
     )
 
     return server_tls_factory.buildProtocol(None)
@@ -937,7 +1092,8 @@ def _log_request(request):
 
 @implementer(IPolicyForHTTPS)
 class TrustingTLSPolicyForHTTPS(object):
-    """An IPolicyForHTTPS which doesn't do any certificate verification"""
+    """An IPolicyForHTTPS which checks that the certificate belongs to the
+    right server, but doesn't check the certificate chain."""
 
     def creatorForNetloc(self, hostname, port):
         certificateOptions = OpenSSLCertificateOptions()
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index 034c0db8d2..cf6c6e95b5 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -100,7 +100,7 @@ class SrvResolverTestCase(unittest.TestCase):
     def test_from_cache(self):
         clock = MockClock()
 
-        dns_client_mock = Mock(spec_set=['lookupService'])
+        dns_client_mock = Mock(spec_set=["lookupService"])
         dns_client_mock.lookupService = Mock(spec_set=[])
 
         service_name = b"test_service.example.com"
diff --git a/tests/http/server.key b/tests/http/server.key
new file mode 100644
index 0000000000..c53ee02b21
--- /dev/null
+++ b/tests/http/server.key
@@ -0,0 +1,27 @@
+-----BEGIN RSA PRIVATE KEY-----
+MIIEpAIBAAKCAQEAvUAWLOE6TEp3FYSfEnJMwYtJg3KIW5BjiAOOvFVOVQfJ5eEa
+vzyJ1Z+8DUgLznFnUkAeD9GjPvP7awl3NPJKLQSMkV5Tp+ea4YyV+Aa4R7flROEa
+zCGvmleydZw0VqN1atVZ0ikEoglM/APJQd70ec7KSR3QoxaV2/VNCHmyAPdP+0WI
+llV54VXX1CZrWSHaCSn1gzo3WjnGbxTOCQE5Z4k5hqJAwLWWhxDv+FX/jD38Sq3H
+gMFNpXJv6FYwwaKU8awghHdSY/qlBPE/1rU83vIBFJ3jW6I1WnQDfCQ69of5vshK
+N4v4hok56ScwdUnk8lw6xvJx1Uav/XQB9qGh4QIDAQABAoIBAQCHLO5p8hotAgdb
+JFZm26N9nxrMPBOvq0ucjEX4ucnwrFaGzynGrNwa7TRqHCrqs0/EjS2ryOacgbL0
+eldeRy26SASLlN+WD7UuI7e+6DXabDzj3RHB+tGuIbPDk+ZCeBDXVTsKBOhdQN1v
+KNkpJrJjCtSsMxKiWvCBow353srJKqCDZcF5NIBYBeDBPMoMbfYn5dJ9JhEf+2h4
+0iwpnWDX1Vqf46pCRa0hwEyMXycGeV2CnfJSyV7z52ZHQrvkz8QspSnPpnlCnbOE
+UAvc8kZ5e8oZE7W+JfkK38vHbEGM1FCrBmrC/46uUGMRpZfDferGs91RwQVq/F0n
+JN9hLzsBAoGBAPh2pm9Xt7a4fWSkX0cDgjI7PT2BvLUjbRwKLV+459uDa7+qRoGE
+sSwb2QBqmQ1kbr9JyTS+Ld8dyUTsGHZK+YbTieAxI3FBdKsuFtcYJO/REN0vik+6
+fMaBHPvDHSU2ioq7spZ4JBFskzqs38FvZ0lX7aa3fguMk8GMLnofQ8QxAoGBAML9
+o5sJLN9Tk9bv2aFgnERgfRfNjjV4Wd99TsktnCD04D1GrP2eDSLfpwFlCnguck6b
+jxikqcolsNhZH4dgYHqRNj+IljSdl+sYZiygO6Ld0XU+dEFO86N3E9NzZhKcQ1at
+85VdwNPCS7JM2fIxEvS9xfbVnsmK6/37ZZ5iI7yxAoGBALw2vRtJGmy60pojfd1A
+hibhAyINnlKlFGkSOI7zdgeuRTf6l9BTIRclvTt4hJpFgzM6hMWEbyE94hJoupsZ
+bm443o/LCWsox2VI05p6urhD6f9znNWKkiyY78izY+elqksvpjgfqEresaTYAeP5
+LQe9KNSK2VuMUP1j4G04M9BxAoGAWe8ITZJuytZOgrz/YIohqPvj1l2tcIYA1a6C
+7xEFSMIIxtpZIWSLZIFJEsCakpHBkPX4iwIveZfmt/JrM1JFTWK6ZZVGyh/BmOIZ
+Bg4lU1oBqJTUo+aZQtTCJS29b2n5OPpkNYkXTdP4e9UsVKNDvfPlYZJneUeEzxDr
+bqCPIRECgYA544KMwrWxDQZg1dsKWgdVVKx80wEFZAiQr9+0KF6ch6Iu7lwGJHFY
+iI6O85paX41qeC/Fo+feIWJVJU2GvG6eBsbO4bmq+KSg4NkABJSYxodgBp9ftNeD
+jo1tfw+gudlNe5jXHu7oSX93tqGjR4Cnlgan/KtfkB96yHOumGmOhQ==
+-----END RSA PRIVATE KEY-----
diff --git a/tests/http/server.pem b/tests/http/server.pem
deleted file mode 100644
index 0584cf1a80..0000000000
--- a/tests/http/server.pem
+++ /dev/null
@@ -1,81 +0,0 @@
------BEGIN PRIVATE KEY-----
-MIIJQgIBADANBgkqhkiG9w0BAQEFAASCCSwwggkoAgEAAoICAQCgF43/3lAgJ+p0
-x7Rn8UcL8a4fctvdkikvZrCngw96LkB34Evfq8YGWlOVjU+f9naUJLAKMatmAfEN
-r+rMX4VOXmpTwuu6iLtqwreUrRFMESyrmvQxa15p+y85gkY0CFmXMblv6ORbxHTG
-ncBGwST4WK4Poewcgt6jcISFCESTUKu1zc3cw1ANIDRyDLB5K44KwIe36dcKckyN
-Kdtv4BJ+3fcIZIkPJH62zqCypgFF1oiFt40uJzClxgHdJZlKYpgkfnDTckw4Y/Mx
-9k8BbE310KAzUNMV9H7I1eEolzrNr66FQj1eN64X/dqO8lTbwCqAd4diCT4sIUk0
-0SVsAUjNd3g8j651hx+Qb1t8fuOjrny8dmeMxtUgIBHoQcpcj76R55Fs7KZ9uar0
-8OFTyGIze51W1jG2K/7/5M1zxIqrA+7lsXu5OR81s7I+Ng/UUAhiHA/z+42/aiNa
-qEuk6tqj3rHfLctnCbtZ+JrRNqSSwEi8F0lMA021ivEd2eJV+284OyJjhXOmKHrX
-QADHrmS7Sh4syTZvRNm9n+qWID0KdDr2Sji/KnS3Enp44HDQ4xriT6/xhwEGsyuX
-oH5aAkdLznulbWkHBbyx1SUQSTLpOqzaioF9m1vRrLsFvrkrY3D253mPJ5eU9HM/
-dilduFcUgj4rz+6cdXUAh+KK/v95zwIDAQABAoICAFG5tJPaOa0ws0/KYx5s3YgL
-aIhFalhCNSQtmCDrlwsYcXDA3/rfBchYdDL0YKGYgBBAal3J3WXFt/j0xThvyu2m
-5UC9UPl4s7RckrsjXqEmY1d3UxGnbhtMT19cUdpeKN42VCP9EBaIw9Rg07dLAkSF
-gNYaIx6q8F0fI4eGIPvTQtUcqur4CfWpaxyNvckdovV6M85/YXfDwbCOnacPDGIX
-jfSK3i0MxGMuOHr6o8uzKR6aBUh6WStHWcw7VXXTvzdiFNbckmx3Gb93rf1b/LBw
-QFfx+tBKcC62gKroCOzXso/0sL9YTVeSD/DJZOiJwSiz3Dj/3u1IUMbVvfTU8wSi
-CYS7Z+jHxwSOCSSNTXm1wO/MtDsNKbI1+R0cohr/J9pOMQvrVh1+2zSDOFvXAQ1S
-yvjn+uqdmijRoV2VEGVHd+34C+ci7eJGAhL/f92PohuuFR2shUETgGWzpACZSJwg
-j1d90Hs81hj07vWRb+xCeDh00vimQngz9AD8vYvv/S4mqRGQ6TZdfjLoUwSTg0JD
-6sQgRXX026gQhLhn687vLKZfHwzQPZkpQdxOR0dTZ/ho/RyGGRJXH4kN4cA2tPr+
-AKYQ29YXGlEzGG7OqikaZcprNWG6UFgEpuXyBxCgp9r4ladZo3J+1Rhgus8ZYatd
-uO98q3WEBmP6CZ2n32mBAoIBAQDS/c/ybFTos0YpGHakwdmSfj5OOQJto2y8ywfG
-qDHwO0ebcpNnS1+MA+7XbKUQb/3Iq7iJljkkzJG2DIJ6rpKynYts1ViYpM7M/t0T
-W3V1gvUcUL62iqkgws4pnpWmubFkqV31cPSHcfIIclnzeQ1aOEGsGHNAvhty0ciC
-DnkJACbqApvopFLOR5f6UFTtKExE+hDH0WqgpsCAKJ1L4g6pBzZatI32/CN9JEVU
-tDbxLV75hHlFFjUrG7nT1rPyr/gI8Ceh9/2xeXPfjJUR0PrG3U1nwLqUCZkvFzO6
-XpN2+A+/v4v5xqMjKDKDFy1oq6SCMomwv/viw6wl/84TMbolAoIBAQDCPiMecnR8
-REik6tqVzQO/uSe9ZHjz6J15t5xdwaI6HpSwLlIkQPkLTjyXtFpemK5DOYRxrJvQ
-remfrZrN2qtLlb/DKpuGPWRsPOvWCrSuNEp48ivUehtclljrzxAFfy0sM+fWeJ48
-nTnR+td9KNhjNtZixzWdAy/mE+jdaMsXVnk66L73Uz+2WsnvVMW2R6cpCR0F2eP/
-B4zDWRqlT2w47sePAB81mFYSQLvPC6Xcgg1OqMubfiizJI49c8DO6Jt+FFYdsxhd
-kG52Eqa/Net6rN3ueiS6yXL5TU3Y6g96bPA2KyNCypucGcddcBfqaiVx/o4AH6yT
-NrdsrYtyvk/jAoIBAQDHUwKVeeRJJbvdbQAArCV4MI155n+1xhMe1AuXkCQFWGtQ
-nlBE4D72jmyf1UKnIbW2Uwv15xY6/ouVWYIWlj9+QDmMaozVP7Uiko+WDuwLRNl8
-k4dn+dzHV2HejbPBG2JLv3lFOx23q1zEwArcaXrExaq9Ayg2fKJ/uVHcFAIiD6Oz
-pR1XDY4w1A/uaN+iYFSVQUyDCQLbnEz1hej73CaPZoHh9Pq83vxD5/UbjVjuRTeZ
-L55FNzKpc/r89rNvTPBcuUwnxplDhYKDKVNWzn9rSXwrzTY2Tk8J3rh+k4RqevSd
-6D47jH1n5Dy7/TRn0ueKHGZZtTUnyEUkbOJo3ayFAoIBAHKDyZaQqaX9Z8p6fwWj
-yVsFoK0ih8BcWkLBAdmwZ6DWGJjJpjmjaG/G3ygc9s4gO1R8m12dAnuDnGE8KzDD
-gwtbrKM2Alyg4wyA2hTlWOH/CAzH0RlCJ9Fs/d1/xJVJBeuyajLiB3/6vXTS6qnq
-I7BSSxAPG8eGcn21LSsjNeB7ZZtaTgNnu/8ZBUYo9yrgkWc67TZe3/ChldYxOOlO
-qqHh/BqNWtjxB4VZTp/g4RbgQVInZ2ozdXEv0v/dt0UEk29ANAjsZif7F3RayJ2f
-/0TilzCaJ/9K9pKNhaClVRy7Dt8QjYg6BIWCGSw4ApF7pLnQ9gySn95mersCkVzD
-YDsCggEAb0E/TORjQhKfNQvahyLfQFm151e+HIoqBqa4WFyfFxe/IJUaLH/JSSFw
-VohbQqPdCmaAeuQ8ERL564DdkcY5BgKcax79fLLCOYP5bT11aQx6uFpfl2Dcm6Z9
-QdCRI4jzPftsd5fxLNH1XtGyC4t6vTic4Pji2O71WgWzx0j5v4aeDY4sZQeFxqCV
-/q7Ee8hem1Rn5RFHu14FV45RS4LAWl6wvf5pQtneSKzx8YL0GZIRRytOzdEfnGKr
-FeUlAj5uL+5/p0ZEgM7gPsEBwdm8scF79qSUn8UWSoXNeIauF9D4BDg8RZcFFxka
-KILVFsq3cQC+bEnoM4eVbjEQkGs1RQ==
------END PRIVATE KEY-----
------BEGIN CERTIFICATE-----
-MIIE/jCCAuagAwIBAgIJANFtVaGvJWZlMA0GCSqGSIb3DQEBCwUAMBMxETAPBgNV
-BAMMCHRlc3RzZXJ2MCAXDTE5MDEyNzIyMDIzNloYDzIxMTkwMTAzMjIwMjM2WjAT
-MREwDwYDVQQDDAh0ZXN0c2VydjCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoC
-ggIBAKAXjf/eUCAn6nTHtGfxRwvxrh9y292SKS9msKeDD3ouQHfgS9+rxgZaU5WN
-T5/2dpQksAoxq2YB8Q2v6sxfhU5ealPC67qIu2rCt5StEUwRLKua9DFrXmn7LzmC
-RjQIWZcxuW/o5FvEdMadwEbBJPhYrg+h7ByC3qNwhIUIRJNQq7XNzdzDUA0gNHIM
-sHkrjgrAh7fp1wpyTI0p22/gEn7d9whkiQ8kfrbOoLKmAUXWiIW3jS4nMKXGAd0l
-mUpimCR+cNNyTDhj8zH2TwFsTfXQoDNQ0xX0fsjV4SiXOs2vroVCPV43rhf92o7y
-VNvAKoB3h2IJPiwhSTTRJWwBSM13eDyPrnWHH5BvW3x+46OufLx2Z4zG1SAgEehB
-ylyPvpHnkWzspn25qvTw4VPIYjN7nVbWMbYr/v/kzXPEiqsD7uWxe7k5HzWzsj42
-D9RQCGIcD/P7jb9qI1qoS6Tq2qPesd8ty2cJu1n4mtE2pJLASLwXSUwDTbWK8R3Z
-4lX7bzg7ImOFc6YoetdAAMeuZLtKHizJNm9E2b2f6pYgPQp0OvZKOL8qdLcSenjg
-cNDjGuJPr/GHAQazK5egfloCR0vOe6VtaQcFvLHVJRBJMuk6rNqKgX2bW9GsuwW+
-uStjcPbneY8nl5T0cz92KV24VxSCPivP7px1dQCH4or+/3nPAgMBAAGjUzBRMB0G
-A1UdDgQWBBQcQZpzLzTk5KdS/Iz7sGCV7gTd/zAfBgNVHSMEGDAWgBQcQZpzLzTk
-5KdS/Iz7sGCV7gTd/zAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IC
-AQAr/Pgha57jqYsDDX1LyRrVdqoVBpLBeB7x/p9dKYm7S6tBTDFNMZ0SZyQP8VEG
-7UoC9/OQ9nCdEMoR7ZKpQsmipwcIqpXHS6l4YOkf5EEq5jpMgvlEesHmBJJeJew/
-FEPDl1bl8d0tSrmWaL3qepmwzA+2lwAAouWk2n+rLiP8CZ3jZeoTXFqYYrUlEqO9
-fHMvuWqTV4KCSyNY+GWCrnHetulgKHlg+W2J1mZnrCKcBhWf9C2DesTJO+JldIeM
-ornTFquSt21hZi+k3aySuMn2N3MWiNL8XsZVsAnPSs0zA+2fxjJkShls8Gc7cCvd
-a6XrNC+PY6pONguo7rEU4HiwbvnawSTngFFglmH/ImdA/HkaAekW6o82aI8/UxFx
-V9fFMO3iKDQdOrg77hI1bx9RlzKNZZinE2/Pu26fWd5d2zqDWCjl8ykGQRAfXgYN
-H3BjgyXLl+ao5/pOUYYtzm3ruTXTgRcy5hhL6hVTYhSrf9vYh4LNIeXNKnZ78tyG
-TX77/kU2qXhBGCFEUUMqUNV/+ITir2lmoxVjknt19M07aGr8C7SgYt6Rs+qDpMiy
-JurgvRh8LpVq4pHx1efxzxCFmo58DMrG40I0+CF3y/niNpOb1gp2wAqByRiORkds
-f0ytW6qZ0TpHbD6gOtQLYDnhx3ISuX+QYSekVwQUpffeWQ==
------END CERTIFICATE-----
diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py
index 3b0155ed03..b2e9533b07 100644
--- a/tests/http/test_endpoint.py
+++ b/tests/http/test_endpoint.py
@@ -20,12 +20,12 @@ from tests import unittest
 class ServerNameTestCase(unittest.TestCase):
     def test_parse_server_name(self):
         test_data = {
-            'localhost': ('localhost', None),
-            'my-example.com:1234': ('my-example.com', 1234),
-            '1.2.3.4': ('1.2.3.4', None),
-            '[0abc:1def::1234]': ('[0abc:1def::1234]', None),
-            '1.2.3.4:1': ('1.2.3.4', 1),
-            '[0abc:1def::1234]:8080': ('[0abc:1def::1234]', 8080),
+            "localhost": ("localhost", None),
+            "my-example.com:1234": ("my-example.com", 1234),
+            "1.2.3.4": ("1.2.3.4", None),
+            "[0abc:1def::1234]": ("[0abc:1def::1234]", None),
+            "1.2.3.4:1": ("1.2.3.4", 1),
+            "[0abc:1def::1234]:8080": ("[0abc:1def::1234]", 8080),
         }
 
         for i, o in test_data.items():
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index ee767f3a5a..c4c0d9b968 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -83,7 +83,7 @@ class FederationClientTests(HomeserverTestCase):
         clients = self.reactor.tcpClients
         self.assertEqual(len(clients), 1)
         (host, port, factory, _timeout, _bindAddress) = clients[0]
-        self.assertEqual(host, '1.2.3.4')
+        self.assertEqual(host, "1.2.3.4")
         self.assertEqual(port, 8008)
 
         # complete the connection and wire it up to a fake transport
@@ -99,7 +99,7 @@ class FederationClientTests(HomeserverTestCase):
         self.assertNoResult(test_d)
 
         # Send it the HTTP response
-        res_json = '{ "a": 1 }'.encode('ascii')
+        res_json = '{ "a": 1 }'.encode("ascii")
         protocol.dataReceived(
             b"HTTP/1.1 200 OK\r\n"
             b"Server: Fake\r\n"
@@ -138,7 +138,7 @@ class FederationClientTests(HomeserverTestCase):
         clients = self.reactor.tcpClients
         self.assertEqual(len(clients), 1)
         (host, port, factory, _timeout, _bindAddress) = clients[0]
-        self.assertEqual(host, '1.2.3.4')
+        self.assertEqual(host, "1.2.3.4")
         self.assertEqual(port, 8008)
         e = Exception("go away")
         factory.clientConnectionFailed(None, e)
@@ -164,7 +164,7 @@ class FederationClientTests(HomeserverTestCase):
         # Make sure treq is trying to connect
         clients = self.reactor.tcpClients
         self.assertEqual(len(clients), 1)
-        self.assertEqual(clients[0][0], '1.2.3.4')
+        self.assertEqual(clients[0][0], "1.2.3.4")
         self.assertEqual(clients[0][1], 8008)
 
         # Deferred is still without a result
@@ -194,7 +194,7 @@ class FederationClientTests(HomeserverTestCase):
         # Make sure treq is trying to connect
         clients = self.reactor.tcpClients
         self.assertEqual(len(clients), 1)
-        self.assertEqual(clients[0][0], '1.2.3.4')
+        self.assertEqual(clients[0][0], "1.2.3.4")
         self.assertEqual(clients[0][1], 8008)
 
         conn = Mock()
@@ -215,10 +215,9 @@ class FederationClientTests(HomeserverTestCase):
         """Ensure that Synapse does not try to connect to blacklisted IPs"""
 
         # Set up the ip_range blacklist
-        self.hs.config.federation_ip_range_blacklist = IPSet([
-            "127.0.0.0/8",
-            "fe80::/64",
-        ])
+        self.hs.config.federation_ip_range_blacklist = IPSet(
+            ["127.0.0.0/8", "fe80::/64"]
+        )
         self.reactor.lookups["internal"] = "127.0.0.1"
         self.reactor.lookups["internalv6"] = "fe80:0:0:0:0:8a2e:370:7337"
         self.reactor.lookups["fine"] = "10.20.30.40"
@@ -382,7 +381,7 @@ class FederationClientTests(HomeserverTestCase):
             b"Content-Type: application/json\r\n"
             b"Content-Length: 2\r\n"
             b"\r\n"
-            b'{}'
+            b"{}"
         )
 
         # We should get a successful response
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 9cdde1a9bd..358b593cd4 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -15,6 +15,7 @@
 
 import os
 
+import attr
 import pkg_resources
 
 from twisted.internet.defer import Deferred
@@ -24,15 +25,16 @@ from synapse.rest.client.v1 import login, room
 
 from tests.unittest import HomeserverTestCase
 
-try:
-    from synapse.push.mailer import load_jinja2_templates
-except Exception:
-    load_jinja2_templates = None
+
+@attr.s
+class _User(object):
+    "Helper wrapper for user ID and access token"
+    id = attr.ib()
+    token = attr.ib()
 
 
 class EmailPusherTests(HomeserverTestCase):
 
-    skip = "No Jinja installed" if not load_jinja2_templates else None
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         room.register_servlets,
@@ -55,7 +57,7 @@ class EmailPusherTests(HomeserverTestCase):
         config["email"] = {
             "enable_notifs": True,
             "template_dir": os.path.abspath(
-                pkg_resources.resource_filename('synapse', 'res/templates')
+                pkg_resources.resource_filename("synapse", "res/templates")
             ),
             "expiry_template_html": "notice_expiry.html",
             "expiry_template_text": "notice_expiry.txt",
@@ -77,25 +79,32 @@ class EmailPusherTests(HomeserverTestCase):
 
         return hs
 
-    def test_sends_email(self):
-
+    def prepare(self, reactor, clock, hs):
         # Register the user who gets notified
-        user_id = self.register_user("user", "pass")
-        access_token = self.login("user", "pass")
-
-        # Register the user who sends the message
-        other_user_id = self.register_user("otheruser", "pass")
-        other_access_token = self.login("otheruser", "pass")
+        self.user_id = self.register_user("user", "pass")
+        self.access_token = self.login("user", "pass")
+
+        # Register other users
+        self.others = [
+            _User(
+                id=self.register_user("otheruser1", "pass"),
+                token=self.login("otheruser1", "pass"),
+            ),
+            _User(
+                id=self.register_user("otheruser2", "pass"),
+                token=self.login("otheruser2", "pass"),
+            ),
+        ]
 
         # Register the pusher
         user_tuple = self.get_success(
-            self.hs.get_datastore().get_user_by_access_token(access_token)
+            self.hs.get_datastore().get_user_by_access_token(self.access_token)
         )
         token_id = user_tuple["token_id"]
 
-        self.get_success(
+        self.pusher = self.get_success(
             self.hs.get_pusherpool().add_pusher(
-                user_id=user_id,
+                user_id=self.user_id,
                 access_token=token_id,
                 kind="email",
                 app_id="m.email",
@@ -107,22 +116,54 @@ class EmailPusherTests(HomeserverTestCase):
             )
         )
 
-        # Create a room
-        room = self.helper.create_room_as(user_id, tok=access_token)
+    def test_simple_sends_email(self):
+        # Create a simple room with two users
+        room = self.helper.create_room_as(self.user_id, tok=self.access_token)
+        self.helper.invite(
+            room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id
+        )
+        self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
 
-        # Invite the other person
-        self.helper.invite(room=room, src=user_id, tok=access_token, targ=other_user_id)
+        # The other user sends some messages
+        self.helper.send(room, body="Hi!", tok=self.others[0].token)
+        self.helper.send(room, body="There!", tok=self.others[0].token)
 
-        # The other user joins
-        self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+        # We should get emailed about that message
+        self._check_for_mail()
 
-        # The other user sends some messages
-        self.helper.send(room, body="Hi!", tok=other_access_token)
-        self.helper.send(room, body="There!", tok=other_access_token)
+    def test_multiple_members_email(self):
+        # We want to test multiple notifications, so we pause processing of push
+        # while we send messages.
+        self.pusher._pause_processing()
+
+        # Create a simple room with multiple other users
+        room = self.helper.create_room_as(self.user_id, tok=self.access_token)
+
+        for other in self.others:
+            self.helper.invite(
+                room=room, src=self.user_id, tok=self.access_token, targ=other.id
+            )
+            self.helper.join(room=room, user=other.id, tok=other.token)
+
+        # The other users send some messages
+        self.helper.send(room, body="Hi!", tok=self.others[0].token)
+        self.helper.send(room, body="There!", tok=self.others[1].token)
+        self.helper.send(room, body="There!", tok=self.others[1].token)
+
+        # Nothing should have happened yet, as we're paused.
+        assert not self.email_attempts
+
+        self.pusher._resume_processing()
+
+        # We should get emailed about those messages
+        self._check_for_mail()
+
+    def _check_for_mail(self):
+        "Check that the user receives an email notification"
 
         # Get the stream ordering before it gets sent
         pushers = self.get_success(
-            self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
+            self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id))
         )
         self.assertEqual(len(pushers), 1)
         last_stream_ordering = pushers[0]["last_stream_ordering"]
@@ -132,7 +173,7 @@ class EmailPusherTests(HomeserverTestCase):
 
         # It hasn't succeeded yet, so the stream ordering shouldn't have moved
         pushers = self.get_success(
-            self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
+            self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id))
         )
         self.assertEqual(len(pushers), 1)
         self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"])
@@ -149,7 +190,7 @@ class EmailPusherTests(HomeserverTestCase):
 
         # The stream ordering has increased
         pushers = self.get_success(
-            self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
+            self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id))
         )
         self.assertEqual(len(pushers), 1)
         self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index aba618b2be..22c3f73ef3 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -23,15 +23,9 @@ from synapse.util.logcontext import make_deferred_yieldable
 
 from tests.unittest import HomeserverTestCase
 
-try:
-    from synapse.push.mailer import load_jinja2_templates
-except Exception:
-    load_jinja2_templates = None
-
 
 class HTTPPusherTests(HomeserverTestCase):
 
-    skip = "No Jinja installed" if not load_jinja2_templates else None
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         room.register_servlets,
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index ee5f09041f..5877bb2133 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -30,7 +30,7 @@ from tests import unittest
 
 
 class VersionTestCase(unittest.HomeserverTestCase):
-    url = '/_synapse/admin/v1/server_version'
+    url = "/_synapse/admin/v1/server_version"
 
     def create_test_json_resource(self):
         resource = JsonResource(self.hs)
@@ -43,7 +43,7 @@ class VersionTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(
-            {'server_version', 'python_version'}, set(channel.json_body.keys())
+            {"server_version", "python_version"}, set(channel.json_body.keys())
         )
 
 
@@ -68,7 +68,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
 
         self.hs = self.setup_test_homeserver()
 
-        self.hs.config.registration_shared_secret = u"shared"
+        self.hs.config.registration_shared_secret = "shared"
 
         self.hs.get_media_repository = Mock()
         self.hs.get_deactivate_account_handler = Mock()
@@ -82,12 +82,12 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         """
         self.hs.config.registration_shared_secret = None
 
-        request, channel = self.make_request("POST", self.url, b'{}')
+        request, channel = self.make_request("POST", self.url, b"{}")
         self.render(request)
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(
-            'Shared secret registration is not enabled', channel.json_body["error"]
+            "Shared secret registration is not enabled", channel.json_body["error"]
         )
 
     def test_get_nonce(self):
@@ -118,20 +118,20 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         self.reactor.advance(59)
 
         body = json.dumps({"nonce": nonce})
-        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
         self.render(request)
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual('username must be specified', channel.json_body["error"])
+        self.assertEqual("username must be specified", channel.json_body["error"])
 
         # 61 seconds
         self.reactor.advance(2)
 
-        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
         self.render(request)
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual('unrecognised nonce', channel.json_body["error"])
+        self.assertEqual("unrecognised nonce", channel.json_body["error"])
 
     def test_register_incorrect_nonce(self):
         """
@@ -154,7 +154,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
                 "mac": want_mac,
             }
         )
-        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
         self.render(request)
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
@@ -171,7 +171,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
 
         want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
         want_mac.update(
-            nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin\x00support"
+            nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support"
         )
         want_mac = want_mac.hexdigest()
 
@@ -185,7 +185,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
                 "mac": want_mac,
             }
         )
-        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
         self.render(request)
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -200,7 +200,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         nonce = channel.json_body["nonce"]
 
         want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
-        want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin")
+        want_mac.update(nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin")
         want_mac = want_mac.hexdigest()
 
         body = json.dumps(
@@ -212,18 +212,18 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
                 "mac": want_mac,
             }
         )
-        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
         self.render(request)
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob:test", channel.json_body["user_id"])
 
         # Now, try and reuse it
-        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
         self.render(request)
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual('unrecognised nonce', channel.json_body["error"])
+        self.assertEqual("unrecognised nonce", channel.json_body["error"])
 
     def test_missing_parts(self):
         """
@@ -243,11 +243,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
 
         # Must be present
         body = json.dumps({})
-        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
         self.render(request)
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual('nonce must be specified', channel.json_body["error"])
+        self.assertEqual("nonce must be specified", channel.json_body["error"])
 
         #
         # Username checks
@@ -255,35 +255,35 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
 
         # Must be present
         body = json.dumps({"nonce": nonce()})
-        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
         self.render(request)
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual('username must be specified', channel.json_body["error"])
+        self.assertEqual("username must be specified", channel.json_body["error"])
 
         # Must be a string
         body = json.dumps({"nonce": nonce(), "username": 1234})
-        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
         self.render(request)
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual('Invalid username', channel.json_body["error"])
+        self.assertEqual("Invalid username", channel.json_body["error"])
 
         # Must not have null bytes
-        body = json.dumps({"nonce": nonce(), "username": u"abcd\u0000"})
-        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"})
+        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
         self.render(request)
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual('Invalid username', channel.json_body["error"])
+        self.assertEqual("Invalid username", channel.json_body["error"])
 
         # Must not have null bytes
         body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
-        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
         self.render(request)
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual('Invalid username', channel.json_body["error"])
+        self.assertEqual("Invalid username", channel.json_body["error"])
 
         #
         # Password checks
@@ -291,37 +291,35 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
 
         # Must be present
         body = json.dumps({"nonce": nonce(), "username": "a"})
-        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
         self.render(request)
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual('password must be specified', channel.json_body["error"])
+        self.assertEqual("password must be specified", channel.json_body["error"])
 
         # Must be a string
         body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
-        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
         self.render(request)
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual('Invalid password', channel.json_body["error"])
+        self.assertEqual("Invalid password", channel.json_body["error"])
 
         # Must not have null bytes
-        body = json.dumps(
-            {"nonce": nonce(), "username": "a", "password": u"abcd\u0000"}
-        )
-        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"})
+        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
         self.render(request)
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual('Invalid password', channel.json_body["error"])
+        self.assertEqual("Invalid password", channel.json_body["error"])
 
         # Super long
         body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
-        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
         self.render(request)
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual('Invalid password', channel.json_body["error"])
+        self.assertEqual("Invalid password", channel.json_body["error"])
 
         #
         # user_type check
@@ -336,11 +334,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
                 "user_type": "invalid",
             }
         )
-        request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
         self.render(request)
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual('Invalid user type', channel.json_body["error"])
+        self.assertEqual("Invalid user type", channel.json_body["error"])
 
 
 class ShutdownRoomTestCase(unittest.HomeserverTestCase):
@@ -396,7 +394,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
         url = "admin/shutdown_room/" + room_id
         request, channel = self.make_request(
             "POST",
-            url.encode('ascii'),
+            url.encode("ascii"),
             json.dumps({"new_room_user_id": self.admin_user}),
             access_token=self.admin_user_tok,
         )
@@ -408,7 +406,6 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
         users_in_room = self.get_success(self.store.get_users_in_room(room_id))
         self.assertEqual([], users_in_room)
 
-    @unittest.DEBUG
     def test_shutdown_room_block_peek(self):
         """Test that a world_readable room can no longer be peeked into after
         it has been shut down.
@@ -422,7 +419,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
         url = "rooms/%s/state/m.room.history_visibility" % (room_id,)
         request, channel = self.make_request(
             "PUT",
-            url.encode('ascii'),
+            url.encode("ascii"),
             json.dumps({"history_visibility": "world_readable"}),
             access_token=self.other_user_token,
         )
@@ -433,7 +430,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
         url = "admin/shutdown_room/" + room_id
         request, channel = self.make_request(
             "POST",
-            url.encode('ascii'),
+            url.encode("ascii"),
             json.dumps({"new_room_user_id": self.admin_user}),
             access_token=self.admin_user_tok,
         )
@@ -450,7 +447,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
 
         url = "rooms/%s/initialSync" % (room_id,)
         request, channel = self.make_request(
-            "GET", url.encode('ascii'), access_token=self.admin_user_tok
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok
         )
         self.render(request)
         self.assertEqual(
@@ -459,7 +456,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
 
         url = "events?timeout=0&room_id=" + room_id
         request, channel = self.make_request(
-            "GET", url.encode('ascii'), access_token=self.admin_user_tok
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok
         )
         self.render(request)
         self.assertEqual(
@@ -487,7 +484,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
         # Create a new group
         request, channel = self.make_request(
             "POST",
-            "/create_group".encode('ascii'),
+            "/create_group".encode("ascii"),
             access_token=self.admin_user_tok,
             content={"localpart": "test"},
         )
@@ -503,14 +500,14 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
 
         url = "/groups/%s/admin/users/invite/%s" % (group_id, self.other_user)
         request, channel = self.make_request(
-            "PUT", url.encode('ascii'), access_token=self.admin_user_tok, content={}
+            "PUT", url.encode("ascii"), access_token=self.admin_user_tok, content={}
         )
         self.render(request)
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
 
         url = "/groups/%s/self/accept_invite" % (group_id,)
         request, channel = self.make_request(
-            "PUT", url.encode('ascii'), access_token=self.other_user_token, content={}
+            "PUT", url.encode("ascii"), access_token=self.other_user_token, content={}
         )
         self.render(request)
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -523,7 +520,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
         url = "/admin/delete_group/" + group_id
         request, channel = self.make_request(
             "POST",
-            url.encode('ascii'),
+            url.encode("ascii"),
             access_token=self.admin_user_tok,
             content={"localpart": "test"},
         )
@@ -545,7 +542,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
 
         url = "/groups/%s/profile" % (group_id,)
         request, channel = self.make_request(
-            "GET", url.encode('ascii'), access_token=self.admin_user_tok
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok
         )
 
         self.render(request)
@@ -557,7 +554,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
         """Returns the list of groups the user is in (given their access token)
         """
         request, channel = self.make_request(
-            "GET", "/joined_groups".encode('ascii'), access_token=access_token
+            "GET", "/joined_groups".encode("ascii"), access_token=access_token
         )
 
         self.render(request)
diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py
index 88f8f1abdc..6803b372ac 100644
--- a/tests/rest/client/test_consent.py
+++ b/tests/rest/client/test_consent.py
@@ -23,14 +23,8 @@ from synapse.rest.consent import consent_resource
 from tests import unittest
 from tests.server import render
 
-try:
-    from synapse.push.mailer import load_jinja2_templates
-except Exception:
-    load_jinja2_templates = None
-
 
 class ConsentResourceTestCase(unittest.HomeserverTestCase):
-    skip = "No Jinja installed" if not load_jinja2_templates else None
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         room.register_servlets,
@@ -48,17 +42,17 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
         # Make some temporary templates...
         temp_consent_path = self.mktemp()
         os.mkdir(temp_consent_path)
-        os.mkdir(os.path.join(temp_consent_path, 'en'))
+        os.mkdir(os.path.join(temp_consent_path, "en"))
 
         config["user_consent"] = {
             "version": "1",
             "template_dir": os.path.abspath(temp_consent_path),
         }
 
-        with open(os.path.join(temp_consent_path, "en/1.html"), 'w') as f:
+        with open(os.path.join(temp_consent_path, "en/1.html"), "w") as f:
             f.write("{{version}},{{has_consented}}")
 
-        with open(os.path.join(temp_consent_path, "en/success.html"), 'w') as f:
+        with open(os.path.join(temp_consent_path, "en/success.html"), "w") as f:
             f.write("yay!")
 
         hs = self.setup_test_homeserver(config=config)
@@ -94,7 +88,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200)
 
         # Get the version from the body, and whether we've consented
-        version, consented = channel.result["body"].decode('ascii').split(",")
+        version, consented = channel.result["body"].decode("ascii").split(",")
         self.assertEqual(consented, "False")
 
         # POST to the consent page, saying we've agreed
@@ -117,6 +111,6 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
 
         # Get the version from the body, and check that it's the version we
         # agreed to, and that we've consented to it.
-        version, consented = channel.result["body"].decode('ascii').split(",")
+        version, consented = channel.result["body"].decode("ascii").split(",")
         self.assertEqual(consented, "True")
         self.assertEqual(version, "1")
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index 68949307d9..c973521907 100644
--- a/tests/rest/client/test_identity.py
+++ b/tests/rest/client/test_identity.py
@@ -56,7 +56,7 @@ class IdentityTestCase(unittest.HomeserverTestCase):
             "address": "test@example.com",
         }
         request_data = json.dumps(params)
-        request_url = ("/rooms/%s/invite" % (room_id)).encode('ascii')
+        request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
         request, channel = self.make_request(
             b"POST", request_url, request_data, access_token=tok
         )
diff --git a/tests/rest/client/third_party_rules.py b/tests/rest/client/third_party_rules.py
new file mode 100644
index 0000000000..7167fc56b6
--- /dev/null
+++ b/tests/rest/client/third_party_rules.py
@@ -0,0 +1,79 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests import unittest
+
+
+class ThirdPartyRulesTestModule(object):
+    def __init__(self, config):
+        pass
+
+    def check_event_allowed(self, event, context):
+        if event.type == "foo.bar.forbidden":
+            return False
+        else:
+            return True
+
+    @staticmethod
+    def parse_config(config):
+        return config
+
+
+class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor, clock):
+        config = self.default_config()
+        config["third_party_event_rules"] = {
+            "module": "tests.rest.client.third_party_rules.ThirdPartyRulesTestModule",
+            "config": {},
+        }
+
+        self.hs = self.setup_test_homeserver(config=config)
+        return self.hs
+
+    def test_third_party_rules(self):
+        """Tests that a forbidden event is forbidden from being sent, but an allowed one
+        can be sent.
+        """
+        user_id = self.register_user("kermit", "monkey")
+        tok = self.login("kermit", "monkey")
+
+        room_id = self.helper.create_room_as(user_id, tok=tok)
+
+        request, channel = self.make_request(
+            "PUT",
+            "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % room_id,
+            {},
+            access_token=tok,
+        )
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        request, channel = self.make_request(
+            "PUT",
+            "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % room_id,
+            {},
+            access_token=tok,
+        )
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"403", channel.result)
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 769c37ce52..dff9b2f10c 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -14,6 +14,8 @@
 # limitations under the License.
 
 """Tests REST events for /profile paths."""
+import json
+
 from mock import Mock
 
 from twisted.internet import defer
@@ -28,11 +30,14 @@ from tests import unittest
 from ....utils import MockHttpResource, setup_test_homeserver
 
 myid = "@1234ABCD:test"
-PATH_PREFIX = "/_matrix/client/api/v1"
+PATH_PREFIX = "/_matrix/client/r0"
+
 
+class MockHandlerProfileTestCase(unittest.TestCase):
+    """ Tests rest layer of profile management.
 
-class ProfileTestCase(unittest.TestCase):
-    """ Tests profile management. """
+    Todo: move these into ProfileTestCase
+    """
 
     @defer.inlineCallbacks
     def setUp(self):
@@ -159,6 +164,58 @@ class ProfileTestCase(unittest.TestCase):
         self.assertEquals(mocked_set.call_args[0][2], "http://my.server/pic.gif")
 
 
+class ProfileTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        admin.register_servlets_for_client_rest_resource,
+        login.register_servlets,
+        profile.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor, clock):
+        self.hs = self.setup_test_homeserver()
+        return self.hs
+
+    def prepare(self, reactor, clock, hs):
+        self.owner = self.register_user("owner", "pass")
+        self.owner_tok = self.login("owner", "pass")
+
+    def test_set_displayname(self):
+        request, channel = self.make_request(
+            "PUT",
+            "/profile/%s/displayname" % (self.owner,),
+            content=json.dumps({"displayname": "test"}),
+            access_token=self.owner_tok,
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200, channel.result)
+
+        res = self.get_displayname()
+        self.assertEqual(res, "test")
+
+    def test_set_displayname_too_long(self):
+        """Attempts to set a stupid displayname should get a 400"""
+        request, channel = self.make_request(
+            "PUT",
+            "/profile/%s/displayname" % (self.owner,),
+            content=json.dumps({"displayname": "test" * 100}),
+            access_token=self.owner_tok,
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 400, channel.result)
+
+        res = self.get_displayname()
+        self.assertEqual(res, "owner")
+
+    def get_displayname(self):
+        request, channel = self.make_request(
+            "GET", "/profile/%s/displayname" % (self.owner,)
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200, channel.result)
+        return channel.json_body["displayname"]
+
+
 class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
 
     servlets = [
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 5f75ad7579..fe741637f5 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -79,7 +79,7 @@ class RoomPermissionsTestCase(RoomBase):
         # send a message in one of the rooms
         self.created_rmid_msg_path = (
             "rooms/%s/send/m.room.message/a1" % (self.created_rmid)
-        ).encode('ascii')
+        ).encode("ascii")
         request, channel = self.make_request(
             "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}'
         )
@@ -89,7 +89,7 @@ class RoomPermissionsTestCase(RoomBase):
         # set topic for public room
         request, channel = self.make_request(
             "PUT",
-            ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode('ascii'),
+            ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"),
             b'{"topic":"Public Room Topic"}',
         )
         self.render(request)
@@ -193,7 +193,7 @@ class RoomPermissionsTestCase(RoomBase):
         request, channel = self.make_request("GET", topic_path)
         self.render(request)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
-        self.assert_dict(json.loads(topic_content.decode('utf8')), channel.json_body)
+        self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body)
 
         # set/get topic in created PRIVATE room and left, expect 403
         self.helper.leave(room=self.created_rmid, user=self.user_id)
@@ -497,7 +497,7 @@ class RoomTopicTestCase(RoomBase):
 
     def test_invalid_puts(self):
         # missing keys or invalid json
-        request, channel = self.make_request("PUT", self.path, '{}')
+        request, channel = self.make_request("PUT", self.path, "{}")
         self.render(request)
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
@@ -515,11 +515,11 @@ class RoomTopicTestCase(RoomBase):
         self.render(request)
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("PUT", self.path, 'text only')
+        request, channel = self.make_request("PUT", self.path, "text only")
         self.render(request)
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("PUT", self.path, '')
+        request, channel = self.make_request("PUT", self.path, "")
         self.render(request)
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
@@ -572,7 +572,7 @@ class RoomMemberStateTestCase(RoomBase):
     def test_invalid_puts(self):
         path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id)
         # missing keys or invalid json
-        request, channel = self.make_request("PUT", path, '{}')
+        request, channel = self.make_request("PUT", path, "{}")
         self.render(request)
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
@@ -590,11 +590,11 @@ class RoomMemberStateTestCase(RoomBase):
         self.render(request)
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("PUT", path, 'text only')
+        request, channel = self.make_request("PUT", path, "text only")
         self.render(request)
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("PUT", path, '')
+        request, channel = self.make_request("PUT", path, "")
         self.render(request)
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
@@ -604,7 +604,7 @@ class RoomMemberStateTestCase(RoomBase):
             Membership.JOIN,
             Membership.LEAVE,
         )
-        request, channel = self.make_request("PUT", path, content.encode('ascii'))
+        request, channel = self.make_request("PUT", path, content.encode("ascii"))
         self.render(request)
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
@@ -616,7 +616,7 @@ class RoomMemberStateTestCase(RoomBase):
 
         # valid join message (NOOP since we made the room)
         content = '{"membership":"%s"}' % Membership.JOIN
-        request, channel = self.make_request("PUT", path, content.encode('ascii'))
+        request, channel = self.make_request("PUT", path, content.encode("ascii"))
         self.render(request)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
 
@@ -678,7 +678,7 @@ class RoomMessagesTestCase(RoomBase):
     def test_invalid_puts(self):
         path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
         # missing keys or invalid json
-        request, channel = self.make_request("PUT", path, b'{}')
+        request, channel = self.make_request("PUT", path, b"{}")
         self.render(request)
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
@@ -696,11 +696,11 @@ class RoomMessagesTestCase(RoomBase):
         self.render(request)
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("PUT", path, b'text only')
+        request, channel = self.make_request("PUT", path, b"text only")
         self.render(request)
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("PUT", path, b'')
+        request, channel = self.make_request("PUT", path, b"")
         self.render(request)
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
@@ -786,7 +786,7 @@ class RoomMessageListTestCase(RoomBase):
         self.render(request)
         self.assertEquals(200, channel.code)
         self.assertTrue("start" in channel.json_body)
-        self.assertEquals(token, channel.json_body['start'])
+        self.assertEquals(token, channel.json_body["start"])
         self.assertTrue("chunk" in channel.json_body)
         self.assertTrue("end" in channel.json_body)
 
@@ -798,7 +798,7 @@ class RoomMessageListTestCase(RoomBase):
         self.render(request)
         self.assertEquals(200, channel.code)
         self.assertTrue("start" in channel.json_body)
-        self.assertEquals(token, channel.json_body['start'])
+        self.assertEquals(token, channel.json_body["start"])
         self.assertTrue("chunk" in channel.json_body)
         self.assertTrue("end" in channel.json_body)
 
@@ -920,7 +920,7 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
         self.url = b"/_matrix/client/r0/publicRooms"
 
         config = self.default_config()
-        config["restrict_public_rooms_to_local_users"] = True
+        config["allow_public_rooms_without_auth"] = False
         self.hs = self.setup_test_homeserver(config=config)
 
         return self.hs
@@ -961,9 +961,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
 
         # Set a profile for the test user
         self.displayname = "test user"
-        data = {
-            "displayname": self.displayname,
-        }
+        data = {"displayname": self.displayname}
         request_data = json.dumps(data)
         request, channel = self.make_request(
             "PUT",
@@ -977,16 +975,12 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
         self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
 
     def test_per_room_profile_forbidden(self):
-        data = {
-            "membership": "join",
-            "displayname": "other test user"
-        }
+        data = {"membership": "join", "displayname": "other test user"}
         request_data = json.dumps(data)
         request, channel = self.make_request(
             "PUT",
-            "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (
-                self.room_id, self.user_id,
-            ),
+            "/_matrix/client/r0/rooms/%s/state/m.room.member/%s"
+            % (self.room_id, self.user_id),
             request_data,
             access_token=self.tok,
         )
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 05b0143c42..9915367144 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -44,7 +44,7 @@ class RestHelper(object):
             path = path + "?access_token=%s" % tok
 
         request, channel = make_request(
-            self.hs.get_reactor(), "POST", path, json.dumps(content).encode('utf8')
+            self.hs.get_reactor(), "POST", path, json.dumps(content).encode("utf8")
         )
         render(request, self.resource, self.hs.get_reactor())
 
@@ -93,7 +93,7 @@ class RestHelper(object):
         data = {"membership": membership}
 
         request, channel = make_request(
-            self.hs.get_reactor(), "PUT", path, json.dumps(data).encode('utf8')
+            self.hs.get_reactor(), "PUT", path, json.dumps(data).encode("utf8")
         )
 
         render(request, self.resource, self.hs.get_reactor())
@@ -117,7 +117,24 @@ class RestHelper(object):
             path = path + "?access_token=%s" % tok
 
         request, channel = make_request(
-            self.hs.get_reactor(), "PUT", path, json.dumps(content).encode('utf8')
+            self.hs.get_reactor(), "PUT", path, json.dumps(content).encode("utf8")
+        )
+        render(request, self.resource, self.hs.get_reactor())
+
+        assert int(channel.result["code"]) == expect_code, (
+            "Expected: %d, got: %d, resp: %r"
+            % (expect_code, int(channel.result["code"]), channel.result["body"])
+        )
+
+        return channel.json_body
+
+    def send_state(self, room_id, event_type, body, tok, expect_code=200):
+        path = "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, event_type)
+        if tok:
+            path = path + "?access_token=%s" % tok
+
+        request, channel = make_request(
+            self.hs.get_reactor(), "PUT", path, json.dumps(body).encode("utf8")
         )
         render(request, self.resource, self.hs.get_reactor())
 
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
new file mode 100644
index 0000000000..920de41de4
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -0,0 +1,281 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+import re
+from email.parser import Parser
+
+import pkg_resources
+
+import synapse.rest.admin
+from synapse.api.constants import LoginType
+from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import account, register
+
+from tests import unittest
+
+
+class PasswordResetTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        account.register_servlets,
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        register.register_servlets,
+        login.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor, clock):
+        config = self.default_config()
+
+        # Email config.
+        self.email_attempts = []
+
+        def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs):
+            self.email_attempts.append(msg)
+            return
+
+        config["email"] = {
+            "enable_notifs": False,
+            "template_dir": os.path.abspath(
+                pkg_resources.resource_filename("synapse", "res/templates")
+            ),
+            "smtp_host": "127.0.0.1",
+            "smtp_port": 20,
+            "require_transport_security": False,
+            "smtp_user": None,
+            "smtp_pass": None,
+            "notif_from": "test@example.com",
+        }
+        config["public_baseurl"] = "https://example.com"
+
+        hs = self.setup_test_homeserver(config=config, sendmail=sendmail)
+        return hs
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+
+    def test_basic_password_reset(self):
+        """Test basic password reset flow
+        """
+        old_password = "monkey"
+        new_password = "kangeroo"
+
+        user_id = self.register_user("kermit", old_password)
+        self.login("kermit", old_password)
+
+        email = "test@example.com"
+
+        # Add a threepid
+        self.get_success(
+            self.store.user_add_threepid(
+                user_id=user_id,
+                medium="email",
+                address=email,
+                validated_at=0,
+                added_at=0,
+            )
+        )
+
+        client_secret = "foobar"
+        session_id = self._request_token(email, client_secret)
+
+        self.assertEquals(len(self.email_attempts), 1)
+        link = self._get_link_from_email()
+
+        self._validate_token(link)
+
+        self._reset_password(new_password, session_id, client_secret)
+
+        # Assert we can log in with the new password
+        self.login("kermit", new_password)
+
+        # Assert we can't log in with the old password
+        self.attempt_wrong_password_login("kermit", old_password)
+
+    def test_cant_reset_password_without_clicking_link(self):
+        """Test that we do actually need to click the link in the email
+        """
+        old_password = "monkey"
+        new_password = "kangeroo"
+
+        user_id = self.register_user("kermit", old_password)
+        self.login("kermit", old_password)
+
+        email = "test@example.com"
+
+        # Add a threepid
+        self.get_success(
+            self.store.user_add_threepid(
+                user_id=user_id,
+                medium="email",
+                address=email,
+                validated_at=0,
+                added_at=0,
+            )
+        )
+
+        client_secret = "foobar"
+        session_id = self._request_token(email, client_secret)
+
+        self.assertEquals(len(self.email_attempts), 1)
+
+        # Attempt to reset password without clicking the link
+        self._reset_password(new_password, session_id, client_secret, expected_code=401)
+
+        # Assert we can log in with the old password
+        self.login("kermit", old_password)
+
+        # Assert we can't log in with the new password
+        self.attempt_wrong_password_login("kermit", new_password)
+
+    def test_no_valid_token(self):
+        """Test that we do actually need to request a token and can't just
+        make a session up.
+        """
+        old_password = "monkey"
+        new_password = "kangeroo"
+
+        user_id = self.register_user("kermit", old_password)
+        self.login("kermit", old_password)
+
+        email = "test@example.com"
+
+        # Add a threepid
+        self.get_success(
+            self.store.user_add_threepid(
+                user_id=user_id,
+                medium="email",
+                address=email,
+                validated_at=0,
+                added_at=0,
+            )
+        )
+
+        client_secret = "foobar"
+        session_id = "weasle"
+
+        # Attempt to reset password without even requesting an email
+        self._reset_password(new_password, session_id, client_secret, expected_code=401)
+
+        # Assert we can log in with the old password
+        self.login("kermit", old_password)
+
+        # Assert we can't log in with the new password
+        self.attempt_wrong_password_login("kermit", new_password)
+
+    def _request_token(self, email, client_secret):
+        request, channel = self.make_request(
+            "POST",
+            b"account/password/email/requestToken",
+            {"client_secret": client_secret, "email": email, "send_attempt": 1},
+        )
+        self.render(request)
+        self.assertEquals(200, channel.code, channel.result)
+
+        return channel.json_body["sid"]
+
+    def _validate_token(self, link):
+        # Remove the host
+        path = link.replace("https://example.com", "")
+
+        request, channel = self.make_request("GET", path, shorthand=False)
+        self.render(request)
+        self.assertEquals(200, channel.code, channel.result)
+
+    def _get_link_from_email(self):
+        assert self.email_attempts, "No emails have been sent"
+
+        raw_msg = self.email_attempts[-1].decode("UTF-8")
+        mail = Parser().parsestr(raw_msg)
+
+        text = None
+        for part in mail.walk():
+            if part.get_content_type() == "text/plain":
+                text = part.get_payload(decode=True).decode("UTF-8")
+                break
+
+        if not text:
+            self.fail("Could not find text portion of email to parse")
+
+        match = re.search(r"https://example.com\S+", text)
+        assert match, "Could not find link in email"
+
+        return match.group(0)
+
+    def _reset_password(
+        self, new_password, session_id, client_secret, expected_code=200
+    ):
+        request, channel = self.make_request(
+            "POST",
+            b"account/password",
+            {
+                "new_password": new_password,
+                "auth": {
+                    "type": LoginType.EMAIL_IDENTITY,
+                    "threepid_creds": {
+                        "client_secret": client_secret,
+                        "sid": session_id,
+                    },
+                },
+            },
+        )
+        self.render(request)
+        self.assertEquals(expected_code, channel.code, channel.result)
+
+
+class DeactivateTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        login.register_servlets,
+        account.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor, clock):
+        hs = self.setup_test_homeserver()
+        return hs
+
+    def test_deactivate_account(self):
+        user_id = self.register_user("kermit", "test")
+        tok = self.login("kermit", "test")
+
+        request_data = json.dumps(
+            {
+                "auth": {
+                    "type": "m.login.password",
+                    "user": user_id,
+                    "password": "test",
+                },
+                "erase": False,
+            }
+        )
+        request, channel = self.make_request(
+            "POST", "account/deactivate", request_data, access_token=tok
+        )
+        self.render(request)
+        self.assertEqual(request.code, 200)
+
+        store = self.hs.get_datastore()
+
+        # Check that the user has been marked as deactivated.
+        self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id)))
+
+        # Check that this access token has been invalidated.
+        request, channel = self.make_request("GET", "account/whoami")
+        self.render(request)
+        self.assertEqual(request.code, 401)
diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py
index f3ef977404..b9e01c9418 100644
--- a/tests/rest/client/v2_alpha/test_capabilities.py
+++ b/tests/rest/client/v2_alpha/test_capabilities.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import synapse.rest.admin
-from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.rest.client.v1 import login
 from synapse.rest.client.v2_alpha import capabilities
 
@@ -32,6 +32,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
         self.url = b"/_matrix/client/r0/capabilities"
         hs = self.setup_test_homeserver()
         self.store = hs.get_datastore()
+        self.config = hs.config
         return hs
 
     def test_check_auth_required(self):
@@ -46,13 +47,15 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
 
         request, channel = self.make_request("GET", self.url, access_token=access_token)
         self.render(request)
-        capabilities = channel.json_body['capabilities']
+        capabilities = channel.json_body["capabilities"]
 
         self.assertEqual(channel.code, 200)
-        for room_version in capabilities['m.room_versions']['available'].keys():
+        for room_version in capabilities["m.room_versions"]["available"].keys():
             self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, "" + room_version)
+
         self.assertEqual(
-            DEFAULT_ROOM_VERSION.identifier, capabilities['m.room_versions']['default']
+            self.config.default_room_version.identifier,
+            capabilities["m.room_versions"]["default"],
         )
 
     def test_get_change_password_capabilities(self):
@@ -63,16 +66,16 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
 
         request, channel = self.make_request("GET", self.url, access_token=access_token)
         self.render(request)
-        capabilities = channel.json_body['capabilities']
+        capabilities = channel.json_body["capabilities"]
 
         self.assertEqual(channel.code, 200)
 
         # Test case where password is handled outside of Synapse
-        self.assertTrue(capabilities['m.change_password']['enabled'])
+        self.assertTrue(capabilities["m.change_password"]["enabled"])
         self.get_success(self.store.user_set_password_hash(user, None))
         request, channel = self.make_request("GET", self.url, access_token=access_token)
         self.render(request)
-        capabilities = channel.json_body['capabilities']
+        capabilities = channel.json_body["capabilities"]
 
         self.assertEqual(channel.code, 200)
-        self.assertFalse(capabilities['m.change_password']['enabled'])
+        self.assertFalse(capabilities["m.change_password"]["enabled"])
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index d4a1d4d50c..89a3f95c0a 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -26,15 +26,10 @@ from synapse.api.constants import LoginType
 from synapse.api.errors import Codes
 from synapse.appservice import ApplicationService
 from synapse.rest.client.v1 import login
-from synapse.rest.client.v2_alpha import account_validity, register, sync
+from synapse.rest.client.v2_alpha import account, account_validity, register, sync
 
 from tests import unittest
 
-try:
-    from synapse.push.mailer import load_jinja2_templates
-except ImportError:
-    load_jinja2_templates = None
-
 
 class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
@@ -307,13 +302,13 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
 
 class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
 
-    skip = "No Jinja installed" if not load_jinja2_templates else None
     servlets = [
         register.register_servlets,
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         login.register_servlets,
         sync.register_servlets,
         account_validity.register_servlets,
+        account.register_servlets,
     ]
 
     def make_homeserver(self, reactor, clock):
@@ -340,7 +335,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
         config["email"] = {
             "enable_notifs": True,
             "template_dir": os.path.abspath(
-                pkg_resources.resource_filename('synapse', 'res/templates')
+                pkg_resources.resource_filename("synapse", "res/templates")
             ),
             "expiry_template_html": "notice_expiry.html",
             "expiry_template_text": "notice_expiry.txt",
@@ -364,20 +359,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
     def test_renewal_email(self):
         self.email_attempts = []
 
-        user_id = self.register_user("kermit", "monkey")
-        tok = self.login("kermit", "monkey")
-        # We need to manually add an email address otherwise the handler will do
-        # nothing.
-        now = self.hs.clock.time_msec()
-        self.get_success(
-            self.store.user_add_threepid(
-                user_id=user_id,
-                medium="email",
-                address="kermit@example.com",
-                validated_at=now,
-                added_at=now,
-            )
-        )
+        (user_id, tok) = self.create_user()
 
         # Move 6 days forward. This should trigger a renewal email to be sent.
         self.reactor.advance(datetime.timedelta(days=6).total_seconds())
@@ -402,6 +384,43 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
     def test_manual_email_send(self):
         self.email_attempts = []
 
+        (user_id, tok) = self.create_user()
+        request, channel = self.make_request(
+            b"POST",
+            "/_matrix/client/unstable/account_validity/send_mail",
+            access_token=tok,
+        )
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        self.assertEqual(len(self.email_attempts), 1)
+
+    def test_deactivated_user(self):
+        self.email_attempts = []
+
+        (user_id, tok) = self.create_user()
+
+        request_data = json.dumps(
+            {
+                "auth": {
+                    "type": "m.login.password",
+                    "user": user_id,
+                    "password": "monkey",
+                },
+                "erase": False,
+            }
+        )
+        request, channel = self.make_request(
+            "POST", "account/deactivate", request_data, access_token=tok
+        )
+        self.render(request)
+        self.assertEqual(request.code, 200)
+
+        self.reactor.advance(datetime.timedelta(days=8).total_seconds())
+
+        self.assertEqual(len(self.email_attempts), 0)
+
+    def create_user(self):
         user_id = self.register_user("kermit", "monkey")
         tok = self.login("kermit", "monkey")
         # We need to manually add an email address otherwise the handler will do
@@ -416,7 +435,33 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
                 added_at=now,
             )
         )
+        return (user_id, tok)
+
+    def test_manual_email_send_expired_account(self):
+        user_id = self.register_user("kermit", "monkey")
+        tok = self.login("kermit", "monkey")
 
+        # We need to manually add an email address otherwise the handler will do
+        # nothing.
+        now = self.hs.clock.time_msec()
+        self.get_success(
+            self.store.user_add_threepid(
+                user_id=user_id,
+                medium="email",
+                address="kermit@example.com",
+                validated_at=now,
+                added_at=now,
+            )
+        )
+
+        # Make the account expire.
+        self.reactor.advance(datetime.timedelta(days=8).total_seconds())
+
+        # Ignore all emails sent by the automatic background task and only focus on the
+        # ones sent manually.
+        self.email_attempts = []
+
+        # Test that we're still able to manually trigger a mail to be sent.
         request, channel = self.make_request(
             b"POST",
             "/_matrix/client/unstable/account_validity/send_mail",
@@ -430,19 +475,16 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
 
 class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
 
-    servlets = [
-        synapse.rest.admin.register_servlets_for_client_rest_resource,
-    ]
+    servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
 
     def make_homeserver(self, reactor, clock):
         self.validity_period = 10
+        self.max_delta = self.validity_period * 10.0 / 100.0
 
         config = self.default_config()
 
         config["enable_registration"] = True
-        config["account_validity"] = {
-            "enabled": False,
-        }
+        config["account_validity"] = {"enabled": False}
 
         self.hs = self.setup_test_homeserver(config=config)
         self.hs.config.account_validity.period = self.validity_period
@@ -453,14 +495,18 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
 
     def test_background_job(self):
         """
-        Tests whether the account validity startup background job does the right thing,
-        which is sticking an expiration date to every account that doesn't already have
-        one.
+        Tests the same thing as test_background_job, except that it sets the
+        startup_job_max_delta parameter and checks that the expiration date is within the
+        allowed range.
         """
-        user_id = self.register_user("kermit", "user")
+        user_id = self.register_user("kermit_delta", "user")
+
+        self.hs.config.account_validity.startup_job_max_delta = self.max_delta
 
         now_ms = self.hs.clock.time_msec()
         self.get_success(self.store._set_expiration_date_when_missing())
 
         res = self.get_success(self.store.get_expiration_ts_for_user(user_id))
-        self.assertEqual(res, now_ms + self.validity_period)
+
+        self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
+        self.assertLessEqual(res, now_ms + self.validity_period)
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index 43b3049daa..3deeed3a70 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -56,7 +56,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         creates the right shape of event.
         """
 
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key=u"👍")
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
         self.assertEquals(200, channel.code, channel.json_body)
 
         event_id = channel.json_body["event_id"]
@@ -76,7 +76,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
                 "content": {
                     "m.relates_to": {
                         "event_id": self.parent_id,
-                        "key": u"👍",
+                        "key": "👍",
                         "rel_type": RelationTypes.ANNOTATION,
                     }
                 },
@@ -187,7 +187,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             access_tokens.append(token)
 
         idx = 0
-        sent_groups = {u"👍": 10, u"a": 7, u"b": 5, u"c": 3, u"d": 2, u"e": 1}
+        sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1}
         for key in itertools.chain.from_iterable(
             itertools.repeat(key, num) for key, num in sent_groups.items()
         ):
@@ -259,7 +259,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             channel = self._send_relation(
                 RelationTypes.ANNOTATION,
                 "m.reaction",
-                key=u"👍",
+                key="👍",
                 access_token=access_tokens[idx],
             )
             self.assertEquals(200, channel.code, channel.json_body)
@@ -273,7 +273,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
         prev_token = None
         found_event_ids = []
-        encoded_key = six.moves.urllib.parse.quote_plus(u"👍".encode("utf-8"))
+        encoded_key = six.moves.urllib.parse.quote_plus("👍".encode("utf-8"))
         for _ in range(20):
             from_token = ""
             if prev_token:
diff --git a/tests/rest/media/v1/test_base.py b/tests/rest/media/v1/test_base.py
index 00688a7325..ebd7869208 100644
--- a/tests/rest/media/v1/test_base.py
+++ b/tests/rest/media/v1/test_base.py
@@ -21,17 +21,17 @@ from tests import unittest
 class GetFileNameFromHeadersTests(unittest.TestCase):
     # input -> expected result
     TEST_CASES = {
-        b"inline; filename=abc.txt": u"abc.txt",
-        b'inline; filename="azerty"': u"azerty",
-        b'inline; filename="aze%20rty"': u"aze%20rty",
-        b'inline; filename="aze\"rty"': u'aze"rty',
-        b'inline; filename="azer;ty"': u"azer;ty",
-        b"inline; filename*=utf-8''foo%C2%A3bar": u"foo£bar",
+        b"inline; filename=abc.txt": "abc.txt",
+        b'inline; filename="azerty"': "azerty",
+        b'inline; filename="aze%20rty"': "aze%20rty",
+        b'inline; filename="aze"rty"': 'aze"rty',
+        b'inline; filename="azer;ty"': "azer;ty",
+        b"inline; filename*=utf-8''foo%C2%A3bar": "foo£bar",
     }
 
     def tests(self):
         for hdr, expected in self.TEST_CASES.items():
-            res = get_filename_from_headers({b'Content-Disposition': [hdr]})
+            res = get_filename_from_headers({b"Content-Disposition": [hdr]})
             self.assertEqual(
                 res,
                 expected,
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 1069a44145..e2d418b1df 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -143,7 +143,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
 
         self.media_repo = hs.get_media_repository_resource()
-        self.download_resource = self.media_repo.children[b'download']
+        self.download_resource = self.media_repo.children[b"download"]
 
         # smol png
         self.end_content = unhexlify(
@@ -171,7 +171,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
 
         headers = {
             b"Content-Length": [b"%d" % (len(self.end_content))],
-            b"Content-Type": [b'image/png'],
+            b"Content-Type": [b"image/png"],
         }
         if content_disposition:
             headers[b"Content-Disposition"] = [content_disposition]
@@ -204,7 +204,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         correctly decode it as the UTF-8 string, and use filename* in the
         response.
         """
-        filename = parse.quote(u"\u2603".encode('utf8')).encode('ascii')
+        filename = parse.quote("\u2603".encode("utf8")).encode("ascii")
         channel = self._req(b"inline; filename*=utf-8''" + filename + b".png")
 
         headers = channel.headers
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 1ab0f7293a..8fe5961866 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -55,10 +55,10 @@ class URLPreviewTests(unittest.HomeserverTestCase):
     hijack_auth = True
     user_id = "@test:user"
     end_content = (
-        b'<html><head>'
+        b"<html><head>"
         b'<meta property="og:title" content="~matrix~" />'
         b'<meta property="og:description" content="hi" />'
-        b'</head></html>'
+        b"</head></html>"
     )
 
     def make_homeserver(self, reactor, clock):
@@ -98,7 +98,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
 
         self.media_repo = hs.get_media_repository_resource()
-        self.preview_url = self.media_repo.children[b'preview_url']
+        self.preview_url = self.media_repo.children[b"preview_url"]
 
         self.lookups = {}
 
@@ -109,7 +109,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
                 hostName,
                 portNumber=0,
                 addressTypes=None,
-                transportSemantics='TCP',
+                transportSemantics="TCP",
             ):
 
                 resolution = HostResolution(hostName)
@@ -118,7 +118,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
                     raise DNSLookupError("OH NO")
 
                 for i in self.lookups[hostName]:
-                    resolutionReceiver.addressResolved(i[0]('TCP', i[1], portNumber))
+                    resolutionReceiver.addressResolved(i[0]("TCP", i[1], portNumber))
                 resolutionReceiver.resolutionComplete()
                 return resolutionReceiver
 
@@ -184,11 +184,11 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
 
         end_content = (
-            b'<html><head>'
+            b"<html><head>"
             b'<meta http-equiv="Content-Type" content="text/html; charset=windows-1251"/>'
             b'<meta property="og:title" content="\xe4\xea\xe0" />'
             b'<meta property="og:description" content="hi" />'
-            b'</head></html>'
+            b"</head></html>"
         )
 
         request, channel = self.make_request(
@@ -204,7 +204,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         client.dataReceived(
             (
                 b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
-                b"Content-Type: text/html; charset=\"utf8\"\r\n\r\n"
+                b'Content-Type: text/html; charset="utf8"\r\n\r\n'
             )
             % (len(end_content),)
             + end_content
@@ -212,16 +212,16 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         self.pump()
         self.assertEqual(channel.code, 200)
-        self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430")
+        self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
 
     def test_non_ascii_preview_content_type(self):
         self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
 
         end_content = (
-            b'<html><head>'
+            b"<html><head>"
             b'<meta property="og:title" content="\xe4\xea\xe0" />'
             b'<meta property="og:description" content="hi" />'
-            b'</head></html>'
+            b"</head></html>"
         )
 
         request, channel = self.make_request(
@@ -237,7 +237,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         client.dataReceived(
             (
                 b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
-                b"Content-Type: text/html; charset=\"windows-1251\"\r\n\r\n"
+                b'Content-Type: text/html; charset="windows-1251"\r\n\r\n'
             )
             % (len(end_content),)
             + end_content
@@ -245,7 +245,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         self.pump()
         self.assertEqual(channel.code, 200)
-        self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430")
+        self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
 
     def test_ipaddr(self):
         """
@@ -293,8 +293,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertEqual(
             channel.json_body,
             {
-                'errcode': 'M_UNKNOWN',
-                'error': 'DNS resolution failure during URL preview generation',
+                "errcode": "M_UNKNOWN",
+                "error": "DNS resolution failure during URL preview generation",
             },
         )
 
@@ -314,8 +314,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertEqual(
             channel.json_body,
             {
-                'errcode': 'M_UNKNOWN',
-                'error': 'DNS resolution failure during URL preview generation',
+                "errcode": "M_UNKNOWN",
+                "error": "DNS resolution failure during URL preview generation",
             },
         )
 
@@ -334,8 +334,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertEqual(
             channel.json_body,
             {
-                'errcode': 'M_UNKNOWN',
-                'error': 'IP address blocked by IP blacklist entry',
+                "errcode": "M_UNKNOWN",
+                "error": "IP address blocked by IP blacklist entry",
             },
         )
         self.assertEqual(channel.code, 403)
@@ -354,8 +354,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertEqual(
             channel.json_body,
             {
-                'errcode': 'M_UNKNOWN',
-                'error': 'IP address blocked by IP blacklist entry',
+                "errcode": "M_UNKNOWN",
+                "error": "IP address blocked by IP blacklist entry",
             },
         )
 
@@ -396,7 +396,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         non-blacklisted one, it will be rejected.
         """
         # Hardcode the URL resolving to the IP we want.
-        self.lookups[u"example.com"] = [
+        self.lookups["example.com"] = [
             (IPv4Address, "1.1.1.2"),
             (IPv4Address, "8.8.8.8"),
         ]
@@ -410,8 +410,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertEqual(
             channel.json_body,
             {
-                'errcode': 'M_UNKNOWN',
-                'error': 'DNS resolution failure during URL preview generation',
+                "errcode": "M_UNKNOWN",
+                "error": "DNS resolution failure during URL preview generation",
             },
         )
 
@@ -435,8 +435,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertEqual(
             channel.json_body,
             {
-                'errcode': 'M_UNKNOWN',
-                'error': 'DNS resolution failure during URL preview generation',
+                "errcode": "M_UNKNOWN",
+                "error": "DNS resolution failure during URL preview generation",
             },
         )
 
@@ -456,7 +456,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertEqual(
             channel.json_body,
             {
-                'errcode': 'M_UNKNOWN',
-                'error': 'DNS resolution failure during URL preview generation',
+                "errcode": "M_UNKNOWN",
+                "error": "DNS resolution failure during URL preview generation",
             },
         )
diff --git a/tests/server.py b/tests/server.py
index c15a47f2a4..e573c4e4c5 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -46,7 +46,7 @@ class FakeChannel(object):
     def json_body(self):
         if not self.result:
             raise Exception("No result yet.")
-        return json.loads(self.result["body"].decode('utf8'))
+        return json.loads(self.result["body"].decode("utf8"))
 
     @property
     def code(self):
@@ -151,10 +151,10 @@ def make_request(
         Tuple[synapse.http.site.SynapseRequest, channel]
     """
     if not isinstance(method, bytes):
-        method = method.encode('ascii')
+        method = method.encode("ascii")
 
     if not isinstance(path, bytes):
-        path = path.encode('ascii')
+        path = path.encode("ascii")
 
     # Decorate it to be the full path, if we're using shorthand
     if shorthand and not path.startswith(b"/_matrix"):
@@ -165,7 +165,7 @@ def make_request(
         path = b"/" + path
 
     if isinstance(content, text_type):
-        content = content.encode('utf8')
+        content = content.encode("utf8")
 
     site = FakeSite()
     channel = FakeChannel(reactor)
@@ -173,11 +173,11 @@ def make_request(
     req = request(site, channel)
     req.process = lambda: b""
     req.content = BytesIO(content)
-    req.postpath = list(map(unquote, path[1:].split(b'/')))
+    req.postpath = list(map(unquote, path[1:].split(b"/")))
 
     if access_token:
         req.requestHeaders.addRawHeader(
-            b"Authorization", b"Bearer " + access_token.encode('ascii')
+            b"Authorization", b"Bearer " + access_token.encode("ascii")
         )
 
     if federation_auth_origin is not None:
@@ -242,7 +242,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
         self.nameResolver = SimpleResolverComplexifier(FakeResolver())
         super(ThreadedMemoryReactorClock, self).__init__()
 
-    def listenUDP(self, port, protocol, interface='', maxPacketSize=8196):
+    def listenUDP(self, port, protocol, interface="", maxPacketSize=8196):
         p = udp.Port(port, protocol, interface, maxPacketSize, self)
         p.startListening()
         self._udp.append(p)
@@ -371,7 +371,7 @@ class FakeTransport(object):
 
     disconnecting = False
     disconnected = False
-    buffer = attr.ib(default=b'')
+    buffer = attr.ib(default=b"")
     producer = attr.ib(default=None)
     autoflush = attr.ib(default=True)
 
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 739ee59ce4..984feb623f 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -109,7 +109,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         Test when user has blocked notice, but notice ought to be there (NOOP)
         """
         self._rlsn._auth.check_auth_blocking = Mock(
-            side_effect=ResourceLimitError(403, 'foo')
+            side_effect=ResourceLimitError(403, "foo")
         )
 
         mock_event = Mock(
@@ -128,7 +128,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         """
 
         self._rlsn._auth.check_auth_blocking = Mock(
-            side_effect=ResourceLimitError(403, 'foo')
+            side_effect=ResourceLimitError(403, "foo")
         )
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
 
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index 9c5311d916..8d3845c870 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -181,7 +181,7 @@ class StateTestCase(unittest.TestCase):
                 id="PB",
                 sender=BOB,
                 type=EventTypes.PowerLevels,
-                state_key='',
+                state_key="",
                 content={"users": {ALICE: 100, BOB: 50}},
             ),
         ]
@@ -229,14 +229,14 @@ class StateTestCase(unittest.TestCase):
                 id="PB",
                 sender=BOB,
                 type=EventTypes.PowerLevels,
-                state_key='',
+                state_key="",
                 content={"users": {ALICE: 100, BOB: 50, CHARLIE: 50}},
             ),
             FakeEvent(
                 id="PC",
                 sender=CHARLIE,
                 type=EventTypes.PowerLevels,
-                state_key='',
+                state_key="",
                 content={"users": {ALICE: 100, BOB: 50, CHARLIE: 0}},
             ),
         ]
@@ -256,7 +256,7 @@ class StateTestCase(unittest.TestCase):
                 id="PA1",
                 sender=ALICE,
                 type=EventTypes.PowerLevels,
-                state_key='',
+                state_key="",
                 content={"users": {ALICE: 100, BOB: 50}},
             ),
             FakeEvent(
@@ -266,14 +266,14 @@ class StateTestCase(unittest.TestCase):
                 id="PA2",
                 sender=ALICE,
                 type=EventTypes.PowerLevels,
-                state_key='',
+                state_key="",
                 content={"users": {ALICE: 100, BOB: 0}},
             ),
             FakeEvent(
                 id="PB",
                 sender=BOB,
                 type=EventTypes.PowerLevels,
-                state_key='',
+                state_key="",
                 content={"users": {ALICE: 100, BOB: 50}},
             ),
             FakeEvent(
@@ -296,7 +296,7 @@ class StateTestCase(unittest.TestCase):
                 id="PA",
                 sender=ALICE,
                 type=EventTypes.PowerLevels,
-                state_key='',
+                state_key="",
                 content={"users": {ALICE: 100, BOB: 50}},
             ),
             FakeEvent(
@@ -326,7 +326,7 @@ class StateTestCase(unittest.TestCase):
                 id="PA1",
                 sender=ALICE,
                 type=EventTypes.PowerLevels,
-                state_key='',
+                state_key="",
                 content={"users": {ALICE: 100, BOB: 50}},
             ),
             FakeEvent(
@@ -336,14 +336,14 @@ class StateTestCase(unittest.TestCase):
                 id="PA2",
                 sender=ALICE,
                 type=EventTypes.PowerLevels,
-                state_key='',
+                state_key="",
                 content={"users": {ALICE: 100, BOB: 0}},
             ),
             FakeEvent(
                 id="PB",
                 sender=BOB,
                 type=EventTypes.PowerLevels,
-                state_key='',
+                state_key="",
                 content={"users": {ALICE: 100, BOB: 50}},
             ),
             FakeEvent(
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 25a6c89ef5..622b16a071 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -74,7 +74,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
             namespaces={},
         )
         # use the token as the filename
-        with open(as_token, 'w') as outfile:
+        with open(as_token, "w") as outfile:
             outfile.write(yaml.dump(as_yaml))
             self.as_yaml_files.append(as_token)
 
@@ -135,7 +135,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
             namespaces={},
         )
         # use the token as the filename
-        with open(as_token, 'w') as outfile:
+        with open(as_token, "w") as outfile:
             outfile.write(yaml.dump(as_yaml))
             self.as_yaml_files.append(as_token)
 
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 6dda66ecd3..e9e2d5337c 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -15,7 +15,6 @@
 
 import os.path
 
-from synapse.api.constants import EventTypes
 from synapse.storage import prepare_database
 from synapse.types import Requester, UserID
 
@@ -23,12 +22,12 @@ from tests.unittest import HomeserverTestCase
 
 
 class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
-    """Test the background update to clean forward extremities table.
+    """
+    Test the background update to clean forward extremities table.
     """
 
     def prepare(self, reactor, clock, homeserver):
         self.store = homeserver.get_datastore()
-        self.event_creator = homeserver.get_event_creation_handler()
         self.room_creator = homeserver.get_room_creation_handler()
 
         # Create a test user and room
@@ -37,56 +36,6 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         info = self.get_success(self.room_creator.create_room(self.requester, {}))
         self.room_id = info["room_id"]
 
-    def create_and_send_event(self, soft_failed=False, prev_event_ids=None):
-        """Create and send an event.
-
-        Args:
-            soft_failed (bool): Whether to create a soft failed event or not
-            prev_event_ids (list[str]|None): Explicitly set the prev events,
-                or if None just use the default
-
-        Returns:
-            str: The new event's ID.
-        """
-        prev_events_and_hashes = None
-        if prev_event_ids:
-            prev_events_and_hashes = [[p, {}, 0] for p in prev_event_ids]
-
-        event, context = self.get_success(
-            self.event_creator.create_event(
-                self.requester,
-                {
-                    "type": EventTypes.Message,
-                    "room_id": self.room_id,
-                    "sender": self.user.to_string(),
-                    "content": {"body": "", "msgtype": "m.text"},
-                },
-                prev_events_and_hashes=prev_events_and_hashes,
-            )
-        )
-
-        if soft_failed:
-            event.internal_metadata.soft_failed = True
-
-        self.get_success(
-            self.event_creator.send_nonmember_event(self.requester, event, context)
-        )
-
-        return event.event_id
-
-    def add_extremity(self, event_id):
-        """Add the given event as an extremity to the room.
-        """
-        self.get_success(
-            self.store._simple_insert(
-                table="event_forward_extremities",
-                values={"room_id": self.room_id, "event_id": event_id},
-                desc="test_add_extremity",
-            )
-        )
-
-        self.store.get_latest_event_ids_in_room.invalidate((self.room_id,))
-
     def run_background_update(self):
         """Re run the background update to clean up the extremities.
         """
@@ -126,10 +75,16 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         """
 
         # Create the room graph
-        event_id_1 = self.create_and_send_event()
-        event_id_2 = self.create_and_send_event(True, [event_id_1])
-        event_id_3 = self.create_and_send_event(True, [event_id_2])
-        event_id_4 = self.create_and_send_event(False, [event_id_3])
+        event_id_1 = self.create_and_send_event(self.room_id, self.user)
+        event_id_2 = self.create_and_send_event(
+            self.room_id, self.user, True, [event_id_1]
+        )
+        event_id_3 = self.create_and_send_event(
+            self.room_id, self.user, True, [event_id_2]
+        )
+        event_id_4 = self.create_and_send_event(
+            self.room_id, self.user, False, [event_id_3]
+        )
 
         # Check the latest events are as expected
         latest_event_ids = self.get_success(
@@ -149,12 +104,16 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         Where SF* are soft failed, and with extremities of A and B
         """
         # Create the room graph
-        event_id_a = self.create_and_send_event()
-        event_id_sf1 = self.create_and_send_event(True, [event_id_a])
-        event_id_b = self.create_and_send_event(False, [event_id_sf1])
+        event_id_a = self.create_and_send_event(self.room_id, self.user)
+        event_id_sf1 = self.create_and_send_event(
+            self.room_id, self.user, True, [event_id_a]
+        )
+        event_id_b = self.create_and_send_event(
+            self.room_id, self.user, False, [event_id_sf1]
+        )
 
         # Add the new extremity and check the latest events are as expected
-        self.add_extremity(event_id_a)
+        self.add_extremity(self.room_id, event_id_a)
 
         latest_event_ids = self.get_success(
             self.store.get_latest_event_ids_in_room(self.room_id)
@@ -180,13 +139,19 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         Where SF* are soft failed, and with extremities of A and B
         """
         # Create the room graph
-        event_id_a = self.create_and_send_event()
-        event_id_sf1 = self.create_and_send_event(True, [event_id_a])
-        event_id_sf2 = self.create_and_send_event(True, [event_id_sf1])
-        event_id_b = self.create_and_send_event(False, [event_id_sf2])
+        event_id_a = self.create_and_send_event(self.room_id, self.user)
+        event_id_sf1 = self.create_and_send_event(
+            self.room_id, self.user, True, [event_id_a]
+        )
+        event_id_sf2 = self.create_and_send_event(
+            self.room_id, self.user, True, [event_id_sf1]
+        )
+        event_id_b = self.create_and_send_event(
+            self.room_id, self.user, False, [event_id_sf2]
+        )
 
         # Add the new extremity and check the latest events are as expected
-        self.add_extremity(event_id_a)
+        self.add_extremity(self.room_id, event_id_a)
 
         latest_event_ids = self.get_success(
             self.store.get_latest_event_ids_in_room(self.room_id)
@@ -220,17 +185,28 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         Where SF* are soft failed, and with them A, B and C marked as
         extremities. This should resolve to B and C being marked as extremity.
         """
+
         # Create the room graph
-        event_id_a = self.create_and_send_event()
-        event_id_b = self.create_and_send_event()
-        event_id_sf1 = self.create_and_send_event(True, [event_id_a])
-        event_id_sf2 = self.create_and_send_event(True, [event_id_a, event_id_b])
-        event_id_sf3 = self.create_and_send_event(True, [event_id_sf1])
-        self.create_and_send_event(True, [event_id_sf2, event_id_sf3])  # SF4
-        event_id_c = self.create_and_send_event(False, [event_id_sf3])
+        event_id_a = self.create_and_send_event(self.room_id, self.user)
+        event_id_b = self.create_and_send_event(self.room_id, self.user)
+        event_id_sf1 = self.create_and_send_event(
+            self.room_id, self.user, True, [event_id_a]
+        )
+        event_id_sf2 = self.create_and_send_event(
+            self.room_id, self.user, True, [event_id_a, event_id_b]
+        )
+        event_id_sf3 = self.create_and_send_event(
+            self.room_id, self.user, True, [event_id_sf1]
+        )
+        self.create_and_send_event(
+            self.room_id, self.user, True, [event_id_sf2, event_id_sf3]
+        )  # SF4
+        event_id_c = self.create_and_send_event(
+            self.room_id, self.user, False, [event_id_sf3]
+        )
 
         # Add the new extremity and check the latest events are as expected
-        self.add_extremity(event_id_a)
+        self.add_extremity(self.room_id, event_id_a)
 
         latest_event_ids = self.get_success(
             self.store.get_latest_event_ids_in_room(self.room_id)
@@ -246,3 +222,44 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
             self.store.get_latest_event_ids_in_room(self.room_id)
         )
         self.assertEqual(set(latest_event_ids), set([event_id_b, event_id_c]))
+
+
+class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
+    def make_homeserver(self, reactor, clock):
+        config = self.default_config()
+        config["cleanup_extremities_with_dummy_events"] = True
+        return self.setup_test_homeserver(config=config)
+
+    def prepare(self, reactor, clock, homeserver):
+        self.store = homeserver.get_datastore()
+        self.room_creator = homeserver.get_room_creation_handler()
+
+        # Create a test user and room
+        self.user = UserID("alice", "test")
+        self.requester = Requester(self.user, None, False, None, None)
+        info = self.get_success(self.room_creator.create_room(self.requester, {}))
+        self.room_id = info["room_id"]
+
+    def test_send_dummy_event(self):
+        # Create a bushy graph with 50 extremities.
+
+        event_id_start = self.create_and_send_event(self.room_id, self.user)
+
+        for _ in range(50):
+            self.create_and_send_event(
+                self.room_id, self.user, prev_event_ids=[event_id_start]
+            )
+
+        latest_event_ids = self.get_success(
+            self.store.get_latest_event_ids_in_room(self.room_id)
+        )
+        self.assertEqual(len(latest_event_ids), 50)
+
+        # Pump the reactor repeatedly so that the background updates have a
+        # chance to run.
+        self.pump(10 * 60)
+
+        latest_event_ids = self.get_success(
+            self.store.get_latest_event_ids_in_room(self.room_id)
+        )
+        self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index b62eae7abc..59c6f8c227 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -94,11 +94,11 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
             result,
             [
                 {
-                    'access_token': 'access_token',
-                    'ip': 'ip',
-                    'user_agent': 'user_agent',
-                    'device_id': None,
-                    'last_seen': 12345678000,
+                    "access_token": "access_token",
+                    "ip": "ip",
+                    "user_agent": "user_agent",
+                    "device_id": None,
+                    "last_seen": 12345678000,
                 }
             ],
         )
@@ -125,11 +125,11 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
             result,
             [
                 {
-                    'access_token': 'access_token',
-                    'ip': 'ip',
-                    'user_agent': 'user_agent',
-                    'device_id': None,
-                    'last_seen': 12345878000,
+                    "access_token": "access_token",
+                    "ip": "ip",
+                    "user_agent": "user_agent",
+                    "device_id": None,
+                    "last_seen": 12345878000,
                 }
             ],
         )
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index aef4dfaf57..3cc18f9f1c 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -72,6 +72,75 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
         )
 
     @defer.inlineCallbacks
+    def test_get_devices_by_remote(self):
+        device_ids = ["device_id1", "device_id2"]
+
+        # Add two device updates with a single stream_id
+        yield self.store.add_device_change_to_streams(
+            "user_id", device_ids, ["somehost"]
+        )
+
+        # Get all device updates ever meant for this remote
+        now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+            "somehost", -1, limit=100
+        )
+
+        # Check original device_ids are contained within these updates
+        self._check_devices_in_updates(device_ids, device_updates)
+
+    @defer.inlineCallbacks
+    def test_get_devices_by_remote_limited(self):
+        # Test breaking the update limit in 1, 101, and 1 device_id segments
+
+        # first add one device
+        device_ids1 = ["device_id0"]
+        yield self.store.add_device_change_to_streams(
+            "user_id", device_ids1, ["someotherhost"]
+        )
+
+        # then add 101
+        device_ids2 = ["device_id" + str(i + 1) for i in range(101)]
+        yield self.store.add_device_change_to_streams(
+            "user_id", device_ids2, ["someotherhost"]
+        )
+
+        # then one more
+        device_ids3 = ["newdevice"]
+        yield self.store.add_device_change_to_streams(
+            "user_id", device_ids3, ["someotherhost"]
+        )
+
+        #
+        # now read them back.
+        #
+
+        # first we should get a single update
+        now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+            "someotherhost", -1, limit=100
+        )
+        self._check_devices_in_updates(device_ids1, device_updates)
+
+        # Then we should get an empty list back as the 101 devices broke the limit
+        now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+            "someotherhost", now_stream_id, limit=100
+        )
+        self.assertEqual(len(device_updates), 0)
+
+        # The 101 devices should've been cleared, so we should now just get one device
+        # update
+        now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+            "someotherhost", now_stream_id, limit=100
+        )
+        self._check_devices_in_updates(device_ids3, device_updates)
+
+    def _check_devices_in_updates(self, expected_device_ids, device_updates):
+        """Check that an specific device ids exist in a list of device update EDUs"""
+        self.assertEqual(len(device_updates), len(expected_device_ids))
+
+        received_device_ids = {update["device_id"] for update in device_updates}
+        self.assertEqual(received_device_ids, set(expected_device_ids))
+
+    @defer.inlineCallbacks
     def test_update_device(self):
         yield self.store.store_device("user_id", "device_id", "display_name 1")
 
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index cd2bcd4ca3..c8ece15284 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -80,10 +80,10 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
         yield self.store.store_device("user2", "device1", None)
         yield self.store.store_device("user2", "device2", None)
 
-        yield self.store.set_e2e_device_keys("user1", "device1", now, 'json11')
-        yield self.store.set_e2e_device_keys("user1", "device2", now, 'json12')
-        yield self.store.set_e2e_device_keys("user2", "device1", now, 'json21')
-        yield self.store.set_e2e_device_keys("user2", "device2", now, 'json22')
+        yield self.store.set_e2e_device_keys("user1", "device1", now, "json11")
+        yield self.store.set_e2e_device_keys("user1", "device2", now, "json12")
+        yield self.store.set_e2e_device_keys("user2", "device1", now, "json21")
+        yield self.store.set_e2e_device_keys("user2", "device2", now, "json22")
 
         res = yield self.store.get_e2e_device_keys(
             (("user1", "device1"), ("user2", "device2"))
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 0d4e74d637..86c7ac350d 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -27,11 +27,11 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_prev_events_for_room(self):
-        room_id = '@ROOM:local'
+        room_id = "@ROOM:local"
 
         # add a bunch of events and hashes to act as forward extremities
         def insert_event(txn, i):
-            event_id = '$event_%i:local' % i
+            event_id = "$event_%i:local" % i
 
             txn.execute(
                 (
@@ -45,19 +45,19 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
 
             txn.execute(
                 (
-                    'INSERT INTO event_forward_extremities (room_id, event_id) '
-                    'VALUES (?, ?)'
+                    "INSERT INTO event_forward_extremities (room_id, event_id) "
+                    "VALUES (?, ?)"
                 ),
                 (room_id, event_id),
             )
 
             txn.execute(
                 (
-                    'INSERT INTO event_reference_hashes '
-                    '(event_id, algorithm, hash) '
+                    "INSERT INTO event_reference_hashes "
+                    "(event_id, algorithm, hash) "
                     "VALUES (?, 'sha256', ?)"
                 ),
-                (event_id, b'ffff'),
+                (event_id, b"ffff"),
             )
 
         for i in range(0, 11):
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
new file mode 100644
index 0000000000..d44359ff93
--- /dev/null
+++ b/tests/storage/test_event_metrics.py
@@ -0,0 +1,84 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from prometheus_client.exposition import generate_latest
+
+from synapse.metrics import REGISTRY
+from synapse.types import Requester, UserID
+
+from tests.unittest import HomeserverTestCase
+
+
+class ExtremStatisticsTestCase(HomeserverTestCase):
+    def test_exposed_to_prometheus(self):
+        """
+        Forward extremity counts are exposed via Prometheus.
+        """
+        room_creator = self.hs.get_room_creation_handler()
+
+        user = UserID("alice", "test")
+        requester = Requester(user, None, False, None, None)
+
+        # Real events, forward extremities
+        events = [(3, 2), (6, 2), (4, 6)]
+
+        for event_count, extrems in events:
+            info = self.get_success(room_creator.create_room(requester, {}))
+            room_id = info["room_id"]
+
+            last_event = None
+
+            # Make a real event chain
+            for i in range(event_count):
+                ev = self.create_and_send_event(room_id, user, False, last_event)
+                last_event = [ev]
+
+            # Sprinkle in some extremities
+            for i in range(extrems):
+                ev = self.create_and_send_event(room_id, user, False, last_event)
+
+        # Let it run for a while, then pull out the statistics from the
+        # Prometheus client registry
+        self.reactor.advance(60 * 60 * 1000)
+        self.pump(1)
+
+        items = set(
+            filter(
+                lambda x: b"synapse_forward_extremities_" in x,
+                generate_latest(REGISTRY).split(b"\n"),
+            )
+        )
+
+        expected = set(
+            [
+                b'synapse_forward_extremities_bucket{le="1.0"} 0.0',
+                b'synapse_forward_extremities_bucket{le="2.0"} 2.0',
+                b'synapse_forward_extremities_bucket{le="3.0"} 2.0',
+                b'synapse_forward_extremities_bucket{le="5.0"} 2.0',
+                b'synapse_forward_extremities_bucket{le="7.0"} 3.0',
+                b'synapse_forward_extremities_bucket{le="10.0"} 3.0',
+                b'synapse_forward_extremities_bucket{le="15.0"} 3.0',
+                b'synapse_forward_extremities_bucket{le="20.0"} 3.0',
+                b'synapse_forward_extremities_bucket{le="50.0"} 3.0',
+                b'synapse_forward_extremities_bucket{le="100.0"} 3.0',
+                b'synapse_forward_extremities_bucket{le="200.0"} 3.0',
+                b'synapse_forward_extremities_bucket{le="500.0"} 3.0',
+                b'synapse_forward_extremities_bucket{le="+Inf"} 3.0',
+                b"synapse_forward_extremities_count 3.0",
+                b"synapse_forward_extremities_sum 10.0",
+            ]
+        )
+
+        self.assertEqual(items, expected)
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index 6bfaa00fe9..e07ff01201 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -17,6 +17,8 @@ import signedjson.key
 
 from twisted.internet.defer import Deferred
 
+from synapse.storage.keys import FetchKeyResult
+
 import tests.unittest
 
 KEY_1 = signedjson.key.decode_verify_key_base64(
@@ -31,23 +33,34 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
     def test_get_server_verify_keys(self):
         store = self.hs.get_datastore()
 
-        d = store.store_server_verify_key("server1", "from_server", 0, KEY_1)
-        self.get_success(d)
-        d = store.store_server_verify_key("server1", "from_server", 0, KEY_2)
+        key_id_1 = "ed25519:key1"
+        key_id_2 = "ed25519:KEY_ID_2"
+        d = store.store_server_verify_keys(
+            "from_server",
+            10,
+            [
+                ("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
+                ("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
+            ],
+        )
         self.get_success(d)
 
         d = store.get_server_verify_keys(
-            [
-                ("server1", "ed25519:key1"),
-                ("server1", "ed25519:key2"),
-                ("server1", "ed25519:key3"),
-            ]
+            [("server1", key_id_1), ("server1", key_id_2), ("server1", "ed25519:key3")]
         )
         res = self.get_success(d)
 
         self.assertEqual(len(res.keys()), 3)
-        self.assertEqual(res[("server1", "ed25519:key1")].version, "key1")
-        self.assertEqual(res[("server1", "ed25519:key2")].version, "key2")
+        res1 = res[("server1", key_id_1)]
+        self.assertEqual(res1.verify_key, KEY_1)
+        self.assertEqual(res1.verify_key.version, "key1")
+        self.assertEqual(res1.valid_until_ts, 100)
+
+        res2 = res[("server1", key_id_2)]
+        self.assertEqual(res2.verify_key, KEY_2)
+        # version comes from the ID it was stored with
+        self.assertEqual(res2.verify_key.version, "KEY_ID_2")
+        self.assertEqual(res2.valid_until_ts, 200)
 
         # non-existent result gives None
         self.assertIsNone(res[("server1", "ed25519:key3")])
@@ -60,32 +73,51 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
         key_id_1 = "ed25519:key1"
         key_id_2 = "ed25519:key2"
 
-        d = store.store_server_verify_key("srv1", "from_server", 0, KEY_1)
-        self.get_success(d)
-        d = store.store_server_verify_key("srv1", "from_server", 0, KEY_2)
+        d = store.store_server_verify_keys(
+            "from_server",
+            0,
+            [
+                ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
+                ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
+            ],
+        )
         self.get_success(d)
 
         d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
         res = self.get_success(d)
         self.assertEqual(len(res.keys()), 2)
-        self.assertEqual(res[("srv1", key_id_1)], KEY_1)
-        self.assertEqual(res[("srv1", key_id_2)], KEY_2)
+
+        res1 = res[("srv1", key_id_1)]
+        self.assertEqual(res1.verify_key, KEY_1)
+        self.assertEqual(res1.valid_until_ts, 100)
+
+        res2 = res[("srv1", key_id_2)]
+        self.assertEqual(res2.verify_key, KEY_2)
+        self.assertEqual(res2.valid_until_ts, 200)
 
         # we should be able to look up the same thing again without a db hit
         res = store.get_server_verify_keys([("srv1", key_id_1)])
         if isinstance(res, Deferred):
             res = self.successResultOf(res)
         self.assertEqual(len(res.keys()), 1)
-        self.assertEqual(res[("srv1", key_id_1)], KEY_1)
+        self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)
 
         new_key_2 = signedjson.key.get_verify_key(
             signedjson.key.generate_signing_key("key2")
         )
-        d = store.store_server_verify_key("srv1", "from_server", 10, new_key_2)
+        d = store.store_server_verify_keys(
+            "from_server", 10, [("srv1", key_id_2, FetchKeyResult(new_key_2, 300))]
+        )
         self.get_success(d)
 
         d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
         res = self.get_success(d)
         self.assertEqual(len(res.keys()), 2)
-        self.assertEqual(res[("srv1", key_id_1)], KEY_1)
-        self.assertEqual(res[("srv1", key_id_2)], new_key_2)
+
+        res1 = res[("srv1", key_id_1)]
+        self.assertEqual(res1.verify_key, KEY_1)
+        self.assertEqual(res1.valid_until_ts, 100)
+
+        res2 = res[("srv1", key_id_2)]
+        self.assertEqual(res2.verify_key, new_key_2)
+        self.assertEqual(res2.valid_until_ts, 300)
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index f458c03054..0ce0b991f9 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -46,9 +46,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         user3_email = "user3@matrix.org"
 
         threepids = [
-            {'medium': 'email', 'address': user1_email},
-            {'medium': 'email', 'address': user2_email},
-            {'medium': 'email', 'address': user3_email},
+            {"medium": "email", "address": user1_email},
+            {"medium": "email", "address": user2_email},
+            {"medium": "email", "address": user3_email},
         ]
         # -1 because user3 is a support user and does not count
         user_num = len(threepids) - 1
@@ -177,7 +177,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         self.store.user_last_seen_monthly_active = Mock(
             return_value=defer.succeed(None)
         )
-        self.store.populate_monthly_active_users('user_id')
+        self.store.populate_monthly_active_users("user_id")
         self.pump()
         self.store.upsert_monthly_active_user.assert_called_once()
 
@@ -188,7 +188,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         self.store.user_last_seen_monthly_active = Mock(
             return_value=defer.succeed(self.hs.get_clock().time_msec())
         )
-        self.store.populate_monthly_active_users('user_id')
+        self.store.populate_monthly_active_users("user_id")
         self.pump()
         self.store.upsert_monthly_active_user.assert_not_called()
 
@@ -198,13 +198,13 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         self.assertEquals(self.get_success(count), 0)
         # Test reserved users but no registered users
 
-        user1 = '@user1:example.com'
-        user2 = '@user2:example.com'
-        user1_email = 'user1@example.com'
-        user2_email = 'user2@example.com'
+        user1 = "@user1:example.com"
+        user2 = "@user2:example.com"
+        user1_email = "user1@example.com"
+        user2_email = "user2@example.com"
         threepids = [
-            {'medium': 'email', 'address': user1_email},
-            {'medium': 'email', 'address': user2_email},
+            {"medium": "email", "address": user1_email},
+            {"medium": "email", "address": user2_email},
         ]
         self.hs.config.mau_limits_reserved_threepids = threepids
         self.store.runInteraction(
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 4823d44dec..732a778fab 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -82,7 +82,7 @@ class RedactionTestCase(unittest.TestCase):
                 "sender": user.to_string(),
                 "state_key": user.to_string(),
                 "room_id": room.to_string(),
-                "content": {"body": body, "msgtype": u"message"},
+                "content": {"body": body, "msgtype": "message"},
             },
         )
 
@@ -118,7 +118,7 @@ class RedactionTestCase(unittest.TestCase):
     def test_redact(self):
         yield self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
 
-        msg_event = yield self.inject_message(self.room1, self.u_alice, u"t")
+        msg_event = yield self.inject_message(self.room1, self.u_alice, "t")
 
         # Check event has not been redacted:
         event = yield self.store.get_event(msg_event.event_id)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index c0e0155bb4..625b651e91 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -128,4 +128,4 @@ class TokenGenerator:
 
     def generate(self, user_id):
         self._last_issued_token += 1
-        return u"%s-%d" % (user_id, self._last_issued_token)
+        return "%s-%d" % (user_id, self._last_issued_token)
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index a1ea23b068..1bee45706f 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -78,7 +78,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def STALE_test_room_name(self):
-        name = u"A-Room-Name"
+        name = "A-Room-Name"
 
         yield self.inject_room_event(
             etype=EventTypes.Name, name=name, content={"name": name}, depth=1
@@ -94,7 +94,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def STALE_test_room_topic(self):
-        topic = u"A place for things"
+        topic = "A place for things"
 
         yield self.inject_room_event(
             etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index b6169436de..212a7ae765 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -76,10 +76,10 @@ class StateStoreTestCase(tests.unittest.TestCase):
     @defer.inlineCallbacks
     def test_get_state_groups_ids(self):
         e1 = yield self.inject_state_event(
-            self.room, self.u_alice, EventTypes.Create, '', {}
+            self.room, self.u_alice, EventTypes.Create, "", {}
         )
         e2 = yield self.inject_state_event(
-            self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"}
+            self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
         )
 
         state_group_map = yield self.store.get_state_groups_ids(
@@ -89,16 +89,16 @@ class StateStoreTestCase(tests.unittest.TestCase):
         state_map = list(state_group_map.values())[0]
         self.assertDictEqual(
             state_map,
-            {(EventTypes.Create, ''): e1.event_id, (EventTypes.Name, ''): e2.event_id},
+            {(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
         )
 
     @defer.inlineCallbacks
     def test_get_state_groups(self):
         e1 = yield self.inject_state_event(
-            self.room, self.u_alice, EventTypes.Create, '', {}
+            self.room, self.u_alice, EventTypes.Create, "", {}
         )
         e2 = yield self.inject_state_event(
-            self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"}
+            self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
         )
 
         state_group_map = yield self.store.get_state_groups(self.room, [e2.event_id])
@@ -113,10 +113,10 @@ class StateStoreTestCase(tests.unittest.TestCase):
         # this defaults to a linear DAG as each new injection defaults to whatever
         # forward extremities are currently in the DB for this room.
         e1 = yield self.inject_state_event(
-            self.room, self.u_alice, EventTypes.Create, '', {}
+            self.room, self.u_alice, EventTypes.Create, "", {}
         )
         e2 = yield self.inject_state_event(
-            self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"}
+            self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
         )
         e3 = yield self.inject_state_event(
             self.room,
@@ -158,7 +158,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # check we can filter to the m.room.name event (with a '' state key)
         state = yield self.store.get_state_for_event(
-            e5.event_id, StateFilter.from_types([(EventTypes.Name, '')])
+            e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
         )
 
         self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
diff --git a/tests/test_preview.py b/tests/test_preview.py
index 84ef5e5ba4..7f67ee9e1f 100644
--- a/tests/test_preview.py
+++ b/tests/test_preview.py
@@ -24,14 +24,14 @@ from . import unittest
 class PreviewTestCase(unittest.TestCase):
     def test_long_summarize(self):
         example_paras = [
-            u"""Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:
+            """Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:
             Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in
             Troms county, Norway. The administrative centre of the municipality is
             the city of Tromsø. Outside of Norway, Tromso and Tromsö are
             alternative spellings of the city.Tromsø is considered the northernmost
             city in the world with a population above 50,000. The most populous town
             north of it is Alta, Norway, with a population of 14,272 (2013).""",
-            u"""Tromsø lies in Northern Norway. The municipality has a population of
+            """Tromsø lies in Northern Norway. The municipality has a population of
             (2015) 72,066, but with an annual influx of students it has over 75,000
             most of the year. It is the largest urban area in Northern Norway and the
             third largest north of the Arctic Circle (following Murmansk and Norilsk).
@@ -44,7 +44,7 @@ class PreviewTestCase(unittest.TestCase):
             Sandnessund Bridge. Tromsø Airport connects the city to many destinations
             in Europe. The city is warmer than most other places located on the same
             latitude, due to the warming effect of the Gulf Stream.""",
-            u"""The city centre of Tromsø contains the highest number of old wooden
+            """The city centre of Tromsø contains the highest number of old wooden
             houses in Northern Norway, the oldest house dating from 1789. The Arctic
             Cathedral, a modern church from 1965, is probably the most famous landmark
             in Tromsø. The city is a cultural centre for its region, with several
@@ -58,87 +58,87 @@ class PreviewTestCase(unittest.TestCase):
 
         self.assertEquals(
             desc,
-            u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
-            u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
-            u" Troms county, Norway. The administrative centre of the municipality is"
-            u" the city of Tromsø. Outside of Norway, Tromso and Tromsö are"
-            u" alternative spellings of the city.Tromsø is considered the northernmost"
-            u" city in the world with a population above 50,000. The most populous town"
-            u" north of it is Alta, Norway, with a population of 14,272 (2013).",
+            "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
+            " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
+            " Troms county, Norway. The administrative centre of the municipality is"
+            " the city of Tromsø. Outside of Norway, Tromso and Tromsö are"
+            " alternative spellings of the city.Tromsø is considered the northernmost"
+            " city in the world with a population above 50,000. The most populous town"
+            " north of it is Alta, Norway, with a population of 14,272 (2013).",
         )
 
         desc = summarize_paragraphs(example_paras[1:], min_size=200, max_size=500)
 
         self.assertEquals(
             desc,
-            u"Tromsø lies in Northern Norway. The municipality has a population of"
-            u" (2015) 72,066, but with an annual influx of students it has over 75,000"
-            u" most of the year. It is the largest urban area in Northern Norway and the"
-            u" third largest north of the Arctic Circle (following Murmansk and Norilsk)."
-            u" Most of Tromsø, including the city centre, is located on the island of"
-            u" Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012,"
-            u" Tromsøya had a population of 36,088. Substantial parts of the urban…",
+            "Tromsø lies in Northern Norway. The municipality has a population of"
+            " (2015) 72,066, but with an annual influx of students it has over 75,000"
+            " most of the year. It is the largest urban area in Northern Norway and the"
+            " third largest north of the Arctic Circle (following Murmansk and Norilsk)."
+            " Most of Tromsø, including the city centre, is located on the island of"
+            " Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012,"
+            " Tromsøya had a population of 36,088. Substantial parts of the urban…",
         )
 
     def test_short_summarize(self):
         example_paras = [
-            u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
-            u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
-            u" Troms county, Norway.",
-            u"Tromsø lies in Northern Norway. The municipality has a population of"
-            u" (2015) 72,066, but with an annual influx of students it has over 75,000"
-            u" most of the year.",
-            u"The city centre of Tromsø contains the highest number of old wooden"
-            u" houses in Northern Norway, the oldest house dating from 1789. The Arctic"
-            u" Cathedral, a modern church from 1965, is probably the most famous landmark"
-            u" in Tromsø.",
+            "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
+            " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
+            " Troms county, Norway.",
+            "Tromsø lies in Northern Norway. The municipality has a population of"
+            " (2015) 72,066, but with an annual influx of students it has over 75,000"
+            " most of the year.",
+            "The city centre of Tromsø contains the highest number of old wooden"
+            " houses in Northern Norway, the oldest house dating from 1789. The Arctic"
+            " Cathedral, a modern church from 1965, is probably the most famous landmark"
+            " in Tromsø.",
         ]
 
         desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
 
         self.assertEquals(
             desc,
-            u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
-            u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
-            u" Troms county, Norway.\n"
-            u"\n"
-            u"Tromsø lies in Northern Norway. The municipality has a population of"
-            u" (2015) 72,066, but with an annual influx of students it has over 75,000"
-            u" most of the year.",
+            "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
+            " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
+            " Troms county, Norway.\n"
+            "\n"
+            "Tromsø lies in Northern Norway. The municipality has a population of"
+            " (2015) 72,066, but with an annual influx of students it has over 75,000"
+            " most of the year.",
         )
 
     def test_small_then_large_summarize(self):
         example_paras = [
-            u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
-            u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
-            u" Troms county, Norway.",
-            u"Tromsø lies in Northern Norway. The municipality has a population of"
-            u" (2015) 72,066, but with an annual influx of students it has over 75,000"
-            u" most of the year."
-            u" The city centre of Tromsø contains the highest number of old wooden"
-            u" houses in Northern Norway, the oldest house dating from 1789. The Arctic"
-            u" Cathedral, a modern church from 1965, is probably the most famous landmark"
-            u" in Tromsø.",
+            "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
+            " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
+            " Troms county, Norway.",
+            "Tromsø lies in Northern Norway. The municipality has a population of"
+            " (2015) 72,066, but with an annual influx of students it has over 75,000"
+            " most of the year."
+            " The city centre of Tromsø contains the highest number of old wooden"
+            " houses in Northern Norway, the oldest house dating from 1789. The Arctic"
+            " Cathedral, a modern church from 1965, is probably the most famous landmark"
+            " in Tromsø.",
         ]
 
         desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
         self.assertEquals(
             desc,
-            u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
-            u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
-            u" Troms county, Norway.\n"
-            u"\n"
-            u"Tromsø lies in Northern Norway. The municipality has a population of"
-            u" (2015) 72,066, but with an annual influx of students it has over 75,000"
-            u" most of the year. The city centre of Tromsø contains the highest number"
-            u" of old wooden houses in Northern Norway, the oldest house dating from"
-            u" 1789. The Arctic Cathedral, a modern church from…",
+            "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
+            " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
+            " Troms county, Norway.\n"
+            "\n"
+            "Tromsø lies in Northern Norway. The municipality has a population of"
+            " (2015) 72,066, but with an annual influx of students it has over 75,000"
+            " most of the year. The city centre of Tromsø contains the highest number"
+            " of old wooden houses in Northern Norway, the oldest house dating from"
+            " 1789. The Arctic Cathedral, a modern church from…",
         )
 
 
 class PreviewUrlTestCase(unittest.TestCase):
     def test_simple(self):
-        html = u"""
+        html = """
         <html>
         <head><title>Foo</title></head>
         <body>
@@ -149,10 +149,10 @@ class PreviewUrlTestCase(unittest.TestCase):
 
         og = decode_and_calc_og(html, "http://example.com/test.html")
 
-        self.assertEquals(og, {u"og:title": u"Foo", u"og:description": u"Some text."})
+        self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."})
 
     def test_comment(self):
-        html = u"""
+        html = """
         <html>
         <head><title>Foo</title></head>
         <body>
@@ -164,10 +164,10 @@ class PreviewUrlTestCase(unittest.TestCase):
 
         og = decode_and_calc_og(html, "http://example.com/test.html")
 
-        self.assertEquals(og, {u"og:title": u"Foo", u"og:description": u"Some text."})
+        self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."})
 
     def test_comment2(self):
-        html = u"""
+        html = """
         <html>
         <head><title>Foo</title></head>
         <body>
@@ -185,13 +185,13 @@ class PreviewUrlTestCase(unittest.TestCase):
         self.assertEquals(
             og,
             {
-                u"og:title": u"Foo",
-                u"og:description": u"Some text.\n\nSome more text.\n\nText\n\nMore text",
+                "og:title": "Foo",
+                "og:description": "Some text.\n\nSome more text.\n\nText\n\nMore text",
             },
         )
 
     def test_script(self):
-        html = u"""
+        html = """
         <html>
         <head><title>Foo</title></head>
         <body>
@@ -203,10 +203,10 @@ class PreviewUrlTestCase(unittest.TestCase):
 
         og = decode_and_calc_og(html, "http://example.com/test.html")
 
-        self.assertEquals(og, {u"og:title": u"Foo", u"og:description": u"Some text."})
+        self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."})
 
     def test_missing_title(self):
-        html = u"""
+        html = """
         <html>
         <body>
         Some text.
@@ -216,10 +216,10 @@ class PreviewUrlTestCase(unittest.TestCase):
 
         og = decode_and_calc_og(html, "http://example.com/test.html")
 
-        self.assertEquals(og, {u"og:title": None, u"og:description": u"Some text."})
+        self.assertEquals(og, {"og:title": None, "og:description": "Some text."})
 
     def test_h1_as_title(self):
-        html = u"""
+        html = """
         <html>
         <meta property="og:description" content="Some text."/>
         <body>
@@ -230,10 +230,10 @@ class PreviewUrlTestCase(unittest.TestCase):
 
         og = decode_and_calc_og(html, "http://example.com/test.html")
 
-        self.assertEquals(og, {u"og:title": u"Title", u"og:description": u"Some text."})
+        self.assertEquals(og, {"og:title": "Title", "og:description": "Some text."})
 
     def test_missing_title_and_broken_h1(self):
-        html = u"""
+        html = """
         <html>
         <body>
         <h1><a href="foo"/></h1>
@@ -244,4 +244,4 @@ class PreviewUrlTestCase(unittest.TestCase):
 
         og = decode_and_calc_og(html, "http://example.com/test.html")
 
-        self.assertEquals(og, {u"og:title": None, u"og:description": u"Some text."})
+        self.assertEquals(og, {"og:title": None, "og:description": "Some text."})
diff --git a/tests/test_server.py b/tests/test_server.py
index 08fb3fe02f..da29ae92ce 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -69,8 +69,8 @@ class JsonResourceTests(unittest.TestCase):
         )
         render(request, res, self.reactor)
 
-        self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]})
-        self.assertEqual(got_kwargs, {u"room_id": u"\N{SNOWMAN}"})
+        self.assertEqual(request.args, {b"a": ["\N{SNOWMAN}".encode("utf8")]})
+        self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"})
 
     def test_callback_direct_exception(self):
         """
@@ -87,7 +87,7 @@ class JsonResourceTests(unittest.TestCase):
         request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
         render(request, res, self.reactor)
 
-        self.assertEqual(channel.result["code"], b'500')
+        self.assertEqual(channel.result["code"], b"500")
 
     def test_callback_indirect_exception(self):
         """
@@ -110,7 +110,7 @@ class JsonResourceTests(unittest.TestCase):
         request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
         render(request, res, self.reactor)
 
-        self.assertEqual(channel.result["code"], b'500')
+        self.assertEqual(channel.result["code"], b"500")
 
     def test_callback_synapseerror(self):
         """
@@ -127,7 +127,7 @@ class JsonResourceTests(unittest.TestCase):
         request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
         render(request, res, self.reactor)
 
-        self.assertEqual(channel.result["code"], b'403')
+        self.assertEqual(channel.result["code"], b"403")
         self.assertEqual(channel.json_body["error"], "Forbidden!!one!")
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
 
@@ -148,7 +148,7 @@ class JsonResourceTests(unittest.TestCase):
         request, channel = make_request(self.reactor, b"GET", b"/_matrix/foobar")
         render(request, res, self.reactor)
 
-        self.assertEqual(channel.result["code"], b'400')
+        self.assertEqual(channel.result["code"], b"400")
         self.assertEqual(channel.json_body["error"], "Unrecognized request")
         self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
 
@@ -180,7 +180,7 @@ class SiteTestCase(unittest.HomeserverTestCase):
         # Make a resource and a Site, the resource will hang and allow us to
         # time out the request while it's 'processing'
         base_resource = Resource()
-        base_resource.putChild(b'', HangingResource())
+        base_resource.putChild(b"", HangingResource())
         site = SynapseSite("test", "site_tag", {}, base_resource, "1.0")
 
         server = site.buildProtocol(None)
diff --git a/tests/test_state.py b/tests/test_state.py
index 6491a7105a..6d33566f47 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -366,11 +366,11 @@ class StateTestCase(unittest.TestCase):
     def _add_depths(self, nodes, edges):
         def _get_depth(ev):
             node = nodes[ev]
-            if 'depth' not in node:
+            if "depth" not in node:
                 prevs = edges[ev]
                 depth = max(_get_depth(prev) for prev in prevs) + 1
-                node['depth'] = depth
-            return node['depth']
+                node["depth"] = depth
+            return node["depth"]
 
         for n in nodes:
             _get_depth(n)
diff --git a/tests/test_types.py b/tests/test_types.py
index d83c36559f..9ab5f829b0 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -102,7 +102,7 @@ class MapUsernameTestCase(unittest.TestCase):
 
     def testNonAscii(self):
         # this should work with either a unicode or a bytes
-        self.assertEqual(map_username_to_mxid_localpart(u'têst'), "t=c3=aast")
+        self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast")
         self.assertEqual(
-            map_username_to_mxid_localpart(u'têst'.encode('utf-8')), "t=c3=aast"
+            map_username_to_mxid_localpart("têst".encode("utf-8")), "t=c3=aast"
         )
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index fde0baee8e..813f984199 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -27,7 +27,7 @@ class ToTwistedHandler(logging.Handler):
 
     def emit(self, record):
         log_entry = self.format(record)
-        log_level = record.levelname.lower().replace('warning', 'warn')
+        log_level = record.levelname.lower().replace("warning", "warn")
         self.tx_log.emit(
             twisted.logger.LogLevel.levelWithName(log_level),
             log_entry.replace("{", r"(").replace("}", r")"),
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 6a180ddc32..118c3bd238 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -265,7 +265,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
 
         pr.disable()
         with open("filter_events_for_server.profile", "w+") as f:
-            ps = pstats.Stats(pr, stream=f).sort_stats('cumulative')
+            ps = pstats.Stats(pr, stream=f).sort_stats("cumulative")
             ps.print_stats()
 
         # the result should be 5 redacted events, and 5 unredacted events.
diff --git a/tests/unittest.py b/tests/unittest.py
index 26204470b1..36df43c137 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -27,11 +27,12 @@ import twisted.logger
 from twisted.internet.defer import Deferred
 from twisted.trial import unittest
 
+from synapse.api.constants import EventTypes
 from synapse.config.homeserver import HomeServerConfig
 from synapse.http.server import JsonResource
 from synapse.http.site import SynapseRequest
 from synapse.server import HomeServer
-from synapse.types import UserID, create_requester
+from synapse.types import Requester, UserID, create_requester
 from synapse.util.logcontext import LoggingContext
 
 from tests.server import get_clock, make_request, render, setup_test_homeserver
@@ -297,7 +298,7 @@ class HomeserverTestCase(TestCase):
             Tuple[synapse.http.site.SynapseRequest, channel]
         """
         if isinstance(content, dict):
-            content = json.dumps(content).encode('utf8')
+            content = json.dumps(content).encode("utf8")
 
         return make_request(
             self.reactor,
@@ -341,7 +342,7 @@ class HomeserverTestCase(TestCase):
 
         # Parse the config from a config dict into a HomeServerConfig
         config_obj = HomeServerConfig()
-        config_obj.parse_config_dict(config)
+        config_obj.parse_config_dict(config, "", "")
         kwargs["config"] = config_obj
 
         hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
@@ -388,7 +389,7 @@ class HomeserverTestCase(TestCase):
         Returns:
             The MXID of the new user (unicode).
         """
-        self.hs.config.registration_shared_secret = u"shared"
+        self.hs.config.registration_shared_secret = "shared"
 
         # Create the user
         request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register")
@@ -396,13 +397,13 @@ class HomeserverTestCase(TestCase):
         nonce = channel.json_body["nonce"]
 
         want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
-        nonce_str = b"\x00".join([username.encode('utf8'), password.encode('utf8')])
+        nonce_str = b"\x00".join([username.encode("utf8"), password.encode("utf8")])
         if admin:
             nonce_str += b"\x00admin"
         else:
             nonce_str += b"\x00notadmin"
 
-        want_mac.update(nonce.encode('ascii') + b"\x00" + nonce_str)
+        want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str)
         want_mac = want_mac.hexdigest()
 
         body = json.dumps(
@@ -415,7 +416,7 @@ class HomeserverTestCase(TestCase):
             }
         )
         request, channel = self.make_request(
-            "POST", "/_matrix/client/r0/admin/register", body.encode('utf8')
+            "POST", "/_matrix/client/r0/admin/register", body.encode("utf8")
         )
         self.render(request)
         self.assertEqual(channel.code, 200)
@@ -434,10 +435,78 @@ class HomeserverTestCase(TestCase):
             body["device_id"] = device_id
 
         request, channel = self.make_request(
-            "POST", "/_matrix/client/r0/login", json.dumps(body).encode('utf8')
+            "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
         )
         self.render(request)
         self.assertEqual(channel.code, 200, channel.result)
 
         access_token = channel.json_body["access_token"]
         return access_token
+
+    def create_and_send_event(
+        self, room_id, user, soft_failed=False, prev_event_ids=None
+    ):
+        """
+        Create and send an event.
+
+        Args:
+            soft_failed (bool): Whether to create a soft failed event or not
+            prev_event_ids (list[str]|None): Explicitly set the prev events,
+                or if None just use the default
+
+        Returns:
+            str: The new event's ID.
+        """
+        event_creator = self.hs.get_event_creation_handler()
+        secrets = self.hs.get_secrets()
+        requester = Requester(user, None, False, None, None)
+
+        prev_events_and_hashes = None
+        if prev_event_ids:
+            prev_events_and_hashes = [[p, {}, 0] for p in prev_event_ids]
+
+        event, context = self.get_success(
+            event_creator.create_event(
+                requester,
+                {
+                    "type": EventTypes.Message,
+                    "room_id": room_id,
+                    "sender": user.to_string(),
+                    "content": {"body": secrets.token_hex(), "msgtype": "m.text"},
+                },
+                prev_events_and_hashes=prev_events_and_hashes,
+            )
+        )
+
+        if soft_failed:
+            event.internal_metadata.soft_failed = True
+
+        self.get_success(event_creator.send_nonmember_event(requester, event, context))
+
+        return event.event_id
+
+    def add_extremity(self, room_id, event_id):
+        """
+        Add the given event as an extremity to the room.
+        """
+        self.get_success(
+            self.hs.get_datastore()._simple_insert(
+                table="event_forward_extremities",
+                values={"room_id": room_id, "event_id": event_id},
+                desc="test_add_extremity",
+            )
+        )
+
+        self.hs.get_datastore().get_latest_event_ids_in_room.invalidate((room_id,))
+
+    def attempt_wrong_password_login(self, username, password):
+        """Attempts to login as the user with the given password, asserting
+        that the attempt *fails*.
+        """
+        body = {"type": "m.login.password", "user": username, "password": password}
+
+        request, channel = self.make_request(
+            "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 403, channel.result)
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 463a737efa..6f8f52537c 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -88,24 +88,24 @@ class DescriptorTestCase(unittest.TestCase):
 
         obj = Cls()
 
-        obj.mock.return_value = 'fish'
+        obj.mock.return_value = "fish"
         r = yield obj.fn(1, 2)
-        self.assertEqual(r, 'fish')
+        self.assertEqual(r, "fish")
         obj.mock.assert_called_once_with(1, 2)
         obj.mock.reset_mock()
 
         # a call with different params should call the mock again
-        obj.mock.return_value = 'chips'
+        obj.mock.return_value = "chips"
         r = yield obj.fn(1, 3)
-        self.assertEqual(r, 'chips')
+        self.assertEqual(r, "chips")
         obj.mock.assert_called_once_with(1, 3)
         obj.mock.reset_mock()
 
         # the two values should now be cached
         r = yield obj.fn(1, 2)
-        self.assertEqual(r, 'fish')
+        self.assertEqual(r, "fish")
         r = yield obj.fn(1, 3)
-        self.assertEqual(r, 'chips')
+        self.assertEqual(r, "chips")
         obj.mock.assert_not_called()
 
     @defer.inlineCallbacks
@@ -121,25 +121,25 @@ class DescriptorTestCase(unittest.TestCase):
                 return self.mock(arg1, arg2)
 
         obj = Cls()
-        obj.mock.return_value = 'fish'
+        obj.mock.return_value = "fish"
         r = yield obj.fn(1, 2)
-        self.assertEqual(r, 'fish')
+        self.assertEqual(r, "fish")
         obj.mock.assert_called_once_with(1, 2)
         obj.mock.reset_mock()
 
         # a call with different params should call the mock again
-        obj.mock.return_value = 'chips'
+        obj.mock.return_value = "chips"
         r = yield obj.fn(2, 3)
-        self.assertEqual(r, 'chips')
+        self.assertEqual(r, "chips")
         obj.mock.assert_called_once_with(2, 3)
         obj.mock.reset_mock()
 
         # the two values should now be cached; we should be able to vary
         # the second argument and still get the cached result.
         r = yield obj.fn(1, 4)
-        self.assertEqual(r, 'fish')
+        self.assertEqual(r, "fish")
         r = yield obj.fn(2, 5)
-        self.assertEqual(r, 'chips')
+        self.assertEqual(r, "chips")
         obj.mock.assert_not_called()
 
     def test_cache_logcontexts(self):
@@ -248,30 +248,30 @@ class DescriptorTestCase(unittest.TestCase):
 
         obj = Cls()
 
-        obj.mock.return_value = 'fish'
+        obj.mock.return_value = "fish"
         r = yield obj.fn(1, 2, 3)
-        self.assertEqual(r, 'fish')
+        self.assertEqual(r, "fish")
         obj.mock.assert_called_once_with(1, 2, 3)
         obj.mock.reset_mock()
 
         # a call with same params shouldn't call the mock again
         r = yield obj.fn(1, 2)
-        self.assertEqual(r, 'fish')
+        self.assertEqual(r, "fish")
         obj.mock.assert_not_called()
         obj.mock.reset_mock()
 
         # a call with different params should call the mock again
-        obj.mock.return_value = 'chips'
+        obj.mock.return_value = "chips"
         r = yield obj.fn(2, 3)
-        self.assertEqual(r, 'chips')
+        self.assertEqual(r, "chips")
         obj.mock.assert_called_once_with(2, 3, 3)
         obj.mock.reset_mock()
 
         # the two values should now be cached
         r = yield obj.fn(1, 2)
-        self.assertEqual(r, 'fish')
+        self.assertEqual(r, "fish")
         r = yield obj.fn(2, 3)
-        self.assertEqual(r, 'chips')
+        self.assertEqual(r, "chips")
         obj.mock.assert_not_called()
 
 
@@ -297,7 +297,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
         with logcontext.LoggingContext() as c1:
             c1.request = "c1"
             obj = Cls()
-            obj.mock.return_value = {10: 'fish', 20: 'chips'}
+            obj.mock.return_value = {10: "fish", 20: "chips"}
             d1 = obj.list_fn([10, 20], 2)
             self.assertEqual(
                 logcontext.LoggingContext.current_context(),
@@ -306,26 +306,26 @@ class CachedListDescriptorTestCase(unittest.TestCase):
             r = yield d1
             self.assertEqual(logcontext.LoggingContext.current_context(), c1)
             obj.mock.assert_called_once_with([10, 20], 2)
-            self.assertEqual(r, {10: 'fish', 20: 'chips'})
+            self.assertEqual(r, {10: "fish", 20: "chips"})
             obj.mock.reset_mock()
 
             # a call with different params should call the mock again
-            obj.mock.return_value = {30: 'peas'}
+            obj.mock.return_value = {30: "peas"}
             r = yield obj.list_fn([20, 30], 2)
             obj.mock.assert_called_once_with([30], 2)
-            self.assertEqual(r, {20: 'chips', 30: 'peas'})
+            self.assertEqual(r, {20: "chips", 30: "peas"})
             obj.mock.reset_mock()
 
             # all the values should now be cached
             r = yield obj.fn(10, 2)
-            self.assertEqual(r, 'fish')
+            self.assertEqual(r, "fish")
             r = yield obj.fn(20, 2)
-            self.assertEqual(r, 'chips')
+            self.assertEqual(r, "chips")
             r = yield obj.fn(30, 2)
-            self.assertEqual(r, 'peas')
+            self.assertEqual(r, "peas")
             r = yield obj.list_fn([10, 20, 30], 2)
             obj.mock.assert_not_called()
-            self.assertEqual(r, {10: 'fish', 20: 'chips', 30: 'peas'})
+            self.assertEqual(r, {10: "fish", 20: "chips", 30: "peas"})
 
     @defer.inlineCallbacks
     def test_invalidate(self):
@@ -350,16 +350,16 @@ class CachedListDescriptorTestCase(unittest.TestCase):
         invalidate1 = mock.Mock()
 
         # cache miss
-        obj.mock.return_value = {10: 'fish', 20: 'chips'}
+        obj.mock.return_value = {10: "fish", 20: "chips"}
         r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
         obj.mock.assert_called_once_with([10, 20], 2)
-        self.assertEqual(r1, {10: 'fish', 20: 'chips'})
+        self.assertEqual(r1, {10: "fish", 20: "chips"})
         obj.mock.reset_mock()
 
         # cache hit
         r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1)
         obj.mock.assert_not_called()
-        self.assertEqual(r2, {10: 'fish', 20: 'chips'})
+        self.assertEqual(r2, {10: "fish", 20: "chips"})
 
         invalidate0.assert_not_called()
         invalidate1.assert_not_called()
diff --git a/tests/util/caches/test_ttlcache.py b/tests/util/caches/test_ttlcache.py
index 03b3c15db6..c94cbb662b 100644
--- a/tests/util/caches/test_ttlcache.py
+++ b/tests/util/caches/test_ttlcache.py
@@ -27,57 +27,57 @@ class CacheTestCase(unittest.TestCase):
 
     def test_get(self):
         """simple set/get tests"""
-        self.cache.set('one', '1', 10)
-        self.cache.set('two', '2', 20)
-        self.cache.set('three', '3', 30)
+        self.cache.set("one", "1", 10)
+        self.cache.set("two", "2", 20)
+        self.cache.set("three", "3", 30)
 
         self.assertEqual(len(self.cache), 3)
 
-        self.assertTrue('one' in self.cache)
-        self.assertEqual(self.cache.get('one'), '1')
-        self.assertEqual(self.cache['one'], '1')
-        self.assertEqual(self.cache.get_with_expiry('one'), ('1', 110))
+        self.assertTrue("one" in self.cache)
+        self.assertEqual(self.cache.get("one"), "1")
+        self.assertEqual(self.cache["one"], "1")
+        self.assertEqual(self.cache.get_with_expiry("one"), ("1", 110))
         self.assertEqual(self.cache._metrics.hits, 3)
         self.assertEqual(self.cache._metrics.misses, 0)
 
-        self.cache.set('two', '2.5', 20)
-        self.assertEqual(self.cache['two'], '2.5')
+        self.cache.set("two", "2.5", 20)
+        self.assertEqual(self.cache["two"], "2.5")
         self.assertEqual(self.cache._metrics.hits, 4)
 
         # non-existent-item tests
-        self.assertEqual(self.cache.get('four', '4'), '4')
-        self.assertIs(self.cache.get('four', None), None)
+        self.assertEqual(self.cache.get("four", "4"), "4")
+        self.assertIs(self.cache.get("four", None), None)
 
         with self.assertRaises(KeyError):
-            self.cache['four']
+            self.cache["four"]
 
         with self.assertRaises(KeyError):
-            self.cache.get('four')
+            self.cache.get("four")
 
         with self.assertRaises(KeyError):
-            self.cache.get_with_expiry('four')
+            self.cache.get_with_expiry("four")
 
         self.assertEqual(self.cache._metrics.hits, 4)
         self.assertEqual(self.cache._metrics.misses, 5)
 
     def test_expiry(self):
-        self.cache.set('one', '1', 10)
-        self.cache.set('two', '2', 20)
-        self.cache.set('three', '3', 30)
+        self.cache.set("one", "1", 10)
+        self.cache.set("two", "2", 20)
+        self.cache.set("three", "3", 30)
 
         self.assertEqual(len(self.cache), 3)
-        self.assertEqual(self.cache['one'], '1')
-        self.assertEqual(self.cache['two'], '2')
+        self.assertEqual(self.cache["one"], "1")
+        self.assertEqual(self.cache["two"], "2")
 
         # enough for the first entry to expire, but not the rest
         self.mock_timer.side_effect = lambda: 110.0
 
         self.assertEqual(len(self.cache), 2)
-        self.assertFalse('one' in self.cache)
-        self.assertEqual(self.cache['two'], '2')
-        self.assertEqual(self.cache['three'], '3')
+        self.assertFalse("one" in self.cache)
+        self.assertEqual(self.cache["two"], "2")
+        self.assertEqual(self.cache["three"], "3")
 
-        self.assertEqual(self.cache.get_with_expiry('two'), ('2', 120))
+        self.assertEqual(self.cache.get_with_expiry("two"), ("2", 120))
 
         self.assertEqual(self.cache._metrics.hits, 5)
         self.assertEqual(self.cache._metrics.misses, 0)
diff --git a/tests/utils.py b/tests/utils.py
index 200c1ceabe..da43166f3a 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -31,6 +31,7 @@ from synapse.api.constants import EventTypes
 from synapse.api.errors import CodeMessageException, cs_error
 from synapse.api.room_versions import RoomVersions
 from synapse.config.homeserver import HomeServerConfig
+from synapse.config.server import DEFAULT_ROOM_VERSION
 from synapse.federation.transport import server as federation_server
 from synapse.http.server import HttpServer
 from synapse.server import HomeServer
@@ -131,7 +132,6 @@ def default_config(name, parse=False):
         "password_providers": [],
         "worker_replication_url": "",
         "worker_app": None,
-        "email_enable_notifs": False,
         "block_non_admin_invites": False,
         "federation_domain_whitelist": None,
         "filter_timeline_limit": 5000,
@@ -174,7 +174,7 @@ def default_config(name, parse=False):
         "use_frozen_dicts": False,
         # We need a sane default_room_version, otherwise attempts to create
         # rooms will fail.
-        "default_room_version": "1",
+        "default_room_version": DEFAULT_ROOM_VERSION,
         # disable user directory updates, because they get done in the
         # background, which upsets the test runner.
         "update_user_directory": False,
@@ -182,7 +182,7 @@ def default_config(name, parse=False):
 
     if parse:
         config = HomeServerConfig()
-        config.parse_config_dict(config_dict)
+        config.parse_config_dict(config_dict, "", "")
         return config
 
     return config_dict
@@ -358,9 +358,9 @@ def setup_test_homeserver(
     # Need to let the HS build an auth handler and then mess with it
     # because AuthHandler's constructor requires the HS, so we can't make one
     # beforehand and pass it in to the HS's constructor (chicken / egg)
-    hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode('utf8')).hexdigest()
+    hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode("utf8")).hexdigest()
     hs.get_auth_handler().validate_hash = (
-        lambda p, h: hashlib.md5(p.encode('utf8')).hexdigest() == h
+        lambda p, h: hashlib.md5(p.encode("utf8")).hexdigest() == h
     )
 
     fed = kargs.get("resource_for_federation", None)
@@ -407,7 +407,7 @@ class MockHttpResource(HttpServer):
     def trigger_get(self, path):
         return self.trigger(b"GET", path, None)
 
-    @patch('twisted.web.http.Request')
+    @patch("twisted.web.http.Request")
     @defer.inlineCallbacks
     def trigger(
         self, http_method, path, content, mock_request, federation_auth_origin=None
@@ -431,12 +431,12 @@ class MockHttpResource(HttpServer):
         # annoyingly we return a twisted http request which has chained calls
         # to get at the http content, hence mock it here.
         mock_content = Mock()
-        config = {'read.return_value': content}
+        config = {"read.return_value": content}
         mock_content.configure_mock(**config)
         mock_request.content = mock_content
 
-        mock_request.method = http_method.encode('ascii')
-        mock_request.uri = path.encode('ascii')
+        mock_request.method = http_method.encode("ascii")
+        mock_request.uri = path.encode("ascii")
 
         mock_request.getClientIP.return_value = "-"
 
@@ -452,14 +452,14 @@ class MockHttpResource(HttpServer):
 
         # add in query params to the right place
         try:
-            mock_request.args = urlparse.parse_qs(path.split('?')[1])
-            mock_request.path = path.split('?')[0]
+            mock_request.args = urlparse.parse_qs(path.split("?")[1])
+            mock_request.path = path.split("?")[0]
             path = mock_request.path
         except Exception:
             pass
 
         if isinstance(path, bytes):
-            path = path.decode('utf8')
+            path = path.decode("utf8")
 
         for (method, pattern, func) in self.callbacks:
             if http_method != method: