summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_auth.py227
-rw-r--r--tests/api/test_filtering.py38
-rw-r--r--tests/api/test_ratelimiting.py22
-rw-r--r--tests/config/test_cache.py6
-rw-r--r--tests/config/test_server.py61
-rw-r--r--tests/crypto/test_keyring.py9
-rw-r--r--tests/events/test_snapshot.py5
-rw-r--r--tests/federation/test_complexity.py4
-rw-r--r--tests/federation/test_federation_sender.py15
-rw-r--r--tests/handlers/test_admin.py9
-rw-r--r--tests/handlers/test_appservice.py33
-rw-r--r--tests/handlers/test_auth.py133
-rw-r--r--tests/handlers/test_cas.py60
-rw-r--r--tests/handlers/test_device.py4
-rw-r--r--tests/handlers/test_directory.py20
-rw-r--r--tests/handlers/test_e2e_keys.py238
-rw-r--r--tests/handlers/test_e2e_room_keys.py347
-rw-r--r--tests/handlers/test_federation.py54
-rw-r--r--tests/handlers/test_message.py13
-rw-r--r--tests/handlers/test_oidc.py158
-rw-r--r--tests/handlers/test_password_providers.py11
-rw-r--r--tests/handlers/test_presence.py3
-rw-r--r--tests/handlers/test_profile.py125
-rw-r--r--tests/handlers/test_saml.py64
-rw-r--r--tests/handlers/test_typing.py12
-rw-r--r--tests/handlers/test_user_directory.py12
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py6
-rw-r--r--tests/http/test_client.py9
-rw-r--r--tests/module_api/test_api.py5
-rw-r--r--tests/push/test_email.py81
-rw-r--r--tests/push/test_presentable_names.py229
-rw-r--r--tests/push/test_push_rule_evaluator.py2
-rw-r--r--tests/replication/_base.py110
-rw-r--r--tests/replication/slave/storage/test_events.py5
-rw-r--r--tests/replication/tcp/streams/test_account_data.py6
-rw-r--r--tests/replication/tcp/streams/test_events.py13
-rw-r--r--tests/replication/tcp/test_remote_server_up.py3
-rw-r--r--tests/replication/test_auth.py12
-rw-r--r--tests/replication/test_client_reader_shard.py6
-rw-r--r--tests/replication/test_multi_media_repo.py12
-rw-r--r--tests/replication/test_pusher_shard.py21
-rw-r--r--tests/replication/test_sharded_event_persister.py25
-rw-r--r--tests/rest/admin/test_admin.py34
-rw-r--r--tests/rest/admin/test_device.py148
-rw-r--r--tests/rest/admin/test_event_reports.py99
-rw-r--r--tests/rest/admin/test_media.py60
-rw-r--r--tests/rest/admin/test_room.py275
-rw-r--r--tests/rest/admin/test_statistics.py133
-rw-r--r--tests/rest/admin/test_user.py613
-rw-r--r--tests/rest/client/test_power_levels.py4
-rw-r--r--tests/rest/client/test_redactions.py3
-rw-r--r--tests/rest/client/test_retention.py3
-rw-r--r--tests/rest/client/test_shadow_banned.py18
-rw-r--r--tests/rest/client/v1/test_events.py4
-rw-r--r--tests/rest/client/v1/test_login.py160
-rw-r--r--tests/rest/client/v1/test_profile.py249
-rw-r--r--tests/rest/client/v1/test_rooms.py51
-rw-r--r--tests/rest/client/v1/test_typing.py32
-rw-r--r--tests/rest/client/v1/utils.py36
-rw-r--r--tests/rest/client/v2_alpha/test_account.py154
-rw-r--r--tests/rest/client/v2_alpha/test_auth.py32
-rw-r--r--tests/rest/client/v2_alpha/test_password_policy.py20
-rw-r--r--tests/rest/client/v2_alpha/test_relations.py27
-rw-r--r--tests/rest/client/v2_alpha/test_sync.py23
-rw-r--r--tests/rest/client/v2_alpha/test_upgrade_room.py161
-rw-r--r--tests/rest/key/v2/test_remote_key_resource.py7
-rw-r--r--tests/rest/media/v1/test_media_storage.py129
-rw-r--r--tests/rest/test_well_known.py9
-rw-r--r--tests/server.py12
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py6
-rw-r--r--tests/state/test_v2.py75
-rw-r--r--tests/storage/test_account_data.py4
-rw-r--r--tests/storage/test_background_update.py4
-rw-r--r--tests/storage/test_cleanup_extrems.py3
-rw-r--r--tests/storage/test_client_ips.py24
-rw-r--r--tests/storage/test_event_chain.py23
-rw-r--r--tests/storage/test_event_federation.py17
-rw-r--r--tests/storage/test_event_push_actions.py4
-rw-r--r--tests/storage/test_events.py6
-rw-r--r--tests/storage/test_id_generators.py30
-rw-r--r--tests/storage/test_monthly_active_users.py2
-rw-r--r--tests/storage/test_redaction.py9
-rw-r--r--tests/storage/test_registration.py11
-rw-r--r--tests/test_event_auth.py25
-rw-r--r--tests/test_mau.py5
-rw-r--r--tests/test_metrics.py2
-rw-r--r--tests/test_preview.py130
-rw-r--r--tests/test_server.py5
-rw-r--r--tests/unittest.py10
-rw-r--r--tests/util/caches/test_cached_call.py161
-rw-r--r--tests/util/caches/test_deferred_cache.py5
-rw-r--r--tests/util/caches/test_descriptors.py12
-rw-r--r--tests/util/test_itertools.py12
-rw-r--r--tests/util/test_stream_change_cache.py9
-rw-r--r--tests/utils.py15
95 files changed, 3710 insertions, 1623 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index ee5217b074..34f72ae795 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -17,8 +17,6 @@ from mock import Mock
 
 import pymacaroons
 
-from twisted.internet import defer
-
 from synapse.api.auth import Auth
 from synapse.api.constants import UserTypes
 from synapse.api.errors import (
@@ -33,19 +31,17 @@ from synapse.storage.databases.main.registration import TokenLookupResult
 from synapse.types import UserID
 
 from tests import unittest
-from tests.utils import mock_getRawHeaders, setup_test_homeserver
+from tests.test_utils import simple_async_mock
+from tests.utils import mock_getRawHeaders
 
 
-class AuthTestCase(unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        self.state_handler = Mock()
+class AuthTestCase(unittest.HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
         self.store = Mock()
 
-        self.hs = yield setup_test_homeserver(self.addCleanup)
-        self.hs.get_datastore = Mock(return_value=self.store)
-        self.hs.get_auth_handler().store = self.store
-        self.auth = Auth(self.hs)
+        hs.get_datastore = Mock(return_value=self.store)
+        hs.get_auth_handler().store = self.store
+        self.auth = Auth(hs)
 
         # AuthBlocking reads from the hs' config on initialization. We need to
         # modify its config instead of the hs'
@@ -57,64 +53,59 @@ class AuthTestCase(unittest.TestCase):
         # this is overridden for the appservice tests
         self.store.get_app_service_by_token = Mock(return_value=None)
 
-        self.store.insert_client_ip = Mock(return_value=defer.succeed(None))
-        self.store.is_support_user = Mock(return_value=defer.succeed(False))
+        self.store.insert_client_ip = simple_async_mock(None)
+        self.store.is_support_user = simple_async_mock(False)
 
-    @defer.inlineCallbacks
     def test_get_user_by_req_user_valid_token(self):
         user_info = TokenLookupResult(
             user_id=self.test_user, token_id=5, device_id="device"
         )
-        self.store.get_user_by_access_token = Mock(
-            return_value=defer.succeed(user_info)
-        )
+        self.store.get_user_by_access_token = simple_async_mock(user_info)
 
         request = Mock(args={})
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
+        requester = self.get_success(self.auth.get_user_by_req(request))
         self.assertEquals(requester.user.to_string(), self.test_user)
 
     def test_get_user_by_req_user_bad_token(self):
-        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
+        self.store.get_user_by_access_token = simple_async_mock(None)
 
         request = Mock(args={})
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        d = defer.ensureDeferred(self.auth.get_user_by_req(request))
-        f = self.failureResultOf(d, InvalidClientTokenError).value
+        f = self.get_failure(
+            self.auth.get_user_by_req(request), InvalidClientTokenError
+        ).value
         self.assertEqual(f.code, 401)
         self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
 
     def test_get_user_by_req_user_missing_token(self):
         user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
-        self.store.get_user_by_access_token = Mock(
-            return_value=defer.succeed(user_info)
-        )
+        self.store.get_user_by_access_token = simple_async_mock(user_info)
 
         request = Mock(args={})
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        d = defer.ensureDeferred(self.auth.get_user_by_req(request))
-        f = self.failureResultOf(d, MissingClientTokenError).value
+        f = self.get_failure(
+            self.auth.get_user_by_req(request), MissingClientTokenError
+        ).value
         self.assertEqual(f.code, 401)
         self.assertEqual(f.errcode, "M_MISSING_TOKEN")
 
-    @defer.inlineCallbacks
     def test_get_user_by_req_appservice_valid_token(self):
         app_service = Mock(
             token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
         )
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
+        self.store.get_user_by_access_token = simple_async_mock(None)
 
         request = Mock(args={})
         request.getClientIP.return_value = "127.0.0.1"
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
+        requester = self.get_success(self.auth.get_user_by_req(request))
         self.assertEquals(requester.user.to_string(), self.test_user)
 
-    @defer.inlineCallbacks
     def test_get_user_by_req_appservice_valid_token_good_ip(self):
         from netaddr import IPSet
 
@@ -125,13 +116,13 @@ class AuthTestCase(unittest.TestCase):
             ip_range_whitelist=IPSet(["192.168/16"]),
         )
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
+        self.store.get_user_by_access_token = simple_async_mock(None)
 
         request = Mock(args={})
         request.getClientIP.return_value = "192.168.10.10"
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
+        requester = self.get_success(self.auth.get_user_by_req(request))
         self.assertEquals(requester.user.to_string(), self.test_user)
 
     def test_get_user_by_req_appservice_valid_token_bad_ip(self):
@@ -144,42 +135,44 @@ class AuthTestCase(unittest.TestCase):
             ip_range_whitelist=IPSet(["192.168/16"]),
         )
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
+        self.store.get_user_by_access_token = simple_async_mock(None)
 
         request = Mock(args={})
         request.getClientIP.return_value = "131.111.8.42"
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        d = defer.ensureDeferred(self.auth.get_user_by_req(request))
-        f = self.failureResultOf(d, InvalidClientTokenError).value
+        f = self.get_failure(
+            self.auth.get_user_by_req(request), InvalidClientTokenError
+        ).value
         self.assertEqual(f.code, 401)
         self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
 
     def test_get_user_by_req_appservice_bad_token(self):
         self.store.get_app_service_by_token = Mock(return_value=None)
-        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
+        self.store.get_user_by_access_token = simple_async_mock(None)
 
         request = Mock(args={})
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        d = defer.ensureDeferred(self.auth.get_user_by_req(request))
-        f = self.failureResultOf(d, InvalidClientTokenError).value
+        f = self.get_failure(
+            self.auth.get_user_by_req(request), InvalidClientTokenError
+        ).value
         self.assertEqual(f.code, 401)
         self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
 
     def test_get_user_by_req_appservice_missing_token(self):
         app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
+        self.store.get_user_by_access_token = simple_async_mock(None)
 
         request = Mock(args={})
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        d = defer.ensureDeferred(self.auth.get_user_by_req(request))
-        f = self.failureResultOf(d, MissingClientTokenError).value
+        f = self.get_failure(
+            self.auth.get_user_by_req(request), MissingClientTokenError
+        ).value
         self.assertEqual(f.code, 401)
         self.assertEqual(f.errcode, "M_MISSING_TOKEN")
 
-    @defer.inlineCallbacks
     def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
         masquerading_user_id = b"@doppelganger:matrix.org"
         app_service = Mock(
@@ -188,17 +181,15 @@ class AuthTestCase(unittest.TestCase):
         app_service.is_interested_in_user = Mock(return_value=True)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
         # This just needs to return a truth-y value.
-        self.store.get_user_by_id = Mock(
-            return_value=defer.succeed({"is_guest": False})
-        )
-        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
+        self.store.get_user_by_id = simple_async_mock({"is_guest": False})
+        self.store.get_user_by_access_token = simple_async_mock(None)
 
         request = Mock(args={})
         request.getClientIP.return_value = "127.0.0.1"
         request.args[b"access_token"] = [self.test_token]
         request.args[b"user_id"] = [masquerading_user_id]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
+        requester = self.get_success(self.auth.get_user_by_req(request))
         self.assertEquals(
             requester.user.to_string(), masquerading_user_id.decode("utf8")
         )
@@ -210,22 +201,18 @@ class AuthTestCase(unittest.TestCase):
         )
         app_service.is_interested_in_user = Mock(return_value=False)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
+        self.store.get_user_by_access_token = simple_async_mock(None)
 
         request = Mock(args={})
         request.getClientIP.return_value = "127.0.0.1"
         request.args[b"access_token"] = [self.test_token]
         request.args[b"user_id"] = [masquerading_user_id]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        d = defer.ensureDeferred(self.auth.get_user_by_req(request))
-        self.failureResultOf(d, AuthError)
+        self.get_failure(self.auth.get_user_by_req(request), AuthError)
 
-    @defer.inlineCallbacks
     def test_get_user_from_macaroon(self):
-        self.store.get_user_by_access_token = Mock(
-            return_value=defer.succeed(
-                TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
-            )
+        self.store.get_user_by_access_token = simple_async_mock(
+            TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
         )
 
         user_id = "@baldrick:matrix.org"
@@ -237,7 +224,7 @@ class AuthTestCase(unittest.TestCase):
         macaroon.add_first_party_caveat("gen = 1")
         macaroon.add_first_party_caveat("type = access")
         macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
-        user_info = yield defer.ensureDeferred(
+        user_info = self.get_success(
             self.auth.get_user_by_access_token(macaroon.serialize())
         )
         self.assertEqual(user_id, user_info.user_id)
@@ -246,10 +233,9 @@ class AuthTestCase(unittest.TestCase):
         # from the db.
         self.assertEqual(user_info.device_id, "device")
 
-    @defer.inlineCallbacks
     def test_get_guest_user_from_macaroon(self):
-        self.store.get_user_by_id = Mock(return_value=defer.succeed({"is_guest": True}))
-        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
+        self.store.get_user_by_id = simple_async_mock({"is_guest": True})
+        self.store.get_user_by_access_token = simple_async_mock(None)
 
         user_id = "@baldrick:matrix.org"
         macaroon = pymacaroons.Macaroon(
@@ -263,20 +249,17 @@ class AuthTestCase(unittest.TestCase):
         macaroon.add_first_party_caveat("guest = true")
         serialized = macaroon.serialize()
 
-        user_info = yield defer.ensureDeferred(
-            self.auth.get_user_by_access_token(serialized)
-        )
+        user_info = self.get_success(self.auth.get_user_by_access_token(serialized))
         self.assertEqual(user_id, user_info.user_id)
         self.assertTrue(user_info.is_guest)
         self.store.get_user_by_id.assert_called_with(user_id)
 
-    @defer.inlineCallbacks
     def test_cannot_use_regular_token_as_guest(self):
         USER_ID = "@percy:matrix.org"
-        self.store.add_access_token_to_user = Mock(return_value=defer.succeed(None))
-        self.store.get_device = Mock(return_value=defer.succeed(None))
+        self.store.add_access_token_to_user = simple_async_mock(None)
+        self.store.get_device = simple_async_mock(None)
 
-        token = yield defer.ensureDeferred(
+        token = self.get_success(
             self.hs.get_auth_handler().get_access_token_for_user_id(
                 USER_ID, "DEVICE", valid_until_ms=None
             )
@@ -289,25 +272,24 @@ class AuthTestCase(unittest.TestCase):
             puppets_user_id=None,
         )
 
-        def get_user(tok):
+        async def get_user(tok):
             if token != tok:
-                return defer.succeed(None)
-            return defer.succeed(
-                TokenLookupResult(
-                    user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE",
-                )
+                return None
+            return TokenLookupResult(
+                user_id=USER_ID,
+                is_guest=False,
+                token_id=1234,
+                device_id="DEVICE",
             )
 
         self.store.get_user_by_access_token = get_user
-        self.store.get_user_by_id = Mock(
-            return_value=defer.succeed({"is_guest": False})
-        )
+        self.store.get_user_by_id = simple_async_mock({"is_guest": False})
 
         # check the token works
         request = Mock(args={})
         request.args[b"access_token"] = [token.encode("ascii")]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        requester = yield defer.ensureDeferred(
+        requester = self.get_success(
             self.auth.get_user_by_req(request, allow_guest=True)
         )
         self.assertEqual(UserID.from_string(USER_ID), requester.user)
@@ -323,17 +305,16 @@ class AuthTestCase(unittest.TestCase):
         request.args[b"access_token"] = [guest_tok.encode("ascii")]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
 
-        with self.assertRaises(InvalidClientCredentialsError) as cm:
-            yield defer.ensureDeferred(
-                self.auth.get_user_by_req(request, allow_guest=True)
-            )
+        cm = self.get_failure(
+            self.auth.get_user_by_req(request, allow_guest=True),
+            InvalidClientCredentialsError,
+        )
 
-        self.assertEqual(401, cm.exception.code)
-        self.assertEqual("Guest access token used for regular user", cm.exception.msg)
+        self.assertEqual(401, cm.value.code)
+        self.assertEqual("Guest access token used for regular user", cm.value.msg)
 
         self.store.get_user_by_id.assert_called_with(USER_ID)
 
-    @defer.inlineCallbacks
     def test_blocking_mau(self):
         self.auth_blocking._limit_usage_by_mau = False
         self.auth_blocking._max_mau_value = 50
@@ -341,77 +322,61 @@ class AuthTestCase(unittest.TestCase):
         small_number_of_users = 1
 
         # Ensure no error thrown
-        yield defer.ensureDeferred(self.auth.check_auth_blocking())
+        self.get_success(self.auth.check_auth_blocking())
 
         self.auth_blocking._limit_usage_by_mau = True
 
-        self.store.get_monthly_active_count = Mock(
-            return_value=defer.succeed(lots_of_users)
-        )
+        self.store.get_monthly_active_count = simple_async_mock(lots_of_users)
 
-        with self.assertRaises(ResourceLimitError) as e:
-            yield defer.ensureDeferred(self.auth.check_auth_blocking())
-        self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
-        self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
-        self.assertEquals(e.exception.code, 403)
+        e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+        self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact)
+        self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+        self.assertEquals(e.value.code, 403)
 
         # Ensure does not throw an error
-        self.store.get_monthly_active_count = Mock(
-            return_value=defer.succeed(small_number_of_users)
-        )
-        yield defer.ensureDeferred(self.auth.check_auth_blocking())
+        self.store.get_monthly_active_count = simple_async_mock(small_number_of_users)
+        self.get_success(self.auth.check_auth_blocking())
 
-    @defer.inlineCallbacks
     def test_blocking_mau__depending_on_user_type(self):
         self.auth_blocking._max_mau_value = 50
         self.auth_blocking._limit_usage_by_mau = True
 
-        self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
+        self.store.get_monthly_active_count = simple_async_mock(100)
         # Support users allowed
-        yield defer.ensureDeferred(
-            self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT)
-        )
-        self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
+        self.get_success(self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT))
+        self.store.get_monthly_active_count = simple_async_mock(100)
         # Bots not allowed
-        with self.assertRaises(ResourceLimitError):
-            yield defer.ensureDeferred(
-                self.auth.check_auth_blocking(user_type=UserTypes.BOT)
-            )
-        self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
+        self.get_failure(
+            self.auth.check_auth_blocking(user_type=UserTypes.BOT), ResourceLimitError
+        )
+        self.store.get_monthly_active_count = simple_async_mock(100)
         # Real users not allowed
-        with self.assertRaises(ResourceLimitError):
-            yield defer.ensureDeferred(self.auth.check_auth_blocking())
+        self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
 
-    @defer.inlineCallbacks
     def test_reserved_threepid(self):
         self.auth_blocking._limit_usage_by_mau = True
         self.auth_blocking._max_mau_value = 1
-        self.store.get_monthly_active_count = lambda: defer.succeed(2)
+        self.store.get_monthly_active_count = simple_async_mock(2)
         threepid = {"medium": "email", "address": "reserved@server.com"}
         unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
         self.auth_blocking._mau_limits_reserved_threepids = [threepid]
 
-        with self.assertRaises(ResourceLimitError):
-            yield defer.ensureDeferred(self.auth.check_auth_blocking())
+        self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
 
-        with self.assertRaises(ResourceLimitError):
-            yield defer.ensureDeferred(
-                self.auth.check_auth_blocking(threepid=unknown_threepid)
-            )
+        self.get_failure(
+            self.auth.check_auth_blocking(threepid=unknown_threepid), ResourceLimitError
+        )
 
-        yield defer.ensureDeferred(self.auth.check_auth_blocking(threepid=threepid))
+        self.get_success(self.auth.check_auth_blocking(threepid=threepid))
 
-    @defer.inlineCallbacks
     def test_hs_disabled(self):
         self.auth_blocking._hs_disabled = True
         self.auth_blocking._hs_disabled_message = "Reason for being disabled"
-        with self.assertRaises(ResourceLimitError) as e:
-            yield defer.ensureDeferred(self.auth.check_auth_blocking())
-        self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
-        self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
-        self.assertEquals(e.exception.code, 403)
+        e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+        self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact)
+        self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+        self.assertEquals(e.value.code, 403)
 
-    @defer.inlineCallbacks
     def test_hs_disabled_no_server_notices_user(self):
         """Check that 'hs_disabled_message' works correctly when there is no
         server_notices user.
@@ -422,16 +387,14 @@ class AuthTestCase(unittest.TestCase):
 
         self.auth_blocking._hs_disabled = True
         self.auth_blocking._hs_disabled_message = "Reason for being disabled"
-        with self.assertRaises(ResourceLimitError) as e:
-            yield defer.ensureDeferred(self.auth.check_auth_blocking())
-        self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
-        self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
-        self.assertEquals(e.exception.code, 403)
+        e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+        self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact)
+        self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+        self.assertEquals(e.value.code, 403)
 
-    @defer.inlineCallbacks
     def test_server_notices_mxid_special_cased(self):
         self.auth_blocking._hs_disabled = True
         user = "@user:server"
         self.auth_blocking._server_notices_mxid = user
         self.auth_blocking._hs_disabled_message = "Reason for being disabled"
-        yield defer.ensureDeferred(self.auth.check_auth_blocking(user))
+        self.get_success(self.auth.check_auth_blocking(user))
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 279c94a03d..ab7d290724 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -18,15 +18,12 @@
 
 import jsonschema
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventContentFields
 from synapse.api.errors import SynapseError
 from synapse.api.filtering import Filter
 from synapse.events import make_event_from_dict
 
 from tests import unittest
-from tests.utils import setup_test_homeserver
 
 user_localpart = "test_user"
 
@@ -39,9 +36,8 @@ def MockEvent(**kwargs):
     return make_event_from_dict(kwargs)
 
 
-class FilteringTestCase(unittest.TestCase):
-    def setUp(self):
-        hs = setup_test_homeserver(self.addCleanup)
+class FilteringTestCase(unittest.HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
         self.filtering = hs.get_filtering()
         self.datastore = hs.get_datastore()
 
@@ -351,10 +347,9 @@ class FilteringTestCase(unittest.TestCase):
 
         self.assertTrue(Filter(definition).check(event))
 
-    @defer.inlineCallbacks
     def test_filter_presence_match(self):
         user_filter_json = {"presence": {"types": ["m.*"]}}
-        filter_id = yield defer.ensureDeferred(
+        filter_id = self.get_success(
             self.datastore.add_user_filter(
                 user_localpart=user_localpart, user_filter=user_filter_json
             )
@@ -362,7 +357,7 @@ class FilteringTestCase(unittest.TestCase):
         event = MockEvent(sender="@foo:bar", type="m.profile")
         events = [event]
 
-        user_filter = yield defer.ensureDeferred(
+        user_filter = self.get_success(
             self.filtering.get_user_filter(
                 user_localpart=user_localpart, filter_id=filter_id
             )
@@ -371,11 +366,10 @@ class FilteringTestCase(unittest.TestCase):
         results = user_filter.filter_presence(events=events)
         self.assertEquals(events, results)
 
-    @defer.inlineCallbacks
     def test_filter_presence_no_match(self):
         user_filter_json = {"presence": {"types": ["m.*"]}}
 
-        filter_id = yield defer.ensureDeferred(
+        filter_id = self.get_success(
             self.datastore.add_user_filter(
                 user_localpart=user_localpart + "2", user_filter=user_filter_json
             )
@@ -387,7 +381,7 @@ class FilteringTestCase(unittest.TestCase):
         )
         events = [event]
 
-        user_filter = yield defer.ensureDeferred(
+        user_filter = self.get_success(
             self.filtering.get_user_filter(
                 user_localpart=user_localpart + "2", filter_id=filter_id
             )
@@ -396,10 +390,9 @@ class FilteringTestCase(unittest.TestCase):
         results = user_filter.filter_presence(events=events)
         self.assertEquals([], results)
 
-    @defer.inlineCallbacks
     def test_filter_room_state_match(self):
         user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
-        filter_id = yield defer.ensureDeferred(
+        filter_id = self.get_success(
             self.datastore.add_user_filter(
                 user_localpart=user_localpart, user_filter=user_filter_json
             )
@@ -407,7 +400,7 @@ class FilteringTestCase(unittest.TestCase):
         event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
         events = [event]
 
-        user_filter = yield defer.ensureDeferred(
+        user_filter = self.get_success(
             self.filtering.get_user_filter(
                 user_localpart=user_localpart, filter_id=filter_id
             )
@@ -416,10 +409,9 @@ class FilteringTestCase(unittest.TestCase):
         results = user_filter.filter_room_state(events=events)
         self.assertEquals(events, results)
 
-    @defer.inlineCallbacks
     def test_filter_room_state_no_match(self):
         user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
-        filter_id = yield defer.ensureDeferred(
+        filter_id = self.get_success(
             self.datastore.add_user_filter(
                 user_localpart=user_localpart, user_filter=user_filter_json
             )
@@ -429,7 +421,7 @@ class FilteringTestCase(unittest.TestCase):
         )
         events = [event]
 
-        user_filter = yield defer.ensureDeferred(
+        user_filter = self.get_success(
             self.filtering.get_user_filter(
                 user_localpart=user_localpart, filter_id=filter_id
             )
@@ -454,11 +446,10 @@ class FilteringTestCase(unittest.TestCase):
 
         self.assertEquals(filtered_room_ids, ["!allowed:example.com"])
 
-    @defer.inlineCallbacks
     def test_add_filter(self):
         user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
 
-        filter_id = yield defer.ensureDeferred(
+        filter_id = self.get_success(
             self.filtering.add_user_filter(
                 user_localpart=user_localpart, user_filter=user_filter_json
             )
@@ -468,7 +459,7 @@ class FilteringTestCase(unittest.TestCase):
         self.assertEquals(
             user_filter_json,
             (
-                yield defer.ensureDeferred(
+                self.get_success(
                     self.datastore.get_user_filter(
                         user_localpart=user_localpart, filter_id=0
                     )
@@ -476,17 +467,16 @@ class FilteringTestCase(unittest.TestCase):
             ),
         )
 
-    @defer.inlineCallbacks
     def test_get_filter(self):
         user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
 
-        filter_id = yield defer.ensureDeferred(
+        filter_id = self.get_success(
             self.datastore.add_user_filter(
                 user_localpart=user_localpart, user_filter=user_filter_json
             )
         )
 
-        filter = yield defer.ensureDeferred(
+        filter = self.get_success(
             self.filtering.get_user_filter(
                 user_localpart=user_localpart, filter_id=filter_id
             )
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index fe504d0869..483418192c 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -43,7 +43,11 @@ class TestRatelimiter(unittest.TestCase):
 
     def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
         appservice = ApplicationService(
-            None, "example.com", id="foo", rate_limited=True, sender="@as:example.com",
+            None,
+            "example.com",
+            id="foo",
+            rate_limited=True,
+            sender="@as:example.com",
         )
         as_requester = create_requester("@user:example.com", app_service=appservice)
 
@@ -68,7 +72,11 @@ class TestRatelimiter(unittest.TestCase):
 
     def test_allowed_appservice_via_can_requester_do_action(self):
         appservice = ApplicationService(
-            None, "example.com", id="foo", rate_limited=False, sender="@as:example.com",
+            None,
+            "example.com",
+            id="foo",
+            rate_limited=False,
+            sender="@as:example.com",
         )
         as_requester = create_requester("@user:example.com", app_service=appservice)
 
@@ -113,12 +121,18 @@ class TestRatelimiter(unittest.TestCase):
         limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
 
         # First attempt should be allowed
-        allowed, time_allowed = limiter.can_do_action(("test_id",), _time_now_s=0,)
+        allowed, time_allowed = limiter.can_do_action(
+            ("test_id",),
+            _time_now_s=0,
+        )
         self.assertTrue(allowed)
         self.assertEqual(10.0, time_allowed)
 
         # Second attempt, 1s later, will fail
-        allowed, time_allowed = limiter.can_do_action(("test_id",), _time_now_s=1,)
+        allowed, time_allowed = limiter.can_do_action(
+            ("test_id",),
+            _time_now_s=1,
+        )
         self.assertFalse(allowed)
         self.assertEqual(10.0, time_allowed)
 
diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py
index d3ec24c975..2b7f09c14b 100644
--- a/tests/config/test_cache.py
+++ b/tests/config/test_cache.py
@@ -127,8 +127,7 @@ class CacheConfigTests(TestCase):
         self.assertEqual(cache.max_size, 150)
 
     def test_cache_with_asterisk_in_name(self):
-        """Some caches have asterisks in their name, test that they are set correctly.
-        """
+        """Some caches have asterisks in their name, test that they are set correctly."""
 
         config = {
             "caches": {
@@ -164,7 +163,8 @@ class CacheConfigTests(TestCase):
         t.read_config(config, config_dir_path="", data_dir_path="")
 
         cache = LruCache(
-            max_size=t.caches.event_cache_size, apply_cache_factor_from_config=False,
+            max_size=t.caches.event_cache_size,
+            apply_cache_factor_from_config=False,
         )
         add_resizable_cache("event_cache", cache_resize_callback=cache.set_cache_factor)
 
diff --git a/tests/config/test_server.py b/tests/config/test_server.py
index a10d017120..98af7aa675 100644
--- a/tests/config/test_server.py
+++ b/tests/config/test_server.py
@@ -15,7 +15,8 @@
 
 import yaml
 
-from synapse.config.server import ServerConfig, is_threepid_reserved
+from synapse.config._base import ConfigError
+from synapse.config.server import ServerConfig, generate_ip_set, is_threepid_reserved
 
 from tests import unittest
 
@@ -128,3 +129,61 @@ class ServerConfigTestCase(unittest.TestCase):
         )
 
         self.assertEqual(conf["listeners"], expected_listeners)
+
+
+class GenerateIpSetTestCase(unittest.TestCase):
+    def test_empty(self):
+        ip_set = generate_ip_set(())
+        self.assertFalse(ip_set)
+
+        ip_set = generate_ip_set((), ())
+        self.assertFalse(ip_set)
+
+    def test_generate(self):
+        """Check adding IPv4 and IPv6 addresses."""
+        # IPv4 address
+        ip_set = generate_ip_set(("1.2.3.4",))
+        self.assertEqual(len(ip_set.iter_cidrs()), 4)
+
+        # IPv4 CIDR
+        ip_set = generate_ip_set(("1.2.3.4/24",))
+        self.assertEqual(len(ip_set.iter_cidrs()), 4)
+
+        # IPv6 address
+        ip_set = generate_ip_set(("2001:db8::8a2e:370:7334",))
+        self.assertEqual(len(ip_set.iter_cidrs()), 1)
+
+        # IPv6 CIDR
+        ip_set = generate_ip_set(("2001:db8::/104",))
+        self.assertEqual(len(ip_set.iter_cidrs()), 1)
+
+        # The addresses can overlap OK.
+        ip_set = generate_ip_set(("1.2.3.4", "::1.2.3.4"))
+        self.assertEqual(len(ip_set.iter_cidrs()), 4)
+
+    def test_extra(self):
+        """Extra IP addresses are treated the same."""
+        ip_set = generate_ip_set((), ("1.2.3.4",))
+        self.assertEqual(len(ip_set.iter_cidrs()), 4)
+
+        ip_set = generate_ip_set(("1.1.1.1",), ("1.2.3.4",))
+        self.assertEqual(len(ip_set.iter_cidrs()), 8)
+
+        # They can duplicate without error.
+        ip_set = generate_ip_set(("1.2.3.4",), ("1.2.3.4",))
+        self.assertEqual(len(ip_set.iter_cidrs()), 4)
+
+    def test_bad_value(self):
+        """An error should be raised if a bad value is passed in."""
+        with self.assertRaises(ConfigError):
+            generate_ip_set(("not-an-ip",))
+
+        with self.assertRaises(ConfigError):
+            generate_ip_set(("1.2.3.4/128",))
+
+        with self.assertRaises(ConfigError):
+            generate_ip_set((":::",))
+
+        # The following get treated as empty data.
+        self.assertFalse(generate_ip_set(None))
+        self.assertFalse(generate_ip_set({}))
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 1d65ea2f9c..30fcc4c1bf 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -400,7 +400,10 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
         )
 
     def build_perspectives_response(
-        self, server_name: str, signing_key: SigningKey, valid_until_ts: int,
+        self,
+        server_name: str,
+        signing_key: SigningKey,
+        valid_until_ts: int,
     ) -> dict:
         """
         Build a valid perspectives server response to a request for the given key
@@ -455,7 +458,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
         VALID_UNTIL_TS = 200 * 1000
 
         response = self.build_perspectives_response(
-            SERVER_NAME, testkey, VALID_UNTIL_TS,
+            SERVER_NAME,
+            testkey,
+            VALID_UNTIL_TS,
         )
 
         self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
index 3a80626224..ec85324c0c 100644
--- a/tests/events/test_snapshot.py
+++ b/tests/events/test_snapshot.py
@@ -43,7 +43,10 @@ class TestEventContext(unittest.HomeserverTestCase):
 
         event, context = self.get_success(
             create_event(
-                self.hs, room_id=self.room_id, type="m.test", sender=self.user_id,
+                self.hs,
+                room_id=self.room_id,
+                type="m.test",
+                sender=self.user_id,
             )
         )
 
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 9ccd2d76b8..8186b8ca01 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -150,8 +150,8 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         )
 
         # Artificially raise the complexity
-        self.hs.get_datastore().get_current_state_event_counts = lambda x: make_awaitable(
-            600
+        self.hs.get_datastore().get_current_state_event_counts = (
+            lambda x: make_awaitable(600)
         )
 
         d = handler._remote_join(
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 917762e6b6..ecc3faa572 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -279,7 +279,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
 
         ret = self.get_success(
             e2e_handler.upload_signatures_for_device_keys(
-                u1, {u1: {"D1": d1_json, "D2": d2_json}},
+                u1,
+                {u1: {"D1": d1_json, "D2": d2_json}},
             )
         )
         self.assertEqual(ret["failures"], {})
@@ -486,9 +487,11 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
             self.assertGreaterEqual(content["stream_id"], prev_stream_id)
         return content["stream_id"]
 
-    def check_signing_key_update_txn(self, txn: JsonDict,) -> None:
-        """Check that the txn has an EDU with a signing key update.
-        """
+    def check_signing_key_update_txn(
+        self,
+        txn: JsonDict,
+    ) -> None:
+        """Check that the txn has an EDU with a signing key update."""
         edus = txn["edus"]
         self.assertEqual(len(edus), 1)
 
@@ -502,7 +505,9 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
 
         self.get_success(
             self.hs.get_e2e_keys_handler().upload_keys_for_user(
-                user_id, device_id, {"device_keys": device_dict},
+                user_id,
+                device_id,
+                {"device_keys": device_dict},
             )
         )
         return sk
diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index 5c2b4de1a6..a01fdd0839 100644
--- a/tests/handlers/test_admin.py
+++ b/tests/handlers/test_admin.py
@@ -44,8 +44,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
         self.token2 = self.login("user2", "password")
 
     def test_single_public_joined_room(self):
-        """Test that we write *all* events for a public room
-        """
+        """Test that we write *all* events for a public room"""
         room_id = self.helper.create_room_as(
             self.user1, tok=self.token1, is_public=True
         )
@@ -116,8 +115,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
         self.assertEqual(counter[(EventTypes.Member, self.user2)], 1)
 
     def test_single_left_room(self):
-        """Tests that we don't see events in the room after we leave.
-        """
+        """Tests that we don't see events in the room after we leave."""
         room_id = self.helper.create_room_as(self.user1, tok=self.token1)
         self.helper.send(room_id, body="Hello!", tok=self.token1)
         self.helper.join(room_id, self.user2, tok=self.token2)
@@ -190,8 +188,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
         self.assertEqual(counter[(EventTypes.Member, self.user2)], 3)
 
     def test_invite(self):
-        """Tests that pending invites get handled correctly.
-        """
+        """Tests that pending invites get handled correctly."""
         room_id = self.helper.create_room_as(self.user1, tok=self.token1)
         self.helper.send(room_id, body="Hello!", tok=self.token1)
         self.helper.invite(room_id, self.user1, self.user2, tok=self.token1)
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 53763cd0f9..d5d3fdd99a 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -35,8 +35,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         self.mock_scheduler = Mock()
         hs = Mock()
         hs.get_datastore.return_value = self.mock_store
-        self.mock_store.get_received_ts.return_value = defer.succeed(0)
-        self.mock_store.set_appservice_last_pos.return_value = defer.succeed(None)
+        self.mock_store.get_received_ts.return_value = make_awaitable(0)
+        self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None)
         hs.get_application_service_api.return_value = self.mock_as_api
         hs.get_application_service_scheduler.return_value = self.mock_scheduler
         hs.get_clock.return_value = MockClock()
@@ -50,16 +50,16 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             self._mkservice(is_interested=False),
         ]
 
-        self.mock_as_api.query_user.return_value = defer.succeed(True)
+        self.mock_as_api.query_user.return_value = make_awaitable(True)
         self.mock_store.get_app_services.return_value = services
-        self.mock_store.get_user_by_id.return_value = defer.succeed([])
+        self.mock_store.get_user_by_id.return_value = make_awaitable([])
 
         event = Mock(
             sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
         )
         self.mock_store.get_new_events_for_appservice.side_effect = [
-            defer.succeed((0, [event])),
-            defer.succeed((0, [])),
+            make_awaitable((0, [event])),
+            make_awaitable((0, [])),
         ]
         self.handler.notify_interested_services(RoomStreamToken(None, 0))
 
@@ -72,13 +72,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         services = [self._mkservice(is_interested=True)]
         services[0].is_interested_in_user.return_value = True
         self.mock_store.get_app_services.return_value = services
-        self.mock_store.get_user_by_id.return_value = defer.succeed(None)
+        self.mock_store.get_user_by_id.return_value = make_awaitable(None)
 
         event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
-        self.mock_as_api.query_user.return_value = defer.succeed(True)
+        self.mock_as_api.query_user.return_value = make_awaitable(True)
         self.mock_store.get_new_events_for_appservice.side_effect = [
-            defer.succeed((0, [event])),
-            defer.succeed((0, [])),
+            make_awaitable((0, [event])),
+            make_awaitable((0, [])),
         ]
 
         self.handler.notify_interested_services(RoomStreamToken(None, 0))
@@ -90,13 +90,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         services = [self._mkservice(is_interested=True)]
         services[0].is_interested_in_user.return_value = True
         self.mock_store.get_app_services.return_value = services
-        self.mock_store.get_user_by_id.return_value = defer.succeed({"name": user_id})
+        self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id})
 
         event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
-        self.mock_as_api.query_user.return_value = defer.succeed(True)
+        self.mock_as_api.query_user.return_value = make_awaitable(True)
         self.mock_store.get_new_events_for_appservice.side_effect = [
-            defer.succeed((0, [event])),
-            defer.succeed((0, [])),
+            make_awaitable((0, [event])),
+            make_awaitable((0, [])),
         ]
 
         self.handler.notify_interested_services(RoomStreamToken(None, 0))
@@ -106,7 +106,6 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             "query_user called when it shouldn't have been.",
         )
 
-    @defer.inlineCallbacks
     def test_query_room_alias_exists(self):
         room_alias_str = "#foo:bar"
         room_alias = Mock()
@@ -127,8 +126,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             Mock(room_id=room_id, servers=servers)
         )
 
-        result = yield defer.ensureDeferred(
-            self.handler.query_room_alias_exists(room_alias)
+        result = self.successResultOf(
+            defer.ensureDeferred(self.handler.query_room_alias_exists(room_alias))
         )
 
         self.mock_as_api.query_alias.assert_called_once_with(
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index e24ce81284..0e42013bb9 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -16,28 +16,21 @@ from mock import Mock
 
 import pymacaroons
 
-from twisted.internet import defer
-
-import synapse
-import synapse.api.errors
-from synapse.api.errors import ResourceLimitError
+from synapse.api.errors import AuthError, ResourceLimitError
 
 from tests import unittest
 from tests.test_utils import make_awaitable
-from tests.utils import setup_test_homeserver
 
 
-class AuthTestCase(unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        self.hs = yield setup_test_homeserver(self.addCleanup)
-        self.auth_handler = self.hs.get_auth_handler()
-        self.macaroon_generator = self.hs.get_macaroon_generator()
+class AuthTestCase(unittest.HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
+        self.auth_handler = hs.get_auth_handler()
+        self.macaroon_generator = hs.get_macaroon_generator()
 
         # MAU tests
         # AuthBlocking reads from the hs' config on initialization. We need to
         # modify its config instead of the hs'
-        self.auth_blocking = self.hs.get_auth()._auth_blocking
+        self.auth_blocking = hs.get_auth()._auth_blocking
         self.auth_blocking._max_mau_value = 50
 
         self.small_number_of_users = 1
@@ -52,8 +45,6 @@ class AuthTestCase(unittest.TestCase):
             self.fail("some_user was not in %s" % macaroon.inspect())
 
     def test_macaroon_caveats(self):
-        self.hs.get_clock().now = 5000
-
         token = self.macaroon_generator.generate_access_token("a_user")
         macaroon = pymacaroons.Macaroon.deserialize(token)
 
@@ -76,29 +67,25 @@ class AuthTestCase(unittest.TestCase):
         v.satisfy_general(verify_nonce)
         v.verify(macaroon, self.hs.config.macaroon_secret_key)
 
-    @defer.inlineCallbacks
     def test_short_term_login_token_gives_user_id(self):
-        self.hs.get_clock().now = 1000
-
         token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
-        user_id = yield defer.ensureDeferred(
+        user_id = self.get_success(
             self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
         )
         self.assertEqual("a_user", user_id)
 
         # when we advance the clock, the token should be rejected
-        self.hs.get_clock().now = 6000
-        with self.assertRaises(synapse.api.errors.AuthError):
-            yield defer.ensureDeferred(
-                self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
-            )
+        self.reactor.advance(6)
+        self.get_failure(
+            self.auth_handler.validate_short_term_login_token_and_get_user_id(token),
+            AuthError,
+        )
 
-    @defer.inlineCallbacks
     def test_short_term_login_token_cannot_replace_user_id(self):
         token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
         macaroon = pymacaroons.Macaroon.deserialize(token)
 
-        user_id = yield defer.ensureDeferred(
+        user_id = self.get_success(
             self.auth_handler.validate_short_term_login_token_and_get_user_id(
                 macaroon.serialize()
             )
@@ -109,102 +96,90 @@ class AuthTestCase(unittest.TestCase):
         # user_id.
         macaroon.add_first_party_caveat("user_id = b_user")
 
-        with self.assertRaises(synapse.api.errors.AuthError):
-            yield defer.ensureDeferred(
-                self.auth_handler.validate_short_term_login_token_and_get_user_id(
-                    macaroon.serialize()
-                )
-            )
+        self.get_failure(
+            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+                macaroon.serialize()
+            ),
+            AuthError,
+        )
 
-    @defer.inlineCallbacks
     def test_mau_limits_disabled(self):
         self.auth_blocking._limit_usage_by_mau = False
         # Ensure does not throw exception
-        yield defer.ensureDeferred(
+        self.get_success(
             self.auth_handler.get_access_token_for_user_id(
                 "user_a", device_id=None, valid_until_ms=None
             )
         )
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.auth_handler.validate_short_term_login_token_and_get_user_id(
                 self._get_macaroon().serialize()
             )
         )
 
-    @defer.inlineCallbacks
     def test_mau_limits_exceeded_large(self):
         self.auth_blocking._limit_usage_by_mau = True
         self.hs.get_datastore().get_monthly_active_count = Mock(
             return_value=make_awaitable(self.large_number_of_users)
         )
 
-        with self.assertRaises(ResourceLimitError):
-            yield defer.ensureDeferred(
-                self.auth_handler.get_access_token_for_user_id(
-                    "user_a", device_id=None, valid_until_ms=None
-                )
-            )
+        self.get_failure(
+            self.auth_handler.get_access_token_for_user_id(
+                "user_a", device_id=None, valid_until_ms=None
+            ),
+            ResourceLimitError,
+        )
 
         self.hs.get_datastore().get_monthly_active_count = Mock(
             return_value=make_awaitable(self.large_number_of_users)
         )
-        with self.assertRaises(ResourceLimitError):
-            yield defer.ensureDeferred(
-                self.auth_handler.validate_short_term_login_token_and_get_user_id(
-                    self._get_macaroon().serialize()
-                )
-            )
+        self.get_failure(
+            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+                self._get_macaroon().serialize()
+            ),
+            ResourceLimitError,
+        )
 
-    @defer.inlineCallbacks
     def test_mau_limits_parity(self):
+        # Ensure we're not at the unix epoch.
+        self.reactor.advance(1)
         self.auth_blocking._limit_usage_by_mau = True
 
-        # If not in monthly active cohort
+        # Set the server to be at the edge of too many users.
         self.hs.get_datastore().get_monthly_active_count = Mock(
             return_value=make_awaitable(self.auth_blocking._max_mau_value)
         )
-        with self.assertRaises(ResourceLimitError):
-            yield defer.ensureDeferred(
-                self.auth_handler.get_access_token_for_user_id(
-                    "user_a", device_id=None, valid_until_ms=None
-                )
-            )
 
-        self.hs.get_datastore().get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.auth_blocking._max_mau_value)
+        # If not in monthly active cohort
+        self.get_failure(
+            self.auth_handler.get_access_token_for_user_id(
+                "user_a", device_id=None, valid_until_ms=None
+            ),
+            ResourceLimitError,
         )
-        with self.assertRaises(ResourceLimitError):
-            yield defer.ensureDeferred(
-                self.auth_handler.validate_short_term_login_token_and_get_user_id(
-                    self._get_macaroon().serialize()
-                )
-            )
+        self.get_failure(
+            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+                self._get_macaroon().serialize()
+            ),
+            ResourceLimitError,
+        )
+
         # If in monthly active cohort
         self.hs.get_datastore().user_last_seen_monthly_active = Mock(
-            return_value=make_awaitable(self.hs.get_clock().time_msec())
+            return_value=make_awaitable(self.clock.time_msec())
         )
-        self.hs.get_datastore().get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.auth_blocking._max_mau_value)
-        )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.auth_handler.get_access_token_for_user_id(
                 "user_a", device_id=None, valid_until_ms=None
             )
         )
-        self.hs.get_datastore().user_last_seen_monthly_active = Mock(
-            return_value=make_awaitable(self.hs.get_clock().time_msec())
-        )
-        self.hs.get_datastore().get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.auth_blocking._max_mau_value)
-        )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.auth_handler.validate_short_term_login_token_and_get_user_id(
                 self._get_macaroon().serialize()
             )
         )
 
-    @defer.inlineCallbacks
     def test_mau_limits_not_exceeded(self):
         self.auth_blocking._limit_usage_by_mau = True
 
@@ -212,7 +187,7 @@ class AuthTestCase(unittest.TestCase):
             return_value=make_awaitable(self.small_number_of_users)
         )
         # Ensure does not raise exception
-        yield defer.ensureDeferred(
+        self.get_success(
             self.auth_handler.get_access_token_for_user_id(
                 "user_a", device_id=None, valid_until_ms=None
             )
@@ -221,7 +196,7 @@ class AuthTestCase(unittest.TestCase):
         self.hs.get_datastore().get_monthly_active_count = Mock(
             return_value=make_awaitable(self.small_number_of_users)
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.auth_handler.validate_short_term_login_token_and_get_user_id(
                 self._get_macaroon().serialize()
             )
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index c37bb6440e..6f992291b8 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -16,7 +16,7 @@ from mock import Mock
 from synapse.handlers.cas_handler import CasResponse
 
 from tests.test_utils import simple_async_mock
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, override_config
 
 # These are a few constants that are used as config parameters in the tests.
 BASE_URL = "https://synapse/"
@@ -32,6 +32,10 @@ class CasHandlerTestCase(HomeserverTestCase):
             "server_url": SERVER_URL,
             "service_url": BASE_URL,
         }
+
+        # Update this config with what's in the default config so that
+        # override_config works as expected.
+        cas_config.update(config.get("cas_config", {}))
         config["cas_config"] = cas_config
 
         return config
@@ -62,7 +66,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None
+            "@test_user:test", request, "redirect_uri", None, new_user=True
         )
 
     def test_map_cas_user_to_existing_user(self):
@@ -85,7 +89,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None
+            "@test_user:test", request, "redirect_uri", None, new_user=False
         )
 
         # Subsequent calls should map to the same mxid.
@@ -94,7 +98,7 @@ class CasHandlerTestCase(HomeserverTestCase):
             self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
         )
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None
+            "@test_user:test", request, "redirect_uri", None, new_user=False
         )
 
     def test_map_cas_user_to_invalid_localpart(self):
@@ -112,10 +116,54 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@f=c3=b6=c3=b6:test", request, "redirect_uri", None
+            "@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
+        )
+
+    @override_config(
+        {
+            "cas_config": {
+                "required_attributes": {"userGroup": "staff", "department": None}
+            }
+        }
+    )
+    def test_required_attributes(self):
+        """The required attributes must be met from the CAS response."""
+
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        # The response doesn't have the proper userGroup or department.
+        cas_response = CasResponse("test_user", {})
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+        auth_handler.complete_sso_login.assert_not_called()
+
+        # The response doesn't have any department.
+        cas_response = CasResponse("test_user", {"userGroup": "staff"})
+        request.reset_mock()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+        auth_handler.complete_sso_login.assert_not_called()
+
+        # Add the proper attributes and it should succeed.
+        cas_response = CasResponse(
+            "test_user", {"userGroup": ["staff", "admin"], "department": ["sales"]}
+        )
+        request.reset_mock()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "redirect_uri", None, new_user=True
         )
 
 
 def _mock_request():
     """Returns a mock which will stand in as a SynapseRequest"""
-    return Mock(spec=["getClientIP", "getHeader"])
+    return Mock(spec=["getClientIP", "getHeader", "_disconnected"])
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 5dfeccfeb6..821629bc38 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -260,7 +260,9 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
         # Create a new login for the user and dehydrated the device
         device_id, access_token = self.get_success(
             self.registration.register_device(
-                user_id=user_id, device_id=None, initial_display_name="new device",
+                user_id=user_id,
+                device_id=None,
+                initial_display_name="new device",
             )
         )
 
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index a39f898608..863d8737b2 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -131,7 +131,9 @@ class TestCreateAlias(unittest.HomeserverTestCase):
         """A user can create an alias for a room they're in."""
         self.get_success(
             self.handler.create_association(
-                create_requester(self.test_user), self.room_alias, self.room_id,
+                create_requester(self.test_user),
+                self.room_alias,
+                self.room_id,
             )
         )
 
@@ -143,7 +145,9 @@ class TestCreateAlias(unittest.HomeserverTestCase):
 
         self.get_failure(
             self.handler.create_association(
-                create_requester(self.test_user), self.room_alias, other_room_id,
+                create_requester(self.test_user),
+                self.room_alias,
+                other_room_id,
             ),
             synapse.api.errors.SynapseError,
         )
@@ -156,7 +160,9 @@ class TestCreateAlias(unittest.HomeserverTestCase):
 
         self.get_success(
             self.handler.create_association(
-                create_requester(self.admin_user), self.room_alias, other_room_id,
+                create_requester(self.admin_user),
+                self.room_alias,
+                other_room_id,
             )
         )
 
@@ -275,8 +281,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
 
 
 class CanonicalAliasTestCase(unittest.HomeserverTestCase):
-    """Test modifications of the canonical alias when delete aliases.
-    """
+    """Test modifications of the canonical alias when delete aliases."""
 
     servlets = [
         synapse.rest.admin.register_servlets,
@@ -317,7 +322,10 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
     def _set_canonical_alias(self, content):
         """Configure the canonical alias state on the room."""
         self.helper.send_state(
-            self.room_id, "m.room.canonical_alias", content, tok=self.admin_user_tok,
+            self.room_id,
+            "m.room.canonical_alias",
+            content,
+            tok=self.admin_user_tok,
         )
 
     def _get_canonical_alias(self):
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 924f29f051..5e86c5e56b 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -18,42 +18,26 @@ import mock
 
 from signedjson import key as key, sign as sign
 
-from twisted.internet import defer
-
-import synapse.handlers.e2e_keys
-import synapse.storage
-from synapse.api import errors
 from synapse.api.constants import RoomEncryptionAlgorithms
+from synapse.api.errors import Codes, SynapseError
 
-from tests import unittest, utils
+from tests import unittest
 
 
-class E2eKeysHandlerTestCase(unittest.TestCase):
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        self.hs = None  # type: synapse.server.HomeServer
-        self.handler = None  # type: synapse.handlers.e2e_keys.E2eKeysHandler
-        self.store = None  # type: synapse.storage.Storage
+class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
+    def make_homeserver(self, reactor, clock):
+        return self.setup_test_homeserver(federation_client=mock.Mock())
 
-    @defer.inlineCallbacks
-    def setUp(self):
-        self.hs = yield utils.setup_test_homeserver(
-            self.addCleanup, federation_client=mock.Mock()
-        )
-        self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs)
+    def prepare(self, reactor, clock, hs):
+        self.handler = hs.get_e2e_keys_handler()
         self.store = self.hs.get_datastore()
 
-    @defer.inlineCallbacks
     def test_query_local_devices_no_devices(self):
-        """If the user has no devices, we expect an empty list.
-        """
+        """If the user has no devices, we expect an empty list."""
         local_user = "@boris:" + self.hs.hostname
-        res = yield defer.ensureDeferred(
-            self.handler.query_local_devices({local_user: None})
-        )
+        res = self.get_success(self.handler.query_local_devices({local_user: None}))
         self.assertDictEqual(res, {local_user: {}})
 
-    @defer.inlineCallbacks
     def test_reupload_one_time_keys(self):
         """we should be able to re-upload the same keys"""
         local_user = "@boris:" + self.hs.hostname
@@ -64,7 +48,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "alg2:k3": {"key": "key3"},
         }
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.upload_keys_for_user(
                 local_user, device_id, {"one_time_keys": keys}
             )
@@ -73,14 +57,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
 
         # we should be able to change the signature without a problem
         keys["alg2:k2"]["signatures"]["k1"] = "sig2"
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.upload_keys_for_user(
                 local_user, device_id, {"one_time_keys": keys}
             )
         )
         self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
 
-    @defer.inlineCallbacks
     def test_change_one_time_keys(self):
         """attempts to change one-time-keys should be rejected"""
 
@@ -92,75 +75,66 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "alg2:k3": {"key": "key3"},
         }
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.upload_keys_for_user(
                 local_user, device_id, {"one_time_keys": keys}
             )
         )
         self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
 
-        try:
-            yield defer.ensureDeferred(
-                self.handler.upload_keys_for_user(
-                    local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
-                )
-            )
-            self.fail("No error when changing string key")
-        except errors.SynapseError:
-            pass
-
-        try:
-            yield defer.ensureDeferred(
-                self.handler.upload_keys_for_user(
-                    local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
-                )
-            )
-            self.fail("No error when replacing dict key with string")
-        except errors.SynapseError:
-            pass
-
-        try:
-            yield defer.ensureDeferred(
-                self.handler.upload_keys_for_user(
-                    local_user,
-                    device_id,
-                    {"one_time_keys": {"alg1:k1": {"key": "key"}}},
-                )
-            )
-            self.fail("No error when replacing string key with dict")
-        except errors.SynapseError:
-            pass
-
-        try:
-            yield defer.ensureDeferred(
-                self.handler.upload_keys_for_user(
-                    local_user,
-                    device_id,
-                    {
-                        "one_time_keys": {
-                            "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
-                        }
-                    },
-                )
-            )
-            self.fail("No error when replacing dict key")
-        except errors.SynapseError:
-            pass
+        # Error when changing string key
+        self.get_failure(
+            self.handler.upload_keys_for_user(
+                local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
+            ),
+            SynapseError,
+        )
+
+        # Error when replacing dict key with strin
+        self.get_failure(
+            self.handler.upload_keys_for_user(
+                local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
+            ),
+            SynapseError,
+        )
+
+        # Error when replacing string key with dict
+        self.get_failure(
+            self.handler.upload_keys_for_user(
+                local_user,
+                device_id,
+                {"one_time_keys": {"alg1:k1": {"key": "key"}}},
+            ),
+            SynapseError,
+        )
+
+        # Error when replacing dict key
+        self.get_failure(
+            self.handler.upload_keys_for_user(
+                local_user,
+                device_id,
+                {
+                    "one_time_keys": {
+                        "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
+                    }
+                },
+            ),
+            SynapseError,
+        )
 
-    @defer.inlineCallbacks
     def test_claim_one_time_key(self):
         local_user = "@boris:" + self.hs.hostname
         device_id = "xyz"
         keys = {"alg1:k1": "key1"}
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.upload_keys_for_user(
                 local_user, device_id, {"one_time_keys": keys}
             )
         )
         self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}})
 
-        res2 = yield defer.ensureDeferred(
+        res2 = self.get_success(
             self.handler.claim_one_time_keys(
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
             )
@@ -173,7 +147,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             },
         )
 
-    @defer.inlineCallbacks
     def test_fallback_key(self):
         local_user = "@boris:" + self.hs.hostname
         device_id = "xyz"
@@ -181,12 +154,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         otk = {"alg1:k2": "key2"}
 
         # we shouldn't have any unused fallback keys yet
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
         )
         self.assertEqual(res, [])
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_keys_for_user(
                 local_user,
                 device_id,
@@ -195,14 +168,14 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         )
 
         # we should now have an unused alg1 key
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
         )
         self.assertEqual(res, ["alg1"])
 
         # claiming an OTK when no OTKs are available should return the fallback
         # key
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.claim_one_time_keys(
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
             )
@@ -213,13 +186,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         )
 
         # we shouldn't have any unused fallback keys again
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
         )
         self.assertEqual(res, [])
 
         # claiming an OTK again should return the same fallback key
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.claim_one_time_keys(
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
             )
@@ -231,22 +204,23 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
 
         # if the user uploads a one-time key, the next claim should fetch the
         # one-time key, and then go back to the fallback
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_keys_for_user(
                 local_user, device_id, {"one_time_keys": otk}
             )
         )
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.claim_one_time_keys(
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
             )
         )
         self.assertEqual(
-            res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
+            res,
+            {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
         )
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.claim_one_time_keys(
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
             )
@@ -256,7 +230,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
         )
 
-    @defer.inlineCallbacks
     def test_replace_master_key(self):
         """uploading a new signing key should make the old signing key unavailable"""
         local_user = "@boris:" + self.hs.hostname
@@ -270,9 +243,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
                 },
             }
         }
-        yield defer.ensureDeferred(
-            self.handler.upload_signing_keys_for_user(local_user, keys1)
-        )
+        self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
 
         keys2 = {
             "master_key": {
@@ -284,16 +255,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
                 },
             }
         }
-        yield defer.ensureDeferred(
-            self.handler.upload_signing_keys_for_user(local_user, keys2)
-        )
+        self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys2))
 
-        devices = yield defer.ensureDeferred(
+        devices = self.get_success(
             self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
         )
         self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
 
-    @defer.inlineCallbacks
     def test_reupload_signatures(self):
         """re-uploading a signature should not fail"""
         local_user = "@boris:" + self.hs.hostname
@@ -326,9 +294,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
             "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0",
         )
-        yield defer.ensureDeferred(
-            self.handler.upload_signing_keys_for_user(local_user, keys1)
-        )
+        self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
 
         # upload two device keys, which will be signed later by the self-signing key
         device_key_1 = {
@@ -358,12 +324,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "signatures": {local_user: {"ed25519:def": "base64+signature"}},
         }
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_keys_for_user(
                 local_user, "abc", {"device_keys": device_key_1}
             )
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_keys_for_user(
                 local_user, "def", {"device_keys": device_key_2}
             )
@@ -372,7 +338,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         # sign the first device key and upload it
         del device_key_1["signatures"]
         sign.sign_json(device_key_1, local_user, signing_key)
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_signatures_for_device_keys(
                 local_user, {local_user: {"abc": device_key_1}}
             )
@@ -383,7 +349,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         # signature for it
         del device_key_2["signatures"]
         sign.sign_json(device_key_2, local_user, signing_key)
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_signatures_for_device_keys(
                 local_user, {local_user: {"abc": device_key_1, "def": device_key_2}}
             )
@@ -391,7 +357,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
 
         device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature"
         device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature"
-        devices = yield defer.ensureDeferred(
+        devices = self.get_success(
             self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
         )
         del devices["device_keys"][local_user]["abc"]["unsigned"]
@@ -399,7 +365,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1)
         self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2)
 
-    @defer.inlineCallbacks
     def test_self_signing_key_doesnt_show_up_as_device(self):
         """signing keys should be hidden when fetching a user's devices"""
         local_user = "@boris:" + self.hs.hostname
@@ -413,29 +378,22 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
                 },
             }
         }
-        yield defer.ensureDeferred(
-            self.handler.upload_signing_keys_for_user(local_user, keys1)
-        )
-
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.hs.get_device_handler().check_device_registered(
-                    user_id=local_user,
-                    device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
-                    initial_device_display_name="new display name",
-                )
-            )
-        except errors.SynapseError as e:
-            res = e.code
-        self.assertEqual(res, 400)
+        self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
 
-        res = yield defer.ensureDeferred(
-            self.handler.query_local_devices({local_user: None})
+        e = self.get_failure(
+            self.hs.get_device_handler().check_device_registered(
+                user_id=local_user,
+                device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
+                initial_device_display_name="new display name",
+            ),
+            SynapseError,
         )
+        res = e.value.code
+        self.assertEqual(res, 400)
+
+        res = self.get_success(self.handler.query_local_devices({local_user: None}))
         self.assertDictEqual(res, {local_user: {}})
 
-    @defer.inlineCallbacks
     def test_upload_signatures(self):
         """should check signatures that are uploaded"""
         # set up a user with cross-signing keys and a device.  This user will
@@ -458,7 +416,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA"
         )
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_keys_for_user(
                 local_user, device_id, {"device_keys": device_key}
             )
@@ -501,7 +459,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "user_signing_key": usersigning_key,
             "self_signing_key": selfsigning_key,
         }
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys)
         )
 
@@ -515,14 +473,14 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "usage": ["master"],
             "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey},
         }
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_signing_keys_for_user(
                 other_user, {"master_key": other_master_key}
             )
         )
 
         # test various signature failures (see below)
-        ret = yield defer.ensureDeferred(
+        ret = self.get_success(
             self.handler.upload_signatures_for_device_keys(
                 local_user,
                 {
@@ -602,20 +560,16 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         )
 
         user_failures = ret["failures"][local_user]
+        self.assertEqual(user_failures[device_id]["errcode"], Codes.INVALID_SIGNATURE)
         self.assertEqual(
-            user_failures[device_id]["errcode"], errors.Codes.INVALID_SIGNATURE
+            user_failures[master_pubkey]["errcode"], Codes.INVALID_SIGNATURE
         )
-        self.assertEqual(
-            user_failures[master_pubkey]["errcode"], errors.Codes.INVALID_SIGNATURE
-        )
-        self.assertEqual(user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND)
+        self.assertEqual(user_failures["unknown"]["errcode"], Codes.NOT_FOUND)
 
         other_user_failures = ret["failures"][other_user]
+        self.assertEqual(other_user_failures["unknown"]["errcode"], Codes.NOT_FOUND)
         self.assertEqual(
-            other_user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND
-        )
-        self.assertEqual(
-            other_user_failures[other_master_pubkey]["errcode"], errors.Codes.UNKNOWN
+            other_user_failures[other_master_pubkey]["errcode"], Codes.UNKNOWN
         )
 
         # test successful signatures
@@ -623,7 +577,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         sign.sign_json(device_key, local_user, selfsigning_signing_key)
         sign.sign_json(master_key, local_user, device_signing_key)
         sign.sign_json(other_master_key, local_user, usersigning_signing_key)
-        ret = yield defer.ensureDeferred(
+        ret = self.get_success(
             self.handler.upload_signatures_for_device_keys(
                 local_user,
                 {
@@ -636,7 +590,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         self.assertEqual(ret["failures"], {})
 
         # fetch the signed keys/devices and make sure that the signatures are there
-        ret = yield defer.ensureDeferred(
+        ret = self.get_success(
             self.handler.query_devices(
                 {"device_keys": {local_user: [], other_user: []}}, 0, local_user
             )
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 45f201a399..d7498aa51a 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -19,14 +19,9 @@ import copy
 
 import mock
 
-from twisted.internet import defer
+from synapse.api.errors import SynapseError
 
-import synapse.api.errors
-import synapse.handlers.e2e_room_keys
-import synapse.storage
-from synapse.api import errors
-
-from tests import unittest, utils
+from tests import unittest
 
 # sample room_key data for use in the tests
 room_keys = {
@@ -45,51 +40,38 @@ room_keys = {
 }
 
 
-class E2eRoomKeysHandlerTestCase(unittest.TestCase):
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        self.hs = None  # type: synapse.server.HomeServer
-        self.handler = None  # type: synapse.handlers.e2e_keys.E2eRoomKeysHandler
+class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
+    def make_homeserver(self, reactor, clock):
+        return self.setup_test_homeserver(replication_layer=mock.Mock())
 
-    @defer.inlineCallbacks
-    def setUp(self):
-        self.hs = yield utils.setup_test_homeserver(
-            self.addCleanup, replication_layer=mock.Mock()
-        )
-        self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs)
-        self.local_user = "@boris:" + self.hs.hostname
+    def prepare(self, reactor, clock, hs):
+        self.handler = hs.get_e2e_room_keys_handler()
+        self.local_user = "@boris:" + hs.hostname
 
-    @defer.inlineCallbacks
     def test_get_missing_current_version_info(self):
         """Check that we get a 404 if we ask for info about the current version
         if there is no version.
         """
-        res = None
-        try:
-            yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.get_version_info(self.local_user), SynapseError
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
 
-    @defer.inlineCallbacks
     def test_get_missing_version_info(self):
         """Check that we get a 404 if we ask for info about a specific version
         if it doesn't exist.
         """
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.get_version_info(self.local_user, "bogus_version")
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.get_version_info(self.local_user, "bogus_version"),
+            SynapseError,
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
 
-    @defer.inlineCallbacks
     def test_create_version(self):
-        """Check that we can create and then retrieve versions.
-        """
-        res = yield defer.ensureDeferred(
+        """Check that we can create and then retrieve versions."""
+        res = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -101,7 +83,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertEqual(res, "1")
 
         # check we can retrieve it as the current version
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         version_etag = res["etag"]
         self.assertIsInstance(version_etag, str)
         del res["etag"]
@@ -116,9 +98,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
 
         # check we can retrieve it as a specific version
-        res = yield defer.ensureDeferred(
-            self.handler.get_version_info(self.local_user, "1")
-        )
+        res = self.get_success(self.handler.get_version_info(self.local_user, "1"))
         self.assertEqual(res["etag"], version_etag)
         del res["etag"]
         self.assertDictEqual(
@@ -132,7 +112,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
 
         # upload a new one...
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -144,7 +124,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertEqual(res, "2")
 
         # check we can retrieve it as the current version
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         del res["etag"]
         self.assertDictEqual(
             res,
@@ -156,11 +136,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
             },
         )
 
-    @defer.inlineCallbacks
     def test_update_version(self):
-        """Check that we can update versions.
-        """
-        version = yield defer.ensureDeferred(
+        """Check that we can update versions."""
+        version = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -171,7 +149,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertEqual(version, "1")
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.update_version(
                 self.local_user,
                 version,
@@ -185,7 +163,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertDictEqual(res, {})
 
         # check we can retrieve it as the current version
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         del res["etag"]
         self.assertDictEqual(
             res,
@@ -197,32 +175,26 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
             },
         )
 
-    @defer.inlineCallbacks
     def test_update_missing_version(self):
-        """Check that we get a 404 on updating nonexistent versions
-        """
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.update_version(
-                    self.local_user,
-                    "1",
-                    {
-                        "algorithm": "m.megolm_backup.v1",
-                        "auth_data": "revised_first_version_auth_data",
-                        "version": "1",
-                    },
-                )
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        """Check that we get a 404 on updating nonexistent versions"""
+        e = self.get_failure(
+            self.handler.update_version(
+                self.local_user,
+                "1",
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "revised_first_version_auth_data",
+                    "version": "1",
+                },
+            ),
+            SynapseError,
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
 
-    @defer.inlineCallbacks
     def test_update_omitted_version(self):
-        """Check that the update succeeds if the version is missing from the body
-        """
-        version = yield defer.ensureDeferred(
+        """Check that the update succeeds if the version is missing from the body"""
+        version = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -233,7 +205,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertEqual(version, "1")
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.update_version(
                 self.local_user,
                 version,
@@ -245,7 +217,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
 
         # check we can retrieve it as the current version
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         del res["etag"]  # etag is opaque, so don't test its contents
         self.assertDictEqual(
             res,
@@ -257,11 +229,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
             },
         )
 
-    @defer.inlineCallbacks
     def test_update_bad_version(self):
-        """Check that we get a 400 if the version in the body doesn't match
-        """
-        version = yield defer.ensureDeferred(
+        """Check that we get a 400 if the version in the body doesn't match"""
+        version = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -272,52 +242,38 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertEqual(version, "1")
 
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.update_version(
-                    self.local_user,
-                    version,
-                    {
-                        "algorithm": "m.megolm_backup.v1",
-                        "auth_data": "revised_first_version_auth_data",
-                        "version": "incorrect",
-                    },
-                )
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.update_version(
+                self.local_user,
+                version,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "revised_first_version_auth_data",
+                    "version": "incorrect",
+                },
+            ),
+            SynapseError,
+        )
+        res = e.value.code
         self.assertEqual(res, 400)
 
-    @defer.inlineCallbacks
     def test_delete_missing_version(self):
-        """Check that we get a 404 on deleting nonexistent versions
-        """
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.delete_version(self.local_user, "1")
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        """Check that we get a 404 on deleting nonexistent versions"""
+        e = self.get_failure(
+            self.handler.delete_version(self.local_user, "1"), SynapseError
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
 
-    @defer.inlineCallbacks
     def test_delete_missing_current_version(self):
-        """Check that we get a 404 on deleting nonexistent current version
-        """
-        res = None
-        try:
-            yield defer.ensureDeferred(self.handler.delete_version(self.local_user))
-        except errors.SynapseError as e:
-            res = e.code
+        """Check that we get a 404 on deleting nonexistent current version"""
+        e = self.get_failure(self.handler.delete_version(self.local_user), SynapseError)
+        res = e.value.code
         self.assertEqual(res, 404)
 
-    @defer.inlineCallbacks
     def test_delete_version(self):
-        """Check that we can create and then delete versions.
-        """
-        res = yield defer.ensureDeferred(
+        """Check that we can create and then delete versions."""
+        res = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -329,36 +285,26 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertEqual(res, "1")
 
         # check we can delete it
-        yield defer.ensureDeferred(self.handler.delete_version(self.local_user, "1"))
+        self.get_success(self.handler.delete_version(self.local_user, "1"))
 
         # check that it's gone
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.get_version_info(self.local_user, "1")
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.get_version_info(self.local_user, "1"), SynapseError
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
 
-    @defer.inlineCallbacks
     def test_get_missing_backup(self):
-        """Check that we get a 404 on querying missing backup
-        """
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.get_room_keys(self.local_user, "bogus_version")
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        """Check that we get a 404 on querying missing backup"""
+        e = self.get_failure(
+            self.handler.get_room_keys(self.local_user, "bogus_version"), SynapseError
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
 
-    @defer.inlineCallbacks
     def test_get_missing_room_keys(self):
-        """Check we get an empty response from an empty backup
-        """
-        version = yield defer.ensureDeferred(
+        """Check we get an empty response from an empty backup"""
+        version = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -369,33 +315,26 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertEqual(version, "1")
 
-        res = yield defer.ensureDeferred(
-            self.handler.get_room_keys(self.local_user, version)
-        )
+        res = self.get_success(self.handler.get_room_keys(self.local_user, version))
         self.assertDictEqual(res, {"rooms": {}})
 
     # TODO: test the locking semantics when uploading room_keys,
     # although this is probably best done in sytest
 
-    @defer.inlineCallbacks
     def test_upload_room_keys_no_versions(self):
-        """Check that we get a 404 on uploading keys when no versions are defined
-        """
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.upload_room_keys(self.local_user, "no_version", room_keys)
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        """Check that we get a 404 on uploading keys when no versions are defined"""
+        e = self.get_failure(
+            self.handler.upload_room_keys(self.local_user, "no_version", room_keys),
+            SynapseError,
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
 
-    @defer.inlineCallbacks
     def test_upload_room_keys_bogus_version(self):
         """Check that we get a 404 on uploading keys when an nonexistent version
         is specified
         """
-        version = yield defer.ensureDeferred(
+        version = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -406,22 +345,16 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertEqual(version, "1")
 
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.upload_room_keys(
-                    self.local_user, "bogus_version", room_keys
-                )
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.upload_room_keys(self.local_user, "bogus_version", room_keys),
+            SynapseError,
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
 
-    @defer.inlineCallbacks
     def test_upload_room_keys_wrong_version(self):
-        """Check that we get a 403 on uploading keys for an old version
-        """
-        version = yield defer.ensureDeferred(
+        """Check that we get a 403 on uploading keys for an old version"""
+        version = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -432,7 +365,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertEqual(version, "1")
 
-        version = yield defer.ensureDeferred(
+        version = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -443,20 +376,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertEqual(version, "2")
 
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.upload_room_keys(self.local_user, "1", room_keys)
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.upload_room_keys(self.local_user, "1", room_keys), SynapseError
+        )
+        res = e.value.code
         self.assertEqual(res, 403)
 
-    @defer.inlineCallbacks
     def test_upload_room_keys_insert(self):
-        """Check that we can insert and retrieve keys for a session
-        """
-        version = yield defer.ensureDeferred(
+        """Check that we can insert and retrieve keys for a session"""
+        version = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -467,17 +395,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertEqual(version, "1")
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, room_keys)
         )
 
-        res = yield defer.ensureDeferred(
-            self.handler.get_room_keys(self.local_user, version)
-        )
+        res = self.get_success(self.handler.get_room_keys(self.local_user, version))
         self.assertDictEqual(res, room_keys)
 
         # check getting room_keys for a given room
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.get_room_keys(
                 self.local_user, version, room_id="!abc:matrix.org"
             )
@@ -485,18 +411,17 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertDictEqual(res, room_keys)
 
         # check getting room_keys for a given session_id
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.get_room_keys(
                 self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
             )
         )
         self.assertDictEqual(res, room_keys)
 
-    @defer.inlineCallbacks
     def test_upload_room_keys_merge(self):
         """Check that we can upload a new room_key for an existing session and
         have it correctly merged"""
-        version = yield defer.ensureDeferred(
+        version = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -507,12 +432,12 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertEqual(version, "1")
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, room_keys)
         )
 
         # get the etag to compare to future versions
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         backup_etag = res["etag"]
         self.assertEqual(res["count"], 1)
 
@@ -522,37 +447,33 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         # test that increasing the message_index doesn't replace the existing session
         new_room_key["first_message_index"] = 2
         new_room_key["session_data"] = "new"
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, new_room_keys)
         )
 
-        res = yield defer.ensureDeferred(
-            self.handler.get_room_keys(self.local_user, version)
-        )
+        res = self.get_success(self.handler.get_room_keys(self.local_user, version))
         self.assertEqual(
             res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
             "SSBBTSBBIEZJU0gK",
         )
 
         # the etag should be the same since the session did not change
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         self.assertEqual(res["etag"], backup_etag)
 
         # test that marking the session as verified however /does/ replace it
         new_room_key["is_verified"] = True
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, new_room_keys)
         )
 
-        res = yield defer.ensureDeferred(
-            self.handler.get_room_keys(self.local_user, version)
-        )
+        res = self.get_success(self.handler.get_room_keys(self.local_user, version))
         self.assertEqual(
             res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
         )
 
         # the etag should NOT be equal now, since the key changed
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         self.assertNotEqual(res["etag"], backup_etag)
         backup_etag = res["etag"]
 
@@ -560,28 +481,24 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         # with a lower forwarding count
         new_room_key["forwarded_count"] = 2
         new_room_key["session_data"] = "other"
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, new_room_keys)
         )
 
-        res = yield defer.ensureDeferred(
-            self.handler.get_room_keys(self.local_user, version)
-        )
+        res = self.get_success(self.handler.get_room_keys(self.local_user, version))
         self.assertEqual(
             res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
         )
 
         # the etag should be the same since the session did not change
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         self.assertEqual(res["etag"], backup_etag)
 
         # TODO: check edge cases as well as the common variations here
 
-    @defer.inlineCallbacks
     def test_delete_room_keys(self):
-        """Check that we can insert and delete keys for a session
-        """
-        version = yield defer.ensureDeferred(
+        """Check that we can insert and delete keys for a session"""
+        version = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -593,13 +510,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertEqual(version, "1")
 
         # check for bulk-delete
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, room_keys)
         )
-        yield defer.ensureDeferred(
-            self.handler.delete_room_keys(self.local_user, version)
-        )
-        res = yield defer.ensureDeferred(
+        self.get_success(self.handler.delete_room_keys(self.local_user, version))
+        res = self.get_success(
             self.handler.get_room_keys(
                 self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
             )
@@ -607,15 +522,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertDictEqual(res, {"rooms": {}})
 
         # check for bulk-delete per room
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, room_keys)
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.delete_room_keys(
                 self.local_user, version, room_id="!abc:matrix.org"
             )
         )
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.get_room_keys(
                 self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
             )
@@ -623,15 +538,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertDictEqual(res, {"rooms": {}})
 
         # check for bulk-delete per session
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, room_keys)
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.delete_room_keys(
                 self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
             )
         )
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.get_room_keys(
                 self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
             )
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 0b24b89a2e..3af361195b 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -16,7 +16,7 @@ import logging
 from unittest import TestCase
 
 from synapse.api.constants import EventTypes
-from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase
 from synapse.federation.federation_base import event_from_pdu_json
@@ -191,6 +191,58 @@ class FederationTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(sg, sg2)
 
+    @unittest.override_config(
+        {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
+    )
+    def test_invite_by_user_ratelimit(self):
+        """Tests that invites from federation to a particular user are
+        actually rate-limited.
+        """
+        other_server = "otherserver"
+        other_user = "@otheruser:" + other_server
+
+        # create the room
+        user_id = self.register_user("kermit", "test")
+        tok = self.login("kermit", "test")
+
+        def create_invite():
+            room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+            room_version = self.get_success(self.store.get_room_version(room_id))
+            return event_from_pdu_json(
+                {
+                    "type": EventTypes.Member,
+                    "content": {"membership": "invite"},
+                    "room_id": room_id,
+                    "sender": other_user,
+                    "state_key": "@user:test",
+                    "depth": 32,
+                    "prev_events": [],
+                    "auth_events": [],
+                    "origin_server_ts": self.clock.time_msec(),
+                },
+                room_version,
+            )
+
+        for i in range(3):
+            event = create_invite()
+            self.get_success(
+                self.handler.on_invite_request(
+                    other_server,
+                    event,
+                    event.room_version,
+                )
+            )
+
+        event = create_invite()
+        self.get_failure(
+            self.handler.on_invite_request(
+                other_server,
+                event,
+                event.room_version,
+            ),
+            exc=LimitExceededError,
+        )
+
     def _build_and_send_join_event(self, other_server, other_user, room_id):
         join_event = self.get_success(
             self.handler.on_make_join_request(other_server, room_id, other_user)
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index f955dfa490..a0d1ebdbe3 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -44,7 +44,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
         self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
 
         self.info = self.get_success(
-            self.hs.get_datastore().get_user_by_access_token(self.access_token,)
+            self.hs.get_datastore().get_user_by_access_token(
+                self.access_token,
+            )
         )
         self.token_id = self.info.token_id
 
@@ -169,8 +171,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
         self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
 
     def test_allow_server_acl(self):
-        """Test that sending an ACL that blocks everyone but ourselves works.
-        """
+        """Test that sending an ACL that blocks everyone but ourselves works."""
 
         self.helper.send_state(
             self.room_id,
@@ -181,8 +182,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
         )
 
     def test_deny_server_acl_block_outselves(self):
-        """Test that sending an ACL that blocks ourselves does not work.
-        """
+        """Test that sending an ACL that blocks ourselves does not work."""
         self.helper.send_state(
             self.room_id,
             EventTypes.ServerACL,
@@ -192,8 +192,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
         )
 
     def test_deny_redact_server_acl(self):
-        """Test that attempting to redact an ACL is blocked.
-        """
+        """Test that attempting to redact an ACL is blocked."""
 
         body = self.helper.send_state(
             self.room_id,
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index b3dfa40d25..cf1de28fa9 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -24,7 +24,7 @@ from synapse.handlers.sso import MappingException
 from synapse.server import HomeServer
 from synapse.types import UserID
 
-from tests.test_utils import FakeResponse, simple_async_mock
+from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
 
 try:
@@ -40,7 +40,7 @@ ISSUER = "https://issuer/"
 CLIENT_ID = "test-client-id"
 CLIENT_SECRET = "test-client-secret"
 BASE_URL = "https://synapse/"
-CALLBACK_URL = BASE_URL + "_synapse/oidc/callback"
+CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback"
 SCOPES = ["openid"]
 
 AUTHORIZATION_ENDPOINT = ISSUER + "authorize"
@@ -58,12 +58,6 @@ COMMON_CONFIG = {
 }
 
 
-# The cookie name and path don't really matter, just that it has to be coherent
-# between the callback & redirect handlers.
-COOKIE_NAME = b"oidc_session"
-COOKIE_PATH = "/_synapse/oidc"
-
-
 class TestMappingProvider:
     @staticmethod
     def parse_config(config):
@@ -137,7 +131,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
         return config
 
     def make_homeserver(self, reactor, clock):
-
         self.http_client = Mock(spec=["get_json"])
         self.http_client.get_json.side_effect = get_json
         self.http_client.user_agent = "Synapse Test"
@@ -157,7 +150,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
         return hs
 
     def metadata_edit(self, values):
-        return patch.dict(self.provider._provider_metadata, values)
+        """Modify the result that will be returned by the well-known query"""
+
+        async def patched_get_json(uri):
+            res = await get_json(uri)
+            if uri == WELL_KNOWN:
+                res.update(values)
+            return res
+
+        return patch.object(self.http_client, "get_json", patched_get_json)
 
     def assertRenderedError(self, error, error_description=None):
         self.render_error.assert_called_once()
@@ -218,7 +219,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.http_client.get_json.assert_called_once_with(JWKS_URI)
 
         # Throw if the JWKS uri is missing
-        with self.metadata_edit({"jwks_uri": None}):
+        original = self.provider.load_metadata
+
+        async def patched_load_metadata():
+            m = (await original()).copy()
+            m.update({"jwks_uri": None})
+            return m
+
+        with patch.object(self.provider, "load_metadata", patched_load_metadata):
             self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
 
         # Return empty key set if JWKS are not used
@@ -228,55 +236,60 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.http_client.get_json.assert_not_called()
         self.assertEqual(jwks, {"keys": []})
 
-    @override_config({"oidc_config": COMMON_CONFIG})
     def test_validate_config(self):
         """Provider metadatas are extensively validated."""
         h = self.provider
 
+        def force_load_metadata():
+            async def force_load():
+                return await h.load_metadata(force=True)
+
+            return get_awaitable_result(force_load())
+
         # Default test config does not throw
-        h._validate_metadata()
+        force_load_metadata()
 
         with self.metadata_edit({"issuer": None}):
-            self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
+            self.assertRaisesRegex(ValueError, "issuer", force_load_metadata)
 
         with self.metadata_edit({"issuer": "http://insecure/"}):
-            self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
+            self.assertRaisesRegex(ValueError, "issuer", force_load_metadata)
 
         with self.metadata_edit({"issuer": "https://invalid/?because=query"}):
-            self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
+            self.assertRaisesRegex(ValueError, "issuer", force_load_metadata)
 
         with self.metadata_edit({"authorization_endpoint": None}):
             self.assertRaisesRegex(
-                ValueError, "authorization_endpoint", h._validate_metadata
+                ValueError, "authorization_endpoint", force_load_metadata
             )
 
         with self.metadata_edit({"authorization_endpoint": "http://insecure/auth"}):
             self.assertRaisesRegex(
-                ValueError, "authorization_endpoint", h._validate_metadata
+                ValueError, "authorization_endpoint", force_load_metadata
             )
 
         with self.metadata_edit({"token_endpoint": None}):
-            self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata)
+            self.assertRaisesRegex(ValueError, "token_endpoint", force_load_metadata)
 
         with self.metadata_edit({"token_endpoint": "http://insecure/token"}):
-            self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata)
+            self.assertRaisesRegex(ValueError, "token_endpoint", force_load_metadata)
 
         with self.metadata_edit({"jwks_uri": None}):
-            self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata)
+            self.assertRaisesRegex(ValueError, "jwks_uri", force_load_metadata)
 
         with self.metadata_edit({"jwks_uri": "http://insecure/jwks.json"}):
-            self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata)
+            self.assertRaisesRegex(ValueError, "jwks_uri", force_load_metadata)
 
         with self.metadata_edit({"response_types_supported": ["id_token"]}):
             self.assertRaisesRegex(
-                ValueError, "response_types_supported", h._validate_metadata
+                ValueError, "response_types_supported", force_load_metadata
             )
 
         with self.metadata_edit(
             {"token_endpoint_auth_methods_supported": ["client_secret_basic"]}
         ):
             # should not throw, as client_secret_basic is the default auth method
-            h._validate_metadata()
+            force_load_metadata()
 
         with self.metadata_edit(
             {"token_endpoint_auth_methods_supported": ["client_secret_post"]}
@@ -284,7 +297,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             self.assertRaisesRegex(
                 ValueError,
                 "token_endpoint_auth_methods_supported",
-                h._validate_metadata,
+                force_load_metadata,
             )
 
         # Tests for configs that require the userinfo endpoint
@@ -293,28 +306,30 @@ class OidcHandlerTestCase(HomeserverTestCase):
         h._user_profile_method = "userinfo_endpoint"
         self.assertTrue(h._uses_userinfo)
 
-        # Revert the profile method and do not request the "openid" scope.
+        # Revert the profile method and do not request the "openid" scope: this should
+        # mean that we check for a userinfo endpoint
         h._user_profile_method = "auto"
         h._scopes = []
         self.assertTrue(h._uses_userinfo)
-        self.assertRaisesRegex(ValueError, "userinfo_endpoint", h._validate_metadata)
+        with self.metadata_edit({"userinfo_endpoint": None}):
+            self.assertRaisesRegex(ValueError, "userinfo_endpoint", force_load_metadata)
 
-        with self.metadata_edit(
-            {"userinfo_endpoint": USERINFO_ENDPOINT, "jwks_uri": None}
-        ):
-            # Shouldn't raise with a valid userinfo, even without
-            h._validate_metadata()
+        with self.metadata_edit({"jwks_uri": None}):
+            # Shouldn't raise with a valid userinfo, even without jwks
+            force_load_metadata()
 
     @override_config({"oidc_config": {"skip_verification": True}})
     def test_skip_verification(self):
         """Provider metadata validation can be disabled by config."""
         with self.metadata_edit({"issuer": "http://insecure"}):
             # This should not throw
-            self.provider._validate_metadata()
+            get_awaitable_result(self.provider.load_metadata())
 
     def test_redirect_request(self):
         """The redirect request has the right arguments & generates a valid session cookie."""
-        req = Mock(spec=["addCookie"])
+        req = Mock(spec=["cookies"])
+        req.cookies = []
+
         url = self.get_success(
             self.provider.handle_redirect_request(req, b"http://client/redirect")
         )
@@ -333,16 +348,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(len(params["state"]), 1)
         self.assertEqual(len(params["nonce"]), 1)
 
-        # Check what is in the cookie
-        # note: python3.5 mock does not have the .called_once() method
-        calls = req.addCookie.call_args_list
-        self.assertEqual(len(calls), 1)  # called once
-        # For some reason, call.args does not work with python3.5
-        args = calls[0][0]
-        kwargs = calls[0][1]
-        self.assertEqual(args[0], COOKIE_NAME)
-        self.assertEqual(kwargs["path"], COOKIE_PATH)
-        cookie = args[1]
+        # Check what is in the cookies
+        self.assertEqual(len(req.cookies), 2)  # two cookies
+        cookie_header = req.cookies[0]
+
+        # The cookie name and path don't really matter, just that it has to be coherent
+        # between the callback & redirect handlers.
+        parts = [p.strip() for p in cookie_header.split(b";")]
+        self.assertIn(b"Path=/_synapse/client/oidc", parts)
+        name, cookie = parts[0].split(b"=")
+        self.assertEqual(name, b"oidc_session")
 
         macaroon = pymacaroons.Macaroon.deserialize(cookie)
         state = self.handler._token_generator._get_value_from_macaroon(
@@ -419,7 +434,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.handler.handle_oidc_callback(request))
 
         auth_handler.complete_sso_login.assert_called_once_with(
-            expected_user_id, request, client_redirect_url, None,
+            expected_user_id, request, client_redirect_url, None, new_user=True
         )
         self.provider._exchange_code.assert_called_once_with(code)
         self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
@@ -450,7 +465,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.handler.handle_oidc_callback(request))
 
         auth_handler.complete_sso_login.assert_called_once_with(
-            expected_user_id, request, client_redirect_url, None,
+            expected_user_id, request, client_redirect_url, None, new_user=False
         )
         self.provider._exchange_code.assert_called_once_with(code)
         self.provider._parse_id_token.assert_not_called()
@@ -473,7 +488,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
     def test_callback_session(self):
         """The callback verifies the session presence and validity"""
-        request = Mock(spec=["args", "getCookie", "addCookie"])
+        request = Mock(spec=["args", "getCookie", "cookies"])
 
         # Missing cookie
         request.args = {}
@@ -496,7 +511,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         # Mismatching session
         session = self._generate_oidc_session_token(
-            state="state", nonce="nonce", client_redirect_url="http://client/redirect",
+            state="state",
+            nonce="nonce",
+            client_redirect_url="http://client/redirect",
         )
         request.args = {}
         request.args[b"state"] = [b"mismatching state"]
@@ -551,7 +568,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         # Internal server error with no JSON body
         self.http_client.request = simple_async_mock(
             return_value=FakeResponse(
-                code=500, phrase=b"Internal Server Error", body=b"Not JSON",
+                code=500,
+                phrase=b"Internal Server Error",
+                body=b"Not JSON",
             )
         )
         exc = self.get_failure(self.provider._exchange_code(code), OidcError)
@@ -571,7 +590,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         # 4xx error without "error" field
         self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
+            return_value=FakeResponse(
+                code=400,
+                phrase=b"Bad request",
+                body=b"{}",
+            )
         )
         exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "server_error")
@@ -579,7 +602,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         # 2xx error with "error" field
         self.http_client.request = simple_async_mock(
             return_value=FakeResponse(
-                code=200, phrase=b"OK", body=b'{"error": "some_error"}',
+                code=200,
+                phrase=b"OK",
+                body=b'{"error": "some_error"}',
             )
         )
         exc = self.get_failure(self.provider._exchange_code(code), OidcError)
@@ -616,14 +641,20 @@ class OidcHandlerTestCase(HomeserverTestCase):
         state = "state"
         client_redirect_url = "http://client/redirect"
         session = self._generate_oidc_session_token(
-            state=state, nonce="nonce", client_redirect_url=client_redirect_url,
+            state=state,
+            nonce="nonce",
+            client_redirect_url=client_redirect_url,
         )
         request = _build_callback_request("code", state, session)
 
         self.get_success(self.handler.handle_oidc_callback(request))
 
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@foo:test", request, client_redirect_url, {"phone": "1234567"},
+            "@foo:test",
+            request,
+            client_redirect_url,
+            {"phone": "1234567"},
+            new_user=True,
         )
 
     def test_map_userinfo_to_user(self):
@@ -637,7 +668,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", ANY, ANY, None,
+            "@test_user:test", ANY, ANY, None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -648,7 +679,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user_2:test", ANY, ANY, None,
+            "@test_user_2:test", ANY, ANY, None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -685,14 +716,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            user.to_string(), ANY, ANY, None,
+            user.to_string(), ANY, ANY, None, new_user=False
         )
         auth_handler.complete_sso_login.reset_mock()
 
         # Subsequent calls should map to the same mxid.
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            user.to_string(), ANY, ANY, None,
+            user.to_string(), ANY, ANY, None, new_user=False
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -707,7 +738,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            user.to_string(), ANY, ANY, None,
+            user.to_string(), ANY, ANY, None, new_user=False
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -743,7 +774,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@TEST_USER_2:test", ANY, ANY, None,
+            "@TEST_USER_2:test", ANY, ANY, None, new_user=False
         )
 
     def test_map_userinfo_to_invalid_localpart(self):
@@ -779,7 +810,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         # test_user is already taken, so test_user1 gets registered instead.
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user1:test", ANY, ANY, None,
+            "@test_user1:test", ANY, ANY, None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -875,7 +906,9 @@ async def _make_callback_with_userinfo(
     session = handler._token_generator.generate_oidc_session_token(
         state=state,
         session_data=OidcSessionData(
-            idp_id="oidc", nonce="nonce", client_redirect_url=client_redirect_url,
+            idp_id="oidc",
+            nonce="nonce",
+            client_redirect_url=client_redirect_url,
         ),
     )
     request = _build_callback_request("code", state, session)
@@ -909,13 +942,14 @@ def _build_callback_request(
         spec=[
             "args",
             "getCookie",
-            "addCookie",
+            "cookies",
             "requestHeaders",
             "getClientIP",
             "getHeader",
         ]
     )
 
+    request.cookies = []
     request.getCookie.return_value = session
     request.args = {}
     request.args[b"code"] = [code.encode("utf-8")]
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index f816594ee4..a98a65ae67 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -231,8 +231,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         }
     )
     def test_no_local_user_fallback_login(self):
-        """localdb_enabled can block login with the local password
-        """
+        """localdb_enabled can block login with the local password"""
         self.register_user("localuser", "localpass")
 
         # check_password must return an awaitable
@@ -251,8 +250,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         }
     )
     def test_no_local_user_fallback_ui_auth(self):
-        """localdb_enabled can block ui auth with the local password
-        """
+        """localdb_enabled can block ui auth with the local password"""
         self.register_user("localuser", "localpass")
 
         # allow login via the auth provider
@@ -594,7 +592,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         )
 
     def _delete_device(
-        self, access_token: str, device: str, body: Union[JsonDict, bytes] = b"",
+        self,
+        access_token: str,
+        device: str,
+        body: Union[JsonDict, bytes] = b"",
     ) -> FakeChannel:
         """Delete an individual device."""
         channel = self.make_request(
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 0794b32c9c..be2ee26f07 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -589,8 +589,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
         )
 
     def _add_new_user(self, room_id, user_id):
-        """Add new user to the room by creating an event and poking the federation API.
-        """
+        """Add new user to the room by creating an event and poking the federation API."""
 
         hostname = get_domain_from_id(user_id)
 
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 022943a10a..18ca8b84f5 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -13,25 +13,20 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 from mock import Mock
 
-from twisted.internet import defer
-
 import synapse.types
 from synapse.api.errors import AuthError, SynapseError
 from synapse.types import UserID
 
 from tests import unittest
 from tests.test_utils import make_awaitable
-from tests.utils import setup_test_homeserver
 
 
-class ProfileTestCase(unittest.TestCase):
+class ProfileTestCase(unittest.HomeserverTestCase):
     """ Tests profile management. """
 
-    @defer.inlineCallbacks
-    def setUp(self):
+    def make_homeserver(self, reactor, clock):
         self.mock_federation = Mock()
         self.mock_registry = Mock()
 
@@ -42,39 +37,35 @@ class ProfileTestCase(unittest.TestCase):
 
         self.mock_registry.register_query_handler = register_query_handler
 
-        hs = yield setup_test_homeserver(
-            self.addCleanup,
+        hs = self.setup_test_homeserver(
             federation_client=self.mock_federation,
             federation_server=Mock(),
             federation_registry=self.mock_registry,
         )
+        return hs
 
+    def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
 
         self.frank = UserID.from_string("@1234ABCD:test")
         self.bob = UserID.from_string("@4567:test")
         self.alice = UserID.from_string("@alice:remote")
 
-        yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart))
+        self.get_success(self.store.create_profile(self.frank.localpart))
 
         self.handler = hs.get_profile_handler()
-        self.hs = hs
 
-    @defer.inlineCallbacks
     def test_get_my_name(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_profile_displayname(self.frank.localpart, "Frank")
         )
 
-        displayname = yield defer.ensureDeferred(
-            self.handler.get_displayname(self.frank)
-        )
+        displayname = self.get_success(self.handler.get_displayname(self.frank))
 
         self.assertEquals("Frank", displayname)
 
-    @defer.inlineCallbacks
     def test_set_my_name(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.set_displayname(
                 self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
             )
@@ -82,7 +73,7 @@ class ProfileTestCase(unittest.TestCase):
 
         self.assertEquals(
             (
-                yield defer.ensureDeferred(
+                self.get_success(
                     self.store.get_profile_displayname(self.frank.localpart)
                 )
             ),
@@ -90,7 +81,7 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         # Set displayname again
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.set_displayname(
                 self.frank, synapse.types.create_requester(self.frank), "Frank"
             )
@@ -98,7 +89,7 @@ class ProfileTestCase(unittest.TestCase):
 
         self.assertEquals(
             (
-                yield defer.ensureDeferred(
+                self.get_success(
                     self.store.get_profile_displayname(self.frank.localpart)
                 )
             ),
@@ -106,32 +97,27 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         # Set displayname to an empty string
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.set_displayname(
                 self.frank, synapse.types.create_requester(self.frank), ""
             )
         )
 
         self.assertIsNone(
-            (
-                yield defer.ensureDeferred(
-                    self.store.get_profile_displayname(self.frank.localpart)
-                )
-            )
+            (self.get_success(self.store.get_profile_displayname(self.frank.localpart)))
         )
 
-    @defer.inlineCallbacks
     def test_set_my_name_if_disabled(self):
         self.hs.config.enable_set_displayname = False
 
         # Setting displayname for the first time is allowed
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_profile_displayname(self.frank.localpart, "Frank")
         )
 
         self.assertEquals(
             (
-                yield defer.ensureDeferred(
+                self.get_success(
                     self.store.get_profile_displayname(self.frank.localpart)
                 )
             ),
@@ -139,33 +125,27 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         # Setting displayname a second time is forbidden
-        d = defer.ensureDeferred(
+        self.get_failure(
             self.handler.set_displayname(
                 self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
-            )
+            ),
+            SynapseError,
         )
 
-        yield self.assertFailure(d, SynapseError)
-
-    @defer.inlineCallbacks
     def test_set_my_name_noauth(self):
-        d = defer.ensureDeferred(
+        self.get_failure(
             self.handler.set_displayname(
                 self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
-            )
+            ),
+            AuthError,
         )
 
-        yield self.assertFailure(d, AuthError)
-
-    @defer.inlineCallbacks
     def test_get_other_name(self):
         self.mock_federation.make_query.return_value = make_awaitable(
             {"displayname": "Alice"}
         )
 
-        displayname = yield defer.ensureDeferred(
-            self.handler.get_displayname(self.alice)
-        )
+        displayname = self.get_success(self.handler.get_displayname(self.alice))
 
         self.assertEquals(displayname, "Alice")
         self.mock_federation.make_query.assert_called_with(
@@ -175,14 +155,11 @@ class ProfileTestCase(unittest.TestCase):
             ignore_backoff=True,
         )
 
-    @defer.inlineCallbacks
     def test_incoming_fed_query(self):
-        yield defer.ensureDeferred(self.store.create_profile("caroline"))
-        yield defer.ensureDeferred(
-            self.store.set_profile_displayname("caroline", "Caroline")
-        )
+        self.get_success(self.store.create_profile("caroline"))
+        self.get_success(self.store.set_profile_displayname("caroline", "Caroline"))
 
-        response = yield defer.ensureDeferred(
+        response = self.get_success(
             self.query_handlers["profile"](
                 {"user_id": "@caroline:test", "field": "displayname"}
             )
@@ -190,20 +167,18 @@ class ProfileTestCase(unittest.TestCase):
 
         self.assertEquals({"displayname": "Caroline"}, response)
 
-    @defer.inlineCallbacks
     def test_get_my_avatar(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_profile_avatar_url(
                 self.frank.localpart, "http://my.server/me.png"
             )
         )
-        avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
+        avatar_url = self.get_success(self.handler.get_avatar_url(self.frank))
 
         self.assertEquals("http://my.server/me.png", avatar_url)
 
-    @defer.inlineCallbacks
     def test_set_my_avatar(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.set_avatar_url(
                 self.frank,
                 synapse.types.create_requester(self.frank),
@@ -212,16 +187,12 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         self.assertEquals(
-            (
-                yield defer.ensureDeferred(
-                    self.store.get_profile_avatar_url(self.frank.localpart)
-                )
-            ),
+            (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
             "http://my.server/pic.gif",
         )
 
         # Set avatar again
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.set_avatar_url(
                 self.frank,
                 synapse.types.create_requester(self.frank),
@@ -230,56 +201,44 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         self.assertEquals(
-            (
-                yield defer.ensureDeferred(
-                    self.store.get_profile_avatar_url(self.frank.localpart)
-                )
-            ),
+            (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
             "http://my.server/me.png",
         )
 
         # Set avatar to an empty string
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.set_avatar_url(
-                self.frank, synapse.types.create_requester(self.frank), "",
+                self.frank,
+                synapse.types.create_requester(self.frank),
+                "",
             )
         )
 
         self.assertIsNone(
-            (
-                yield defer.ensureDeferred(
-                    self.store.get_profile_avatar_url(self.frank.localpart)
-                )
-            ),
+            (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
         )
 
-    @defer.inlineCallbacks
     def test_set_my_avatar_if_disabled(self):
         self.hs.config.enable_set_avatar_url = False
 
         # Setting displayname for the first time is allowed
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_profile_avatar_url(
                 self.frank.localpart, "http://my.server/me.png"
             )
         )
 
         self.assertEquals(
-            (
-                yield defer.ensureDeferred(
-                    self.store.get_profile_avatar_url(self.frank.localpart)
-                )
-            ),
+            (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
             "http://my.server/me.png",
         )
 
         # Set avatar a second time is forbidden
-        d = defer.ensureDeferred(
+        self.get_failure(
             self.handler.set_avatar_url(
                 self.frank,
                 synapse.types.create_requester(self.frank),
                 "http://my.server/pic.gif",
-            )
+            ),
+            SynapseError,
         )
-
-        yield self.assertFailure(d, SynapseError)
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 261c7083d1..029af2853e 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -131,7 +131,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None
+            "@test_user:test", request, "redirect_uri", None, new_user=True
         )
 
     @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
@@ -157,7 +157,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "", None
+            "@test_user:test", request, "", None, new_user=False
         )
 
         # Subsequent calls should map to the same mxid.
@@ -166,7 +166,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
             self.handler._handle_authn_response(request, saml_response, "")
         )
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "", None
+            "@test_user:test", request, "", None, new_user=False
         )
 
     def test_map_saml_response_to_invalid_localpart(self):
@@ -214,7 +214,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # test_user is already taken, so test_user1 gets registered instead.
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user1:test", request, "", None
+            "@test_user1:test", request, "", None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -259,7 +259,61 @@ class SamlHandlerTestCase(HomeserverTestCase):
         )
         self.assertEqual(e.value.location, b"https://custom-saml-redirect/")
 
+    @override_config(
+        {
+            "saml2_config": {
+                "attribute_requirements": [
+                    {"attribute": "userGroup", "value": "staff"},
+                    {"attribute": "department", "value": "sales"},
+                ],
+            },
+        }
+    )
+    def test_attribute_requirements(self):
+        """The required attributes must be met from the SAML response."""
+
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        # The response doesn't have the proper userGroup or department.
+        saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, "redirect_uri")
+        )
+        auth_handler.complete_sso_login.assert_not_called()
+
+        # The response doesn't have the proper department.
+        saml_response = FakeAuthnResponse(
+            {"uid": "test_user", "username": "test_user", "userGroup": ["staff"]}
+        )
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, "redirect_uri")
+        )
+        auth_handler.complete_sso_login.assert_not_called()
+
+        # Add the proper attributes and it should succeed.
+        saml_response = FakeAuthnResponse(
+            {
+                "uid": "test_user",
+                "username": "test_user",
+                "userGroup": ["staff", "admin"],
+                "department": ["sales"],
+            }
+        )
+        request.reset_mock()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, "redirect_uri")
+        )
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "redirect_uri", None, new_user=True
+        )
+
 
 def _mock_request():
     """Returns a mock which will stand in as a SynapseRequest"""
-    return Mock(spec=["getClientIP", "getHeader"])
+    return Mock(spec=["getClientIP", "getHeader", "_disconnected"])
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 96e5bdac4a..24e7138196 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -143,14 +143,14 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.datastore.get_current_state_deltas = Mock(return_value=(0, None))
 
         self.datastore.get_to_device_stream_token = lambda: 0
-        self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
-            ([], 0)
+        self.datastore.get_new_device_msgs_for_remote = (
+            lambda *args, **kargs: make_awaitable(([], 0))
         )
-        self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
-            None
+        self.datastore.delete_device_msgs_for_remote = (
+            lambda *args, **kargs: make_awaitable(None)
         )
-        self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable(
-            None
+        self.datastore.set_received_txn_response = (
+            lambda *args, **kwargs: make_awaitable(None)
         )
 
     def test_started_typing_local(self):
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 9c886d671a..3572e54c5d 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -200,7 +200,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
 
         # Check that the room has an encryption state event
         event_content = self.helper.get_state(
-            room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token,
+            room_id=room_id,
+            event_type=EventTypes.RoomEncryption,
+            tok=user_token,
         )
         self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
 
@@ -209,7 +211,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
 
         # Check that the room has an encryption state event
         event_content = self.helper.get_state(
-            room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token,
+            room_id=room_id,
+            event_type=EventTypes.RoomEncryption,
+            tok=user_token,
         )
         self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
 
@@ -227,7 +231,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
 
         # Check that the room has an encryption state event
         event_content = self.helper.get_state(
-            room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token,
+            room_id=room_id,
+            event_type=EventTypes.RoomEncryption,
+            tok=user_token,
         )
         self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
 
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 686012dd25..4c56253da5 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -518,8 +518,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self.successResultOf(test_d)
 
     def test_get_well_known(self):
-        """Test the behaviour when the .well-known delegates elsewhere
-        """
+        """Test the behaviour when the .well-known delegates elsewhere"""
 
         self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
         self.reactor.lookups["testserv"] = "1.2.3.4"
@@ -1135,8 +1134,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self.assertIsNone(r.delegated_server)
 
     def test_srv_fallbacks(self):
-        """Test that other SRV results are tried if the first one fails.
-        """
+        """Test that other SRV results are tried if the first one fails."""
         self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
             [
                 Server(host=b"target.com", port=8443),
diff --git a/tests/http/test_client.py b/tests/http/test_client.py
index f17c122e93..2d9b733be0 100644
--- a/tests/http/test_client.py
+++ b/tests/http/test_client.py
@@ -18,6 +18,7 @@ from mock import Mock
 
 from twisted.python.failure import Failure
 from twisted.web.client import ResponseDone
+from twisted.web.iweb import UNKNOWN_LENGTH
 
 from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
 
@@ -27,12 +28,12 @@ from tests.unittest import TestCase
 class ReadBodyWithMaxSizeTests(TestCase):
     def setUp(self):
         """Start reading the body, returns the response, result and proto"""
-        self.response = Mock()
+        response = Mock(length=UNKNOWN_LENGTH)
         self.result = BytesIO()
-        self.deferred = read_body_with_max_size(self.response, self.result, 6)
+        self.deferred = read_body_with_max_size(response, self.result, 6)
 
         # Fish the protocol out of the response.
-        self.protocol = self.response.deliverBody.call_args[0][0]
+        self.protocol = response.deliverBody.call_args[0][0]
         self.protocol.transport = Mock()
 
     def _cleanup_error(self):
@@ -88,7 +89,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
         self.protocol.dataReceived(b"1234567890")
         self.assertIsInstance(self.deferred.result, Failure)
         self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
-        self.protocol.transport.loseConnection.assert_called_once()
+        self.protocol.transport.abortConnection.assert_called_once()
 
         # More data might have come in.
         self.protocol.dataReceived(b"1234567890")
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 27206ca3db..edacd1b566 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -100,7 +100,10 @@ class ModuleApiTestCase(HomeserverTestCase):
 
         # Check that the event was sent
         self.event_creation_handler.create_and_send_nonmember_event.assert_called_with(
-            expected_requester, event_dict, ratelimit=False, ignore_shadow_ban=True,
+            expected_requester,
+            event_dict,
+            ratelimit=False,
+            ignore_shadow_ban=True,
         )
 
         # Create and send a state event
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 961bf09de9..22f452ec24 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -124,13 +124,18 @@ class EmailPusherTests(HomeserverTestCase):
         )
         self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
 
-        # The other user sends some messages
+        # The other user sends a single message.
         self.helper.send(room, body="Hi!", tok=self.others[0].token)
-        self.helper.send(room, body="There!", tok=self.others[0].token)
 
         # We should get emailed about that message
         self._check_for_mail()
 
+        # The other user sends multiple messages.
+        self.helper.send(room, body="Hi!", tok=self.others[0].token)
+        self.helper.send(room, body="There!", tok=self.others[0].token)
+
+        self._check_for_mail()
+
     def test_invite_sends_email(self):
         # Create a room and invite the user to it
         room = self.helper.create_room_as(self.others[0].id, tok=self.others[0].token)
@@ -187,6 +192,75 @@ class EmailPusherTests(HomeserverTestCase):
         # We should get emailed about those messages
         self._check_for_mail()
 
+    def test_multiple_rooms(self):
+        # We want to test multiple notifications from multiple rooms, so we pause
+        # processing of push while we send messages.
+        self.pusher._pause_processing()
+
+        # Create a simple room with multiple other users
+        rooms = [
+            self.helper.create_room_as(self.user_id, tok=self.access_token),
+            self.helper.create_room_as(self.user_id, tok=self.access_token),
+        ]
+
+        for r, other in zip(rooms, self.others):
+            self.helper.invite(
+                room=r, src=self.user_id, tok=self.access_token, targ=other.id
+            )
+            self.helper.join(room=r, user=other.id, tok=other.token)
+
+        # The other users send some messages
+        self.helper.send(rooms[0], body="Hi!", tok=self.others[0].token)
+        self.helper.send(rooms[1], body="There!", tok=self.others[1].token)
+        self.helper.send(rooms[1], body="There!", tok=self.others[1].token)
+
+        # Nothing should have happened yet, as we're paused.
+        assert not self.email_attempts
+
+        self.pusher._resume_processing()
+
+        # We should get emailed about those messages
+        self._check_for_mail()
+
+    def test_empty_room(self):
+        """All users leaving a room shouldn't cause the pusher to break."""
+        # Create a simple room with two users
+        room = self.helper.create_room_as(self.user_id, tok=self.access_token)
+        self.helper.invite(
+            room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id
+        )
+        self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
+
+        # The other user sends a single message.
+        self.helper.send(room, body="Hi!", tok=self.others[0].token)
+
+        # Leave the room before the message is processed.
+        self.helper.leave(room, self.user_id, tok=self.access_token)
+        self.helper.leave(room, self.others[0].id, tok=self.others[0].token)
+
+        # We should get emailed about that message
+        self._check_for_mail()
+
+    def test_empty_room_multiple_messages(self):
+        """All users leaving a room shouldn't cause the pusher to break."""
+        # Create a simple room with two users
+        room = self.helper.create_room_as(self.user_id, tok=self.access_token)
+        self.helper.invite(
+            room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id
+        )
+        self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
+
+        # The other user sends a single message.
+        self.helper.send(room, body="Hi!", tok=self.others[0].token)
+        self.helper.send(room, body="There!", tok=self.others[0].token)
+
+        # Leave the room before the message is processed.
+        self.helper.leave(room, self.user_id, tok=self.access_token)
+        self.helper.leave(room, self.others[0].id, tok=self.others[0].token)
+
+        # We should get emailed about that message
+        self._check_for_mail()
+
     def test_encrypted_message(self):
         room = self.helper.create_room_as(self.user_id, tok=self.access_token)
         self.helper.invite(
@@ -239,3 +313,6 @@ class EmailPusherTests(HomeserverTestCase):
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
         self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
+
+        # Reset the attempts.
+        self.email_attempts = []
diff --git a/tests/push/test_presentable_names.py b/tests/push/test_presentable_names.py
new file mode 100644
index 0000000000..aff563919d
--- /dev/null
+++ b/tests/push/test_presentable_names.py
@@ -0,0 +1,229 @@
+#  Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+
+from typing import Iterable, Optional, Tuple
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.room_versions import RoomVersions
+from synapse.events import FrozenEvent
+from synapse.push.presentable_names import calculate_room_name
+from synapse.types import StateKey, StateMap
+
+from tests import unittest
+
+
+class MockDataStore:
+    """
+    A fake data store which stores a mapping of state key to event content.
+    (I.e. the state key is used as the event ID.)
+    """
+
+    def __init__(self, events: Iterable[Tuple[StateKey, dict]]):
+        """
+        Args:
+            events: A state map to event contents.
+        """
+        self._events = {}
+
+        for i, (event_id, content) in enumerate(events):
+            self._events[event_id] = FrozenEvent(
+                {
+                    "event_id": "$event_id",
+                    "type": event_id[0],
+                    "sender": "@user:test",
+                    "state_key": event_id[1],
+                    "room_id": "#room:test",
+                    "content": content,
+                    "origin_server_ts": i,
+                },
+                RoomVersions.V1,
+            )
+
+    async def get_event(
+        self, event_id: StateKey, allow_none: bool = False
+    ) -> Optional[FrozenEvent]:
+        assert allow_none, "Mock not configured for allow_none = False"
+
+        return self._events.get(event_id)
+
+    async def get_events(self, event_ids: Iterable[StateKey]):
+        # This is cheating since it just returns all events.
+        return self._events
+
+
+class PresentableNamesTestCase(unittest.HomeserverTestCase):
+    USER_ID = "@test:test"
+    OTHER_USER_ID = "@user:test"
+
+    def _calculate_room_name(
+        self,
+        events: StateMap[dict],
+        user_id: str = "",
+        fallback_to_members: bool = True,
+        fallback_to_single_member: bool = True,
+    ):
+        # This isn't 100% accurate, but works with MockDataStore.
+        room_state_ids = {k[0]: k[0] for k in events}
+
+        return self.get_success(
+            calculate_room_name(
+                MockDataStore(events),
+                room_state_ids,
+                user_id or self.USER_ID,
+                fallback_to_members,
+                fallback_to_single_member,
+            )
+        )
+
+    def test_name(self):
+        """A room name event should be used."""
+        events = [
+            ((EventTypes.Name, ""), {"name": "test-name"}),
+        ]
+        self.assertEqual("test-name", self._calculate_room_name(events))
+
+        # Check if the event content has garbage.
+        events = [((EventTypes.Name, ""), {"foo": 1})]
+        self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+        events = [((EventTypes.Name, ""), {"name": 1})]
+        self.assertEqual(1, self._calculate_room_name(events))
+
+    def test_canonical_alias(self):
+        """An canonical alias should be used."""
+        events = [
+            ((EventTypes.CanonicalAlias, ""), {"alias": "#test-name:test"}),
+        ]
+        self.assertEqual("#test-name:test", self._calculate_room_name(events))
+
+        # Check if the event content has garbage.
+        events = [((EventTypes.CanonicalAlias, ""), {"foo": 1})]
+        self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+        events = [((EventTypes.CanonicalAlias, ""), {"alias": "test-name"})]
+        self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+    def test_invite(self):
+        """An invite has special behaviour."""
+        events = [
+            ((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
+            ((EventTypes.Member, self.OTHER_USER_ID), {"displayname": "Other User"}),
+        ]
+        self.assertEqual("Invite from Other User", self._calculate_room_name(events))
+        self.assertIsNone(
+            self._calculate_room_name(events, fallback_to_single_member=False)
+        )
+        # Ensure this logic is skipped if we don't fallback to members.
+        self.assertIsNone(self._calculate_room_name(events, fallback_to_members=False))
+
+        # Check if the event content has garbage.
+        events = [
+            ((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
+            ((EventTypes.Member, self.OTHER_USER_ID), {"foo": 1}),
+        ]
+        self.assertEqual("Invite from @user:test", self._calculate_room_name(events))
+
+        # No member event for sender.
+        events = [
+            ((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
+        ]
+        self.assertEqual("Room Invite", self._calculate_room_name(events))
+
+    def test_no_members(self):
+        """Behaviour of an empty room."""
+        events = []
+        self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+        # Note that events with invalid (or missing) membership are ignored.
+        events = [
+            ((EventTypes.Member, self.OTHER_USER_ID), {"foo": 1}),
+            ((EventTypes.Member, "@foo:test"), {"membership": "foo"}),
+        ]
+        self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+    def test_no_other_members(self):
+        """Behaviour of a room with no other members in it."""
+        events = [
+            (
+                (EventTypes.Member, self.USER_ID),
+                {"membership": Membership.JOIN, "displayname": "Me"},
+            ),
+        ]
+        self.assertEqual("Me", self._calculate_room_name(events))
+
+        # Check if the event content has no displayname.
+        events = [
+            ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+        ]
+        self.assertEqual("@test:test", self._calculate_room_name(events))
+
+        # 3pid invite, use the other user (who is set as the sender).
+        events = [
+            ((EventTypes.Member, self.OTHER_USER_ID), {"membership": Membership.JOIN}),
+        ]
+        self.assertEqual(
+            "nobody", self._calculate_room_name(events, user_id=self.OTHER_USER_ID)
+        )
+
+        events = [
+            ((EventTypes.Member, self.OTHER_USER_ID), {"membership": Membership.JOIN}),
+            ((EventTypes.ThirdPartyInvite, self.OTHER_USER_ID), {}),
+        ]
+        self.assertEqual(
+            "Inviting email address",
+            self._calculate_room_name(events, user_id=self.OTHER_USER_ID),
+        )
+
+    def test_one_other_member(self):
+        """Behaviour of a room with a single other member."""
+        events = [
+            ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+            (
+                (EventTypes.Member, self.OTHER_USER_ID),
+                {"membership": Membership.JOIN, "displayname": "Other User"},
+            ),
+        ]
+        self.assertEqual("Other User", self._calculate_room_name(events))
+        self.assertIsNone(
+            self._calculate_room_name(events, fallback_to_single_member=False)
+        )
+
+        # Check if the event content has no displayname and is an invite.
+        events = [
+            ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+            (
+                (EventTypes.Member, self.OTHER_USER_ID),
+                {"membership": Membership.INVITE},
+            ),
+        ]
+        self.assertEqual("@user:test", self._calculate_room_name(events))
+
+    def test_other_members(self):
+        """Behaviour of a room with multiple other members."""
+        # Two other members.
+        events = [
+            ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+            (
+                (EventTypes.Member, self.OTHER_USER_ID),
+                {"membership": Membership.JOIN, "displayname": "Other User"},
+            ),
+            ((EventTypes.Member, "@foo:test"), {"membership": Membership.JOIN}),
+        ]
+        self.assertEqual("Other User and @foo:test", self._calculate_room_name(events))
+
+        # Three or more other members.
+        events.append(
+            ((EventTypes.Member, "@fourth:test"), {"membership": Membership.INVITE})
+        )
+        self.assertEqual("Other User and 2 others", self._calculate_room_name(events))
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 1f4b5ca2ac..4a841f5bb8 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -29,7 +29,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
                 "type": "m.room.history_visibility",
                 "sender": "@user:test",
                 "state_key": "",
-                "room_id": "@room:test",
+                "room_id": "#room:test",
                 "content": content,
             },
             RoomVersions.V1,
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 3379189785..f6a6aed35e 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -79,7 +79,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
 
         repl_handler = ReplicationCommandHandler(self.worker_hs)
         self.client = ClientReplicationStreamProtocol(
-            self.worker_hs, "client", "test", clock, repl_handler,
+            self.worker_hs,
+            "client",
+            "test",
+            clock,
+            repl_handler,
         )
 
         self._client_transport = None
@@ -212,6 +216,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         # Fake in memory Redis server that servers can connect to.
         self._redis_server = FakeRedisPubSubServer()
 
+        # We may have an attempt to connect to redis for the external cache already.
+        self.connect_any_redis_attempts()
+
         store = self.hs.get_datastore()
         self.database_pool = store.db_pool
 
@@ -225,7 +232,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         if self.hs.config.redis.redis_enabled:
             # Handle attempts to connect to fake redis server.
             self.reactor.add_tcp_client_callback(
-                "localhost", 6379, self.connect_any_redis_attempts,
+                "localhost",
+                6379,
+                self.connect_any_redis_attempts,
             )
 
             self.hs.get_tcp_replication().start_replication(self.hs)
@@ -243,8 +252,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         )
 
     def create_test_resource(self):
-        """Overrides `HomeserverTestCase.create_test_resource`.
-        """
+        """Overrides `HomeserverTestCase.create_test_resource`."""
         # We override this so that it automatically registers all the HTTP
         # replication servlets, without having to explicitly do that in all
         # subclassses.
@@ -293,7 +301,10 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
             if instance_loc.host not in self.reactor.lookups:
                 raise Exception(
                     "Host does not have an IP for instance_map[%r].host = %r"
-                    % (instance_name, instance_loc.host,)
+                    % (
+                        instance_name,
+                        instance_loc.host,
+                    )
                 )
 
             self.reactor.add_tcp_client_callback(
@@ -312,7 +323,11 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         if not worker_hs.config.redis_enabled:
             repl_handler = ReplicationCommandHandler(worker_hs)
             client = ClientReplicationStreamProtocol(
-                worker_hs, "client", "test", self.clock, repl_handler,
+                worker_hs,
+                "client",
+                "test",
+                self.clock,
+                repl_handler,
             )
             server = self.server_factory.buildProtocol(None)
 
@@ -401,25 +416,23 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         fake one.
         """
         clients = self.reactor.tcpClients
-        self.assertEqual(len(clients), 1)
-        (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
-        self.assertEqual(host, "localhost")
-        self.assertEqual(port, 6379)
+        while clients:
+            (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+            self.assertEqual(host, "localhost")
+            self.assertEqual(port, 6379)
 
-        client_protocol = client_factory.buildProtocol(None)
-        server_protocol = self._redis_server.buildProtocol(None)
+            client_protocol = client_factory.buildProtocol(None)
+            server_protocol = self._redis_server.buildProtocol(None)
 
-        client_to_server_transport = FakeTransport(
-            server_protocol, self.reactor, client_protocol
-        )
-        client_protocol.makeConnection(client_to_server_transport)
-
-        server_to_client_transport = FakeTransport(
-            client_protocol, self.reactor, server_protocol
-        )
-        server_protocol.makeConnection(server_to_client_transport)
+            client_to_server_transport = FakeTransport(
+                server_protocol, self.reactor, client_protocol
+            )
+            client_protocol.makeConnection(client_to_server_transport)
 
-        return client_to_server_transport, server_to_client_transport
+            server_to_client_transport = FakeTransport(
+                client_protocol, self.reactor, server_protocol
+            )
+            server_protocol.makeConnection(server_to_client_transport)
 
 
 class TestReplicationDataHandler(GenericWorkerReplicationHandler):
@@ -484,8 +497,7 @@ class _PushHTTPChannel(HTTPChannel):
             self._pull_to_push_producer.stop()
 
     def checkPersistence(self, request, version):
-        """Check whether the connection can be re-used
-        """
+        """Check whether the connection can be re-used"""
         # We hijack this to always say no for ease of wiring stuff up in
         # `handle_http_replication_attempt`.
         request.responseHeaders.setRawHeaders(b"connection", [b"close"])
@@ -493,8 +505,7 @@ class _PushHTTPChannel(HTTPChannel):
 
 
 class _PullToPushProducer:
-    """A push producer that wraps a pull producer.
-    """
+    """A push producer that wraps a pull producer."""
 
     def __init__(
         self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer
@@ -511,39 +522,33 @@ class _PullToPushProducer:
         self._start_loop()
 
     def _start_loop(self):
-        """Start the looping call to
-        """
+        """Start the looping call to"""
 
         if not self._looping_call:
             # Start a looping call which runs every tick.
             self._looping_call = self._clock.looping_call(self._run_once, 0)
 
     def stop(self):
-        """Stops calling resumeProducing.
-        """
+        """Stops calling resumeProducing."""
         if self._looping_call:
             self._looping_call.stop()
             self._looping_call = None
 
     def pauseProducing(self):
-        """Implements IPushProducer
-        """
+        """Implements IPushProducer"""
         self.stop()
 
     def resumeProducing(self):
-        """Implements IPushProducer
-        """
+        """Implements IPushProducer"""
         self._start_loop()
 
     def stopProducing(self):
-        """Implements IPushProducer
-        """
+        """Implements IPushProducer"""
         self.stop()
         self._producer.stopProducing()
 
     def _run_once(self):
-        """Calls resumeProducing on producer once.
-        """
+        """Calls resumeProducing on producer once."""
 
         try:
             self._producer.resumeProducing()
@@ -558,25 +563,21 @@ class _PullToPushProducer:
 
 
 class FakeRedisPubSubServer:
-    """A fake Redis server for pub/sub.
-    """
+    """A fake Redis server for pub/sub."""
 
     def __init__(self):
         self._subscribers = set()
 
     def add_subscriber(self, conn):
-        """A connection has called SUBSCRIBE
-        """
+        """A connection has called SUBSCRIBE"""
         self._subscribers.add(conn)
 
     def remove_subscriber(self, conn):
-        """A connection has called UNSUBSCRIBE
-        """
+        """A connection has called UNSUBSCRIBE"""
         self._subscribers.discard(conn)
 
     def publish(self, conn, channel, msg) -> int:
-        """A connection want to publish a message to subscribers.
-        """
+        """A connection want to publish a message to subscribers."""
         for sub in self._subscribers:
             sub.send(["message", channel, msg])
 
@@ -587,8 +588,7 @@ class FakeRedisPubSubServer:
 
 
 class FakeRedisPubSubProtocol(Protocol):
-    """A connection from a client talking to the fake Redis server.
-    """
+    """A connection from a client talking to the fake Redis server."""
 
     def __init__(self, server: FakeRedisPubSubServer):
         self._server = server
@@ -612,8 +612,7 @@ class FakeRedisPubSubProtocol(Protocol):
             self.handle_command(msg[0], *msg[1:])
 
     def handle_command(self, command, *args):
-        """Received a Redis command from the client.
-        """
+        """Received a Redis command from the client."""
 
         # We currently only support pub/sub.
         if command == b"PUBLISH":
@@ -624,12 +623,17 @@ class FakeRedisPubSubProtocol(Protocol):
             (channel,) = args
             self._server.add_subscriber(self)
             self.send(["subscribe", channel, 1])
+
+        # Since we use SET/GET to cache things we can safely no-op them.
+        elif command == b"SET":
+            self.send("OK")
+        elif command == b"GET":
+            self.send(None)
         else:
             raise Exception("Unknown command")
 
     def send(self, msg):
-        """Send a message back to the client.
-        """
+        """Send a message back to the client."""
         raw = self.encode(msg).encode("utf-8")
 
         self.transport.write(raw)
@@ -645,6 +649,8 @@ class FakeRedisPubSubProtocol(Protocol):
             # We assume bytes are just unicode strings.
             obj = obj.decode("utf-8")
 
+        if obj is None:
+            return "$-1\r\n"
         if isinstance(obj, str):
             return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
         if isinstance(obj, int):
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index c0ee1cfbd6..0ceb0f935c 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -66,7 +66,10 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
 
         self.get_success(
             self.master_store.store_room(
-                ROOM_ID, USER_ID, is_public=False, room_version=RoomVersions.V1,
+                ROOM_ID,
+                USER_ID,
+                is_public=False,
+                room_version=RoomVersions.V1,
             )
         )
 
diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py
index 6a5116dd2a..153634d4ee 100644
--- a/tests/replication/tcp/streams/test_account_data.py
+++ b/tests/replication/tcp/streams/test_account_data.py
@@ -23,8 +23,7 @@ from tests.replication._base import BaseStreamTestCase
 
 class AccountDataStreamTestCase(BaseStreamTestCase):
     def test_update_function_room_account_data_limit(self):
-        """Test replication with many room account data updates
-        """
+        """Test replication with many room account data updates"""
         store = self.hs.get_datastore()
 
         # generate lots of account data updates
@@ -70,8 +69,7 @@ class AccountDataStreamTestCase(BaseStreamTestCase):
         self.assertEqual([], received_rows)
 
     def test_update_function_global_account_data_limit(self):
-        """Test replication with many global account data updates
-        """
+        """Test replication with many global account data updates"""
         store = self.hs.get_datastore()
 
         # generate lots of account data updates
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index bad0df08cf..77856fc304 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -129,7 +129,10 @@ class EventsStreamTestCase(BaseStreamTestCase):
         )
         pls["users"][OTHER_USER] = 50
         self.helper.send_state(
-            self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok,
+            self.room_id,
+            EventTypes.PowerLevels,
+            pls,
+            tok=self.user_tok,
         )
 
         # this is the point in the DAG where we make a fork
@@ -255,8 +258,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
             self.assertIsNone(sr.event_id)
 
     def test_update_function_state_row_limit(self):
-        """Test replication with many state events over several stream ids.
-        """
+        """Test replication with many state events over several stream ids."""
 
         # we want to generate lots of state changes, but for this test, we want to
         # spread out the state changes over a few stream IDs.
@@ -282,7 +284,10 @@ class EventsStreamTestCase(BaseStreamTestCase):
         )
         pls["users"].update({u: 50 for u in user_ids})
         self.helper.send_state(
-            self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok,
+            self.room_id,
+            EventTypes.PowerLevels,
+            pls,
+            tok=self.user_tok,
         )
 
         # this is the point in the DAG where we make a fork
diff --git a/tests/replication/tcp/test_remote_server_up.py b/tests/replication/tcp/test_remote_server_up.py
index d1c15caeb0..1fe9d5b4d0 100644
--- a/tests/replication/tcp/test_remote_server_up.py
+++ b/tests/replication/tcp/test_remote_server_up.py
@@ -28,8 +28,7 @@ class RemoteServerUpTestCase(HomeserverTestCase):
         self.factory = ReplicationStreamProtocolFactory(hs)
 
     def _make_client(self) -> Tuple[IProtocol, StringTransport]:
-        """Create a new direct TCP replication connection
-        """
+        """Create a new direct TCP replication connection"""
 
         proto = self.factory.buildProtocol(("127.0.0.1", 0))
         transport = StringTransport()
diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py
index f35a5235e1..f8fd8a843c 100644
--- a/tests/replication/test_auth.py
+++ b/tests/replication/test_auth.py
@@ -79,8 +79,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
         )
 
     def test_no_auth(self):
-        """With no authentication the request should finish.
-        """
+        """With no authentication the request should finish."""
         channel = self._test_register()
         self.assertEqual(channel.code, 200)
 
@@ -89,8 +88,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
 
     @override_config({"main_replication_secret": "my-secret"})
     def test_missing_auth(self):
-        """If the main process expects a secret that is not provided, an error results.
-        """
+        """If the main process expects a secret that is not provided, an error results."""
         channel = self._test_register()
         self.assertEqual(channel.code, 500)
 
@@ -101,15 +99,13 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
         }
     )
     def test_unauthorized(self):
-        """If the main process receives the wrong secret, an error results.
-        """
+        """If the main process receives the wrong secret, an error results."""
         channel = self._test_register()
         self.assertEqual(channel.code, 500)
 
     @override_config({"worker_replication_secret": "my-secret"})
     def test_authorized(self):
-        """The request should finish when the worker provides the authentication header.
-        """
+        """The request should finish when the worker provides the authentication header."""
         channel = self._test_register()
         self.assertEqual(channel.code, 200)
 
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
index 4608b65a0c..5da1d5dc4d 100644
--- a/tests/replication/test_client_reader_shard.py
+++ b/tests/replication/test_client_reader_shard.py
@@ -35,8 +35,7 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
         return config
 
     def test_register_single_worker(self):
-        """Test that registration works when using a single client reader worker.
-        """
+        """Test that registration works when using a single client reader worker."""
         worker_hs = self.make_worker_hs("synapse.app.client_reader")
         site = self._hs_to_site[worker_hs]
 
@@ -66,8 +65,7 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
         self.assertEqual(channel_2.json_body["user_id"], "@user:test")
 
     def test_register_multi_worker(self):
-        """Test that registration works when using multiple client reader workers.
-        """
+        """Test that registration works when using multiple client reader workers."""
         worker_hs_1 = self.make_worker_hs("synapse.app.client_reader")
         worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
 
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index d1feca961f..7ff11cde10 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -36,8 +36,7 @@ test_server_connection_factory = None
 
 
 class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
-    """Checks running multiple media repos work correctly.
-    """
+    """Checks running multiple media repos work correctly."""
 
     servlets = [
         admin.register_servlets_for_client_rest_resource,
@@ -124,8 +123,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         return channel, request
 
     def test_basic(self):
-        """Test basic fetching of remote media from a single worker.
-        """
+        """Test basic fetching of remote media from a single worker."""
         hs1 = self.make_worker_hs("synapse.app.generic_worker")
 
         channel, request = self._get_media_req(hs1, "example.com:443", "ABC123")
@@ -223,16 +221,14 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         self.assertEqual(start_count + 3, self._count_remote_thumbnails())
 
     def _count_remote_media(self) -> int:
-        """Count the number of files in our remote media directory.
-        """
+        """Count the number of files in our remote media directory."""
         path = os.path.join(
             self.hs.get_media_repository().primary_base_path, "remote_content"
         )
         return sum(len(files) for _, _, files in os.walk(path))
 
     def _count_remote_thumbnails(self) -> int:
-        """Count the number of files in our remote thumbnails directory.
-        """
+        """Count the number of files in our remote thumbnails directory."""
         path = os.path.join(
             self.hs.get_media_repository().primary_base_path, "remote_thumbnail"
         )
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 800ad94a04..f118fe32af 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -27,8 +27,7 @@ logger = logging.getLogger(__name__)
 
 
 class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
-    """Checks pusher sharding works
-    """
+    """Checks pusher sharding works"""
 
     servlets = [
         admin.register_servlets_for_client_rest_resource,
@@ -88,11 +87,10 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         return event_id
 
     def test_send_push_single_worker(self):
-        """Test that registration works when using a pusher worker.
-        """
+        """Test that registration works when using a pusher worker."""
         http_client_mock = Mock(spec_set=["post_json_get_json"])
-        http_client_mock.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
-            {}
+        http_client_mock.post_json_get_json.side_effect = (
+            lambda *_, **__: defer.succeed({})
         )
 
         self.make_worker_hs(
@@ -119,11 +117,10 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         )
 
     def test_send_push_multiple_workers(self):
-        """Test that registration works when using sharded pusher workers.
-        """
+        """Test that registration works when using sharded pusher workers."""
         http_client_mock1 = Mock(spec_set=["post_json_get_json"])
-        http_client_mock1.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
-            {}
+        http_client_mock1.post_json_get_json.side_effect = (
+            lambda *_, **__: defer.succeed({})
         )
 
         self.make_worker_hs(
@@ -137,8 +134,8 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         )
 
         http_client_mock2 = Mock(spec_set=["post_json_get_json"])
-        http_client_mock2.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
-            {}
+        http_client_mock2.post_json_get_json.side_effect = (
+            lambda *_, **__: defer.succeed({})
         )
 
         self.make_worker_hs(
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index 8d494ebc03..c9b773fbd2 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -29,8 +29,7 @@ logger = logging.getLogger(__name__)
 
 
 class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
-    """Checks event persisting sharding works
-    """
+    """Checks event persisting sharding works"""
 
     # Event persister sharding requires postgres (due to needing
     # `MutliWriterIdGenerator`).
@@ -63,8 +62,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         return conf
 
     def _create_room(self, room_id: str, user_id: str, tok: str):
-        """Create a room with given room_id
-        """
+        """Create a room with given room_id"""
 
         # We control the room ID generation by patching out the
         # `_generate_room_id` method
@@ -91,11 +89,13 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         """
 
         self.make_worker_hs(
-            "synapse.app.generic_worker", {"worker_name": "worker1"},
+            "synapse.app.generic_worker",
+            {"worker_name": "worker1"},
         )
 
         self.make_worker_hs(
-            "synapse.app.generic_worker", {"worker_name": "worker2"},
+            "synapse.app.generic_worker",
+            {"worker_name": "worker2"},
         )
 
         persisted_on_1 = False
@@ -139,15 +139,18 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         """
 
         self.make_worker_hs(
-            "synapse.app.generic_worker", {"worker_name": "worker1"},
+            "synapse.app.generic_worker",
+            {"worker_name": "worker1"},
         )
 
         worker_hs2 = self.make_worker_hs(
-            "synapse.app.generic_worker", {"worker_name": "worker2"},
+            "synapse.app.generic_worker",
+            {"worker_name": "worker2"},
         )
 
         sync_hs = self.make_worker_hs(
-            "synapse.app.generic_worker", {"worker_name": "sync"},
+            "synapse.app.generic_worker",
+            {"worker_name": "sync"},
         )
         sync_hs_site = self._hs_to_site[sync_hs]
 
@@ -323,7 +326,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
             sync_hs_site,
             "GET",
             "/rooms/{}/messages?from={}&to={}&dir=f".format(
-                room_id2, vector_clock_token, prev_batch2,
+                room_id2,
+                vector_clock_token,
+                prev_batch2,
             ),
             access_token=access_token,
         )
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 9d22c04073..057e27372e 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -130,8 +130,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
         )
 
     def _get_groups_user_is_in(self, access_token):
-        """Returns the list of groups the user is in (given their access token)
-        """
+        """Returns the list of groups the user is in (given their access token)"""
         channel = self.make_request(
             "GET", "/joined_groups".encode("ascii"), access_token=access_token
         )
@@ -142,8 +141,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
 
 
 class QuarantineMediaTestCase(unittest.HomeserverTestCase):
-    """Test /quarantine_media admin API.
-    """
+    """Test /quarantine_media admin API."""
 
     servlets = [
         synapse.rest.admin.register_servlets,
@@ -237,7 +235,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
         # Attempt quarantine media APIs as non-admin
         url = "/_synapse/admin/v1/media/quarantine/example.org/abcde12345"
         channel = self.make_request(
-            "POST", url.encode("ascii"), access_token=non_admin_user_tok,
+            "POST",
+            url.encode("ascii"),
+            access_token=non_admin_user_tok,
         )
 
         # Expect a forbidden error
@@ -250,7 +250,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
         # And the roomID/userID endpoint
         url = "/_synapse/admin/v1/room/!room%3Aexample.com/media/quarantine"
         channel = self.make_request(
-            "POST", url.encode("ascii"), access_token=non_admin_user_tok,
+            "POST",
+            url.encode("ascii"),
+            access_token=non_admin_user_tok,
         )
 
         # Expect a forbidden error
@@ -294,7 +296,11 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
             urllib.parse.quote(server_name),
             urllib.parse.quote(media_id),
         )
-        channel = self.make_request("POST", url, access_token=admin_user_tok,)
+        channel = self.make_request(
+            "POST",
+            url,
+            access_token=admin_user_tok,
+        )
         self.pump(1.0)
         self.assertEqual(200, int(channel.code), msg=channel.result["body"])
 
@@ -346,7 +352,11 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
             url = "/_synapse/admin/v1/room/%s/media/quarantine" % urllib.parse.quote(
                 room_id
             )
-        channel = self.make_request("POST", url, access_token=admin_user_tok,)
+        channel = self.make_request(
+            "POST",
+            url,
+            access_token=admin_user_tok,
+        )
         self.pump(1.0)
         self.assertEqual(200, int(channel.code), msg=channel.result["body"])
         self.assertEqual(
@@ -391,7 +401,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
             non_admin_user
         )
         channel = self.make_request(
-            "POST", url.encode("ascii"), access_token=admin_user_tok,
+            "POST",
+            url.encode("ascii"),
+            access_token=admin_user_tok,
         )
         self.pump(1.0)
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -437,7 +449,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
             non_admin_user
         )
         channel = self.make_request(
-            "POST", url.encode("ascii"), access_token=admin_user_tok,
+            "POST",
+            url.encode("ascii"),
+            access_token=admin_user_tok,
         )
         self.pump(1.0)
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index 248c4442c3..2a1bcf1760 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -70,21 +70,27 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         If the user is not a server admin, an error is returned.
         """
         channel = self.make_request(
-            "GET", self.url, access_token=self.other_user_token,
+            "GET",
+            self.url,
+            access_token=self.other_user_token,
         )
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
         channel = self.make_request(
-            "PUT", self.url, access_token=self.other_user_token,
+            "PUT",
+            self.url,
+            access_token=self.other_user_token,
         )
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
         channel = self.make_request(
-            "DELETE", self.url, access_token=self.other_user_token,
+            "DELETE",
+            self.url,
+            access_token=self.other_user_token,
         )
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
@@ -99,17 +105,29 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
             % self.other_user_device_id
         )
 
-        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
 
-        channel = self.make_request("PUT", url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "PUT",
+            url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
 
-        channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "DELETE",
+            url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -123,17 +141,29 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
             % self.other_user_device_id
         )
 
-        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Can only lookup local users", channel.json_body["error"])
 
-        channel = self.make_request("PUT", url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "PUT",
+            url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Can only lookup local users", channel.json_body["error"])
 
-        channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "DELETE",
+            url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -146,16 +176,28 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
             self.other_user
         )
 
-        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
 
-        channel = self.make_request("PUT", url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "PUT",
+            url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
 
-        channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "DELETE",
+            url,
+            access_token=self.admin_user_tok,
+        )
 
         # Delete unknown device returns status 200
         self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -190,7 +232,11 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"])
 
         # Ensure the display name was not updated.
-        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual("new display", channel.json_body["display_name"])
@@ -207,12 +253,20 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        channel = self.make_request("PUT", self.url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "PUT",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
 
         # Ensure the display name was not updated.
-        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual("new display", channel.json_body["display_name"])
@@ -233,7 +287,11 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, channel.code, msg=channel.json_body)
 
         # Check new display_name
-        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual("new displayname", channel.json_body["display_name"])
@@ -242,7 +300,11 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         """
         Tests that a normal lookup for a device is successfully
         """
-        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(self.other_user, channel.json_body["user_id"])
@@ -264,7 +326,9 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
 
         # Delete device
         channel = self.make_request(
-            "DELETE", self.url, access_token=self.admin_user_tok,
+            "DELETE",
+            self.url,
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -306,7 +370,11 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
         """
         other_user_token = self.login("user", "pass")
 
-        channel = self.make_request("GET", self.url, access_token=other_user_token,)
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=other_user_token,
+        )
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -316,7 +384,11 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
         Tests that a lookup for a user that does not exist returns a 404
         """
         url = "/_synapse/admin/v2/users/@unknown_person:test/devices"
-        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -327,7 +399,11 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
         """
         url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices"
 
-        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -339,7 +415,11 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
         """
 
         # Get devices
-        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(0, channel.json_body["total"])
@@ -355,7 +435,11 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
             self.login("user", "pass")
 
         # Get devices
-        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(number_devices, channel.json_body["total"])
@@ -404,7 +488,11 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
         """
         other_user_token = self.login("user", "pass")
 
-        channel = self.make_request("POST", self.url, access_token=other_user_token,)
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=other_user_token,
+        )
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -414,7 +502,11 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
         Tests that a lookup for a user that does not exist returns a 404
         """
         url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices"
-        channel = self.make_request("POST", url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "POST",
+            url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -425,7 +517,11 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
         """
         url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices"
 
-        channel = self.make_request("POST", url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "POST",
+            url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Can only lookup local users", channel.json_body["error"])
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index d0090faa4f..e30ffe4fa0 100644
--- a/tests/rest/admin/test_event_reports.py
+++ b/tests/rest/admin/test_event_reports.py
@@ -51,19 +51,23 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         # Two rooms and two users. Every user sends and reports every room event
         for i in range(5):
             self._create_event_and_report(
-                room_id=self.room_id1, user_tok=self.other_user_tok,
+                room_id=self.room_id1,
+                user_tok=self.other_user_tok,
             )
         for i in range(5):
             self._create_event_and_report(
-                room_id=self.room_id2, user_tok=self.other_user_tok,
+                room_id=self.room_id2,
+                user_tok=self.other_user_tok,
             )
         for i in range(5):
             self._create_event_and_report(
-                room_id=self.room_id1, user_tok=self.admin_user_tok,
+                room_id=self.room_id1,
+                user_tok=self.admin_user_tok,
             )
         for i in range(5):
             self._create_event_and_report(
-                room_id=self.room_id2, user_tok=self.admin_user_tok,
+                room_id=self.room_id2,
+                user_tok=self.admin_user_tok,
             )
 
         self.url = "/_synapse/admin/v1/event_reports"
@@ -82,7 +86,11 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         If the user is not a server admin, an error 403 is returned.
         """
 
-        channel = self.make_request("GET", self.url, access_token=self.other_user_tok,)
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.other_user_tok,
+        )
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -92,7 +100,11 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         Testing list of reported events
         """
 
-        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(channel.json_body["total"], 20)
@@ -106,7 +118,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         """
 
         channel = self.make_request(
-            "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?limit=5",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -121,7 +135,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         """
 
         channel = self.make_request(
-            "GET", self.url + "?from=5", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?from=5",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -136,7 +152,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         """
 
         channel = self.make_request(
-            "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?from=5&limit=10",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -213,7 +231,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
 
         # fetch the most recent first, largest timestamp
         channel = self.make_request(
-            "GET", self.url + "?dir=b", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?dir=b",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -229,7 +249,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
 
         # fetch the oldest first, smallest timestamp
         channel = self.make_request(
-            "GET", self.url + "?dir=f", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?dir=f",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -249,7 +271,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         """
 
         channel = self.make_request(
-            "GET", self.url + "?dir=bar", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?dir=bar",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -262,7 +286,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         """
 
         channel = self.make_request(
-            "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?limit=-5",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -274,7 +300,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         """
 
         channel = self.make_request(
-            "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?from=-5",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -288,7 +316,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         #  `next_token` does not appear
         # Number of results is the number of entries
         channel = self.make_request(
-            "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?limit=20",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -299,7 +329,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         #  `next_token` does not appear
         # Number of max results is larger than the number of entries
         channel = self.make_request(
-            "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?limit=21",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -310,7 +342,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         #  `next_token` does appear
         # Number of max results is smaller than the number of entries
         channel = self.make_request(
-            "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?limit=19",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -322,7 +356,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         # Set `from` to value of `next_token` for request remaining entries
         #  `next_token` does not appear
         channel = self.make_request(
-            "GET", self.url + "?from=19", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?from=19",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -331,8 +367,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         self.assertNotIn("next_token", channel.json_body)
 
     def _create_event_and_report(self, room_id, user_tok):
-        """Create and report events
-        """
+        """Create and report events"""
         resp = self.helper.send(room_id, tok=user_tok)
         event_id = resp["event_id"]
 
@@ -345,8 +380,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
 
     def _check_fields(self, content):
-        """Checks that all attributes are present in an event report
-        """
+        """Checks that all attributes are present in an event report"""
         for c in content:
             self.assertIn("id", c)
             self.assertIn("received_ts", c)
@@ -381,7 +415,8 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
         self.helper.join(self.room_id1, user=self.admin_user, tok=self.admin_user_tok)
 
         self._create_event_and_report(
-            room_id=self.room_id1, user_tok=self.other_user_tok,
+            room_id=self.room_id1,
+            user_tok=self.other_user_tok,
         )
 
         # first created event report gets `id`=2
@@ -401,7 +436,11 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
         If the user is not a server admin, an error 403 is returned.
         """
 
-        channel = self.make_request("GET", self.url, access_token=self.other_user_tok,)
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.other_user_tok,
+        )
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -411,7 +450,11 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
         Testing get a reported event
         """
 
-        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self._check_fields(channel.json_body)
@@ -479,8 +522,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
         self.assertEqual("Event report not found", channel.json_body["error"])
 
     def _create_event_and_report(self, room_id, user_tok):
-        """Create and report events
-        """
+        """Create and report events"""
         resp = self.helper.send(room_id, tok=user_tok)
         event_id = resp["event_id"]
 
@@ -493,8 +535,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
 
     def _check_fields(self, content):
-        """Checks that all attributes are present in a event report
-        """
+        """Checks that all attributes are present in a event report"""
         self.assertIn("id", content)
         self.assertIn("received_ts", content)
         self.assertIn("room_id", content)
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 51a7731693..31db472cd3 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -63,7 +63,11 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
 
         url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
 
-        channel = self.make_request("DELETE", url, access_token=self.other_user_token,)
+        channel = self.make_request(
+            "DELETE",
+            url,
+            access_token=self.other_user_token,
+        )
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -74,7 +78,11 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
         """
         url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
 
-        channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "DELETE",
+            url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -85,7 +93,11 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
         """
         url = "/_synapse/admin/v1/media/%s/%s" % ("unknown_domain", "12345")
 
-        channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "DELETE",
+            url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Can only delete local media", channel.json_body["error"])
@@ -139,12 +151,17 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
         url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, media_id)
 
         # Delete media
-        channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "DELETE",
+            url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(1, channel.json_body["total"])
         self.assertEqual(
-            media_id, channel.json_body["deleted_media"][0],
+            media_id,
+            channel.json_body["deleted_media"][0],
         )
 
         # Attempt to access media
@@ -207,7 +224,9 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         self.other_user_token = self.login("user", "pass")
 
         channel = self.make_request(
-            "POST", self.url, access_token=self.other_user_token,
+            "POST",
+            self.url,
+            access_token=self.other_user_token,
         )
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
@@ -220,7 +239,9 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         url = "/_synapse/admin/v1/media/%s/delete" % "unknown_domain"
 
         channel = self.make_request(
-            "POST", url + "?before_ts=1234", access_token=self.admin_user_tok,
+            "POST",
+            url + "?before_ts=1234",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
@@ -230,7 +251,11 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         """
         If the parameter `before_ts` is missing, an error is returned.
         """
-        channel = self.make_request("POST", self.url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
@@ -243,7 +268,9 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         If parameters are invalid, an error is returned.
         """
         channel = self.make_request(
-            "POST", self.url + "?before_ts=-1234", access_token=self.admin_user_tok,
+            "POST",
+            self.url + "?before_ts=-1234",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -304,7 +331,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(1, channel.json_body["total"])
         self.assertEqual(
-            media_id, channel.json_body["deleted_media"][0],
+            media_id,
+            channel.json_body["deleted_media"][0],
         )
 
         self._access_media(server_and_media_id, False)
@@ -340,7 +368,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(1, channel.json_body["total"])
         self.assertEqual(
-            server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0],
+            server_and_media_id.split("/")[1],
+            channel.json_body["deleted_media"][0],
         )
 
         self._access_media(server_and_media_id, False)
@@ -374,7 +403,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(1, channel.json_body["total"])
         self.assertEqual(
-            server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0],
+            server_and_media_id.split("/")[1],
+            channel.json_body["deleted_media"][0],
         )
 
         self._access_media(server_and_media_id, False)
@@ -417,7 +447,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(1, channel.json_body["total"])
         self.assertEqual(
-            server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0],
+            server_and_media_id.split("/")[1],
+            channel.json_body["deleted_media"][0],
         )
 
         self._access_media(server_and_media_id, False)
@@ -461,7 +492,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(1, channel.json_body["total"])
         self.assertEqual(
-            server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0],
+            server_and_media_id.split("/")[1],
+            channel.json_body["deleted_media"][0],
         )
 
         self._access_media(server_and_media_id, False)
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index a0f32c5512..b55160b70a 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -127,8 +127,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
         self._assert_peek(room_id, expect_code=403)
 
     def _assert_peek(self, room_id, expect_code):
-        """Assert that the admin user can (or cannot) peek into the room.
-        """
+        """Assert that the admin user can (or cannot) peek into the room."""
 
         url = "rooms/%s/initialSync" % (room_id,)
         channel = self.make_request(
@@ -186,7 +185,10 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         """
 
         channel = self.make_request(
-            "POST", self.url, json.dumps({}), access_token=self.other_user_tok,
+            "POST",
+            self.url,
+            json.dumps({}),
+            access_token=self.other_user_tok,
         )
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
@@ -199,7 +201,10 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         url = "/_synapse/admin/v1/rooms/!unknown:test/delete"
 
         channel = self.make_request(
-            "POST", url, json.dumps({}), access_token=self.admin_user_tok,
+            "POST",
+            url,
+            json.dumps({}),
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
@@ -212,12 +217,16 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         url = "/_synapse/admin/v1/rooms/invalidroom/delete"
 
         channel = self.make_request(
-            "POST", url, json.dumps({}), access_token=self.admin_user_tok,
+            "POST",
+            url,
+            json.dumps({}),
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(
-            "invalidroom is not a legal room ID", channel.json_body["error"],
+            "invalidroom is not a legal room ID",
+            channel.json_body["error"],
         )
 
     def test_new_room_user_does_not_exist(self):
@@ -254,7 +263,8 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(
-            "User must be our own: @not:exist.bla", channel.json_body["error"],
+            "User must be our own: @not:exist.bla",
+            channel.json_body["error"],
         )
 
     def test_block_is_not_bool(self):
@@ -491,8 +501,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         self._assert_peek(self.room_id, expect_code=403)
 
     def _is_blocked(self, room_id, expect=True):
-        """Assert that the room is blocked or not
-        """
+        """Assert that the room is blocked or not"""
         d = self.store.is_room_blocked(room_id)
         if expect:
             self.assertTrue(self.get_success(d))
@@ -500,20 +509,17 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
             self.assertIsNone(self.get_success(d))
 
     def _has_no_members(self, room_id):
-        """Assert there is now no longer anyone in the room
-        """
+        """Assert there is now no longer anyone in the room"""
         users_in_room = self.get_success(self.store.get_users_in_room(room_id))
         self.assertEqual([], users_in_room)
 
     def _is_member(self, room_id, user_id):
-        """Test that user is member of the room
-        """
+        """Test that user is member of the room"""
         users_in_room = self.get_success(self.store.get_users_in_room(room_id))
         self.assertIn(user_id, users_in_room)
 
     def _is_purged(self, room_id):
-        """Test that the following tables have been purged of all rows related to the room.
-        """
+        """Test that the following tables have been purged of all rows related to the room."""
         for table in PURGE_TABLES:
             count = self.get_success(
                 self.store.db_pool.simple_select_one_onecol(
@@ -527,8 +533,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
             self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
 
     def _assert_peek(self, room_id, expect_code):
-        """Assert that the admin user can (or cannot) peek into the room.
-        """
+        """Assert that the admin user can (or cannot) peek into the room."""
 
         url = "rooms/%s/initialSync" % (room_id,)
         channel = self.make_request(
@@ -548,8 +553,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
 
 
 class PurgeRoomTestCase(unittest.HomeserverTestCase):
-    """Test /purge_room admin API.
-    """
+    """Test /purge_room admin API."""
 
     servlets = [
         synapse.rest.admin.register_servlets,
@@ -594,8 +598,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase):
 
 
 class RoomTestCase(unittest.HomeserverTestCase):
-    """Test /room admin API.
-    """
+    """Test /room admin API."""
 
     servlets = [
         synapse.rest.admin.register_servlets,
@@ -623,7 +626,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
         # Request the list of rooms
         url = "/_synapse/admin/v1/rooms"
         channel = self.make_request(
-            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+            "GET",
+            url.encode("ascii"),
+            access_token=self.admin_user_tok,
         )
 
         # Check request completed successfully
@@ -685,7 +690,10 @@ class RoomTestCase(unittest.HomeserverTestCase):
         # Set the name of the rooms so we get a consistent returned ordering
         for idx, room_id in enumerate(room_ids):
             self.helper.send_state(
-                room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok,
+                room_id,
+                "m.room.name",
+                {"name": str(idx)},
+                tok=self.admin_user_tok,
             )
 
         # Request the list of rooms
@@ -704,7 +712,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
                 "name",
             )
             channel = self.make_request(
-                "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+                "GET",
+                url.encode("ascii"),
+                access_token=self.admin_user_tok,
             )
             self.assertEqual(
                 200, int(channel.result["code"]), msg=channel.result["body"]
@@ -744,7 +754,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
 
         url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit)
         channel = self.make_request(
-            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+            "GET",
+            url.encode("ascii"),
+            access_token=self.admin_user_tok,
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
 
@@ -788,13 +800,18 @@ class RoomTestCase(unittest.HomeserverTestCase):
 
         # Set a name for the room
         self.helper.send_state(
-            room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok,
+            room_id,
+            "m.room.name",
+            {"name": test_room_name},
+            tok=self.admin_user_tok,
         )
 
         # Request the list of rooms
         url = "/_synapse/admin/v1/rooms"
         channel = self.make_request(
-            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+            "GET",
+            url.encode("ascii"),
+            access_token=self.admin_user_tok,
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
 
@@ -860,7 +877,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
             )
 
         def _order_test(
-            order_type: str, expected_room_list: List[str], reverse: bool = False,
+            order_type: str,
+            expected_room_list: List[str],
+            reverse: bool = False,
         ):
             """Request the list of rooms in a certain order. Assert that order is what
             we expect
@@ -875,7 +894,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
             if reverse:
                 url += "&dir=b"
             channel = self.make_request(
-                "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+                "GET",
+                url.encode("ascii"),
+                access_token=self.admin_user_tok,
             )
             self.assertEqual(200, channel.code, msg=channel.json_body)
 
@@ -907,13 +928,22 @@ class RoomTestCase(unittest.HomeserverTestCase):
 
         # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C
         self.helper.send_state(
-            room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok,
+            room_id_1,
+            "m.room.name",
+            {"name": "A"},
+            tok=self.admin_user_tok,
         )
         self.helper.send_state(
-            room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok,
+            room_id_2,
+            "m.room.name",
+            {"name": "B"},
+            tok=self.admin_user_tok,
         )
         self.helper.send_state(
-            room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok,
+            room_id_3,
+            "m.room.name",
+            {"name": "C"},
+            tok=self.admin_user_tok,
         )
 
         # Set room canonical room aliases
@@ -990,10 +1020,16 @@ class RoomTestCase(unittest.HomeserverTestCase):
 
         # Set the name for each room
         self.helper.send_state(
-            room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
+            room_id_1,
+            "m.room.name",
+            {"name": room_name_1},
+            tok=self.admin_user_tok,
         )
         self.helper.send_state(
-            room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
+            room_id_2,
+            "m.room.name",
+            {"name": room_name_2},
+            tok=self.admin_user_tok,
         )
 
         def _search_test(
@@ -1011,7 +1047,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
             """
             url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,)
             channel = self.make_request(
-                "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+                "GET",
+                url.encode("ascii"),
+                access_token=self.admin_user_tok,
             )
             self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
 
@@ -1071,15 +1109,23 @@ class RoomTestCase(unittest.HomeserverTestCase):
 
         # Set the name for each room
         self.helper.send_state(
-            room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
+            room_id_1,
+            "m.room.name",
+            {"name": room_name_1},
+            tok=self.admin_user_tok,
         )
         self.helper.send_state(
-            room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
+            room_id_2,
+            "m.room.name",
+            {"name": room_name_2},
+            tok=self.admin_user_tok,
         )
 
         url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
         channel = self.make_request(
-            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+            "GET",
+            url.encode("ascii"),
+            access_token=self.admin_user_tok,
         )
         self.assertEqual(200, channel.code, msg=channel.json_body)
 
@@ -1109,7 +1155,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
 
         url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
         channel = self.make_request(
-            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+            "GET",
+            url.encode("ascii"),
+            access_token=self.admin_user_tok,
         )
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(1, channel.json_body["joined_local_devices"])
@@ -1121,7 +1169,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
 
         url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
         channel = self.make_request(
-            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+            "GET",
+            url.encode("ascii"),
+            access_token=self.admin_user_tok,
         )
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(2, channel.json_body["joined_local_devices"])
@@ -1131,7 +1181,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
         self.helper.leave(room_id_1, user_1, tok=user_tok_1)
         url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
         channel = self.make_request(
-            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+            "GET",
+            url.encode("ascii"),
+            access_token=self.admin_user_tok,
         )
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(0, channel.json_body["joined_local_devices"])
@@ -1160,7 +1212,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
 
         url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,)
         channel = self.make_request(
-            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+            "GET",
+            url.encode("ascii"),
+            access_token=self.admin_user_tok,
         )
         self.assertEqual(200, channel.code, msg=channel.json_body)
 
@@ -1171,7 +1225,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
 
         url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,)
         channel = self.make_request(
-            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+            "GET",
+            url.encode("ascii"),
+            access_token=self.admin_user_tok,
         )
         self.assertEqual(200, channel.code, msg=channel.json_body)
 
@@ -1180,6 +1236,23 @@ class RoomTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(channel.json_body["total"], 3)
 
+    def test_room_state(self):
+        """Test that room state can be requested correctly"""
+        # Create two test rooms
+        room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+        url = "/_synapse/admin/v1/rooms/%s/state" % (room_id,)
+        channel = self.make_request(
+            "GET",
+            url.encode("ascii"),
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertIn("state", channel.json_body)
+        # testing that the state events match is painful and not done here. We assume that
+        # the create_room already does the right thing, so no need to verify that we got
+        # the state events it created.
+
 
 class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
 
@@ -1327,7 +1400,9 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         # Validate if user is a member of the room
 
         channel = self.make_request(
-            "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+            "GET",
+            "/_matrix/client/r0/joined_rooms",
+            access_token=self.second_tok,
         )
         self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
@@ -1374,7 +1449,9 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         # Validate if server admin is a member of the room
 
         channel = self.make_request(
-            "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok,
+            "GET",
+            "/_matrix/client/r0/joined_rooms",
+            access_token=self.admin_user_tok,
         )
         self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
@@ -1396,7 +1473,9 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         # Validate if user is a member of the room
 
         channel = self.make_request(
-            "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+            "GET",
+            "/_matrix/client/r0/joined_rooms",
+            access_token=self.second_tok,
         )
         self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
@@ -1425,11 +1504,97 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         # Validate if user is a member of the room
 
         channel = self.make_request(
-            "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+            "GET",
+            "/_matrix/client/r0/joined_rooms",
+            access_token=self.second_tok,
         )
         self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
 
+    def test_context_as_non_admin(self):
+        """
+        Test that, without being admin, one cannot use the context admin API
+        """
+        # Create a room.
+        user_id = self.register_user("test", "test")
+        user_tok = self.login("test", "test")
+
+        self.register_user("test_2", "test")
+        user_tok_2 = self.login("test_2", "test")
+
+        room_id = self.helper.create_room_as(user_id, tok=user_tok)
+
+        # Populate the room with events.
+        events = []
+        for i in range(30):
+            events.append(
+                self.helper.send_event(
+                    room_id, "com.example.test", content={"index": i}, tok=user_tok
+                )
+            )
+
+        # Now attempt to find the context using the admin API without being admin.
+        midway = (len(events) - 1) // 2
+        for tok in [user_tok, user_tok_2]:
+            channel = self.make_request(
+                "GET",
+                "/_synapse/admin/v1/rooms/%s/context/%s"
+                % (room_id, events[midway]["event_id"]),
+                access_token=tok,
+            )
+            self.assertEquals(
+                403, int(channel.result["code"]), msg=channel.result["body"]
+            )
+            self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_context_as_admin(self):
+        """
+        Test that, as admin, we can find the context of an event without having joined the room.
+        """
+
+        # Create a room. We're not part of it.
+        user_id = self.register_user("test", "test")
+        user_tok = self.login("test", "test")
+        room_id = self.helper.create_room_as(user_id, tok=user_tok)
+
+        # Populate the room with events.
+        events = []
+        for i in range(30):
+            events.append(
+                self.helper.send_event(
+                    room_id, "com.example.test", content={"index": i}, tok=user_tok
+                )
+            )
+
+        # Now let's fetch the context for this room.
+        midway = (len(events) - 1) // 2
+        channel = self.make_request(
+            "GET",
+            "/_synapse/admin/v1/rooms/%s/context/%s"
+            % (room_id, events[midway]["event_id"]),
+            access_token=self.admin_user_tok,
+        )
+        self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEquals(
+            channel.json_body["event"]["event_id"], events[midway]["event_id"]
+        )
+
+        for i, found_event in enumerate(channel.json_body["events_before"]):
+            for j, posted_event in enumerate(events):
+                if found_event["event_id"] == posted_event["event_id"]:
+                    self.assertTrue(j < midway)
+                    break
+            else:
+                self.fail("Event %s from events_before not found" % j)
+
+        for i, found_event in enumerate(channel.json_body["events_after"]):
+            for j, posted_event in enumerate(events):
+                if found_event["event_id"] == posted_event["event_id"]:
+                    self.assertTrue(j > midway)
+                    break
+            else:
+                self.fail("Event %s from events_after not found" % j)
+
 
 class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
     servlets = [
@@ -1456,8 +1621,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
         )
 
     def test_public_room(self):
-        """Test that getting admin in a public room works.
-        """
+        """Test that getting admin in a public room works."""
         room_id = self.helper.create_room_as(
             self.creator, tok=self.creator_tok, is_public=True
         )
@@ -1482,10 +1646,11 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
         )
 
     def test_private_room(self):
-        """Test that getting admin in a private room works and we get invited.
-        """
+        """Test that getting admin in a private room works and we get invited."""
         room_id = self.helper.create_room_as(
-            self.creator, tok=self.creator_tok, is_public=False,
+            self.creator,
+            tok=self.creator_tok,
+            is_public=False,
         )
 
         channel = self.make_request(
@@ -1509,8 +1674,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
         )
 
     def test_other_user(self):
-        """Test that giving admin in a public room works to a non-admin user works.
-        """
+        """Test that giving admin in a public room works to a non-admin user works."""
         room_id = self.helper.create_room_as(
             self.creator, tok=self.creator_tok, is_public=True
         )
@@ -1535,8 +1699,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
         )
 
     def test_not_enough_power(self):
-        """Test that we get a sensible error if there are no local room admins.
-        """
+        """Test that we get a sensible error if there are no local room admins."""
         room_id = self.helper.create_room_as(
             self.creator, tok=self.creator_tok, is_public=True
         )
diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py
index f48be3d65a..1f1d11f527 100644
--- a/tests/rest/admin/test_statistics.py
+++ b/tests/rest/admin/test_statistics.py
@@ -55,7 +55,10 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         If the user is not a server admin, an error 403 is returned.
         """
         channel = self.make_request(
-            "GET", self.url, json.dumps({}), access_token=self.other_user_tok,
+            "GET",
+            self.url,
+            json.dumps({}),
+            access_token=self.other_user_tok,
         )
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
@@ -67,7 +70,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         """
         # unkown order_by
         channel = self.make_request(
-            "GET", self.url + "?order_by=bar", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?order_by=bar",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -75,7 +80,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
 
         # negative from
         channel = self.make_request(
-            "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?from=-5",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -83,7 +90,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
 
         # negative limit
         channel = self.make_request(
-            "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?limit=-5",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -91,7 +100,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
 
         # negative from_ts
         channel = self.make_request(
-            "GET", self.url + "?from_ts=-1234", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?from_ts=-1234",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -99,7 +110,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
 
         # negative until_ts
         channel = self.make_request(
-            "GET", self.url + "?until_ts=-1234", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?until_ts=-1234",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -117,7 +130,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
 
         # empty search term
         channel = self.make_request(
-            "GET", self.url + "?search_term=", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?search_term=",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -125,7 +140,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
 
         # invalid search order
         channel = self.make_request(
-            "GET", self.url + "?dir=bar", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?dir=bar",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -138,7 +155,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         self._create_users_with_media(10, 2)
 
         channel = self.make_request(
-            "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?limit=5",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -154,7 +173,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         self._create_users_with_media(20, 2)
 
         channel = self.make_request(
-            "GET", self.url + "?from=5", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?from=5",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -170,7 +191,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         self._create_users_with_media(20, 2)
 
         channel = self.make_request(
-            "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?from=5&limit=10",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -190,7 +213,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         #  `next_token` does not appear
         # Number of results is the number of entries
         channel = self.make_request(
-            "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?limit=20",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -201,7 +226,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         #  `next_token` does not appear
         # Number of max results is larger than the number of entries
         channel = self.make_request(
-            "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?limit=21",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -212,7 +239,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         #  `next_token` does appear
         # Number of max results is smaller than the number of entries
         channel = self.make_request(
-            "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?limit=19",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -223,7 +252,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         # Set `from` to value of `next_token` for request remaining entries
         # Check `next_token` does not appear
         channel = self.make_request(
-            "GET", self.url + "?from=19", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?from=19",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -237,7 +268,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         if users have no media created
         """
 
-        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(0, channel.json_body["total"])
@@ -264,10 +299,14 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         # order by user_id
         self._order_test("user_id", ["@user_a:test", "@user_b:test", "@user_c:test"])
         self._order_test(
-            "user_id", ["@user_a:test", "@user_b:test", "@user_c:test"], "f",
+            "user_id",
+            ["@user_a:test", "@user_b:test", "@user_c:test"],
+            "f",
         )
         self._order_test(
-            "user_id", ["@user_c:test", "@user_b:test", "@user_a:test"], "b",
+            "user_id",
+            ["@user_c:test", "@user_b:test", "@user_a:test"],
+            "b",
         )
 
         # order by displayname
@@ -275,32 +314,46 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
             "displayname", ["@user_c:test", "@user_b:test", "@user_a:test"]
         )
         self._order_test(
-            "displayname", ["@user_c:test", "@user_b:test", "@user_a:test"], "f",
+            "displayname",
+            ["@user_c:test", "@user_b:test", "@user_a:test"],
+            "f",
         )
         self._order_test(
-            "displayname", ["@user_a:test", "@user_b:test", "@user_c:test"], "b",
+            "displayname",
+            ["@user_a:test", "@user_b:test", "@user_c:test"],
+            "b",
         )
 
         # order by media_length
         self._order_test(
-            "media_length", ["@user_a:test", "@user_c:test", "@user_b:test"],
+            "media_length",
+            ["@user_a:test", "@user_c:test", "@user_b:test"],
         )
         self._order_test(
-            "media_length", ["@user_a:test", "@user_c:test", "@user_b:test"], "f",
+            "media_length",
+            ["@user_a:test", "@user_c:test", "@user_b:test"],
+            "f",
         )
         self._order_test(
-            "media_length", ["@user_b:test", "@user_c:test", "@user_a:test"], "b",
+            "media_length",
+            ["@user_b:test", "@user_c:test", "@user_a:test"],
+            "b",
         )
 
         # order by media_count
         self._order_test(
-            "media_count", ["@user_a:test", "@user_c:test", "@user_b:test"],
+            "media_count",
+            ["@user_a:test", "@user_c:test", "@user_b:test"],
         )
         self._order_test(
-            "media_count", ["@user_a:test", "@user_c:test", "@user_b:test"], "f",
+            "media_count",
+            ["@user_a:test", "@user_c:test", "@user_b:test"],
+            "f",
         )
         self._order_test(
-            "media_count", ["@user_b:test", "@user_c:test", "@user_a:test"], "b",
+            "media_count",
+            ["@user_b:test", "@user_c:test", "@user_a:test"],
+            "b",
         )
 
     def test_from_until_ts(self):
@@ -313,14 +366,20 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         ts1 = self.clock.time_msec()
 
         # list all media when filter is not set
-        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(channel.json_body["users"][0]["media_count"], 3)
 
         # filter media starting at `ts1` after creating first media
         # result is 0
         channel = self.make_request(
-            "GET", self.url + "?from_ts=%s" % (ts1,), access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?from_ts=%s" % (ts1,),
+            access_token=self.admin_user_tok,
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(channel.json_body["total"], 0)
@@ -342,7 +401,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
 
         # filter media until `ts2` and earlier
         channel = self.make_request(
-            "GET", self.url + "?until_ts=%s" % (ts2,), access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?until_ts=%s" % (ts2,),
+            access_token=self.admin_user_tok,
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(channel.json_body["users"][0]["media_count"], 6)
@@ -351,7 +412,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         self._create_users_with_media(20, 1)
 
         # check without filter get all users
-        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(channel.json_body["total"], 20)
 
@@ -376,7 +441,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
 
         # filter and get empty result
         channel = self.make_request(
-            "GET", self.url + "?search_term=foobar", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?search_term=foobar",
+            access_token=self.admin_user_tok,
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(channel.json_body["total"], 0)
@@ -441,7 +508,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         if dir is not None and dir in ("b", "f"):
             url += "&dir=%s" % (dir,)
         channel = self.make_request(
-            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+            "GET",
+            url.encode("ascii"),
+            access_token=self.admin_user_tok,
         )
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(channel.json_body["total"], len(expected_user_list))
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 04599c2fcf..ba26895391 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -28,6 +28,7 @@ from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
 from synapse.api.room_versions import RoomVersions
 from synapse.rest.client.v1 import login, logout, profile, room
 from synapse.rest.client.v2_alpha import devices, sync
+from synapse.types import JsonDict
 
 from tests import unittest
 from tests.test_utils import make_awaitable
@@ -468,13 +469,6 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
-        self.user1 = self.register_user(
-            "user1", "pass1", admin=False, displayname="Name 1"
-        )
-        self.user2 = self.register_user(
-            "user2", "pass2", admin=False, displayname="Name 2"
-        )
-
     def test_no_auth(self):
         """
         Try to list users without authentication.
@@ -488,6 +482,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         """
         If the user is not a server admin, an error is returned.
         """
+        self._create_users(1)
         other_user_token = self.login("user1", "pass1")
 
         channel = self.make_request("GET", self.url, access_token=other_user_token)
@@ -499,6 +494,8 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         """
         List all users, including deactivated users.
         """
+        self._create_users(2)
+
         channel = self.make_request(
             "GET",
             self.url + "?deactivated=true",
@@ -511,14 +508,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         self.assertEqual(3, channel.json_body["total"])
 
         # Check that all fields are available
-        for u in channel.json_body["users"]:
-            self.assertIn("name", u)
-            self.assertIn("is_guest", u)
-            self.assertIn("admin", u)
-            self.assertIn("user_type", u)
-            self.assertIn("deactivated", u)
-            self.assertIn("displayname", u)
-            self.assertIn("avatar_url", u)
+        self._check_fields(channel.json_body["users"])
 
     def test_search_term(self):
         """Test that searching for a users works correctly"""
@@ -538,9 +528,14 @@ class UsersListTestCase(unittest.HomeserverTestCase):
                 search_field: Field which is to request: `name` or `user_id`
                 expected_http_code: The expected http code for the request
             """
-            url = self.url + "?%s=%s" % (search_field, search_term,)
+            url = self.url + "?%s=%s" % (
+                search_field,
+                search_term,
+            )
             channel = self.make_request(
-                "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+                "GET",
+                url.encode("ascii"),
+                access_token=self.admin_user_tok,
             )
             self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
 
@@ -549,6 +544,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
 
             # Check that users were returned
             self.assertTrue("users" in channel.json_body)
+            self._check_fields(channel.json_body["users"])
             users = channel.json_body["users"]
 
             # Check that the expected number of users were returned
@@ -561,25 +557,30 @@ class UsersListTestCase(unittest.HomeserverTestCase):
                 u = users[0]
                 self.assertEqual(expected_user_id, u["name"])
 
+        self._create_users(2)
+
+        user1 = "@user1:test"
+        user2 = "@user2:test"
+
         # Perform search tests
-        _search_test(self.user1, "er1")
-        _search_test(self.user1, "me 1")
+        _search_test(user1, "er1")
+        _search_test(user1, "me 1")
 
-        _search_test(self.user2, "er2")
-        _search_test(self.user2, "me 2")
+        _search_test(user2, "er2")
+        _search_test(user2, "me 2")
 
-        _search_test(self.user1, "er1", "user_id")
-        _search_test(self.user2, "er2", "user_id")
+        _search_test(user1, "er1", "user_id")
+        _search_test(user2, "er2", "user_id")
 
         # Test case insensitive
-        _search_test(self.user1, "ER1")
-        _search_test(self.user1, "NAME 1")
+        _search_test(user1, "ER1")
+        _search_test(user1, "NAME 1")
 
-        _search_test(self.user2, "ER2")
-        _search_test(self.user2, "NAME 2")
+        _search_test(user2, "ER2")
+        _search_test(user2, "NAME 2")
 
-        _search_test(self.user1, "ER1", "user_id")
-        _search_test(self.user2, "ER2", "user_id")
+        _search_test(user1, "ER1", "user_id")
+        _search_test(user2, "ER2", "user_id")
 
         _search_test(None, "foo")
         _search_test(None, "bar")
@@ -587,6 +588,205 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         _search_test(None, "foo", "user_id")
         _search_test(None, "bar", "user_id")
 
+    def test_invalid_parameter(self):
+        """
+        If parameters are invalid, an error is returned.
+        """
+
+        # negative limit
+        channel = self.make_request(
+            "GET",
+            self.url + "?limit=-5",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+        # negative from
+        channel = self.make_request(
+            "GET",
+            self.url + "?from=-5",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+        # invalid guests
+        channel = self.make_request(
+            "GET",
+            self.url + "?guests=not_bool",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+        # invalid deactivated
+        channel = self.make_request(
+            "GET",
+            self.url + "?deactivated=not_bool",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+    def test_limit(self):
+        """
+        Testing list of users with limit
+        """
+
+        number_users = 20
+        # Create one less user (since there's already an admin user).
+        self._create_users(number_users - 1)
+
+        channel = self.make_request(
+            "GET",
+            self.url + "?limit=5",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], number_users)
+        self.assertEqual(len(channel.json_body["users"]), 5)
+        self.assertEqual(channel.json_body["next_token"], "5")
+        self._check_fields(channel.json_body["users"])
+
+    def test_from(self):
+        """
+        Testing list of users with a defined starting point (from)
+        """
+
+        number_users = 20
+        # Create one less user (since there's already an admin user).
+        self._create_users(number_users - 1)
+
+        channel = self.make_request(
+            "GET",
+            self.url + "?from=5",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], number_users)
+        self.assertEqual(len(channel.json_body["users"]), 15)
+        self.assertNotIn("next_token", channel.json_body)
+        self._check_fields(channel.json_body["users"])
+
+    def test_limit_and_from(self):
+        """
+        Testing list of users with a defined starting point and limit
+        """
+
+        number_users = 20
+        # Create one less user (since there's already an admin user).
+        self._create_users(number_users - 1)
+
+        channel = self.make_request(
+            "GET",
+            self.url + "?from=5&limit=10",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], number_users)
+        self.assertEqual(channel.json_body["next_token"], "15")
+        self.assertEqual(len(channel.json_body["users"]), 10)
+        self._check_fields(channel.json_body["users"])
+
+    def test_next_token(self):
+        """
+        Testing that `next_token` appears at the right place
+        """
+
+        number_users = 20
+        # Create one less user (since there's already an admin user).
+        self._create_users(number_users - 1)
+
+        #  `next_token` does not appear
+        # Number of results is the number of entries
+        channel = self.make_request(
+            "GET",
+            self.url + "?limit=20",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], number_users)
+        self.assertEqual(len(channel.json_body["users"]), number_users)
+        self.assertNotIn("next_token", channel.json_body)
+
+        #  `next_token` does not appear
+        # Number of max results is larger than the number of entries
+        channel = self.make_request(
+            "GET",
+            self.url + "?limit=21",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], number_users)
+        self.assertEqual(len(channel.json_body["users"]), number_users)
+        self.assertNotIn("next_token", channel.json_body)
+
+        #  `next_token` does appear
+        # Number of max results is smaller than the number of entries
+        channel = self.make_request(
+            "GET",
+            self.url + "?limit=19",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], number_users)
+        self.assertEqual(len(channel.json_body["users"]), 19)
+        self.assertEqual(channel.json_body["next_token"], "19")
+
+        # Check
+        # Set `from` to value of `next_token` for request remaining entries
+        #  `next_token` does not appear
+        channel = self.make_request(
+            "GET",
+            self.url + "?from=19",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], number_users)
+        self.assertEqual(len(channel.json_body["users"]), 1)
+        self.assertNotIn("next_token", channel.json_body)
+
+    def _check_fields(self, content: JsonDict):
+        """Checks that the expected user attributes are present in content
+        Args:
+            content: List that is checked for content
+        """
+        for u in content:
+            self.assertIn("name", u)
+            self.assertIn("is_guest", u)
+            self.assertIn("admin", u)
+            self.assertIn("user_type", u)
+            self.assertIn("deactivated", u)
+            self.assertIn("shadow_banned", u)
+            self.assertIn("displayname", u)
+            self.assertIn("avatar_url", u)
+
+    def _create_users(self, number_users: int):
+        """
+        Create a number of users
+        Args:
+            number_users: Number of users to be created
+        """
+        for i in range(1, number_users + 1):
+            self.register_user(
+                "user%d" % i,
+                "pass%d" % i,
+                admin=False,
+                displayname="Name %d" % i,
+            )
+
 
 class DeactivateAccountTestCase(unittest.HomeserverTestCase):
 
@@ -639,7 +839,10 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
         self.assertEqual("You are not a server admin", channel.json_body["error"])
 
         channel = self.make_request(
-            "POST", url, access_token=self.other_user_token, content=b"{}",
+            "POST",
+            url,
+            access_token=self.other_user_token,
+            content=b"{}",
         )
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
@@ -693,7 +896,9 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
 
         # Get user
         channel = self.make_request(
-            "GET", self.url_other_user, access_token=self.admin_user_tok,
+            "GET",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -717,7 +922,9 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
 
         # Get user
         channel = self.make_request(
-            "GET", self.url_other_user, access_token=self.admin_user_tok,
+            "GET",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -736,7 +943,9 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
 
         # Get user
         channel = self.make_request(
-            "GET", self.url_other_user, access_token=self.admin_user_tok,
+            "GET",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -760,7 +969,9 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
 
         # Get user
         channel = self.make_request(
-            "GET", self.url_other_user, access_token=self.admin_user_tok,
+            "GET",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -773,8 +984,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
         self._is_erased("@user:test", False)
 
     def _is_erased(self, user_id: str, expect: bool) -> None:
-        """Assert that the user is erased or not
-        """
+        """Assert that the user is erased or not"""
         d = self.store.is_user_erased(user_id)
         if expect:
             self.assertTrue(self.get_success(d))
@@ -808,13 +1018,20 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         """
         url = "/_synapse/admin/v2/users/@bob:test"
 
-        channel = self.make_request("GET", url, access_token=self.other_user_token,)
+        channel = self.make_request(
+            "GET",
+            url,
+            access_token=self.other_user_token,
+        )
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("You are not a server admin", channel.json_body["error"])
 
         channel = self.make_request(
-            "PUT", url, access_token=self.other_user_token, content=b"{}",
+            "PUT",
+            url,
+            access_token=self.other_user_token,
+            content=b"{}",
         )
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
@@ -867,7 +1084,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
 
         # Get user
-        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob:test", channel.json_body["name"])
@@ -912,7 +1133,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
 
         # Get user
-        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob:test", channel.json_body["name"])
@@ -922,6 +1147,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(False, channel.json_body["admin"])
         self.assertEqual(False, channel.json_body["is_guest"])
         self.assertEqual(False, channel.json_body["deactivated"])
+        self.assertEqual(False, channel.json_body["shadow_banned"])
         self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
 
     @override_config(
@@ -1137,7 +1363,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         # Get user
         channel = self.make_request(
-            "GET", self.url_other_user, access_token=self.admin_user_tok,
+            "GET",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1168,7 +1396,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         # Get user
         channel = self.make_request(
-            "GET", self.url_other_user, access_token=self.admin_user_tok,
+            "GET",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1191,7 +1421,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         # Get user
         channel = self.make_request(
-            "GET", self.url_other_user, access_token=self.admin_user_tok,
+            "GET",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1221,7 +1453,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         # Get user
         channel = self.make_request(
-            "GET", self.url_other_user, access_token=self.admin_user_tok,
+            "GET",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1319,7 +1553,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         # Get user
         channel = self.make_request(
-            "GET", self.url_other_user, access_token=self.admin_user_tok,
+            "GET",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1348,7 +1584,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         # Get user
         channel = self.make_request(
-            "GET", self.url_other_user, access_token=self.admin_user_tok,
+            "GET",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1377,7 +1615,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("bob", channel.json_body["displayname"])
 
         # Get user
-        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob:test", channel.json_body["name"])
@@ -1397,7 +1639,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
 
         # Check user is not deactivated
-        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob:test", channel.json_body["name"])
@@ -1407,8 +1653,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(0, channel.json_body["deactivated"])
 
     def _is_erased(self, user_id, expect):
-        """Assert that the user is erased or not
-        """
+        """Assert that the user is erased or not"""
         d = self.store.is_user_erased(user_id)
         if expect:
             self.assertTrue(self.get_success(d))
@@ -1448,7 +1693,11 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
         """
         other_user_token = self.login("user", "pass")
 
-        channel = self.make_request("GET", self.url, access_token=other_user_token,)
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=other_user_token,
+        )
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -1458,7 +1707,11 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
         Tests that a lookup for a user that does not exist returns an empty list
         """
         url = "/_synapse/admin/v1/users/@unknown_person:test/joined_rooms"
-        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(0, channel.json_body["total"])
@@ -1470,7 +1723,11 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
         """
         url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/joined_rooms"
 
-        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(0, channel.json_body["total"])
@@ -1482,7 +1739,11 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
         if user has no memberships
         """
         # Get rooms
-        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(0, channel.json_body["total"])
@@ -1499,7 +1760,11 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
             self.helper.create_room_as(self.other_user, tok=other_user_tok)
 
         # Get rooms
-        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(number_rooms, channel.json_body["total"])
@@ -1542,7 +1807,11 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
 
         # Now get rooms
         url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms"
-        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(1, channel.json_body["total"])
@@ -1582,7 +1851,11 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
         """
         other_user_token = self.login("user", "pass")
 
-        channel = self.make_request("GET", self.url, access_token=other_user_token,)
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=other_user_token,
+        )
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -1592,7 +1865,11 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
         Tests that a lookup for a user that does not exist returns a 404
         """
         url = "/_synapse/admin/v1/users/@unknown_person:test/pushers"
-        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -1603,7 +1880,11 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
         """
         url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers"
 
-        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -1614,7 +1895,11 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
         """
 
         # Get pushers
-        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(0, channel.json_body["total"])
@@ -1641,7 +1926,11 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
         )
 
         # Get pushers
-        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(1, channel.json_body["total"])
@@ -1690,7 +1979,11 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         """
         other_user_token = self.login("user", "pass")
 
-        channel = self.make_request("GET", self.url, access_token=other_user_token,)
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=other_user_token,
+        )
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -1700,7 +1993,11 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         Tests that a lookup for a user that does not exist returns a 404
         """
         url = "/_synapse/admin/v1/users/@unknown_person:test/media"
-        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -1711,7 +2008,11 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         """
         url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media"
 
-        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -1726,7 +2027,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         self._create_media(other_user_tok, number_media)
 
         channel = self.make_request(
-            "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?limit=5",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1745,7 +2048,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         self._create_media(other_user_tok, number_media)
 
         channel = self.make_request(
-            "GET", self.url + "?from=5", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?from=5",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1764,7 +2069,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         self._create_media(other_user_tok, number_media)
 
         channel = self.make_request(
-            "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?from=5&limit=10",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1779,7 +2086,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         """
 
         channel = self.make_request(
-            "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?limit=-5",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -1791,7 +2100,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         """
 
         channel = self.make_request(
-            "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?from=-5",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -1809,7 +2120,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         #  `next_token` does not appear
         # Number of results is the number of entries
         channel = self.make_request(
-            "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?limit=20",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1820,7 +2133,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         #  `next_token` does not appear
         # Number of max results is larger than the number of entries
         channel = self.make_request(
-            "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?limit=21",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1831,7 +2146,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         #  `next_token` does appear
         # Number of max results is smaller than the number of entries
         channel = self.make_request(
-            "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?limit=19",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1843,7 +2160,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         # Set `from` to value of `next_token` for request remaining entries
         #  `next_token` does not appear
         channel = self.make_request(
-            "GET", self.url + "?from=19", access_token=self.admin_user_tok,
+            "GET",
+            self.url + "?from=19",
+            access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1857,7 +2176,11 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         if user has no media created
         """
 
-        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(0, channel.json_body["total"])
@@ -1872,7 +2195,11 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         other_user_tok = self.login("user", "pass")
         self._create_media(other_user_tok, number_media)
 
-        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(number_media, channel.json_body["total"])
@@ -1899,8 +2226,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
             )
 
     def _check_fields(self, content):
-        """Checks that all attributes are present in content
-        """
+        """Checks that all attributes are present in content"""
         for m in content:
             self.assertIn("media_id", m)
             self.assertIn("media_type", m)
@@ -1913,8 +2239,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
 
 
 class UserTokenRestTestCase(unittest.HomeserverTestCase):
-    """Test for /_synapse/admin/v1/users/<user>/login
-    """
+    """Test for /_synapse/admin/v1/users/<user>/login"""
 
     servlets = [
         synapse.rest.admin.register_servlets,
@@ -1945,16 +2270,14 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
         return channel.json_body["access_token"]
 
     def test_no_auth(self):
-        """Try to login as a user without authentication.
-        """
+        """Try to login as a user without authentication."""
         channel = self.make_request("POST", self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
 
     def test_not_admin(self):
-        """Try to login as a user as a non-admin user.
-        """
+        """Try to login as a user as a non-admin user."""
         channel = self.make_request(
             "POST", self.url, b"{}", access_token=self.other_user_tok
         )
@@ -1962,8 +2285,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
 
     def test_send_event(self):
-        """Test that sending event as a user works.
-        """
+        """Test that sending event as a user works."""
         # Create a room.
         room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_tok)
 
@@ -1977,8 +2299,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(event.sender, self.other_user)
 
     def test_devices(self):
-        """Tests that logging in as a user doesn't create a new device for them.
-        """
+        """Tests that logging in as a user doesn't create a new device for them."""
         # Login in as the user
         self._get_token()
 
@@ -1992,8 +2313,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(len(channel.json_body["devices"]), 1)
 
     def test_logout(self):
-        """Test that calling `/logout` with the token works.
-        """
+        """Test that calling `/logout` with the token works."""
         # Login in as the user
         puppet_token = self._get_token()
 
@@ -2083,8 +2403,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
         }
     )
     def test_consent(self):
-        """Test that sending a message is not subject to the privacy policies.
-        """
+        """Test that sending a message is not subject to the privacy policies."""
         # Have the admin user accept the terms.
         self.get_success(self.store.user_set_consent_version(self.admin_user, "1.0"))
 
@@ -2159,11 +2478,19 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
         self.register_user("user2", "pass")
         other_user2_token = self.login("user2", "pass")
 
-        channel = self.make_request("GET", self.url1, access_token=other_user2_token,)
+        channel = self.make_request(
+            "GET",
+            self.url1,
+            access_token=other_user2_token,
+        )
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
-        channel = self.make_request("GET", self.url2, access_token=other_user2_token,)
+        channel = self.make_request(
+            "GET",
+            self.url2,
+            access_token=other_user2_token,
+        )
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
@@ -2174,11 +2501,19 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
         url1 = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
         url2 = "/_matrix/client/r0/admin/whois/@unknown_person:unknown_domain"
 
-        channel = self.make_request("GET", url1, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            url1,
+            access_token=self.admin_user_tok,
+        )
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Can only whois a local user", channel.json_body["error"])
 
-        channel = self.make_request("GET", url2, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            url2,
+            access_token=self.admin_user_tok,
+        )
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Can only whois a local user", channel.json_body["error"])
 
@@ -2186,12 +2521,20 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
         """
         The lookup should succeed for an admin.
         """
-        channel = self.make_request("GET", self.url1, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            self.url1,
+            access_token=self.admin_user_tok,
+        )
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(self.other_user, channel.json_body["user_id"])
         self.assertIn("devices", channel.json_body)
 
-        channel = self.make_request("GET", self.url2, access_token=self.admin_user_tok,)
+        channel = self.make_request(
+            "GET",
+            self.url2,
+            access_token=self.admin_user_tok,
+        )
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(self.other_user, channel.json_body["user_id"])
         self.assertIn("devices", channel.json_body)
@@ -2202,12 +2545,84 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
         """
         other_user_token = self.login("user", "pass")
 
-        channel = self.make_request("GET", self.url1, access_token=other_user_token,)
+        channel = self.make_request(
+            "GET",
+            self.url1,
+            access_token=other_user_token,
+        )
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(self.other_user, channel.json_body["user_id"])
         self.assertIn("devices", channel.json_body)
 
-        channel = self.make_request("GET", self.url2, access_token=other_user_token,)
+        channel = self.make_request(
+            "GET",
+            self.url2,
+            access_token=other_user_token,
+        )
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(self.other_user, channel.json_body["user_id"])
         self.assertIn("devices", channel.json_body)
+
+
+class ShadowBanRestTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.other_user = self.register_user("user", "pass")
+
+        self.url = "/_synapse/admin/v1/users/%s/shadow_ban" % urllib.parse.quote(
+            self.other_user
+        )
+
+    def test_no_auth(self):
+        """
+        Try to get information of an user without authentication.
+        """
+        channel = self.make_request("POST", self.url)
+        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_requester_is_not_admin(self):
+        """
+        If the user is not a server admin, an error is returned.
+        """
+        other_user_token = self.login("user", "pass")
+
+        channel = self.make_request("POST", self.url, access_token=other_user_token)
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_user_is_not_local(self):
+        """
+        Tests that shadow-banning for a user that is not a local returns a 400
+        """
+        url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
+
+        channel = self.make_request("POST", url, access_token=self.admin_user_tok)
+        self.assertEqual(400, channel.code, msg=channel.json_body)
+
+    def test_success(self):
+        """
+        Shadow-banning should succeed for an admin.
+        """
+        # The user starts off as not shadow-banned.
+        other_user_token = self.login("user", "pass")
+        result = self.get_success(self.store.get_user_by_access_token(other_user_token))
+        self.assertFalse(result.shadow_banned)
+
+        channel = self.make_request("POST", self.url, access_token=self.admin_user_tok)
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual({}, channel.json_body)
+
+        # Ensure the user is shadow-banned (and the cache was cleared).
+        result = self.get_success(self.store.get_user_by_access_token(other_user_token))
+        self.assertTrue(result.shadow_banned)
diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py
index 913ea3c98e..5256c11fe6 100644
--- a/tests/rest/client/test_power_levels.py
+++ b/tests/rest/client/test_power_levels.py
@@ -73,7 +73,9 @@ class PowerLevelsTestCase(HomeserverTestCase):
 
         # Mod the mod
         room_power_levels = self.helper.get_state(
-            self.room_id, "m.room.power_levels", tok=self.admin_access_token,
+            self.room_id,
+            "m.room.power_levels",
+            tok=self.admin_access_token,
         )
 
         # Update existing power levels with mod at PL50
diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py
index f0707646bb..e0c74591b6 100644
--- a/tests/rest/client/test_redactions.py
+++ b/tests/rest/client/test_redactions.py
@@ -181,8 +181,7 @@ class RedactionsTestCase(HomeserverTestCase):
         )
 
     def test_redact_event_as_moderator_ratelimit(self):
-        """Tests that the correct ratelimiting is applied to redactions
-        """
+        """Tests that the correct ratelimiting is applied to redactions"""
 
         message_ids = []
         # as a regular user, send messages to redact
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 31dc832fd5..aee99bb6a0 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -250,7 +250,8 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
         mock_federation_client = Mock(spec=["backfill"])
 
         self.hs = self.setup_test_homeserver(
-            config=config, federation_client=mock_federation_client,
+            config=config,
+            federation_client=mock_federation_client,
         )
         return self.hs
 
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index e689c3fbea..d2cce44032 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -18,6 +18,7 @@ import synapse.rest.admin
 from synapse.api.constants import EventTypes
 from synapse.rest.client.v1 import directory, login, profile, room
 from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet
+from synapse.types import UserID
 
 from tests import unittest
 
@@ -31,12 +32,7 @@ class _ShadowBannedBase(unittest.HomeserverTestCase):
         self.store = self.hs.get_datastore()
 
         self.get_success(
-            self.store.db_pool.simple_update(
-                table="users",
-                keyvalues={"name": self.banned_user_id},
-                updatevalues={"shadow_banned": True},
-                desc="shadow_ban",
-            )
+            self.store.set_shadow_banned(UserID.from_string(self.banned_user_id), True)
         )
 
         self.other_user_id = self.register_user("otheruser", "pass")
@@ -264,7 +260,10 @@ class ProfileTestCase(_ShadowBannedBase):
         message_handler = self.hs.get_message_handler()
         event = self.get_success(
             message_handler.get_room_data(
-                self.banned_user_id, room_id, "m.room.member", self.banned_user_id,
+                self.banned_user_id,
+                room_id,
+                "m.room.member",
+                self.banned_user_id,
             )
         )
         self.assertEqual(
@@ -296,7 +295,10 @@ class ProfileTestCase(_ShadowBannedBase):
         message_handler = self.hs.get_message_handler()
         event = self.get_success(
             message_handler.get_room_data(
-                self.banned_user_id, room_id, "m.room.member", self.banned_user_id,
+                self.banned_user_id,
+                room_id,
+                "m.room.member",
+                self.banned_user_id,
             )
         )
         self.assertEqual(
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index 0a5ca317ea..2ae896db1e 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -150,6 +150,8 @@ class GetEventsTestCase(unittest.HomeserverTestCase):
         event_id = resp["event_id"]
 
         channel = self.make_request(
-            "GET", "/events/" + event_id, access_token=self.token,
+            "GET",
+            "/events/" + event_id,
+            access_token=self.token,
         )
         self.assertEquals(channel.code, 200, msg=channel.result)
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 2672ce24c6..fb29eaed6f 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -15,7 +15,7 @@
 
 import time
 import urllib.parse
-from typing import Any, Dict, Union
+from typing import Any, Dict, List, Union
 from urllib.parse import urlencode
 
 from mock import Mock
@@ -29,8 +29,7 @@ from synapse.appservice import ApplicationService
 from synapse.rest.client.v1 import login, logout
 from synapse.rest.client.v2_alpha import devices, register
 from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
-from synapse.rest.synapse.client.pick_idp import PickIdpResource
-from synapse.rest.synapse.client.pick_username import pick_username_resource
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
 from synapse.types import create_requester
 
 from tests import unittest
@@ -75,6 +74,10 @@ TEST_CLIENT_REDIRECT_URL = 'https://x?<ab c>&q"+%3D%2B"="fö%26=o"'
 # the query params in TEST_CLIENT_REDIRECT_URL
 EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fö&=o"')]
 
+# (possibly experimental) login flows we expect to appear in the list after the normal
+# ones
+ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}]
+
 
 class LoginRestServletTestCase(unittest.HomeserverTestCase):
 
@@ -419,13 +422,61 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         return config
 
     def create_resource_dict(self) -> Dict[str, Resource]:
-        from synapse.rest.oidc import OIDCResource
-
         d = super().create_resource_dict()
-        d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs)
-        d["/_synapse/oidc"] = OIDCResource(self.hs)
+        d.update(build_synapse_client_resource_tree(self.hs))
         return d
 
+    def test_get_login_flows(self):
+        """GET /login should return password and SSO flows"""
+        channel = self.make_request("GET", "/_matrix/client/r0/login")
+        self.assertEqual(channel.code, 200, channel.result)
+
+        expected_flows = [
+            {"type": "m.login.cas"},
+            {"type": "m.login.sso"},
+            {"type": "m.login.token"},
+            {"type": "m.login.password"},
+        ] + ADDITIONAL_LOGIN_FLOWS
+
+        self.assertCountEqual(channel.json_body["flows"], expected_flows)
+
+    @override_config({"experimental_features": {"msc2858_enabled": True}})
+    def test_get_msc2858_login_flows(self):
+        """The SSO flow should include IdP info if MSC2858 is enabled"""
+        channel = self.make_request("GET", "/_matrix/client/r0/login")
+        self.assertEqual(channel.code, 200, channel.result)
+
+        # stick the flows results in a dict by type
+        flow_results = {}  # type: Dict[str, Any]
+        for f in channel.json_body["flows"]:
+            flow_type = f["type"]
+            self.assertNotIn(
+                flow_type, flow_results, "duplicate flow type %s" % (flow_type,)
+            )
+            flow_results[flow_type] = f
+
+        self.assertIn("m.login.sso", flow_results, "m.login.sso was not returned")
+        sso_flow = flow_results.pop("m.login.sso")
+        # we should have a set of IdPs
+        self.assertCountEqual(
+            sso_flow["org.matrix.msc2858.identity_providers"],
+            [
+                {"id": "cas", "name": "CAS"},
+                {"id": "saml", "name": "SAML"},
+                {"id": "oidc-idp1", "name": "IDP1"},
+                {"id": "oidc", "name": "OIDC"},
+            ],
+        )
+
+        # the rest of the flows are simple
+        expected_flows = [
+            {"type": "m.login.cas"},
+            {"type": "m.login.token"},
+            {"type": "m.login.password"},
+        ] + ADDITIONAL_LOGIN_FLOWS
+
+        self.assertCountEqual(flow_results.values(), expected_flows)
+
     def test_multi_sso_redirect(self):
         """/login/sso/redirect should redirect to an identity picker"""
         # first hit the redirect url, which should redirect to our idp picker
@@ -442,13 +493,21 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200, channel.result)
 
         # parse the form to check it has fields assumed elsewhere in this class
+        html = channel.result["body"].decode("utf-8")
         p = TestHtmlParser()
-        p.feed(channel.result["body"].decode("utf-8"))
+        p.feed(html)
         p.close()
 
-        self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "oidc-idp1", "saml"])
+        # there should be a link for each href
+        returned_idps = []  # type: List[str]
+        for link in p.links:
+            path, query = link.split("?", 1)
+            self.assertEqual(path, "pick_idp")
+            params = urllib.parse.parse_qs(query)
+            self.assertEqual(params["redirectUrl"], [TEST_CLIENT_REDIRECT_URL])
+            returned_idps.append(params["idp"][0])
 
-        self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL)
+        self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"])
 
     def test_multi_sso_redirect_to_cas(self):
         """If CAS is chosen, should redirect to the CAS server"""
@@ -552,7 +611,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         # matrix access token, mxid, and device id.
         login_token = params[2][1]
         chan = self.make_request(
-            "POST", "/login", content={"type": "m.login.token", "token": login_token},
+            "POST",
+            "/login",
+            content={"type": "m.login.token", "token": login_token},
         )
         self.assertEqual(chan.code, 200, chan.result)
         self.assertEqual(chan.json_body["user_id"], "@user1:test")
@@ -560,9 +621,47 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
     def test_multi_sso_redirect_to_unknown(self):
         """An unknown IdP should cause a 400"""
         channel = self.make_request(
-            "GET", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
+            "GET",
+            "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
+        )
+        self.assertEqual(channel.code, 400, channel.result)
+
+    def test_client_idp_redirect_msc2858_disabled(self):
+        """If the client tries to pick an IdP but MSC2858 is disabled, return a 400"""
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
+            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
         )
         self.assertEqual(channel.code, 400, channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
+
+    @override_config({"experimental_features": {"msc2858_enabled": True}})
+    def test_client_idp_redirect_to_unknown(self):
+        """If the client tries to pick an unknown IdP, return a 404"""
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/xxx?redirectUrl="
+            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
+        )
+        self.assertEqual(channel.code, 404, channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
+
+    @override_config({"experimental_features": {"msc2858_enabled": True}})
+    def test_client_idp_redirect_to_oidc(self):
+        """If the client pick a known IdP, redirect to it"""
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
+            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
+        )
+
+        self.assertEqual(channel.code, 302, channel.result)
+        oidc_uri = channel.headers.getRawHeaders("Location")[0]
+        oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
+
+        # it should redirect us to the auth page of the OIDC server
+        self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
 
     @staticmethod
     def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
@@ -584,10 +683,12 @@ class CASTestCase(unittest.HomeserverTestCase):
         self.redirect_path = "_synapse/client/login/sso/redirect/confirm"
 
         config = self.default_config()
+        config["public_baseurl"] = (
+            config.get("public_baseurl") or "https://matrix.goodserver.com:8448"
+        )
         config["cas_config"] = {
             "enabled": True,
             "server_url": CAS_SERVER,
-            "service_url": "https://matrix.goodserver.com:8448",
         }
 
         cas_user_id = "username"
@@ -621,7 +722,8 @@ class CASTestCase(unittest.HomeserverTestCase):
         mocked_http_client.get_raw.side_effect = get_raw
 
         self.hs = self.setup_test_homeserver(
-            config=config, proxied_http_client=mocked_http_client,
+            config=config,
+            proxied_http_client=mocked_http_client,
         )
 
         return self.hs
@@ -1119,11 +1221,8 @@ class UsernamePickerTestCase(HomeserverTestCase):
         return config
 
     def create_resource_dict(self) -> Dict[str, Resource]:
-        from synapse.rest.oidc import OIDCResource
-
         d = super().create_resource_dict()
-        d["/_synapse/client/pick_username"] = pick_username_resource(self.hs)
-        d["/_synapse/oidc"] = OIDCResource(self.hs)
+        d.update(build_synapse_client_resource_tree(self.hs))
         return d
 
     def test_username_picker(self):
@@ -1137,7 +1236,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
         # that should redirect to the username picker
         self.assertEqual(channel.code, 302, channel.result)
         picker_url = channel.headers.getRawHeaders("Location")[0]
-        self.assertEqual(picker_url, "/_synapse/client/pick_username")
+        self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details")
 
         # ... with a username_mapping_session cookie
         cookies = {}  # type: Dict[str,str]
@@ -1149,7 +1248,9 @@ class UsernamePickerTestCase(HomeserverTestCase):
         # looks ok.
         username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions
         self.assertIn(
-            session_id, username_mapping_sessions, "session id not found in map",
+            session_id,
+            username_mapping_sessions,
+            "session id not found in map",
         )
         session = username_mapping_sessions[session_id]
         self.assertEqual(session.remote_user_id, "tester")
@@ -1161,12 +1262,11 @@ class UsernamePickerTestCase(HomeserverTestCase):
         self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)
 
         # Now, submit a username to the username picker, which should serve a redirect
-        # back to the client
-        submit_path = picker_url + "/submit"
+        # to the completion page
         content = urlencode({b"username": b"bobby"}).encode("utf8")
         chan = self.make_request(
             "POST",
-            path=submit_path,
+            path=picker_url,
             content=content,
             content_is_form=True,
             custom_headers=[
@@ -1178,6 +1278,16 @@ class UsernamePickerTestCase(HomeserverTestCase):
         )
         self.assertEqual(chan.code, 302, chan.result)
         location_headers = chan.headers.getRawHeaders("Location")
+
+        # send a request to the completion page, which should 302 to the client redirectUrl
+        chan = self.make_request(
+            "GET",
+            path=location_headers[0],
+            custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
+        )
+        self.assertEqual(chan.code, 302, chan.result)
+        location_headers = chan.headers.getRawHeaders("Location")
+
         # ensure that the returned location matches the requested redirect URL
         path, query = location_headers[0].split("?", 1)
         self.assertEqual(path, "https://x")
@@ -1195,7 +1305,9 @@ class UsernamePickerTestCase(HomeserverTestCase):
         # finally, submit the matrix login token to the login API, which gives us our
         # matrix access token, mxid, and device id.
         chan = self.make_request(
-            "POST", "/login", content={"type": "m.login.token", "token": login_token},
+            "POST",
+            "/login",
+            content={"type": "m.login.token", "token": login_token},
         )
         self.assertEqual(chan.code, 200, chan.result)
         self.assertEqual(chan.json_body["user_id"], "@bobby:test")
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index e59fa70baa..f3448c94dd 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -14,163 +14,11 @@
 # limitations under the License.
 
 """Tests REST events for /profile paths."""
-import json
-
-from mock import Mock
-
-from twisted.internet import defer
-
-import synapse.types
-from synapse.api.errors import AuthError, SynapseError
 from synapse.rest import admin
 from synapse.rest.client.v1 import login, profile, room
 
 from tests import unittest
 
-from ....utils import MockHttpResource, setup_test_homeserver
-
-myid = "@1234ABCD:test"
-PATH_PREFIX = "/_matrix/client/r0"
-
-
-class MockHandlerProfileTestCase(unittest.TestCase):
-    """ Tests rest layer of profile management.
-
-    Todo: move these into ProfileTestCase
-    """
-
-    @defer.inlineCallbacks
-    def setUp(self):
-        self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
-        self.mock_handler = Mock(
-            spec=[
-                "get_displayname",
-                "set_displayname",
-                "get_avatar_url",
-                "set_avatar_url",
-                "check_profile_query_allowed",
-            ]
-        )
-
-        self.mock_handler.get_displayname.return_value = defer.succeed(Mock())
-        self.mock_handler.set_displayname.return_value = defer.succeed(Mock())
-        self.mock_handler.get_avatar_url.return_value = defer.succeed(Mock())
-        self.mock_handler.set_avatar_url.return_value = defer.succeed(Mock())
-        self.mock_handler.check_profile_query_allowed.return_value = defer.succeed(
-            Mock()
-        )
-
-        hs = yield setup_test_homeserver(
-            self.addCleanup,
-            "test",
-            federation_http_client=None,
-            resource_for_client=self.mock_resource,
-            federation=Mock(),
-            federation_client=Mock(),
-            profile_handler=self.mock_handler,
-        )
-
-        async def _get_user_by_req(request=None, allow_guest=False):
-            return synapse.types.create_requester(myid)
-
-        hs.get_auth().get_user_by_req = _get_user_by_req
-
-        profile.register_servlets(hs, self.mock_resource)
-
-    @defer.inlineCallbacks
-    def test_get_my_name(self):
-        mocked_get = self.mock_handler.get_displayname
-        mocked_get.return_value = defer.succeed("Frank")
-
-        (code, response) = yield self.mock_resource.trigger(
-            "GET", "/profile/%s/displayname" % (myid), None
-        )
-
-        self.assertEquals(200, code)
-        self.assertEquals({"displayname": "Frank"}, response)
-        self.assertEquals(mocked_get.call_args[0][0].localpart, "1234ABCD")
-
-    @defer.inlineCallbacks
-    def test_set_my_name(self):
-        mocked_set = self.mock_handler.set_displayname
-        mocked_set.return_value = defer.succeed(())
-
-        (code, response) = yield self.mock_resource.trigger(
-            "PUT", "/profile/%s/displayname" % (myid), b'{"displayname": "Frank Jr."}'
-        )
-
-        self.assertEquals(200, code)
-        self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD")
-        self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD")
-        self.assertEquals(mocked_set.call_args[0][2], "Frank Jr.")
-
-    @defer.inlineCallbacks
-    def test_set_my_name_noauth(self):
-        mocked_set = self.mock_handler.set_displayname
-        mocked_set.side_effect = AuthError(400, "message")
-
-        (code, response) = yield self.mock_resource.trigger(
-            "PUT",
-            "/profile/%s/displayname" % ("@4567:test"),
-            b'{"displayname": "Frank Jr."}',
-        )
-
-        self.assertTrue(400 <= code < 499, msg="code %d is in the 4xx range" % (code))
-
-    @defer.inlineCallbacks
-    def test_get_other_name(self):
-        mocked_get = self.mock_handler.get_displayname
-        mocked_get.return_value = defer.succeed("Bob")
-
-        (code, response) = yield self.mock_resource.trigger(
-            "GET", "/profile/%s/displayname" % ("@opaque:elsewhere"), None
-        )
-
-        self.assertEquals(200, code)
-        self.assertEquals({"displayname": "Bob"}, response)
-
-    @defer.inlineCallbacks
-    def test_set_other_name(self):
-        mocked_set = self.mock_handler.set_displayname
-        mocked_set.side_effect = SynapseError(400, "message")
-
-        (code, response) = yield self.mock_resource.trigger(
-            "PUT",
-            "/profile/%s/displayname" % ("@opaque:elsewhere"),
-            b'{"displayname":"bob"}',
-        )
-
-        self.assertTrue(400 <= code <= 499, msg="code %d is in the 4xx range" % (code))
-
-    @defer.inlineCallbacks
-    def test_get_my_avatar(self):
-        mocked_get = self.mock_handler.get_avatar_url
-        mocked_get.return_value = defer.succeed("http://my.server/me.png")
-
-        (code, response) = yield self.mock_resource.trigger(
-            "GET", "/profile/%s/avatar_url" % (myid), None
-        )
-
-        self.assertEquals(200, code)
-        self.assertEquals({"avatar_url": "http://my.server/me.png"}, response)
-        self.assertEquals(mocked_get.call_args[0][0].localpart, "1234ABCD")
-
-    @defer.inlineCallbacks
-    def test_set_my_avatar(self):
-        mocked_set = self.mock_handler.set_avatar_url
-        mocked_set.return_value = defer.succeed(())
-
-        (code, response) = yield self.mock_resource.trigger(
-            "PUT",
-            "/profile/%s/avatar_url" % (myid),
-            b'{"avatar_url": "http://my.server/pic.gif"}',
-        )
-
-        self.assertEquals(200, code)
-        self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD")
-        self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD")
-        self.assertEquals(mocked_set.call_args[0][2], "http://my.server/pic.gif")
-
 
 class ProfileTestCase(unittest.HomeserverTestCase):
 
@@ -187,37 +35,122 @@ class ProfileTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
         self.owner = self.register_user("owner", "pass")
         self.owner_tok = self.login("owner", "pass")
+        self.other = self.register_user("other", "pass", displayname="Bob")
+
+    def test_get_displayname(self):
+        res = self._get_displayname()
+        self.assertEqual(res, "owner")
 
     def test_set_displayname(self):
         channel = self.make_request(
             "PUT",
             "/profile/%s/displayname" % (self.owner,),
-            content=json.dumps({"displayname": "test"}),
+            content={"displayname": "test"},
             access_token=self.owner_tok,
         )
         self.assertEqual(channel.code, 200, channel.result)
 
-        res = self.get_displayname()
+        res = self._get_displayname()
         self.assertEqual(res, "test")
 
+    def test_set_displayname_noauth(self):
+        channel = self.make_request(
+            "PUT",
+            "/profile/%s/displayname" % (self.owner,),
+            content={"displayname": "test"},
+        )
+        self.assertEqual(channel.code, 401, channel.result)
+
     def test_set_displayname_too_long(self):
         """Attempts to set a stupid displayname should get a 400"""
         channel = self.make_request(
             "PUT",
             "/profile/%s/displayname" % (self.owner,),
-            content=json.dumps({"displayname": "test" * 100}),
+            content={"displayname": "test" * 100},
             access_token=self.owner_tok,
         )
         self.assertEqual(channel.code, 400, channel.result)
 
-        res = self.get_displayname()
+        res = self._get_displayname()
         self.assertEqual(res, "owner")
 
-    def get_displayname(self):
-        channel = self.make_request("GET", "/profile/%s/displayname" % (self.owner,))
+    def test_get_displayname_other(self):
+        res = self._get_displayname(self.other)
+        self.assertEquals(res, "Bob")
+
+    def test_set_displayname_other(self):
+        channel = self.make_request(
+            "PUT",
+            "/profile/%s/displayname" % (self.other,),
+            content={"displayname": "test"},
+            access_token=self.owner_tok,
+        )
+        self.assertEqual(channel.code, 400, channel.result)
+
+    def test_get_avatar_url(self):
+        res = self._get_avatar_url()
+        self.assertIsNone(res)
+
+    def test_set_avatar_url(self):
+        channel = self.make_request(
+            "PUT",
+            "/profile/%s/avatar_url" % (self.owner,),
+            content={"avatar_url": "http://my.server/pic.gif"},
+            access_token=self.owner_tok,
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+
+        res = self._get_avatar_url()
+        self.assertEqual(res, "http://my.server/pic.gif")
+
+    def test_set_avatar_url_noauth(self):
+        channel = self.make_request(
+            "PUT",
+            "/profile/%s/avatar_url" % (self.owner,),
+            content={"avatar_url": "http://my.server/pic.gif"},
+        )
+        self.assertEqual(channel.code, 401, channel.result)
+
+    def test_set_avatar_url_too_long(self):
+        """Attempts to set a stupid avatar_url should get a 400"""
+        channel = self.make_request(
+            "PUT",
+            "/profile/%s/avatar_url" % (self.owner,),
+            content={"avatar_url": "http://my.server/pic.gif" * 100},
+            access_token=self.owner_tok,
+        )
+        self.assertEqual(channel.code, 400, channel.result)
+
+        res = self._get_avatar_url()
+        self.assertIsNone(res)
+
+    def test_get_avatar_url_other(self):
+        res = self._get_avatar_url(self.other)
+        self.assertIsNone(res)
+
+    def test_set_avatar_url_other(self):
+        channel = self.make_request(
+            "PUT",
+            "/profile/%s/avatar_url" % (self.other,),
+            content={"avatar_url": "http://my.server/pic.gif"},
+            access_token=self.owner_tok,
+        )
+        self.assertEqual(channel.code, 400, channel.result)
+
+    def _get_displayname(self, name=None):
+        channel = self.make_request(
+            "GET", "/profile/%s/displayname" % (name or self.owner,)
+        )
         self.assertEqual(channel.code, 200, channel.result)
         return channel.json_body["displayname"]
 
+    def _get_avatar_url(self, name=None):
+        channel = self.make_request(
+            "GET", "/profile/%s/avatar_url" % (name or self.owner,)
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+        return channel.json_body.get("avatar_url")
+
 
 class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
 
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index d4e3165436..ed65f645fc 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -46,7 +46,9 @@ class RoomBase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
 
         self.hs = self.setup_test_homeserver(
-            "red", federation_http_client=None, federation_client=Mock(),
+            "red",
+            federation_http_client=None,
+            federation_client=Mock(),
         )
 
         self.hs.get_federation_handler = Mock()
@@ -616,6 +618,41 @@ class RoomMemberStateTestCase(RoomBase):
         self.assertEquals(json.loads(content), channel.json_body)
 
 
+class RoomInviteRatelimitTestCase(RoomBase):
+    user_id = "@sid1:red"
+
+    servlets = [
+        admin.register_servlets,
+        profile.register_servlets,
+        room.register_servlets,
+    ]
+
+    @unittest.override_config(
+        {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}}
+    )
+    def test_invites_by_rooms_ratelimit(self):
+        """Tests that invites in a room are actually rate-limited."""
+        room_id = self.helper.create_room_as(self.user_id)
+
+        for i in range(3):
+            self.helper.invite(room_id, self.user_id, "@user-%s:red" % (i,))
+
+        self.helper.invite(room_id, self.user_id, "@user-4:red", expect_code=429)
+
+    @unittest.override_config(
+        {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
+    )
+    def test_invites_by_users_ratelimit(self):
+        """Tests that invites to a specific user are actually rate-limited."""
+
+        for i in range(3):
+            room_id = self.helper.create_room_as(self.user_id)
+            self.helper.invite(room_id, self.user_id, "@other-users:red")
+
+        room_id = self.helper.create_room_as(self.user_id)
+        self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429)
+
+
 class RoomJoinRatelimitTestCase(RoomBase):
     user_id = "@sid1:red"
 
@@ -1445,7 +1482,9 @@ class LabelsTestCase(unittest.HomeserverTestCase):
         results = channel.json_body["search_categories"]["room_events"]["results"]
 
         self.assertEqual(
-            len(results), 2, [result["result"]["content"] for result in results],
+            len(results),
+            2,
+            [result["result"]["content"] for result in results],
         )
         self.assertEqual(
             results[0]["result"]["content"]["body"],
@@ -1480,7 +1519,9 @@ class LabelsTestCase(unittest.HomeserverTestCase):
         results = channel.json_body["search_categories"]["room_events"]["results"]
 
         self.assertEqual(
-            len(results), 4, [result["result"]["content"] for result in results],
+            len(results),
+            4,
+            [result["result"]["content"] for result in results],
         )
         self.assertEqual(
             results[0]["result"]["content"]["body"],
@@ -1527,7 +1568,9 @@ class LabelsTestCase(unittest.HomeserverTestCase):
         results = channel.json_body["search_categories"]["room_events"]["results"]
 
         self.assertEqual(
-            len(results), 1, [result["result"]["content"] for result in results],
+            len(results),
+            1,
+            [result["result"]["content"] for result in results],
         )
         self.assertEqual(
             results[0]["result"]["content"]["body"],
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 38c51525a3..329dbd06de 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -18,8 +18,6 @@
 
 from mock import Mock
 
-from twisted.internet import defer
-
 from synapse.rest.client.v1 import room
 from synapse.types import UserID
 
@@ -39,7 +37,9 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
 
         hs = self.setup_test_homeserver(
-            "red", federation_http_client=None, federation_client=Mock(),
+            "red",
+            federation_http_client=None,
+            federation_client=Mock(),
         )
 
         self.event_source = hs.get_event_sources().sources["typing"]
@@ -60,32 +60,6 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
 
         hs.get_datastore().insert_client_ip = _insert_client_ip
 
-        def get_room_members(room_id):
-            if room_id == self.room_id:
-                return defer.succeed([self.user])
-            else:
-                return defer.succeed([])
-
-        @defer.inlineCallbacks
-        def fetch_room_distributions_into(
-            room_id, localusers=None, remotedomains=None, ignore_user=None
-        ):
-            members = yield get_room_members(room_id)
-            for member in members:
-                if ignore_user is not None and member == ignore_user:
-                    continue
-
-                if hs.is_mine(member):
-                    if localusers is not None:
-                        localusers.add(member)
-                else:
-                    if remotedomains is not None:
-                        remotedomains.add(member.domain)
-
-        hs.get_room_member_handler().fetch_room_distributions_into = (
-            fetch_room_distributions_into
-        )
-
         return hs
 
     def prepare(self, reactor, clock, hs):
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index b1333df82d..8231a423f3 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -166,9 +166,12 @@ class RestHelper:
             json.dumps(data).encode("utf8"),
         )
 
-        assert int(channel.result["code"]) == expect_code, (
-            "Expected: %d, got: %d, resp: %r"
-            % (expect_code, int(channel.result["code"]), channel.result["body"])
+        assert (
+            int(channel.result["code"]) == expect_code
+        ), "Expected: %d, got: %d, resp: %r" % (
+            expect_code,
+            int(channel.result["code"]),
+            channel.result["body"],
         )
 
         self.auth_user_id = temp_id
@@ -201,9 +204,12 @@ class RestHelper:
             json.dumps(content).encode("utf8"),
         )
 
-        assert int(channel.result["code"]) == expect_code, (
-            "Expected: %d, got: %d, resp: %r"
-            % (expect_code, int(channel.result["code"]), channel.result["body"])
+        assert (
+            int(channel.result["code"]) == expect_code
+        ), "Expected: %d, got: %d, resp: %r" % (
+            expect_code,
+            int(channel.result["code"]),
+            channel.result["body"],
         )
 
         return channel.json_body
@@ -251,9 +257,12 @@ class RestHelper:
 
         channel = make_request(self.hs.get_reactor(), self.site, method, path, content)
 
-        assert int(channel.result["code"]) == expect_code, (
-            "Expected: %d, got: %d, resp: %r"
-            % (expect_code, int(channel.result["code"]), channel.result["body"])
+        assert (
+            int(channel.result["code"]) == expect_code
+        ), "Expected: %d, got: %d, resp: %r" % (
+            expect_code,
+            int(channel.result["code"]),
+            channel.result["body"],
         )
 
         return channel.json_body
@@ -447,7 +456,10 @@ class RestHelper:
         return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict)
 
     def complete_oidc_auth(
-        self, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict,
+        self,
+        oauth_uri: str,
+        cookies: Mapping[str, str],
+        user_info_dict: JsonDict,
     ) -> FakeChannel:
         """Mock out an OIDC authentication flow
 
@@ -491,7 +503,9 @@ class RestHelper:
             (expected_uri, resp_obj) = expected_requests.pop(0)
             assert uri == expected_uri
             resp = FakeResponse(
-                code=200, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"),
+                code=200,
+                phrase=b"OK",
+                body=json.dumps(resp_obj).encode("utf-8"),
             )
             return resp
 
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index cb87b80e33..e72b61963d 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -24,7 +24,7 @@ import pkg_resources
 
 import synapse.rest.admin
 from synapse.api.constants import LoginType, Membership
-from synapse.api.errors import Codes
+from synapse.api.errors import Codes, HttpResponseException
 from synapse.rest.client.v1 import login, room
 from synapse.rest.client.v2_alpha import account, register
 from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
@@ -75,8 +75,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         self.submit_token_resource = PasswordResetSubmitTokenResource(hs)
 
     def test_basic_password_reset(self):
-        """Test basic password reset flow
-        """
+        """Test basic password reset flow"""
         old_password = "monkey"
         new_password = "kangeroo"
 
@@ -112,6 +111,55 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         # Assert we can't log in with the old password
         self.attempt_wrong_password_login("kermit", old_password)
 
+    @override_config({"rc_3pid_validation": {"burst_count": 3}})
+    def test_ratelimit_by_email(self):
+        """Test that we ratelimit /requestToken for the same email."""
+        old_password = "monkey"
+        new_password = "kangeroo"
+
+        user_id = self.register_user("kermit", old_password)
+        self.login("kermit", old_password)
+
+        email = "test1@example.com"
+
+        # Add a threepid
+        self.get_success(
+            self.store.user_add_threepid(
+                user_id=user_id,
+                medium="email",
+                address=email,
+                validated_at=0,
+                added_at=0,
+            )
+        )
+
+        def reset(ip):
+            client_secret = "foobar"
+            session_id = self._request_token(email, client_secret, ip)
+
+            self.assertEquals(len(self.email_attempts), 1)
+            link = self._get_link_from_email()
+
+            self._validate_token(link)
+
+            self._reset_password(new_password, session_id, client_secret)
+
+            self.email_attempts.clear()
+
+        # We expect to be able to make three requests before getting rate
+        # limited.
+        #
+        # We change IPs to ensure that we're not being ratelimited due to the
+        # same IP
+        reset("127.0.0.1")
+        reset("127.0.0.2")
+        reset("127.0.0.3")
+
+        with self.assertRaises(HttpResponseException) as cm:
+            reset("127.0.0.4")
+
+        self.assertEqual(cm.exception.code, 429)
+
     def test_basic_password_reset_canonicalise_email(self):
         """Test basic password reset flow
         Request password reset with different spelling
@@ -153,8 +201,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         self.attempt_wrong_password_login("kermit", old_password)
 
     def test_cant_reset_password_without_clicking_link(self):
-        """Test that we do actually need to click the link in the email
-        """
+        """Test that we do actually need to click the link in the email"""
         old_password = "monkey"
         new_password = "kangeroo"
 
@@ -239,13 +286,20 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
 
         self.assertIsNotNone(session_id)
 
-    def _request_token(self, email, client_secret):
+    def _request_token(self, email, client_secret, ip="127.0.0.1"):
         channel = self.make_request(
             "POST",
             b"account/password/email/requestToken",
             {"client_secret": client_secret, "email": email, "send_attempt": 1},
+            client_ip=ip,
         )
-        self.assertEquals(200, channel.code, channel.result)
+
+        if channel.code != 200:
+            raise HttpResponseException(
+                channel.code,
+                channel.result["reason"],
+                channel.result["body"],
+            )
 
         return channel.json_body["sid"]
 
@@ -509,9 +563,22 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
     def test_address_trim(self):
         self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar"))
 
+    @override_config({"rc_3pid_validation": {"burst_count": 3}})
+    def test_ratelimit_by_ip(self):
+        """Tests that adding emails is ratelimited by IP"""
+
+        # We expect to be able to set three emails before getting ratelimited.
+        self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar"))
+        self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar"))
+        self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar"))
+
+        with self.assertRaises(HttpResponseException) as cm:
+            self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar"))
+
+        self.assertEqual(cm.exception.code, 429)
+
     def test_add_email_if_disabled(self):
-        """Test adding email to profile when doing so is disallowed
-        """
+        """Test adding email to profile when doing so is disallowed"""
         self.hs.config.enable_3pid_changes = False
 
         client_secret = "foobar"
@@ -541,15 +608,16 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
 
         # Get user
         channel = self.make_request(
-            "GET", self.url_3pid, access_token=self.user_id_tok,
+            "GET",
+            self.url_3pid,
+            access_token=self.user_id_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertFalse(channel.json_body["threepids"])
 
     def test_delete_email(self):
-        """Test deleting an email from profile
-        """
+        """Test deleting an email from profile"""
         # Add a threepid
         self.get_success(
             self.store.user_add_threepid(
@@ -571,15 +639,16 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
 
         # Get user
         channel = self.make_request(
-            "GET", self.url_3pid, access_token=self.user_id_tok,
+            "GET",
+            self.url_3pid,
+            access_token=self.user_id_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertFalse(channel.json_body["threepids"])
 
     def test_delete_email_if_disabled(self):
-        """Test deleting an email from profile when disallowed
-        """
+        """Test deleting an email from profile when disallowed"""
         self.hs.config.enable_3pid_changes = False
 
         # Add a threepid
@@ -605,7 +674,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
 
         # Get user
         channel = self.make_request(
-            "GET", self.url_3pid, access_token=self.user_id_tok,
+            "GET",
+            self.url_3pid,
+            access_token=self.user_id_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -613,8 +684,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
 
     def test_cant_add_email_without_clicking_link(self):
-        """Test that we do actually need to click the link in the email
-        """
+        """Test that we do actually need to click the link in the email"""
         client_secret = "foobar"
         session_id = self._request_token(self.email, client_secret)
 
@@ -640,7 +710,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
 
         # Get user
         channel = self.make_request(
-            "GET", self.url_3pid, access_token=self.user_id_tok,
+            "GET",
+            self.url_3pid,
+            access_token=self.user_id_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -673,7 +745,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
 
         # Get user
         channel = self.make_request(
-            "GET", self.url_3pid, access_token=self.user_id_tok,
+            "GET",
+            self.url_3pid,
+            access_token=self.user_id_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -718,7 +792,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
 
         # Ensure not providing a next_link parameter still works
         self._request_token(
-            "something@example.com", "some_secret", next_link=None, expect_code=200,
+            "something@example.com",
+            "some_secret",
+            next_link=None,
+            expect_code=200,
         )
 
         self._request_token(
@@ -776,13 +853,27 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         if next_link:
             body["next_link"] = next_link
 
-        channel = self.make_request("POST", b"account/3pid/email/requestToken", body,)
-        self.assertEquals(expect_code, channel.code, channel.result)
+        channel = self.make_request(
+            "POST",
+            b"account/3pid/email/requestToken",
+            body,
+        )
+
+        if channel.code != expect_code:
+            raise HttpResponseException(
+                channel.code,
+                channel.result["reason"],
+                channel.result["body"],
+            )
 
         return channel.json_body.get("sid")
 
     def _request_token_invalid_email(
-        self, email, expected_errcode, expected_error, client_secret="foobar",
+        self,
+        email,
+        expected_errcode,
+        expected_error,
+        client_secret="foobar",
     ):
         channel = self.make_request(
             "POST",
@@ -821,12 +912,13 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         return match.group(0)
 
     def _add_email(self, request_email, expected_email):
-        """Test adding an email to profile
-        """
+        """Test adding an email to profile"""
+        previous_email_attempts = len(self.email_attempts)
+
         client_secret = "foobar"
         session_id = self._request_token(request_email, client_secret)
 
-        self.assertEquals(len(self.email_attempts), 1)
+        self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1)
         link = self._get_link_from_email()
 
         self._validate_token(link)
@@ -850,9 +942,13 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
 
         # Get user
         channel = self.make_request(
-            "GET", self.url_3pid, access_token=self.user_id_tok,
+            "GET",
+            self.url_3pid,
+            access_token=self.user_id_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
-        self.assertEqual(expected_email, channel.json_body["threepids"][0]["address"])
+
+        threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
+        self.assertIn(expected_email, threepids)
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index a6488a3d29..c26ad824f7 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -22,7 +22,7 @@ from synapse.api.constants import LoginType
 from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
 from synapse.rest.client.v1 import login
 from synapse.rest.client.v2_alpha import auth, devices, register
-from synapse.rest.oidc import OIDCResource
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
 from synapse.types import JsonDict, UserID
 
 from tests import unittest
@@ -102,7 +102,8 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
         """Ensure that fallback auth via a captcha works."""
         # Returns a 401 as per the spec
         channel = self.register(
-            401, {"username": "user", "type": "m.login.password", "password": "bar"},
+            401,
+            {"username": "user", "type": "m.login.password", "password": "bar"},
         )
 
         # Grab the session
@@ -173,9 +174,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
     def create_resource_dict(self):
         resource_dict = super().create_resource_dict()
-        if HAS_OIDC:
-            # mount the OIDC resource at /_synapse/oidc
-            resource_dict["/_synapse/oidc"] = OIDCResource(self.hs)
+        resource_dict.update(build_synapse_client_resource_tree(self.hs))
         return resource_dict
 
     def prepare(self, reactor, clock, hs):
@@ -193,7 +192,10 @@ class UIAuthTests(unittest.HomeserverTestCase):
     ) -> FakeChannel:
         """Delete an individual device."""
         channel = self.make_request(
-            "DELETE", "devices/" + device, body, access_token=access_token,
+            "DELETE",
+            "devices/" + device,
+            body,
+            access_token=access_token,
         )
 
         # Ensure the response is sane.
@@ -206,7 +208,10 @@ class UIAuthTests(unittest.HomeserverTestCase):
         # Note that this uses the delete_devices endpoint so that we can modify
         # the payload half-way through some tests.
         channel = self.make_request(
-            "POST", "delete_devices", body, access_token=self.user_tok,
+            "POST",
+            "delete_devices",
+            body,
+            access_token=self.user_tok,
         )
 
         # Ensure the response is sane.
@@ -338,7 +343,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
             },
         )
 
-    @unittest.override_config({"ui_auth": {"session_timeout": 5 * 1000}})
+    @unittest.override_config({"ui_auth": {"session_timeout": "5s"}})
     def test_can_reuse_session(self):
         """
         The session can be reused if configured.
@@ -419,7 +424,10 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
         # and now the delete request should succeed.
         self.delete_device(
-            self.user_tok, self.device_id, 200, body={"auth": {"session": session_id}},
+            self.user_tok,
+            self.device_id,
+            200,
+            body={"auth": {"session": session_id}},
         )
 
     @skip_unless(HAS_OIDC, "requires OIDC")
@@ -445,8 +453,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
     @skip_unless(HAS_OIDC, "requires OIDC")
     @override_config({"oidc_config": TEST_OIDC_CONFIG})
     def test_offers_both_flows_for_upgraded_user(self):
-        """A user that had a password and then logged in with SSO should get both flows
-        """
+        """A user that had a password and then logged in with SSO should get both flows"""
         login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
         self.assertEqual(login_resp["user_id"], self.user)
 
@@ -461,8 +468,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
     @skip_unless(HAS_OIDC, "requires OIDC")
     @override_config({"oidc_config": TEST_OIDC_CONFIG})
     def test_ui_auth_fails_for_incorrect_sso_user(self):
-        """If the user tries to authenticate with the wrong SSO user, they get an error
-        """
+        """If the user tries to authenticate with the wrong SSO user, they get an error"""
         # log the user in
         login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
         self.assertEqual(login_resp["user_id"], self.user)
diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py
index fba34def30..5ebc5707a5 100644
--- a/tests/rest/client/v2_alpha/test_password_policy.py
+++ b/tests/rest/client/v2_alpha/test_password_policy.py
@@ -91,7 +91,9 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(channel.code, 400, channel.result)
         self.assertEqual(
-            channel.json_body["errcode"], Codes.PASSWORD_TOO_SHORT, channel.result,
+            channel.json_body["errcode"],
+            Codes.PASSWORD_TOO_SHORT,
+            channel.result,
         )
 
     def test_password_no_digit(self):
@@ -100,7 +102,9 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(channel.code, 400, channel.result)
         self.assertEqual(
-            channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT, channel.result,
+            channel.json_body["errcode"],
+            Codes.PASSWORD_NO_DIGIT,
+            channel.result,
         )
 
     def test_password_no_symbol(self):
@@ -109,7 +113,9 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(channel.code, 400, channel.result)
         self.assertEqual(
-            channel.json_body["errcode"], Codes.PASSWORD_NO_SYMBOL, channel.result,
+            channel.json_body["errcode"],
+            Codes.PASSWORD_NO_SYMBOL,
+            channel.result,
         )
 
     def test_password_no_uppercase(self):
@@ -118,7 +124,9 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(channel.code, 400, channel.result)
         self.assertEqual(
-            channel.json_body["errcode"], Codes.PASSWORD_NO_UPPERCASE, channel.result,
+            channel.json_body["errcode"],
+            Codes.PASSWORD_NO_UPPERCASE,
+            channel.result,
         )
 
     def test_password_no_lowercase(self):
@@ -127,7 +135,9 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(channel.code, 400, channel.result)
         self.assertEqual(
-            channel.json_body["errcode"], Codes.PASSWORD_NO_LOWERCASE, channel.result,
+            channel.json_body["errcode"],
+            Codes.PASSWORD_NO_LOWERCASE,
+            channel.result,
         )
 
     def test_password_compliant(self):
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index bd574077e7..7c457754f1 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -83,14 +83,12 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         )
 
     def test_deny_membership(self):
-        """Test that we deny relations on membership events
-        """
+        """Test that we deny relations on membership events"""
         channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member)
         self.assertEquals(400, channel.code, channel.json_body)
 
     def test_deny_double_react(self):
-        """Test that we deny relations on membership events
-        """
+        """Test that we deny relations on membership events"""
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
         self.assertEquals(200, channel.code, channel.json_body)
 
@@ -98,8 +96,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertEquals(400, channel.code, channel.json_body)
 
     def test_basic_paginate_relations(self):
-        """Tests that calling pagination API correctly the latest relations.
-        """
+        """Tests that calling pagination API correctly the latest relations."""
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
         self.assertEquals(200, channel.code, channel.json_body)
 
@@ -174,8 +171,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertEquals(found_event_ids, expected_event_ids)
 
     def test_aggregation_pagination_groups(self):
-        """Test that we can paginate annotation groups correctly.
-        """
+        """Test that we can paginate annotation groups correctly."""
 
         # We need to create ten separate users to send each reaction.
         access_tokens = [self.user_token, self.user2_token]
@@ -240,8 +236,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertEquals(sent_groups, found_groups)
 
     def test_aggregation_pagination_within_group(self):
-        """Test that we can paginate within an annotation group.
-        """
+        """Test that we can paginate within an annotation group."""
 
         # We need to create ten separate users to send each reaction.
         access_tokens = [self.user_token, self.user2_token]
@@ -311,8 +306,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertEquals(found_event_ids, expected_event_ids)
 
     def test_aggregation(self):
-        """Test that annotations get correctly aggregated.
-        """
+        """Test that annotations get correctly aggregated."""
 
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
         self.assertEquals(200, channel.code, channel.json_body)
@@ -344,8 +338,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         )
 
     def test_aggregation_redactions(self):
-        """Test that annotations get correctly aggregated after a redaction.
-        """
+        """Test that annotations get correctly aggregated after a redaction."""
 
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
         self.assertEquals(200, channel.code, channel.json_body)
@@ -379,8 +372,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         )
 
     def test_aggregation_must_be_annotation(self):
-        """Test that aggregations must be annotations.
-        """
+        """Test that aggregations must be annotations."""
 
         channel = self.make_request(
             "GET",
@@ -437,8 +429,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         )
 
     def test_edit(self):
-        """Test that a simple edit works.
-        """
+        """Test that a simple edit works."""
 
         new_body = {"msgtype": "m.text", "body": "I've been edited!"}
         channel = self._send_relation(
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 512e36c236..2dbf42397a 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -388,13 +388,19 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
 
         # Check that room name changes increase the unread counter.
         self.helper.send_state(
-            self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2,
+            self.room_id,
+            "m.room.name",
+            {"name": "my super room"},
+            tok=self.tok2,
         )
         self._check_unread_count(1)
 
         # Check that room topic changes increase the unread counter.
         self.helper.send_state(
-            self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2,
+            self.room_id,
+            "m.room.topic",
+            {"topic": "welcome!!!"},
+            tok=self.tok2,
         )
         self._check_unread_count(2)
 
@@ -404,7 +410,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
 
         # Check that custom events with a body increase the unread counter.
         self.helper.send_event(
-            self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2,
+            self.room_id,
+            "org.matrix.custom_type",
+            {"body": "hello"},
+            tok=self.tok2,
         )
         self._check_unread_count(4)
 
@@ -443,14 +452,18 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
         """Syncs and compares the unread count with the expected value."""
 
         channel = self.make_request(
-            "GET", self.url % self.next_batch, access_token=self.tok,
+            "GET",
+            self.url % self.next_batch,
+            access_token=self.tok,
         )
 
         self.assertEqual(channel.code, 200, channel.json_body)
 
         room_entry = channel.json_body["rooms"]["join"][self.room_id]
         self.assertEqual(
-            room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry,
+            room_entry["org.matrix.msc2654.unread_count"],
+            expected_count,
+            room_entry,
         )
 
         # Store the next batch for the next request.
diff --git a/tests/rest/client/v2_alpha/test_upgrade_room.py b/tests/rest/client/v2_alpha/test_upgrade_room.py
new file mode 100644
index 0000000000..d890d11863
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_upgrade_room.py
@@ -0,0 +1,161 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional
+
+from synapse.config.server import DEFAULT_ROOM_VERSION
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet
+
+from tests import unittest
+from tests.server import FakeChannel
+
+
+class UpgradeRoomTest(unittest.HomeserverTestCase):
+    servlets = [
+        admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+        room_upgrade_rest_servlet.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+        self.handler = hs.get_user_directory_handler()
+
+        self.creator = self.register_user("creator", "pass")
+        self.creator_token = self.login(self.creator, "pass")
+
+        self.other = self.register_user("user", "pass")
+        self.other_token = self.login(self.other, "pass")
+
+        self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_token)
+        self.helper.join(self.room_id, self.other, tok=self.other_token)
+
+    def _upgrade_room(self, token: Optional[str] = None) -> FakeChannel:
+        # We never want a cached response.
+        self.reactor.advance(5 * 60 + 1)
+
+        return self.make_request(
+            "POST",
+            "/_matrix/client/r0/rooms/%s/upgrade" % self.room_id,
+            # This will upgrade a room to the same version, but that's fine.
+            content={"new_version": DEFAULT_ROOM_VERSION},
+            access_token=token or self.creator_token,
+        )
+
+    def test_upgrade(self):
+        """
+        Upgrading a room should work fine.
+        """
+        channel = self._upgrade_room()
+        self.assertEquals(200, channel.code, channel.result)
+        self.assertIn("replacement_room", channel.json_body)
+
+    def test_not_in_room(self):
+        """
+        Upgrading a room should work fine.
+        """
+        # THe user isn't in the room.
+        roomless = self.register_user("roomless", "pass")
+        roomless_token = self.login(roomless, "pass")
+
+        channel = self._upgrade_room(roomless_token)
+        self.assertEquals(403, channel.code, channel.result)
+
+    def test_power_levels(self):
+        """
+        Another user can upgrade the room if their power level is increased.
+        """
+        # The other user doesn't have the proper power level.
+        channel = self._upgrade_room(self.other_token)
+        self.assertEquals(403, channel.code, channel.result)
+
+        # Increase the power levels so that this user can upgrade.
+        power_levels = self.helper.get_state(
+            self.room_id,
+            "m.room.power_levels",
+            tok=self.creator_token,
+        )
+        power_levels["users"][self.other] = 100
+        self.helper.send_state(
+            self.room_id,
+            "m.room.power_levels",
+            body=power_levels,
+            tok=self.creator_token,
+        )
+
+        # The upgrade should succeed!
+        channel = self._upgrade_room(self.other_token)
+        self.assertEquals(200, channel.code, channel.result)
+
+    def test_power_levels_user_default(self):
+        """
+        Another user can upgrade the room if the default power level for users is increased.
+        """
+        # The other user doesn't have the proper power level.
+        channel = self._upgrade_room(self.other_token)
+        self.assertEquals(403, channel.code, channel.result)
+
+        # Increase the power levels so that this user can upgrade.
+        power_levels = self.helper.get_state(
+            self.room_id,
+            "m.room.power_levels",
+            tok=self.creator_token,
+        )
+        power_levels["users_default"] = 100
+        self.helper.send_state(
+            self.room_id,
+            "m.room.power_levels",
+            body=power_levels,
+            tok=self.creator_token,
+        )
+
+        # The upgrade should succeed!
+        channel = self._upgrade_room(self.other_token)
+        self.assertEquals(200, channel.code, channel.result)
+
+    def test_power_levels_tombstone(self):
+        """
+        Another user can upgrade the room if they can send the tombstone event.
+        """
+        # The other user doesn't have the proper power level.
+        channel = self._upgrade_room(self.other_token)
+        self.assertEquals(403, channel.code, channel.result)
+
+        # Increase the power levels so that this user can upgrade.
+        power_levels = self.helper.get_state(
+            self.room_id,
+            "m.room.power_levels",
+            tok=self.creator_token,
+        )
+        power_levels["events"]["m.room.tombstone"] = 0
+        self.helper.send_state(
+            self.room_id,
+            "m.room.power_levels",
+            body=power_levels,
+            tok=self.creator_token,
+        )
+
+        # The upgrade should succeed!
+        channel = self._upgrade_room(self.other_token)
+        self.assertEquals(200, channel.code, channel.result)
+
+        power_levels = self.helper.get_state(
+            self.room_id,
+            "m.room.power_levels",
+            tok=self.creator_token,
+        )
+        self.assertNotIn(self.other, power_levels["users"])
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 5e90d656f7..9d0d0ef414 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -180,7 +180,8 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
         async def post_json(destination, path, data):
             self.assertEqual(destination, self.hs.hostname)
             self.assertEqual(
-                path, "/_matrix/key/v2/query",
+                path,
+                "/_matrix/key/v2/query",
             )
 
             channel = FakeChannel(self.site, self.reactor)
@@ -188,7 +189,9 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
             req.content = BytesIO(encode_canonical_json(data))
 
             req.requestReceived(
-                b"POST", path.encode("utf-8"), b"1.1",
+                b"POST",
+                path.encode("utf-8"),
+                b"1.1",
             )
             channel.await_result()
             self.assertEqual(channel.code, 200)
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index ae2b32b131..0789b12392 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -30,6 +30,8 @@ from twisted.internet import defer
 from twisted.internet.defer import Deferred
 
 from synapse.logging.context import make_deferred_yieldable
+from synapse.rest import admin
+from synapse.rest.client.v1 import login
 from synapse.rest.media.v1._base import FileInfo
 from synapse.rest.media.v1.filepath import MediaFilePaths
 from synapse.rest.media.v1.media_storage import MediaStorage
@@ -37,6 +39,7 @@ from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
 
 from tests import unittest
 from tests.server import FakeSite, make_request
+from tests.utils import default_config
 
 
 class MediaStorageTests(unittest.HomeserverTestCase):
@@ -164,7 +167,16 @@ class _TestImage:
             ),
         ),
         # an empty file
-        (_TestImage(b"", b"image/gif", b".gif", None, None, False,),),
+        (
+            _TestImage(
+                b"",
+                b"image/gif",
+                b".gif",
+                None,
+                None,
+                False,
+            ),
+        ),
     ],
 )
 class MediaRepoTests(unittest.HomeserverTestCase):
@@ -202,7 +214,6 @@ class MediaRepoTests(unittest.HomeserverTestCase):
 
         config = self.default_config()
         config["media_store_path"] = self.media_store_path
-        config["thumbnail_requirements"] = {}
         config["max_image_pixels"] = 2000000
 
         provider_config = {
@@ -313,15 +324,39 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
 
     def test_thumbnail_crop(self):
+        """Test that a cropped remote thumbnail is available."""
         self._test_thumbnail(
             "crop", self.test_image.expected_cropped, self.test_image.expected_found
         )
 
     def test_thumbnail_scale(self):
+        """Test that a scaled remote thumbnail is available."""
         self._test_thumbnail(
             "scale", self.test_image.expected_scaled, self.test_image.expected_found
         )
 
+    def test_invalid_type(self):
+        """An invalid thumbnail type is never available."""
+        self._test_thumbnail("invalid", None, False)
+
+    @unittest.override_config(
+        {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
+    )
+    def test_no_thumbnail_crop(self):
+        """
+        Override the config to generate only scaled thumbnails, but request a cropped one.
+        """
+        self._test_thumbnail("crop", None, False)
+
+    @unittest.override_config(
+        {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
+    )
+    def test_no_thumbnail_scale(self):
+        """
+        Override the config to generate only cropped thumbnails, but request a scaled one.
+        """
+        self._test_thumbnail("scale", None, False)
+
     def _test_thumbnail(self, method, expected_body, expected_found):
         params = "?width=32&height=32&method=" + method
         channel = make_request(
@@ -375,3 +410,93 @@ class MediaRepoTests(unittest.HomeserverTestCase):
             headers.getRawHeaders(b"X-Robots-Tag"),
             [b"noindex, nofollow, noarchive, noimageindex"],
         )
+
+
+class TestSpamChecker:
+    """A spam checker module that rejects all media that includes the bytes
+    `evil`.
+    """
+
+    def __init__(self, config, api):
+        self.config = config
+        self.api = api
+
+    def parse_config(config):
+        return config
+
+    async def check_event_for_spam(self, foo):
+        return False  # allow all events
+
+    async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
+        return True  # allow all invites
+
+    async def user_may_create_room(self, userid):
+        return True  # allow all room creations
+
+    async def user_may_create_room_alias(self, userid, room_alias):
+        return True  # allow all room aliases
+
+    async def user_may_publish_room(self, userid, room_id):
+        return True  # allow publishing of all rooms
+
+    async def check_media_file_for_spam(self, file_wrapper, file_info) -> bool:
+        buf = BytesIO()
+        await file_wrapper.write_chunks_to(buf.write)
+
+        return b"evil" in buf.getvalue()
+
+
+class SpamCheckerTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        login.register_servlets,
+        admin.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.user = self.register_user("user", "pass")
+        self.tok = self.login("user", "pass")
+
+        # Allow for uploading and downloading to/from the media repo
+        self.media_repo = hs.get_media_repository_resource()
+        self.download_resource = self.media_repo.children[b"download"]
+        self.upload_resource = self.media_repo.children[b"upload"]
+
+    def default_config(self):
+        config = default_config("test")
+
+        config.update(
+            {
+                "spam_checker": [
+                    {
+                        "module": TestSpamChecker.__module__ + ".TestSpamChecker",
+                        "config": {},
+                    }
+                ]
+            }
+        )
+
+        return config
+
+    def test_upload_innocent(self):
+        """Attempt to upload some innocent data that should be allowed."""
+
+        image_data = unhexlify(
+            b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+            b"0000001f15c4890000000a49444154789c63000100000500010d"
+            b"0a2db40000000049454e44ae426082"
+        )
+
+        self.helper.upload_media(
+            self.upload_resource, image_data, tok=self.tok, expect_code=200
+        )
+
+    def test_upload_ban(self):
+        """Attempt to upload some data that includes bytes "evil", which should
+        get rejected by the spam checker.
+        """
+
+        data = b"Some evil data"
+
+        self.helper.upload_media(
+            self.upload_resource, data, tok=self.tok, expect_code=400
+        )
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index c5e44af9f7..14de0921be 100644
--- a/tests/rest/test_well_known.py
+++ b/tests/rest/test_well_known.py
@@ -40,3 +40,12 @@ class WellKnownTests(unittest.HomeserverTestCase):
                 "m.identity_server": {"base_url": "https://testis"},
             },
         )
+
+    def test_well_known_no_public_baseurl(self):
+        self.hs.config.public_baseurl = None
+
+        channel = self.make_request(
+            "GET", "/.well-known/matrix/client", shorthand=False
+        )
+
+        self.assertEqual(channel.code, 404)
diff --git a/tests/server.py b/tests/server.py
index 5a85d5fe7f..d4ece5c448 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -47,6 +47,7 @@ class FakeChannel:
     site = attr.ib(type=Site)
     _reactor = attr.ib()
     result = attr.ib(type=dict, default=attr.Factory(dict))
+    _ip = attr.ib(type=str, default="127.0.0.1")
     _producer = None
 
     @property
@@ -120,7 +121,7 @@ class FakeChannel:
     def getPeer(self):
         # We give an address so that getClientIP returns a non null entry,
         # causing us to record the MAU
-        return address.IPv4Address("TCP", "127.0.0.1", 3423)
+        return address.IPv4Address("TCP", self._ip, 3423)
 
     def getHost(self):
         return None
@@ -196,6 +197,7 @@ def make_request(
     custom_headers: Optional[
         Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
     ] = None,
+    client_ip: str = "127.0.0.1",
 ) -> FakeChannel:
     """
     Make a web request using the given method, path and content, and render it
@@ -223,6 +225,9 @@ def make_request(
              will pump the reactor until the the renderer tells the channel the request
              is finished.
 
+        client_ip: The IP to use as the requesting IP. Useful for testing
+            ratelimiting.
+
     Returns:
         channel
     """
@@ -250,7 +255,7 @@ def make_request(
     if isinstance(content, str):
         content = content.encode("utf8")
 
-    channel = FakeChannel(site, reactor)
+    channel = FakeChannel(site, reactor, ip=client_ip)
 
     req = request(channel)
     req.content = BytesIO(content)
@@ -342,8 +347,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
         self._tcp_callbacks[(host, port)] = callback
 
     def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
-        """Fake L{IReactorTCP.connectTCP}.
-        """
+        """Fake L{IReactorTCP.connectTCP}."""
 
         conn = super().connectTCP(
             host, port, factory, timeout=timeout, bindAddress=None
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index fea54464af..d40d65b06a 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -353,7 +353,11 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
             tok = self.login(localpart, "password")
 
             # Sync with the user's token to mark the user as active.
-            channel = self.make_request("GET", "/sync?timeout=0", access_token=tok,)
+            channel = self.make_request(
+                "GET",
+                "/sync?timeout=0",
+                access_token=tok,
+            )
 
             # Also retrieves the list of invites for this user. We don't care about that
             # one except if we're processing the last user, which should have received an
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index 77c72834f2..66e3cafe8e 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -382,8 +382,7 @@ class StateTestCase(unittest.TestCase):
         self.do_check(events, edges, expected_state_ids)
 
     def test_mainline_sort(self):
-        """Tests that the mainline ordering works correctly.
-        """
+        """Tests that the mainline ordering works correctly."""
 
         events = [
             FakeEvent(
@@ -660,15 +659,27 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
         #           C -|-> B -> A
 
         a = FakeEvent(
-            id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+            id="A",
+            sender=ALICE,
+            type=EventTypes.Member,
+            state_key="",
+            content={},
         ).to_event([], [])
 
         b = FakeEvent(
-            id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+            id="B",
+            sender=ALICE,
+            type=EventTypes.Member,
+            state_key="",
+            content={},
         ).to_event([a.event_id], [])
 
         c = FakeEvent(
-            id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+            id="C",
+            sender=ALICE,
+            type=EventTypes.Member,
+            state_key="",
+            content={},
         ).to_event([b.event_id], [])
 
         persisted_events = {a.event_id: a, b.event_id: b}
@@ -694,19 +705,35 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
         #      D -> C -|-> B -> A
 
         a = FakeEvent(
-            id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+            id="A",
+            sender=ALICE,
+            type=EventTypes.Member,
+            state_key="",
+            content={},
         ).to_event([], [])
 
         b = FakeEvent(
-            id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+            id="B",
+            sender=ALICE,
+            type=EventTypes.Member,
+            state_key="",
+            content={},
         ).to_event([a.event_id], [])
 
         c = FakeEvent(
-            id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+            id="C",
+            sender=ALICE,
+            type=EventTypes.Member,
+            state_key="",
+            content={},
         ).to_event([b.event_id], [])
 
         d = FakeEvent(
-            id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+            id="D",
+            sender=ALICE,
+            type=EventTypes.Member,
+            state_key="",
+            content={},
         ).to_event([c.event_id], [])
 
         persisted_events = {a.event_id: a, b.event_id: b}
@@ -737,23 +764,43 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
         #              |
 
         a = FakeEvent(
-            id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+            id="A",
+            sender=ALICE,
+            type=EventTypes.Member,
+            state_key="",
+            content={},
         ).to_event([], [])
 
         b = FakeEvent(
-            id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+            id="B",
+            sender=ALICE,
+            type=EventTypes.Member,
+            state_key="",
+            content={},
         ).to_event([a.event_id], [])
 
         c = FakeEvent(
-            id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+            id="C",
+            sender=ALICE,
+            type=EventTypes.Member,
+            state_key="",
+            content={},
         ).to_event([b.event_id], [])
 
         d = FakeEvent(
-            id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+            id="D",
+            sender=ALICE,
+            type=EventTypes.Member,
+            state_key="",
+            content={},
         ).to_event([c.event_id], [])
 
         e = FakeEvent(
-            id="E", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+            id="E",
+            sender=ALICE,
+            type=EventTypes.Member,
+            state_key="",
+            content={},
         ).to_event([c.event_id, b.event_id], [])
 
         persisted_events = {a.event_id: a, b.event_id: b}
diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py
index 673e1fe3e3..38444e48e2 100644
--- a/tests/storage/test_account_data.py
+++ b/tests/storage/test_account_data.py
@@ -96,7 +96,9 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
         # No ignored_users key.
         self.get_success(
             self.store.add_account_data_for_user(
-                self.user, AccountDataTypes.IGNORED_USER_LIST, {},
+                self.user,
+                AccountDataTypes.IGNORED_USER_LIST,
+                {},
             )
         )
 
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 02aae1c13d..1b4fae0bb5 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -67,7 +67,9 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
         async def update(progress, count):
             self.assertEqual(progress, {"my_key": 2})
             self.assertAlmostEqual(
-                count, target_background_update_duration_ms / duration_ms, places=0,
+                count,
+                target_background_update_duration_ms / duration_ms,
+                places=0,
             )
             await self.updates._end_background_update("test_update")
             return count
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index c13a57dad1..7791138688 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -43,8 +43,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         self.room_id = info["room_id"]
 
     def run_background_update(self):
-        """Re run the background update to clean up the extremities.
-        """
+        """Re run the background update to clean up the extremities."""
         # Make sure we don't clash with in progress updates.
         self.assertTrue(
             self.store.db_pool.updates._all_done, "Background updates are still ongoing"
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index a69117c5a9..34e6526097 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -41,7 +41,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         device_id = "MY_DEVICE"
 
         # Insert a user IP
-        self.get_success(self.store.store_device(user_id, device_id, "display name",))
+        self.get_success(
+            self.store.store_device(
+                user_id,
+                device_id,
+                "display name",
+            )
+        )
         self.get_success(
             self.store.insert_client_ip(
                 user_id, "access_token", "ip", "user_agent", device_id
@@ -214,7 +220,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         device_id = "MY_DEVICE"
 
         # Insert a user IP
-        self.get_success(self.store.store_device(user_id, device_id, "display name",))
+        self.get_success(
+            self.store.store_device(
+                user_id,
+                device_id,
+                "display name",
+            )
+        )
         self.get_success(
             self.store.insert_client_ip(
                 user_id, "access_token", "ip", "user_agent", device_id
@@ -303,7 +315,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         device_id = "MY_DEVICE"
 
         # Insert a user IP
-        self.get_success(self.store.store_device(user_id, device_id, "display name",))
+        self.get_success(
+            self.store.store_device(
+                user_id,
+                device_id,
+                "display name",
+            )
+        )
         self.get_success(
             self.store.insert_client_ip(
                 user_id, "access_token", "ip", "user_agent", device_id
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index 0c46ad595b..16daa66cc9 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -90,7 +90,8 @@ class EventChainStoreTestCase(HomeserverTestCase):
                     "content": {"tag": "power"},
                 },
             ).build(
-                prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id],
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, bob_join.event_id],
             )
         )
 
@@ -226,7 +227,8 @@ class EventChainStoreTestCase(HomeserverTestCase):
 
             self.assertFalse(
                 link_map.exists_path_from(
-                    chain_map[create.event_id], chain_map[event.event_id],
+                    chain_map[create.event_id],
+                    chain_map[event.event_id],
                 ),
             )
 
@@ -287,7 +289,8 @@ class EventChainStoreTestCase(HomeserverTestCase):
                     "content": {"tag": "power"},
                 },
             ).build(
-                prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id],
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, bob_join.event_id],
             )
         )
 
@@ -373,7 +376,8 @@ class EventChainStoreTestCase(HomeserverTestCase):
             )
 
     def persist(
-        self, events: List[EventBase],
+        self,
+        events: List[EventBase],
     ):
         """Persist the given events and check that the links generated match
         those given.
@@ -394,7 +398,10 @@ class EventChainStoreTestCase(HomeserverTestCase):
             persist_events_store._persist_event_auth_chain_txn(txn, events)
 
         self.get_success(
-            persist_events_store.db_pool.runInteraction("_persist", _persist,)
+            persist_events_store.db_pool.runInteraction(
+                "_persist",
+                _persist,
+            )
         )
 
     def fetch_chains(
@@ -447,8 +454,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
 
 class LinkMapTestCase(unittest.TestCase):
     def test_simple(self):
-        """Basic tests for the LinkMap.
-        """
+        """Basic tests for the LinkMap."""
         link_map = _LinkMap()
 
         link_map.add_link((1, 1), (2, 1), new=False)
@@ -490,8 +496,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
         self.requester = create_requester(self.user_id)
 
     def _generate_room(self) -> Tuple[str, List[Set[str]]]:
-        """Insert a room without a chain cover index.
-        """
+        """Insert a room without a chain cover index."""
         room_id = self.helper.create_room_as(self.user_id, tok=self.token)
 
         # Mark the room as not having a chain cover index
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 9d04a066d8..06000f81a6 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -215,7 +215,12 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
                 ],
             )
 
-        self.get_success(self.store.db_pool.runInteraction("insert", insert_event,))
+        self.get_success(
+            self.store.db_pool.runInteraction(
+                "insert",
+                insert_event,
+            )
+        )
 
         # Now actually test that various combinations give the right result:
 
@@ -370,7 +375,8 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
             )
 
             self.hs.datastores.persist_events._persist_event_auth_chain_txn(
-                txn, [FakeEvent("b", room_id, auth_graph["b"])],
+                txn,
+                [FakeEvent("b", room_id, auth_graph["b"])],
             )
 
             self.store.db_pool.simple_update_txn(
@@ -380,7 +386,12 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
                 updatevalues={"has_auth_chain_index": True},
             )
 
-        self.get_success(self.store.db_pool.runInteraction("insert", insert_event,))
+        self.get_success(
+            self.store.db_pool.runInteraction(
+                "insert",
+                insert_event,
+            )
+        )
 
         # Now actually test that various combinations give the right result:
 
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index c0595963dd..485f1ee033 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -84,7 +84,9 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
 
             yield defer.ensureDeferred(
                 self.store.add_push_actions_to_staging(
-                    event.event_id, {user_id: action}, False,
+                    event.event_id,
+                    {user_id: action},
+                    False,
                 )
             )
             yield defer.ensureDeferred(
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index 71210ce606..ed898b8dbb 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -68,16 +68,14 @@ class ExtremPruneTestCase(HomeserverTestCase):
         self.assert_extremities([self.remote_event_1.event_id])
 
     def persist_event(self, event, state=None):
-        """Persist the event, with optional state
-        """
+        """Persist the event, with optional state"""
         context = self.get_success(
             self.state.compute_event_context(event, old_state=state)
         )
         self.get_success(self.persistence.persist_event(event, context))
 
     def assert_extremities(self, expected_extremities):
-        """Assert the current extremities for the room
-        """
+        """Assert the current extremities for the room"""
         extremities = self.get_success(
             self.store.get_prev_events_for_room(self.room_id)
         )
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 3e2fd4da01..aad6bc907e 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -86,7 +86,11 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         def _insert(txn):
             txn.execute(
-                "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
+                "INSERT INTO foobar VALUES (?, ?)",
+                (
+                    stream_id,
+                    instance_name,
+                ),
             )
             txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
             txn.execute(
@@ -138,8 +142,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
 
     def test_out_of_order_finish(self):
-        """Test that IDs persisted out of order are correctly handled
-        """
+        """Test that IDs persisted out of order are correctly handled"""
 
         # Prefill table with 7 rows written by 'master'
         self._insert_rows("master", 7)
@@ -246,8 +249,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
 
     def test_get_next_txn(self):
-        """Test that the `get_next_txn` function works correctly.
-        """
+        """Test that the `get_next_txn` function works correctly."""
 
         # Prefill table with 7 rows written by 'master'
         self._insert_rows("master", 7)
@@ -386,8 +388,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
 
     def test_writer_config_change(self):
-        """Test that changing the writer config correctly works.
-        """
+        """Test that changing the writer config correctly works."""
 
         self._insert_row_with_id("first", 3)
         self._insert_row_with_id("second", 5)
@@ -434,8 +435,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6)
 
     def test_sequence_consistency(self):
-        """Test that we error out if the table and sequence diverges.
-        """
+        """Test that we error out if the table and sequence diverges."""
 
         # Prefill with some rows
         self._insert_row_with_id("master", 3)
@@ -452,8 +452,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
 
 class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
-    """Tests MultiWriterIdGenerator that produce *negative* stream IDs.
-    """
+    """Tests MultiWriterIdGenerator that produce *negative* stream IDs."""
 
     if not USE_POSTGRES_FOR_TESTS:
         skip = "Requires Postgres"
@@ -494,12 +493,15 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         return self.get_success(self.db_pool.runWithConnection(_create))
 
     def _insert_row(self, instance_name: str, stream_id: int):
-        """Insert one row as the given instance with given stream_id.
-        """
+        """Insert one row as the given instance with given stream_id."""
 
         def _insert(txn):
             txn.execute(
-                "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
+                "INSERT INTO foobar VALUES (?, ?)",
+                (
+                    stream_id,
+                    instance_name,
+                ),
             )
             txn.execute(
                 """
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 8d97b6d4cd..5858c7fcc4 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -198,7 +198,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
     # value, although it gets stored on the config object as mau_limits.
     @override_config({"max_mau_value": 5, "mau_limit_reserved_threepids": gen_3pids(5)})
     def test_reap_monthly_active_users_reserved_users(self):
-        """ Tests that reaping correctly handles reaping where reserved users are
+        """Tests that reaping correctly handles reaping where reserved users are
         present"""
         threepids = self.hs.config.mau_limits_reserved_threepids
         initial_users = len(threepids)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index a6303bf0ee..b2a0e60856 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -299,8 +299,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
         )
 
     def test_redact_censor(self):
-        """Test that a redacted event gets censored in the DB after a month
-        """
+        """Test that a redacted event gets censored in the DB after a month"""
 
         self.get_success(
             self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@@ -370,8 +369,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
         self.assert_dict({"content": {}}, json.loads(event_json))
 
     def test_redact_redaction(self):
-        """Tests that we can redact a redaction and can fetch it again.
-        """
+        """Tests that we can redact a redaction and can fetch it again."""
 
         self.get_success(
             self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@@ -404,8 +402,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
         )
 
     def test_store_redacted_redaction(self):
-        """Tests that we can store a redacted redaction.
-        """
+        """Tests that we can store a redacted redaction."""
 
         self.get_success(
             self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index c8c7a90e5d..4eb41c46e8 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -52,6 +52,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
                 "creation_ts": 1000,
                 "user_type": None,
                 "deactivated": 0,
+                "shadow_banned": 0,
             },
             (yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))),
         )
@@ -145,7 +146,10 @@ class RegistrationStoreTestCase(unittest.TestCase):
         try:
             yield defer.ensureDeferred(
                 self.store.validate_threepid_session(
-                    "fake_sid", "fake_client_secret", "fake_token", 0,
+                    "fake_sid",
+                    "fake_client_secret",
+                    "fake_token",
+                    0,
                 )
             )
         except ThreepidValidationError as e:
@@ -158,7 +162,10 @@ class RegistrationStoreTestCase(unittest.TestCase):
         try:
             yield defer.ensureDeferred(
                 self.store.validate_threepid_session(
-                    "fake_sid", "fake_client_secret", "fake_token", 0,
+                    "fake_sid",
+                    "fake_client_secret",
+                    "fake_token",
+                    0,
                 )
             )
         except ThreepidValidationError as e:
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 69b4c5d6c2..3f2691ee6b 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -85,7 +85,10 @@ class EventAuthTestCase(unittest.TestCase):
 
         # king should be able to send state
         event_auth.check(
-            RoomVersions.V1, _random_state_event(king), auth_events, do_sig_check=False,
+            RoomVersions.V1,
+            _random_state_event(king),
+            auth_events,
+            do_sig_check=False,
         )
 
     def test_alias_event(self):
@@ -99,7 +102,10 @@ class EventAuthTestCase(unittest.TestCase):
 
         # creator should be able to send aliases
         event_auth.check(
-            RoomVersions.V1, _alias_event(creator), auth_events, do_sig_check=False,
+            RoomVersions.V1,
+            _alias_event(creator),
+            auth_events,
+            do_sig_check=False,
         )
 
         # Reject an event with no state key.
@@ -122,7 +128,10 @@ class EventAuthTestCase(unittest.TestCase):
 
         # Note that the member does *not* need to be in the room.
         event_auth.check(
-            RoomVersions.V1, _alias_event(other), auth_events, do_sig_check=False,
+            RoomVersions.V1,
+            _alias_event(other),
+            auth_events,
+            do_sig_check=False,
         )
 
     def test_msc2432_alias_event(self):
@@ -136,7 +145,10 @@ class EventAuthTestCase(unittest.TestCase):
 
         # creator should be able to send aliases
         event_auth.check(
-            RoomVersions.V6, _alias_event(creator), auth_events, do_sig_check=False,
+            RoomVersions.V6,
+            _alias_event(creator),
+            auth_events,
+            do_sig_check=False,
         )
 
         # No particular checks are done on the state key.
@@ -156,7 +168,10 @@ class EventAuthTestCase(unittest.TestCase):
         # Per standard auth rules, the member must be in the room.
         with self.assertRaises(AuthError):
             event_auth.check(
-                RoomVersions.V6, _alias_event(other), auth_events, do_sig_check=False,
+                RoomVersions.V6,
+                _alias_event(other),
+                auth_events,
+                do_sig_check=False,
             )
 
     def test_msc2209(self):
diff --git a/tests/test_mau.py b/tests/test_mau.py
index 51660b51d5..75d28a42df 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -242,7 +242,10 @@ class TestMauLimit(unittest.HomeserverTestCase):
         )
 
         channel = self.make_request(
-            "POST", "/register", request_data, access_token=token,
+            "POST",
+            "/register",
+            request_data,
+            access_token=token,
         )
 
         if channel.code != 200:
diff --git a/tests/test_metrics.py b/tests/test_metrics.py
index 759e4cd048..f696fcf89e 100644
--- a/tests/test_metrics.py
+++ b/tests/test_metrics.py
@@ -21,7 +21,7 @@ from tests import unittest
 
 
 def get_sample_labels_value(sample):
-    """ Extract the labels and values of a sample.
+    """Extract the labels and values of a sample.
 
     prometheus_client 0.5 changed the sample type to a named tuple with more
     members than the plain tuple had in 0.4 and earlier. This function can
diff --git a/tests/test_preview.py b/tests/test_preview.py
index c19facc1cb..ea83299918 100644
--- a/tests/test_preview.py
+++ b/tests/test_preview.py
@@ -15,6 +15,7 @@
 
 from synapse.rest.media.v1.preview_url_resource import (
     decode_and_calc_og,
+    get_html_media_encoding,
     summarize_paragraphs,
 )
 
@@ -26,7 +27,7 @@ except ImportError:
     lxml = None
 
 
-class PreviewTestCase(unittest.TestCase):
+class SummarizeTestCase(unittest.TestCase):
     if not lxml:
         skip = "url preview feature requires lxml"
 
@@ -144,12 +145,12 @@ class PreviewTestCase(unittest.TestCase):
         )
 
 
-class PreviewUrlTestCase(unittest.TestCase):
+class CalcOgTestCase(unittest.TestCase):
     if not lxml:
         skip = "url preview feature requires lxml"
 
     def test_simple(self):
-        html = """
+        html = b"""
         <html>
         <head><title>Foo</title></head>
         <body>
@@ -163,7 +164,7 @@ class PreviewUrlTestCase(unittest.TestCase):
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
     def test_comment(self):
-        html = """
+        html = b"""
         <html>
         <head><title>Foo</title></head>
         <body>
@@ -178,7 +179,7 @@ class PreviewUrlTestCase(unittest.TestCase):
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
     def test_comment2(self):
-        html = """
+        html = b"""
         <html>
         <head><title>Foo</title></head>
         <body>
@@ -202,7 +203,7 @@ class PreviewUrlTestCase(unittest.TestCase):
         )
 
     def test_script(self):
-        html = """
+        html = b"""
         <html>
         <head><title>Foo</title></head>
         <body>
@@ -217,7 +218,7 @@ class PreviewUrlTestCase(unittest.TestCase):
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
     def test_missing_title(self):
-        html = """
+        html = b"""
         <html>
         <body>
         Some text.
@@ -230,7 +231,7 @@ class PreviewUrlTestCase(unittest.TestCase):
         self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
 
     def test_h1_as_title(self):
-        html = """
+        html = b"""
         <html>
         <meta property="og:description" content="Some text."/>
         <body>
@@ -244,7 +245,7 @@ class PreviewUrlTestCase(unittest.TestCase):
         self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
 
     def test_missing_title_and_broken_h1(self):
-        html = """
+        html = b"""
         <html>
         <body>
         <h1><a href="foo"/></h1>
@@ -258,6 +259,115 @@ class PreviewUrlTestCase(unittest.TestCase):
         self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
 
     def test_empty(self):
-        html = ""
+        """Test a body with no data in it."""
+        html = b""
         og = decode_and_calc_og(html, "http://example.com/test.html")
         self.assertEqual(og, {})
+
+    def test_no_tree(self):
+        """A valid body with no tree in it."""
+        html = b"\x00"
+        og = decode_and_calc_og(html, "http://example.com/test.html")
+        self.assertEqual(og, {})
+
+    def test_invalid_encoding(self):
+        """An invalid character encoding should be ignored and treated as UTF-8, if possible."""
+        html = b"""
+        <html>
+        <head><title>Foo</title></head>
+        <body>
+        Some text.
+        </body>
+        </html>
+        """
+        og = decode_and_calc_og(
+            html, "http://example.com/test.html", "invalid-encoding"
+        )
+        self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
+
+    def test_invalid_encoding2(self):
+        """A body which doesn't match the sent character encoding."""
+        # Note that this contains an invalid UTF-8 sequence in the title.
+        html = b"""
+        <html>
+        <head><title>\xff\xff Foo</title></head>
+        <body>
+        Some text.
+        </body>
+        </html>
+        """
+        og = decode_and_calc_og(html, "http://example.com/test.html")
+        self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
+
+
+class MediaEncodingTestCase(unittest.TestCase):
+    def test_meta_charset(self):
+        """A character encoding is found via the meta tag."""
+        encoding = get_html_media_encoding(
+            b"""
+        <html>
+        <head><meta charset="ascii">
+        </head>
+        </html>
+        """,
+            "text/html",
+        )
+        self.assertEqual(encoding, "ascii")
+
+        # A less well-formed version.
+        encoding = get_html_media_encoding(
+            b"""
+        <html>
+        <head>< meta charset = ascii>
+        </head>
+        </html>
+        """,
+            "text/html",
+        )
+        self.assertEqual(encoding, "ascii")
+
+    def test_xml_encoding(self):
+        """A character encoding is found via the meta tag."""
+        encoding = get_html_media_encoding(
+            b"""
+        <?xml version="1.0" encoding="ascii"?>
+        <html>
+        </html>
+        """,
+            "text/html",
+        )
+        self.assertEqual(encoding, "ascii")
+
+    def test_meta_xml_encoding(self):
+        """Meta tags take precedence over XML encoding."""
+        encoding = get_html_media_encoding(
+            b"""
+        <?xml version="1.0" encoding="ascii"?>
+        <html>
+        <head><meta charset="UTF-16">
+        </head>
+        </html>
+        """,
+            "text/html",
+        )
+        self.assertEqual(encoding, "UTF-16")
+
+    def test_content_type(self):
+        """A character encoding is found via the Content-Type header."""
+        # Test a few variations of the header.
+        headers = (
+            'text/html; charset="ascii";',
+            "text/html;charset=ascii;",
+            'text/html;  charset="ascii"',
+            "text/html; charset=ascii",
+            'text/html; charset="ascii;',
+            'text/html; charset=ascii";',
+        )
+        for header in headers:
+            encoding = get_html_media_encoding(b"", header)
+            self.assertEqual(encoding, "ascii")
+
+    def test_fallback(self):
+        """A character encoding cannot be found in the body or header."""
+        encoding = get_html_media_encoding(b"", "text/html")
+        self.assertEqual(encoding, "utf-8")
diff --git a/tests/test_server.py b/tests/test_server.py
index 815da18e65..55cde7f62f 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -166,7 +166,10 @@ class JsonResourceTests(unittest.TestCase):
 
         res = JsonResource(self.homeserver)
         res.register_paths(
-            "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet",
+            "GET",
+            [re.compile("^/_matrix/foo$")],
+            _callback,
+            "test_servlet",
         )
 
         # The path was registered as GET, but this is a HEAD request.
diff --git a/tests/unittest.py b/tests/unittest.py
index bbd295687c..ca7031c724 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -255,7 +255,10 @@ class HomeserverTestCase(TestCase):
                 # We need a valid token ID to satisfy foreign key constraints.
                 token_id = self.get_success(
                     self.hs.get_datastore().add_access_token_to_user(
-                        self.helper.auth_user_id, "some_fake_token", None, None,
+                        self.helper.auth_user_id,
+                        "some_fake_token",
+                        None,
+                        None,
                     )
                 )
 
@@ -386,6 +389,7 @@ class HomeserverTestCase(TestCase):
         custom_headers: Optional[
             Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
         ] = None,
+        client_ip: str = "127.0.0.1",
     ) -> FakeChannel:
         """
         Create a SynapseRequest at the path using the method and containing the
@@ -410,6 +414,9 @@ class HomeserverTestCase(TestCase):
 
             custom_headers: (name, value) pairs to add as request headers
 
+            client_ip: The IP to use as the requesting IP. Useful for testing
+                ratelimiting.
+
         Returns:
             The FakeChannel object which stores the result of the request.
         """
@@ -426,6 +433,7 @@ class HomeserverTestCase(TestCase):
             content_is_form,
             await_result,
             custom_headers,
+            client_ip,
         )
 
     def setup_test_homeserver(self, *args, **kwargs):
diff --git a/tests/util/caches/test_cached_call.py b/tests/util/caches/test_cached_call.py
new file mode 100644
index 0000000000..f349b5ced0
--- /dev/null
+++ b/tests/util/caches/test_cached_call.py
@@ -0,0 +1,161 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest.mock import Mock
+
+from twisted.internet import defer
+from twisted.internet.defer import Deferred
+
+from synapse.util.caches.cached_call import CachedCall, RetryOnExceptionCachedCall
+
+from tests.test_utils import get_awaitable_result
+from tests.unittest import TestCase
+
+
+class CachedCallTestCase(TestCase):
+    def test_get(self):
+        """
+        Happy-path test case: makes a couple of calls and makes sure they behave
+        correctly
+        """
+        d = Deferred()
+
+        async def f():
+            return await d
+
+        slow_call = Mock(side_effect=f)
+
+        cached_call = CachedCall(slow_call)
+
+        # the mock should not yet have been called
+        slow_call.assert_not_called()
+
+        # now fire off a couple of calls
+        completed_results = []
+
+        async def r():
+            res = await cached_call.get()
+            completed_results.append(res)
+
+        r1 = defer.ensureDeferred(r())
+        r2 = defer.ensureDeferred(r())
+
+        # neither result should be complete yet
+        self.assertNoResult(r1)
+        self.assertNoResult(r2)
+
+        # and the mock should have been called *once*, with no params
+        slow_call.assert_called_once_with()
+
+        # allow the deferred to complete, which should complete both the pending results
+        d.callback(123)
+        self.assertEqual(completed_results, [123, 123])
+        self.successResultOf(r1)
+        self.successResultOf(r2)
+
+        # another call to the getter should complete immediately
+        slow_call.reset_mock()
+        r3 = get_awaitable_result(cached_call.get())
+        self.assertEqual(r3, 123)
+        slow_call.assert_not_called()
+
+    def test_fast_call(self):
+        """
+        Test the behaviour when the underlying function completes immediately
+        """
+
+        async def f():
+            return 12
+
+        fast_call = Mock(side_effect=f)
+        cached_call = CachedCall(fast_call)
+
+        # the mock should not yet have been called
+        fast_call.assert_not_called()
+
+        # run the call a couple of times, which should complete immediately
+        self.assertEqual(get_awaitable_result(cached_call.get()), 12)
+        self.assertEqual(get_awaitable_result(cached_call.get()), 12)
+
+        # the mock should have been called once
+        fast_call.assert_called_once_with()
+
+
+class RetryOnExceptionCachedCallTestCase(TestCase):
+    def test_get(self):
+        # set up the RetryOnExceptionCachedCall around a function which will fail
+        # (after a while)
+        d = Deferred()
+
+        async def f1():
+            await d
+            raise ValueError("moo")
+
+        slow_call = Mock(side_effect=f1)
+        cached_call = RetryOnExceptionCachedCall(slow_call)
+
+        # the mock should not yet have been called
+        slow_call.assert_not_called()
+
+        # now fire off a couple of calls
+        completed_results = []
+
+        async def r():
+            try:
+                await cached_call.get()
+            except Exception as e1:
+                completed_results.append(e1)
+
+        r1 = defer.ensureDeferred(r())
+        r2 = defer.ensureDeferred(r())
+
+        # neither result should be complete yet
+        self.assertNoResult(r1)
+        self.assertNoResult(r2)
+
+        # and the mock should have been called *once*, with no params
+        slow_call.assert_called_once_with()
+
+        # complete the deferred, which should make the pending calls fail
+        d.callback(0)
+        self.assertEqual(len(completed_results), 2)
+        for e in completed_results:
+            self.assertIsInstance(e, ValueError)
+            self.assertEqual(e.args, ("moo",))
+
+        # reset the mock to return a successful result, and make another pair of calls
+        # to the getter
+        d = Deferred()
+
+        async def f2():
+            return await d
+
+        slow_call.reset_mock()
+        slow_call.side_effect = f2
+        r3 = defer.ensureDeferred(cached_call.get())
+        r4 = defer.ensureDeferred(cached_call.get())
+
+        self.assertNoResult(r3)
+        self.assertNoResult(r4)
+        slow_call.assert_called_once_with()
+
+        # let that call complete, and check the results
+        d.callback(123)
+        self.assertEqual(self.successResultOf(r3), 123)
+        self.assertEqual(self.successResultOf(r4), 123)
+
+        # and now more calls to the getter should complete immediately
+        slow_call.reset_mock()
+        self.assertEqual(get_awaitable_result(cached_call.get()), 123)
+        slow_call.assert_not_called()
diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py
index ecd9efc4df..c24c33ee91 100644
--- a/tests/util/caches/test_deferred_cache.py
+++ b/tests/util/caches/test_deferred_cache.py
@@ -232,7 +232,10 @@ class DeferredCacheTestCase(TestCase):
 
     def test_eviction_iterable(self):
         cache = DeferredCache(
-            "test", max_entries=3, apply_cache_factor_from_config=False, iterable=True,
+            "test",
+            max_entries=3,
+            apply_cache_factor_from_config=False,
+            iterable=True,
         )
 
         cache.prefill(1, ["one", "two"])
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index cf1e3203a4..afb11b9caf 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -143,8 +143,7 @@ class DescriptorTestCase(unittest.TestCase):
         obj.mock.assert_not_called()
 
     def test_cache_with_sync_exception(self):
-        """If the wrapped function throws synchronously, things should continue to work
-        """
+        """If the wrapped function throws synchronously, things should continue to work"""
 
         class Cls:
             @cached()
@@ -165,8 +164,7 @@ class DescriptorTestCase(unittest.TestCase):
         self.failureResultOf(d, SynapseError)
 
     def test_cache_with_async_exception(self):
-        """The wrapped function returns a failure
-        """
+        """The wrapped function returns a failure"""
 
         class Cls:
             result = None
@@ -282,7 +280,8 @@ class DescriptorTestCase(unittest.TestCase):
                 try:
                     d = obj.fn(1)
                     self.assertEqual(
-                        current_context(), SENTINEL_CONTEXT,
+                        current_context(),
+                        SENTINEL_CONTEXT,
                     )
                     yield d
                     self.fail("No exception thrown")
@@ -374,8 +373,7 @@ class DescriptorTestCase(unittest.TestCase):
         obj.mock.assert_not_called()
 
     def test_cache_iterable_with_sync_exception(self):
-        """If the wrapped function throws synchronously, things should continue to work
-        """
+        """If the wrapped function throws synchronously, things should continue to work"""
 
         class Cls:
             @descriptors.cached(iterable=True)
diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py
index 1ef0af8e8f..e931a7ec18 100644
--- a/tests/util/test_itertools.py
+++ b/tests/util/test_itertools.py
@@ -24,28 +24,32 @@ class ChunkSeqTests(TestCase):
         parts = chunk_seq("123", 8)
 
         self.assertEqual(
-            list(parts), ["123"],
+            list(parts),
+            ["123"],
         )
 
     def test_long_seq(self):
         parts = chunk_seq("abcdefghijklmnop", 8)
 
         self.assertEqual(
-            list(parts), ["abcdefgh", "ijklmnop"],
+            list(parts),
+            ["abcdefgh", "ijklmnop"],
         )
 
     def test_uneven_parts(self):
         parts = chunk_seq("abcdefghijklmnop", 5)
 
         self.assertEqual(
-            list(parts), ["abcde", "fghij", "klmno", "p"],
+            list(parts),
+            ["abcde", "fghij", "klmno", "p"],
         )
 
     def test_empty_input(self):
         parts = chunk_seq([], 5)
 
         self.assertEqual(
-            list(parts), [],
+            list(parts),
+            [],
         )
 
 
diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py
index 13b753e367..9ed01f7e0c 100644
--- a/tests/util/test_stream_change_cache.py
+++ b/tests/util/test_stream_change_cache.py
@@ -70,7 +70,8 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
         self.assertTrue("user@foo.com" not in cache._entity_to_key)
 
         self.assertEqual(
-            cache.get_all_entities_changed(2), ["bar@baz.net", "user@elsewhere.org"],
+            cache.get_all_entities_changed(2),
+            ["bar@baz.net", "user@elsewhere.org"],
         )
         self.assertIsNone(cache.get_all_entities_changed(1))
 
@@ -80,7 +81,8 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
             {"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key)
         )
         self.assertEqual(
-            cache.get_all_entities_changed(2), ["user@elsewhere.org", "bar@baz.net"],
+            cache.get_all_entities_changed(2),
+            ["user@elsewhere.org", "bar@baz.net"],
         )
         self.assertIsNone(cache.get_all_entities_changed(1))
 
@@ -222,7 +224,8 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
         # Query a subset of the entries mid-way through the stream. We should
         # only get back the subset.
         self.assertEqual(
-            cache.get_entities_changed(["bar@baz.net"], stream_pos=2), {"bar@baz.net"},
+            cache.get_entities_changed(["bar@baz.net"], stream_pos=2),
+            {"bar@baz.net"},
         )
 
     def test_max_pos(self):
diff --git a/tests/utils.py b/tests/utils.py
index 09614093bc..4fb5098550 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -33,7 +33,6 @@ from synapse.api.room_versions import RoomVersions
 from synapse.config.database import DatabaseConnectionConfig
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import DEFAULT_ROOM_VERSION
-from synapse.http.server import HttpServer
 from synapse.logging.context import current_context, set_current_context
 from synapse.server import HomeServer
 from synapse.storage import DataStore
@@ -158,7 +157,9 @@ def default_config(name, parse=False):
             "local": {"per_second": 10000, "burst_count": 10000},
             "remote": {"per_second": 10000, "burst_count": 10000},
         },
+        "rc_3pid_validation": {"per_second": 10000, "burst_count": 10000},
         "saml2_enabled": False,
+        "public_baseurl": None,
         "default_identity_server": None,
         "key_refresh_interval": 24 * 60 * 60 * 1000,
         "old_signing_keys": {},
@@ -262,7 +263,10 @@ def setup_test_homeserver(
         db_conn.close()
 
     hs = homeserver_to_use(
-        name, config=config, version_string="Synapse/tests", reactor=reactor,
+        name,
+        config=config,
+        version_string="Synapse/tests",
+        reactor=reactor,
     )
 
     # Install @cache_in_self attributes
@@ -351,7 +355,7 @@ def mock_getRawHeaders(headers=None):
 
 
 # This is a mock /resource/ not an entire server
-class MockHttpResource(HttpServer):
+class MockHttpResource:
     def __init__(self, prefix=""):
         self.callbacks = []  # 3-tuple of method/pattern/function
         self.prefix = prefix
@@ -364,7 +368,7 @@ class MockHttpResource(HttpServer):
     def trigger(
         self, http_method, path, content, mock_request, federation_auth_origin=None
     ):
-        """ Fire an HTTP event.
+        """Fire an HTTP event.
 
         Args:
             http_method : The HTTP method
@@ -527,8 +531,7 @@ class MockClock:
 
 
 async def create_room(hs, room_id: str, creator_id: str):
-    """Creates and persist a creation event for the given room
-    """
+    """Creates and persist a creation event for the given room"""
 
     persistence_store = hs.get_storage().persistence
     store = hs.get_datastore()