summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_filtering.py16
-rw-r--r--tests/app/test_frontend_proxy.py6
-rw-r--r--tests/app/test_openid_listener.py8
-rw-r--r--tests/config/test_util.py53
-rw-r--r--tests/crypto/test_keyring.py16
-rw-r--r--tests/events/test_utils.py185
-rw-r--r--tests/federation/test_complexity.py4
-rw-r--r--tests/federation/test_federation_server.py6
-rw-r--r--tests/federation/transport/__init__.py0
-rw-r--r--tests/federation/transport/test_server.py39
-rw-r--r--tests/handlers/test_cas.py121
-rw-r--r--tests/handlers/test_device.py4
-rw-r--r--tests/handlers/test_directory.py14
-rw-r--r--tests/handlers/test_federation.py54
-rw-r--r--tests/handlers/test_message.py2
-rw-r--r--tests/handlers/test_oidc.py519
-rw-r--r--tests/handlers/test_password_providers.py29
-rw-r--r--tests/handlers/test_presence.py2
-rw-r--r--tests/handlers/test_profile.py32
-rw-r--r--tests/handlers/test_saml.py155
-rw-r--r--tests/handlers/test_typing.py19
-rw-r--r--tests/handlers/test_user_directory.py48
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py32
-rw-r--r--tests/http/test_additional_resource.py8
-rw-r--r--tests/http/test_client.py101
-rw-r--r--tests/http/test_endpoint.py2
-rw-r--r--tests/http/test_fedclient.py2
-rw-r--r--tests/http/test_proxyagent.py130
-rw-r--r--tests/logging/test_terse_json.py70
-rw-r--r--tests/push/test_email.py36
-rw-r--r--tests/push/test_http.py118
-rw-r--r--tests/push/test_presentable_names.py229
-rw-r--r--tests/push/test_push_rule_evaluator.py2
-rw-r--r--tests/replication/_base.py59
-rw-r--r--tests/replication/test_auth.py117
-rw-r--r--tests/replication/test_client_reader_shard.py36
-rw-r--r--tests/replication/test_federation_sender_shard.py10
-rw-r--r--tests/replication/test_multi_media_repo.py4
-rw-r--r--tests/replication/test_pusher_shard.py14
-rw-r--r--tests/replication/test_sharded_event_persister.py16
-rw-r--r--tests/rest/admin/test_admin.py45
-rw-r--r--tests/rest/admin/test_device.py114
-rw-r--r--tests/rest/admin/test_event_reports.py66
-rw-r--r--tests/rest/admin/test_media.py64
-rw-r--r--tests/rest/admin/test_room.py289
-rw-r--r--tests/rest/admin/test_statistics.py61
-rw-r--r--tests/rest/admin/test_user.py992
-rw-r--r--tests/rest/client/test_consent.py8
-rw-r--r--tests/rest/client/test_ephemeral_message.py2
-rw-r--r--tests/rest/client/test_identity.py6
-rw-r--r--tests/rest/client/test_redactions.py8
-rw-r--r--tests/rest/client/test_retention.py2
-rw-r--r--tests/rest/client/test_shadow_banned.py24
-rw-r--r--tests/rest/client/test_third_party_rules.py10
-rw-r--r--tests/rest/client/v1/test_directory.py8
-rw-r--r--tests/rest/client/v1/test_events.py8
-rw-r--r--tests/rest/client/v1/test_login.py609
-rw-r--r--tests/rest/client/v1/test_presence.py6
-rw-r--r--tests/rest/client/v1/test_profile.py18
-rw-r--r--tests/rest/client/v1/test_push_rule_attrs.py66
-rw-r--r--tests/rest/client/v1/test_rooms.py283
-rw-r--r--tests/rest/client/v1/test_typing.py10
-rw-r--r--tests/rest/client/v1/utils.py268
-rw-r--r--tests/rest/client/v2_alpha/test_account.py148
-rw-r--r--tests/rest/client/v2_alpha/test_auth.py260
-rw-r--r--tests/rest/client/v2_alpha/test_capabilities.py8
-rw-r--r--tests/rest/client/v2_alpha/test_filter.py14
-rw-r--r--tests/rest/client/v2_alpha/test_password_policy.py18
-rw-r--r--tests/rest/client/v2_alpha/test_register.py71
-rw-r--r--tests/rest/client/v2_alpha/test_relations.py36
-rw-r--r--tests/rest/client/v2_alpha/test_shared_rooms.py16
-rw-r--r--tests/rest/client/v2_alpha/test_sync.py32
-rw-r--r--tests/rest/key/v2/test_remote_key_resource.py4
-rw-r--r--tests/rest/media/v1/test_media_storage.py44
-rw-r--r--tests/rest/media/v1/test_url_preview.py69
-rw-r--r--tests/rest/test_health.py4
-rw-r--r--tests/rest/test_well_known.py8
-rw-r--r--tests/server.py55
-rw-r--r--tests/server_notices/test_consent.py6
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py8
-rw-r--r--tests/state/test_v2.py193
-rw-r--r--tests/storage/test_account_data.py120
-rw-r--r--tests/storage/test_devices.py26
-rw-r--r--tests/storage/test_e2e_room_keys.py2
-rw-r--r--tests/storage/test_event_chain.py741
-rw-r--r--tests/storage/test_event_federation.py270
-rw-r--r--tests/storage/test_events.py334
-rw-r--r--tests/storage/test_id_generators.py112
-rw-r--r--tests/storage/test_main.py7
-rw-r--r--tests/storage/test_profile.py26
-rw-r--r--tests/storage/test_purge.py2
-rw-r--r--tests/storage/test_redaction.py11
-rw-r--r--tests/storage/test_roommember.py8
-rw-r--r--tests/storage/test_user_directory.py23
-rw-r--r--tests/test_federation.py4
-rw-r--r--tests/test_mau.py48
-rw-r--r--tests/test_preview.py67
-rw-r--r--tests/test_server.py30
-rw-r--r--tests/test_terms_auth.py6
-rw-r--r--tests/test_types.py4
-rw-r--r--tests/test_utils/__init__.py39
-rw-r--r--tests/test_utils/html_parsers.py53
-rw-r--r--tests/test_utils/logging_setup.py2
-rw-r--r--tests/unittest.py134
-rw-r--r--tests/util/caches/test_deferred_cache.py73
-rw-r--r--tests/util/test_itertools.py61
-rw-r--r--tests/utils.py112
107 files changed, 6750 insertions, 1844 deletions
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index c98ae75974..279c94a03d 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -16,8 +16,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from mock import Mock
-
 import jsonschema
 
 from twisted.internet import defer
@@ -28,7 +26,7 @@ from synapse.api.filtering import Filter
 from synapse.events import make_event_from_dict
 
 from tests import unittest
-from tests.utils import DeferredMockCallable, MockHttpResource, setup_test_homeserver
+from tests.utils import setup_test_homeserver
 
 user_localpart = "test_user"
 
@@ -42,19 +40,9 @@ def MockEvent(**kwargs):
 
 
 class FilteringTestCase(unittest.TestCase):
-    @defer.inlineCallbacks
     def setUp(self):
-        self.mock_federation_resource = MockHttpResource()
-
-        self.mock_http_client = Mock(spec=[])
-        self.mock_http_client.put_json = DeferredMockCallable()
-
-        hs = yield setup_test_homeserver(
-            self.addCleanup, http_client=self.mock_http_client, keyring=Mock(),
-        )
-
+        hs = setup_test_homeserver(self.addCleanup)
         self.filtering = hs.get_filtering()
-
         self.datastore = hs.get_datastore()
 
     def test_errors_on_invalid_filters(self):
diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py
index 40abe9d72d..e0ca288829 100644
--- a/tests/app/test_frontend_proxy.py
+++ b/tests/app/test_frontend_proxy.py
@@ -23,7 +23,7 @@ class FrontendProxyTests(HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
 
         hs = self.setup_test_homeserver(
-            http_client=None, homeserver_to_use=GenericWorkerServer
+            federation_http_client=None, homeserver_to_use=GenericWorkerServer
         )
 
         return hs
@@ -57,7 +57,7 @@ class FrontendProxyTests(HomeserverTestCase):
         self.assertEqual(len(self.reactor.tcpServers), 1)
         site = self.reactor.tcpServers[0][1]
 
-        _, channel = make_request(self.reactor, site, "PUT", "presence/a/status")
+        channel = make_request(self.reactor, site, "PUT", "presence/a/status")
 
         # 400 + unrecognised, because nothing is registered
         self.assertEqual(channel.code, 400)
@@ -77,7 +77,7 @@ class FrontendProxyTests(HomeserverTestCase):
         self.assertEqual(len(self.reactor.tcpServers), 1)
         site = self.reactor.tcpServers[0][1]
 
-        _, channel = make_request(self.reactor, site, "PUT", "presence/a/status")
+        channel = make_request(self.reactor, site, "PUT", "presence/a/status")
 
         # 401, because the stub servlet still checks authentication
         self.assertEqual(channel.code, 401)
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index ea3be95cf1..467033e201 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -27,7 +27,7 @@ from tests.unittest import HomeserverTestCase
 class FederationReaderOpenIDListenerTests(HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
         hs = self.setup_test_homeserver(
-            http_client=None, homeserver_to_use=GenericWorkerServer
+            federation_http_client=None, homeserver_to_use=GenericWorkerServer
         )
         return hs
 
@@ -73,7 +73,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
                 return
             raise
 
-        _, channel = make_request(
+        channel = make_request(
             self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo"
         )
 
@@ -84,7 +84,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
 class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
         hs = self.setup_test_homeserver(
-            http_client=None, homeserver_to_use=SynapseHomeServer
+            federation_http_client=None, homeserver_to_use=SynapseHomeServer
         )
         return hs
 
@@ -121,7 +121,7 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
                 return
             raise
 
-        _, channel = make_request(
+        channel = make_request(
             self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo"
         )
 
diff --git a/tests/config/test_util.py b/tests/config/test_util.py
new file mode 100644
index 0000000000..10363e3765
--- /dev/null
+++ b/tests/config/test_util.py
@@ -0,0 +1,53 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.config import ConfigError
+from synapse.config._util import validate_config
+
+from tests.unittest import TestCase
+
+
+class ValidateConfigTestCase(TestCase):
+    """Test cases for synapse.config._util.validate_config"""
+
+    def test_bad_object_in_array(self):
+        """malformed objects within an array should be validated correctly"""
+
+        # consider a structure:
+        #
+        # array_of_objs:
+        #   - r: 1
+        #     foo: 2
+        #
+        #   - r: 2
+        #     bar: 3
+        #
+        # ... where each entry must contain an "r": check that the path
+        # to the required item is correclty reported.
+
+        schema = {
+            "type": "object",
+            "properties": {
+                "array_of_objs": {
+                    "type": "array",
+                    "items": {"type": "object", "required": ["r"]},
+                },
+            },
+        }
+
+        with self.assertRaises(ConfigError) as c:
+            validate_config(schema, {"array_of_objs": [{}]}, ("base",))
+
+        self.assertEqual(c.exception.path, ["base", "array_of_objs", "<item 0>"])
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 697916a019..1d65ea2f9c 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -75,7 +75,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         return val
 
     def test_verify_json_objects_for_server_awaits_previous_requests(self):
-        mock_fetcher = keyring.KeyFetcher()
+        mock_fetcher = Mock()
         mock_fetcher.get_keys = Mock()
         kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
 
@@ -195,7 +195,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         """Tests that we correctly handle key requests for keys we've stored
         with a null `ts_valid_until_ms`
         """
-        mock_fetcher = keyring.KeyFetcher()
+        mock_fetcher = Mock()
         mock_fetcher.get_keys = Mock(return_value=make_awaitable({}))
 
         kr = keyring.Keyring(
@@ -249,7 +249,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
                 }
             }
 
-        mock_fetcher = keyring.KeyFetcher()
+        mock_fetcher = Mock()
         mock_fetcher.get_keys = Mock(side_effect=get_keys)
         kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
 
@@ -288,9 +288,9 @@ class KeyringTestCase(unittest.HomeserverTestCase):
                 }
             }
 
-        mock_fetcher1 = keyring.KeyFetcher()
+        mock_fetcher1 = Mock()
         mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
-        mock_fetcher2 = keyring.KeyFetcher()
+        mock_fetcher2 = Mock()
         mock_fetcher2.get_keys = Mock(side_effect=get_keys2)
         kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2))
 
@@ -315,7 +315,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
 class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
         self.http_client = Mock()
-        hs = self.setup_test_homeserver(http_client=self.http_client)
+        hs = self.setup_test_homeserver(federation_http_client=self.http_client)
         return hs
 
     def test_get_keys_from_server(self):
@@ -395,7 +395,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
             }
         ]
 
-        return self.setup_test_homeserver(http_client=self.http_client, config=config)
+        return self.setup_test_homeserver(
+            federation_http_client=self.http_client, config=config
+        )
 
     def build_perspectives_response(
         self, server_name: str, signing_key: SigningKey, valid_until_ts: int,
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index c1274c14af..8ba36c6074 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -34,11 +34,17 @@ def MockEvent(**kwargs):
 
 
 class PruneEventTestCase(unittest.TestCase):
-    """ Asserts that a new event constructed with `evdict` will look like
-    `matchdict` when it is redacted. """
-
     def run_test(self, evdict, matchdict, **kwargs):
-        self.assertEquals(
+        """
+        Asserts that a new event constructed with `evdict` will look like
+        `matchdict` when it is redacted.
+
+        Args:
+             evdict: The dictionary to build the event from.
+             matchdict: The expected resulting dictionary.
+             kwargs: Additional keyword arguments used to create the event.
+        """
+        self.assertEqual(
             prune_event(make_event_from_dict(evdict, **kwargs)).get_dict(), matchdict
         )
 
@@ -55,54 +61,80 @@ class PruneEventTestCase(unittest.TestCase):
         )
 
     def test_basic_keys(self):
+        """Ensure that the keys that should be untouched are kept."""
+        # Note that some of the values below don't really make sense, but the
+        # pruning of events doesn't worry about the values of any fields (with
+        # the exception of the content field).
         self.run_test(
             {
+                "event_id": "$3:domain",
                 "type": "A",
                 "room_id": "!1:domain",
                 "sender": "@2:domain",
-                "event_id": "$3:domain",
+                "state_key": "B",
+                "content": {"other_key": "foo"},
+                "hashes": "hashes",
+                "signatures": {"domain": {"algo:1": "sigs"}},
+                "depth": 4,
+                "prev_events": "prev_events",
+                "prev_state": "prev_state",
+                "auth_events": "auth_events",
                 "origin": "domain",
+                "origin_server_ts": 1234,
+                "membership": "join",
+                # Also include a key that should be removed.
+                "other_key": "foo",
             },
             {
+                "event_id": "$3:domain",
                 "type": "A",
                 "room_id": "!1:domain",
                 "sender": "@2:domain",
-                "event_id": "$3:domain",
+                "state_key": "B",
+                "hashes": "hashes",
+                "depth": 4,
+                "prev_events": "prev_events",
+                "prev_state": "prev_state",
+                "auth_events": "auth_events",
                 "origin": "domain",
+                "origin_server_ts": 1234,
+                "membership": "join",
                 "content": {},
-                "signatures": {},
+                "signatures": {"domain": {"algo:1": "sigs"}},
                 "unsigned": {},
             },
         )
 
-    def test_unsigned_age_ts(self):
+        # As of MSC2176 we now redact the membership and prev_states keys.
         self.run_test(
-            {"type": "B", "event_id": "$test:domain", "unsigned": {"age_ts": 20}},
-            {
-                "type": "B",
-                "event_id": "$test:domain",
-                "content": {},
-                "signatures": {},
-                "unsigned": {"age_ts": 20},
-            },
+            {"type": "A", "prev_state": "prev_state", "membership": "join"},
+            {"type": "A", "content": {}, "signatures": {}, "unsigned": {}},
+            room_version=RoomVersions.MSC2176,
         )
 
+    def test_unsigned(self):
+        """Ensure that unsigned properties get stripped (except age_ts and replaces_state)."""
         self.run_test(
             {
                 "type": "B",
                 "event_id": "$test:domain",
-                "unsigned": {"other_key": "here"},
+                "unsigned": {
+                    "age_ts": 20,
+                    "replaces_state": "$test2:domain",
+                    "other_key": "foo",
+                },
             },
             {
                 "type": "B",
                 "event_id": "$test:domain",
                 "content": {},
                 "signatures": {},
-                "unsigned": {},
+                "unsigned": {"age_ts": 20, "replaces_state": "$test2:domain"},
             },
         )
 
     def test_content(self):
+        """The content dictionary should be stripped in most cases."""
         self.run_test(
             {"type": "C", "event_id": "$test:domain", "content": {"things": "here"}},
             {
@@ -114,11 +146,35 @@ class PruneEventTestCase(unittest.TestCase):
             },
         )
 
+        # Some events keep a single content key/value.
+        EVENT_KEEP_CONTENT_KEYS = [
+            ("member", "membership", "join"),
+            ("join_rules", "join_rule", "invite"),
+            ("history_visibility", "history_visibility", "shared"),
+        ]
+        for event_type, key, value in EVENT_KEEP_CONTENT_KEYS:
+            self.run_test(
+                {
+                    "type": "m.room." + event_type,
+                    "event_id": "$test:domain",
+                    "content": {key: value, "other_key": "foo"},
+                },
+                {
+                    "type": "m.room." + event_type,
+                    "event_id": "$test:domain",
+                    "content": {key: value},
+                    "signatures": {},
+                    "unsigned": {},
+                },
+            )
+
+    def test_create(self):
+        """Create events are partially redacted until MSC2176."""
         self.run_test(
             {
                 "type": "m.room.create",
                 "event_id": "$test:domain",
-                "content": {"creator": "@2:domain", "other_field": "here"},
+                "content": {"creator": "@2:domain", "other_key": "foo"},
             },
             {
                 "type": "m.room.create",
@@ -129,6 +185,68 @@ class PruneEventTestCase(unittest.TestCase):
             },
         )
 
+        # After MSC2176, create events get nothing redacted.
+        self.run_test(
+            {"type": "m.room.create", "content": {"not_a_real_key": True}},
+            {
+                "type": "m.room.create",
+                "content": {"not_a_real_key": True},
+                "signatures": {},
+                "unsigned": {},
+            },
+            room_version=RoomVersions.MSC2176,
+        )
+
+    def test_power_levels(self):
+        """Power level events keep a variety of content keys."""
+        self.run_test(
+            {
+                "type": "m.room.power_levels",
+                "event_id": "$test:domain",
+                "content": {
+                    "ban": 1,
+                    "events": {"m.room.name": 100},
+                    "events_default": 2,
+                    "invite": 3,
+                    "kick": 4,
+                    "redact": 5,
+                    "state_default": 6,
+                    "users": {"@admin:domain": 100},
+                    "users_default": 7,
+                    "other_key": 8,
+                },
+            },
+            {
+                "type": "m.room.power_levels",
+                "event_id": "$test:domain",
+                "content": {
+                    "ban": 1,
+                    "events": {"m.room.name": 100},
+                    "events_default": 2,
+                    # Note that invite is not here.
+                    "kick": 4,
+                    "redact": 5,
+                    "state_default": 6,
+                    "users": {"@admin:domain": 100},
+                    "users_default": 7,
+                },
+                "signatures": {},
+                "unsigned": {},
+            },
+        )
+
+        # After MSC2176, power levels events keep the invite key.
+        self.run_test(
+            {"type": "m.room.power_levels", "content": {"invite": 75}},
+            {
+                "type": "m.room.power_levels",
+                "content": {"invite": 75},
+                "signatures": {},
+                "unsigned": {},
+            },
+            room_version=RoomVersions.MSC2176,
+        )
+
     def test_alias_event(self):
         """Alias events have special behavior up through room version 6."""
         self.run_test(
@@ -146,8 +264,7 @@ class PruneEventTestCase(unittest.TestCase):
             },
         )
 
-    def test_msc2432_alias_event(self):
-        """After MSC2432, alias events have no special behavior."""
+        # After MSC2432, alias events have no special behavior.
         self.run_test(
             {"type": "m.room.aliases", "content": {"aliases": ["test"]}},
             {
@@ -159,6 +276,32 @@ class PruneEventTestCase(unittest.TestCase):
             room_version=RoomVersions.V6,
         )
 
+    def test_redacts(self):
+        """Redaction events have no special behaviour until MSC2174/MSC2176."""
+
+        self.run_test(
+            {"type": "m.room.redaction", "content": {"redacts": "$test2:domain"}},
+            {
+                "type": "m.room.redaction",
+                "content": {},
+                "signatures": {},
+                "unsigned": {},
+            },
+            room_version=RoomVersions.V6,
+        )
+
+        # After MSC2174, redaction events keep the redacts content key.
+        self.run_test(
+            {"type": "m.room.redaction", "content": {"redacts": "$test2:domain"}},
+            {
+                "type": "m.room.redaction",
+                "content": {"redacts": "$test2:domain"},
+                "signatures": {},
+                "unsigned": {},
+            },
+            room_version=RoomVersions.MSC2176,
+        )
+
 
 class SerializeEventTestCase(unittest.TestCase):
     def serialize(self, ev, fields):
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 0187f56e21..9ccd2d76b8 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -48,7 +48,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         )
 
         # Get the room complexity
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
         )
         self.assertEquals(200, channel.code)
@@ -60,7 +60,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23)
 
         # Get the room complexity again -- make sure it's our artificial value
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
         )
         self.assertEquals(200, channel.code)
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index 3009fbb6c4..cfeccc0577 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -46,7 +46,7 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase):
 
         "/get_missing_events/(?P<room_id>[^/]*)/?"
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/federation/v1/get_missing_events/%s" % (room_1,),
             query_content,
@@ -95,7 +95,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
         room_1 = self.helper.create_room_as(u1, tok=u1_token)
         self.inject_room_member(room_1, "@user:other.example.com", "join")
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/_matrix/federation/v1/state/%s" % (room_1,)
         )
         self.assertEquals(200, channel.code, channel.result)
@@ -127,7 +127,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
 
         room_1 = self.helper.create_room_as(u1, tok=u1_token)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/_matrix/federation/v1/state/%s" % (room_1,)
         )
         self.assertEquals(403, channel.code, channel.result)
diff --git a/tests/federation/transport/__init__.py b/tests/federation/transport/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/federation/transport/__init__.py
diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py
index f9e3c7a51f..85500e169c 100644
--- a/tests/federation/transport/test_server.py
+++ b/tests/federation/transport/test_server.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2020 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,38 +13,31 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
-from twisted.internet import defer
-
-from synapse.config.ratelimiting import FederationRateLimitConfig
-from synapse.federation.transport import server
-from synapse.util.ratelimitutils import FederationRateLimiter
-
 from tests import unittest
 from tests.unittest import override_config
 
 
-class RoomDirectoryFederationTests(unittest.HomeserverTestCase):
-    def prepare(self, reactor, clock, homeserver):
-        class Authenticator:
-            def authenticate_request(self, request, content):
-                return defer.succeed("otherserver.nottld")
-
-        ratelimiter = FederationRateLimiter(clock, FederationRateLimitConfig())
-        server.register_servlets(
-            homeserver, self.resource, Authenticator(), ratelimiter
-        )
-
+class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
     @override_config({"allow_public_rooms_over_federation": False})
     def test_blocked_public_room_list_over_federation(self):
-        request, channel = self.make_request(
-            "GET", "/_matrix/federation/v1/publicRooms"
+        """Test that unauthenticated requests to the public rooms directory 403 when
+        allow_public_rooms_over_federation is False.
+        """
+        channel = self.make_request(
+            "GET",
+            "/_matrix/federation/v1/publicRooms",
+            federation_auth_origin=b"example.com",
         )
         self.assertEquals(403, channel.code)
 
     @override_config({"allow_public_rooms_over_federation": True})
     def test_open_public_room_list_over_federation(self):
-        request, channel = self.make_request(
-            "GET", "/_matrix/federation/v1/publicRooms"
+        """Test that unauthenticated requests to the public rooms directory 200 when
+        allow_public_rooms_over_federation is True.
+        """
+        channel = self.make_request(
+            "GET",
+            "/_matrix/federation/v1/publicRooms",
+            federation_auth_origin=b"example.com",
         )
         self.assertEquals(200, channel.code)
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
new file mode 100644
index 0000000000..7baf224f7e
--- /dev/null
+++ b/tests/handlers/test_cas.py
@@ -0,0 +1,121 @@
+#  Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+from mock import Mock
+
+from synapse.handlers.cas_handler import CasResponse
+
+from tests.test_utils import simple_async_mock
+from tests.unittest import HomeserverTestCase
+
+# These are a few constants that are used as config parameters in the tests.
+BASE_URL = "https://synapse/"
+SERVER_URL = "https://issuer/"
+
+
+class CasHandlerTestCase(HomeserverTestCase):
+    def default_config(self):
+        config = super().default_config()
+        config["public_baseurl"] = BASE_URL
+        cas_config = {
+            "enabled": True,
+            "server_url": SERVER_URL,
+            "service_url": BASE_URL,
+        }
+        config["cas_config"] = cas_config
+
+        return config
+
+    def make_homeserver(self, reactor, clock):
+        hs = self.setup_test_homeserver()
+
+        self.handler = hs.get_cas_handler()
+
+        # Reduce the number of attempts when generating MXIDs.
+        sso_handler = hs.get_sso_handler()
+        sso_handler._MAP_USERNAME_RETRIES = 3
+
+        return hs
+
+    def test_map_cas_user_to_user(self):
+        """Ensure that mapping the CAS user returned from a provider to an MXID works properly."""
+
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        cas_response = CasResponse("test_user", {})
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "redirect_uri", None, new_user=True
+        )
+
+    def test_map_cas_user_to_existing_user(self):
+        """Existing users can log in with CAS account."""
+        store = self.hs.get_datastore()
+        self.get_success(
+            store.register_user(user_id="@test_user:test", password_hash=None)
+        )
+
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        # Map a user via SSO.
+        cas_response = CasResponse("test_user", {})
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "redirect_uri", None, new_user=False
+        )
+
+        # Subsequent calls should map to the same mxid.
+        auth_handler.complete_sso_login.reset_mock()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "redirect_uri", None, new_user=False
+        )
+
+    def test_map_cas_user_to_invalid_localpart(self):
+        """CAS automaps invalid characters to base-64 encoding."""
+
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        cas_response = CasResponse("föö", {})
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
+        )
+
+
+def _mock_request():
+    """Returns a mock which will stand in as a SynapseRequest"""
+    return Mock(spec=["getClientIP", "getHeader"])
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 875aaec2c6..5dfeccfeb6 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -27,7 +27,7 @@ user2 = "@theresa:bbb"
 
 class DeviceTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
-        hs = self.setup_test_homeserver("server", http_client=None)
+        hs = self.setup_test_homeserver("server", federation_http_client=None)
         self.handler = hs.get_device_handler()
         self.store = hs.get_datastore()
         return hs
@@ -229,7 +229,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
 
 class DehydrationTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
-        hs = self.setup_test_homeserver("server", http_client=None)
+        hs = self.setup_test_homeserver("server", federation_http_client=None)
         self.handler = hs.get_device_handler()
         self.registration = hs.get_registration_handler()
         self.auth = hs.get_auth()
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index ee6ef5e6fa..a39f898608 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -42,8 +42,6 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         self.mock_registry.register_query_handler = register_query_handler
 
         hs = self.setup_test_homeserver(
-            http_client=None,
-            resource_for_federation=Mock(),
             federation_client=self.mock_federation,
             federation_registry=self.mock_registry,
         )
@@ -407,7 +405,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
     def test_denied(self):
         room_id = self.helper.create_room_as(self.user_id)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             b"directory/room/%23test%3Atest",
             ('{"room_id":"%s"}' % (room_id,)).encode("ascii"),
@@ -417,7 +415,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
     def test_allowed(self):
         room_id = self.helper.create_room_as(self.user_id)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             b"directory/room/%23unofficial_test%3Atest",
             ('{"room_id":"%s"}' % (room_id,)).encode("ascii"),
@@ -433,7 +431,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
         room_id = self.helper.create_room_as(self.user_id)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}"
         )
         self.assertEquals(200, channel.code, channel.result)
@@ -448,7 +446,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
         self.directory_handler.enable_room_list_search = True
 
         # Room list is enabled so we should get some results
-        request, channel = self.make_request("GET", b"publicRooms")
+        channel = self.make_request("GET", b"publicRooms")
         self.assertEquals(200, channel.code, channel.result)
         self.assertTrue(len(channel.json_body["chunk"]) > 0)
 
@@ -456,13 +454,13 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
         self.directory_handler.enable_room_list_search = False
 
         # Room list disabled so we should get no results
-        request, channel = self.make_request("GET", b"publicRooms")
+        channel = self.make_request("GET", b"publicRooms")
         self.assertEquals(200, channel.code, channel.result)
         self.assertTrue(len(channel.json_body["chunk"]) == 0)
 
         # Room list disabled so we shouldn't be allowed to publish rooms
         room_id = self.helper.create_room_as(self.user_id)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}"
         )
         self.assertEquals(403, channel.code, channel.result)
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index bf866dacf3..983e368592 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -16,7 +16,7 @@ import logging
 from unittest import TestCase
 
 from synapse.api.constants import EventTypes
-from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase
 from synapse.federation.federation_base import event_from_pdu_json
@@ -37,7 +37,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
     ]
 
     def make_homeserver(self, reactor, clock):
-        hs = self.setup_test_homeserver(http_client=None)
+        hs = self.setup_test_homeserver(federation_http_client=None)
         self.handler = hs.get_federation_handler()
         self.store = hs.get_datastore()
         return hs
@@ -126,7 +126,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
             room_version,
         )
 
-        with LoggingContext(request="send_rejected"):
+        with LoggingContext("send_rejected"):
             d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
         self.get_success(d)
 
@@ -178,7 +178,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
             room_version,
         )
 
-        with LoggingContext(request="send_rejected"):
+        with LoggingContext("send_rejected"):
             d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
         self.get_success(d)
 
@@ -191,6 +191,50 @@ class FederationTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(sg, sg2)
 
+    @unittest.override_config(
+        {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
+    )
+    def test_invite_by_user_ratelimit(self):
+        """Tests that invites from federation to a particular user are
+        actually rate-limited.
+        """
+        other_server = "otherserver"
+        other_user = "@otheruser:" + other_server
+
+        # create the room
+        user_id = self.register_user("kermit", "test")
+        tok = self.login("kermit", "test")
+
+        def create_invite():
+            room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+            room_version = self.get_success(self.store.get_room_version(room_id))
+            return event_from_pdu_json(
+                {
+                    "type": EventTypes.Member,
+                    "content": {"membership": "invite"},
+                    "room_id": room_id,
+                    "sender": other_user,
+                    "state_key": "@user:test",
+                    "depth": 32,
+                    "prev_events": [],
+                    "auth_events": [],
+                    "origin_server_ts": self.clock.time_msec(),
+                },
+                room_version,
+            )
+
+        for i in range(3):
+            event = create_invite()
+            self.get_success(
+                self.handler.on_invite_request(other_server, event, event.room_version,)
+            )
+
+        event = create_invite()
+        self.get_failure(
+            self.handler.on_invite_request(other_server, event, event.room_version,),
+            exc=LimitExceededError,
+        )
+
     def _build_and_send_join_event(self, other_server, other_user, room_id):
         join_event = self.get_success(
             self.handler.on_make_join_request(other_server, room_id, other_user)
@@ -198,7 +242,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         # the auth code requires that a signature exists, but doesn't check that
         # signature... go figure.
         join_event.signatures[other_server] = {"x": "y"}
-        with LoggingContext(request="send_join"):
+        with LoggingContext("send_join"):
             d = run_in_background(
                 self.handler.on_send_join_request, other_server, join_event
             )
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index af42775815..f955dfa490 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -206,7 +206,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
 
         # Redaction of event should fail.
         path = "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room_id, event_id)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", path, content={}, access_token=self.access_token
         )
         self.assertEqual(int(channel.result["code"]), 403)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index a308c46da9..ad20400b1d 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,32 +13,26 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import json
+from typing import Optional
 from urllib.parse import parse_qs, urlparse
 
-from mock import Mock, patch
+from mock import ANY, Mock, patch
 
-import attr
 import pymacaroons
 
-from twisted.python.failure import Failure
-from twisted.web._newclient import ResponseDone
-
-from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
 from synapse.handlers.sso import MappingException
+from synapse.server import HomeServer
 from synapse.types import UserID
 
+from tests.test_utils import FakeResponse, simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
 
+try:
+    import authlib  # noqa: F401
 
-@attr.s
-class FakeResponse:
-    code = attr.ib()
-    body = attr.ib()
-    phrase = attr.ib()
-
-    def deliverBody(self, protocol):
-        protocol.dataReceived(self.body)
-        protocol.connectionLost(Failure(ResponseDone()))
+    HAS_OIDC = True
+except ImportError:
+    HAS_OIDC = False
 
 
 # These are a few constants that are used as config parameters in the tests.
@@ -46,7 +40,7 @@ ISSUER = "https://issuer/"
 CLIENT_ID = "test-client-id"
 CLIENT_SECRET = "test-client-secret"
 BASE_URL = "https://synapse/"
-CALLBACK_URL = BASE_URL + "_synapse/oidc/callback"
+CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback"
 SCOPES = ["openid"]
 
 AUTHORIZATION_ENDPOINT = ISSUER + "authorize"
@@ -64,17 +58,14 @@ COMMON_CONFIG = {
 }
 
 
-# The cookie name and path don't really matter, just that it has to be coherent
-# between the callback & redirect handlers.
-COOKIE_NAME = b"oidc_session"
-COOKIE_PATH = "/_synapse/oidc"
-
-
-class TestMappingProvider(OidcMappingProvider):
+class TestMappingProvider:
     @staticmethod
     def parse_config(config):
         return
 
+    def __init__(self, config):
+        pass
+
     def get_remote_user_id(self, userinfo):
         return userinfo["sub"]
 
@@ -97,16 +88,6 @@ class TestMappingProviderFailures(TestMappingProvider):
         }
 
 
-def simple_async_mock(return_value=None, raises=None):
-    # AsyncMock is not available in python3.5, this mimics part of its behaviour
-    async def cb(*args, **kwargs):
-        if raises:
-            raise raises
-        return return_value
-
-    return Mock(side_effect=cb)
-
-
 async def get_json(url):
     # Mock get_json calls to handle jwks & oidc discovery endpoints
     if url == WELL_KNOWN:
@@ -127,6 +108,9 @@ async def get_json(url):
 
 
 class OidcHandlerTestCase(HomeserverTestCase):
+    if not HAS_OIDC:
+        skip = "requires OIDC"
+
     def default_config(self):
         config = super().default_config()
         config["public_baseurl"] = BASE_URL
@@ -155,6 +139,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
 
         self.handler = hs.get_oidc_handler()
+        self.provider = self.handler._providers["oidc"]
         sso_handler = hs.get_sso_handler()
         # Mock the render error method.
         self.render_error = Mock(return_value=None)
@@ -166,27 +151,29 @@ class OidcHandlerTestCase(HomeserverTestCase):
         return hs
 
     def metadata_edit(self, values):
-        return patch.dict(self.handler._provider_metadata, values)
+        return patch.dict(self.provider._provider_metadata, values)
 
     def assertRenderedError(self, error, error_description=None):
+        self.render_error.assert_called_once()
         args = self.render_error.call_args[0]
         self.assertEqual(args[1], error)
         if error_description is not None:
             self.assertEqual(args[2], error_description)
         # Reset the render_error mock
         self.render_error.reset_mock()
+        return args
 
     def test_config(self):
         """Basic config correctly sets up the callback URL and client auth correctly."""
-        self.assertEqual(self.handler._callback_url, CALLBACK_URL)
-        self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID)
-        self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET)
+        self.assertEqual(self.provider._callback_url, CALLBACK_URL)
+        self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
+        self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
 
     @override_config({"oidc_config": {"discover": True}})
     def test_discovery(self):
         """The handler should discover the endpoints from OIDC discovery document."""
         # This would throw if some metadata were invalid
-        metadata = self.get_success(self.handler.load_metadata())
+        metadata = self.get_success(self.provider.load_metadata())
         self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
 
         self.assertEqual(metadata.issuer, ISSUER)
@@ -198,47 +185,47 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         # subsequent calls should be cached
         self.http_client.reset_mock()
-        self.get_success(self.handler.load_metadata())
+        self.get_success(self.provider.load_metadata())
         self.http_client.get_json.assert_not_called()
 
     @override_config({"oidc_config": COMMON_CONFIG})
     def test_no_discovery(self):
         """When discovery is disabled, it should not try to load from discovery document."""
-        self.get_success(self.handler.load_metadata())
+        self.get_success(self.provider.load_metadata())
         self.http_client.get_json.assert_not_called()
 
     @override_config({"oidc_config": COMMON_CONFIG})
     def test_load_jwks(self):
         """JWKS loading is done once (then cached) if used."""
-        jwks = self.get_success(self.handler.load_jwks())
+        jwks = self.get_success(self.provider.load_jwks())
         self.http_client.get_json.assert_called_once_with(JWKS_URI)
         self.assertEqual(jwks, {"keys": []})
 
         # subsequent calls should be cached…
         self.http_client.reset_mock()
-        self.get_success(self.handler.load_jwks())
+        self.get_success(self.provider.load_jwks())
         self.http_client.get_json.assert_not_called()
 
         # …unless forced
         self.http_client.reset_mock()
-        self.get_success(self.handler.load_jwks(force=True))
+        self.get_success(self.provider.load_jwks(force=True))
         self.http_client.get_json.assert_called_once_with(JWKS_URI)
 
         # Throw if the JWKS uri is missing
         with self.metadata_edit({"jwks_uri": None}):
-            self.get_failure(self.handler.load_jwks(force=True), RuntimeError)
+            self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
 
         # Return empty key set if JWKS are not used
-        self.handler._scopes = []  # not asking the openid scope
+        self.provider._scopes = []  # not asking the openid scope
         self.http_client.get_json.reset_mock()
-        jwks = self.get_success(self.handler.load_jwks(force=True))
+        jwks = self.get_success(self.provider.load_jwks(force=True))
         self.http_client.get_json.assert_not_called()
         self.assertEqual(jwks, {"keys": []})
 
     @override_config({"oidc_config": COMMON_CONFIG})
     def test_validate_config(self):
         """Provider metadatas are extensively validated."""
-        h = self.handler
+        h = self.provider
 
         # Default test config does not throw
         h._validate_metadata()
@@ -317,13 +304,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
         """Provider metadata validation can be disabled by config."""
         with self.metadata_edit({"issuer": "http://insecure"}):
             # This should not throw
-            self.handler._validate_metadata()
+            self.provider._validate_metadata()
 
     def test_redirect_request(self):
         """The redirect request has the right arguments & generates a valid session cookie."""
         req = Mock(spec=["addCookie"])
         url = self.get_success(
-            self.handler.handle_redirect_request(req, b"http://client/redirect")
+            self.provider.handle_redirect_request(req, b"http://client/redirect")
         )
         url = urlparse(url)
         auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
@@ -347,14 +334,21 @@ class OidcHandlerTestCase(HomeserverTestCase):
         # For some reason, call.args does not work with python3.5
         args = calls[0][0]
         kwargs = calls[0][1]
-        self.assertEqual(args[0], COOKIE_NAME)
-        self.assertEqual(kwargs["path"], COOKIE_PATH)
+
+        # The cookie name and path don't really matter, just that it has to be coherent
+        # between the callback & redirect handlers.
+        self.assertEqual(args[0], b"oidc_session")
+        self.assertEqual(kwargs["path"], "/_synapse/client/oidc")
         cookie = args[1]
 
         macaroon = pymacaroons.Macaroon.deserialize(cookie)
-        state = self.handler._get_value_from_macaroon(macaroon, "state")
-        nonce = self.handler._get_value_from_macaroon(macaroon, "nonce")
-        redirect = self.handler._get_value_from_macaroon(
+        state = self.handler._token_generator._get_value_from_macaroon(
+            macaroon, "state"
+        )
+        nonce = self.handler._token_generator._get_value_from_macaroon(
+            macaroon, "nonce"
+        )
+        redirect = self.handler._token_generator._get_value_from_macaroon(
             macaroon, "client_redirect_url"
         )
 
@@ -384,31 +378,29 @@ class OidcHandlerTestCase(HomeserverTestCase):
          - when the userinfo fetching fails
          - when the code exchange fails
         """
+
+        # ensure that we are correctly testing the fallback when "get_extra_attributes"
+        # is not implemented.
+        mapping_provider = self.provider._user_mapping_provider
+        with self.assertRaises(AttributeError):
+            _ = mapping_provider.get_extra_attributes
+
         token = {
             "type": "bearer",
             "id_token": "id_token",
             "access_token": "access_token",
         }
+        username = "bar"
         userinfo = {
             "sub": "foo",
-            "preferred_username": "bar",
+            "username": username,
         }
-        user_id = "@foo:domain.org"
-        self.handler._exchange_code = simple_async_mock(return_value=token)
-        self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
-        self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
-        self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
-        self.handler._auth_handler.complete_sso_login = simple_async_mock()
-        request = Mock(
-            spec=[
-                "args",
-                "getCookie",
-                "addCookie",
-                "requestHeaders",
-                "getClientIP",
-                "get_user_agent",
-            ]
-        )
+        expected_user_id = "@%s:%s" % (username, self.hs.hostname)
+        self.provider._exchange_code = simple_async_mock(return_value=token)
+        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
+        self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
 
         code = "code"
         state = "state"
@@ -416,74 +408,61 @@ class OidcHandlerTestCase(HomeserverTestCase):
         client_redirect_url = "http://client/redirect"
         user_agent = "Browser"
         ip_address = "10.0.0.1"
-        request.getCookie.return_value = self.handler._generate_oidc_session_token(
-            state=state,
-            nonce=nonce,
-            client_redirect_url=client_redirect_url,
-            ui_auth_session_id=None,
+        session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
+        request = _build_callback_request(
+            code, state, session, user_agent=user_agent, ip_address=ip_address
         )
 
-        request.args = {}
-        request.args[b"code"] = [code.encode("utf-8")]
-        request.args[b"state"] = [state.encode("utf-8")]
-
-        request.getClientIP.return_value = ip_address
-        request.get_user_agent.return_value = user_agent
-
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        self.handler._auth_handler.complete_sso_login.assert_called_once_with(
-            user_id, request, client_redirect_url, {},
+        auth_handler.complete_sso_login.assert_called_once_with(
+            expected_user_id, request, client_redirect_url, None, new_user=True
         )
-        self.handler._exchange_code.assert_called_once_with(code)
-        self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
-        self.handler._map_userinfo_to_user.assert_called_once_with(
-            userinfo, token, user_agent, ip_address
-        )
-        self.handler._fetch_userinfo.assert_not_called()
+        self.provider._exchange_code.assert_called_once_with(code)
+        self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
+        self.provider._fetch_userinfo.assert_not_called()
         self.render_error.assert_not_called()
 
         # Handle mapping errors
-        self.handler._map_userinfo_to_user = simple_async_mock(
-            raises=MappingException()
-        )
-        self.get_success(self.handler.handle_oidc_callback(request))
-        self.assertRenderedError("mapping_error")
-        self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
+        with patch.object(
+            self.provider,
+            "_remote_id_from_userinfo",
+            new=Mock(side_effect=MappingException()),
+        ):
+            self.get_success(self.handler.handle_oidc_callback(request))
+            self.assertRenderedError("mapping_error")
 
         # Handle ID token errors
-        self.handler._parse_id_token = simple_async_mock(raises=Exception())
+        self.provider._parse_id_token = simple_async_mock(raises=Exception())
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_token")
 
-        self.handler._auth_handler.complete_sso_login.reset_mock()
-        self.handler._exchange_code.reset_mock()
-        self.handler._parse_id_token.reset_mock()
-        self.handler._map_userinfo_to_user.reset_mock()
-        self.handler._fetch_userinfo.reset_mock()
+        auth_handler.complete_sso_login.reset_mock()
+        self.provider._exchange_code.reset_mock()
+        self.provider._parse_id_token.reset_mock()
+        self.provider._fetch_userinfo.reset_mock()
 
         # With userinfo fetching
-        self.handler._scopes = []  # do not ask the "openid" scope
+        self.provider._scopes = []  # do not ask the "openid" scope
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        self.handler._auth_handler.complete_sso_login.assert_called_once_with(
-            user_id, request, client_redirect_url, {},
-        )
-        self.handler._exchange_code.assert_called_once_with(code)
-        self.handler._parse_id_token.assert_not_called()
-        self.handler._map_userinfo_to_user.assert_called_once_with(
-            userinfo, token, user_agent, ip_address
+        auth_handler.complete_sso_login.assert_called_once_with(
+            expected_user_id, request, client_redirect_url, None, new_user=False
         )
-        self.handler._fetch_userinfo.assert_called_once_with(token)
+        self.provider._exchange_code.assert_called_once_with(code)
+        self.provider._parse_id_token.assert_not_called()
+        self.provider._fetch_userinfo.assert_called_once_with(token)
         self.render_error.assert_not_called()
 
         # Handle userinfo fetching error
-        self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
+        self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("fetch_error")
 
         # Handle code exchange failure
-        self.handler._exchange_code = simple_async_mock(
+        from synapse.handlers.oidc_handler import OidcError
+
+        self.provider._exchange_code = simple_async_mock(
             raises=OidcError("invalid_request")
         )
         self.get_success(self.handler.handle_oidc_callback(request))
@@ -513,11 +492,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertRenderedError("invalid_session")
 
         # Mismatching session
-        session = self.handler._generate_oidc_session_token(
-            state="state",
-            nonce="nonce",
-            client_redirect_url="http://client/redirect",
-            ui_auth_session_id=None,
+        session = self._generate_oidc_session_token(
+            state="state", nonce="nonce", client_redirect_url="http://client/redirect",
         )
         request.args = {}
         request.args[b"state"] = [b"mismatching state"]
@@ -541,7 +517,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
         )
         code = "code"
-        ret = self.get_success(self.handler._exchange_code(code))
+        ret = self.get_success(self.provider._exchange_code(code))
         kwargs = self.http_client.request.call_args[1]
 
         self.assertEqual(ret, token)
@@ -563,7 +539,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
                 body=b'{"error": "foo", "error_description": "bar"}',
             )
         )
-        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        from synapse.handlers.oidc_handler import OidcError
+
+        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "foo")
         self.assertEqual(exc.value.error_description, "bar")
 
@@ -573,7 +551,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
                 code=500, phrase=b"Internal Server Error", body=b"Not JSON",
             )
         )
-        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "server_error")
 
         # Internal server error with JSON body
@@ -585,14 +563,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
             )
         )
 
-        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "internal_server_error")
 
         # 4xx error without "error" field
         self.http_client.request = simple_async_mock(
             return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
         )
-        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "server_error")
 
         # 2xx error with "error" field
@@ -601,7 +579,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
                 code=200, phrase=b"OK", body=b'{"error": "some_error"}',
             )
         )
-        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "some_error")
 
     @override_config(
@@ -624,72 +602,56 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         userinfo = {
             "sub": "foo",
+            "username": "foo",
             "phone": "1234567",
         }
-        user_id = "@foo:domain.org"
-        self.handler._exchange_code = simple_async_mock(return_value=token)
-        self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
-        self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
-        self.handler._auth_handler.complete_sso_login = simple_async_mock()
-        request = Mock(
-            spec=[
-                "args",
-                "getCookie",
-                "addCookie",
-                "requestHeaders",
-                "getClientIP",
-                "get_user_agent",
-            ]
-        )
+        self.provider._exchange_code = simple_async_mock(return_value=token)
+        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
 
         state = "state"
         client_redirect_url = "http://client/redirect"
-        request.getCookie.return_value = self.handler._generate_oidc_session_token(
-            state=state,
-            nonce="nonce",
-            client_redirect_url=client_redirect_url,
-            ui_auth_session_id=None,
+        session = self._generate_oidc_session_token(
+            state=state, nonce="nonce", client_redirect_url=client_redirect_url,
         )
-
-        request.args = {}
-        request.args[b"code"] = [b"code"]
-        request.args[b"state"] = [state.encode("utf-8")]
-
-        request.getClientIP.return_value = "10.0.0.1"
-        request.get_user_agent.return_value = "Browser"
+        request = _build_callback_request("code", state, session)
 
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        self.handler._auth_handler.complete_sso_login.assert_called_once_with(
-            user_id, request, client_redirect_url, {"phone": "1234567"},
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@foo:test",
+            request,
+            client_redirect_url,
+            {"phone": "1234567"},
+            new_user=True,
         )
 
     def test_map_userinfo_to_user(self):
         """Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
         userinfo = {
             "sub": "test_user",
             "username": "test_user",
         }
-        # The token doesn't matter with the default user mapping provider.
-        token = {}
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", ANY, ANY, None, new_user=True
         )
-        self.assertEqual(mxid, "@test_user:test")
+        auth_handler.complete_sso_login.reset_mock()
 
         # Some providers return an integer ID.
         userinfo = {
             "sub": 1234,
             "username": "test_user_2",
         }
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user_2:test", ANY, ANY, None, new_user=True
         )
-        self.assertEqual(mxid, "@test_user_2:test")
+        auth_handler.complete_sso_login.reset_mock()
 
         # Test if the mxid is already taken
         store = self.hs.get_datastore()
@@ -698,14 +660,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user3.to_string(), password_hash=None)
         )
         userinfo = {"sub": "test3", "username": "test_user_3"}
-        e = self.get_failure(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            ),
-            MappingException,
-        )
-        self.assertEqual(
-            str(e.value), "Mapping provider does not support de-duplicating Matrix IDs",
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_not_called()
+        self.assertRenderedError(
+            "mapping_error",
+            "Mapping provider does not support de-duplicating Matrix IDs",
         )
 
     @override_config({"oidc_config": {"allow_existing_users": True}})
@@ -717,26 +676,26 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user.to_string(), password_hash=None)
         )
 
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
         # Map a user via SSO.
         userinfo = {
             "sub": "test",
             "username": "test_user",
         }
-        token = {}
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_called_once_with(
+            user.to_string(), ANY, ANY, None, new_user=False
         )
-        self.assertEqual(mxid, "@test_user:test")
+        auth_handler.complete_sso_login.reset_mock()
 
         # Subsequent calls should map to the same mxid.
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_called_once_with(
+            user.to_string(), ANY, ANY, None, new_user=False
         )
-        self.assertEqual(mxid, "@test_user:test")
+        auth_handler.complete_sso_login.reset_mock()
 
         # Note that a second SSO user can be mapped to the same Matrix ID. (This
         # requires a unique sub, but something that maps to the same matrix ID,
@@ -747,13 +706,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test1",
             "username": "test_user",
         }
-        token = {}
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_called_once_with(
+            user.to_string(), ANY, ANY, None, new_user=False
         )
-        self.assertEqual(mxid, "@test_user:test")
+        auth_handler.complete_sso_login.reset_mock()
 
         # Register some non-exact matching cases.
         user2 = UserID.from_string("@TEST_user_2:test")
@@ -770,14 +727,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test2",
             "username": "TEST_USER_2",
         }
-        e = self.get_failure(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            ),
-            MappingException,
-        )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_not_called()
+        args = self.assertRenderedError("mapping_error")
         self.assertTrue(
-            str(e.value).startswith(
+            args[2].startswith(
                 "Attempted to login as '@TEST_USER_2:test' but it matches more than one user inexactly:"
             )
         )
@@ -788,28 +742,17 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user2.to_string(), password_hash=None)
         )
 
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@TEST_USER_2:test", ANY, ANY, None, new_user=False
         )
-        self.assertEqual(mxid, "@TEST_USER_2:test")
 
     def test_map_userinfo_to_invalid_localpart(self):
         """If the mapping provider generates an invalid localpart it should be rejected."""
-        userinfo = {
-            "sub": "test2",
-            "username": "föö",
-        }
-        token = {}
-
-        e = self.get_failure(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            ),
-            MappingException,
+        self.get_success(
+            _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
         )
-        self.assertEqual(str(e.value), "localpart is invalid: föö")
+        self.assertRenderedError("mapping_error", "localpart is invalid: föö")
 
     @override_config(
         {
@@ -822,6 +765,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
     )
     def test_map_userinfo_to_user_retries(self):
         """The mapping provider can retry generating an MXID if the MXID is already in use."""
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
         store = self.hs.get_datastore()
         self.get_success(
             store.register_user(user_id="@test_user:test", password_hash=None)
@@ -830,14 +776,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test",
             "username": "test_user",
         }
-        token = {}
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
-        )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+
         # test_user is already taken, so test_user1 gets registered instead.
-        self.assertEqual(mxid, "@test_user1:test")
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user1:test", ANY, ANY, None, new_user=True
+        )
+        auth_handler.complete_sso_login.reset_mock()
 
         # Register all of the potential mxids for a particular OIDC username.
         self.get_success(
@@ -853,12 +798,128 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "tester",
             "username": "tester",
         }
-        e = self.get_failure(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            ),
-            MappingException,
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_not_called()
+        self.assertRenderedError(
+            "mapping_error", "Unable to generate a Matrix ID from the SSO response"
         )
-        self.assertEqual(
-            str(e.value), "Unable to generate a Matrix ID from the SSO response"
+
+    def test_empty_localpart(self):
+        """Attempts to map onto an empty localpart should be rejected."""
+        userinfo = {
+            "sub": "tester",
+            "username": "",
+        }
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        self.assertRenderedError("mapping_error", "localpart is invalid: ")
+
+    @override_config(
+        {
+            "oidc_config": {
+                "user_mapping_provider": {
+                    "config": {"localpart_template": "{{ user.username }}"}
+                }
+            }
+        }
+    )
+    def test_null_localpart(self):
+        """Mapping onto a null localpart via an empty OIDC attribute should be rejected"""
+        userinfo = {
+            "sub": "tester",
+            "username": None,
+        }
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        self.assertRenderedError("mapping_error", "localpart is invalid: ")
+
+    def _generate_oidc_session_token(
+        self,
+        state: str,
+        nonce: str,
+        client_redirect_url: str,
+        ui_auth_session_id: Optional[str] = None,
+    ) -> str:
+        from synapse.handlers.oidc_handler import OidcSessionData
+
+        return self.handler._token_generator.generate_oidc_session_token(
+            state=state,
+            session_data=OidcSessionData(
+                idp_id="oidc",
+                nonce=nonce,
+                client_redirect_url=client_redirect_url,
+                ui_auth_session_id=ui_auth_session_id,
+            ),
         )
+
+
+async def _make_callback_with_userinfo(
+    hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect"
+) -> None:
+    """Mock up an OIDC callback with the given userinfo dict
+
+    We'll pull out the OIDC handler from the homeserver, stub out a couple of methods,
+    and poke in the userinfo dict as if it were the response to an OIDC userinfo call.
+
+    Args:
+        hs: the HomeServer impl to send the callback to.
+        userinfo: the OIDC userinfo dict
+        client_redirect_url: the URL to redirect to on success.
+    """
+    from synapse.handlers.oidc_handler import OidcSessionData
+
+    handler = hs.get_oidc_handler()
+    provider = handler._providers["oidc"]
+    provider._exchange_code = simple_async_mock(return_value={})
+    provider._parse_id_token = simple_async_mock(return_value=userinfo)
+    provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
+
+    state = "state"
+    session = handler._token_generator.generate_oidc_session_token(
+        state=state,
+        session_data=OidcSessionData(
+            idp_id="oidc", nonce="nonce", client_redirect_url=client_redirect_url,
+        ),
+    )
+    request = _build_callback_request("code", state, session)
+
+    await handler.handle_oidc_callback(request)
+
+
+def _build_callback_request(
+    code: str,
+    state: str,
+    session: str,
+    user_agent: str = "Browser",
+    ip_address: str = "10.0.0.1",
+):
+    """Builds a fake SynapseRequest to mock the browser callback
+
+    Returns a Mock object which looks like the SynapseRequest we get from a browser
+    after SSO (before we return to the client)
+
+    Args:
+        code: the authorization code which would have been returned by the OIDC
+           provider
+        state: the "state" param which would have been passed around in the
+           query param. Should be the same as was embedded in the session in
+           _build_oidc_session.
+        session: the "session" which would have been passed around in the cookie.
+        user_agent: the user-agent to present
+        ip_address: the IP address to pretend the request came from
+    """
+    request = Mock(
+        spec=[
+            "args",
+            "getCookie",
+            "addCookie",
+            "requestHeaders",
+            "getClientIP",
+            "getHeader",
+        ]
+    )
+
+    request.getCookie.return_value = session
+    request.args = {}
+    request.args[b"code"] = [code.encode("utf-8")]
+    request.args[b"state"] = [state.encode("utf-8")]
+    request.getClientIP.return_value = ip_address
+    return request
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index ceaf0902d2..f816594ee4 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -432,6 +432,29 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
 
     @override_config(
         {
+            **providers_config(CustomAuthProvider),
+            "password_config": {"enabled": False, "localdb_enabled": False},
+        }
+    )
+    def test_custom_auth_password_disabled_localdb_enabled(self):
+        """Check the localdb_enabled == enabled == False
+
+        Regression test for https://github.com/matrix-org/synapse/issues/8914: check
+        that setting *both* `localdb_enabled` *and* `password: enabled` to False doesn't
+        cause an exception.
+        """
+        self.register_user("localuser", "localpass")
+
+        flows = self._get_login_flows()
+        self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
+
+        # login shouldn't work and should be rejected with a 400 ("unknown login type")
+        channel = self._send_password_login("localuser", "localpass")
+        self.assertEqual(channel.code, 400, channel.result)
+        mock_password_provider.check_auth.assert_not_called()
+
+    @override_config(
+        {
             **providers_config(PasswordCustomAuthProvider),
             "password_config": {"enabled": False},
         }
@@ -528,7 +551,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 400, channel.result)
 
     def _get_login_flows(self) -> JsonDict:
-        _, channel = self.make_request("GET", "/_matrix/client/r0/login")
+        channel = self.make_request("GET", "/_matrix/client/r0/login")
         self.assertEqual(channel.code, 200, channel.result)
         return channel.json_body["flows"]
 
@@ -537,7 +560,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
 
     def _send_login(self, type, user, **params) -> FakeChannel:
         params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type})
-        _, channel = self.make_request("POST", "/_matrix/client/r0/login", params)
+        channel = self.make_request("POST", "/_matrix/client/r0/login", params)
         return channel
 
     def _start_delete_device_session(self, access_token, device_id) -> str:
@@ -574,7 +597,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self, access_token: str, device: str, body: Union[JsonDict, bytes] = b"",
     ) -> FakeChannel:
         """Delete an individual device."""
-        _, channel = self.make_request(
+        channel = self.make_request(
             "DELETE", "devices/" + device, body, access_token=access_token
         )
         return channel
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 8ed67640f8..0794b32c9c 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -463,7 +463,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
 
     def make_homeserver(self, reactor, clock):
         hs = self.setup_test_homeserver(
-            "server", http_client=None, federation_sender=Mock()
+            "server", federation_http_client=None, federation_sender=Mock()
         )
         return hs
 
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index a69fa28b41..022943a10a 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -44,8 +44,6 @@ class ProfileTestCase(unittest.TestCase):
 
         hs = yield setup_test_homeserver(
             self.addCleanup,
-            http_client=None,
-            resource_for_federation=Mock(),
             federation_client=self.mock_federation,
             federation_server=Mock(),
             federation_registry=self.mock_registry,
@@ -107,6 +105,21 @@ class ProfileTestCase(unittest.TestCase):
             "Frank",
         )
 
+        # Set displayname to an empty string
+        yield defer.ensureDeferred(
+            self.handler.set_displayname(
+                self.frank, synapse.types.create_requester(self.frank), ""
+            )
+        )
+
+        self.assertIsNone(
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_displayname(self.frank.localpart)
+                )
+            )
+        )
+
     @defer.inlineCallbacks
     def test_set_my_name_if_disabled(self):
         self.hs.config.enable_set_displayname = False
@@ -225,6 +238,21 @@ class ProfileTestCase(unittest.TestCase):
             "http://my.server/me.png",
         )
 
+        # Set avatar to an empty string
+        yield defer.ensureDeferred(
+            self.handler.set_avatar_url(
+                self.frank, synapse.types.create_requester(self.frank), "",
+            )
+        )
+
+        self.assertIsNone(
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_avatar_url(self.frank.localpart)
+                )
+            ),
+        )
+
     @defer.inlineCallbacks
     def test_set_my_avatar_if_disabled(self):
         self.hs.config.enable_set_avatar_url = False
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 45dc17aba5..a8d6c0f617 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -12,13 +12,35 @@
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
 
+from typing import Optional
+
+from mock import Mock
+
 import attr
 
 from synapse.api.errors import RedirectException
-from synapse.handlers.sso import MappingException
 
+from tests.test_utils import simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
 
+# Check if we have the dependencies to run the tests.
+try:
+    import saml2.config
+    from saml2.sigver import SigverError
+
+    has_saml2 = True
+
+    # pysaml2 can be installed and imported, but might not be able to find xmlsec1.
+    config = saml2.config.SPConfig()
+    try:
+        config.load({"metadata": {}})
+        has_xmlsec1 = True
+    except SigverError:
+        has_xmlsec1 = False
+except ImportError:
+    has_saml2 = False
+    has_xmlsec1 = False
+
 # These are a few constants that are used as config parameters in the tests.
 BASE_URL = "https://synapse/"
 
@@ -26,6 +48,8 @@ BASE_URL = "https://synapse/"
 @attr.s
 class FakeAuthnResponse:
     ava = attr.ib(type=dict)
+    assertions = attr.ib(type=list, factory=list)
+    in_response_to = attr.ib(type=Optional[str], default=None)
 
 
 class TestMappingProvider:
@@ -86,17 +110,29 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         return hs
 
+    if not has_saml2:
+        skip = "Requires pysaml2"
+    elif not has_xmlsec1:
+        skip = "Requires xmlsec1"
+
     def test_map_saml_response_to_user(self):
         """Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
+
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        # send a mocked-up SAML response to the callback
         saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
-        # The redirect_url doesn't matter with the default user mapping provider.
-        redirect_url = ""
-        mxid = self.get_success(
-            self.handler._map_saml_response_to_user(
-                saml_response, redirect_url, "user-agent", "10.10.10.10"
-            )
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, "redirect_uri")
+        )
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "redirect_uri", None, new_user=True
         )
-        self.assertEqual(mxid, "@test_user:test")
 
     @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
     def test_map_saml_response_to_existing_user(self):
@@ -106,53 +142,81 @@ class SamlHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id="@test_user:test", password_hash=None)
         )
 
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
         # Map a user via SSO.
         saml_response = FakeAuthnResponse(
             {"uid": "tester", "mxid": ["test_user"], "username": "test_user"}
         )
-        redirect_url = ""
-        mxid = self.get_success(
-            self.handler._map_saml_response_to_user(
-                saml_response, redirect_url, "user-agent", "10.10.10.10"
-            )
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, "")
+        )
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "", None, new_user=False
         )
-        self.assertEqual(mxid, "@test_user:test")
 
         # Subsequent calls should map to the same mxid.
-        mxid = self.get_success(
-            self.handler._map_saml_response_to_user(
-                saml_response, redirect_url, "user-agent", "10.10.10.10"
-            )
+        auth_handler.complete_sso_login.reset_mock()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, "")
+        )
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "", None, new_user=False
         )
-        self.assertEqual(mxid, "@test_user:test")
 
     def test_map_saml_response_to_invalid_localpart(self):
         """If the mapping provider generates an invalid localpart it should be rejected."""
+
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        # mock out the error renderer too
+        sso_handler = self.hs.get_sso_handler()
+        sso_handler.render_error = Mock(return_value=None)
+
         saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
-        redirect_url = ""
-        e = self.get_failure(
-            self.handler._map_saml_response_to_user(
-                saml_response, redirect_url, "user-agent", "10.10.10.10"
-            ),
-            MappingException,
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, ""),
+        )
+        sso_handler.render_error.assert_called_once_with(
+            request, "mapping_error", "localpart is invalid: föö"
         )
-        self.assertEqual(str(e.value), "localpart is invalid: föö")
+        auth_handler.complete_sso_login.assert_not_called()
 
     def test_map_saml_response_to_user_retries(self):
         """The mapping provider can retry generating an MXID if the MXID is already in use."""
+
+        # stub out the auth handler and error renderer
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+        sso_handler = self.hs.get_sso_handler()
+        sso_handler.render_error = Mock(return_value=None)
+
+        # register a user to occupy the first-choice MXID
         store = self.hs.get_datastore()
         self.get_success(
             store.register_user(user_id="@test_user:test", password_hash=None)
         )
+
+        # send the fake SAML response
         saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
-        redirect_url = ""
-        mxid = self.get_success(
-            self.handler._map_saml_response_to_user(
-                saml_response, redirect_url, "user-agent", "10.10.10.10"
-            )
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, ""),
         )
+
         # test_user is already taken, so test_user1 gets registered instead.
-        self.assertEqual(mxid, "@test_user1:test")
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user1:test", request, "", None, new_user=True
+        )
+        auth_handler.complete_sso_login.reset_mock()
 
         # Register all of the potential mxids for a particular SAML username.
         self.get_success(
@@ -165,15 +229,15 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # Now attempt to map to a username, this will fail since all potential usernames are taken.
         saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"})
-        e = self.get_failure(
-            self.handler._map_saml_response_to_user(
-                saml_response, redirect_url, "user-agent", "10.10.10.10"
-            ),
-            MappingException,
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, ""),
         )
-        self.assertEqual(
-            str(e.value), "Unable to generate a Matrix ID from the SSO response"
+        sso_handler.render_error.assert_called_once_with(
+            request,
+            "mapping_error",
+            "Unable to generate a Matrix ID from the SSO response",
         )
+        auth_handler.complete_sso_login.assert_not_called()
 
     @override_config(
         {
@@ -185,12 +249,17 @@ class SamlHandlerTestCase(HomeserverTestCase):
         }
     )
     def test_map_saml_response_redirect(self):
+        """Test a mapping provider that raises a RedirectException"""
+
         saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
-        redirect_url = ""
+        request = _mock_request()
         e = self.get_failure(
-            self.handler._map_saml_response_to_user(
-                saml_response, redirect_url, "user-agent", "10.10.10.10"
-            ),
+            self.handler._handle_authn_response(request, saml_response, ""),
             RedirectException,
         )
         self.assertEqual(e.value.location, b"https://custom-saml-redirect/")
+
+
+def _mock_request():
+    """Returns a mock which will stand in as a SynapseRequest"""
+    return Mock(spec=["getClientIP", "getHeader"])
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index abbdf2d524..96e5bdac4a 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -15,18 +15,20 @@
 
 
 import json
+from typing import Dict
 
 from mock import ANY, Mock, call
 
 from twisted.internet import defer
+from twisted.web.resource import Resource
 
 from synapse.api.errors import AuthError
+from synapse.federation.transport.server import TransportLayerServer
 from synapse.types import UserID, create_requester
 
 from tests import unittest
 from tests.test_utils import make_awaitable
 from tests.unittest import override_config
-from tests.utils import register_federation_servlets
 
 # Some local users to test with
 U_APPLE = UserID.from_string("@apple:test")
@@ -53,8 +55,6 @@ def _make_edu_transaction_json(edu_type, content):
 
 
 class TypingNotificationsTestCase(unittest.HomeserverTestCase):
-    servlets = [register_federation_servlets]
-
     def make_homeserver(self, reactor, clock):
         # we mock out the keyring so as to skip the authentication check on the
         # federation API call.
@@ -70,13 +70,18 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         hs = self.setup_test_homeserver(
             notifier=Mock(),
-            http_client=mock_federation_client,
+            federation_http_client=mock_federation_client,
             keyring=mock_keyring,
             replication_streams={},
         )
 
         return hs
 
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        d = super().create_resource_dict()
+        d["/_matrix/federation"] = TransportLayerServer(self.hs)
+        return d
+
     def prepare(self, reactor, clock, hs):
         mock_notifier = hs.get_notifier()
         self.on_new_event = mock_notifier.on_new_event
@@ -192,7 +197,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        put_json = self.hs.get_http_client().put_json
+        put_json = self.hs.get_federation_http_client().put_json
         put_json.assert_called_once_with(
             "farm",
             path="/_matrix/federation/v1/send/1000000",
@@ -215,7 +220,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.assertEquals(self.event_source.get_current_key(), 0)
 
-        (request, channel) = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/_matrix/federation/v1/send/1000000",
             _make_edu_transaction_json(
@@ -270,7 +275,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
 
-        put_json = self.hs.get_http_client().put_json
+        put_json = self.hs.get_federation_http_client().put_json
         put_json.assert_called_once_with(
             "farm",
             path="/_matrix/federation/v1/send/1000000",
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 98e5af2072..9c886d671a 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -54,6 +54,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
                 user_id=support_user_id, password_hash=None, user_type=UserTypes.SUPPORT
             )
         )
+        regular_user_id = "@regular:test"
+        self.get_success(
+            self.store.register_user(user_id=regular_user_id, password_hash=None)
+        )
 
         self.get_success(
             self.handler.handle_local_profile_change(support_user_id, None)
@@ -63,13 +67,47 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         display_name = "display_name"
 
         profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name)
-        regular_user_id = "@regular:test"
         self.get_success(
             self.handler.handle_local_profile_change(regular_user_id, profile_info)
         )
         profile = self.get_success(self.store.get_user_in_directory(regular_user_id))
         self.assertTrue(profile["display_name"] == display_name)
 
+    def test_handle_local_profile_change_with_deactivated_user(self):
+        # create user
+        r_user_id = "@regular:test"
+        self.get_success(
+            self.store.register_user(user_id=r_user_id, password_hash=None)
+        )
+
+        # update profile
+        display_name = "Regular User"
+        profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name)
+        self.get_success(
+            self.handler.handle_local_profile_change(r_user_id, profile_info)
+        )
+
+        # profile is in directory
+        profile = self.get_success(self.store.get_user_in_directory(r_user_id))
+        self.assertTrue(profile["display_name"] == display_name)
+
+        # deactivate user
+        self.get_success(self.store.set_user_deactivated_status(r_user_id, True))
+        self.get_success(self.handler.handle_user_deactivated(r_user_id))
+
+        # profile is not in directory
+        profile = self.get_success(self.store.get_user_in_directory(r_user_id))
+        self.assertTrue(profile is None)
+
+        # update profile after deactivation
+        self.get_success(
+            self.handler.handle_local_profile_change(r_user_id, profile_info)
+        )
+
+        # profile is furthermore not in directory
+        profile = self.get_success(self.store.get_user_in_directory(r_user_id))
+        self.assertTrue(profile is None)
+
     def test_handle_user_deactivated_support_user(self):
         s_user_id = "@support:test"
         self.get_success(
@@ -270,7 +308,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         spam_checker = self.hs.get_spam_checker()
 
         class AllowAll:
-            def check_username_for_spam(self, user_profile):
+            async def check_username_for_spam(self, user_profile):
                 # Allow all users.
                 return False
 
@@ -283,7 +321,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
 
         # Configure a spam checker that filters all users.
         class BlockAll:
-            def check_username_for_spam(self, user_profile):
+            async def check_username_for_spam(self, user_profile):
                 # All users are spammy.
                 return True
 
@@ -534,7 +572,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
         self.helper.join(room, user=u2)
 
         # Assert user directory is not empty
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", b"user_directory/search", b'{"search_term":"user2"}'
         )
         self.assertEquals(200, channel.code, channel.result)
@@ -542,7 +580,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
 
         # Disable user directory and check search returns nothing
         self.config.user_directory_search_enabled = False
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", b"user_directory/search", b'{"search_term":"user2"}'
         )
         self.assertEquals(200, channel.code, channel.result)
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 8b5ad4574f..686012dd25 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -17,6 +17,7 @@ import logging
 from mock import Mock
 
 import treq
+from netaddr import IPSet
 from service_identity import VerificationError
 from zope.interface import implementer
 
@@ -35,6 +36,7 @@ from synapse.crypto.context_factory import FederationPolicyForHTTPS
 from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
 from synapse.http.federation.srv_resolver import Server
 from synapse.http.federation.well_known_resolver import (
+    WELL_KNOWN_MAX_SIZE,
     WellKnownResolver,
     _cache_period_from_headers,
 )
@@ -103,6 +105,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
             reactor=self.reactor,
             tls_client_options_factory=self.tls_factory,
             user_agent="test-agent",  # Note that this is unused since _well_known_resolver is provided.
+            ip_blacklist=IPSet(),
             _srv_resolver=self.mock_resolver,
             _well_known_resolver=self.well_known_resolver,
         )
@@ -736,6 +739,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
             reactor=self.reactor,
             tls_client_options_factory=tls_factory,
             user_agent=b"test-agent",  # This is unused since _well_known_resolver is passed below.
+            ip_blacklist=IPSet(),
             _srv_resolver=self.mock_resolver,
             _well_known_resolver=WellKnownResolver(
                 self.reactor,
@@ -1091,7 +1095,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         # Expire both caches and repeat the request
         self.reactor.pump((10000.0,))
 
-        # Repated the request, this time it should fail if the lookup fails.
+        # Repeat the request, this time it should fail if the lookup fails.
         fetch_d = defer.ensureDeferred(
             self.well_known_resolver.get_well_known(b"testserv")
         )
@@ -1104,6 +1108,32 @@ class MatrixFederationAgentTests(unittest.TestCase):
         r = self.successResultOf(fetch_d)
         self.assertEqual(r.delegated_server, None)
 
+    def test_well_known_too_large(self):
+        """A well-known query that returns a result which is too large should be rejected."""
+        self.reactor.lookups["testserv"] = "1.2.3.4"
+
+        fetch_d = defer.ensureDeferred(
+            self.well_known_resolver.get_well_known(b"testserv")
+        )
+
+        # there should be an attempt to connect on port 443 for the .well-known
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+        self.assertEqual(host, "1.2.3.4")
+        self.assertEqual(port, 443)
+
+        self._handle_well_known_connection(
+            client_factory,
+            expected_sni=b"testserv",
+            response_headers={b"Cache-Control": b"max-age=1000"},
+            content=b'{ "m.server": "' + (b"a" * WELL_KNOWN_MAX_SIZE) + b'" }',
+        )
+
+        # The result is successful, but disabled delegation.
+        r = self.successResultOf(fetch_d)
+        self.assertIsNone(r.delegated_server)
+
     def test_srv_fallbacks(self):
         """Test that other SRV results are tried if the first one fails.
         """
diff --git a/tests/http/test_additional_resource.py b/tests/http/test_additional_resource.py
index 05e9c449be..453391a5a5 100644
--- a/tests/http/test_additional_resource.py
+++ b/tests/http/test_additional_resource.py
@@ -46,16 +46,16 @@ class AdditionalResourceTests(HomeserverTestCase):
         handler = _AsyncTestCustomEndpoint({}, None).handle_request
         resource = AdditionalResource(self.hs, handler)
 
-        request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
+        channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
 
-        self.assertEqual(request.code, 200)
+        self.assertEqual(channel.code, 200)
         self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
 
     def test_sync(self):
         handler = _SyncTestCustomEndpoint({}, None).handle_request
         resource = AdditionalResource(self.hs, handler)
 
-        request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
+        channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
 
-        self.assertEqual(request.code, 200)
+        self.assertEqual(channel.code, 200)
         self.assertEqual(channel.json_body, {"some_key": "some_value_sync"})
diff --git a/tests/http/test_client.py b/tests/http/test_client.py
new file mode 100644
index 0000000000..f17c122e93
--- /dev/null
+++ b/tests/http/test_client.py
@@ -0,0 +1,101 @@
+#  Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+
+from io import BytesIO
+
+from mock import Mock
+
+from twisted.python.failure import Failure
+from twisted.web.client import ResponseDone
+
+from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
+
+from tests.unittest import TestCase
+
+
+class ReadBodyWithMaxSizeTests(TestCase):
+    def setUp(self):
+        """Start reading the body, returns the response, result and proto"""
+        self.response = Mock()
+        self.result = BytesIO()
+        self.deferred = read_body_with_max_size(self.response, self.result, 6)
+
+        # Fish the protocol out of the response.
+        self.protocol = self.response.deliverBody.call_args[0][0]
+        self.protocol.transport = Mock()
+
+    def _cleanup_error(self):
+        """Ensure that the error in the Deferred is handled gracefully."""
+        called = [False]
+
+        def errback(f):
+            called[0] = True
+
+        self.deferred.addErrback(errback)
+        self.assertTrue(called[0])
+
+    def test_no_error(self):
+        """A response that is NOT too large."""
+
+        # Start sending data.
+        self.protocol.dataReceived(b"12345")
+        # Close the connection.
+        self.protocol.connectionLost(Failure(ResponseDone()))
+
+        self.assertEqual(self.result.getvalue(), b"12345")
+        self.assertEqual(self.deferred.result, 5)
+
+    def test_too_large(self):
+        """A response which is too large raises an exception."""
+
+        # Start sending data.
+        self.protocol.dataReceived(b"1234567890")
+        # Close the connection.
+        self.protocol.connectionLost(Failure(ResponseDone()))
+
+        self.assertEqual(self.result.getvalue(), b"1234567890")
+        self.assertIsInstance(self.deferred.result, Failure)
+        self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
+        self._cleanup_error()
+
+    def test_multiple_packets(self):
+        """Data should be accummulated through mutliple packets."""
+
+        # Start sending data.
+        self.protocol.dataReceived(b"12")
+        self.protocol.dataReceived(b"34")
+        # Close the connection.
+        self.protocol.connectionLost(Failure(ResponseDone()))
+
+        self.assertEqual(self.result.getvalue(), b"1234")
+        self.assertEqual(self.deferred.result, 4)
+
+    def test_additional_data(self):
+        """A connection can receive data after being closed."""
+
+        # Start sending data.
+        self.protocol.dataReceived(b"1234567890")
+        self.assertIsInstance(self.deferred.result, Failure)
+        self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
+        self.protocol.transport.loseConnection.assert_called_once()
+
+        # More data might have come in.
+        self.protocol.dataReceived(b"1234567890")
+        # Close the connection.
+        self.protocol.connectionLost(Failure(ResponseDone()))
+
+        self.assertEqual(self.result.getvalue(), b"1234567890")
+        self.assertIsInstance(self.deferred.result, Failure)
+        self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
+        self._cleanup_error()
diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py
index b2e9533b07..d06ea518ce 100644
--- a/tests/http/test_endpoint.py
+++ b/tests/http/test_endpoint.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from synapse.http.endpoint import parse_and_validate_server_name, parse_server_name
+from synapse.util.stringutils import parse_and_validate_server_name, parse_server_name
 
 from tests import unittest
 
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index 212484a7fe..9c52c8fdca 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -560,4 +560,4 @@ class FederationClientTests(HomeserverTestCase):
         self.pump()
 
         f = self.failureResultOf(test_d)
-        self.assertIsInstance(f.value, ValueError)
+        self.assertIsInstance(f.value, RequestSendFailed)
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index 22abf76515..9a56e1c14a 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -15,12 +15,14 @@
 import logging
 
 import treq
+from netaddr import IPSet
 
 from twisted.internet import interfaces  # noqa: F401
 from twisted.internet.protocol import Factory
 from twisted.protocols.tls import TLSMemoryBIOFactory
 from twisted.web.http import HTTPChannel
 
+from synapse.http.client import BlacklistingReactorWrapper
 from synapse.http.proxyagent import ProxyAgent
 
 from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
@@ -292,6 +294,134 @@ class MatrixFederationAgentTests(TestCase):
         body = self.successResultOf(treq.content(resp))
         self.assertEqual(body, b"result")
 
+    def test_http_request_via_proxy_with_blacklist(self):
+        # The blacklist includes the configured proxy IP.
+        agent = ProxyAgent(
+            BlacklistingReactorWrapper(
+                self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])
+            ),
+            self.reactor,
+            http_proxy=b"proxy.com:8888",
+        )
+
+        self.reactor.lookups["proxy.com"] = "1.2.3.5"
+        d = agent.request(b"GET", b"http://test.com")
+
+        # there should be a pending TCP connection
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+        self.assertEqual(host, "1.2.3.5")
+        self.assertEqual(port, 8888)
+
+        # make a test server, and wire up the client
+        http_server = self._make_connection(
+            client_factory, _get_test_protocol_factory()
+        )
+
+        # the FakeTransport is async, so we need to pump the reactor
+        self.reactor.advance(0)
+
+        # now there should be a pending request
+        self.assertEqual(len(http_server.requests), 1)
+
+        request = http_server.requests[0]
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"http://test.com")
+        self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+        request.write(b"result")
+        request.finish()
+
+        self.reactor.advance(0)
+
+        resp = self.successResultOf(d)
+        body = self.successResultOf(treq.content(resp))
+        self.assertEqual(body, b"result")
+
+    def test_https_request_via_proxy_with_blacklist(self):
+        # The blacklist includes the configured proxy IP.
+        agent = ProxyAgent(
+            BlacklistingReactorWrapper(
+                self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])
+            ),
+            self.reactor,
+            contextFactory=get_test_https_policy(),
+            https_proxy=b"proxy.com",
+        )
+
+        self.reactor.lookups["proxy.com"] = "1.2.3.5"
+        d = agent.request(b"GET", b"https://test.com/abc")
+
+        # there should be a pending TCP connection
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+        self.assertEqual(host, "1.2.3.5")
+        self.assertEqual(port, 1080)
+
+        # make a test HTTP server, and wire up the client
+        proxy_server = self._make_connection(
+            client_factory, _get_test_protocol_factory()
+        )
+
+        # fish the transports back out so that we can do the old switcheroo
+        s2c_transport = proxy_server.transport
+        client_protocol = s2c_transport.other
+        c2s_transport = client_protocol.transport
+
+        # the FakeTransport is async, so we need to pump the reactor
+        self.reactor.advance(0)
+
+        # now there should be a pending CONNECT request
+        self.assertEqual(len(proxy_server.requests), 1)
+
+        request = proxy_server.requests[0]
+        self.assertEqual(request.method, b"CONNECT")
+        self.assertEqual(request.path, b"test.com:443")
+
+        # tell the proxy server not to close the connection
+        proxy_server.persistent = True
+
+        # this just stops the http Request trying to do a chunked response
+        # request.setHeader(b"Content-Length", b"0")
+        request.finish()
+
+        # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
+        ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
+        ssl_protocol = ssl_factory.buildProtocol(None)
+        http_server = ssl_protocol.wrappedProtocol
+
+        ssl_protocol.makeConnection(
+            FakeTransport(client_protocol, self.reactor, ssl_protocol)
+        )
+        c2s_transport.other = ssl_protocol
+
+        self.reactor.advance(0)
+
+        server_name = ssl_protocol._tlsConnection.get_servername()
+        expected_sni = b"test.com"
+        self.assertEqual(
+            server_name,
+            expected_sni,
+            "Expected SNI %s but got %s" % (expected_sni, server_name),
+        )
+
+        # now there should be a pending request
+        self.assertEqual(len(http_server.requests), 1)
+
+        request = http_server.requests[0]
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"/abc")
+        self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+        request.write(b"result")
+        request.finish()
+
+        self.reactor.advance(0)
+
+        resp = self.successResultOf(d)
+        body = self.successResultOf(treq.content(resp))
+        self.assertEqual(body, b"result")
+
 
 def _wrap_server_factory_for_tls(factory, sanlist=None):
     """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index 73f469b802..48a74e2eee 100644
--- a/tests/logging/test_terse_json.py
+++ b/tests/logging/test_terse_json.py
@@ -18,30 +18,35 @@ import logging
 from io import StringIO
 
 from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter
+from synapse.logging.context import LoggingContext, LoggingContextFilter
 
 from tests.logging import LoggerCleanupMixin
 from tests.unittest import TestCase
 
 
 class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
+    def setUp(self):
+        self.output = StringIO()
+
+    def get_log_line(self):
+        # One log message, with a single trailing newline.
+        data = self.output.getvalue()
+        logs = data.splitlines()
+        self.assertEqual(len(logs), 1)
+        self.assertEqual(data.count("\n"), 1)
+        return json.loads(logs[0])
+
     def test_terse_json_output(self):
         """
         The Terse JSON formatter converts log messages to JSON.
         """
-        output = StringIO()
-
-        handler = logging.StreamHandler(output)
+        handler = logging.StreamHandler(self.output)
         handler.setFormatter(TerseJsonFormatter())
         logger = self.get_logger(handler)
 
         logger.info("Hello there, %s!", "wally")
 
-        # One log message, with a single trailing newline.
-        data = output.getvalue()
-        logs = data.splitlines()
-        self.assertEqual(len(logs), 1)
-        self.assertEqual(data.count("\n"), 1)
-        log = json.loads(logs[0])
+        log = self.get_log_line()
 
         # The terse logger should give us these keys.
         expected_log_keys = [
@@ -57,9 +62,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
         """
         Additional information can be included in the structured logging.
         """
-        output = StringIO()
-
-        handler = logging.StreamHandler(output)
+        handler = logging.StreamHandler(self.output)
         handler.setFormatter(TerseJsonFormatter())
         logger = self.get_logger(handler)
 
@@ -67,12 +70,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
             "Hello there, %s!", "wally", extra={"foo": "bar", "int": 3, "bool": True}
         )
 
-        # One log message, with a single trailing newline.
-        data = output.getvalue()
-        logs = data.splitlines()
-        self.assertEqual(len(logs), 1)
-        self.assertEqual(data.count("\n"), 1)
-        log = json.loads(logs[0])
+        log = self.get_log_line()
 
         # The terse logger should give us these keys.
         expected_log_keys = [
@@ -96,26 +94,44 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
         """
         The Terse JSON formatter converts log messages to JSON.
         """
-        output = StringIO()
-
-        handler = logging.StreamHandler(output)
+        handler = logging.StreamHandler(self.output)
         handler.setFormatter(JsonFormatter())
         logger = self.get_logger(handler)
 
         logger.info("Hello there, %s!", "wally")
 
-        # One log message, with a single trailing newline.
-        data = output.getvalue()
-        logs = data.splitlines()
-        self.assertEqual(len(logs), 1)
-        self.assertEqual(data.count("\n"), 1)
-        log = json.loads(logs[0])
+        log = self.get_log_line()
+
+        # The terse logger should give us these keys.
+        expected_log_keys = [
+            "log",
+            "level",
+            "namespace",
+        ]
+        self.assertCountEqual(log.keys(), expected_log_keys)
+        self.assertEqual(log["log"], "Hello there, wally!")
+
+    def test_with_context(self):
+        """
+        The logging context should be added to the JSON response.
+        """
+        handler = logging.StreamHandler(self.output)
+        handler.setFormatter(JsonFormatter())
+        handler.addFilter(LoggingContextFilter())
+        logger = self.get_logger(handler)
+
+        with LoggingContext(request="test"):
+            logger.info("Hello there, %s!", "wally")
+
+        log = self.get_log_line()
 
         # The terse logger should give us these keys.
         expected_log_keys = [
             "log",
             "level",
             "namespace",
+            "request",
         ]
         self.assertCountEqual(log.keys(), expected_log_keys)
         self.assertEqual(log["log"], "Hello there, wally!")
+        self.assertEqual(log["request"], "test")
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index bcdcafa5a9..c4e1e7ed85 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -187,6 +187,36 @@ class EmailPusherTests(HomeserverTestCase):
         # We should get emailed about those messages
         self._check_for_mail()
 
+    def test_multiple_rooms(self):
+        # We want to test multiple notifications from multiple rooms, so we pause
+        # processing of push while we send messages.
+        self.pusher._pause_processing()
+
+        # Create a simple room with multiple other users
+        rooms = [
+            self.helper.create_room_as(self.user_id, tok=self.access_token),
+            self.helper.create_room_as(self.user_id, tok=self.access_token),
+        ]
+
+        for r, other in zip(rooms, self.others):
+            self.helper.invite(
+                room=r, src=self.user_id, tok=self.access_token, targ=other.id
+            )
+            self.helper.join(room=r, user=other.id, tok=other.token)
+
+        # The other users send some messages
+        self.helper.send(rooms[0], body="Hi!", tok=self.others[0].token)
+        self.helper.send(rooms[1], body="There!", tok=self.others[1].token)
+        self.helper.send(rooms[1], body="There!", tok=self.others[1].token)
+
+        # Nothing should have happened yet, as we're paused.
+        assert not self.email_attempts
+
+        self.pusher._resume_processing()
+
+        # We should get emailed about those messages
+        self._check_for_mail()
+
     def test_encrypted_message(self):
         room = self.helper.create_room_as(self.user_id, tok=self.access_token)
         self.helper.invite(
@@ -209,7 +239,7 @@ class EmailPusherTests(HomeserverTestCase):
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
-        last_stream_ordering = pushers[0]["last_stream_ordering"]
+        last_stream_ordering = pushers[0].last_stream_ordering
 
         # Advance time a bit, so the pusher will register something has happened
         self.pump(10)
@@ -220,7 +250,7 @@ class EmailPusherTests(HomeserverTestCase):
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
-        self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"])
+        self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
 
         # One email was attempted to be sent
         self.assertEqual(len(self.email_attempts), 1)
@@ -238,4 +268,4 @@ class EmailPusherTests(HomeserverTestCase):
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
-        self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
+        self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index f118430309..60f0820cff 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -18,6 +18,7 @@ from twisted.internet.defer import Deferred
 
 import synapse.rest.admin
 from synapse.logging.context import make_deferred_yieldable
+from synapse.push import PusherConfigException
 from synapse.rest.client.v1 import login, room
 from synapse.rest.client.v2_alpha import receipts
 
@@ -34,6 +35,11 @@ class HTTPPusherTests(HomeserverTestCase):
     user_id = True
     hijack_auth = False
 
+    def default_config(self):
+        config = super().default_config()
+        config["start_pushers"] = True
+        return config
+
     def make_homeserver(self, reactor, clock):
         self.push_attempts = []
 
@@ -46,13 +52,49 @@ class HTTPPusherTests(HomeserverTestCase):
 
         m.post_json_get_json = post_json_get_json
 
-        config = self.default_config()
-        config["start_pushers"] = True
-
-        hs = self.setup_test_homeserver(config=config, proxied_http_client=m)
+        hs = self.setup_test_homeserver(proxied_blacklisted_http_client=m)
 
         return hs
 
+    def test_invalid_configuration(self):
+        """Invalid push configurations should be rejected."""
+        # Register the user who gets notified
+        user_id = self.register_user("user", "pass")
+        access_token = self.login("user", "pass")
+
+        # Register the pusher
+        user_tuple = self.get_success(
+            self.hs.get_datastore().get_user_by_access_token(access_token)
+        )
+        token_id = user_tuple.token_id
+
+        def test_data(data):
+            self.get_failure(
+                self.hs.get_pusherpool().add_pusher(
+                    user_id=user_id,
+                    access_token=token_id,
+                    kind="http",
+                    app_id="m.http",
+                    app_display_name="HTTP Push Notifications",
+                    device_display_name="pushy push",
+                    pushkey="a@example.com",
+                    lang=None,
+                    data=data,
+                ),
+                PusherConfigException,
+            )
+
+        # Data must be provided with a URL.
+        test_data(None)
+        test_data({})
+        test_data({"url": 1})
+        # A bare domain name isn't accepted.
+        test_data({"url": "example.com"})
+        # A URL without a path isn't accepted.
+        test_data({"url": "http://example.com"})
+        # A url with an incorrect path isn't accepted.
+        test_data({"url": "http://example.com/foo"})
+
     def test_sends_http(self):
         """
         The HTTP pusher will send pushes for each message to a HTTP endpoint
@@ -82,7 +124,7 @@ class HTTPPusherTests(HomeserverTestCase):
                 device_display_name="pushy push",
                 pushkey="a@example.com",
                 lang=None,
-                data={"url": "example.com"},
+                data={"url": "http://example.com/_matrix/push/v1/notify"},
             )
         )
 
@@ -102,7 +144,7 @@ class HTTPPusherTests(HomeserverTestCase):
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
-        last_stream_ordering = pushers[0]["last_stream_ordering"]
+        last_stream_ordering = pushers[0].last_stream_ordering
 
         # Advance time a bit, so the pusher will register something has happened
         self.pump()
@@ -113,11 +155,13 @@ class HTTPPusherTests(HomeserverTestCase):
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
-        self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"])
+        self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
 
         # One push was attempted to be sent -- it'll be the first message
         self.assertEqual(len(self.push_attempts), 1)
-        self.assertEqual(self.push_attempts[0][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+        )
         self.assertEqual(
             self.push_attempts[0][2]["notification"]["content"]["body"], "Hi!"
         )
@@ -132,12 +176,14 @@ class HTTPPusherTests(HomeserverTestCase):
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
-        self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
-        last_stream_ordering = pushers[0]["last_stream_ordering"]
+        self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
+        last_stream_ordering = pushers[0].last_stream_ordering
 
         # Now it'll try and send the second push message, which will be the second one
         self.assertEqual(len(self.push_attempts), 2)
-        self.assertEqual(self.push_attempts[1][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+        )
         self.assertEqual(
             self.push_attempts[1][2]["notification"]["content"]["body"], "There!"
         )
@@ -152,7 +198,7 @@ class HTTPPusherTests(HomeserverTestCase):
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
-        self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
+        self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
 
     def test_sends_high_priority_for_encrypted(self):
         """
@@ -194,7 +240,7 @@ class HTTPPusherTests(HomeserverTestCase):
                 device_display_name="pushy push",
                 pushkey="a@example.com",
                 lang=None,
-                data={"url": "example.com"},
+                data={"url": "http://example.com/_matrix/push/v1/notify"},
             )
         )
 
@@ -230,7 +276,9 @@ class HTTPPusherTests(HomeserverTestCase):
 
         # Check our push made it with high priority
         self.assertEqual(len(self.push_attempts), 1)
-        self.assertEqual(self.push_attempts[0][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+        )
         self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
 
         # Add yet another person — we want to make this room not a 1:1
@@ -268,7 +316,9 @@ class HTTPPusherTests(HomeserverTestCase):
         # Advance time a bit, so the pusher will register something has happened
         self.pump()
         self.assertEqual(len(self.push_attempts), 2)
-        self.assertEqual(self.push_attempts[1][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+        )
         self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
 
     def test_sends_high_priority_for_one_to_one_only(self):
@@ -310,7 +360,7 @@ class HTTPPusherTests(HomeserverTestCase):
                 device_display_name="pushy push",
                 pushkey="a@example.com",
                 lang=None,
-                data={"url": "example.com"},
+                data={"url": "http://example.com/_matrix/push/v1/notify"},
             )
         )
 
@@ -326,7 +376,9 @@ class HTTPPusherTests(HomeserverTestCase):
 
         # Check our push made it with high priority — this is a one-to-one room
         self.assertEqual(len(self.push_attempts), 1)
-        self.assertEqual(self.push_attempts[0][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+        )
         self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
 
         # Yet another user joins
@@ -345,7 +397,9 @@ class HTTPPusherTests(HomeserverTestCase):
         # Advance time a bit, so the pusher will register something has happened
         self.pump()
         self.assertEqual(len(self.push_attempts), 2)
-        self.assertEqual(self.push_attempts[1][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+        )
 
         # check that this is low-priority
         self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
@@ -392,7 +446,7 @@ class HTTPPusherTests(HomeserverTestCase):
                 device_display_name="pushy push",
                 pushkey="a@example.com",
                 lang=None,
-                data={"url": "example.com"},
+                data={"url": "http://example.com/_matrix/push/v1/notify"},
             )
         )
 
@@ -408,7 +462,9 @@ class HTTPPusherTests(HomeserverTestCase):
 
         # Check our push made it with high priority
         self.assertEqual(len(self.push_attempts), 1)
-        self.assertEqual(self.push_attempts[0][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+        )
         self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
 
         # Send another event, this time with no mention
@@ -417,7 +473,9 @@ class HTTPPusherTests(HomeserverTestCase):
         # Advance time a bit, so the pusher will register something has happened
         self.pump()
         self.assertEqual(len(self.push_attempts), 2)
-        self.assertEqual(self.push_attempts[1][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+        )
 
         # check that this is low-priority
         self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
@@ -465,7 +523,7 @@ class HTTPPusherTests(HomeserverTestCase):
                 device_display_name="pushy push",
                 pushkey="a@example.com",
                 lang=None,
-                data={"url": "example.com"},
+                data={"url": "http://example.com/_matrix/push/v1/notify"},
             )
         )
 
@@ -485,7 +543,9 @@ class HTTPPusherTests(HomeserverTestCase):
 
         # Check our push made it with high priority
         self.assertEqual(len(self.push_attempts), 1)
-        self.assertEqual(self.push_attempts[0][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+        )
         self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
 
         # Send another event, this time as someone without the power of @room
@@ -496,7 +556,9 @@ class HTTPPusherTests(HomeserverTestCase):
         # Advance time a bit, so the pusher will register something has happened
         self.pump()
         self.assertEqual(len(self.push_attempts), 2)
-        self.assertEqual(self.push_attempts[1][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+        )
 
         # check that this is low-priority
         self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
@@ -570,7 +632,7 @@ class HTTPPusherTests(HomeserverTestCase):
                 device_display_name="pushy push",
                 pushkey="a@example.com",
                 lang=None,
-                data={"url": "example.com"},
+                data={"url": "http://example.com/_matrix/push/v1/notify"},
             )
         )
 
@@ -589,7 +651,9 @@ class HTTPPusherTests(HomeserverTestCase):
 
         # Check our push made it
         self.assertEqual(len(self.push_attempts), 1)
-        self.assertEqual(self.push_attempts[0][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+        )
 
         # Check that the unread count for the room is 0
         #
@@ -603,7 +667,7 @@ class HTTPPusherTests(HomeserverTestCase):
         # This will actually trigger a new notification to be sent out so that
         # even if the user does not receive another message, their unread
         # count goes down
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/rooms/%s/receipt/m.read/%s" % (room_id, first_message_event_id),
             {},
diff --git a/tests/push/test_presentable_names.py b/tests/push/test_presentable_names.py
new file mode 100644
index 0000000000..aff563919d
--- /dev/null
+++ b/tests/push/test_presentable_names.py
@@ -0,0 +1,229 @@
+#  Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+
+from typing import Iterable, Optional, Tuple
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.room_versions import RoomVersions
+from synapse.events import FrozenEvent
+from synapse.push.presentable_names import calculate_room_name
+from synapse.types import StateKey, StateMap
+
+from tests import unittest
+
+
+class MockDataStore:
+    """
+    A fake data store which stores a mapping of state key to event content.
+    (I.e. the state key is used as the event ID.)
+    """
+
+    def __init__(self, events: Iterable[Tuple[StateKey, dict]]):
+        """
+        Args:
+            events: A state map to event contents.
+        """
+        self._events = {}
+
+        for i, (event_id, content) in enumerate(events):
+            self._events[event_id] = FrozenEvent(
+                {
+                    "event_id": "$event_id",
+                    "type": event_id[0],
+                    "sender": "@user:test",
+                    "state_key": event_id[1],
+                    "room_id": "#room:test",
+                    "content": content,
+                    "origin_server_ts": i,
+                },
+                RoomVersions.V1,
+            )
+
+    async def get_event(
+        self, event_id: StateKey, allow_none: bool = False
+    ) -> Optional[FrozenEvent]:
+        assert allow_none, "Mock not configured for allow_none = False"
+
+        return self._events.get(event_id)
+
+    async def get_events(self, event_ids: Iterable[StateKey]):
+        # This is cheating since it just returns all events.
+        return self._events
+
+
+class PresentableNamesTestCase(unittest.HomeserverTestCase):
+    USER_ID = "@test:test"
+    OTHER_USER_ID = "@user:test"
+
+    def _calculate_room_name(
+        self,
+        events: StateMap[dict],
+        user_id: str = "",
+        fallback_to_members: bool = True,
+        fallback_to_single_member: bool = True,
+    ):
+        # This isn't 100% accurate, but works with MockDataStore.
+        room_state_ids = {k[0]: k[0] for k in events}
+
+        return self.get_success(
+            calculate_room_name(
+                MockDataStore(events),
+                room_state_ids,
+                user_id or self.USER_ID,
+                fallback_to_members,
+                fallback_to_single_member,
+            )
+        )
+
+    def test_name(self):
+        """A room name event should be used."""
+        events = [
+            ((EventTypes.Name, ""), {"name": "test-name"}),
+        ]
+        self.assertEqual("test-name", self._calculate_room_name(events))
+
+        # Check if the event content has garbage.
+        events = [((EventTypes.Name, ""), {"foo": 1})]
+        self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+        events = [((EventTypes.Name, ""), {"name": 1})]
+        self.assertEqual(1, self._calculate_room_name(events))
+
+    def test_canonical_alias(self):
+        """An canonical alias should be used."""
+        events = [
+            ((EventTypes.CanonicalAlias, ""), {"alias": "#test-name:test"}),
+        ]
+        self.assertEqual("#test-name:test", self._calculate_room_name(events))
+
+        # Check if the event content has garbage.
+        events = [((EventTypes.CanonicalAlias, ""), {"foo": 1})]
+        self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+        events = [((EventTypes.CanonicalAlias, ""), {"alias": "test-name"})]
+        self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+    def test_invite(self):
+        """An invite has special behaviour."""
+        events = [
+            ((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
+            ((EventTypes.Member, self.OTHER_USER_ID), {"displayname": "Other User"}),
+        ]
+        self.assertEqual("Invite from Other User", self._calculate_room_name(events))
+        self.assertIsNone(
+            self._calculate_room_name(events, fallback_to_single_member=False)
+        )
+        # Ensure this logic is skipped if we don't fallback to members.
+        self.assertIsNone(self._calculate_room_name(events, fallback_to_members=False))
+
+        # Check if the event content has garbage.
+        events = [
+            ((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
+            ((EventTypes.Member, self.OTHER_USER_ID), {"foo": 1}),
+        ]
+        self.assertEqual("Invite from @user:test", self._calculate_room_name(events))
+
+        # No member event for sender.
+        events = [
+            ((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
+        ]
+        self.assertEqual("Room Invite", self._calculate_room_name(events))
+
+    def test_no_members(self):
+        """Behaviour of an empty room."""
+        events = []
+        self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+        # Note that events with invalid (or missing) membership are ignored.
+        events = [
+            ((EventTypes.Member, self.OTHER_USER_ID), {"foo": 1}),
+            ((EventTypes.Member, "@foo:test"), {"membership": "foo"}),
+        ]
+        self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+    def test_no_other_members(self):
+        """Behaviour of a room with no other members in it."""
+        events = [
+            (
+                (EventTypes.Member, self.USER_ID),
+                {"membership": Membership.JOIN, "displayname": "Me"},
+            ),
+        ]
+        self.assertEqual("Me", self._calculate_room_name(events))
+
+        # Check if the event content has no displayname.
+        events = [
+            ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+        ]
+        self.assertEqual("@test:test", self._calculate_room_name(events))
+
+        # 3pid invite, use the other user (who is set as the sender).
+        events = [
+            ((EventTypes.Member, self.OTHER_USER_ID), {"membership": Membership.JOIN}),
+        ]
+        self.assertEqual(
+            "nobody", self._calculate_room_name(events, user_id=self.OTHER_USER_ID)
+        )
+
+        events = [
+            ((EventTypes.Member, self.OTHER_USER_ID), {"membership": Membership.JOIN}),
+            ((EventTypes.ThirdPartyInvite, self.OTHER_USER_ID), {}),
+        ]
+        self.assertEqual(
+            "Inviting email address",
+            self._calculate_room_name(events, user_id=self.OTHER_USER_ID),
+        )
+
+    def test_one_other_member(self):
+        """Behaviour of a room with a single other member."""
+        events = [
+            ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+            (
+                (EventTypes.Member, self.OTHER_USER_ID),
+                {"membership": Membership.JOIN, "displayname": "Other User"},
+            ),
+        ]
+        self.assertEqual("Other User", self._calculate_room_name(events))
+        self.assertIsNone(
+            self._calculate_room_name(events, fallback_to_single_member=False)
+        )
+
+        # Check if the event content has no displayname and is an invite.
+        events = [
+            ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+            (
+                (EventTypes.Member, self.OTHER_USER_ID),
+                {"membership": Membership.INVITE},
+            ),
+        ]
+        self.assertEqual("@user:test", self._calculate_room_name(events))
+
+    def test_other_members(self):
+        """Behaviour of a room with multiple other members."""
+        # Two other members.
+        events = [
+            ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+            (
+                (EventTypes.Member, self.OTHER_USER_ID),
+                {"membership": Membership.JOIN, "displayname": "Other User"},
+            ),
+            ((EventTypes.Member, "@foo:test"), {"membership": Membership.JOIN}),
+        ]
+        self.assertEqual("Other User and @foo:test", self._calculate_room_name(events))
+
+        # Three or more other members.
+        events.append(
+            ((EventTypes.Member, "@fourth:test"), {"membership": Membership.INVITE})
+        )
+        self.assertEqual("Other User and 2 others", self._calculate_room_name(events))
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 1f4b5ca2ac..4a841f5bb8 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -29,7 +29,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
                 "type": "m.room.history_visibility",
                 "sender": "@user:test",
                 "state_key": "",
-                "room_id": "@room:test",
+                "room_id": "#room:test",
                 "content": content,
             },
             RoomVersions.V1,
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 295c5d58a6..d5dce1f83f 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Any, Callable, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
 
 import attr
 
@@ -21,6 +21,7 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
 from twisted.internet.protocol import Protocol
 from twisted.internet.task import LoopingCall
 from twisted.web.http import HTTPChannel
+from twisted.web.resource import Resource
 
 from synapse.app.generic_worker import (
     GenericWorkerReplicationHandler,
@@ -28,7 +29,7 @@ from synapse.app.generic_worker import (
 )
 from synapse.http.server import JsonResource
 from synapse.http.site import SynapseRequest, SynapseSite
-from synapse.replication.http import ReplicationRestResource, streams
+from synapse.replication.http import ReplicationRestResource
 from synapse.replication.tcp.handler import ReplicationCommandHandler
 from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -54,10 +55,6 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
     if not hiredis:
         skip = "Requires hiredis"
 
-    servlets = [
-        streams.register_servlets,
-    ]
-
     def prepare(self, reactor, clock, hs):
         # build a replication server
         server_factory = ReplicationStreamProtocolFactory(hs)
@@ -67,7 +64,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         # Make a new HomeServer object for the worker
         self.reactor.lookups["testserv"] = "1.2.3.4"
         self.worker_hs = self.setup_test_homeserver(
-            http_client=None,
+            federation_http_client=None,
             homeserver_to_use=GenericWorkerServer,
             config=self._get_worker_hs_config(),
             reactor=self.reactor,
@@ -88,6 +85,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self._client_transport = None
         self._server_transport = None
 
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        d = super().create_resource_dict()
+        d["/_synapse/replication"] = ReplicationRestResource(self.hs)
+        return d
+
     def _get_worker_hs_config(self) -> dict:
         config = self.default_config()
         config["worker_app"] = "synapse.app.generic_worker"
@@ -210,6 +212,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         # Fake in memory Redis server that servers can connect to.
         self._redis_server = FakeRedisPubSubServer()
 
+        # We may have an attempt to connect to redis for the external cache already.
+        self.connect_any_redis_attempts()
+
         store = self.hs.get_datastore()
         self.database_pool = store.db_pool
 
@@ -264,7 +269,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
             worker_app: Type of worker, e.g. `synapse.app.federation_sender`.
             extra_config: Any extra config to use for this instances.
             **kwargs: Options that get passed to `self.setup_test_homeserver`,
-                useful to e.g. pass some mocks for things like `http_client`
+                useful to e.g. pass some mocks for things like `federation_http_client`
 
         Returns:
             The new worker HomeServer instance.
@@ -399,25 +404,23 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         fake one.
         """
         clients = self.reactor.tcpClients
-        self.assertEqual(len(clients), 1)
-        (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
-        self.assertEqual(host, "localhost")
-        self.assertEqual(port, 6379)
+        while clients:
+            (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+            self.assertEqual(host, "localhost")
+            self.assertEqual(port, 6379)
 
-        client_protocol = client_factory.buildProtocol(None)
-        server_protocol = self._redis_server.buildProtocol(None)
+            client_protocol = client_factory.buildProtocol(None)
+            server_protocol = self._redis_server.buildProtocol(None)
 
-        client_to_server_transport = FakeTransport(
-            server_protocol, self.reactor, client_protocol
-        )
-        client_protocol.makeConnection(client_to_server_transport)
-
-        server_to_client_transport = FakeTransport(
-            client_protocol, self.reactor, server_protocol
-        )
-        server_protocol.makeConnection(server_to_client_transport)
+            client_to_server_transport = FakeTransport(
+                server_protocol, self.reactor, client_protocol
+            )
+            client_protocol.makeConnection(client_to_server_transport)
 
-        return client_to_server_transport, server_to_client_transport
+            server_to_client_transport = FakeTransport(
+                client_protocol, self.reactor, server_protocol
+            )
+            server_protocol.makeConnection(server_to_client_transport)
 
 
 class TestReplicationDataHandler(GenericWorkerReplicationHandler):
@@ -622,6 +625,12 @@ class FakeRedisPubSubProtocol(Protocol):
             (channel,) = args
             self._server.add_subscriber(self)
             self.send(["subscribe", channel, 1])
+
+        # Since we use SET/GET to cache things we can safely no-op them.
+        elif command == b"SET":
+            self.send("OK")
+        elif command == b"GET":
+            self.send(None)
         else:
             raise Exception("Unknown command")
 
@@ -643,6 +652,8 @@ class FakeRedisPubSubProtocol(Protocol):
             # We assume bytes are just unicode strings.
             obj = obj.decode("utf-8")
 
+        if obj is None:
+            return "$-1\r\n"
         if isinstance(obj, str):
             return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
         if isinstance(obj, int):
diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py
new file mode 100644
index 0000000000..f35a5235e1
--- /dev/null
+++ b/tests/replication/test_auth.py
@@ -0,0 +1,117 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.rest.client.v2_alpha import register
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.server import FakeChannel, make_request
+from tests.unittest import override_config
+
+logger = logging.getLogger(__name__)
+
+
+class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
+    """Test the authentication of HTTP calls between workers."""
+
+    servlets = [register.register_servlets]
+
+    def make_homeserver(self, reactor, clock):
+        config = self.default_config()
+        # This isn't a real configuration option but is used to provide the main
+        # homeserver and worker homeserver different options.
+        main_replication_secret = config.pop("main_replication_secret", None)
+        if main_replication_secret:
+            config["worker_replication_secret"] = main_replication_secret
+        return self.setup_test_homeserver(config=config)
+
+    def _get_worker_hs_config(self) -> dict:
+        config = self.default_config()
+        config["worker_app"] = "synapse.app.client_reader"
+        config["worker_replication_host"] = "testserv"
+        config["worker_replication_http_port"] = "8765"
+
+        return config
+
+    def _test_register(self) -> FakeChannel:
+        """Run the actual test:
+
+        1. Create a worker homeserver.
+        2. Start registration by providing a user/password.
+        3. Complete registration by providing dummy auth (this hits the main synapse).
+        4. Return the final request.
+
+        """
+        worker_hs = self.make_worker_hs("synapse.app.client_reader")
+        site = self._hs_to_site[worker_hs]
+
+        channel_1 = make_request(
+            self.reactor,
+            site,
+            "POST",
+            "register",
+            {"username": "user", "type": "m.login.password", "password": "bar"},
+        )
+        self.assertEqual(channel_1.code, 401)
+
+        # Grab the session
+        session = channel_1.json_body["session"]
+
+        # also complete the dummy auth
+        return make_request(
+            self.reactor,
+            site,
+            "POST",
+            "register",
+            {"auth": {"session": session, "type": "m.login.dummy"}},
+        )
+
+    def test_no_auth(self):
+        """With no authentication the request should finish.
+        """
+        channel = self._test_register()
+        self.assertEqual(channel.code, 200)
+
+        # We're given a registered user.
+        self.assertEqual(channel.json_body["user_id"], "@user:test")
+
+    @override_config({"main_replication_secret": "my-secret"})
+    def test_missing_auth(self):
+        """If the main process expects a secret that is not provided, an error results.
+        """
+        channel = self._test_register()
+        self.assertEqual(channel.code, 500)
+
+    @override_config(
+        {
+            "main_replication_secret": "my-secret",
+            "worker_replication_secret": "wrong-secret",
+        }
+    )
+    def test_unauthorized(self):
+        """If the main process receives the wrong secret, an error results.
+        """
+        channel = self._test_register()
+        self.assertEqual(channel.code, 500)
+
+    @override_config({"worker_replication_secret": "my-secret"})
+    def test_authorized(self):
+        """The request should finish when the worker provides the authentication header.
+        """
+        channel = self._test_register()
+        self.assertEqual(channel.code, 200)
+
+        # We're given a registered user.
+        self.assertEqual(channel.json_body["user_id"], "@user:test")
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
index 96801db473..4608b65a0c 100644
--- a/tests/replication/test_client_reader_shard.py
+++ b/tests/replication/test_client_reader_shard.py
@@ -14,27 +14,19 @@
 # limitations under the License.
 import logging
 
-from synapse.api.constants import LoginType
-from synapse.http.site import SynapseRequest
 from synapse.rest.client.v2_alpha import register
 
 from tests.replication._base import BaseMultiWorkerStreamTestCase
-from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
-from tests.server import FakeChannel, make_request
+from tests.server import make_request
 
 logger = logging.getLogger(__name__)
 
 
 class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
-    """Base class for tests of the replication streams"""
+    """Test using one or more client readers for registration."""
 
     servlets = [register.register_servlets]
 
-    def prepare(self, reactor, clock, hs):
-        self.recaptcha_checker = DummyRecaptchaChecker(hs)
-        auth_handler = hs.get_auth_handler()
-        auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
-
     def _get_worker_hs_config(self) -> dict:
         config = self.default_config()
         config["worker_app"] = "synapse.app.client_reader"
@@ -48,27 +40,27 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
         worker_hs = self.make_worker_hs("synapse.app.client_reader")
         site = self._hs_to_site[worker_hs]
 
-        request_1, channel_1 = make_request(
+        channel_1 = make_request(
             self.reactor,
             site,
             "POST",
             "register",
             {"username": "user", "type": "m.login.password", "password": "bar"},
-        )  # type: SynapseRequest, FakeChannel
-        self.assertEqual(request_1.code, 401)
+        )
+        self.assertEqual(channel_1.code, 401)
 
         # Grab the session
         session = channel_1.json_body["session"]
 
         # also complete the dummy auth
-        request_2, channel_2 = make_request(
+        channel_2 = make_request(
             self.reactor,
             site,
             "POST",
             "register",
             {"auth": {"session": session, "type": "m.login.dummy"}},
-        )  # type: SynapseRequest, FakeChannel
-        self.assertEqual(request_2.code, 200)
+        )
+        self.assertEqual(channel_2.code, 200)
 
         # We're given a registered user.
         self.assertEqual(channel_2.json_body["user_id"], "@user:test")
@@ -80,28 +72,28 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
         worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
 
         site_1 = self._hs_to_site[worker_hs_1]
-        request_1, channel_1 = make_request(
+        channel_1 = make_request(
             self.reactor,
             site_1,
             "POST",
             "register",
             {"username": "user", "type": "m.login.password", "password": "bar"},
-        )  # type: SynapseRequest, FakeChannel
-        self.assertEqual(request_1.code, 401)
+        )
+        self.assertEqual(channel_1.code, 401)
 
         # Grab the session
         session = channel_1.json_body["session"]
 
         # also complete the dummy auth
         site_2 = self._hs_to_site[worker_hs_2]
-        request_2, channel_2 = make_request(
+        channel_2 = make_request(
             self.reactor,
             site_2,
             "POST",
             "register",
             {"auth": {"session": session, "type": "m.login.dummy"}},
-        )  # type: SynapseRequest, FakeChannel
-        self.assertEqual(request_2.code, 200)
+        )
+        self.assertEqual(channel_2.code, 200)
 
         # We're given a registered user.
         self.assertEqual(channel_2.json_body["user_id"], "@user:test")
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 779745ae9d..fffdb742c8 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -50,7 +50,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         self.make_worker_hs(
             "synapse.app.federation_sender",
             {"send_federation": True},
-            http_client=mock_client,
+            federation_http_client=mock_client,
         )
 
         user = self.register_user("user", "pass")
@@ -81,7 +81,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
                 "worker_name": "sender1",
                 "federation_sender_instances": ["sender1", "sender2"],
             },
-            http_client=mock_client1,
+            federation_http_client=mock_client1,
         )
 
         mock_client2 = Mock(spec=["put_json"])
@@ -93,7 +93,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
                 "worker_name": "sender2",
                 "federation_sender_instances": ["sender1", "sender2"],
             },
-            http_client=mock_client2,
+            federation_http_client=mock_client2,
         )
 
         user = self.register_user("user2", "pass")
@@ -144,7 +144,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
                 "worker_name": "sender1",
                 "federation_sender_instances": ["sender1", "sender2"],
             },
-            http_client=mock_client1,
+            federation_http_client=mock_client1,
         )
 
         mock_client2 = Mock(spec=["put_json"])
@@ -156,7 +156,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
                 "worker_name": "sender2",
                 "federation_sender_instances": ["sender1", "sender2"],
             },
-            http_client=mock_client2,
+            federation_http_client=mock_client2,
         )
 
         user = self.register_user("user3", "pass")
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 48b574ccbe..d1feca961f 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -48,7 +48,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         self.user_id = self.register_user("user", "pass")
         self.access_token = self.login("user", "pass")
 
-        self.reactor.lookups["example.com"] = "127.0.0.2"
+        self.reactor.lookups["example.com"] = "1.2.3.4"
 
     def default_config(self):
         conf = super().default_config()
@@ -68,7 +68,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
             the media which the caller should respond to.
         """
         resource = hs.get_media_repository_resource().children[b"download"]
-        _, channel = make_request(
+        channel = make_request(
             self.reactor,
             FakeSite(resource),
             "GET",
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 67c27a089f..800ad94a04 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -67,7 +67,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
                 device_display_name="pushy push",
                 pushkey="a@example.com",
                 lang=None,
-                data={"url": "https://push.example.com/push"},
+                data={"url": "https://push.example.com/_matrix/push/v1/notify"},
             )
         )
 
@@ -98,7 +98,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         self.make_worker_hs(
             "synapse.app.pusher",
             {"start_pushers": True},
-            proxied_http_client=http_client_mock,
+            proxied_blacklisted_http_client=http_client_mock,
         )
 
         event_id = self._create_pusher_and_send_msg("user")
@@ -109,7 +109,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         http_client_mock.post_json_get_json.assert_called_once()
         self.assertEqual(
             http_client_mock.post_json_get_json.call_args[0][0],
-            "https://push.example.com/push",
+            "https://push.example.com/_matrix/push/v1/notify",
         )
         self.assertEqual(
             event_id,
@@ -133,7 +133,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
                 "worker_name": "pusher1",
                 "pusher_instances": ["pusher1", "pusher2"],
             },
-            proxied_http_client=http_client_mock1,
+            proxied_blacklisted_http_client=http_client_mock1,
         )
 
         http_client_mock2 = Mock(spec_set=["post_json_get_json"])
@@ -148,7 +148,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
                 "worker_name": "pusher2",
                 "pusher_instances": ["pusher1", "pusher2"],
             },
-            proxied_http_client=http_client_mock2,
+            proxied_blacklisted_http_client=http_client_mock2,
         )
 
         # We choose a user name that we know should go to pusher1.
@@ -161,7 +161,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         http_client_mock2.post_json_get_json.assert_not_called()
         self.assertEqual(
             http_client_mock1.post_json_get_json.call_args[0][0],
-            "https://push.example.com/push",
+            "https://push.example.com/_matrix/push/v1/notify",
         )
         self.assertEqual(
             event_id,
@@ -183,7 +183,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         http_client_mock2.post_json_get_json.assert_called_once()
         self.assertEqual(
             http_client_mock2.post_json_get_json.call_args[0][0],
-            "https://push.example.com/push",
+            "https://push.example.com/_matrix/push/v1/notify",
         )
         self.assertEqual(
             event_id,
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index 77fc3856d5..8d494ebc03 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -180,7 +180,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         )
 
         # Do an initial sync so that we're up to date.
-        request, channel = make_request(
+        channel = make_request(
             self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token
         )
         next_batch = channel.json_body["next_batch"]
@@ -206,7 +206,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
 
         # Check that syncing still gets the new event, despite the gap in the
         # stream IDs.
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             sync_hs_site,
             "GET",
@@ -236,7 +236,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token)
         first_event_in_room2 = response["event_id"]
 
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             sync_hs_site,
             "GET",
@@ -261,7 +261,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token)
         self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token)
 
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             sync_hs_site,
             "GET",
@@ -279,7 +279,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         # Paginating back in the first room should not produce any results, as
         # no events have happened in it. This tests that we are correctly
         # filtering results based on the vector clock portion.
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             sync_hs_site,
             "GET",
@@ -292,7 +292,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
 
         # Paginating back on the second room should produce the first event
         # again. This tests that pagination isn't completely broken.
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             sync_hs_site,
             "GET",
@@ -307,7 +307,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         )
 
         # Paginating forwards should give the same results
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             sync_hs_site,
             "GET",
@@ -318,7 +318,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         )
         self.assertListEqual([], channel.json_body["chunk"])
 
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             sync_hs_site,
             "GET",
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 4f76f8f768..9d22c04073 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -42,7 +42,7 @@ class VersionTestCase(unittest.HomeserverTestCase):
         return resource
 
     def test_version_string(self):
-        request, channel = self.make_request("GET", self.url, shorthand=False)
+        channel = self.make_request("GET", self.url, shorthand=False)
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(
@@ -58,8 +58,6 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
-
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
@@ -68,7 +66,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
 
     def test_delete_group(self):
         # Create a new group
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/create_group".encode("ascii"),
             access_token=self.admin_user_tok,
@@ -84,13 +82,13 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
         # Invite/join another user
 
         url = "/groups/%s/admin/users/invite/%s" % (group_id, self.other_user)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", url.encode("ascii"), access_token=self.admin_user_tok, content={}
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
 
         url = "/groups/%s/self/accept_invite" % (group_id,)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", url.encode("ascii"), access_token=self.other_user_token, content={}
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -101,7 +99,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
 
         # Now delete the group
         url = "/_synapse/admin/v1/delete_group/" + group_id
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             url.encode("ascii"),
             access_token=self.admin_user_tok,
@@ -123,7 +121,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
         """
 
         url = "/groups/%s/profile" % (group_id,)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", url.encode("ascii"), access_token=self.admin_user_tok
         )
 
@@ -134,7 +132,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
     def _get_groups_user_is_in(self, access_token):
         """Returns the list of groups the user is in (given their access token)
         """
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/joined_groups".encode("ascii"), access_token=access_token
         )
 
@@ -155,9 +153,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
-        self.hs = hs
-
         # Allow for uploading and downloading to/from the media repo
         self.media_repo = hs.get_media_repository_resource()
         self.download_resource = self.media_repo.children[b"download"]
@@ -210,13 +205,13 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
         }
         config["media_storage_providers"] = [provider_config]
 
-        hs = self.setup_test_homeserver(config=config, http_client=client)
+        hs = self.setup_test_homeserver(config=config, federation_http_client=client)
 
         return hs
 
     def _ensure_quarantined(self, admin_user_tok, server_and_media_id):
         """Ensure a piece of media is quarantined when trying to access it."""
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             FakeSite(self.download_resource),
             "GET",
@@ -241,7 +236,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
 
         # Attempt quarantine media APIs as non-admin
         url = "/_synapse/admin/v1/media/quarantine/example.org/abcde12345"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", url.encode("ascii"), access_token=non_admin_user_tok,
         )
 
@@ -254,7 +249,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
 
         # And the roomID/userID endpoint
         url = "/_synapse/admin/v1/room/!room%3Aexample.com/media/quarantine"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", url.encode("ascii"), access_token=non_admin_user_tok,
         )
 
@@ -282,7 +277,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
         server_name, media_id = server_name_and_media_id.split("/")
 
         # Attempt to access the media
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             FakeSite(self.download_resource),
             "GET",
@@ -299,7 +294,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
             urllib.parse.quote(server_name),
             urllib.parse.quote(media_id),
         )
-        request, channel = self.make_request("POST", url, access_token=admin_user_tok,)
+        channel = self.make_request("POST", url, access_token=admin_user_tok,)
         self.pump(1.0)
         self.assertEqual(200, int(channel.code), msg=channel.result["body"])
 
@@ -351,7 +346,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
             url = "/_synapse/admin/v1/room/%s/media/quarantine" % urllib.parse.quote(
                 room_id
             )
-        request, channel = self.make_request("POST", url, access_token=admin_user_tok,)
+        channel = self.make_request("POST", url, access_token=admin_user_tok,)
         self.pump(1.0)
         self.assertEqual(200, int(channel.code), msg=channel.result["body"])
         self.assertEqual(
@@ -395,7 +390,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
         url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
             non_admin_user
         )
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", url.encode("ascii"), access_token=admin_user_tok,
         )
         self.pump(1.0)
@@ -431,13 +426,17 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
 
         # Mark the second item as safe from quarantine.
         _, media_id_2 = server_and_media_id_2.split("/")
-        self.get_success(self.store.mark_local_media_as_safe(media_id_2))
+        # Quarantine the media
+        url = "/_synapse/admin/v1/media/protect/%s" % (urllib.parse.quote(media_id_2),)
+        channel = self.make_request("POST", url, access_token=admin_user_tok)
+        self.pump(1.0)
+        self.assertEqual(200, int(channel.code), msg=channel.result["body"])
 
         # Quarantine all media by this user
         url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
             non_admin_user
         )
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", url.encode("ascii"), access_token=admin_user_tok,
         )
         self.pump(1.0)
@@ -453,7 +452,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
         self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
 
         # Attempt to access each piece of media
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             FakeSite(self.download_resource),
             "GET",
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index cf3a007598..248c4442c3 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -50,17 +50,17 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         """
         Try to get a device of an user without authentication.
         """
-        request, channel = self.make_request("GET", self.url, b"{}")
+        channel = self.make_request("GET", self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
 
-        request, channel = self.make_request("PUT", self.url, b"{}")
+        channel = self.make_request("PUT", self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
 
-        request, channel = self.make_request("DELETE", self.url, b"{}")
+        channel = self.make_request("DELETE", self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -69,21 +69,21 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         """
         If the user is not a server admin, an error is returned.
         """
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url, access_token=self.other_user_token,
         )
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", self.url, access_token=self.other_user_token,
         )
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "DELETE", self.url, access_token=self.other_user_token,
         )
 
@@ -99,23 +99,17 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
             % self.other_user_device_id
         )
 
-        request, channel = self.make_request(
-            "GET", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
 
-        request, channel = self.make_request(
-            "PUT", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("PUT", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
 
-        request, channel = self.make_request(
-            "DELETE", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -129,23 +123,17 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
             % self.other_user_device_id
         )
 
-        request, channel = self.make_request(
-            "GET", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Can only lookup local users", channel.json_body["error"])
 
-        request, channel = self.make_request(
-            "PUT", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("PUT", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Can only lookup local users", channel.json_body["error"])
 
-        request, channel = self.make_request(
-            "DELETE", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -158,22 +146,16 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
             self.other_user
         )
 
-        request, channel = self.make_request(
-            "GET", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
 
-        request, channel = self.make_request(
-            "PUT", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("PUT", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
 
-        request, channel = self.make_request(
-            "DELETE", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
 
         # Delete unknown device returns status 200
         self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -197,7 +179,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         }
 
         body = json.dumps(update)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             self.url,
             access_token=self.admin_user_tok,
@@ -208,9 +190,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"])
 
         # Ensure the display name was not updated.
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual("new display", channel.json_body["display_name"])
@@ -227,16 +207,12 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        request, channel = self.make_request(
-            "PUT", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("PUT", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
 
         # Ensure the display name was not updated.
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual("new display", channel.json_body["display_name"])
@@ -247,7 +223,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         """
         # Set new display_name
         body = json.dumps({"display_name": "new displayname"})
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             self.url,
             access_token=self.admin_user_tok,
@@ -257,9 +233,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, channel.code, msg=channel.json_body)
 
         # Check new display_name
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual("new displayname", channel.json_body["display_name"])
@@ -268,9 +242,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         """
         Tests that a normal lookup for a device is successfully
         """
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(self.other_user, channel.json_body["user_id"])
@@ -291,7 +263,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(1, number_devices)
 
         # Delete device
-        request, channel = self.make_request(
+        channel = self.make_request(
             "DELETE", self.url, access_token=self.admin_user_tok,
         )
 
@@ -323,7 +295,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
         """
         Try to list devices of an user without authentication.
         """
-        request, channel = self.make_request("GET", self.url, b"{}")
+        channel = self.make_request("GET", self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -334,9 +306,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
         """
         other_user_token = self.login("user", "pass")
 
-        request, channel = self.make_request(
-            "GET", self.url, access_token=other_user_token,
-        )
+        channel = self.make_request("GET", self.url, access_token=other_user_token,)
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -346,9 +316,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
         Tests that a lookup for a user that does not exist returns a 404
         """
         url = "/_synapse/admin/v2/users/@unknown_person:test/devices"
-        request, channel = self.make_request(
-            "GET", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -359,9 +327,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
         """
         url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices"
 
-        request, channel = self.make_request(
-            "GET", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -373,9 +339,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
         """
 
         # Get devices
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(0, channel.json_body["total"])
@@ -391,9 +355,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
             self.login("user", "pass")
 
         # Get devices
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(number_devices, channel.json_body["total"])
@@ -431,7 +393,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
         """
         Try to delete devices of an user without authentication.
         """
-        request, channel = self.make_request("POST", self.url, b"{}")
+        channel = self.make_request("POST", self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -442,9 +404,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
         """
         other_user_token = self.login("user", "pass")
 
-        request, channel = self.make_request(
-            "POST", self.url, access_token=other_user_token,
-        )
+        channel = self.make_request("POST", self.url, access_token=other_user_token,)
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -454,9 +414,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
         Tests that a lookup for a user that does not exist returns a 404
         """
         url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices"
-        request, channel = self.make_request(
-            "POST", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("POST", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -467,9 +425,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
         """
         url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices"
 
-        request, channel = self.make_request(
-            "POST", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("POST", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -479,7 +435,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
         Tests that a remove of a device that does not exist returns 200.
         """
         body = json.dumps({"devices": ["unknown_device1", "unknown_device2"]})
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url,
             access_token=self.admin_user_tok,
@@ -510,7 +466,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
 
         # Delete devices
         body = json.dumps({"devices": device_ids})
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url,
             access_token=self.admin_user_tok,
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index 11b72c10f7..d0090faa4f 100644
--- a/tests/rest/admin/test_event_reports.py
+++ b/tests/rest/admin/test_event_reports.py
@@ -32,8 +32,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
-
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
@@ -74,7 +72,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         """
         Try to get an event report without authentication.
         """
-        request, channel = self.make_request("GET", self.url, b"{}")
+        channel = self.make_request("GET", self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -84,9 +82,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         If the user is not a server admin, an error 403 is returned.
         """
 
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.other_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.other_user_tok,)
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -96,9 +92,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         Testing list of reported events
         """
 
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(channel.json_body["total"], 20)
@@ -111,7 +105,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         Testing list of reported events with limit
         """
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
         )
 
@@ -126,7 +120,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         Testing list of reported events with a defined starting point (from)
         """
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?from=5", access_token=self.admin_user_tok,
         )
 
@@ -141,7 +135,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         Testing list of reported events with a defined starting point and limit
         """
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
         )
 
@@ -156,7 +150,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         Testing list of reported events with a filter of room
         """
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             self.url + "?room_id=%s" % self.room_id1,
             access_token=self.admin_user_tok,
@@ -176,7 +170,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         Testing list of reported events with a filter of user
         """
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             self.url + "?user_id=%s" % self.other_user,
             access_token=self.admin_user_tok,
@@ -196,7 +190,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         Testing list of reported events with a filter of user and room
         """
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             self.url + "?user_id=%s&room_id=%s" % (self.other_user, self.room_id1),
             access_token=self.admin_user_tok,
@@ -218,7 +212,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         """
 
         # fetch the most recent first, largest timestamp
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?dir=b", access_token=self.admin_user_tok,
         )
 
@@ -234,7 +228,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
             report += 1
 
         # fetch the oldest first, smallest timestamp
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?dir=f", access_token=self.admin_user_tok,
         )
 
@@ -254,7 +248,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         Testing that a invalid search order returns a 400
         """
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?dir=bar", access_token=self.admin_user_tok,
         )
 
@@ -267,7 +261,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         Testing that a negative limit parameter returns a 400
         """
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
         )
 
@@ -279,7 +273,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         Testing that a negative from parameter returns a 400
         """
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
         )
 
@@ -293,7 +287,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
 
         #  `next_token` does not appear
         # Number of results is the number of entries
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
         )
 
@@ -304,7 +298,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
 
         #  `next_token` does not appear
         # Number of max results is larger than the number of entries
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
         )
 
@@ -315,7 +309,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
 
         #  `next_token` does appear
         # Number of max results is smaller than the number of entries
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
         )
 
@@ -327,7 +321,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         # Check
         # Set `from` to value of `next_token` for request remaining entries
         #  `next_token` does not appear
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?from=19", access_token=self.admin_user_tok,
         )
 
@@ -342,7 +336,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         resp = self.helper.send(room_id, tok=user_tok)
         event_id = resp["event_id"]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "rooms/%s/report/%s" % (room_id, event_id),
             json.dumps({"score": -100, "reason": "this makes me sad"}),
@@ -375,8 +369,6 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
-
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
@@ -399,7 +391,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
         """
         Try to get event report without authentication.
         """
-        request, channel = self.make_request("GET", self.url, b"{}")
+        channel = self.make_request("GET", self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -409,9 +401,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
         If the user is not a server admin, an error 403 is returned.
         """
 
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.other_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.other_user_tok,)
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -421,9 +411,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
         Testing get a reported event
         """
 
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self._check_fields(channel.json_body)
@@ -434,7 +422,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
         """
 
         # `report_id` is negative
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_synapse/admin/v1/event_reports/-123",
             access_token=self.admin_user_tok,
@@ -448,7 +436,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
         )
 
         # `report_id` is a non-numerical string
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_synapse/admin/v1/event_reports/abcdef",
             access_token=self.admin_user_tok,
@@ -462,7 +450,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
         )
 
         # `report_id` is undefined
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_synapse/admin/v1/event_reports/",
             access_token=self.admin_user_tok,
@@ -480,7 +468,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
         Testing that a not existing `report_id` returns a 404.
         """
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_synapse/admin/v1/event_reports/123",
             access_token=self.admin_user_tok,
@@ -496,7 +484,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
         resp = self.helper.send(room_id, tok=user_tok)
         event_id = resp["event_id"]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "rooms/%s/report/%s" % (room_id, event_id),
             json.dumps({"score": -100, "reason": "this makes me sad"}),
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index dadf9db660..51a7731693 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -35,7 +35,6 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.handler = hs.get_device_handler()
         self.media_repo = hs.get_media_repository_resource()
         self.server_name = hs.hostname
 
@@ -50,7 +49,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
         """
         url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
 
-        request, channel = self.make_request("DELETE", url, b"{}")
+        channel = self.make_request("DELETE", url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -64,9 +63,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
 
         url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
 
-        request, channel = self.make_request(
-            "DELETE", url, access_token=self.other_user_token,
-        )
+        channel = self.make_request("DELETE", url, access_token=self.other_user_token,)
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -77,9 +74,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
         """
         url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
 
-        request, channel = self.make_request(
-            "DELETE", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -90,9 +85,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
         """
         url = "/_synapse/admin/v1/media/%s/%s" % ("unknown_domain", "12345")
 
-        request, channel = self.make_request(
-            "DELETE", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Can only delete local media", channel.json_body["error"])
@@ -121,7 +114,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
         self.assertEqual(server_name, self.server_name)
 
         # Attempt to access media
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             FakeSite(download_resource),
             "GET",
@@ -146,9 +139,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
         url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, media_id)
 
         # Delete media
-        request, channel = self.make_request(
-            "DELETE", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(1, channel.json_body["total"])
@@ -157,7 +148,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
         )
 
         # Attempt to access media
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             FakeSite(download_resource),
             "GET",
@@ -189,7 +180,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.handler = hs.get_device_handler()
         self.media_repo = hs.get_media_repository_resource()
         self.server_name = hs.hostname
 
@@ -204,7 +194,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         Try to delete media without authentication.
         """
 
-        request, channel = self.make_request("POST", self.url, b"{}")
+        channel = self.make_request("POST", self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -216,7 +206,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         self.other_user = self.register_user("user", "pass")
         self.other_user_token = self.login("user", "pass")
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", self.url, access_token=self.other_user_token,
         )
 
@@ -229,7 +219,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         """
         url = "/_synapse/admin/v1/media/%s/delete" % "unknown_domain"
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", url + "?before_ts=1234", access_token=self.admin_user_tok,
         )
 
@@ -240,9 +230,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         """
         If the parameter `before_ts` is missing, an error is returned.
         """
-        request, channel = self.make_request(
-            "POST", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("POST", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
@@ -254,7 +242,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         """
         If parameters are invalid, an error is returned.
         """
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", self.url + "?before_ts=-1234", access_token=self.admin_user_tok,
         )
 
@@ -265,7 +253,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
             channel.json_body["error"],
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url + "?before_ts=1234&size_gt=-1234",
             access_token=self.admin_user_tok,
@@ -278,7 +266,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
             channel.json_body["error"],
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url + "?before_ts=1234&keep_profiles=not_bool",
             access_token=self.admin_user_tok,
@@ -308,7 +296,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
 
         # timestamp after upload/create
         now_ms = self.clock.time_msec()
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url + "?before_ts=" + str(now_ms),
             access_token=self.admin_user_tok,
@@ -332,7 +320,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
 
         self._access_media(server_and_media_id)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url + "?before_ts=" + str(now_ms),
             access_token=self.admin_user_tok,
@@ -344,7 +332,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
 
         # timestamp after upload
         now_ms = self.clock.time_msec()
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url + "?before_ts=" + str(now_ms),
             access_token=self.admin_user_tok,
@@ -367,7 +355,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         self._access_media(server_and_media_id)
 
         now_ms = self.clock.time_msec()
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url + "?before_ts=" + str(now_ms) + "&size_gt=67",
             access_token=self.admin_user_tok,
@@ -378,7 +366,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         self._access_media(server_and_media_id)
 
         now_ms = self.clock.time_msec()
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url + "?before_ts=" + str(now_ms) + "&size_gt=66",
             access_token=self.admin_user_tok,
@@ -401,7 +389,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         self._access_media(server_and_media_id)
 
         # set media as avatar
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/profile/%s/avatar_url" % (self.admin_user,),
             content=json.dumps({"avatar_url": "mxc://%s" % (server_and_media_id,)}),
@@ -410,7 +398,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, channel.code, msg=channel.json_body)
 
         now_ms = self.clock.time_msec()
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true",
             access_token=self.admin_user_tok,
@@ -421,7 +409,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         self._access_media(server_and_media_id)
 
         now_ms = self.clock.time_msec()
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false",
             access_token=self.admin_user_tok,
@@ -445,7 +433,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
 
         # set media as room avatar
         room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/rooms/%s/state/m.room.avatar" % (room_id,),
             content=json.dumps({"url": "mxc://%s" % (server_and_media_id,)}),
@@ -454,7 +442,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, channel.code, msg=channel.json_body)
 
         now_ms = self.clock.time_msec()
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true",
             access_token=self.admin_user_tok,
@@ -465,7 +453,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         self._access_media(server_and_media_id)
 
         now_ms = self.clock.time_msec()
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false",
             access_token=self.admin_user_tok,
@@ -512,7 +500,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         media_id = server_and_media_id.split("/")[1]
         local_path = self.filepaths.local_media_filepath(media_id)
 
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             FakeSite(download_resource),
             "GET",
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 46933a0493..7c47aa7e0a 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -20,6 +20,7 @@ from typing import List, Optional
 from mock import Mock
 
 import synapse.rest.admin
+from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import Codes
 from synapse.rest.client.v1 import directory, events, login, room
 
@@ -79,7 +80,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
 
         # Test that the admin can still send shutdown
         url = "/_synapse/admin/v1/shutdown_room/" + room_id
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             url.encode("ascii"),
             json.dumps({"new_room_user_id": self.admin_user}),
@@ -103,7 +104,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
 
         # Enable world readable
         url = "rooms/%s/state/m.room.history_visibility" % (room_id,)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             url.encode("ascii"),
             json.dumps({"history_visibility": "world_readable"}),
@@ -113,7 +114,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
 
         # Test that the admin can still send shutdown
         url = "/_synapse/admin/v1/shutdown_room/" + room_id
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             url.encode("ascii"),
             json.dumps({"new_room_user_id": self.admin_user}),
@@ -130,7 +131,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
         """
 
         url = "rooms/%s/initialSync" % (room_id,)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", url.encode("ascii"), access_token=self.admin_user_tok
         )
         self.assertEqual(
@@ -138,7 +139,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
         )
 
         url = "events?timeout=0&room_id=" + room_id
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", url.encode("ascii"), access_token=self.admin_user_tok
         )
         self.assertEqual(
@@ -184,7 +185,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         If the user is not a server admin, an error 403 is returned.
         """
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", self.url, json.dumps({}), access_token=self.other_user_tok,
         )
 
@@ -197,7 +198,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         """
         url = "/_synapse/admin/v1/rooms/!unknown:test/delete"
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", url, json.dumps({}), access_token=self.admin_user_tok,
         )
 
@@ -210,7 +211,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         """
         url = "/_synapse/admin/v1/rooms/invalidroom/delete"
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", url, json.dumps({}), access_token=self.admin_user_tok,
         )
 
@@ -225,7 +226,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         """
         body = json.dumps({"new_room_user_id": "@unknown:test"})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url,
             content=body.encode(encoding="utf_8"),
@@ -244,7 +245,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         """
         body = json.dumps({"new_room_user_id": "@not:exist.bla"})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url,
             content=body.encode(encoding="utf_8"),
@@ -262,7 +263,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         """
         body = json.dumps({"block": "NotBool"})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url,
             content=body.encode(encoding="utf_8"),
@@ -278,7 +279,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         """
         body = json.dumps({"purge": "NotBool"})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url,
             content=body.encode(encoding="utf_8"),
@@ -304,7 +305,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
 
         body = json.dumps({"block": True, "purge": True})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url.encode("ascii"),
             content=body.encode(encoding="utf_8"),
@@ -337,7 +338,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
 
         body = json.dumps({"block": False, "purge": True})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url.encode("ascii"),
             content=body.encode(encoding="utf_8"),
@@ -371,7 +372,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
 
         body = json.dumps({"block": False, "purge": False})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url.encode("ascii"),
             content=body.encode(encoding="utf_8"),
@@ -418,7 +419,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
 
         # Test that the admin can still send shutdown
         url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             url.encode("ascii"),
             json.dumps({"new_room_user_id": self.admin_user}),
@@ -448,7 +449,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
 
         # Enable world readable
         url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             url.encode("ascii"),
             json.dumps({"history_visibility": "world_readable"}),
@@ -465,7 +466,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
 
         # Test that the admin can still send shutdown
         url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             url.encode("ascii"),
             json.dumps({"new_room_user_id": self.admin_user}),
@@ -530,7 +531,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         """
 
         url = "rooms/%s/initialSync" % (room_id,)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", url.encode("ascii"), access_token=self.admin_user_tok
         )
         self.assertEqual(
@@ -538,7 +539,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         )
 
         url = "events?timeout=0&room_id=" + room_id
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", url.encode("ascii"), access_token=self.admin_user_tok
         )
         self.assertEqual(
@@ -569,7 +570,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase):
         self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
 
         url = "/_synapse/admin/v1/purge_room"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             url.encode("ascii"),
             {"room_id": room_id},
@@ -604,8 +605,6 @@ class RoomTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
-
         # Create user
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
@@ -623,7 +622,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
 
         # Request the list of rooms
         url = "/_synapse/admin/v1/rooms"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", url.encode("ascii"), access_token=self.admin_user_tok,
         )
 
@@ -704,7 +703,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
                 limit,
                 "name",
             )
-            request, channel = self.make_request(
+            channel = self.make_request(
                 "GET", url.encode("ascii"), access_token=self.admin_user_tok,
             )
             self.assertEqual(
@@ -744,7 +743,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
         self.assertEqual(room_ids, returned_room_ids)
 
         url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", url.encode("ascii"), access_token=self.admin_user_tok,
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -764,7 +763,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
 
         # Create a new alias to this room
         url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             url.encode("ascii"),
             {"room_id": room_id},
@@ -794,7 +793,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
 
         # Request the list of rooms
         url = "/_synapse/admin/v1/rooms"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", url.encode("ascii"), access_token=self.admin_user_tok,
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -835,7 +834,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
             url = "/_matrix/client/r0/directory/room/%s" % (
                 urllib.parse.quote(test_alias),
             )
-            request, channel = self.make_request(
+            channel = self.make_request(
                 "PUT",
                 url.encode("ascii"),
                 {"room_id": room_id},
@@ -875,7 +874,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
             url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,)
             if reverse:
                 url += "&dir=b"
-            request, channel = self.make_request(
+            channel = self.make_request(
                 "GET", url.encode("ascii"), access_token=self.admin_user_tok,
             )
             self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -1011,7 +1010,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
                 expected_http_code: The expected http code for the request
             """
             url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,)
-            request, channel = self.make_request(
+            channel = self.make_request(
                 "GET", url.encode("ascii"), access_token=self.admin_user_tok,
             )
             self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
@@ -1050,6 +1049,13 @@ class RoomTestCase(unittest.HomeserverTestCase):
         _search_test(room_id_2, "else")
         _search_test(room_id_2, "se")
 
+        # Test case insensitive
+        _search_test(room_id_1, "SOMETHING")
+        _search_test(room_id_1, "THING")
+
+        _search_test(room_id_2, "ELSE")
+        _search_test(room_id_2, "SE")
+
         _search_test(None, "foo")
         _search_test(None, "bar")
         _search_test(None, "", expected_http_code=400)
@@ -1072,7 +1078,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
         )
 
         url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", url.encode("ascii"), access_token=self.admin_user_tok,
         )
         self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -1084,6 +1090,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
         self.assertIn("canonical_alias", channel.json_body)
         self.assertIn("joined_members", channel.json_body)
         self.assertIn("joined_local_members", channel.json_body)
+        self.assertIn("joined_local_devices", channel.json_body)
         self.assertIn("version", channel.json_body)
         self.assertIn("creator", channel.json_body)
         self.assertIn("encryption", channel.json_body)
@@ -1096,6 +1103,39 @@ class RoomTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(room_id_1, channel.json_body["room_id"])
 
+    def test_single_room_devices(self):
+        """Test that `joined_local_devices` can be requested correctly"""
+        room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+        url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
+        channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual(1, channel.json_body["joined_local_devices"])
+
+        # Have another user join the room
+        user_1 = self.register_user("foo", "pass")
+        user_tok_1 = self.login("foo", "pass")
+        self.helper.join(room_id_1, user_1, tok=user_tok_1)
+
+        url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
+        channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual(2, channel.json_body["joined_local_devices"])
+
+        # leave room
+        self.helper.leave(room_id_1, self.admin_user, tok=self.admin_user_tok)
+        self.helper.leave(room_id_1, user_1, tok=user_tok_1)
+        url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
+        channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual(0, channel.json_body["joined_local_devices"])
+
     def test_room_members(self):
         """Test that room members can be requested correctly"""
         # Create two test rooms
@@ -1119,7 +1159,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
         self.helper.join(room_id_2, user_3, tok=user_tok_3)
 
         url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", url.encode("ascii"), access_token=self.admin_user_tok,
         )
         self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -1130,7 +1170,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["total"], 3)
 
         url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", url.encode("ascii"), access_token=self.admin_user_tok,
         )
         self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -1140,6 +1180,21 @@ class RoomTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(channel.json_body["total"], 3)
 
+    def test_room_state(self):
+        """Test that room state can be requested correctly"""
+        # Create two test rooms
+        room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+        url = "/_synapse/admin/v1/rooms/%s/state" % (room_id,)
+        channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertIn("state", channel.json_body)
+        # testing that the state events match is painful and not done here. We assume that
+        # the create_room already does the right thing, so no need to verify that we got
+        # the state events it created.
+
 
 class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
 
@@ -1170,7 +1225,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         """
         body = json.dumps({"user_id": self.second_user_id})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url,
             content=body.encode(encoding="utf_8"),
@@ -1186,7 +1241,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         """
         body = json.dumps({"unknown_parameter": "@unknown:test"})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url,
             content=body.encode(encoding="utf_8"),
@@ -1202,7 +1257,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         """
         body = json.dumps({"user_id": "@unknown:test"})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url,
             content=body.encode(encoding="utf_8"),
@@ -1218,7 +1273,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         """
         body = json.dumps({"user_id": "@not:exist.bla"})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url,
             content=body.encode(encoding="utf_8"),
@@ -1238,7 +1293,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         body = json.dumps({"user_id": self.second_user_id})
         url = "/_synapse/admin/v1/join/!unknown:test"
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             url,
             content=body.encode(encoding="utf_8"),
@@ -1255,7 +1310,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         body = json.dumps({"user_id": self.second_user_id})
         url = "/_synapse/admin/v1/join/invalidroom"
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             url,
             content=body.encode(encoding="utf_8"),
@@ -1274,7 +1329,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         """
         body = json.dumps({"user_id": self.second_user_id})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url,
             content=body.encode(encoding="utf_8"),
@@ -1286,7 +1341,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
 
         # Validate if user is a member of the room
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
         )
         self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1303,7 +1358,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         url = "/_synapse/admin/v1/join/{}".format(private_room_id)
         body = json.dumps({"user_id": self.second_user_id})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             url,
             content=body.encode(encoding="utf_8"),
@@ -1333,7 +1388,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
 
         # Validate if server admin is a member of the room
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok,
         )
         self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1344,7 +1399,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         url = "/_synapse/admin/v1/join/{}".format(private_room_id)
         body = json.dumps({"user_id": self.second_user_id})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             url,
             content=body.encode(encoding="utf_8"),
@@ -1355,7 +1410,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
 
         # Validate if user is a member of the room
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
         )
         self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1372,7 +1427,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         url = "/_synapse/admin/v1/join/{}".format(private_room_id)
         body = json.dumps({"user_id": self.second_user_id})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             url,
             content=body.encode(encoding="utf_8"),
@@ -1384,13 +1439,150 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
 
         # Validate if user is a member of the room
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
         )
         self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
 
 
+class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, homeserver):
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.creator = self.register_user("creator", "test")
+        self.creator_tok = self.login("creator", "test")
+
+        self.second_user_id = self.register_user("second", "test")
+        self.second_tok = self.login("second", "test")
+
+        self.public_room_id = self.helper.create_room_as(
+            self.creator, tok=self.creator_tok, is_public=True
+        )
+        self.url = "/_synapse/admin/v1/rooms/{}/make_room_admin".format(
+            self.public_room_id
+        )
+
+    def test_public_room(self):
+        """Test that getting admin in a public room works.
+        """
+        room_id = self.helper.create_room_as(
+            self.creator, tok=self.creator_tok, is_public=True
+        )
+
+        channel = self.make_request(
+            "POST",
+            "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+            content={},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Now we test that we can join the room and ban a user.
+        self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok)
+        self.helper.change_membership(
+            room_id,
+            self.admin_user,
+            "@test:test",
+            Membership.BAN,
+            tok=self.admin_user_tok,
+        )
+
+    def test_private_room(self):
+        """Test that getting admin in a private room works and we get invited.
+        """
+        room_id = self.helper.create_room_as(
+            self.creator, tok=self.creator_tok, is_public=False,
+        )
+
+        channel = self.make_request(
+            "POST",
+            "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+            content={},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Now we test that we can join the room (we should have received an
+        # invite) and can ban a user.
+        self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok)
+        self.helper.change_membership(
+            room_id,
+            self.admin_user,
+            "@test:test",
+            Membership.BAN,
+            tok=self.admin_user_tok,
+        )
+
+    def test_other_user(self):
+        """Test that giving admin in a public room works to a non-admin user works.
+        """
+        room_id = self.helper.create_room_as(
+            self.creator, tok=self.creator_tok, is_public=True
+        )
+
+        channel = self.make_request(
+            "POST",
+            "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+            content={"user_id": self.second_user_id},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Now we test that we can join the room and ban a user.
+        self.helper.join(room_id, self.second_user_id, tok=self.second_tok)
+        self.helper.change_membership(
+            room_id,
+            self.second_user_id,
+            "@test:test",
+            Membership.BAN,
+            tok=self.second_tok,
+        )
+
+    def test_not_enough_power(self):
+        """Test that we get a sensible error if there are no local room admins.
+        """
+        room_id = self.helper.create_room_as(
+            self.creator, tok=self.creator_tok, is_public=True
+        )
+
+        # The creator drops admin rights in the room.
+        pl = self.helper.get_state(
+            room_id, EventTypes.PowerLevels, tok=self.creator_tok
+        )
+        pl["users"][self.creator] = 0
+        self.helper.send_state(
+            room_id, EventTypes.PowerLevels, body=pl, tok=self.creator_tok
+        )
+
+        channel = self.make_request(
+            "POST",
+            "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+            content={},
+            access_token=self.admin_user_tok,
+        )
+
+        # We expect this to fail with a 400 as there are no room admins.
+        #
+        # (Note we assert the error message to ensure that it's not denied for
+        # some other reason)
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(
+            channel.json_body["error"],
+            "No local admin user in room with power to update power levels.",
+        )
+
+
 PURGE_TABLES = [
     "current_state_events",
     "event_backward_extremities",
@@ -1419,7 +1611,6 @@ PURGE_TABLES = [
     "event_push_summary",
     "pusher_throttle",
     "group_summary_rooms",
-    "local_invites",
     "room_account_data",
     "room_tags",
     # "state_groups",  # Current impl leaves orphaned state groups around.
diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py
index 907b49f889..f48be3d65a 100644
--- a/tests/rest/admin/test_statistics.py
+++ b/tests/rest/admin/test_statistics.py
@@ -31,7 +31,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
         self.media_repo = hs.get_media_repository_resource()
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -46,7 +45,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         """
         Try to list users without authentication.
         """
-        request, channel = self.make_request("GET", self.url, b"{}")
+        channel = self.make_request("GET", self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -55,7 +54,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         """
         If the user is not a server admin, an error 403 is returned.
         """
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url, json.dumps({}), access_token=self.other_user_tok,
         )
 
@@ -67,7 +66,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         If parameters are invalid, an error is returned.
         """
         # unkown order_by
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?order_by=bar", access_token=self.admin_user_tok,
         )
 
@@ -75,7 +74,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
         # negative from
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
         )
 
@@ -83,7 +82,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
         # negative limit
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
         )
 
@@ -91,7 +90,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
         # negative from_ts
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?from_ts=-1234", access_token=self.admin_user_tok,
         )
 
@@ -99,7 +98,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
         # negative until_ts
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?until_ts=-1234", access_token=self.admin_user_tok,
         )
 
@@ -107,7 +106,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
         # until_ts smaller from_ts
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             self.url + "?from_ts=10&until_ts=5",
             access_token=self.admin_user_tok,
@@ -117,7 +116,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
         # empty search term
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?search_term=", access_token=self.admin_user_tok,
         )
 
@@ -125,7 +124,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
         # invalid search order
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?dir=bar", access_token=self.admin_user_tok,
         )
 
@@ -138,7 +137,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         """
         self._create_users_with_media(10, 2)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
         )
 
@@ -154,7 +153,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         """
         self._create_users_with_media(20, 2)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?from=5", access_token=self.admin_user_tok,
         )
 
@@ -170,7 +169,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         """
         self._create_users_with_media(20, 2)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
         )
 
@@ -190,7 +189,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
 
         #  `next_token` does not appear
         # Number of results is the number of entries
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
         )
 
@@ -201,7 +200,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
 
         #  `next_token` does not appear
         # Number of max results is larger than the number of entries
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
         )
 
@@ -212,7 +211,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
 
         #  `next_token` does appear
         # Number of max results is smaller than the number of entries
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
         )
 
@@ -223,7 +222,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
 
         # Set `from` to value of `next_token` for request remaining entries
         # Check `next_token` does not appear
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?from=19", access_token=self.admin_user_tok,
         )
 
@@ -238,9 +237,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         if users have no media created
         """
 
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(0, channel.json_body["total"])
@@ -316,15 +313,13 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         ts1 = self.clock.time_msec()
 
         # list all media when filter is not set
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(channel.json_body["users"][0]["media_count"], 3)
 
         # filter media starting at `ts1` after creating first media
         # result is 0
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?from_ts=%s" % (ts1,), access_token=self.admin_user_tok,
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -337,7 +332,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         self._create_media(self.other_user_tok, 3)
 
         # filter media between `ts1` and `ts2`
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             self.url + "?from_ts=%s&until_ts=%s" % (ts1, ts2),
             access_token=self.admin_user_tok,
@@ -346,7 +341,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["users"][0]["media_count"], 3)
 
         # filter media until `ts2` and earlier
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?until_ts=%s" % (ts2,), access_token=self.admin_user_tok,
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -356,14 +351,12 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         self._create_users_with_media(20, 1)
 
         # check without filter get all users
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(channel.json_body["total"], 20)
 
         # filter user 1 and 10-19 by `user_id`
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             self.url + "?search_term=foo_user_1",
             access_token=self.admin_user_tok,
@@ -372,7 +365,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["total"], 11)
 
         # filter on this user in `displayname`
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             self.url + "?search_term=bar_user_10",
             access_token=self.admin_user_tok,
@@ -382,7 +375,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["total"], 1)
 
         # filter and get empty result
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?search_term=foobar", access_token=self.admin_user_tok,
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -447,7 +440,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         url = self.url + "?order_by=%s" % (order_type,)
         if dir is not None and dir in ("b", "f"):
             url += "&dir=%s" % (dir,)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", url.encode("ascii"), access_token=self.admin_user_tok,
         )
         self.assertEqual(200, channel.code, msg=channel.json_body)
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 54d46f4bd3..ee05ee60bc 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -18,14 +18,17 @@ import hmac
 import json
 import urllib.parse
 from binascii import unhexlify
+from typing import Optional
 
 from mock import Mock
 
 import synapse.rest.admin
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
+from synapse.api.room_versions import RoomVersions
 from synapse.rest.client.v1 import login, logout, profile, room
 from synapse.rest.client.v2_alpha import devices, sync
+from synapse.types import JsonDict
 
 from tests import unittest
 from tests.test_utils import make_awaitable
@@ -70,7 +73,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         """
         self.hs.config.registration_shared_secret = None
 
-        request, channel = self.make_request("POST", self.url, b"{}")
+        channel = self.make_request("POST", self.url, b"{}")
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(
@@ -87,7 +90,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
 
         self.hs.get_secrets = Mock(return_value=secrets)
 
-        request, channel = self.make_request("GET", self.url)
+        channel = self.make_request("GET", self.url)
 
         self.assertEqual(channel.json_body, {"nonce": "abcd"})
 
@@ -96,14 +99,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         Calling GET on the endpoint will return a randomised nonce, which will
         only last for SALT_TIMEOUT (60s).
         """
-        request, channel = self.make_request("GET", self.url)
+        channel = self.make_request("GET", self.url)
         nonce = channel.json_body["nonce"]
 
         # 59 seconds
         self.reactor.advance(59)
 
         body = json.dumps({"nonce": nonce})
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("username must be specified", channel.json_body["error"])
@@ -111,7 +114,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         # 61 seconds
         self.reactor.advance(2)
 
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("unrecognised nonce", channel.json_body["error"])
@@ -120,7 +123,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         """
         Only the provided nonce can be used, as it's checked in the MAC.
         """
-        request, channel = self.make_request("GET", self.url)
+        channel = self.make_request("GET", self.url)
         nonce = channel.json_body["nonce"]
 
         want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -136,7 +139,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
                 "mac": want_mac,
             }
         )
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("HMAC incorrect", channel.json_body["error"])
@@ -146,7 +149,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         When the correct nonce is provided, and the right key is provided, the
         user is registered.
         """
-        request, channel = self.make_request("GET", self.url)
+        channel = self.make_request("GET", self.url)
         nonce = channel.json_body["nonce"]
 
         want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -165,7 +168,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
                 "mac": want_mac,
             }
         )
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob:test", channel.json_body["user_id"])
@@ -174,7 +177,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         """
         A valid unrecognised nonce.
         """
-        request, channel = self.make_request("GET", self.url)
+        channel = self.make_request("GET", self.url)
         nonce = channel.json_body["nonce"]
 
         want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -190,13 +193,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
                 "mac": want_mac,
             }
         )
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob:test", channel.json_body["user_id"])
 
         # Now, try and reuse it
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("unrecognised nonce", channel.json_body["error"])
@@ -209,7 +212,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         """
 
         def nonce():
-            request, channel = self.make_request("GET", self.url)
+            channel = self.make_request("GET", self.url)
             return channel.json_body["nonce"]
 
         #
@@ -218,7 +221,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
 
         # Must be present
         body = json.dumps({})
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("nonce must be specified", channel.json_body["error"])
@@ -229,28 +232,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
 
         # Must be present
         body = json.dumps({"nonce": nonce()})
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("username must be specified", channel.json_body["error"])
 
         # Must be a string
         body = json.dumps({"nonce": nonce(), "username": 1234})
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("Invalid username", channel.json_body["error"])
 
         # Must not have null bytes
         body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"})
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("Invalid username", channel.json_body["error"])
 
         # Must not have null bytes
         body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("Invalid username", channel.json_body["error"])
@@ -261,28 +264,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
 
         # Must be present
         body = json.dumps({"nonce": nonce(), "username": "a"})
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("password must be specified", channel.json_body["error"])
 
         # Must be a string
         body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("Invalid password", channel.json_body["error"])
 
         # Must not have null bytes
         body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"})
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("Invalid password", channel.json_body["error"])
 
         # Super long
         body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("Invalid password", channel.json_body["error"])
@@ -300,7 +303,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
                 "user_type": "invalid",
             }
         )
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("Invalid user type", channel.json_body["error"])
@@ -311,7 +314,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         """
 
         # set no displayname
-        request, channel = self.make_request("GET", self.url)
+        channel = self.make_request("GET", self.url)
         nonce = channel.json_body["nonce"]
 
         want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -321,17 +324,17 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         body = json.dumps(
             {"nonce": nonce, "username": "bob1", "password": "abc123", "mac": want_mac}
         )
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob1:test", channel.json_body["user_id"])
 
-        request, channel = self.make_request("GET", "/profile/@bob1:test/displayname")
+        channel = self.make_request("GET", "/profile/@bob1:test/displayname")
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("bob1", channel.json_body["displayname"])
 
         # displayname is None
-        request, channel = self.make_request("GET", self.url)
+        channel = self.make_request("GET", self.url)
         nonce = channel.json_body["nonce"]
 
         want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -347,17 +350,17 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
                 "mac": want_mac,
             }
         )
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob2:test", channel.json_body["user_id"])
 
-        request, channel = self.make_request("GET", "/profile/@bob2:test/displayname")
+        channel = self.make_request("GET", "/profile/@bob2:test/displayname")
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("bob2", channel.json_body["displayname"])
 
         # displayname is empty
-        request, channel = self.make_request("GET", self.url)
+        channel = self.make_request("GET", self.url)
         nonce = channel.json_body["nonce"]
 
         want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -373,16 +376,16 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
                 "mac": want_mac,
             }
         )
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob3:test", channel.json_body["user_id"])
 
-        request, channel = self.make_request("GET", "/profile/@bob3:test/displayname")
+        channel = self.make_request("GET", "/profile/@bob3:test/displayname")
         self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
 
         # set displayname
-        request, channel = self.make_request("GET", self.url)
+        channel = self.make_request("GET", self.url)
         nonce = channel.json_body["nonce"]
 
         want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -398,12 +401,12 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
                 "mac": want_mac,
             }
         )
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob4:test", channel.json_body["user_id"])
 
-        request, channel = self.make_request("GET", "/profile/@bob4:test/displayname")
+        channel = self.make_request("GET", "/profile/@bob4:test/displayname")
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("Bob's Name", channel.json_body["displayname"])
 
@@ -429,7 +432,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         )
 
         # Register new user with admin API
-        request, channel = self.make_request("GET", self.url)
+        channel = self.make_request("GET", self.url)
         nonce = channel.json_body["nonce"]
 
         want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -448,7 +451,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
                 "mac": want_mac,
             }
         )
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob:test", channel.json_body["user_id"])
@@ -466,23 +469,34 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
-        self.register_user("user1", "pass1", admin=False)
-        self.register_user("user2", "pass2", admin=False)
-
     def test_no_auth(self):
         """
         Try to list users without authentication.
         """
-        request, channel = self.make_request("GET", self.url, b"{}")
+        channel = self.make_request("GET", self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual("M_MISSING_TOKEN", channel.json_body["errcode"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_requester_is_no_admin(self):
+        """
+        If the user is not a server admin, an error is returned.
+        """
+        self._create_users(1)
+        other_user_token = self.login("user1", "pass1")
+
+        channel = self.make_request("GET", self.url, access_token=other_user_token)
+
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
     def test_all_users(self):
         """
         List all users, including deactivated users.
         """
-        request, channel = self.make_request(
+        self._create_users(2)
+
+        channel = self.make_request(
             "GET",
             self.url + "?deactivated=true",
             b"{}",
@@ -493,6 +507,449 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         self.assertEqual(3, len(channel.json_body["users"]))
         self.assertEqual(3, channel.json_body["total"])
 
+        # Check that all fields are available
+        self._check_fields(channel.json_body["users"])
+
+    def test_search_term(self):
+        """Test that searching for a users works correctly"""
+
+        def _search_test(
+            expected_user_id: Optional[str],
+            search_term: str,
+            search_field: Optional[str] = "name",
+            expected_http_code: Optional[int] = 200,
+        ):
+            """Search for a user and check that the returned user's id is a match
+
+            Args:
+                expected_user_id: The user_id expected to be returned by the API. Set
+                    to None to expect zero results for the search
+                search_term: The term to search for user names with
+                search_field: Field which is to request: `name` or `user_id`
+                expected_http_code: The expected http code for the request
+            """
+            url = self.url + "?%s=%s" % (search_field, search_term,)
+            channel = self.make_request(
+                "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+            )
+            self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
+
+            if expected_http_code != 200:
+                return
+
+            # Check that users were returned
+            self.assertTrue("users" in channel.json_body)
+            self._check_fields(channel.json_body["users"])
+            users = channel.json_body["users"]
+
+            # Check that the expected number of users were returned
+            expected_user_count = 1 if expected_user_id else 0
+            self.assertEqual(len(users), expected_user_count)
+            self.assertEqual(channel.json_body["total"], expected_user_count)
+
+            if expected_user_id:
+                # Check that the first returned user id is correct
+                u = users[0]
+                self.assertEqual(expected_user_id, u["name"])
+
+        self._create_users(2)
+
+        user1 = "@user1:test"
+        user2 = "@user2:test"
+
+        # Perform search tests
+        _search_test(user1, "er1")
+        _search_test(user1, "me 1")
+
+        _search_test(user2, "er2")
+        _search_test(user2, "me 2")
+
+        _search_test(user1, "er1", "user_id")
+        _search_test(user2, "er2", "user_id")
+
+        # Test case insensitive
+        _search_test(user1, "ER1")
+        _search_test(user1, "NAME 1")
+
+        _search_test(user2, "ER2")
+        _search_test(user2, "NAME 2")
+
+        _search_test(user1, "ER1", "user_id")
+        _search_test(user2, "ER2", "user_id")
+
+        _search_test(None, "foo")
+        _search_test(None, "bar")
+
+        _search_test(None, "foo", "user_id")
+        _search_test(None, "bar", "user_id")
+
+    def test_invalid_parameter(self):
+        """
+        If parameters are invalid, an error is returned.
+        """
+
+        # negative limit
+        channel = self.make_request(
+            "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+        # negative from
+        channel = self.make_request(
+            "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+        # invalid guests
+        channel = self.make_request(
+            "GET", self.url + "?guests=not_bool", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+        # invalid deactivated
+        channel = self.make_request(
+            "GET", self.url + "?deactivated=not_bool", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+    def test_limit(self):
+        """
+        Testing list of users with limit
+        """
+
+        number_users = 20
+        # Create one less user (since there's already an admin user).
+        self._create_users(number_users - 1)
+
+        channel = self.make_request(
+            "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], number_users)
+        self.assertEqual(len(channel.json_body["users"]), 5)
+        self.assertEqual(channel.json_body["next_token"], "5")
+        self._check_fields(channel.json_body["users"])
+
+    def test_from(self):
+        """
+        Testing list of users with a defined starting point (from)
+        """
+
+        number_users = 20
+        # Create one less user (since there's already an admin user).
+        self._create_users(number_users - 1)
+
+        channel = self.make_request(
+            "GET", self.url + "?from=5", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], number_users)
+        self.assertEqual(len(channel.json_body["users"]), 15)
+        self.assertNotIn("next_token", channel.json_body)
+        self._check_fields(channel.json_body["users"])
+
+    def test_limit_and_from(self):
+        """
+        Testing list of users with a defined starting point and limit
+        """
+
+        number_users = 20
+        # Create one less user (since there's already an admin user).
+        self._create_users(number_users - 1)
+
+        channel = self.make_request(
+            "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], number_users)
+        self.assertEqual(channel.json_body["next_token"], "15")
+        self.assertEqual(len(channel.json_body["users"]), 10)
+        self._check_fields(channel.json_body["users"])
+
+    def test_next_token(self):
+        """
+        Testing that `next_token` appears at the right place
+        """
+
+        number_users = 20
+        # Create one less user (since there's already an admin user).
+        self._create_users(number_users - 1)
+
+        #  `next_token` does not appear
+        # Number of results is the number of entries
+        channel = self.make_request(
+            "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], number_users)
+        self.assertEqual(len(channel.json_body["users"]), number_users)
+        self.assertNotIn("next_token", channel.json_body)
+
+        #  `next_token` does not appear
+        # Number of max results is larger than the number of entries
+        channel = self.make_request(
+            "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], number_users)
+        self.assertEqual(len(channel.json_body["users"]), number_users)
+        self.assertNotIn("next_token", channel.json_body)
+
+        #  `next_token` does appear
+        # Number of max results is smaller than the number of entries
+        channel = self.make_request(
+            "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], number_users)
+        self.assertEqual(len(channel.json_body["users"]), 19)
+        self.assertEqual(channel.json_body["next_token"], "19")
+
+        # Check
+        # Set `from` to value of `next_token` for request remaining entries
+        #  `next_token` does not appear
+        channel = self.make_request(
+            "GET", self.url + "?from=19", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], number_users)
+        self.assertEqual(len(channel.json_body["users"]), 1)
+        self.assertNotIn("next_token", channel.json_body)
+
+    def _check_fields(self, content: JsonDict):
+        """Checks that the expected user attributes are present in content
+        Args:
+            content: List that is checked for content
+        """
+        for u in content:
+            self.assertIn("name", u)
+            self.assertIn("is_guest", u)
+            self.assertIn("admin", u)
+            self.assertIn("user_type", u)
+            self.assertIn("deactivated", u)
+            self.assertIn("displayname", u)
+            self.assertIn("avatar_url", u)
+
+    def _create_users(self, number_users: int):
+        """
+        Create a number of users
+        Args:
+            number_users: Number of users to be created
+        """
+        for i in range(1, number_users + 1):
+            self.register_user(
+                "user%d" % i, "pass%d" % i, admin=False, displayname="Name %d" % i,
+            )
+
+
+class DeactivateAccountTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.other_user = self.register_user("user", "pass", displayname="User1")
+        self.other_user_token = self.login("user", "pass")
+        self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote(
+            self.other_user
+        )
+        self.url = "/_synapse/admin/v1/deactivate/%s" % urllib.parse.quote(
+            self.other_user
+        )
+
+        # set attributes for user
+        self.get_success(
+            self.store.set_profile_avatar_url("user", "mxc://servername/mediaid")
+        )
+        self.get_success(
+            self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0)
+        )
+
+    def test_no_auth(self):
+        """
+        Try to deactivate users without authentication.
+        """
+        channel = self.make_request("POST", self.url, b"{}")
+
+        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_requester_is_not_admin(self):
+        """
+        If the user is not a server admin, an error is returned.
+        """
+        url = "/_synapse/admin/v1/deactivate/@bob:test"
+
+        channel = self.make_request("POST", url, access_token=self.other_user_token)
+
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("You are not a server admin", channel.json_body["error"])
+
+        channel = self.make_request(
+            "POST", url, access_token=self.other_user_token, content=b"{}",
+        )
+
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("You are not a server admin", channel.json_body["error"])
+
+    def test_user_does_not_exist(self):
+        """
+        Tests that deactivation for a user that does not exist returns a 404
+        """
+
+        channel = self.make_request(
+            "POST",
+            "/_synapse/admin/v1/deactivate/@unknown_person:test",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(404, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+    def test_erase_is_not_bool(self):
+        """
+        If parameter `erase` is not boolean, return an error
+        """
+        body = json.dumps({"erase": "False"})
+
+        channel = self.make_request(
+            "POST",
+            self.url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+    def test_user_is_not_local(self):
+        """
+        Tests that deactivation for a user that is not a local returns a 400
+        """
+        url = "/_synapse/admin/v1/deactivate/@unknown_person:unknown_domain"
+
+        channel = self.make_request("POST", url, access_token=self.admin_user_tok)
+
+        self.assertEqual(400, channel.code, msg=channel.json_body)
+        self.assertEqual("Can only deactivate local users", channel.json_body["error"])
+
+    def test_deactivate_user_erase_true(self):
+        """
+        Test deactivating an user and set `erase` to `true`
+        """
+
+        # Get user
+        channel = self.make_request(
+            "GET", self.url_other_user, access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(False, channel.json_body["deactivated"])
+        self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
+        self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+        self.assertEqual("User1", channel.json_body["displayname"])
+
+        # Deactivate user
+        body = json.dumps({"erase": True})
+
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+            content=body.encode(encoding="utf_8"),
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Get user
+        channel = self.make_request(
+            "GET", self.url_other_user, access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(True, channel.json_body["deactivated"])
+        self.assertEqual(0, len(channel.json_body["threepids"]))
+        self.assertIsNone(channel.json_body["avatar_url"])
+        self.assertIsNone(channel.json_body["displayname"])
+
+        self._is_erased("@user:test", True)
+
+    def test_deactivate_user_erase_false(self):
+        """
+        Test deactivating an user and set `erase` to `false`
+        """
+
+        # Get user
+        channel = self.make_request(
+            "GET", self.url_other_user, access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(False, channel.json_body["deactivated"])
+        self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
+        self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+        self.assertEqual("User1", channel.json_body["displayname"])
+
+        # Deactivate user
+        body = json.dumps({"erase": False})
+
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+            content=body.encode(encoding="utf_8"),
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Get user
+        channel = self.make_request(
+            "GET", self.url_other_user, access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(True, channel.json_body["deactivated"])
+        self.assertEqual(0, len(channel.json_body["threepids"]))
+        self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+        self.assertEqual("User1", channel.json_body["displayname"])
+
+        self._is_erased("@user:test", False)
+
+    def _is_erased(self, user_id: str, expect: bool) -> None:
+        """Assert that the user is erased or not
+        """
+        d = self.store.is_user_erased(user_id)
+        if expect:
+            self.assertTrue(self.get_success(d))
+        else:
+            self.assertFalse(self.get_success(d))
+
 
 class UserRestTestCase(unittest.HomeserverTestCase):
 
@@ -508,7 +965,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
-        self.other_user = self.register_user("user", "pass")
+        self.other_user = self.register_user("user", "pass", displayname="User")
         self.other_user_token = self.login("user", "pass")
         self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote(
             self.other_user
@@ -520,14 +977,12 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         """
         url = "/_synapse/admin/v2/users/@bob:test"
 
-        request, channel = self.make_request(
-            "GET", url, access_token=self.other_user_token,
-        )
+        channel = self.make_request("GET", url, access_token=self.other_user_token,)
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("You are not a server admin", channel.json_body["error"])
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", url, access_token=self.other_user_token, content=b"{}",
         )
 
@@ -539,7 +994,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         Tests that a lookup for a user that does not exist returns a 404
         """
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_synapse/admin/v2/users/@unknown_person:test",
             access_token=self.admin_user_tok,
@@ -561,11 +1016,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
                 "admin": True,
                 "displayname": "Bob's name",
                 "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
-                "avatar_url": None,
+                "avatar_url": "mxc://fibble/wibble",
             }
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             url,
             access_token=self.admin_user_tok,
@@ -578,11 +1033,10 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
         self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
         self.assertEqual(True, channel.json_body["admin"])
+        self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
 
         # Get user
-        request, channel = self.make_request(
-            "GET", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob:test", channel.json_body["name"])
@@ -592,6 +1046,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(True, channel.json_body["admin"])
         self.assertEqual(False, channel.json_body["is_guest"])
         self.assertEqual(False, channel.json_body["deactivated"])
+        self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
 
     def test_create_user(self):
         """
@@ -606,10 +1061,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
                 "admin": False,
                 "displayname": "Bob's name",
                 "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
+                "avatar_url": "mxc://fibble/wibble",
             }
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             url,
             access_token=self.admin_user_tok,
@@ -622,11 +1078,10 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
         self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
         self.assertEqual(False, channel.json_body["admin"])
+        self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
 
         # Get user
-        request, channel = self.make_request(
-            "GET", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob:test", channel.json_body["name"])
@@ -636,6 +1091,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(False, channel.json_body["admin"])
         self.assertEqual(False, channel.json_body["is_guest"])
         self.assertEqual(False, channel.json_body["deactivated"])
+        self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
 
     @override_config(
         {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
@@ -651,9 +1107,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         # Sync to set admin user to active
         # before limit of monthly active users is reached
-        request, channel = self.make_request(
-            "GET", "/sync", access_token=self.admin_user_tok
-        )
+        channel = self.make_request("GET", "/sync", access_token=self.admin_user_tok)
 
         if channel.code != 200:
             raise HttpResponseException(
@@ -676,7 +1130,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         # Create user
         body = json.dumps({"password": "abc123", "admin": False})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             url,
             access_token=self.admin_user_tok,
@@ -715,7 +1169,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         # Create user
         body = json.dumps({"password": "abc123", "admin": False})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             url,
             access_token=self.admin_user_tok,
@@ -752,7 +1206,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
             }
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             url,
             access_token=self.admin_user_tok,
@@ -769,7 +1223,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
-        self.assertEqual("@bob:test", pushers[0]["user_name"])
+        self.assertEqual("@bob:test", pushers[0].user_name)
 
     @override_config(
         {
@@ -796,7 +1250,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
             }
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             url,
             access_token=self.admin_user_tok,
@@ -822,7 +1276,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         # Change password
         body = json.dumps({"password": "hahaha"})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
@@ -839,7 +1293,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         # Modify user
         body = json.dumps({"displayname": "foobar"})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
@@ -851,7 +1305,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("foobar", channel.json_body["displayname"])
 
         # Get user
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url_other_user, access_token=self.admin_user_tok,
         )
 
@@ -869,7 +1323,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
             {"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]}
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
@@ -882,7 +1336,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
 
         # Get user
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url_other_user, access_token=self.admin_user_tok,
         )
 
@@ -896,10 +1350,30 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         Test deactivating another user.
         """
 
+        # set attributes for user
+        self.get_success(
+            self.store.set_profile_avatar_url("user", "mxc://servername/mediaid")
+        )
+        self.get_success(
+            self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0)
+        )
+
+        # Get user
+        channel = self.make_request(
+            "GET", self.url_other_user, access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(False, channel.json_body["deactivated"])
+        self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
+        self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+        self.assertEqual("User", channel.json_body["displayname"])
+
         # Deactivate user
         body = json.dumps({"deactivated": True})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
@@ -909,16 +1383,70 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
         self.assertEqual(True, channel.json_body["deactivated"])
+        self.assertEqual(0, len(channel.json_body["threepids"]))
+        self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+        self.assertEqual("User", channel.json_body["displayname"])
         # the user is deactivated, the threepid will be deleted
 
         # Get user
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url_other_user, access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
         self.assertEqual(True, channel.json_body["deactivated"])
+        self.assertEqual(0, len(channel.json_body["threepids"]))
+        self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+        self.assertEqual("User", channel.json_body["displayname"])
+
+    @override_config({"user_directory": {"enabled": True, "search_all_users": True}})
+    def test_change_name_deactivate_user_user_directory(self):
+        """
+        Test change profile information of a deactivated user and
+        check that it does not appear in user directory
+        """
+
+        # is in user directory
+        profile = self.get_success(self.store.get_user_in_directory(self.other_user))
+        self.assertTrue(profile["display_name"] == "User")
+
+        # Deactivate user
+        body = json.dumps({"deactivated": True})
+
+        channel = self.make_request(
+            "PUT",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+            content=body.encode(encoding="utf_8"),
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(True, channel.json_body["deactivated"])
+
+        # is not in user directory
+        profile = self.get_success(self.store.get_user_in_directory(self.other_user))
+        self.assertTrue(profile is None)
+
+        # Set new displayname user
+        body = json.dumps({"displayname": "Foobar"})
+
+        channel = self.make_request(
+            "PUT",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+            content=body.encode(encoding="utf_8"),
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(True, channel.json_body["deactivated"])
+        self.assertEqual("Foobar", channel.json_body["displayname"])
+
+        # is not in user directory
+        profile = self.get_success(self.store.get_user_in_directory(self.other_user))
+        self.assertTrue(profile is None)
 
     def test_reactivate_user(self):
         """
@@ -926,7 +1454,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         """
 
         # Deactivate the user.
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
@@ -939,7 +1467,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self._is_erased("@user:test", True)
 
         # Attempt to reactivate the user (without a password).
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
@@ -948,7 +1476,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
 
         # Reactivate the user.
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
@@ -959,7 +1487,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
 
         # Get user
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url_other_user, access_token=self.admin_user_tok,
         )
 
@@ -976,7 +1504,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         # Set a user as an admin
         body = json.dumps({"admin": True})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
@@ -988,7 +1516,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(True, channel.json_body["admin"])
 
         # Get user
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url_other_user, access_token=self.admin_user_tok,
         )
 
@@ -1006,7 +1534,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         # Create user
         body = json.dumps({"password": "abc123"})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             url,
             access_token=self.admin_user_tok,
@@ -1018,9 +1546,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("bob", channel.json_body["displayname"])
 
         # Get user
-        request, channel = self.make_request(
-            "GET", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob:test", channel.json_body["name"])
@@ -1030,7 +1556,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         # Change password (and use a str for deactivate instead of a bool)
         body = json.dumps({"password": "abc123", "deactivated": "false"})  # oops!
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             url,
             access_token=self.admin_user_tok,
@@ -1040,9 +1566,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
 
         # Check user is not deactivated
-        request, channel = self.make_request(
-            "GET", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob:test", channel.json_body["name"])
@@ -1070,8 +1594,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
-
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
@@ -1084,7 +1606,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
         """
         Try to list rooms of an user without authentication.
         """
-        request, channel = self.make_request("GET", self.url, b"{}")
+        channel = self.make_request("GET", self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -1095,37 +1617,33 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
         """
         other_user_token = self.login("user", "pass")
 
-        request, channel = self.make_request(
-            "GET", self.url, access_token=other_user_token,
-        )
+        channel = self.make_request("GET", self.url, access_token=other_user_token,)
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
     def test_user_does_not_exist(self):
         """
-        Tests that a lookup for a user that does not exist returns a 404
+        Tests that a lookup for a user that does not exist returns an empty list
         """
         url = "/_synapse/admin/v1/users/@unknown_person:test/joined_rooms"
-        request, channel = self.make_request(
-            "GET", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
-        self.assertEqual(404, channel.code, msg=channel.json_body)
-        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual(0, channel.json_body["total"])
+        self.assertEqual(0, len(channel.json_body["joined_rooms"]))
 
     def test_user_is_not_local(self):
         """
-        Tests that a lookup for a user that is not a local returns a 400
+        Tests that a lookup for a user that is not a local and participates in no conversation returns an empty list
         """
         url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/joined_rooms"
 
-        request, channel = self.make_request(
-            "GET", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
-        self.assertEqual(400, channel.code, msg=channel.json_body)
-        self.assertEqual("Can only lookup local users", channel.json_body["error"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual(0, channel.json_body["total"])
+        self.assertEqual(0, len(channel.json_body["joined_rooms"]))
 
     def test_no_memberships(self):
         """
@@ -1133,9 +1651,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
         if user has no memberships
         """
         # Get rooms
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(0, channel.json_body["total"])
@@ -1152,14 +1668,55 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
             self.helper.create_room_as(self.other_user, tok=other_user_tok)
 
         # Get rooms
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(number_rooms, channel.json_body["total"])
         self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
 
+    def test_get_rooms_with_nonlocal_user(self):
+        """
+        Tests that a normal lookup for rooms is successful with a non-local user
+        """
+
+        other_user_tok = self.login("user", "pass")
+        event_builder_factory = self.hs.get_event_builder_factory()
+        event_creation_handler = self.hs.get_event_creation_handler()
+        storage = self.hs.get_storage()
+
+        # Create two rooms, one with a local user only and one with both a local
+        # and remote user.
+        self.helper.create_room_as(self.other_user, tok=other_user_tok)
+        local_and_remote_room_id = self.helper.create_room_as(
+            self.other_user, tok=other_user_tok
+        )
+
+        # Add a remote user to the room.
+        builder = event_builder_factory.for_room_version(
+            RoomVersions.V1,
+            {
+                "type": "m.room.member",
+                "sender": "@joiner:remote_hs",
+                "state_key": "@joiner:remote_hs",
+                "room_id": local_and_remote_room_id,
+                "content": {"membership": "join"},
+            },
+        )
+
+        event, context = self.get_success(
+            event_creation_handler.create_new_client_event(builder)
+        )
+
+        self.get_success(storage.persistence.persist_event(event, context))
+
+        # Now get rooms
+        url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms"
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual(1, channel.json_body["total"])
+        self.assertEqual([local_and_remote_room_id], channel.json_body["joined_rooms"])
+
 
 class PushersRestTestCase(unittest.HomeserverTestCase):
 
@@ -1183,7 +1740,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
         """
         Try to list pushers of an user without authentication.
         """
-        request, channel = self.make_request("GET", self.url, b"{}")
+        channel = self.make_request("GET", self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -1194,9 +1751,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
         """
         other_user_token = self.login("user", "pass")
 
-        request, channel = self.make_request(
-            "GET", self.url, access_token=other_user_token,
-        )
+        channel = self.make_request("GET", self.url, access_token=other_user_token,)
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -1206,9 +1761,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
         Tests that a lookup for a user that does not exist returns a 404
         """
         url = "/_synapse/admin/v1/users/@unknown_person:test/pushers"
-        request, channel = self.make_request(
-            "GET", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -1219,9 +1772,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
         """
         url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers"
 
-        request, channel = self.make_request(
-            "GET", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -1232,9 +1783,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
         """
 
         # Get pushers
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(0, channel.json_body["total"])
@@ -1256,14 +1805,12 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
                 device_display_name="pushy push",
                 pushkey="a@example.com",
                 lang=None,
-                data={"url": "example.com"},
+                data={"url": "https://example.com/_matrix/push/v1/notify"},
             )
         )
 
         # Get pushers
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(1, channel.json_body["total"])
@@ -1287,7 +1834,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
         self.media_repo = hs.get_media_repository_resource()
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -1302,7 +1848,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         """
         Try to list media of an user without authentication.
         """
-        request, channel = self.make_request("GET", self.url, b"{}")
+        channel = self.make_request("GET", self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -1313,9 +1859,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         """
         other_user_token = self.login("user", "pass")
 
-        request, channel = self.make_request(
-            "GET", self.url, access_token=other_user_token,
-        )
+        channel = self.make_request("GET", self.url, access_token=other_user_token,)
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -1325,9 +1869,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         Tests that a lookup for a user that does not exist returns a 404
         """
         url = "/_synapse/admin/v1/users/@unknown_person:test/media"
-        request, channel = self.make_request(
-            "GET", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -1338,9 +1880,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         """
         url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media"
 
-        request, channel = self.make_request(
-            "GET", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -1354,7 +1894,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         other_user_tok = self.login("user", "pass")
         self._create_media(other_user_tok, number_media)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
         )
 
@@ -1373,7 +1913,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         other_user_tok = self.login("user", "pass")
         self._create_media(other_user_tok, number_media)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?from=5", access_token=self.admin_user_tok,
         )
 
@@ -1392,7 +1932,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         other_user_tok = self.login("user", "pass")
         self._create_media(other_user_tok, number_media)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
         )
 
@@ -1407,7 +1947,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         Testing that a negative limit parameter returns a 400
         """
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
         )
 
@@ -1419,7 +1959,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         Testing that a negative from parameter returns a 400
         """
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
         )
 
@@ -1437,7 +1977,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
 
         #  `next_token` does not appear
         # Number of results is the number of entries
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
         )
 
@@ -1448,7 +1988,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
 
         #  `next_token` does not appear
         # Number of max results is larger than the number of entries
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
         )
 
@@ -1459,7 +1999,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
 
         #  `next_token` does appear
         # Number of max results is smaller than the number of entries
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
         )
 
@@ -1471,7 +2011,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         # Check
         # Set `from` to value of `next_token` for request remaining entries
         #  `next_token` does not appear
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?from=19", access_token=self.admin_user_tok,
         )
 
@@ -1486,9 +2026,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         if user has no media created
         """
 
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(0, channel.json_body["total"])
@@ -1503,9 +2041,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         other_user_tok = self.login("user", "pass")
         self._create_media(other_user_tok, number_media)
 
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(number_media, channel.json_body["total"])
@@ -1571,7 +2107,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
         )
 
     def _get_token(self) -> str:
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", self.url, b"{}", access_token=self.admin_user_tok
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1580,7 +2116,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
     def test_no_auth(self):
         """Try to login as a user without authentication.
         """
-        request, channel = self.make_request("POST", self.url, b"{}")
+        channel = self.make_request("POST", self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -1588,7 +2124,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
     def test_not_admin(self):
         """Try to login as a user as a non-admin user.
         """
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", self.url, b"{}", access_token=self.other_user_tok
         )
 
@@ -1616,7 +2152,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
         self._get_token()
 
         # Check that we don't see a new device in our devices list
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "devices", b"{}", access_token=self.other_user_tok
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1631,25 +2167,19 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
         puppet_token = self._get_token()
 
         # Test that we can successfully make a request
-        request, channel = self.make_request(
-            "GET", "devices", b"{}", access_token=puppet_token
-        )
+        channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
 
         # Logout with the puppet token
-        request, channel = self.make_request(
-            "POST", "logout", b"{}", access_token=puppet_token
-        )
+        channel = self.make_request("POST", "logout", b"{}", access_token=puppet_token)
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
 
         # The puppet token should no longer work
-        request, channel = self.make_request(
-            "GET", "devices", b"{}", access_token=puppet_token
-        )
+        channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
 
         # .. but the real user's tokens should still work
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "devices", b"{}", access_token=self.other_user_tok
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1662,25 +2192,21 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
         puppet_token = self._get_token()
 
         # Test that we can successfully make a request
-        request, channel = self.make_request(
-            "GET", "devices", b"{}", access_token=puppet_token
-        )
+        channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
 
         # Logout all with the real user token
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", "logout/all", b"{}", access_token=self.other_user_tok
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
 
         # The puppet token should still work
-        request, channel = self.make_request(
-            "GET", "devices", b"{}", access_token=puppet_token
-        )
+        channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
 
         # .. but the real user's tokens shouldn't
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "devices", b"{}", access_token=self.other_user_tok
         )
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
@@ -1693,25 +2219,21 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
         puppet_token = self._get_token()
 
         # Test that we can successfully make a request
-        request, channel = self.make_request(
-            "GET", "devices", b"{}", access_token=puppet_token
-        )
+        channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
 
         # Logout all with the admin user token
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", "logout/all", b"{}", access_token=self.admin_user_tok
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
 
         # The puppet token should no longer work
-        request, channel = self.make_request(
-            "GET", "devices", b"{}", access_token=puppet_token
-        )
+        channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
 
         # .. but the real user's tokens should still work
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "devices", b"{}", access_token=self.other_user_tok
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1778,8 +2300,6 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
-
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
@@ -1793,11 +2313,11 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
         """
         Try to get information of an user without authentication.
         """
-        request, channel = self.make_request("GET", self.url1, b"{}")
+        channel = self.make_request("GET", self.url1, b"{}")
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
 
-        request, channel = self.make_request("GET", self.url2, b"{}")
+        channel = self.make_request("GET", self.url2, b"{}")
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
 
@@ -1808,15 +2328,11 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
         self.register_user("user2", "pass")
         other_user2_token = self.login("user2", "pass")
 
-        request, channel = self.make_request(
-            "GET", self.url1, access_token=other_user2_token,
-        )
+        channel = self.make_request("GET", self.url1, access_token=other_user2_token,)
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
-        request, channel = self.make_request(
-            "GET", self.url2, access_token=other_user2_token,
-        )
+        channel = self.make_request("GET", self.url2, access_token=other_user2_token,)
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
@@ -1827,15 +2343,11 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
         url1 = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
         url2 = "/_matrix/client/r0/admin/whois/@unknown_person:unknown_domain"
 
-        request, channel = self.make_request(
-            "GET", url1, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url1, access_token=self.admin_user_tok,)
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Can only whois a local user", channel.json_body["error"])
 
-        request, channel = self.make_request(
-            "GET", url2, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url2, access_token=self.admin_user_tok,)
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Can only whois a local user", channel.json_body["error"])
 
@@ -1843,16 +2355,12 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
         """
         The lookup should succeed for an admin.
         """
-        request, channel = self.make_request(
-            "GET", self.url1, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url1, access_token=self.admin_user_tok,)
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(self.other_user, channel.json_body["user_id"])
         self.assertIn("devices", channel.json_body)
 
-        request, channel = self.make_request(
-            "GET", self.url2, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url2, access_token=self.admin_user_tok,)
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(self.other_user, channel.json_body["user_id"])
         self.assertIn("devices", channel.json_body)
@@ -1863,16 +2371,76 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
         """
         other_user_token = self.login("user", "pass")
 
-        request, channel = self.make_request(
-            "GET", self.url1, access_token=other_user_token,
-        )
+        channel = self.make_request("GET", self.url1, access_token=other_user_token,)
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(self.other_user, channel.json_body["user_id"])
         self.assertIn("devices", channel.json_body)
 
-        request, channel = self.make_request(
-            "GET", self.url2, access_token=other_user_token,
-        )
+        channel = self.make_request("GET", self.url2, access_token=other_user_token,)
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(self.other_user, channel.json_body["user_id"])
         self.assertIn("devices", channel.json_body)
+
+
+class ShadowBanRestTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.other_user = self.register_user("user", "pass")
+
+        self.url = "/_synapse/admin/v1/users/%s/shadow_ban" % urllib.parse.quote(
+            self.other_user
+        )
+
+    def test_no_auth(self):
+        """
+        Try to get information of an user without authentication.
+        """
+        channel = self.make_request("POST", self.url)
+        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_requester_is_not_admin(self):
+        """
+        If the user is not a server admin, an error is returned.
+        """
+        other_user_token = self.login("user", "pass")
+
+        channel = self.make_request("POST", self.url, access_token=other_user_token)
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_user_is_not_local(self):
+        """
+        Tests that shadow-banning for a user that is not a local returns a 400
+        """
+        url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
+
+        channel = self.make_request("POST", url, access_token=self.admin_user_tok)
+        self.assertEqual(400, channel.code, msg=channel.json_body)
+
+    def test_success(self):
+        """
+        Shadow-banning should succeed for an admin.
+        """
+        # The user starts off as not shadow-banned.
+        other_user_token = self.login("user", "pass")
+        result = self.get_success(self.store.get_user_by_access_token(other_user_token))
+        self.assertFalse(result.shadow_banned)
+
+        channel = self.make_request("POST", self.url, access_token=self.admin_user_tok)
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual({}, channel.json_body)
+
+        # Ensure the user is shadow-banned (and the cache was cleared).
+        result = self.get_success(self.store.get_user_by_access_token(other_user_token))
+        self.assertTrue(result.shadow_banned)
diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py
index e2e6a5e16d..c74693e9b2 100644
--- a/tests/rest/client/test_consent.py
+++ b/tests/rest/client/test_consent.py
@@ -61,7 +61,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
     def test_render_public_consent(self):
         """You can observe the terms form without specifying a user"""
         resource = consent_resource.ConsentResource(self.hs)
-        request, channel = make_request(
+        channel = make_request(
             self.reactor, FakeSite(resource), "GET", "/consent?v=1", shorthand=False
         )
         self.assertEqual(channel.code, 200)
@@ -82,7 +82,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
             uri_builder.build_user_consent_uri(user_id).replace("_matrix/", "")
             + "&u=user"
         )
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             FakeSite(resource),
             "GET",
@@ -97,7 +97,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
         self.assertEqual(consented, "False")
 
         # POST to the consent page, saying we've agreed
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             FakeSite(resource),
             "POST",
@@ -109,7 +109,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
 
         # Fetch the consent page, to get the consent version -- it should have
         # changed
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             FakeSite(resource),
             "GET",
diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py
index a1ccc4ee9a..56937dcd2e 100644
--- a/tests/rest/client/test_ephemeral_message.py
+++ b/tests/rest/client/test_ephemeral_message.py
@@ -93,7 +93,7 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase):
     def get_event(self, room_id, event_id, expected_code=200):
         url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
 
-        request, channel = self.make_request("GET", url)
+        channel = self.make_request("GET", url)
 
         self.assertEqual(channel.code, expected_code, channel.result)
 
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index 259c6a1985..c0a9fc6925 100644
--- a/tests/rest/client/test_identity.py
+++ b/tests/rest/client/test_identity.py
@@ -43,9 +43,7 @@ class IdentityTestCase(unittest.HomeserverTestCase):
         self.register_user("kermit", "monkey")
         tok = self.login("kermit", "monkey")
 
-        request, channel = self.make_request(
-            b"POST", "/createRoom", b"{}", access_token=tok
-        )
+        channel = self.make_request(b"POST", "/createRoom", b"{}", access_token=tok)
         self.assertEquals(channel.result["code"], b"200", channel.result)
         room_id = channel.json_body["room_id"]
 
@@ -56,7 +54,7 @@ class IdentityTestCase(unittest.HomeserverTestCase):
         }
         request_data = json.dumps(params)
         request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
-        request, channel = self.make_request(
+        channel = self.make_request(
             b"POST", request_url, request_data, access_token=tok
         )
         self.assertEquals(channel.result["code"], b"403", channel.result)
diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py
index c1f516cc93..f0707646bb 100644
--- a/tests/rest/client/test_redactions.py
+++ b/tests/rest/client/test_redactions.py
@@ -69,16 +69,12 @@ class RedactionsTestCase(HomeserverTestCase):
         """
         path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id)
 
-        request, channel = self.make_request(
-            "POST", path, content={}, access_token=access_token
-        )
+        channel = self.make_request("POST", path, content={}, access_token=access_token)
         self.assertEqual(int(channel.result["code"]), expect_code)
         return channel.json_body
 
     def _sync_room_timeline(self, access_token, room_id):
-        request, channel = self.make_request(
-            "GET", "sync", access_token=self.mod_access_token
-        )
+        channel = self.make_request("GET", "sync", access_token=self.mod_access_token)
         self.assertEqual(channel.result["code"], b"200")
         room_sync = channel.json_body["rooms"]["join"][room_id]
         return room_sync["timeline"]["events"]
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index f56b5d9231..31dc832fd5 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -325,7 +325,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
     def get_event(self, room_id, event_id, expected_code=200):
         url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
 
-        request, channel = self.make_request("GET", url, access_token=self.token)
+        channel = self.make_request("GET", url, access_token=self.token)
 
         self.assertEqual(channel.code, expected_code, channel.result)
 
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index 94dcfb9f7c..0ebdf1415b 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -18,6 +18,7 @@ import synapse.rest.admin
 from synapse.api.constants import EventTypes
 from synapse.rest.client.v1 import directory, login, profile, room
 from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet
+from synapse.types import UserID
 
 from tests import unittest
 
@@ -31,12 +32,7 @@ class _ShadowBannedBase(unittest.HomeserverTestCase):
         self.store = self.hs.get_datastore()
 
         self.get_success(
-            self.store.db_pool.simple_update(
-                table="users",
-                keyvalues={"name": self.banned_user_id},
-                updatevalues={"shadow_banned": True},
-                desc="shadow_ban",
-            )
+            self.store.set_shadow_banned(UserID.from_string(self.banned_user_id), True)
         )
 
         self.other_user_id = self.register_user("otheruser", "pass")
@@ -89,7 +85,7 @@ class RoomTestCase(_ShadowBannedBase):
         )
 
         # Inviting the user completes successfully.
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/rooms/%s/invite" % (room_id,),
             {"id_server": "test", "medium": "email", "address": "test@test.test"},
@@ -103,7 +99,7 @@ class RoomTestCase(_ShadowBannedBase):
     def test_create_room(self):
         """Invitations during a room creation should be discarded, but the room still gets created."""
         # The room creation is successful.
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/r0/createRoom",
             {"visibility": "public", "invite": [self.other_user_id]},
@@ -158,7 +154,7 @@ class RoomTestCase(_ShadowBannedBase):
             self.banned_user_id, tok=self.banned_access_token
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/r0/rooms/%s/upgrade" % (room_id,),
             {"new_version": "6"},
@@ -183,7 +179,7 @@ class RoomTestCase(_ShadowBannedBase):
             self.banned_user_id, tok=self.banned_access_token
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/rooms/%s/typing/%s" % (room_id, self.banned_user_id),
             {"typing": True, "timeout": 30000},
@@ -198,7 +194,7 @@ class RoomTestCase(_ShadowBannedBase):
         # The other user can join and send typing events.
         self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/rooms/%s/typing/%s" % (room_id, self.other_user_id),
             {"typing": True, "timeout": 30000},
@@ -244,7 +240,7 @@ class ProfileTestCase(_ShadowBannedBase):
         )
 
         # The update should succeed.
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/_matrix/client/r0/profile/%s/displayname" % (self.banned_user_id,),
             {"displayname": new_display_name},
@@ -254,7 +250,7 @@ class ProfileTestCase(_ShadowBannedBase):
         self.assertEqual(channel.json_body, {})
 
         # The user's display name should be updated.
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/profile/%s/displayname" % (self.banned_user_id,)
         )
         self.assertEqual(channel.code, 200, channel.result)
@@ -282,7 +278,7 @@ class ProfileTestCase(_ShadowBannedBase):
         )
 
         # The update should succeed.
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/_matrix/client/r0/rooms/%s/state/m.room.member/%s"
             % (room_id, self.banned_user_id),
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 0e96697f9b..227fffab58 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -86,7 +86,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         callback = Mock(spec=[], side_effect=check)
         current_rules_module().check_event_allowed = callback
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % self.room_id,
             {},
@@ -104,7 +104,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
             self.assertEqual(ev.type, k[0])
             self.assertEqual(ev.state_key, k[1])
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/2" % self.room_id,
             {},
@@ -123,7 +123,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         current_rules_module().check_event_allowed = check
 
         # now send the event
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
             {"x": "x"},
@@ -142,7 +142,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         current_rules_module().check_event_allowed = check
 
         # now send the event
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
             {"x": "x"},
@@ -152,7 +152,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         event_id = channel.json_body["event_id"]
 
         # ... and check that it got modified
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
             access_token=self.tok,
diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py
index 7a2c653df8..edd1d184f8 100644
--- a/tests/rest/client/v1/test_directory.py
+++ b/tests/rest/client/v1/test_directory.py
@@ -91,7 +91,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         # that we can make sure that the check is done on the whole alias.
         data = {"room_alias_name": random_string(256 - len(self.hs.hostname))}
         request_data = json.dumps(data)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", url, request_data, access_token=self.user_tok
         )
         self.assertEqual(channel.code, 400, channel.result)
@@ -104,7 +104,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         # as cautious as possible here.
         data = {"room_alias_name": random_string(5)}
         request_data = json.dumps(data)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", url, request_data, access_token=self.user_tok
         )
         self.assertEqual(channel.code, 200, channel.result)
@@ -118,7 +118,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         data = {"aliases": [self.random_alias(alias_length)]}
         request_data = json.dumps(data)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", url, request_data, access_token=self.user_tok
         )
         self.assertEqual(channel.code, expected_code, channel.result)
@@ -128,7 +128,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         data = {"room_id": self.room_id}
         request_data = json.dumps(data)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", url, request_data, access_token=self.user_tok
         )
         self.assertEqual(channel.code, expected_code, channel.result)
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index 12a93f5687..0a5ca317ea 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -63,13 +63,13 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
         # implementation is now part of the r0 implementation, the newer
         # behaviour is used instead to be consistent with the r0 spec.
         # see issue #2602
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/events?access_token=%s" % ("invalid" + self.token,)
         )
         self.assertEquals(channel.code, 401, msg=channel.result)
 
         # valid token, expect content
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/events?access_token=%s&timeout=0" % (self.token,)
         )
         self.assertEquals(channel.code, 200, msg=channel.result)
@@ -87,7 +87,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
         )
 
         # valid token, expect content
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/events?access_token=%s&timeout=0" % (self.token,)
         )
         self.assertEquals(channel.code, 200, msg=channel.result)
@@ -149,7 +149,7 @@ class GetEventsTestCase(unittest.HomeserverTestCase):
         resp = self.helper.send(self.room_id, tok=self.token)
         event_id = resp["event_id"]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/events/" + event_id, access_token=self.token,
         )
         self.assertEquals(channel.code, 200, msg=channel.result)
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 176ddf7ec9..bfcb786af8 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -1,23 +1,83 @@
-import json
+# -*- coding: utf-8 -*-
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import time
 import urllib.parse
+from typing import Any, Dict, Union
+from urllib.parse import urlencode
 
 from mock import Mock
 
-import jwt
+import pymacaroons
+
+from twisted.web.resource import Resource
 
 import synapse.rest.admin
 from synapse.appservice import ApplicationService
 from synapse.rest.client.v1 import login, logout
 from synapse.rest.client.v2_alpha import devices, register
 from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
+from synapse.types import create_requester
 
 from tests import unittest
-from tests.unittest import override_config
+from tests.handlers.test_oidc import HAS_OIDC
+from tests.handlers.test_saml import has_saml2
+from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
+from tests.test_utils.html_parsers import TestHtmlParser
+from tests.unittest import HomeserverTestCase, override_config, skip_unless
+
+try:
+    import jwt
+
+    HAS_JWT = True
+except ImportError:
+    HAS_JWT = False
+
+
+# public_base_url used in some tests
+BASE_URL = "https://synapse/"
+
+# CAS server used in some tests
+CAS_SERVER = "https://fake.test"
+
+# just enough to tell pysaml2 where to redirect to
+SAML_SERVER = "https://test.saml.server/idp/sso"
+TEST_SAML_METADATA = """
+<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata">
+  <md:IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
+      <md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="%(SAML_SERVER)s"/>
+  </md:IDPSSODescriptor>
+</md:EntityDescriptor>
+""" % {
+    "SAML_SERVER": SAML_SERVER,
+}
 
 LOGIN_URL = b"/_matrix/client/r0/login"
 TEST_URL = b"/_matrix/client/r0/account/whoami"
 
+# a (valid) url with some annoying characters in.  %3D is =, %26 is &, %2B is +
+TEST_CLIENT_REDIRECT_URL = 'https://x?<ab c>&q"+%3D%2B"="fö%26=o"'
+
+# the query params in TEST_CLIENT_REDIRECT_URL
+EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fö&=o"')]
+
+# (possibly experimental) login flows we expect to appear in the list after the normal
+# ones
+ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}]
+
 
 class LoginRestServletTestCase(unittest.HomeserverTestCase):
 
@@ -63,7 +123,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
                 "identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
                 "password": "monkey",
             }
-            request, channel = self.make_request(b"POST", LOGIN_URL, params)
+            channel = self.make_request(b"POST", LOGIN_URL, params)
 
             if i == 5:
                 self.assertEquals(channel.result["code"], b"429", channel.result)
@@ -82,7 +142,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             "identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
             "password": "monkey",
         }
-        request, channel = self.make_request(b"POST", LOGIN_URL, params)
+        channel = self.make_request(b"POST", LOGIN_URL, params)
 
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
@@ -108,7 +168,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
                 "identifier": {"type": "m.id.user", "user": "kermit"},
                 "password": "monkey",
             }
-            request, channel = self.make_request(b"POST", LOGIN_URL, params)
+            channel = self.make_request(b"POST", LOGIN_URL, params)
 
             if i == 5:
                 self.assertEquals(channel.result["code"], b"429", channel.result)
@@ -127,7 +187,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             "identifier": {"type": "m.id.user", "user": "kermit"},
             "password": "monkey",
         }
-        request, channel = self.make_request(b"POST", LOGIN_URL, params)
+        channel = self.make_request(b"POST", LOGIN_URL, params)
 
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
@@ -153,7 +213,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
                 "identifier": {"type": "m.id.user", "user": "kermit"},
                 "password": "notamonkey",
             }
-            request, channel = self.make_request(b"POST", LOGIN_URL, params)
+            channel = self.make_request(b"POST", LOGIN_URL, params)
 
             if i == 5:
                 self.assertEquals(channel.result["code"], b"429", channel.result)
@@ -172,7 +232,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             "identifier": {"type": "m.id.user", "user": "kermit"},
             "password": "notamonkey",
         }
-        request, channel = self.make_request(b"POST", LOGIN_URL, params)
+        channel = self.make_request(b"POST", LOGIN_URL, params)
 
         self.assertEquals(channel.result["code"], b"403", channel.result)
 
@@ -181,7 +241,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         self.register_user("kermit", "monkey")
 
         # we shouldn't be able to make requests without an access token
-        request, channel = self.make_request(b"GET", TEST_URL)
+        channel = self.make_request(b"GET", TEST_URL)
         self.assertEquals(channel.result["code"], b"401", channel.result)
         self.assertEquals(channel.json_body["errcode"], "M_MISSING_TOKEN")
 
@@ -191,25 +251,21 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             "identifier": {"type": "m.id.user", "user": "kermit"},
             "password": "monkey",
         }
-        request, channel = self.make_request(b"POST", LOGIN_URL, params)
+        channel = self.make_request(b"POST", LOGIN_URL, params)
 
         self.assertEquals(channel.code, 200, channel.result)
         access_token = channel.json_body["access_token"]
         device_id = channel.json_body["device_id"]
 
         # we should now be able to make requests with the access token
-        request, channel = self.make_request(
-            b"GET", TEST_URL, access_token=access_token
-        )
+        channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
         self.assertEquals(channel.code, 200, channel.result)
 
         # time passes
         self.reactor.advance(24 * 3600)
 
         # ... and we should be soft-logouted
-        request, channel = self.make_request(
-            b"GET", TEST_URL, access_token=access_token
-        )
+        channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
         self.assertEquals(channel.code, 401, channel.result)
         self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
         self.assertEquals(channel.json_body["soft_logout"], True)
@@ -223,9 +279,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
 
         # more requests with the expired token should still return a soft-logout
         self.reactor.advance(3600)
-        request, channel = self.make_request(
-            b"GET", TEST_URL, access_token=access_token
-        )
+        channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
         self.assertEquals(channel.code, 401, channel.result)
         self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
         self.assertEquals(channel.json_body["soft_logout"], True)
@@ -233,16 +287,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         # ... but if we delete that device, it will be a proper logout
         self._delete_device(access_token_2, "kermit", "monkey", device_id)
 
-        request, channel = self.make_request(
-            b"GET", TEST_URL, access_token=access_token
-        )
+        channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
         self.assertEquals(channel.code, 401, channel.result)
         self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
         self.assertEquals(channel.json_body["soft_logout"], False)
 
     def _delete_device(self, access_token, user_id, password, device_id):
         """Perform the UI-Auth to delete a device"""
-        request, channel = self.make_request(
+        channel = self.make_request(
             b"DELETE", "devices/" + device_id, access_token=access_token
         )
         self.assertEquals(channel.code, 401, channel.result)
@@ -262,7 +314,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             "session": channel.json_body["session"],
         }
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             b"DELETE",
             "devices/" + device_id,
             access_token=access_token,
@@ -278,26 +330,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         access_token = self.login("kermit", "monkey")
 
         # we should now be able to make requests with the access token
-        request, channel = self.make_request(
-            b"GET", TEST_URL, access_token=access_token
-        )
+        channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
         self.assertEquals(channel.code, 200, channel.result)
 
         # time passes
         self.reactor.advance(24 * 3600)
 
         # ... and we should be soft-logouted
-        request, channel = self.make_request(
-            b"GET", TEST_URL, access_token=access_token
-        )
+        channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
         self.assertEquals(channel.code, 401, channel.result)
         self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
         self.assertEquals(channel.json_body["soft_logout"], True)
 
         # Now try to hard logout this session
-        request, channel = self.make_request(
-            b"POST", "/logout", access_token=access_token
-        )
+        channel = self.make_request(b"POST", "/logout", access_token=access_token)
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
     @override_config({"session_lifetime": "24h"})
@@ -308,29 +354,313 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         access_token = self.login("kermit", "monkey")
 
         # we should now be able to make requests with the access token
-        request, channel = self.make_request(
-            b"GET", TEST_URL, access_token=access_token
-        )
+        channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
         self.assertEquals(channel.code, 200, channel.result)
 
         # time passes
         self.reactor.advance(24 * 3600)
 
         # ... and we should be soft-logouted
-        request, channel = self.make_request(
-            b"GET", TEST_URL, access_token=access_token
-        )
+        channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
         self.assertEquals(channel.code, 401, channel.result)
         self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
         self.assertEquals(channel.json_body["soft_logout"], True)
 
         # Now try to hard log out all of the user's sessions
-        request, channel = self.make_request(
-            b"POST", "/logout/all", access_token=access_token
-        )
+        channel = self.make_request(b"POST", "/logout/all", access_token=access_token)
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
 
+@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
+class MultiSSOTestCase(unittest.HomeserverTestCase):
+    """Tests for homeservers with multiple SSO providers enabled"""
+
+    servlets = [
+        login.register_servlets,
+    ]
+
+    def default_config(self) -> Dict[str, Any]:
+        config = super().default_config()
+
+        config["public_baseurl"] = BASE_URL
+
+        config["cas_config"] = {
+            "enabled": True,
+            "server_url": CAS_SERVER,
+            "service_url": "https://matrix.goodserver.com:8448",
+        }
+
+        config["saml2_config"] = {
+            "sp_config": {
+                "metadata": {"inline": [TEST_SAML_METADATA]},
+                # use the XMLSecurity backend to avoid relying on xmlsec1
+                "crypto_backend": "XMLSecurity",
+            },
+        }
+
+        # default OIDC provider
+        config["oidc_config"] = TEST_OIDC_CONFIG
+
+        # additional OIDC providers
+        config["oidc_providers"] = [
+            {
+                "idp_id": "idp1",
+                "idp_name": "IDP1",
+                "discover": False,
+                "issuer": "https://issuer1",
+                "client_id": "test-client-id",
+                "client_secret": "test-client-secret",
+                "scopes": ["profile"],
+                "authorization_endpoint": "https://issuer1/auth",
+                "token_endpoint": "https://issuer1/token",
+                "userinfo_endpoint": "https://issuer1/userinfo",
+                "user_mapping_provider": {
+                    "config": {"localpart_template": "{{ user.sub }}"}
+                },
+            }
+        ]
+        return config
+
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        d = super().create_resource_dict()
+        d.update(build_synapse_client_resource_tree(self.hs))
+        return d
+
+    def test_get_login_flows(self):
+        """GET /login should return password and SSO flows"""
+        channel = self.make_request("GET", "/_matrix/client/r0/login")
+        self.assertEqual(channel.code, 200, channel.result)
+
+        expected_flows = [
+            {"type": "m.login.cas"},
+            {"type": "m.login.sso"},
+            {"type": "m.login.token"},
+            {"type": "m.login.password"},
+        ] + ADDITIONAL_LOGIN_FLOWS
+
+        self.assertCountEqual(channel.json_body["flows"], expected_flows)
+
+    @override_config({"experimental_features": {"msc2858_enabled": True}})
+    def test_get_msc2858_login_flows(self):
+        """The SSO flow should include IdP info if MSC2858 is enabled"""
+        channel = self.make_request("GET", "/_matrix/client/r0/login")
+        self.assertEqual(channel.code, 200, channel.result)
+
+        # stick the flows results in a dict by type
+        flow_results = {}  # type: Dict[str, Any]
+        for f in channel.json_body["flows"]:
+            flow_type = f["type"]
+            self.assertNotIn(
+                flow_type, flow_results, "duplicate flow type %s" % (flow_type,)
+            )
+            flow_results[flow_type] = f
+
+        self.assertIn("m.login.sso", flow_results, "m.login.sso was not returned")
+        sso_flow = flow_results.pop("m.login.sso")
+        # we should have a set of IdPs
+        self.assertCountEqual(
+            sso_flow["org.matrix.msc2858.identity_providers"],
+            [
+                {"id": "cas", "name": "CAS"},
+                {"id": "saml", "name": "SAML"},
+                {"id": "oidc-idp1", "name": "IDP1"},
+                {"id": "oidc", "name": "OIDC"},
+            ],
+        )
+
+        # the rest of the flows are simple
+        expected_flows = [
+            {"type": "m.login.cas"},
+            {"type": "m.login.token"},
+            {"type": "m.login.password"},
+        ] + ADDITIONAL_LOGIN_FLOWS
+
+        self.assertCountEqual(flow_results.values(), expected_flows)
+
+    def test_multi_sso_redirect(self):
+        """/login/sso/redirect should redirect to an identity picker"""
+        # first hit the redirect url, which should redirect to our idp picker
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/r0/login/sso/redirect?redirectUrl="
+            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
+        )
+        self.assertEqual(channel.code, 302, channel.result)
+        uri = channel.headers.getRawHeaders("Location")[0]
+
+        # hitting that picker should give us some HTML
+        channel = self.make_request("GET", uri)
+        self.assertEqual(channel.code, 200, channel.result)
+
+        # parse the form to check it has fields assumed elsewhere in this class
+        p = TestHtmlParser()
+        p.feed(channel.result["body"].decode("utf-8"))
+        p.close()
+
+        self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "oidc-idp1", "saml"])
+
+        self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL)
+
+    def test_multi_sso_redirect_to_cas(self):
+        """If CAS is chosen, should redirect to the CAS server"""
+
+        channel = self.make_request(
+            "GET",
+            "/_synapse/client/pick_idp?redirectUrl="
+            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+            + "&idp=cas",
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 302, channel.result)
+        cas_uri = channel.headers.getRawHeaders("Location")[0]
+        cas_uri_path, cas_uri_query = cas_uri.split("?", 1)
+
+        # it should redirect us to the login page of the cas server
+        self.assertEqual(cas_uri_path, CAS_SERVER + "/login")
+
+        # check that the redirectUrl is correctly encoded in the service param - ie, the
+        # place that CAS will redirect to
+        cas_uri_params = urllib.parse.parse_qs(cas_uri_query)
+        service_uri = cas_uri_params["service"][0]
+        _, service_uri_query = service_uri.split("?", 1)
+        service_uri_params = urllib.parse.parse_qs(service_uri_query)
+        self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL)
+
+    def test_multi_sso_redirect_to_saml(self):
+        """If SAML is chosen, should redirect to the SAML server"""
+        channel = self.make_request(
+            "GET",
+            "/_synapse/client/pick_idp?redirectUrl="
+            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+            + "&idp=saml",
+        )
+        self.assertEqual(channel.code, 302, channel.result)
+        saml_uri = channel.headers.getRawHeaders("Location")[0]
+        saml_uri_path, saml_uri_query = saml_uri.split("?", 1)
+
+        # it should redirect us to the login page of the SAML server
+        self.assertEqual(saml_uri_path, SAML_SERVER)
+
+        # the RelayState is used to carry the client redirect url
+        saml_uri_params = urllib.parse.parse_qs(saml_uri_query)
+        relay_state_param = saml_uri_params["RelayState"][0]
+        self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL)
+
+    def test_login_via_oidc(self):
+        """If OIDC is chosen, should redirect to the OIDC auth endpoint"""
+
+        # pick the default OIDC provider
+        channel = self.make_request(
+            "GET",
+            "/_synapse/client/pick_idp?redirectUrl="
+            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+            + "&idp=oidc",
+        )
+        self.assertEqual(channel.code, 302, channel.result)
+        oidc_uri = channel.headers.getRawHeaders("Location")[0]
+        oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
+
+        # it should redirect us to the auth page of the OIDC server
+        self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+
+        # ... and should have set a cookie including the redirect url
+        cookies = dict(
+            h.split(";")[0].split("=", maxsplit=1)
+            for h in channel.headers.getRawHeaders("Set-Cookie")
+        )
+
+        oidc_session_cookie = cookies["oidc_session"]
+        macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie)
+        self.assertEqual(
+            self._get_value_from_macaroon(macaroon, "client_redirect_url"),
+            TEST_CLIENT_REDIRECT_URL,
+        )
+
+        channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})
+
+        # that should serve a confirmation page
+        self.assertEqual(channel.code, 200, channel.result)
+        self.assertTrue(
+            channel.headers.getRawHeaders("Content-Type")[-1].startswith("text/html")
+        )
+        p = TestHtmlParser()
+        p.feed(channel.text_body)
+        p.close()
+
+        # ... which should contain our redirect link
+        self.assertEqual(len(p.links), 1)
+        path, query = p.links[0].split("?", 1)
+        self.assertEqual(path, "https://x")
+
+        # it will have url-encoded the params properly, so we'll have to parse them
+        params = urllib.parse.parse_qsl(
+            query, keep_blank_values=True, strict_parsing=True, errors="strict"
+        )
+        self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS)
+        self.assertEqual(params[2][0], "loginToken")
+
+        # finally, submit the matrix login token to the login API, which gives us our
+        # matrix access token, mxid, and device id.
+        login_token = params[2][1]
+        chan = self.make_request(
+            "POST", "/login", content={"type": "m.login.token", "token": login_token},
+        )
+        self.assertEqual(chan.code, 200, chan.result)
+        self.assertEqual(chan.json_body["user_id"], "@user1:test")
+
+    def test_multi_sso_redirect_to_unknown(self):
+        """An unknown IdP should cause a 400"""
+        channel = self.make_request(
+            "GET", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
+        )
+        self.assertEqual(channel.code, 400, channel.result)
+
+    def test_client_idp_redirect_msc2858_disabled(self):
+        """If the client tries to pick an IdP but MSC2858 is disabled, return a 400"""
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
+            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
+        )
+        self.assertEqual(channel.code, 400, channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
+
+    @override_config({"experimental_features": {"msc2858_enabled": True}})
+    def test_client_idp_redirect_to_unknown(self):
+        """If the client tries to pick an unknown IdP, return a 404"""
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/xxx?redirectUrl="
+            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
+        )
+        self.assertEqual(channel.code, 404, channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
+
+    @override_config({"experimental_features": {"msc2858_enabled": True}})
+    def test_client_idp_redirect_to_oidc(self):
+        """If the client pick a known IdP, redirect to it"""
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
+            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
+        )
+
+        self.assertEqual(channel.code, 302, channel.result)
+        oidc_uri = channel.headers.getRawHeaders("Location")[0]
+        oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
+
+        # it should redirect us to the auth page of the OIDC server
+        self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+
+    @staticmethod
+    def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
+        prefix = key + " = "
+        for caveat in macaroon.caveats:
+            if caveat.caveat_id.startswith(prefix):
+                return caveat.caveat_id[len(prefix) :]
+        raise ValueError("No %s caveat in macaroon" % (key,))
+
+
 class CASTestCase(unittest.HomeserverTestCase):
 
     servlets = [
@@ -342,10 +672,12 @@ class CASTestCase(unittest.HomeserverTestCase):
         self.redirect_path = "_synapse/client/login/sso/redirect/confirm"
 
         config = self.default_config()
+        config["public_baseurl"] = (
+            config.get("public_baseurl") or "https://matrix.goodserver.com:8448"
+        )
         config["cas_config"] = {
             "enabled": True,
-            "server_url": "https://fake.test",
-            "service_url": "https://matrix.goodserver.com:8448",
+            "server_url": CAS_SERVER,
         }
 
         cas_user_id = "username"
@@ -402,10 +734,10 @@ class CASTestCase(unittest.HomeserverTestCase):
         cas_ticket_url = urllib.parse.urlunparse(url_parts)
 
         # Get Synapse to call the fake CAS and serve the template.
-        request, channel = self.make_request("GET", cas_ticket_url)
+        channel = self.make_request("GET", cas_ticket_url)
 
         # Test that the response is HTML.
-        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.code, 200, channel.result)
         content_type_header_value = ""
         for header in channel.result.get("headers", []):
             if header[0] == b"Content-Type":
@@ -430,8 +762,7 @@ class CASTestCase(unittest.HomeserverTestCase):
         }
     )
     def test_cas_redirect_whitelisted(self):
-        """Tests that the SSO login flow serves a redirect to a whitelisted url
-        """
+        """Tests that the SSO login flow serves a redirect to a whitelisted url"""
         self._test_redirect("https://legit-site.com/")
 
     @override_config({"public_baseurl": "https://example.com"})
@@ -446,7 +777,7 @@ class CASTestCase(unittest.HomeserverTestCase):
         )
 
         # Get Synapse to call the fake CAS and serve the template.
-        request, channel = self.make_request("GET", cas_ticket_url)
+        channel = self.make_request("GET", cas_ticket_url)
 
         self.assertEqual(channel.code, 302)
         location_headers = channel.headers.getRawHeaders("Location")
@@ -462,7 +793,9 @@ class CASTestCase(unittest.HomeserverTestCase):
 
         # Deactivate the account.
         self.get_success(
-            self.deactivate_account_handler.deactivate_account(self.user_id, False)
+            self.deactivate_account_handler.deactivate_account(
+                self.user_id, False, create_requester(self.user_id)
+            )
         )
 
         # Request the CAS ticket.
@@ -472,13 +805,14 @@ class CASTestCase(unittest.HomeserverTestCase):
         )
 
         # Get Synapse to call the fake CAS and serve the template.
-        request, channel = self.make_request("GET", cas_ticket_url)
+        channel = self.make_request("GET", cas_ticket_url)
 
         # Because the user is deactivated they are served an error template.
         self.assertEqual(channel.code, 403)
         self.assertIn(b"SSO account deactivated", channel.result["body"])
 
 
+@skip_unless(HAS_JWT, "requires jwt")
 class JWTTestCase(unittest.HomeserverTestCase):
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -495,14 +829,18 @@ class JWTTestCase(unittest.HomeserverTestCase):
         self.hs.config.jwt_algorithm = self.jwt_algorithm
         return self.hs
 
-    def jwt_encode(self, token, secret=jwt_secret):
-        return jwt.encode(token, secret, self.jwt_algorithm).decode("ascii")
+    def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
+        # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
+        result = jwt.encode(
+            payload, secret, self.jwt_algorithm
+        )  # type: Union[str, bytes]
+        if isinstance(result, bytes):
+            return result.decode("ascii")
+        return result
 
     def jwt_login(self, *args):
-        params = json.dumps(
-            {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
-        )
-        request, channel = self.make_request(b"POST", LOGIN_URL, params)
+        params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
+        channel = self.make_request(b"POST", LOGIN_URL, params)
         return channel
 
     def test_login_jwt_valid_registered(self):
@@ -633,8 +971,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
         )
 
     def test_login_no_token(self):
-        params = json.dumps({"type": "org.matrix.login.jwt"})
-        request, channel = self.make_request(b"POST", LOGIN_URL, params)
+        params = {"type": "org.matrix.login.jwt"}
+        channel = self.make_request(b"POST", LOGIN_URL, params)
         self.assertEqual(channel.result["code"], b"403", channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
@@ -643,6 +981,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
 # The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use
 # RSS256, with a public key configured in synapse as "jwt_secret", and tokens
 # signed by the private key.
+@skip_unless(HAS_JWT, "requires jwt")
 class JWTPubKeyTestCase(unittest.HomeserverTestCase):
     servlets = [
         login.register_servlets,
@@ -700,14 +1039,16 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
         self.hs.config.jwt_algorithm = "RS256"
         return self.hs
 
-    def jwt_encode(self, token, secret=jwt_privatekey):
-        return jwt.encode(token, secret, "RS256").decode("ascii")
+    def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
+        # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
+        result = jwt.encode(payload, secret, "RS256")  # type: Union[bytes,str]
+        if isinstance(result, bytes):
+            return result.decode("ascii")
+        return result
 
     def jwt_login(self, *args):
-        params = json.dumps(
-            {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
-        )
-        request, channel = self.make_request(b"POST", LOGIN_URL, params)
+        params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
+        channel = self.make_request(b"POST", LOGIN_URL, params)
         return channel
 
     def test_login_jwt_valid(self):
@@ -735,7 +1076,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
     ]
 
     def register_as_user(self, username):
-        request, channel = self.make_request(
+        self.make_request(
             b"POST",
             "/_matrix/client/r0/register?access_token=%s" % (self.service.token,),
             {"username": username},
@@ -776,60 +1117,56 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
         return self.hs
 
     def test_login_appservice_user(self):
-        """Test that an appservice user can use /login
-        """
+        """Test that an appservice user can use /login"""
         self.register_as_user(AS_USER)
 
         params = {
             "type": login.LoginRestServlet.APPSERVICE_TYPE,
             "identifier": {"type": "m.id.user", "user": AS_USER},
         }
-        request, channel = self.make_request(
+        channel = self.make_request(
             b"POST", LOGIN_URL, params, access_token=self.service.token
         )
 
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
     def test_login_appservice_user_bot(self):
-        """Test that the appservice bot can use /login
-        """
+        """Test that the appservice bot can use /login"""
         self.register_as_user(AS_USER)
 
         params = {
             "type": login.LoginRestServlet.APPSERVICE_TYPE,
             "identifier": {"type": "m.id.user", "user": self.service.sender},
         }
-        request, channel = self.make_request(
+        channel = self.make_request(
             b"POST", LOGIN_URL, params, access_token=self.service.token
         )
 
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
     def test_login_appservice_wrong_user(self):
-        """Test that non-as users cannot login with the as token
-        """
+        """Test that non-as users cannot login with the as token"""
         self.register_as_user(AS_USER)
 
         params = {
             "type": login.LoginRestServlet.APPSERVICE_TYPE,
             "identifier": {"type": "m.id.user", "user": "fibble_wibble"},
         }
-        request, channel = self.make_request(
+        channel = self.make_request(
             b"POST", LOGIN_URL, params, access_token=self.service.token
         )
 
         self.assertEquals(channel.result["code"], b"403", channel.result)
 
     def test_login_appservice_wrong_as(self):
-        """Test that as users cannot login with wrong as token
-        """
+        """Test that as users cannot login with wrong as token"""
         self.register_as_user(AS_USER)
 
         params = {
             "type": login.LoginRestServlet.APPSERVICE_TYPE,
             "identifier": {"type": "m.id.user", "user": AS_USER},
         }
-        request, channel = self.make_request(
+        channel = self.make_request(
             b"POST", LOGIN_URL, params, access_token=self.another_service.token
         )
 
@@ -837,7 +1174,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
 
     def test_login_appservice_no_token(self):
         """Test that users must provide a token when using the appservice
-           login method
+        login method
         """
         self.register_as_user(AS_USER)
 
@@ -845,6 +1182,116 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
             "type": login.LoginRestServlet.APPSERVICE_TYPE,
             "identifier": {"type": "m.id.user", "user": AS_USER},
         }
-        request, channel = self.make_request(b"POST", LOGIN_URL, params)
+        channel = self.make_request(b"POST", LOGIN_URL, params)
 
         self.assertEquals(channel.result["code"], b"401", channel.result)
+
+
+@skip_unless(HAS_OIDC, "requires OIDC")
+class UsernamePickerTestCase(HomeserverTestCase):
+    """Tests for the username picker flow of SSO login"""
+
+    servlets = [login.register_servlets]
+
+    def default_config(self):
+        config = super().default_config()
+        config["public_baseurl"] = BASE_URL
+
+        config["oidc_config"] = {}
+        config["oidc_config"].update(TEST_OIDC_CONFIG)
+        config["oidc_config"]["user_mapping_provider"] = {
+            "config": {"display_name_template": "{{ user.displayname }}"}
+        }
+
+        # whitelist this client URI so we redirect straight to it rather than
+        # serving a confirmation page
+        config["sso"] = {"client_whitelist": ["https://x"]}
+        return config
+
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        d = super().create_resource_dict()
+        d.update(build_synapse_client_resource_tree(self.hs))
+        return d
+
+    def test_username_picker(self):
+        """Test the happy path of a username picker flow."""
+
+        # do the start of the login flow
+        channel = self.helper.auth_via_oidc(
+            {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL
+        )
+
+        # that should redirect to the username picker
+        self.assertEqual(channel.code, 302, channel.result)
+        picker_url = channel.headers.getRawHeaders("Location")[0]
+        self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details")
+
+        # ... with a username_mapping_session cookie
+        cookies = {}  # type: Dict[str,str]
+        channel.extract_cookies(cookies)
+        self.assertIn("username_mapping_session", cookies)
+        session_id = cookies["username_mapping_session"]
+
+        # introspect the sso handler a bit to check that the username mapping session
+        # looks ok.
+        username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions
+        self.assertIn(
+            session_id, username_mapping_sessions, "session id not found in map",
+        )
+        session = username_mapping_sessions[session_id]
+        self.assertEqual(session.remote_user_id, "tester")
+        self.assertEqual(session.display_name, "Jonny")
+        self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL)
+
+        # the expiry time should be about 15 minutes away
+        expected_expiry = self.clock.time_msec() + (15 * 60 * 1000)
+        self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)
+
+        # Now, submit a username to the username picker, which should serve a redirect
+        # to the completion page
+        content = urlencode({b"username": b"bobby"}).encode("utf8")
+        chan = self.make_request(
+            "POST",
+            path=picker_url,
+            content=content,
+            content_is_form=True,
+            custom_headers=[
+                ("Cookie", "username_mapping_session=" + session_id),
+                # old versions of twisted don't do form-parsing without a valid
+                # content-length header.
+                ("Content-Length", str(len(content))),
+            ],
+        )
+        self.assertEqual(chan.code, 302, chan.result)
+        location_headers = chan.headers.getRawHeaders("Location")
+
+        # send a request to the completion page, which should 302 to the client redirectUrl
+        chan = self.make_request(
+            "GET",
+            path=location_headers[0],
+            custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
+        )
+        self.assertEqual(chan.code, 302, chan.result)
+        location_headers = chan.headers.getRawHeaders("Location")
+
+        # ensure that the returned location matches the requested redirect URL
+        path, query = location_headers[0].split("?", 1)
+        self.assertEqual(path, "https://x")
+
+        # it will have url-encoded the params properly, so we'll have to parse them
+        params = urllib.parse.parse_qsl(
+            query, keep_blank_values=True, strict_parsing=True, errors="strict"
+        )
+        self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS)
+        self.assertEqual(params[2][0], "loginToken")
+
+        # fish the login token out of the returned redirect uri
+        login_token = params[2][1]
+
+        # finally, submit the matrix login token to the login API, which gives us our
+        # matrix access token, mxid, and device id.
+        chan = self.make_request(
+            "POST", "/login", content={"type": "m.login.token", "token": login_token},
+        )
+        self.assertEqual(chan.code, 200, chan.result)
+        self.assertEqual(chan.json_body["user_id"], "@bobby:test")
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 5d5c24d01c..94a5154834 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -38,7 +38,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
 
         hs = self.setup_test_homeserver(
             "red",
-            http_client=None,
+            federation_http_client=None,
             federation_client=Mock(),
             presence_handler=presence_handler,
         )
@@ -53,7 +53,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
         self.hs.config.use_presence = True
 
         body = {"presence": "here", "status_msg": "beep boop"}
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", "/presence/%s/status" % (self.user_id,), body
         )
 
@@ -68,7 +68,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
         self.hs.config.use_presence = False
 
         body = {"presence": "here", "status_msg": "beep boop"}
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", "/presence/%s/status" % (self.user_id,), body
         )
 
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 383a9eafac..e59fa70baa 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -63,7 +63,7 @@ class MockHandlerProfileTestCase(unittest.TestCase):
         hs = yield setup_test_homeserver(
             self.addCleanup,
             "test",
-            http_client=None,
+            federation_http_client=None,
             resource_for_client=self.mock_resource,
             federation=Mock(),
             federation_client=Mock(),
@@ -189,7 +189,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         self.owner_tok = self.login("owner", "pass")
 
     def test_set_displayname(self):
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/profile/%s/displayname" % (self.owner,),
             content=json.dumps({"displayname": "test"}),
@@ -202,7 +202,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
 
     def test_set_displayname_too_long(self):
         """Attempts to set a stupid displayname should get a 400"""
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/profile/%s/displayname" % (self.owner,),
             content=json.dumps({"displayname": "test" * 100}),
@@ -214,9 +214,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         self.assertEqual(res, "owner")
 
     def get_displayname(self):
-        request, channel = self.make_request(
-            "GET", "/profile/%s/displayname" % (self.owner,)
-        )
+        channel = self.make_request("GET", "/profile/%s/displayname" % (self.owner,))
         self.assertEqual(channel.code, 200, channel.result)
         return channel.json_body["displayname"]
 
@@ -278,7 +276,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
         )
 
     def request_profile(self, expected_code, url_suffix="", access_token=None):
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.profile_url + url_suffix, access_token=access_token
         )
         self.assertEqual(channel.code, expected_code, channel.result)
@@ -320,19 +318,19 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase):
         """Tests that a user can lookup their own profile without having to be in a room
         if 'require_auth_for_profile_requests' is set to true in the server's config.
         """
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/profile/" + self.requester, access_token=self.requester_tok
         )
         self.assertEqual(channel.code, 200, channel.result)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/profile/" + self.requester + "/displayname",
             access_token=self.requester_tok,
         )
         self.assertEqual(channel.code, 200, channel.result)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/profile/" + self.requester + "/avatar_url",
             access_token=self.requester_tok,
diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/v1/test_push_rule_attrs.py
index 7add5523c8..2bc512d75e 100644
--- a/tests/rest/client/v1/test_push_rule_attrs.py
+++ b/tests/rest/client/v1/test_push_rule_attrs.py
@@ -45,13 +45,13 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         }
 
         # PUT a new rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", "/pushrules/global/override/best.friend", body, access_token=token
         )
         self.assertEqual(channel.code, 200)
 
         # GET enabled for that new rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
         )
         self.assertEqual(channel.code, 200)
@@ -74,13 +74,13 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         }
 
         # PUT a new rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", "/pushrules/global/override/best.friend", body, access_token=token
         )
         self.assertEqual(channel.code, 200)
 
         # disable the rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/pushrules/global/override/best.friend/enabled",
             {"enabled": False},
@@ -89,26 +89,26 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         self.assertEqual(channel.code, 200)
 
         # check rule disabled
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
         )
         self.assertEqual(channel.code, 200)
         self.assertEqual(channel.json_body["enabled"], False)
 
         # DELETE the rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "DELETE", "/pushrules/global/override/best.friend", access_token=token
         )
         self.assertEqual(channel.code, 200)
 
         # PUT a new rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", "/pushrules/global/override/best.friend", body, access_token=token
         )
         self.assertEqual(channel.code, 200)
 
         # GET enabled for that new rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
         )
         self.assertEqual(channel.code, 200)
@@ -130,13 +130,13 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         }
 
         # PUT a new rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", "/pushrules/global/override/best.friend", body, access_token=token
         )
         self.assertEqual(channel.code, 200)
 
         # disable the rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/pushrules/global/override/best.friend/enabled",
             {"enabled": False},
@@ -145,14 +145,14 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         self.assertEqual(channel.code, 200)
 
         # check rule disabled
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
         )
         self.assertEqual(channel.code, 200)
         self.assertEqual(channel.json_body["enabled"], False)
 
         # re-enable the rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/pushrules/global/override/best.friend/enabled",
             {"enabled": True},
@@ -161,7 +161,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         self.assertEqual(channel.code, 200)
 
         # check rule enabled
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
         )
         self.assertEqual(channel.code, 200)
@@ -182,32 +182,32 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         }
 
         # check 404 for never-heard-of rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
         )
         self.assertEqual(channel.code, 404)
         self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
 
         # PUT a new rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", "/pushrules/global/override/best.friend", body, access_token=token
         )
         self.assertEqual(channel.code, 200)
 
         # GET enabled for that new rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
         )
         self.assertEqual(channel.code, 200)
 
         # DELETE the rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "DELETE", "/pushrules/global/override/best.friend", access_token=token
         )
         self.assertEqual(channel.code, 200)
 
         # check 404 for deleted rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
         )
         self.assertEqual(channel.code, 404)
@@ -221,7 +221,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         token = self.login("user", "pass")
 
         # check 404 for never-heard-of rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/pushrules/global/override/.m.muahahaha/enabled", access_token=token
         )
         self.assertEqual(channel.code, 404)
@@ -235,7 +235,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         token = self.login("user", "pass")
 
         # enable & check 404 for never-heard-of rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/pushrules/global/override/best.friend/enabled",
             {"enabled": True},
@@ -252,7 +252,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         token = self.login("user", "pass")
 
         # enable & check 404 for never-heard-of rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/pushrules/global/override/.m.muahahah/enabled",
             {"enabled": True},
@@ -276,13 +276,13 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         }
 
         # PUT a new rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", "/pushrules/global/override/best.friend", body, access_token=token
         )
         self.assertEqual(channel.code, 200)
 
         # GET actions for that new rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/pushrules/global/override/best.friend/actions", access_token=token
         )
         self.assertEqual(channel.code, 200)
@@ -305,13 +305,13 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         }
 
         # PUT a new rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", "/pushrules/global/override/best.friend", body, access_token=token
         )
         self.assertEqual(channel.code, 200)
 
         # change the rule actions
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/pushrules/global/override/best.friend/actions",
             {"actions": ["dont_notify"]},
@@ -320,7 +320,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         self.assertEqual(channel.code, 200)
 
         # GET actions for that new rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/pushrules/global/override/best.friend/actions", access_token=token
         )
         self.assertEqual(channel.code, 200)
@@ -341,26 +341,26 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         }
 
         # check 404 for never-heard-of rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
         )
         self.assertEqual(channel.code, 404)
         self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
 
         # PUT a new rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", "/pushrules/global/override/best.friend", body, access_token=token
         )
         self.assertEqual(channel.code, 200)
 
         # DELETE the rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "DELETE", "/pushrules/global/override/best.friend", access_token=token
         )
         self.assertEqual(channel.code, 200)
 
         # check 404 for deleted rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
         )
         self.assertEqual(channel.code, 404)
@@ -374,7 +374,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         token = self.login("user", "pass")
 
         # check 404 for never-heard-of rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/pushrules/global/override/.m.muahahaha/actions", access_token=token
         )
         self.assertEqual(channel.code, 404)
@@ -388,7 +388,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         token = self.login("user", "pass")
 
         # enable & check 404 for never-heard-of rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/pushrules/global/override/best.friend/actions",
             {"actions": ["dont_notify"]},
@@ -405,7 +405,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         token = self.login("user", "pass")
 
         # enable & check 404 for never-heard-of rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/pushrules/global/override/.m.muahahah/actions",
             {"actions": ["dont_notify"]},
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 49f1073c88..2548b3a80c 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -26,9 +26,10 @@ from mock import Mock
 import synapse.rest.admin
 from synapse.api.constants import EventContentFields, EventTypes, Membership
 from synapse.handlers.pagination import PurgeStatus
+from synapse.rest import admin
 from synapse.rest.client.v1 import directory, login, profile, room
 from synapse.rest.client.v2_alpha import account
-from synapse.types import JsonDict, RoomAlias, UserID
+from synapse.types import JsonDict, RoomAlias, UserID, create_requester
 from synapse.util.stringutils import random_string
 
 from tests import unittest
@@ -45,7 +46,7 @@ class RoomBase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
 
         self.hs = self.setup_test_homeserver(
-            "red", http_client=None, federation_client=Mock(),
+            "red", federation_http_client=None, federation_client=Mock(),
         )
 
         self.hs.get_federation_handler = Mock()
@@ -83,13 +84,13 @@ class RoomPermissionsTestCase(RoomBase):
         self.created_rmid_msg_path = (
             "rooms/%s/send/m.room.message/a1" % (self.created_rmid)
         ).encode("ascii")
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}'
         )
         self.assertEquals(200, channel.code, channel.result)
 
         # set topic for public room
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"),
             b'{"topic":"Public Room Topic"}',
@@ -111,7 +112,7 @@ class RoomPermissionsTestCase(RoomBase):
             )
 
         # send message in uncreated room, expect 403
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,),
             msg_content,
@@ -119,24 +120,24 @@ class RoomPermissionsTestCase(RoomBase):
         self.assertEquals(403, channel.code, msg=channel.result["body"])
 
         # send message in created room not joined (no state), expect 403
-        request, channel = self.make_request("PUT", send_msg_path(), msg_content)
+        channel = self.make_request("PUT", send_msg_path(), msg_content)
         self.assertEquals(403, channel.code, msg=channel.result["body"])
 
         # send message in created room and invited, expect 403
         self.helper.invite(
             room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
         )
-        request, channel = self.make_request("PUT", send_msg_path(), msg_content)
+        channel = self.make_request("PUT", send_msg_path(), msg_content)
         self.assertEquals(403, channel.code, msg=channel.result["body"])
 
         # send message in created room and joined, expect 200
         self.helper.join(room=self.created_rmid, user=self.user_id)
-        request, channel = self.make_request("PUT", send_msg_path(), msg_content)
+        channel = self.make_request("PUT", send_msg_path(), msg_content)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
 
         # send message in created room and left, expect 403
         self.helper.leave(room=self.created_rmid, user=self.user_id)
-        request, channel = self.make_request("PUT", send_msg_path(), msg_content)
+        channel = self.make_request("PUT", send_msg_path(), msg_content)
         self.assertEquals(403, channel.code, msg=channel.result["body"])
 
     def test_topic_perms(self):
@@ -144,30 +145,30 @@ class RoomPermissionsTestCase(RoomBase):
         topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid
 
         # set/get topic in uncreated room, expect 403
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content
         )
         self.assertEquals(403, channel.code, msg=channel.result["body"])
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid
         )
         self.assertEquals(403, channel.code, msg=channel.result["body"])
 
         # set/get topic in created PRIVATE room not joined, expect 403
-        request, channel = self.make_request("PUT", topic_path, topic_content)
+        channel = self.make_request("PUT", topic_path, topic_content)
         self.assertEquals(403, channel.code, msg=channel.result["body"])
-        request, channel = self.make_request("GET", topic_path)
+        channel = self.make_request("GET", topic_path)
         self.assertEquals(403, channel.code, msg=channel.result["body"])
 
         # set topic in created PRIVATE room and invited, expect 403
         self.helper.invite(
             room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
         )
-        request, channel = self.make_request("PUT", topic_path, topic_content)
+        channel = self.make_request("PUT", topic_path, topic_content)
         self.assertEquals(403, channel.code, msg=channel.result["body"])
 
         # get topic in created PRIVATE room and invited, expect 403
-        request, channel = self.make_request("GET", topic_path)
+        channel = self.make_request("GET", topic_path)
         self.assertEquals(403, channel.code, msg=channel.result["body"])
 
         # set/get topic in created PRIVATE room and joined, expect 200
@@ -175,29 +176,29 @@ class RoomPermissionsTestCase(RoomBase):
 
         # Only room ops can set topic by default
         self.helper.auth_user_id = self.rmcreator_id
-        request, channel = self.make_request("PUT", topic_path, topic_content)
+        channel = self.make_request("PUT", topic_path, topic_content)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
         self.helper.auth_user_id = self.user_id
 
-        request, channel = self.make_request("GET", topic_path)
+        channel = self.make_request("GET", topic_path)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
         self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body)
 
         # set/get topic in created PRIVATE room and left, expect 403
         self.helper.leave(room=self.created_rmid, user=self.user_id)
-        request, channel = self.make_request("PUT", topic_path, topic_content)
+        channel = self.make_request("PUT", topic_path, topic_content)
         self.assertEquals(403, channel.code, msg=channel.result["body"])
-        request, channel = self.make_request("GET", topic_path)
+        channel = self.make_request("GET", topic_path)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
 
         # get topic in PUBLIC room, not joined, expect 403
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid
         )
         self.assertEquals(403, channel.code, msg=channel.result["body"])
 
         # set topic in PUBLIC room, not joined, expect 403
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/rooms/%s/state/m.room.topic" % self.created_public_rmid,
             topic_content,
@@ -207,7 +208,7 @@ class RoomPermissionsTestCase(RoomBase):
     def _test_get_membership(self, room=None, members=[], expect_code=None):
         for member in members:
             path = "/rooms/%s/state/m.room.member/%s" % (room, member)
-            request, channel = self.make_request("GET", path)
+            channel = self.make_request("GET", path)
             self.assertEquals(expect_code, channel.code)
 
     def test_membership_basic_room_perms(self):
@@ -379,16 +380,16 @@ class RoomsMemberListTestCase(RoomBase):
 
     def test_get_member_list(self):
         room_id = self.helper.create_room_as(self.user_id)
-        request, channel = self.make_request("GET", "/rooms/%s/members" % room_id)
+        channel = self.make_request("GET", "/rooms/%s/members" % room_id)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
 
     def test_get_member_list_no_room(self):
-        request, channel = self.make_request("GET", "/rooms/roomdoesnotexist/members")
+        channel = self.make_request("GET", "/rooms/roomdoesnotexist/members")
         self.assertEquals(403, channel.code, msg=channel.result["body"])
 
     def test_get_member_list_no_permission(self):
         room_id = self.helper.create_room_as("@some_other_guy:red")
-        request, channel = self.make_request("GET", "/rooms/%s/members" % room_id)
+        channel = self.make_request("GET", "/rooms/%s/members" % room_id)
         self.assertEquals(403, channel.code, msg=channel.result["body"])
 
     def test_get_member_list_mixed_memberships(self):
@@ -397,17 +398,17 @@ class RoomsMemberListTestCase(RoomBase):
         room_path = "/rooms/%s/members" % room_id
         self.helper.invite(room=room_id, src=room_creator, targ=self.user_id)
         # can't see list if you're just invited.
-        request, channel = self.make_request("GET", room_path)
+        channel = self.make_request("GET", room_path)
         self.assertEquals(403, channel.code, msg=channel.result["body"])
 
         self.helper.join(room=room_id, user=self.user_id)
         # can see list now joined
-        request, channel = self.make_request("GET", room_path)
+        channel = self.make_request("GET", room_path)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
 
         self.helper.leave(room=room_id, user=self.user_id)
         # can see old list once left
-        request, channel = self.make_request("GET", room_path)
+        channel = self.make_request("GET", room_path)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
 
 
@@ -418,30 +419,26 @@ class RoomsCreateTestCase(RoomBase):
 
     def test_post_room_no_keys(self):
         # POST with no config keys, expect new room id
-        request, channel = self.make_request("POST", "/createRoom", "{}")
+        channel = self.make_request("POST", "/createRoom", "{}")
 
         self.assertEquals(200, channel.code, channel.result)
         self.assertTrue("room_id" in channel.json_body)
 
     def test_post_room_visibility_key(self):
         # POST with visibility config key, expect new room id
-        request, channel = self.make_request(
-            "POST", "/createRoom", b'{"visibility":"private"}'
-        )
+        channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}')
         self.assertEquals(200, channel.code)
         self.assertTrue("room_id" in channel.json_body)
 
     def test_post_room_custom_key(self):
         # POST with custom config keys, expect new room id
-        request, channel = self.make_request(
-            "POST", "/createRoom", b'{"custom":"stuff"}'
-        )
+        channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}')
         self.assertEquals(200, channel.code)
         self.assertTrue("room_id" in channel.json_body)
 
     def test_post_room_known_and_unknown_keys(self):
         # POST with custom + known config keys, expect new room id
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", "/createRoom", b'{"visibility":"private","custom":"things"}'
         )
         self.assertEquals(200, channel.code)
@@ -449,16 +446,16 @@ class RoomsCreateTestCase(RoomBase):
 
     def test_post_room_invalid_content(self):
         # POST with invalid content / paths, expect 400
-        request, channel = self.make_request("POST", "/createRoom", b'{"visibili')
+        channel = self.make_request("POST", "/createRoom", b'{"visibili')
         self.assertEquals(400, channel.code)
 
-        request, channel = self.make_request("POST", "/createRoom", b'["hello"]')
+        channel = self.make_request("POST", "/createRoom", b'["hello"]')
         self.assertEquals(400, channel.code)
 
     def test_post_room_invitees_invalid_mxid(self):
         # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088
         # Note the trailing space in the MXID here!
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", "/createRoom", b'{"invite":["@alice:example.com "]}'
         )
         self.assertEquals(400, channel.code)
@@ -476,54 +473,54 @@ class RoomTopicTestCase(RoomBase):
 
     def test_invalid_puts(self):
         # missing keys or invalid json
-        request, channel = self.make_request("PUT", self.path, "{}")
+        channel = self.make_request("PUT", self.path, "{}")
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("PUT", self.path, '{"_name":"bo"}')
+        channel = self.make_request("PUT", self.path, '{"_name":"bo"}')
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("PUT", self.path, '{"nao')
+        channel = self.make_request("PUT", self.path, '{"nao')
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]'
         )
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("PUT", self.path, "text only")
+        channel = self.make_request("PUT", self.path, "text only")
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("PUT", self.path, "")
+        channel = self.make_request("PUT", self.path, "")
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
         # valid key, wrong type
         content = '{"topic":["Topic name"]}'
-        request, channel = self.make_request("PUT", self.path, content)
+        channel = self.make_request("PUT", self.path, content)
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
     def test_rooms_topic(self):
         # nothing should be there
-        request, channel = self.make_request("GET", self.path)
+        channel = self.make_request("GET", self.path)
         self.assertEquals(404, channel.code, msg=channel.result["body"])
 
         # valid put
         content = '{"topic":"Topic name"}'
-        request, channel = self.make_request("PUT", self.path, content)
+        channel = self.make_request("PUT", self.path, content)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
 
         # valid get
-        request, channel = self.make_request("GET", self.path)
+        channel = self.make_request("GET", self.path)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
         self.assert_dict(json.loads(content), channel.json_body)
 
     def test_rooms_topic_with_extra_keys(self):
         # valid put with extra keys
         content = '{"topic":"Seasons","subtopic":"Summer"}'
-        request, channel = self.make_request("PUT", self.path, content)
+        channel = self.make_request("PUT", self.path, content)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
 
         # valid get
-        request, channel = self.make_request("GET", self.path)
+        channel = self.make_request("GET", self.path)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
         self.assert_dict(json.loads(content), channel.json_body)
 
@@ -539,24 +536,22 @@ class RoomMemberStateTestCase(RoomBase):
     def test_invalid_puts(self):
         path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id)
         # missing keys or invalid json
-        request, channel = self.make_request("PUT", path, "{}")
+        channel = self.make_request("PUT", path, "{}")
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("PUT", path, '{"_name":"bo"}')
+        channel = self.make_request("PUT", path, '{"_name":"bo"}')
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("PUT", path, '{"nao')
+        channel = self.make_request("PUT", path, '{"nao')
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request(
-            "PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]'
-        )
+        channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]')
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("PUT", path, "text only")
+        channel = self.make_request("PUT", path, "text only")
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("PUT", path, "")
+        channel = self.make_request("PUT", path, "")
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
         # valid keys, wrong types
@@ -565,7 +560,7 @@ class RoomMemberStateTestCase(RoomBase):
             Membership.JOIN,
             Membership.LEAVE,
         )
-        request, channel = self.make_request("PUT", path, content.encode("ascii"))
+        channel = self.make_request("PUT", path, content.encode("ascii"))
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
     def test_rooms_members_self(self):
@@ -576,10 +571,10 @@ class RoomMemberStateTestCase(RoomBase):
 
         # valid join message (NOOP since we made the room)
         content = '{"membership":"%s"}' % Membership.JOIN
-        request, channel = self.make_request("PUT", path, content.encode("ascii"))
+        channel = self.make_request("PUT", path, content.encode("ascii"))
         self.assertEquals(200, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("GET", path, None)
+        channel = self.make_request("GET", path, None)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
 
         expected_response = {"membership": Membership.JOIN}
@@ -594,10 +589,10 @@ class RoomMemberStateTestCase(RoomBase):
 
         # valid invite message
         content = '{"membership":"%s"}' % Membership.INVITE
-        request, channel = self.make_request("PUT", path, content)
+        channel = self.make_request("PUT", path, content)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("GET", path, None)
+        channel = self.make_request("GET", path, None)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
         self.assertEquals(json.loads(content), channel.json_body)
 
@@ -613,18 +608,54 @@ class RoomMemberStateTestCase(RoomBase):
             Membership.INVITE,
             "Join us!",
         )
-        request, channel = self.make_request("PUT", path, content)
+        channel = self.make_request("PUT", path, content)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("GET", path, None)
+        channel = self.make_request("GET", path, None)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
         self.assertEquals(json.loads(content), channel.json_body)
 
 
+class RoomInviteRatelimitTestCase(RoomBase):
+    user_id = "@sid1:red"
+
+    servlets = [
+        admin.register_servlets,
+        profile.register_servlets,
+        room.register_servlets,
+    ]
+
+    @unittest.override_config(
+        {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}}
+    )
+    def test_invites_by_rooms_ratelimit(self):
+        """Tests that invites in a room are actually rate-limited."""
+        room_id = self.helper.create_room_as(self.user_id)
+
+        for i in range(3):
+            self.helper.invite(room_id, self.user_id, "@user-%s:red" % (i,))
+
+        self.helper.invite(room_id, self.user_id, "@user-4:red", expect_code=429)
+
+    @unittest.override_config(
+        {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
+    )
+    def test_invites_by_users_ratelimit(self):
+        """Tests that invites to a specific user are actually rate-limited."""
+
+        for i in range(3):
+            room_id = self.helper.create_room_as(self.user_id)
+            self.helper.invite(room_id, self.user_id, "@other-users:red")
+
+        room_id = self.helper.create_room_as(self.user_id)
+        self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429)
+
+
 class RoomJoinRatelimitTestCase(RoomBase):
     user_id = "@sid1:red"
 
     servlets = [
+        admin.register_servlets,
         profile.register_servlets,
         room.register_servlets,
     ]
@@ -666,7 +697,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
 
         # Update the display name for the user.
         path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
-        request, channel = self.make_request("PUT", path, {"displayname": "John Doe"})
+        channel = self.make_request("PUT", path, {"displayname": "John Doe"})
         self.assertEquals(channel.code, 200, channel.json_body)
 
         # Check that all the rooms have been sent a profile update into.
@@ -676,7 +707,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
                 self.user_id,
             )
 
-            request, channel = self.make_request("GET", path)
+            channel = self.make_request("GET", path)
             self.assertEquals(channel.code, 200)
 
             self.assertIn("displayname", channel.json_body)
@@ -700,9 +731,23 @@ class RoomJoinRatelimitTestCase(RoomBase):
             # Make sure we send more requests than the rate-limiting config would allow
             # if all of these requests ended up joining the user to a room.
             for i in range(4):
-                request, channel = self.make_request("POST", path % room_id, {})
+                channel = self.make_request("POST", path % room_id, {})
                 self.assertEquals(channel.code, 200)
 
+    @unittest.override_config(
+        {
+            "rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}},
+            "auto_join_rooms": ["#room:red", "#room2:red", "#room3:red", "#room4:red"],
+            "autocreate_auto_join_rooms": True,
+        },
+    )
+    def test_autojoin_rooms(self):
+        user_id = self.register_user("testuser", "password")
+
+        # Check that the new user successfully joined the four rooms
+        rooms = self.get_success(self.hs.get_datastore().get_rooms_for_user(user_id))
+        self.assertEqual(len(rooms), 4)
+
 
 class RoomMessagesTestCase(RoomBase):
     """ Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """
@@ -715,42 +760,40 @@ class RoomMessagesTestCase(RoomBase):
     def test_invalid_puts(self):
         path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
         # missing keys or invalid json
-        request, channel = self.make_request("PUT", path, b"{}")
+        channel = self.make_request("PUT", path, b"{}")
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("PUT", path, b'{"_name":"bo"}')
+        channel = self.make_request("PUT", path, b'{"_name":"bo"}')
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("PUT", path, b'{"nao')
+        channel = self.make_request("PUT", path, b'{"nao')
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request(
-            "PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]'
-        )
+        channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]')
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("PUT", path, b"text only")
+        channel = self.make_request("PUT", path, b"text only")
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("PUT", path, b"")
+        channel = self.make_request("PUT", path, b"")
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
     def test_rooms_messages_sent(self):
         path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
 
         content = b'{"body":"test","msgtype":{"type":"a"}}'
-        request, channel = self.make_request("PUT", path, content)
+        channel = self.make_request("PUT", path, content)
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
         # custom message types
         content = b'{"body":"test","msgtype":"test.custom.text"}'
-        request, channel = self.make_request("PUT", path, content)
+        channel = self.make_request("PUT", path, content)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
 
         # m.text message type
         path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id))
         content = b'{"body":"test2","msgtype":"m.text"}'
-        request, channel = self.make_request("PUT", path, content)
+        channel = self.make_request("PUT", path, content)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
 
 
@@ -764,9 +807,7 @@ class RoomInitialSyncTestCase(RoomBase):
         self.room_id = self.helper.create_room_as(self.user_id)
 
     def test_initial_sync(self):
-        request, channel = self.make_request(
-            "GET", "/rooms/%s/initialSync" % self.room_id
-        )
+        channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id)
         self.assertEquals(200, channel.code)
 
         self.assertEquals(self.room_id, channel.json_body["room_id"])
@@ -807,7 +848,7 @@ class RoomMessageListTestCase(RoomBase):
 
     def test_topo_token_is_accepted(self):
         token = "t1-0_0_0_0_0_0_0_0_0"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
         )
         self.assertEquals(200, channel.code)
@@ -818,7 +859,7 @@ class RoomMessageListTestCase(RoomBase):
 
     def test_stream_token_is_accepted_for_fwd_pagianation(self):
         token = "s0_0_0_0_0_0_0_0_0"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
         )
         self.assertEquals(200, channel.code)
@@ -851,7 +892,7 @@ class RoomMessageListTestCase(RoomBase):
         self.helper.send(self.room_id, "message 3")
 
         # Check that we get the first and second message when querying /messages.
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
             % (
@@ -879,7 +920,7 @@ class RoomMessageListTestCase(RoomBase):
 
         # Check that we only get the second message through /message now that the first
         # has been purged.
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
             % (
@@ -896,7 +937,7 @@ class RoomMessageListTestCase(RoomBase):
         # Check that we get no event, but also no error, when querying /messages with
         # the token that was pointing at the first event, because we don't have it
         # anymore.
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
             % (
@@ -955,7 +996,7 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
         self.helper.send(self.room, body="Hi!", tok=self.other_access_token)
         self.helper.send(self.room, body="There!", tok=self.other_access_token)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/search?access_token=%s" % (self.access_token,),
             {
@@ -984,7 +1025,7 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
         self.helper.send(self.room, body="Hi!", tok=self.other_access_token)
         self.helper.send(self.room, body="There!", tok=self.other_access_token)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/search?access_token=%s" % (self.access_token,),
             {
@@ -1032,14 +1073,14 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
         return self.hs
 
     def test_restricted_no_auth(self):
-        request, channel = self.make_request("GET", self.url)
+        channel = self.make_request("GET", self.url)
         self.assertEqual(channel.code, 401, channel.result)
 
     def test_restricted_auth(self):
         self.register_user("user", "pass")
         tok = self.login("user", "pass")
 
-        request, channel = self.make_request("GET", self.url, access_token=tok)
+        channel = self.make_request("GET", self.url, access_token=tok)
         self.assertEqual(channel.code, 200, channel.result)
 
 
@@ -1067,7 +1108,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
         self.displayname = "test user"
         data = {"displayname": self.displayname}
         request_data = json.dumps(data)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/_matrix/client/r0/profile/%s/displayname" % (self.user_id,),
             request_data,
@@ -1080,7 +1121,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
     def test_per_room_profile_forbidden(self):
         data = {"membership": "join", "displayname": "other test user"}
         request_data = json.dumps(data)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/_matrix/client/r0/rooms/%s/state/m.room.member/%s"
             % (self.room_id, self.user_id),
@@ -1090,7 +1131,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200, channel.result)
         event_id = channel.json_body["event_id"]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
             access_token=self.tok,
@@ -1123,7 +1164,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
 
     def test_join_reason(self):
         reason = "hello"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/r0/rooms/{}/join".format(self.room_id),
             content={"reason": reason},
@@ -1137,7 +1178,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
         self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
 
         reason = "hello"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/r0/rooms/{}/leave".format(self.room_id),
             content={"reason": reason},
@@ -1151,7 +1192,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
         self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
 
         reason = "hello"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/r0/rooms/{}/kick".format(self.room_id),
             content={"reason": reason, "user_id": self.second_user_id},
@@ -1165,7 +1206,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
         self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
 
         reason = "hello"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/r0/rooms/{}/ban".format(self.room_id),
             content={"reason": reason, "user_id": self.second_user_id},
@@ -1177,7 +1218,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
 
     def test_unban_reason(self):
         reason = "hello"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/r0/rooms/{}/unban".format(self.room_id),
             content={"reason": reason, "user_id": self.second_user_id},
@@ -1189,7 +1230,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
 
     def test_invite_reason(self):
         reason = "hello"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/r0/rooms/{}/invite".format(self.room_id),
             content={"reason": reason, "user_id": self.second_user_id},
@@ -1208,7 +1249,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
         )
 
         reason = "hello"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/r0/rooms/{}/leave".format(self.room_id),
             content={"reason": reason},
@@ -1219,7 +1260,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
         self._check_for_reason(reason)
 
     def _check_for_reason(self, reason):
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format(
                 self.room_id, self.second_user_id
@@ -1268,7 +1309,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
         """Test that we can filter by a label on a /context request."""
         event_id = self._send_labelled_messages_in_room()
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/rooms/%s/context/%s?filter=%s"
             % (self.room_id, event_id, json.dumps(self.FILTER_LABELS)),
@@ -1298,7 +1339,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
         """Test that we can filter by the absence of a label on a /context request."""
         event_id = self._send_labelled_messages_in_room()
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/rooms/%s/context/%s?filter=%s"
             % (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)),
@@ -1333,7 +1374,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
         """
         event_id = self._send_labelled_messages_in_room()
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/rooms/%s/context/%s?filter=%s"
             % (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)),
@@ -1361,7 +1402,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
         self._send_labelled_messages_in_room()
 
         token = "s0_0_0_0_0_0_0_0_0"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/rooms/%s/messages?access_token=%s&from=%s&filter=%s"
             % (self.room_id, self.tok, token, json.dumps(self.FILTER_LABELS)),
@@ -1378,7 +1419,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
         self._send_labelled_messages_in_room()
 
         token = "s0_0_0_0_0_0_0_0_0"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/rooms/%s/messages?access_token=%s&from=%s&filter=%s"
             % (self.room_id, self.tok, token, json.dumps(self.FILTER_NOT_LABELS)),
@@ -1401,7 +1442,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
         self._send_labelled_messages_in_room()
 
         token = "s0_0_0_0_0_0_0_0_0"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/rooms/%s/messages?access_token=%s&from=%s&filter=%s"
             % (
@@ -1432,7 +1473,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
 
         self._send_labelled_messages_in_room()
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", "/search?access_token=%s" % self.tok, request_data
         )
 
@@ -1467,7 +1508,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
 
         self._send_labelled_messages_in_room()
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", "/search?access_token=%s" % self.tok, request_data
         )
 
@@ -1514,7 +1555,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
 
         self._send_labelled_messages_in_room()
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", "/search?access_token=%s" % self.tok, request_data
         )
 
@@ -1635,7 +1676,7 @@ class ContextTestCase(unittest.HomeserverTestCase):
 
         # Check that we can still see the messages before the erasure request.
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             '/rooms/%s/context/%s?filter={"types":["m.room.message"]}'
             % (self.room_id, event_id),
@@ -1681,7 +1722,9 @@ class ContextTestCase(unittest.HomeserverTestCase):
 
         deactivate_account_handler = self.hs.get_deactivate_account_handler()
         self.get_success(
-            deactivate_account_handler.deactivate_account(self.user_id, erase_data=True)
+            deactivate_account_handler.deactivate_account(
+                self.user_id, True, create_requester(self.user_id)
+            )
         )
 
         # Invite another user in the room. This is needed because messages will be
@@ -1699,7 +1742,7 @@ class ContextTestCase(unittest.HomeserverTestCase):
         # Check that a user that joined the room after the erasure request can't see
         # the messages anymore.
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             '/rooms/%s/context/%s?filter={"types":["m.room.message"]}'
             % (self.room_id, event_id),
@@ -1789,7 +1832,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
 
     def _get_aliases(self, access_token: str, expected_code: int = 200) -> JsonDict:
         """Calls the endpoint under test. returns the json response object."""
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_matrix/client/unstable/org.matrix.msc2432/rooms/%s/aliases"
             % (self.room_id,),
@@ -1810,7 +1853,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
         data = {"room_id": self.room_id}
         request_data = json.dumps(data)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", url, request_data, access_token=self.room_owner_tok
         )
         self.assertEqual(channel.code, expected_code, channel.result)
@@ -1840,14 +1883,14 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
         data = {"room_id": self.room_id}
         request_data = json.dumps(data)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", url, request_data, access_token=self.room_owner_tok
         )
         self.assertEqual(channel.code, expected_code, channel.result)
 
     def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict:
         """Calls the endpoint under test. returns the json response object."""
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
             access_token=self.room_owner_tok,
@@ -1859,7 +1902,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
 
     def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict:
         """Calls the endpoint under test. returns the json response object."""
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
             json.dumps(content),
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index bbd30f594b..38c51525a3 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -39,7 +39,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
 
         hs = self.setup_test_homeserver(
-            "red", http_client=None, federation_client=Mock(),
+            "red", federation_http_client=None, federation_client=Mock(),
         )
 
         self.event_source = hs.get_event_sources().sources["typing"]
@@ -94,7 +94,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
         self.helper.join(self.room_id, user="@jim:red")
 
     def test_set_typing(self):
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/rooms/%s/typing/%s" % (self.room_id, self.user_id),
             b'{"typing": true, "timeout": 30000}',
@@ -117,7 +117,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
         )
 
     def test_set_not_typing(self):
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/rooms/%s/typing/%s" % (self.room_id, self.user_id),
             b'{"typing": false}',
@@ -125,7 +125,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code)
 
     def test_typing_timeout(self):
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/rooms/%s/typing/%s" % (self.room_id, self.user_id),
             b'{"typing": true, "timeout": 30000}',
@@ -138,7 +138,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
 
         self.assertEquals(self.event_source.get_current_key(), 2)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/rooms/%s/typing/%s" % (self.room_id, self.user_id),
             b'{"typing": true, "timeout": 30000}',
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 737c38c396..b1333df82d 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -2,7 +2,7 @@
 # Copyright 2014-2016 OpenMarket Ltd
 # Copyright 2017 Vector Creations Ltd
 # Copyright 2018-2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -17,8 +17,12 @@
 # limitations under the License.
 
 import json
+import re
 import time
-from typing import Any, Dict, Optional
+import urllib.parse
+from typing import Any, Dict, Mapping, MutableMapping, Optional
+
+from mock import patch
 
 import attr
 
@@ -26,8 +30,11 @@ from twisted.web.resource import Resource
 from twisted.web.server import Site
 
 from synapse.api.constants import Membership
+from synapse.types import JsonDict
 
-from tests.server import FakeSite, make_request
+from tests.server import FakeChannel, FakeSite, make_request
+from tests.test_utils import FakeResponse
+from tests.test_utils.html_parsers import TestHtmlParser
 
 
 @attr.s
@@ -75,7 +82,7 @@ class RestHelper:
         if tok:
             path = path + "?access_token=%s" % tok
 
-        _, channel = make_request(
+        channel = make_request(
             self.hs.get_reactor(),
             self.site,
             "POST",
@@ -151,7 +158,7 @@ class RestHelper:
         data = {"membership": membership}
         data.update(extra_data)
 
-        _, channel = make_request(
+        channel = make_request(
             self.hs.get_reactor(),
             self.site,
             "PUT",
@@ -186,7 +193,7 @@ class RestHelper:
         if tok:
             path = path + "?access_token=%s" % tok
 
-        _, channel = make_request(
+        channel = make_request(
             self.hs.get_reactor(),
             self.site,
             "PUT",
@@ -242,9 +249,7 @@ class RestHelper:
         if body is not None:
             content = json.dumps(body).encode("utf8")
 
-        _, channel = make_request(
-            self.hs.get_reactor(), self.site, method, path, content
-        )
+        channel = make_request(self.hs.get_reactor(), self.site, method, path, content)
 
         assert int(channel.result["code"]) == expect_code, (
             "Expected: %d, got: %d, resp: %r"
@@ -327,7 +332,7 @@ class RestHelper:
         """
         image_length = len(image_data)
         path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
-        _, channel = make_request(
+        channel = make_request(
             self.hs.get_reactor(),
             FakeSite(resource),
             "POST",
@@ -344,3 +349,246 @@ class RestHelper:
         )
 
         return channel.json_body
+
+    def login_via_oidc(self, remote_user_id: str) -> JsonDict:
+        """Log in (as a new user) via OIDC
+
+        Returns the result of the final token login.
+
+        Requires that "oidc_config" in the homeserver config be set appropriately
+        (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
+        "public_base_url".
+
+        Also requires the login servlet and the OIDC callback resource to be mounted at
+        the normal places.
+        """
+        client_redirect_url = "https://x"
+        channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url)
+
+        # expect a confirmation page
+        assert channel.code == 200, channel.result
+
+        # fish the matrix login token out of the body of the confirmation page
+        m = re.search(
+            'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
+            channel.text_body,
+        )
+        assert m, channel.text_body
+        login_token = m.group(1)
+
+        # finally, submit the matrix login token to the login API, which gives us our
+        # matrix access token and device id.
+        channel = make_request(
+            self.hs.get_reactor(),
+            self.site,
+            "POST",
+            "/login",
+            content={"type": "m.login.token", "token": login_token},
+        )
+        assert channel.code == 200
+        return channel.json_body
+
+    def auth_via_oidc(
+        self,
+        user_info_dict: JsonDict,
+        client_redirect_url: Optional[str] = None,
+        ui_auth_session_id: Optional[str] = None,
+    ) -> FakeChannel:
+        """Perform an OIDC authentication flow via a mock OIDC provider.
+
+        This can be used for either login or user-interactive auth.
+
+        Starts by making a request to the relevant synapse redirect endpoint, which is
+        expected to serve a 302 to the OIDC provider. We then make a request to the
+        OIDC callback endpoint, intercepting the HTTP requests that will get sent back
+        to the OIDC provider.
+
+        Requires that "oidc_config" in the homeserver config be set appropriately
+        (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
+        "public_base_url".
+
+        Also requires the login servlet and the OIDC callback resource to be mounted at
+        the normal places.
+
+        Args:
+            user_info_dict: the remote userinfo that the OIDC provider should present.
+                Typically this should be '{"sub": "<remote user id>"}'.
+            client_redirect_url: for a login flow, the client redirect URL to pass to
+                the login redirect endpoint
+            ui_auth_session_id: if set, we will perform a UI Auth flow. The session id
+                of the UI auth.
+
+        Returns:
+            A FakeChannel containing the result of calling the OIDC callback endpoint.
+            Note that the response code may be a 200, 302 or 400 depending on how things
+            went.
+        """
+
+        cookies = {}
+
+        # if we're doing a ui auth, hit the ui auth redirect endpoint
+        if ui_auth_session_id:
+            # can't set the client redirect url for UI Auth
+            assert client_redirect_url is None
+            oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies)
+        else:
+            # otherwise, hit the login redirect endpoint
+            oauth_uri = self.initiate_sso_login(client_redirect_url, cookies)
+
+        # we now have a URI for the OIDC IdP, but we skip that and go straight
+        # back to synapse's OIDC callback resource. However, we do need the "state"
+        # param that synapse passes to the IdP via query params, as well as the cookie
+        # that synapse passes to the client.
+
+        oauth_uri_path, _ = oauth_uri.split("?", 1)
+        assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, (
+            "unexpected SSO URI " + oauth_uri_path
+        )
+        return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict)
+
+    def complete_oidc_auth(
+        self, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict,
+    ) -> FakeChannel:
+        """Mock out an OIDC authentication flow
+
+        Assumes that an OIDC auth has been initiated by one of initiate_sso_login or
+        initiate_sso_ui_auth; completes the OIDC bits of the flow by making a request to
+        Synapse's OIDC callback endpoint, intercepting the HTTP requests that will get
+        sent back to the OIDC provider.
+
+        Requires the OIDC callback resource to be mounted at the normal place.
+
+        Args:
+            oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie,
+               from initiate_sso_login or initiate_sso_ui_auth).
+            cookies: the cookies set by synapse's redirect endpoint, which will be
+               sent back to the callback endpoint.
+            user_info_dict: the remote userinfo that the OIDC provider should present.
+                Typically this should be '{"sub": "<remote user id>"}'.
+
+        Returns:
+            A FakeChannel containing the result of calling the OIDC callback endpoint.
+        """
+        _, oauth_uri_qs = oauth_uri.split("?", 1)
+        params = urllib.parse.parse_qs(oauth_uri_qs)
+        callback_uri = "%s?%s" % (
+            urllib.parse.urlparse(params["redirect_uri"][0]).path,
+            urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}),
+        )
+
+        # before we hit the callback uri, stub out some methods in the http client so
+        # that we don't have to handle full HTTPS requests.
+        # (expected url, json response) pairs, in the order we expect them.
+        expected_requests = [
+            # first we get a hit to the token endpoint, which we tell to return
+            # a dummy OIDC access token
+            (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}),
+            # and then one to the user_info endpoint, which returns our remote user id.
+            (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict),
+        ]
+
+        async def mock_req(method: str, uri: str, data=None, headers=None):
+            (expected_uri, resp_obj) = expected_requests.pop(0)
+            assert uri == expected_uri
+            resp = FakeResponse(
+                code=200, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"),
+            )
+            return resp
+
+        with patch.object(self.hs.get_proxied_http_client(), "request", mock_req):
+            # now hit the callback URI with the right params and a made-up code
+            channel = make_request(
+                self.hs.get_reactor(),
+                self.site,
+                "GET",
+                callback_uri,
+                custom_headers=[
+                    ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items()
+                ],
+            )
+        return channel
+
+    def initiate_sso_login(
+        self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str]
+    ) -> str:
+        """Make a request to the login-via-sso redirect endpoint, and return the target
+
+        Assumes that exactly one SSO provider has been configured. Requires the login
+        servlet to be mounted.
+
+        Args:
+            client_redirect_url: the client redirect URL to pass to the login redirect
+                endpoint
+            cookies: any cookies returned will be added to this dict
+
+        Returns:
+            the URI that the client gets redirected to (ie, the SSO server)
+        """
+        params = {}
+        if client_redirect_url:
+            params["redirectUrl"] = client_redirect_url
+
+        # hit the redirect url (which will issue a cookie and state)
+        channel = make_request(
+            self.hs.get_reactor(),
+            self.site,
+            "GET",
+            "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params),
+        )
+
+        assert channel.code == 302
+        channel.extract_cookies(cookies)
+        return channel.headers.getRawHeaders("Location")[0]
+
+    def initiate_sso_ui_auth(
+        self, ui_auth_session_id: str, cookies: MutableMapping[str, str]
+    ) -> str:
+        """Make a request to the ui-auth-via-sso endpoint, and return the target
+
+        Assumes that exactly one SSO provider has been configured. Requires the
+        AuthRestServlet to be mounted.
+
+        Args:
+            ui_auth_session_id: the session id of the UI auth
+            cookies: any cookies returned will be added to this dict
+
+        Returns:
+            the URI that the client gets linked to (ie, the SSO server)
+        """
+        sso_redirect_endpoint = (
+            "/_matrix/client/r0/auth/m.login.sso/fallback/web?"
+            + urllib.parse.urlencode({"session": ui_auth_session_id})
+        )
+        # hit the redirect url (which will issue a cookie and state)
+        channel = make_request(
+            self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint
+        )
+        # that should serve a confirmation page
+        assert channel.code == 200, channel.text_body
+        channel.extract_cookies(cookies)
+
+        # parse the confirmation page to fish out the link.
+        p = TestHtmlParser()
+        p.feed(channel.text_body)
+        p.close()
+        assert len(p.links) == 1, "not exactly one link in confirmation page"
+        oauth_uri = p.links[0]
+        return oauth_uri
+
+
+# an 'oidc_config' suitable for login_via_oidc.
+TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth"
+TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token"
+TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo"
+TEST_OIDC_CONFIG = {
+    "enabled": True,
+    "discover": False,
+    "issuer": "https://issuer.test",
+    "client_id": "test-client-id",
+    "client_secret": "test-client-secret",
+    "scopes": ["profile"],
+    "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT,
+    "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT,
+    "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT,
+    "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
+}
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 2ac1ecb7d3..177dc476da 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -19,13 +19,12 @@ import os
 import re
 from email.parser import Parser
 from typing import Optional
-from urllib.parse import urlencode
 
 import pkg_resources
 
 import synapse.rest.admin
 from synapse.api.constants import LoginType, Membership
-from synapse.api.errors import Codes
+from synapse.api.errors import Codes, HttpResponseException
 from synapse.rest.client.v1 import login, room
 from synapse.rest.client.v2_alpha import account, register
 from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
@@ -113,6 +112,56 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         # Assert we can't log in with the old password
         self.attempt_wrong_password_login("kermit", old_password)
 
+    @override_config({"rc_3pid_validation": {"burst_count": 3}})
+    def test_ratelimit_by_email(self):
+        """Test that we ratelimit /requestToken for the same email.
+        """
+        old_password = "monkey"
+        new_password = "kangeroo"
+
+        user_id = self.register_user("kermit", old_password)
+        self.login("kermit", old_password)
+
+        email = "test1@example.com"
+
+        # Add a threepid
+        self.get_success(
+            self.store.user_add_threepid(
+                user_id=user_id,
+                medium="email",
+                address=email,
+                validated_at=0,
+                added_at=0,
+            )
+        )
+
+        def reset(ip):
+            client_secret = "foobar"
+            session_id = self._request_token(email, client_secret, ip)
+
+            self.assertEquals(len(self.email_attempts), 1)
+            link = self._get_link_from_email()
+
+            self._validate_token(link)
+
+            self._reset_password(new_password, session_id, client_secret)
+
+            self.email_attempts.clear()
+
+        # We expect to be able to make three requests before getting rate
+        # limited.
+        #
+        # We change IPs to ensure that we're not being ratelimited due to the
+        # same IP
+        reset("127.0.0.1")
+        reset("127.0.0.2")
+        reset("127.0.0.3")
+
+        with self.assertRaises(HttpResponseException) as cm:
+            reset("127.0.0.4")
+
+        self.assertEqual(cm.exception.code, 429)
+
     def test_basic_password_reset_canonicalise_email(self):
         """Test basic password reset flow
         Request password reset with different spelling
@@ -240,13 +289,18 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
 
         self.assertIsNotNone(session_id)
 
-    def _request_token(self, email, client_secret):
-        request, channel = self.make_request(
+    def _request_token(self, email, client_secret, ip="127.0.0.1"):
+        channel = self.make_request(
             "POST",
             b"account/password/email/requestToken",
             {"client_secret": client_secret, "email": email, "send_attempt": 1},
+            client_ip=ip,
         )
-        self.assertEquals(200, channel.code, channel.result)
+
+        if channel.code != 200:
+            raise HttpResponseException(
+                channel.code, channel.result["reason"], channel.result["body"],
+            )
 
         return channel.json_body["sid"]
 
@@ -255,7 +309,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         path = link.replace("https://example.com", "")
 
         # Load the password reset confirmation page
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             FakeSite(self.submit_token_resource),
             "GET",
@@ -268,20 +322,13 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         # Now POST to the same endpoint, mimicking the same behaviour as clicking the
         # password reset confirm button
 
-        # Send arguments as url-encoded form data, matching the template's behaviour
-        form_args = []
-        for key, value_list in request.args.items():
-            for value in value_list:
-                arg = (key, value)
-                form_args.append(arg)
-
         # Confirm the password reset
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             FakeSite(self.submit_token_resource),
             "POST",
             path,
-            content=urlencode(form_args).encode("utf8"),
+            content=b"",
             shorthand=False,
             content_is_form=True,
         )
@@ -310,7 +357,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
     def _reset_password(
         self, new_password, session_id, client_secret, expected_code=200
     ):
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             b"account/password",
             {
@@ -352,8 +399,8 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
         self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id)))
 
         # Check that this access token has been invalidated.
-        request, channel = self.make_request("GET", "account/whoami")
-        self.assertEqual(request.code, 401)
+        channel = self.make_request("GET", "account/whoami")
+        self.assertEqual(channel.code, 401)
 
     def test_pending_invites(self):
         """Tests that deactivating a user rejects every pending invite for them."""
@@ -407,10 +454,10 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
                 "erase": False,
             }
         )
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", "account/deactivate", request_data, access_token=tok
         )
-        self.assertEqual(request.code, 200)
+        self.assertEqual(channel.code, 200)
 
 
 class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
@@ -517,6 +564,21 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
     def test_address_trim(self):
         self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar"))
 
+    @override_config({"rc_3pid_validation": {"burst_count": 3}})
+    def test_ratelimit_by_ip(self):
+        """Tests that adding emails is ratelimited by IP
+        """
+
+        # We expect to be able to set three emails before getting ratelimited.
+        self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar"))
+        self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar"))
+        self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar"))
+
+        with self.assertRaises(HttpResponseException) as cm:
+            self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar"))
+
+        self.assertEqual(cm.exception.code, 429)
+
     def test_add_email_if_disabled(self):
         """Test adding email to profile when doing so is disallowed
         """
@@ -530,7 +592,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
 
         self._validate_token(link)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             b"/_matrix/client/unstable/account/3pid/add",
             {
@@ -548,7 +610,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
         # Get user
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url_3pid, access_token=self.user_id_tok,
         )
 
@@ -569,7 +631,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             b"account/3pid/delete",
             {"medium": "email", "address": self.email},
@@ -578,7 +640,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
 
         # Get user
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url_3pid, access_token=self.user_id_tok,
         )
 
@@ -601,7 +663,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             b"account/3pid/delete",
             {"medium": "email", "address": self.email},
@@ -612,7 +674,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
         # Get user
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url_3pid, access_token=self.user_id_tok,
         )
 
@@ -629,7 +691,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         self.assertEquals(len(self.email_attempts), 1)
 
         # Attempt to add email without clicking the link
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             b"/_matrix/client/unstable/account/3pid/add",
             {
@@ -647,7 +709,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
 
         # Get user
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url_3pid, access_token=self.user_id_tok,
         )
 
@@ -662,7 +724,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         session_id = "weasle"
 
         # Attempt to add email without even requesting an email
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             b"/_matrix/client/unstable/account/3pid/add",
             {
@@ -680,7 +742,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
 
         # Get user
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url_3pid, access_token=self.user_id_tok,
         )
 
@@ -784,17 +846,19 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         if next_link:
             body["next_link"] = next_link
 
-        request, channel = self.make_request(
-            "POST", b"account/3pid/email/requestToken", body,
-        )
-        self.assertEquals(expect_code, channel.code, channel.result)
+        channel = self.make_request("POST", b"account/3pid/email/requestToken", body,)
+
+        if channel.code != expect_code:
+            raise HttpResponseException(
+                channel.code, channel.result["reason"], channel.result["body"],
+            )
 
         return channel.json_body.get("sid")
 
     def _request_token_invalid_email(
         self, email, expected_errcode, expected_error, client_secret="foobar",
     ):
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             b"account/3pid/email/requestToken",
             {"client_secret": client_secret, "email": email, "send_attempt": 1},
@@ -807,7 +871,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         # Remove the host
         path = link.replace("https://example.com", "")
 
-        request, channel = self.make_request("GET", path, shorthand=False)
+        channel = self.make_request("GET", path, shorthand=False)
         self.assertEquals(200, channel.code, channel.result)
 
     def _get_link_from_email(self):
@@ -833,15 +897,17 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
     def _add_email(self, request_email, expected_email):
         """Test adding an email to profile
         """
+        previous_email_attempts = len(self.email_attempts)
+
         client_secret = "foobar"
         session_id = self._request_token(request_email, client_secret)
 
-        self.assertEquals(len(self.email_attempts), 1)
+        self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1)
         link = self._get_link_from_email()
 
         self._validate_token(link)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             b"/_matrix/client/unstable/account/3pid/add",
             {
@@ -859,10 +925,12 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
 
         # Get user
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url_3pid, access_token=self.user_id_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
-        self.assertEqual(expected_email, channel.json_body["threepids"][0]["address"])
+
+        threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
+        self.assertIn(expected_email, threepids)
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 77246e478f..3f50c56745 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2018 New Vector
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,20 +13,23 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import List, Union
+from typing import Union
 
 from twisted.internet.defer import succeed
 
 import synapse.rest.admin
 from synapse.api.constants import LoginType
 from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
-from synapse.http.site import SynapseRequest
 from synapse.rest.client.v1 import login
 from synapse.rest.client.v2_alpha import auth, devices, register
-from synapse.types import JsonDict
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
+from synapse.types import JsonDict, UserID
 
 from tests import unittest
+from tests.handlers.test_oidc import HAS_OIDC
+from tests.rest.client.v1.utils import TEST_OIDC_CONFIG
 from tests.server import FakeChannel
+from tests.unittest import override_config, skip_unless
 
 
 class DummyRecaptchaChecker(UserInteractiveAuthChecker):
@@ -64,11 +68,9 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
 
     def register(self, expected_response: int, body: JsonDict) -> FakeChannel:
         """Make a register request."""
-        request, channel = self.make_request(
-            "POST", "register", body
-        )  # type: SynapseRequest, FakeChannel
+        channel = self.make_request("POST", "register", body)
 
-        self.assertEqual(request.code, expected_response)
+        self.assertEqual(channel.code, expected_response)
         return channel
 
     def recaptcha(
@@ -78,18 +80,18 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
         if post_session is None:
             post_session = session
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "auth/m.login.recaptcha/fallback/web?session=" + session
-        )  # type: SynapseRequest, FakeChannel
-        self.assertEqual(request.code, 200)
+        )
+        self.assertEqual(channel.code, 200)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "auth/m.login.recaptcha/fallback/web?session="
             + post_session
             + "&g-recaptcha-response=a",
         )
-        self.assertEqual(request.code, expected_post_response)
+        self.assertEqual(channel.code, expected_post_response)
 
         # The recaptcha handler is called with the response given
         attempts = self.recaptcha_checker.recaptcha_attempts
@@ -156,31 +158,44 @@ class UIAuthTests(unittest.HomeserverTestCase):
         register.register_servlets,
     ]
 
+    def default_config(self):
+        config = super().default_config()
+        config["public_baseurl"] = "https://synapse.test"
+
+        if HAS_OIDC:
+            # we enable OIDC as a way of testing SSO flows
+            oidc_config = {}
+            oidc_config.update(TEST_OIDC_CONFIG)
+            oidc_config["allow_existing_users"] = True
+            config["oidc_config"] = oidc_config
+
+        return config
+
+    def create_resource_dict(self):
+        resource_dict = super().create_resource_dict()
+        resource_dict.update(build_synapse_client_resource_tree(self.hs))
+        return resource_dict
+
     def prepare(self, reactor, clock, hs):
         self.user_pass = "pass"
         self.user = self.register_user("test", self.user_pass)
-        self.user_tok = self.login("test", self.user_pass)
-
-    def get_device_ids(self) -> List[str]:
-        # Get the list of devices so one can be deleted.
-        request, channel = self.make_request(
-            "GET", "devices", access_token=self.user_tok,
-        )  # type: SynapseRequest, FakeChannel
-
-        # Get the ID of the device.
-        self.assertEqual(request.code, 200)
-        return [d["device_id"] for d in channel.json_body["devices"]]
+        self.device_id = "dev1"
+        self.user_tok = self.login("test", self.user_pass, self.device_id)
 
     def delete_device(
-        self, device: str, expected_response: int, body: Union[bytes, JsonDict] = b""
+        self,
+        access_token: str,
+        device: str,
+        expected_response: int,
+        body: Union[bytes, JsonDict] = b"",
     ) -> FakeChannel:
         """Delete an individual device."""
-        request, channel = self.make_request(
-            "DELETE", "devices/" + device, body, access_token=self.user_tok
-        )  # type: SynapseRequest, FakeChannel
+        channel = self.make_request(
+            "DELETE", "devices/" + device, body, access_token=access_token,
+        )
 
         # Ensure the response is sane.
-        self.assertEqual(request.code, expected_response)
+        self.assertEqual(channel.code, expected_response)
 
         return channel
 
@@ -188,12 +203,12 @@ class UIAuthTests(unittest.HomeserverTestCase):
         """Delete 1 or more devices."""
         # Note that this uses the delete_devices endpoint so that we can modify
         # the payload half-way through some tests.
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", "delete_devices", body, access_token=self.user_tok,
-        )  # type: SynapseRequest, FakeChannel
+        )
 
         # Ensure the response is sane.
-        self.assertEqual(request.code, expected_response)
+        self.assertEqual(channel.code, expected_response)
 
         return channel
 
@@ -201,11 +216,9 @@ class UIAuthTests(unittest.HomeserverTestCase):
         """
         Test user interactive authentication outside of registration.
         """
-        device_id = self.get_device_ids()[0]
-
         # Attempt to delete this device.
         # Returns a 401 as per the spec
-        channel = self.delete_device(device_id, 401)
+        channel = self.delete_device(self.user_tok, self.device_id, 401)
 
         # Grab the session
         session = channel.json_body["session"]
@@ -214,7 +227,8 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
         # Make another request providing the UI auth flow.
         self.delete_device(
-            device_id,
+            self.user_tok,
+            self.device_id,
             200,
             {
                 "auth": {
@@ -233,13 +247,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
         UIA - check that still works.
         """
 
-        device_id = self.get_device_ids()[0]
-        channel = self.delete_device(device_id, 401)
+        channel = self.delete_device(self.user_tok, self.device_id, 401)
         session = channel.json_body["session"]
 
         # Make another request providing the UI auth flow.
         self.delete_device(
-            device_id,
+            self.user_tok,
+            self.device_id,
             200,
             {
                 "auth": {
@@ -262,14 +276,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
         session ID should be rejected.
         """
         # Create a second login.
-        self.login("test", self.user_pass)
-
-        device_ids = self.get_device_ids()
-        self.assertEqual(len(device_ids), 2)
+        self.login("test", self.user_pass, "dev2")
 
         # Attempt to delete the first device.
         # Returns a 401 as per the spec
-        channel = self.delete_devices(401, {"devices": [device_ids[0]]})
+        channel = self.delete_devices(401, {"devices": [self.device_id]})
 
         # Grab the session
         session = channel.json_body["session"]
@@ -281,7 +292,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
         self.delete_devices(
             200,
             {
-                "devices": [device_ids[1]],
+                "devices": ["dev2"],
                 "auth": {
                     "type": "m.login.password",
                     "identifier": {"type": "m.id.user", "user": self.user},
@@ -296,14 +307,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
         The initial requested URI cannot be modified during the user interactive authentication session.
         """
         # Create a second login.
-        self.login("test", self.user_pass)
-
-        device_ids = self.get_device_ids()
-        self.assertEqual(len(device_ids), 2)
+        self.login("test", self.user_pass, "dev2")
 
         # Attempt to delete the first device.
         # Returns a 401 as per the spec
-        channel = self.delete_device(device_ids[0], 401)
+        channel = self.delete_device(self.user_tok, self.device_id, 401)
 
         # Grab the session
         session = channel.json_body["session"]
@@ -312,8 +320,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
         # Make another request providing the UI auth flow, but try to delete the
         # second device. This results in an error.
+        #
+        # This makes use of the fact that the device ID is embedded into the URL.
         self.delete_device(
-            device_ids[1],
+            self.user_tok,
+            "dev2",
             403,
             {
                 "auth": {
@@ -324,3 +335,152 @@ class UIAuthTests(unittest.HomeserverTestCase):
                 },
             },
         )
+
+    @unittest.override_config({"ui_auth": {"session_timeout": 5 * 1000}})
+    def test_can_reuse_session(self):
+        """
+        The session can be reused if configured.
+
+        Compare to test_cannot_change_uri.
+        """
+        # Create a second and third login.
+        self.login("test", self.user_pass, "dev2")
+        self.login("test", self.user_pass, "dev3")
+
+        # Attempt to delete a device. This works since the user just logged in.
+        self.delete_device(self.user_tok, "dev2", 200)
+
+        # Move the clock forward past the validation timeout.
+        self.reactor.advance(6)
+
+        # Deleting another devices throws the user into UI auth.
+        channel = self.delete_device(self.user_tok, "dev3", 401)
+
+        # Grab the session
+        session = channel.json_body["session"]
+        # Ensure that flows are what is expected.
+        self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+
+        # Make another request providing the UI auth flow.
+        self.delete_device(
+            self.user_tok,
+            "dev3",
+            200,
+            {
+                "auth": {
+                    "type": "m.login.password",
+                    "identifier": {"type": "m.id.user", "user": self.user},
+                    "password": self.user_pass,
+                    "session": session,
+                },
+            },
+        )
+
+        # Make another request, but try to delete the first device. This works
+        # due to re-using the previous session.
+        #
+        # Note that *no auth* information is provided, not even a session iD!
+        self.delete_device(self.user_tok, self.device_id, 200)
+
+    @skip_unless(HAS_OIDC, "requires OIDC")
+    @override_config({"oidc_config": TEST_OIDC_CONFIG})
+    def test_ui_auth_via_sso(self):
+        """Test a successful UI Auth flow via SSO
+
+        This includes:
+          * hitting the UIA SSO redirect endpoint
+          * checking it serves a confirmation page which links to the OIDC provider
+          * calling back to the synapse oidc callback
+          * checking that the original operation succeeds
+        """
+
+        # log the user in
+        remote_user_id = UserID.from_string(self.user).localpart
+        login_resp = self.helper.login_via_oidc(remote_user_id)
+        self.assertEqual(login_resp["user_id"], self.user)
+
+        # initiate a UI Auth process by attempting to delete the device
+        channel = self.delete_device(self.user_tok, self.device_id, 401)
+
+        # check that SSO is offered
+        flows = channel.json_body["flows"]
+        self.assertIn({"stages": ["m.login.sso"]}, flows)
+
+        # run the UIA-via-SSO flow
+        session_id = channel.json_body["session"]
+        channel = self.helper.auth_via_oidc(
+            {"sub": remote_user_id}, ui_auth_session_id=session_id
+        )
+
+        # that should serve a confirmation page
+        self.assertEqual(channel.code, 200, channel.result)
+
+        # and now the delete request should succeed.
+        self.delete_device(
+            self.user_tok, self.device_id, 200, body={"auth": {"session": session_id}},
+        )
+
+    @skip_unless(HAS_OIDC, "requires OIDC")
+    @override_config({"oidc_config": TEST_OIDC_CONFIG})
+    def test_does_not_offer_password_for_sso_user(self):
+        login_resp = self.helper.login_via_oidc("username")
+        user_tok = login_resp["access_token"]
+        device_id = login_resp["device_id"]
+
+        # now call the device deletion API: we should get the option to auth with SSO
+        # and not password.
+        channel = self.delete_device(user_tok, device_id, 401)
+
+        flows = channel.json_body["flows"]
+        self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
+
+    def test_does_not_offer_sso_for_password_user(self):
+        channel = self.delete_device(self.user_tok, self.device_id, 401)
+
+        flows = channel.json_body["flows"]
+        self.assertEqual(flows, [{"stages": ["m.login.password"]}])
+
+    @skip_unless(HAS_OIDC, "requires OIDC")
+    @override_config({"oidc_config": TEST_OIDC_CONFIG})
+    def test_offers_both_flows_for_upgraded_user(self):
+        """A user that had a password and then logged in with SSO should get both flows
+        """
+        login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
+        self.assertEqual(login_resp["user_id"], self.user)
+
+        channel = self.delete_device(self.user_tok, self.device_id, 401)
+
+        flows = channel.json_body["flows"]
+        # we have no particular expectations of ordering here
+        self.assertIn({"stages": ["m.login.password"]}, flows)
+        self.assertIn({"stages": ["m.login.sso"]}, flows)
+        self.assertEqual(len(flows), 2)
+
+    @skip_unless(HAS_OIDC, "requires OIDC")
+    @override_config({"oidc_config": TEST_OIDC_CONFIG})
+    def test_ui_auth_fails_for_incorrect_sso_user(self):
+        """If the user tries to authenticate with the wrong SSO user, they get an error
+        """
+        # log the user in
+        login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
+        self.assertEqual(login_resp["user_id"], self.user)
+
+        # start a UI Auth flow by attempting to delete a device
+        channel = self.delete_device(self.user_tok, self.device_id, 401)
+
+        flows = channel.json_body["flows"]
+        self.assertIn({"stages": ["m.login.sso"]}, flows)
+        session_id = channel.json_body["session"]
+
+        # do the OIDC auth, but auth as the wrong user
+        channel = self.helper.auth_via_oidc(
+            {"sub": "wrong_user"}, ui_auth_session_id=session_id
+        )
+
+        # that should return a failure message
+        self.assertSubstring("We were unable to validate", channel.text_body)
+
+        # ... and the delete op should now fail with a 403
+        self.delete_device(
+            self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}}
+        )
diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py
index 767e126875..e808339fb3 100644
--- a/tests/rest/client/v2_alpha/test_capabilities.py
+++ b/tests/rest/client/v2_alpha/test_capabilities.py
@@ -36,7 +36,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
         return hs
 
     def test_check_auth_required(self):
-        request, channel = self.make_request("GET", self.url)
+        channel = self.make_request("GET", self.url)
 
         self.assertEqual(channel.code, 401)
 
@@ -44,7 +44,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
         self.register_user("user", "pass")
         access_token = self.login("user", "pass")
 
-        request, channel = self.make_request("GET", self.url, access_token=access_token)
+        channel = self.make_request("GET", self.url, access_token=access_token)
         capabilities = channel.json_body["capabilities"]
 
         self.assertEqual(channel.code, 200)
@@ -62,7 +62,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
         user = self.register_user(localpart, password)
         access_token = self.login(user, password)
 
-        request, channel = self.make_request("GET", self.url, access_token=access_token)
+        channel = self.make_request("GET", self.url, access_token=access_token)
         capabilities = channel.json_body["capabilities"]
 
         self.assertEqual(channel.code, 200)
@@ -70,7 +70,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
         # Test case where password is handled outside of Synapse
         self.assertTrue(capabilities["m.change_password"]["enabled"])
         self.get_success(self.store.user_set_password_hash(user, None))
-        request, channel = self.make_request("GET", self.url, access_token=access_token)
+        channel = self.make_request("GET", self.url, access_token=access_token)
         capabilities = channel.json_body["capabilities"]
 
         self.assertEqual(channel.code, 200)
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index 231d5aefea..f761c44936 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -36,7 +36,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
         self.store = hs.get_datastore()
 
     def test_add_filter(self):
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/r0/user/%s/filter" % (self.user_id),
             self.EXAMPLE_FILTER_JSON,
@@ -49,7 +49,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
         self.assertEquals(filter.result, self.EXAMPLE_FILTER)
 
     def test_add_filter_for_other_user(self):
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"),
             self.EXAMPLE_FILTER_JSON,
@@ -61,7 +61,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
     def test_add_filter_non_local_user(self):
         _is_mine = self.hs.is_mine
         self.hs.is_mine = lambda target_user: False
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/r0/user/%s/filter" % (self.user_id),
             self.EXAMPLE_FILTER_JSON,
@@ -79,7 +79,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
         )
         self.reactor.advance(1)
         filter_id = filter_id.result
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id)
         )
 
@@ -87,7 +87,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
         self.assertEquals(channel.json_body, self.EXAMPLE_FILTER)
 
     def test_get_filter_non_existant(self):
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id)
         )
 
@@ -97,7 +97,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
     # Currently invalid params do not have an appropriate errcode
     # in errors.py
     def test_get_filter_invalid_id(self):
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id)
         )
 
@@ -105,7 +105,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
 
     # No ID also returns an invalid_id error
     def test_get_filter_no_id(self):
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id)
         )
 
diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py
index ee86b94917..fba34def30 100644
--- a/tests/rest/client/v2_alpha/test_password_policy.py
+++ b/tests/rest/client/v2_alpha/test_password_policy.py
@@ -70,9 +70,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
     def test_get_policy(self):
         """Tests if the /password_policy endpoint returns the configured policy."""
 
-        request, channel = self.make_request(
-            "GET", "/_matrix/client/r0/password_policy"
-        )
+        channel = self.make_request("GET", "/_matrix/client/r0/password_policy")
 
         self.assertEqual(channel.code, 200, channel.result)
         self.assertEqual(
@@ -89,7 +87,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
 
     def test_password_too_short(self):
         request_data = json.dumps({"username": "kermit", "password": "shorty"})
-        request, channel = self.make_request("POST", self.register_url, request_data)
+        channel = self.make_request("POST", self.register_url, request_data)
 
         self.assertEqual(channel.code, 400, channel.result)
         self.assertEqual(
@@ -98,7 +96,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
 
     def test_password_no_digit(self):
         request_data = json.dumps({"username": "kermit", "password": "longerpassword"})
-        request, channel = self.make_request("POST", self.register_url, request_data)
+        channel = self.make_request("POST", self.register_url, request_data)
 
         self.assertEqual(channel.code, 400, channel.result)
         self.assertEqual(
@@ -107,7 +105,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
 
     def test_password_no_symbol(self):
         request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"})
-        request, channel = self.make_request("POST", self.register_url, request_data)
+        channel = self.make_request("POST", self.register_url, request_data)
 
         self.assertEqual(channel.code, 400, channel.result)
         self.assertEqual(
@@ -116,7 +114,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
 
     def test_password_no_uppercase(self):
         request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"})
-        request, channel = self.make_request("POST", self.register_url, request_data)
+        channel = self.make_request("POST", self.register_url, request_data)
 
         self.assertEqual(channel.code, 400, channel.result)
         self.assertEqual(
@@ -125,7 +123,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
 
     def test_password_no_lowercase(self):
         request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"})
-        request, channel = self.make_request("POST", self.register_url, request_data)
+        channel = self.make_request("POST", self.register_url, request_data)
 
         self.assertEqual(channel.code, 400, channel.result)
         self.assertEqual(
@@ -134,7 +132,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
 
     def test_password_compliant(self):
         request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"})
-        request, channel = self.make_request("POST", self.register_url, request_data)
+        channel = self.make_request("POST", self.register_url, request_data)
 
         # Getting a 401 here means the password has passed validation and the server has
         # responded with a list of registration flows.
@@ -160,7 +158,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
                 },
             }
         )
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/r0/account/password",
             request_data,
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 8f0c2430e8..27db4f551e 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -61,7 +61,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.hs.get_datastore().services_cache.append(appservice)
         request_data = json.dumps({"username": "as_user_kermit"})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
         )
 
@@ -72,7 +72,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
     def test_POST_appservice_registration_invalid(self):
         self.appservice = None  # no application service exists
         request_data = json.dumps({"username": "kermit"})
-        request, channel = self.make_request(
+        channel = self.make_request(
             b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
         )
 
@@ -80,14 +80,14 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
     def test_POST_bad_password(self):
         request_data = json.dumps({"username": "kermit", "password": 666})
-        request, channel = self.make_request(b"POST", self.url, request_data)
+        channel = self.make_request(b"POST", self.url, request_data)
 
         self.assertEquals(channel.result["code"], b"400", channel.result)
         self.assertEquals(channel.json_body["error"], "Invalid password")
 
     def test_POST_bad_username(self):
         request_data = json.dumps({"username": 777, "password": "monkey"})
-        request, channel = self.make_request(b"POST", self.url, request_data)
+        channel = self.make_request(b"POST", self.url, request_data)
 
         self.assertEquals(channel.result["code"], b"400", channel.result)
         self.assertEquals(channel.json_body["error"], "Invalid username")
@@ -102,7 +102,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             "auth": {"type": LoginType.DUMMY},
         }
         request_data = json.dumps(params)
-        request, channel = self.make_request(b"POST", self.url, request_data)
+        channel = self.make_request(b"POST", self.url, request_data)
 
         det_data = {
             "user_id": user_id,
@@ -117,16 +117,17 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         request_data = json.dumps({"username": "kermit", "password": "monkey"})
         self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
 
-        request, channel = self.make_request(b"POST", self.url, request_data)
+        channel = self.make_request(b"POST", self.url, request_data)
 
         self.assertEquals(channel.result["code"], b"403", channel.result)
         self.assertEquals(channel.json_body["error"], "Registration has been disabled")
+        self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN")
 
     def test_POST_guest_registration(self):
         self.hs.config.macaroon_secret_key = "test"
         self.hs.config.allow_guest_access = True
 
-        request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+        channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
 
         det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"}
         self.assertEquals(channel.result["code"], b"200", channel.result)
@@ -135,7 +136,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
     def test_POST_disabled_guest_registration(self):
         self.hs.config.allow_guest_access = False
 
-        request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+        channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
 
         self.assertEquals(channel.result["code"], b"403", channel.result)
         self.assertEquals(channel.json_body["error"], "Guest access is disabled")
@@ -144,7 +145,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
     def test_POST_ratelimiting_guest(self):
         for i in range(0, 6):
             url = self.url + b"?kind=guest"
-            request, channel = self.make_request(b"POST", url, b"{}")
+            channel = self.make_request(b"POST", url, b"{}")
 
             if i == 5:
                 self.assertEquals(channel.result["code"], b"429", channel.result)
@@ -154,7 +155,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
         self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
 
-        request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+        channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
 
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
@@ -168,7 +169,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
                 "auth": {"type": LoginType.DUMMY},
             }
             request_data = json.dumps(params)
-            request, channel = self.make_request(b"POST", self.url, request_data)
+            channel = self.make_request(b"POST", self.url, request_data)
 
             if i == 5:
                 self.assertEquals(channel.result["code"], b"429", channel.result)
@@ -178,12 +179,12 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
         self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
 
-        request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+        channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
 
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
     def test_advertised_flows(self):
-        request, channel = self.make_request(b"POST", self.url, b"{}")
+        channel = self.make_request(b"POST", self.url, b"{}")
         self.assertEquals(channel.result["code"], b"401", channel.result)
         flows = channel.json_body["flows"]
 
@@ -206,7 +207,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         }
     )
     def test_advertised_flows_captcha_and_terms_and_3pids(self):
-        request, channel = self.make_request(b"POST", self.url, b"{}")
+        channel = self.make_request(b"POST", self.url, b"{}")
         self.assertEquals(channel.result["code"], b"401", channel.result)
         flows = channel.json_body["flows"]
 
@@ -238,7 +239,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         }
     )
     def test_advertised_flows_no_msisdn_email_required(self):
-        request, channel = self.make_request(b"POST", self.url, b"{}")
+        channel = self.make_request(b"POST", self.url, b"{}")
         self.assertEquals(channel.result["code"], b"401", channel.result)
         flows = channel.json_body["flows"]
 
@@ -278,7 +279,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             b"register/email/requestToken",
             {"client_secret": "foobar", "email": email, "send_attempt": 1},
@@ -317,13 +318,13 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
 
         # The specific endpoint doesn't matter, all we need is an authenticated
         # endpoint.
-        request, channel = self.make_request(b"GET", "/sync", access_token=tok)
+        channel = self.make_request(b"GET", "/sync", access_token=tok)
 
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
         self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
 
-        request, channel = self.make_request(b"GET", "/sync", access_token=tok)
+        channel = self.make_request(b"GET", "/sync", access_token=tok)
 
         self.assertEquals(channel.result["code"], b"403", channel.result)
         self.assertEquals(
@@ -345,14 +346,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
         url = "/_synapse/admin/v1/account_validity/validity"
         params = {"user_id": user_id}
         request_data = json.dumps(params)
-        request, channel = self.make_request(
-            b"POST", url, request_data, access_token=admin_tok
-        )
+        channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
         # The specific endpoint doesn't matter, all we need is an authenticated
         # endpoint.
-        request, channel = self.make_request(b"GET", "/sync", access_token=tok)
+        channel = self.make_request(b"GET", "/sync", access_token=tok)
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
     def test_manual_expire(self):
@@ -369,14 +368,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
             "enable_renewal_emails": False,
         }
         request_data = json.dumps(params)
-        request, channel = self.make_request(
-            b"POST", url, request_data, access_token=admin_tok
-        )
+        channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
         # The specific endpoint doesn't matter, all we need is an authenticated
         # endpoint.
-        request, channel = self.make_request(b"GET", "/sync", access_token=tok)
+        channel = self.make_request(b"GET", "/sync", access_token=tok)
         self.assertEquals(channel.result["code"], b"403", channel.result)
         self.assertEquals(
             channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
@@ -396,20 +393,18 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
             "enable_renewal_emails": False,
         }
         request_data = json.dumps(params)
-        request, channel = self.make_request(
-            b"POST", url, request_data, access_token=admin_tok
-        )
+        channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
         # Try to log the user out
-        request, channel = self.make_request(b"POST", "/logout", access_token=tok)
+        channel = self.make_request(b"POST", "/logout", access_token=tok)
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
         # Log the user in again (allowed for expired accounts)
         tok = self.login("kermit", "monkey")
 
         # Try to log out all of the user's sessions
-        request, channel = self.make_request(b"POST", "/logout/all", access_token=tok)
+        channel = self.make_request(b"POST", "/logout/all", access_token=tok)
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
 
@@ -483,7 +478,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
         # retrieve the token from the DB.
         renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id))
         url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
-        request, channel = self.make_request(b"GET", url)
+        channel = self.make_request(b"GET", url)
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
         # Check that we're getting HTML back.
@@ -503,14 +498,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
         # our access token should be denied from now, otherwise they should
         # succeed.
         self.reactor.advance(datetime.timedelta(days=3).total_seconds())
-        request, channel = self.make_request(b"GET", "/sync", access_token=tok)
+        channel = self.make_request(b"GET", "/sync", access_token=tok)
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
     def test_renewal_invalid_token(self):
         # Hit the renewal endpoint with an invalid token and check that it behaves as
         # expected, i.e. that it responds with 404 Not Found and the correct HTML.
         url = "/_matrix/client/unstable/account_validity/renew?token=123"
-        request, channel = self.make_request(b"GET", url)
+        channel = self.make_request(b"GET", url)
         self.assertEquals(channel.result["code"], b"404", channel.result)
 
         # Check that we're getting HTML back.
@@ -531,7 +526,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
         self.email_attempts = []
 
         (user_id, tok) = self.create_user()
-        request, channel = self.make_request(
+        channel = self.make_request(
             b"POST",
             "/_matrix/client/unstable/account_validity/send_mail",
             access_token=tok,
@@ -555,10 +550,10 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
                 "erase": False,
             }
         )
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", "account/deactivate", request_data, access_token=tok
         )
-        self.assertEqual(request.code, 200)
+        self.assertEqual(channel.code, 200)
 
         self.reactor.advance(datetime.timedelta(days=8).total_seconds())
 
@@ -606,7 +601,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
         self.email_attempts = []
 
         # Test that we're still able to manually trigger a mail to be sent.
-        request, channel = self.make_request(
+        channel = self.make_request(
             b"POST",
             "/_matrix/client/unstable/account_validity/send_mail",
             access_token=tok,
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index 6cd4eb6624..bd574077e7 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -60,7 +60,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
         event_id = channel.json_body["event_id"]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/rooms/%s/event/%s" % (self.room, event_id),
             access_token=self.user_token,
@@ -107,7 +107,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code, channel.json_body)
         annotation_id = channel.json_body["event_id"]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1"
             % (self.room, self.parent_id),
@@ -152,7 +152,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             if prev_token:
                 from_token = "&from=" + prev_token
 
-            request, channel = self.make_request(
+            channel = self.make_request(
                 "GET",
                 "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s"
                 % (self.room, self.parent_id, from_token),
@@ -210,7 +210,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             if prev_token:
                 from_token = "&from=" + prev_token
 
-            request, channel = self.make_request(
+            channel = self.make_request(
                 "GET",
                 "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s"
                 % (self.room, self.parent_id, from_token),
@@ -279,7 +279,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             if prev_token:
                 from_token = "&from=" + prev_token
 
-            request, channel = self.make_request(
+            channel = self.make_request(
                 "GET",
                 "/_matrix/client/unstable/rooms/%s"
                 "/aggregations/%s/%s/m.reaction/%s?limit=1%s"
@@ -325,7 +325,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
         self.assertEquals(200, channel.code, channel.json_body)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_matrix/client/unstable/rooms/%s/aggregations/%s"
             % (self.room, self.parent_id),
@@ -357,7 +357,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code, channel.json_body)
 
         # Now lets redact one of the 'a' reactions
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id),
             access_token=self.user_token,
@@ -365,7 +365,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         )
         self.assertEquals(200, channel.code, channel.json_body)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_matrix/client/unstable/rooms/%s/aggregations/%s"
             % (self.room, self.parent_id),
@@ -382,7 +382,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         """Test that aggregations must be annotations.
         """
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_matrix/client/unstable/rooms/%s/aggregations/%s/%s?limit=1"
             % (self.room, self.parent_id, RelationTypes.REPLACE),
@@ -414,7 +414,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code, channel.json_body)
         reply_2 = channel.json_body["event_id"]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/rooms/%s/event/%s" % (self.room, self.parent_id),
             access_token=self.user_token,
@@ -450,7 +450,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
         edit_event_id = channel.json_body["event_id"]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/rooms/%s/event/%s" % (self.room, self.parent_id),
             access_token=self.user_token,
@@ -507,7 +507,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         )
         self.assertEquals(200, channel.code, channel.json_body)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/rooms/%s/event/%s" % (self.room, self.parent_id),
             access_token=self.user_token,
@@ -549,7 +549,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code, channel.json_body)
 
         # Check the relation is returned
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message"
             % (self.room, original_event_id),
@@ -561,7 +561,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertEquals(len(channel.json_body["chunk"]), 1)
 
         # Redact the original event
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/rooms/%s/redact/%s/%s"
             % (self.room, original_event_id, "test_relations_redaction_redacts_edits"),
@@ -571,7 +571,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code, channel.json_body)
 
         # Try to check for remaining m.replace relations
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message"
             % (self.room, original_event_id),
@@ -598,7 +598,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code, channel.json_body)
 
         # Redact the original
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/rooms/%s/redact/%s/%s"
             % (
@@ -612,7 +612,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code, channel.json_body)
 
         # Check that aggregations returns zero
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_matrix/client/unstable/rooms/%s/aggregations/%s/m.annotation/m.reaction"
             % (self.room, original_event_id),
@@ -656,7 +656,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
         original_id = parent_id if parent_id else self.parent_id
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
             % (self.room, original_id, relation_type, event_type, query),
diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py
index 562a9c1ba4..116ace1812 100644
--- a/tests/rest/client/v2_alpha/test_shared_rooms.py
+++ b/tests/rest/client/v2_alpha/test_shared_rooms.py
@@ -17,6 +17,7 @@ from synapse.rest.client.v1 import login, room
 from synapse.rest.client.v2_alpha import shared_rooms
 
 from tests import unittest
+from tests.server import FakeChannel
 
 
 class UserSharedRoomsTest(unittest.HomeserverTestCase):
@@ -40,14 +41,13 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
         self.store = hs.get_datastore()
         self.handler = hs.get_user_directory_handler()
 
-    def _get_shared_rooms(self, token, other_user):
-        request, channel = self.make_request(
+    def _get_shared_rooms(self, token, other_user) -> FakeChannel:
+        return self.make_request(
             "GET",
             "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s"
             % other_user,
             access_token=token,
         )
-        return request, channel
 
     def test_shared_room_list_public(self):
         """
@@ -63,7 +63,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
         self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
         self.helper.join(room, user=u2, tok=u2_token)
 
-        request, channel = self._get_shared_rooms(u1_token, u2)
+        channel = self._get_shared_rooms(u1_token, u2)
         self.assertEquals(200, channel.code, channel.result)
         self.assertEquals(len(channel.json_body["joined"]), 1)
         self.assertEquals(channel.json_body["joined"][0], room)
@@ -82,7 +82,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
         self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
         self.helper.join(room, user=u2, tok=u2_token)
 
-        request, channel = self._get_shared_rooms(u1_token, u2)
+        channel = self._get_shared_rooms(u1_token, u2)
         self.assertEquals(200, channel.code, channel.result)
         self.assertEquals(len(channel.json_body["joined"]), 1)
         self.assertEquals(channel.json_body["joined"][0], room)
@@ -104,7 +104,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
         self.helper.join(room_public, user=u2, tok=u2_token)
         self.helper.join(room_private, user=u1, tok=u1_token)
 
-        request, channel = self._get_shared_rooms(u1_token, u2)
+        channel = self._get_shared_rooms(u1_token, u2)
         self.assertEquals(200, channel.code, channel.result)
         self.assertEquals(len(channel.json_body["joined"]), 2)
         self.assertTrue(room_public in channel.json_body["joined"])
@@ -125,13 +125,13 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
         self.helper.join(room, user=u2, tok=u2_token)
 
         # Assert user directory is not empty
-        request, channel = self._get_shared_rooms(u1_token, u2)
+        channel = self._get_shared_rooms(u1_token, u2)
         self.assertEquals(200, channel.code, channel.result)
         self.assertEquals(len(channel.json_body["joined"]), 1)
         self.assertEquals(channel.json_body["joined"][0], room)
 
         self.helper.leave(room, user=u1, tok=u1_token)
 
-        request, channel = self._get_shared_rooms(u2_token, u1)
+        channel = self._get_shared_rooms(u2_token, u1)
         self.assertEquals(200, channel.code, channel.result)
         self.assertEquals(len(channel.json_body["joined"]), 0)
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 31ac0fccb8..512e36c236 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -35,7 +35,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
     ]
 
     def test_sync_argless(self):
-        request, channel = self.make_request("GET", "/sync")
+        channel = self.make_request("GET", "/sync")
 
         self.assertEqual(channel.code, 200)
         self.assertTrue(
@@ -55,7 +55,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
         """
         self.hs.config.use_presence = False
 
-        request, channel = self.make_request("GET", "/sync")
+        channel = self.make_request("GET", "/sync")
 
         self.assertEqual(channel.code, 200)
         self.assertTrue(
@@ -194,7 +194,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
             tok=tok,
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/sync?filter=%s" % sync_filter, access_token=tok
         )
         self.assertEqual(channel.code, 200, channel.result)
@@ -245,21 +245,19 @@ class SyncTypingTests(unittest.HomeserverTestCase):
         self.helper.send(room, body="There!", tok=other_access_token)
 
         # Start typing.
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             typing_url % (room, other_user_id, other_access_token),
             b'{"typing": true, "timeout": 30000}',
         )
         self.assertEquals(200, channel.code)
 
-        request, channel = self.make_request(
-            "GET", "/sync?access_token=%s" % (access_token,)
-        )
+        channel = self.make_request("GET", "/sync?access_token=%s" % (access_token,))
         self.assertEquals(200, channel.code)
         next_batch = channel.json_body["next_batch"]
 
         # Stop typing.
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             typing_url % (room, other_user_id, other_access_token),
             b'{"typing": false}',
@@ -267,7 +265,7 @@ class SyncTypingTests(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code)
 
         # Start typing.
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             typing_url % (room, other_user_id, other_access_token),
             b'{"typing": true, "timeout": 30000}',
@@ -275,9 +273,7 @@ class SyncTypingTests(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code)
 
         # Should return immediately
-        request, channel = self.make_request(
-            "GET", sync_url % (access_token, next_batch)
-        )
+        channel = self.make_request("GET", sync_url % (access_token, next_batch))
         self.assertEquals(200, channel.code)
         next_batch = channel.json_body["next_batch"]
 
@@ -289,9 +285,7 @@ class SyncTypingTests(unittest.HomeserverTestCase):
         # invalidate the stream token.
         self.helper.send(room, body="There!", tok=other_access_token)
 
-        request, channel = self.make_request(
-            "GET", sync_url % (access_token, next_batch)
-        )
+        channel = self.make_request("GET", sync_url % (access_token, next_batch))
         self.assertEquals(200, channel.code)
         next_batch = channel.json_body["next_batch"]
 
@@ -299,9 +293,7 @@ class SyncTypingTests(unittest.HomeserverTestCase):
         # ahead, and therefore it's saying the typing (that we've actually
         # already seen) is new, since it's got a token above our new, now-reset
         # stream token.
-        request, channel = self.make_request(
-            "GET", sync_url % (access_token, next_batch)
-        )
+        channel = self.make_request("GET", sync_url % (access_token, next_batch))
         self.assertEquals(200, channel.code)
         next_batch = channel.json_body["next_batch"]
 
@@ -383,7 +375,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
 
         # Send a read receipt to tell the server we've read the latest event.
         body = json.dumps({"m.read": res["event_id"]}).encode("utf8")
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/rooms/%s/read_markers" % self.room_id,
             body,
@@ -450,7 +442,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
     def _check_unread_count(self, expected_count: True):
         """Syncs and compares the unread count with the expected value."""
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url % self.next_batch, access_token=self.tok,
         )
 
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index fbcf8d5b86..5e90d656f7 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -39,7 +39,7 @@ from tests.utils import default_config
 class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
         self.http_client = Mock()
-        return self.setup_test_homeserver(http_client=self.http_client)
+        return self.setup_test_homeserver(federation_http_client=self.http_client)
 
     def create_test_resource(self):
         return create_resource_tree(
@@ -172,7 +172,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
             }
         ]
         self.hs2 = self.setup_test_homeserver(
-            http_client=self.http_client2, config=config
+            federation_http_client=self.http_client2, config=config
         )
 
         # wire up outbound POST /key/v2/query requests from hs2 so that they
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 2a3b2a8f27..a6c6985173 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -202,7 +202,6 @@ class MediaRepoTests(unittest.HomeserverTestCase):
 
         config = self.default_config()
         config["media_store_path"] = self.media_store_path
-        config["thumbnail_requirements"] = {}
         config["max_image_pixels"] = 2000000
 
         provider_config = {
@@ -214,7 +213,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         }
         config["media_storage_providers"] = [provider_config]
 
-        hs = self.setup_test_homeserver(config=config, http_client=client)
+        hs = self.setup_test_homeserver(config=config, federation_http_client=client)
 
         return hs
 
@@ -228,7 +227,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
 
     def _req(self, content_disposition):
 
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             FakeSite(self.download_resource),
             "GET",
@@ -313,18 +312,42 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
 
     def test_thumbnail_crop(self):
+        """Test that a cropped remote thumbnail is available."""
         self._test_thumbnail(
             "crop", self.test_image.expected_cropped, self.test_image.expected_found
         )
 
     def test_thumbnail_scale(self):
+        """Test that a scaled remote thumbnail is available."""
         self._test_thumbnail(
             "scale", self.test_image.expected_scaled, self.test_image.expected_found
         )
 
+    def test_invalid_type(self):
+        """An invalid thumbnail type is never available."""
+        self._test_thumbnail("invalid", None, False)
+
+    @unittest.override_config(
+        {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
+    )
+    def test_no_thumbnail_crop(self):
+        """
+        Override the config to generate only scaled thumbnails, but request a cropped one.
+        """
+        self._test_thumbnail("crop", None, False)
+
+    @unittest.override_config(
+        {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
+    )
+    def test_no_thumbnail_scale(self):
+        """
+        Override the config to generate only cropped thumbnails, but request a scaled one.
+        """
+        self._test_thumbnail("scale", None, False)
+
     def _test_thumbnail(self, method, expected_body, expected_found):
         params = "?width=32&height=32&method=" + method
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             FakeSite(self.thumbnail_resource),
             "GET",
@@ -362,3 +385,16 @@ class MediaRepoTests(unittest.HomeserverTestCase):
                     "error": "Not found [b'example.com', b'12345']",
                 },
             )
+
+    def test_x_robots_tag_header(self):
+        """
+        Tests that the `X-Robots-Tag` header is present, which informs web crawlers
+        to not index, archive, or follow links in media.
+        """
+        channel = self._req(b"inline; filename=out" + self.test_image.extension)
+
+        headers = channel.headers
+        self.assertEqual(
+            headers.getRawHeaders(b"X-Robots-Tag"),
+            [b"noindex, nofollow, noarchive, noimageindex"],
+        )
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index ccdc8c2ecf..6968502433 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -18,42 +18,23 @@ import re
 
 from mock import patch
 
-import attr
-
 from twisted.internet._resolver import HostResolution
 from twisted.internet.address import IPv4Address, IPv6Address
 from twisted.internet.error import DNSLookupError
-from twisted.python.failure import Failure
 from twisted.test.proto_helpers import AccumulatingProtocol
-from twisted.web._newclient import ResponseDone
 
 from tests import unittest
 from tests.server import FakeTransport
 
-
-@attr.s
-class FakeResponse:
-    version = attr.ib()
-    code = attr.ib()
-    phrase = attr.ib()
-    headers = attr.ib()
-    body = attr.ib()
-    absoluteURI = attr.ib()
-
-    @property
-    def request(self):
-        @attr.s
-        class FakeTransport:
-            absoluteURI = self.absoluteURI
-
-        return FakeTransport()
-
-    def deliverBody(self, protocol):
-        protocol.dataReceived(self.body)
-        protocol.connectionLost(Failure(ResponseDone()))
+try:
+    import lxml
+except ImportError:
+    lxml = None
 
 
 class URLPreviewTests(unittest.HomeserverTestCase):
+    if not lxml:
+        skip = "url preview feature requires lxml"
 
     hijack_auth = True
     user_id = "@test:user"
@@ -139,7 +120,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
     def test_cache_returns_correct_type(self):
         self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "preview_url?url=http://matrix.org",
             shorthand=False,
@@ -164,7 +145,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         )
 
         # Check the cache returns the correct response
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "preview_url?url=http://matrix.org", shorthand=False
         )
 
@@ -180,7 +161,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertNotIn("http://matrix.org", self.preview_url._cache)
 
         # Check the database cache returns the correct response
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "preview_url?url=http://matrix.org", shorthand=False
         )
 
@@ -201,7 +182,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             b"</head></html>"
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "preview_url?url=http://matrix.org",
             shorthand=False,
@@ -236,7 +217,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             b"</head></html>"
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "preview_url?url=http://matrix.org",
             shorthand=False,
@@ -271,7 +252,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             b"</head></html>"
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "preview_url?url=http://matrix.org",
             shorthand=False,
@@ -304,7 +285,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         """
         self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "preview_url?url=http://example.com",
             shorthand=False,
@@ -334,7 +315,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         """
         self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "preview_url?url=http://example.com", shorthand=False
         )
 
@@ -355,7 +336,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         """
         self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "preview_url?url=http://example.com", shorthand=False
         )
 
@@ -372,7 +353,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         """
         Blacklisted IP addresses, accessed directly, are not spidered.
         """
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "preview_url?url=http://192.168.1.1", shorthand=False
         )
 
@@ -391,7 +372,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         """
         Blacklisted IP ranges, accessed directly, are not spidered.
         """
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "preview_url?url=http://1.1.1.2", shorthand=False
         )
 
@@ -411,7 +392,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         """
         self.lookups["example.com"] = [(IPv4Address, "1.1.1.1")]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "preview_url?url=http://example.com",
             shorthand=False,
@@ -448,7 +429,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             (IPv4Address, "10.1.2.3"),
         ]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "preview_url?url=http://example.com", shorthand=False
         )
         self.assertEqual(channel.code, 502)
@@ -468,7 +449,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             (IPv6Address, "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")
         ]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "preview_url?url=http://example.com", shorthand=False
         )
 
@@ -489,7 +470,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         """
         self.lookups["example.com"] = [(IPv6Address, "2001:800::1")]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "preview_url?url=http://example.com", shorthand=False
         )
 
@@ -506,7 +487,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         """
         OPTIONS returns the OPTIONS.
         """
-        request, channel = self.make_request(
+        channel = self.make_request(
             "OPTIONS", "preview_url?url=http://example.com", shorthand=False
         )
         self.assertEqual(channel.code, 200)
@@ -519,7 +500,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
 
         # Build and make a request to the server
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "preview_url?url=http://example.com",
             shorthand=False,
@@ -593,7 +574,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
                 b"</head></html>"
             )
 
-            request, channel = self.make_request(
+            channel = self.make_request(
                 "GET",
                 "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
                 shorthand=False,
@@ -658,7 +639,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             }
             end_content = json.dumps(result).encode("utf-8")
 
-            request, channel = self.make_request(
+            channel = self.make_request(
                 "GET",
                 "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
                 shorthand=False,
diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py
index 02a46e5fda..32acd93dc1 100644
--- a/tests/rest/test_health.py
+++ b/tests/rest/test_health.py
@@ -25,7 +25,7 @@ class HealthCheckTests(unittest.HomeserverTestCase):
         return HealthResource()
 
     def test_health(self):
-        request, channel = self.make_request("GET", "/health", shorthand=False)
+        channel = self.make_request("GET", "/health", shorthand=False)
 
-        self.assertEqual(request.code, 200)
+        self.assertEqual(channel.code, 200)
         self.assertEqual(channel.result["body"], b"OK")
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index 6a930f4148..14de0921be 100644
--- a/tests/rest/test_well_known.py
+++ b/tests/rest/test_well_known.py
@@ -28,11 +28,11 @@ class WellKnownTests(unittest.HomeserverTestCase):
         self.hs.config.public_baseurl = "https://tesths"
         self.hs.config.default_identity_server = "https://testis"
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/.well-known/matrix/client", shorthand=False
         )
 
-        self.assertEqual(request.code, 200)
+        self.assertEqual(channel.code, 200)
         self.assertEqual(
             channel.json_body,
             {
@@ -44,8 +44,8 @@ class WellKnownTests(unittest.HomeserverTestCase):
     def test_well_known_no_public_baseurl(self):
         self.hs.config.public_baseurl = None
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/.well-known/matrix/client", shorthand=False
         )
 
-        self.assertEqual(request.code, 404)
+        self.assertEqual(channel.code, 404)
diff --git a/tests/server.py b/tests/server.py
index a51ad0c14e..6419c445ec 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -2,7 +2,7 @@ import json
 import logging
 from collections import deque
 from io import SEEK_END, BytesIO
-from typing import Callable, Iterable, Optional, Tuple, Union
+from typing import Callable, Iterable, MutableMapping, Optional, Tuple, Union
 
 import attr
 from typing_extensions import Deque
@@ -47,13 +47,26 @@ class FakeChannel:
     site = attr.ib(type=Site)
     _reactor = attr.ib()
     result = attr.ib(type=dict, default=attr.Factory(dict))
+    _ip = attr.ib(type=str, default="127.0.0.1")
     _producer = None
 
     @property
     def json_body(self):
-        if not self.result:
-            raise Exception("No result yet.")
-        return json.loads(self.result["body"].decode("utf8"))
+        return json.loads(self.text_body)
+
+    @property
+    def text_body(self) -> str:
+        """The body of the result, utf-8-decoded.
+
+        Raises an exception if the request has not yet completed.
+        """
+        if not self.is_finished:
+            raise Exception("Request not yet completed")
+        return self.result["body"].decode("utf8")
+
+    def is_finished(self) -> bool:
+        """check if the response has been completely received"""
+        return self.result.get("done", False)
 
     @property
     def code(self):
@@ -62,7 +75,7 @@ class FakeChannel:
         return int(self.result["code"])
 
     @property
-    def headers(self):
+    def headers(self) -> Headers:
         if not self.result:
             raise Exception("No result yet.")
         h = Headers()
@@ -108,7 +121,7 @@ class FakeChannel:
     def getPeer(self):
         # We give an address so that getClientIP returns a non null entry,
         # causing us to record the MAU
-        return address.IPv4Address("TCP", "127.0.0.1", 3423)
+        return address.IPv4Address("TCP", self._ip, 3423)
 
     def getHost(self):
         return None
@@ -124,7 +137,7 @@ class FakeChannel:
         self._reactor.run()
         x = 0
 
-        while not self.result.get("done"):
+        while not self.is_finished():
             # If there's a producer, tell it to resume producing so we get content
             if self._producer:
                 self._producer.resumeProducing()
@@ -136,6 +149,16 @@ class FakeChannel:
 
             self._reactor.advance(0.1)
 
+    def extract_cookies(self, cookies: MutableMapping[str, str]) -> None:
+        """Process the contents of any Set-Cookie headers in the response
+
+        Any cookines found are added to the given dict
+        """
+        for h in self.headers.getRawHeaders("Set-Cookie"):
+            parts = h.split(";")
+            k, v = parts[0].split("=", maxsplit=1)
+            cookies[k] = v
+
 
 class FakeSite:
     """
@@ -174,11 +197,12 @@ def make_request(
     custom_headers: Optional[
         Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
     ] = None,
-):
+    client_ip: str = "127.0.0.1",
+) -> FakeChannel:
     """
     Make a web request using the given method, path and content, and render it
 
-    Returns the Request and the Channel underneath.
+    Returns the fake Channel object which records the response to the request.
 
     Args:
         site: The twisted Site to use to render the request
@@ -201,8 +225,11 @@ def make_request(
              will pump the reactor until the the renderer tells the channel the request
              is finished.
 
+        client_ip: The IP to use as the requesting IP. Useful for testing
+            ratelimiting.
+
     Returns:
-        Tuple[synapse.http.site.SynapseRequest, channel]
+        channel
     """
     if not isinstance(method, bytes):
         method = method.encode("ascii")
@@ -216,8 +243,9 @@ def make_request(
         and not path.startswith(b"/_matrix")
         and not path.startswith(b"/_synapse")
     ):
+        if path.startswith(b"/"):
+            path = path[1:]
         path = b"/_matrix/client/r0/" + path
-        path = path.replace(b"//", b"/")
 
     if not path.startswith(b"/"):
         path = b"/" + path
@@ -227,7 +255,7 @@ def make_request(
     if isinstance(content, str):
         content = content.encode("utf8")
 
-    channel = FakeChannel(site, reactor)
+    channel = FakeChannel(site, reactor, ip=client_ip)
 
     req = request(channel)
     req.content = BytesIO(content)
@@ -258,12 +286,13 @@ def make_request(
         for k, v in custom_headers:
             req.requestHeaders.addRawHeader(k, v)
 
+    req.parseCookies()
     req.requestReceived(method, path, b"1.1")
 
     if await_result:
         channel.await_result()
 
-    return req, channel
+    return channel
 
 
 @implementer(IReactorPluggableNameResolver)
diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py
index e0a9cd93ac..4dd5a36178 100644
--- a/tests/server_notices/test_consent.py
+++ b/tests/server_notices/test_consent.py
@@ -70,7 +70,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
         the notice URL + an authentication code.
         """
         # Initial sync, to get the user consent room invite
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/_matrix/client/r0/sync", access_token=self.access_token
         )
         self.assertEqual(channel.code, 200)
@@ -79,7 +79,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
         room_id = list(channel.json_body["rooms"]["invite"].keys())[0]
 
         # Join the room
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/r0/rooms/" + room_id + "/join",
             access_token=self.access_token,
@@ -87,7 +87,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200)
 
         # Sync again, to get the message in the room
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/_matrix/client/r0/sync", access_token=self.access_token
         )
         self.assertEqual(channel.code, 200)
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 9c8027a5b2..fea54464af 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -305,7 +305,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
         self.register_user("user", "password")
         tok = self.login("user", "password")
 
-        request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
+        channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
 
         invites = channel.json_body["rooms"]["invite"]
         self.assertEqual(len(invites), 0, invites)
@@ -318,7 +318,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
 
         # Sync again to retrieve the events in the room, so we can check whether this
         # room has a notice in it.
-        request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
+        channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
 
         # Scan the events in the room to search for a message from the server notices
         # user.
@@ -353,9 +353,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
             tok = self.login(localpart, "password")
 
             # Sync with the user's token to mark the user as active.
-            request, channel = self.make_request(
-                "GET", "/sync?timeout=0", access_token=tok,
-            )
+            channel = self.make_request("GET", "/sync?timeout=0", access_token=tok,)
 
             # Also retrieves the list of invites for this user. We don't care about that
             # one except if we're processing the last user, which should have received an
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index ad9bbef9d2..77c72834f2 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -24,7 +24,11 @@ from synapse.api.constants import EventTypes, JoinRules, Membership
 from synapse.api.room_versions import RoomVersions
 from synapse.event_auth import auth_types_for_event
 from synapse.events import make_event_from_dict
-from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store
+from synapse.state.v2 import (
+    _get_auth_chain_difference,
+    lexicographical_topological_sort,
+    resolve_events_with_store,
+)
 from synapse.types import EventID
 
 from tests import unittest
@@ -84,7 +88,7 @@ class FakeEvent:
         event_dict = {
             "auth_events": [(a, {}) for a in auth_events],
             "prev_events": [(p, {}) for p in prev_events],
-            "event_id": self.node_id,
+            "event_id": self.event_id,
             "sender": self.sender,
             "type": self.type,
             "content": self.content,
@@ -377,6 +381,61 @@ class StateTestCase(unittest.TestCase):
 
         self.do_check(events, edges, expected_state_ids)
 
+    def test_mainline_sort(self):
+        """Tests that the mainline ordering works correctly.
+        """
+
+        events = [
+            FakeEvent(
+                id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
+            ),
+            FakeEvent(
+                id="PA1",
+                sender=ALICE,
+                type=EventTypes.PowerLevels,
+                state_key="",
+                content={"users": {ALICE: 100, BOB: 50}},
+            ),
+            FakeEvent(
+                id="T2", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
+            ),
+            FakeEvent(
+                id="PA2",
+                sender=ALICE,
+                type=EventTypes.PowerLevels,
+                state_key="",
+                content={
+                    "users": {ALICE: 100, BOB: 50},
+                    "events": {EventTypes.PowerLevels: 100},
+                },
+            ),
+            FakeEvent(
+                id="PB",
+                sender=BOB,
+                type=EventTypes.PowerLevels,
+                state_key="",
+                content={"users": {ALICE: 100, BOB: 50}},
+            ),
+            FakeEvent(
+                id="T3", sender=BOB, type=EventTypes.Topic, state_key="", content={}
+            ),
+            FakeEvent(
+                id="T4", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
+            ),
+        ]
+
+        edges = [
+            ["END", "T3", "PA2", "T2", "PA1", "T1", "START"],
+            ["END", "T4", "PB", "PA1"],
+        ]
+
+        # We expect T3 to be picked as the other topics are pointing at older
+        # power levels. Note that without mainline ordering we'd pick T4 due to
+        # it being sent *after* T3.
+        expected_state_ids = ["T3", "PA2"]
+
+        self.do_check(events, edges, expected_state_ids)
+
     def do_check(self, events, edges, expected_state_ids):
         """Take a list of events and edges and calculate the state of the
         graph at END, and asserts it matches `expected_state_ids`
@@ -587,6 +646,134 @@ class SimpleParamStateTestCase(unittest.TestCase):
         self.assert_dict(self.expected_combined_state, state)
 
 
+class AuthChainDifferenceTestCase(unittest.TestCase):
+    """We test that `_get_auth_chain_difference` correctly handles unpersisted
+    events.
+    """
+
+    def test_simple(self):
+        # Test getting the auth difference for a simple chain with a single
+        # unpersisted event:
+        #
+        #  Unpersisted | Persisted
+        #              |
+        #           C -|-> B -> A
+
+        a = FakeEvent(
+            id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([], [])
+
+        b = FakeEvent(
+            id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([a.event_id], [])
+
+        c = FakeEvent(
+            id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([b.event_id], [])
+
+        persisted_events = {a.event_id: a, b.event_id: b}
+        unpersited_events = {c.event_id: c}
+
+        state_sets = [{"a": a.event_id, "b": b.event_id}, {"c": c.event_id}]
+
+        store = TestStateResolutionStore(persisted_events)
+
+        diff_d = _get_auth_chain_difference(
+            ROOM_ID, state_sets, unpersited_events, store
+        )
+        difference = self.successResultOf(defer.ensureDeferred(diff_d))
+
+        self.assertEqual(difference, {c.event_id})
+
+    def test_multiple_unpersisted_chain(self):
+        # Test getting the auth difference for a simple chain with multiple
+        # unpersisted events:
+        #
+        #  Unpersisted | Persisted
+        #              |
+        #      D -> C -|-> B -> A
+
+        a = FakeEvent(
+            id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([], [])
+
+        b = FakeEvent(
+            id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([a.event_id], [])
+
+        c = FakeEvent(
+            id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([b.event_id], [])
+
+        d = FakeEvent(
+            id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([c.event_id], [])
+
+        persisted_events = {a.event_id: a, b.event_id: b}
+        unpersited_events = {c.event_id: c, d.event_id: d}
+
+        state_sets = [
+            {"a": a.event_id, "b": b.event_id},
+            {"c": c.event_id, "d": d.event_id},
+        ]
+
+        store = TestStateResolutionStore(persisted_events)
+
+        diff_d = _get_auth_chain_difference(
+            ROOM_ID, state_sets, unpersited_events, store
+        )
+        difference = self.successResultOf(defer.ensureDeferred(diff_d))
+
+        self.assertEqual(difference, {d.event_id, c.event_id})
+
+    def test_unpersisted_events_different_sets(self):
+        # Test getting the auth difference for with multiple unpersisted events
+        # in different branches:
+        #
+        #  Unpersisted | Persisted
+        #              |
+        #     D --> C -|-> B -> A
+        #     E ----^ -|---^
+        #              |
+
+        a = FakeEvent(
+            id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([], [])
+
+        b = FakeEvent(
+            id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([a.event_id], [])
+
+        c = FakeEvent(
+            id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([b.event_id], [])
+
+        d = FakeEvent(
+            id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([c.event_id], [])
+
+        e = FakeEvent(
+            id="E", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([c.event_id, b.event_id], [])
+
+        persisted_events = {a.event_id: a, b.event_id: b}
+        unpersited_events = {c.event_id: c, d.event_id: d, e.event_id: e}
+
+        state_sets = [
+            {"a": a.event_id, "b": b.event_id, "e": e.event_id},
+            {"c": c.event_id, "d": d.event_id},
+        ]
+
+        store = TestStateResolutionStore(persisted_events)
+
+        diff_d = _get_auth_chain_difference(
+            ROOM_ID, state_sets, unpersited_events, store
+        )
+        difference = self.successResultOf(defer.ensureDeferred(diff_d))
+
+        self.assertEqual(difference, {d.event_id, e.event_id})
+
+
 def pairwise(iterable):
     "s -> (s0,s1), (s1,s2), (s2, s3), ..."
     a, b = itertools.tee(iterable)
@@ -647,7 +834,7 @@ class TestStateResolutionStore:
 
         return list(result)
 
-    def get_auth_chain_difference(self, auth_sets):
+    def get_auth_chain_difference(self, room_id, auth_sets):
         chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
 
         common = set(chains[0]).intersection(*chains[1:])
diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py
new file mode 100644
index 0000000000..673e1fe3e3
--- /dev/null
+++ b/tests/storage/test_account_data.py
@@ -0,0 +1,120 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Iterable, Set
+
+from synapse.api.constants import AccountDataTypes
+
+from tests import unittest
+
+
+class IgnoredUsersTestCase(unittest.HomeserverTestCase):
+    def prepare(self, hs, reactor, clock):
+        self.store = self.hs.get_datastore()
+        self.user = "@user:test"
+
+    def _update_ignore_list(
+        self, *ignored_user_ids: Iterable[str], ignorer_user_id: str = None
+    ) -> None:
+        """Update the account data to block the given users."""
+        if ignorer_user_id is None:
+            ignorer_user_id = self.user
+
+        self.get_success(
+            self.store.add_account_data_for_user(
+                ignorer_user_id,
+                AccountDataTypes.IGNORED_USER_LIST,
+                {"ignored_users": {u: {} for u in ignored_user_ids}},
+            )
+        )
+
+    def assert_ignorers(
+        self, ignored_user_id: str, expected_ignorer_user_ids: Set[str]
+    ) -> None:
+        self.assertEqual(
+            self.get_success(self.store.ignored_by(ignored_user_id)),
+            expected_ignorer_user_ids,
+        )
+
+    def test_ignoring_users(self):
+        """Basic adding/removing of users from the ignore list."""
+        self._update_ignore_list("@other:test", "@another:remote")
+
+        # Check a user which no one ignores.
+        self.assert_ignorers("@user:test", set())
+
+        # Check a local user which is ignored.
+        self.assert_ignorers("@other:test", {self.user})
+
+        # Check a remote user which is ignored.
+        self.assert_ignorers("@another:remote", {self.user})
+
+        # Add one user, remove one user, and leave one user.
+        self._update_ignore_list("@foo:test", "@another:remote")
+
+        # Check the removed user.
+        self.assert_ignorers("@other:test", set())
+
+        # Check the added user.
+        self.assert_ignorers("@foo:test", {self.user})
+
+        # Check the removed user.
+        self.assert_ignorers("@another:remote", {self.user})
+
+    def test_caching(self):
+        """Ensure that caching works properly between different users."""
+        # The first user ignores a user.
+        self._update_ignore_list("@other:test")
+        self.assert_ignorers("@other:test", {self.user})
+
+        # The second user ignores them.
+        self._update_ignore_list("@other:test", ignorer_user_id="@second:test")
+        self.assert_ignorers("@other:test", {self.user, "@second:test"})
+
+        # The first user un-ignores them.
+        self._update_ignore_list()
+        self.assert_ignorers("@other:test", {"@second:test"})
+
+    def test_invalid_data(self):
+        """Invalid data ends up clearing out the ignored users list."""
+        # Add some data and ensure it is there.
+        self._update_ignore_list("@other:test")
+        self.assert_ignorers("@other:test", {self.user})
+
+        # No ignored_users key.
+        self.get_success(
+            self.store.add_account_data_for_user(
+                self.user, AccountDataTypes.IGNORED_USER_LIST, {},
+            )
+        )
+
+        # No one ignores the user now.
+        self.assert_ignorers("@other:test", set())
+
+        # Add some data and ensure it is there.
+        self._update_ignore_list("@other:test")
+        self.assert_ignorers("@other:test", {self.user})
+
+        # Invalid data.
+        self.get_success(
+            self.store.add_account_data_for_user(
+                self.user,
+                AccountDataTypes.IGNORED_USER_LIST,
+                {"ignored_users": "unexpected"},
+            )
+        )
+
+        # No one ignores the user now.
+        self.assert_ignorers("@other:test", set())
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index ecb00f4e02..dabc1c5f09 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -80,6 +80,32 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
         )
 
     @defer.inlineCallbacks
+    def test_count_devices_by_users(self):
+        yield defer.ensureDeferred(
+            self.store.store_device("user_id", "device1", "display_name 1")
+        )
+        yield defer.ensureDeferred(
+            self.store.store_device("user_id", "device2", "display_name 2")
+        )
+        yield defer.ensureDeferred(
+            self.store.store_device("user_id2", "device3", "display_name 3")
+        )
+
+        res = yield defer.ensureDeferred(self.store.count_devices_by_users())
+        self.assertEqual(0, res)
+
+        res = yield defer.ensureDeferred(self.store.count_devices_by_users(["unknown"]))
+        self.assertEqual(0, res)
+
+        res = yield defer.ensureDeferred(self.store.count_devices_by_users(["user_id"]))
+        self.assertEqual(2, res)
+
+        res = yield defer.ensureDeferred(
+            self.store.count_devices_by_users(["user_id", "user_id2"])
+        )
+        self.assertEqual(3, res)
+
+    @defer.inlineCallbacks
     def test_get_device_updates_by_remote(self):
         device_ids = ["device_id1", "device_id2"]
 
diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py
index 35dafbb904..3d7760d5d9 100644
--- a/tests/storage/test_e2e_room_keys.py
+++ b/tests/storage/test_e2e_room_keys.py
@@ -26,7 +26,7 @@ room_key = {
 
 class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
-        hs = self.setup_test_homeserver("server", http_client=None)
+        hs = self.setup_test_homeserver("server", federation_http_client=None)
         self.store = hs.get_datastore()
         return hs
 
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
new file mode 100644
index 0000000000..0c46ad595b
--- /dev/null
+++ b/tests/storage/test_event_chain.py
@@ -0,0 +1,741 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, List, Set, Tuple
+
+from twisted.trial import unittest
+
+from synapse.api.constants import EventTypes
+from synapse.api.room_versions import RoomVersions
+from synapse.events import EventBase
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.storage.databases.main.events import _LinkMap
+from synapse.types import create_requester
+
+from tests.unittest import HomeserverTestCase
+
+
+class EventChainStoreTestCase(HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+        self._next_stream_ordering = 1
+
+    def test_simple(self):
+        """Test that the example in `docs/auth_chain_difference_algorithm.md`
+        works.
+        """
+
+        event_factory = self.hs.get_event_builder_factory()
+        bob = "@creator:test"
+        alice = "@alice:test"
+        room_id = "!room:test"
+
+        # Ensure that we have a rooms entry so that we generate the chain index.
+        self.get_success(
+            self.store.store_room(
+                room_id=room_id,
+                room_creator_user_id="",
+                is_public=True,
+                room_version=RoomVersions.V6,
+            )
+        )
+
+        create = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Create,
+                    "state_key": "",
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "create"},
+                },
+            ).build(prev_event_ids=[], auth_event_ids=[])
+        )
+
+        bob_join = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": bob,
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "bob_join"},
+                },
+            ).build(prev_event_ids=[], auth_event_ids=[create.event_id])
+        )
+
+        power = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.PowerLevels,
+                    "state_key": "",
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "power"},
+                },
+            ).build(
+                prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id],
+            )
+        )
+
+        alice_invite = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": alice,
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "alice_invite"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+            )
+        )
+
+        alice_join = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": alice,
+                    "sender": alice,
+                    "room_id": room_id,
+                    "content": {"tag": "alice_join"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id],
+            )
+        )
+
+        power_2 = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.PowerLevels,
+                    "state_key": "",
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "power_2"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+            )
+        )
+
+        bob_join_2 = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": bob,
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "bob_join_2"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+            )
+        )
+
+        alice_join2 = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": alice,
+                    "sender": alice,
+                    "room_id": room_id,
+                    "content": {"tag": "alice_join2"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[
+                    create.event_id,
+                    alice_join.event_id,
+                    power_2.event_id,
+                ],
+            )
+        )
+
+        events = [
+            create,
+            bob_join,
+            power,
+            alice_invite,
+            alice_join,
+            bob_join_2,
+            power_2,
+            alice_join2,
+        ]
+
+        expected_links = [
+            (bob_join, create),
+            (power, create),
+            (power, bob_join),
+            (alice_invite, create),
+            (alice_invite, power),
+            (alice_invite, bob_join),
+            (bob_join_2, power),
+            (alice_join2, power_2),
+        ]
+
+        self.persist(events)
+        chain_map, link_map = self.fetch_chains(events)
+
+        # Check that the expected links and only the expected links have been
+        # added.
+        self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
+
+        for start, end in expected_links:
+            start_id, start_seq = chain_map[start.event_id]
+            end_id, end_seq = chain_map[end.event_id]
+
+            self.assertIn(
+                (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
+            )
+
+        # Test that everything can reach the create event, but the create event
+        # can't reach anything.
+        for event in events[1:]:
+            self.assertTrue(
+                link_map.exists_path_from(
+                    chain_map[event.event_id], chain_map[create.event_id]
+                ),
+            )
+
+            self.assertFalse(
+                link_map.exists_path_from(
+                    chain_map[create.event_id], chain_map[event.event_id],
+                ),
+            )
+
+    def test_out_of_order_events(self):
+        """Test that we handle persisting events that we don't have the full
+        auth chain for yet (which should only happen for out of band memberships).
+        """
+        event_factory = self.hs.get_event_builder_factory()
+        bob = "@creator:test"
+        alice = "@alice:test"
+        room_id = "!room:test"
+
+        # Ensure that we have a rooms entry so that we generate the chain index.
+        self.get_success(
+            self.store.store_room(
+                room_id=room_id,
+                room_creator_user_id="",
+                is_public=True,
+                room_version=RoomVersions.V6,
+            )
+        )
+
+        # First persist the base room.
+        create = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Create,
+                    "state_key": "",
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "create"},
+                },
+            ).build(prev_event_ids=[], auth_event_ids=[])
+        )
+
+        bob_join = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": bob,
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "bob_join"},
+                },
+            ).build(prev_event_ids=[], auth_event_ids=[create.event_id])
+        )
+
+        power = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.PowerLevels,
+                    "state_key": "",
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "power"},
+                },
+            ).build(
+                prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id],
+            )
+        )
+
+        self.persist([create, bob_join, power])
+
+        # Now persist an invite and a couple of memberships out of order.
+        alice_invite = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": alice,
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "alice_invite"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+            )
+        )
+
+        alice_join = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": alice,
+                    "sender": alice,
+                    "room_id": room_id,
+                    "content": {"tag": "alice_join"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id],
+            )
+        )
+
+        alice_join2 = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": alice,
+                    "sender": alice,
+                    "room_id": room_id,
+                    "content": {"tag": "alice_join2"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, alice_join.event_id, power.event_id],
+            )
+        )
+
+        self.persist([alice_join])
+        self.persist([alice_join2])
+        self.persist([alice_invite])
+
+        # The end result should be sane.
+        events = [create, bob_join, power, alice_invite, alice_join]
+
+        chain_map, link_map = self.fetch_chains(events)
+
+        expected_links = [
+            (bob_join, create),
+            (power, create),
+            (power, bob_join),
+            (alice_invite, create),
+            (alice_invite, power),
+            (alice_invite, bob_join),
+        ]
+
+        # Check that the expected links and only the expected links have been
+        # added.
+        self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
+
+        for start, end in expected_links:
+            start_id, start_seq = chain_map[start.event_id]
+            end_id, end_seq = chain_map[end.event_id]
+
+            self.assertIn(
+                (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
+            )
+
+    def persist(
+        self, events: List[EventBase],
+    ):
+        """Persist the given events and check that the links generated match
+        those given.
+        """
+
+        persist_events_store = self.hs.get_datastores().persist_events
+
+        for e in events:
+            e.internal_metadata.stream_ordering = self._next_stream_ordering
+            self._next_stream_ordering += 1
+
+        def _persist(txn):
+            # We need to persist the events to the events and state_events
+            # tables.
+            persist_events_store._store_event_txn(txn, [(e, {}) for e in events])
+
+            # Actually call the function that calculates the auth chain stuff.
+            persist_events_store._persist_event_auth_chain_txn(txn, events)
+
+        self.get_success(
+            persist_events_store.db_pool.runInteraction("_persist", _persist,)
+        )
+
+    def fetch_chains(
+        self, events: List[EventBase]
+    ) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]:
+
+        # Fetch the map from event ID -> (chain ID, sequence number)
+        rows = self.get_success(
+            self.store.db_pool.simple_select_many_batch(
+                table="event_auth_chains",
+                column="event_id",
+                iterable=[e.event_id for e in events],
+                retcols=("event_id", "chain_id", "sequence_number"),
+                keyvalues={},
+            )
+        )
+
+        chain_map = {
+            row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows
+        }
+
+        # Fetch all the links and pass them to the _LinkMap.
+        rows = self.get_success(
+            self.store.db_pool.simple_select_many_batch(
+                table="event_auth_chain_links",
+                column="origin_chain_id",
+                iterable=[chain_id for chain_id, _ in chain_map.values()],
+                retcols=(
+                    "origin_chain_id",
+                    "origin_sequence_number",
+                    "target_chain_id",
+                    "target_sequence_number",
+                ),
+                keyvalues={},
+            )
+        )
+
+        link_map = _LinkMap()
+        for row in rows:
+            added = link_map.add_link(
+                (row["origin_chain_id"], row["origin_sequence_number"]),
+                (row["target_chain_id"], row["target_sequence_number"]),
+            )
+
+            # We shouldn't have persisted any redundant links
+            self.assertTrue(added)
+
+        return chain_map, link_map
+
+
+class LinkMapTestCase(unittest.TestCase):
+    def test_simple(self):
+        """Basic tests for the LinkMap.
+        """
+        link_map = _LinkMap()
+
+        link_map.add_link((1, 1), (2, 1), new=False)
+        self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
+        self.assertCountEqual(link_map.get_links_from((1, 1)), [(2, 1)])
+        self.assertCountEqual(link_map.get_additions(), [])
+        self.assertTrue(link_map.exists_path_from((1, 5), (2, 1)))
+        self.assertFalse(link_map.exists_path_from((1, 5), (2, 2)))
+        self.assertTrue(link_map.exists_path_from((1, 5), (1, 1)))
+        self.assertFalse(link_map.exists_path_from((1, 1), (1, 5)))
+
+        # Attempting to add a redundant link is ignored.
+        self.assertFalse(link_map.add_link((1, 4), (2, 1)))
+        self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
+
+        # Adding new non-redundant links works
+        self.assertTrue(link_map.add_link((1, 3), (2, 3)))
+        self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
+
+        self.assertTrue(link_map.add_link((2, 5), (1, 3)))
+        self.assertCountEqual(link_map.get_links_between(2, 1), [(5, 3)])
+        self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
+
+        self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)])
+
+
+class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
+
+    servlets = [
+        admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+        self.user_id = self.register_user("foo", "pass")
+        self.token = self.login("foo", "pass")
+        self.requester = create_requester(self.user_id)
+
+    def _generate_room(self) -> Tuple[str, List[Set[str]]]:
+        """Insert a room without a chain cover index.
+        """
+        room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+        # Mark the room as not having a chain cover index
+        self.get_success(
+            self.store.db_pool.simple_update(
+                table="rooms",
+                keyvalues={"room_id": room_id},
+                updatevalues={"has_auth_chain_index": False},
+                desc="test",
+            )
+        )
+
+        # Create a fork in the DAG with different events.
+        event_handler = self.hs.get_event_creation_handler()
+        latest_event_ids = self.get_success(
+            self.store.get_prev_events_for_room(room_id)
+        )
+        event, context = self.get_success(
+            event_handler.create_event(
+                self.requester,
+                {
+                    "type": "some_state_type",
+                    "state_key": "",
+                    "content": {},
+                    "room_id": room_id,
+                    "sender": self.user_id,
+                },
+                prev_event_ids=latest_event_ids,
+            )
+        )
+        self.get_success(
+            event_handler.handle_new_client_event(self.requester, event, context)
+        )
+        state1 = set(self.get_success(context.get_current_state_ids()).values())
+
+        event, context = self.get_success(
+            event_handler.create_event(
+                self.requester,
+                {
+                    "type": "some_state_type",
+                    "state_key": "",
+                    "content": {},
+                    "room_id": room_id,
+                    "sender": self.user_id,
+                },
+                prev_event_ids=latest_event_ids,
+            )
+        )
+        self.get_success(
+            event_handler.handle_new_client_event(self.requester, event, context)
+        )
+        state2 = set(self.get_success(context.get_current_state_ids()).values())
+
+        # Delete the chain cover info.
+
+        def _delete_tables(txn):
+            txn.execute("DELETE FROM event_auth_chains")
+            txn.execute("DELETE FROM event_auth_chain_links")
+
+        self.get_success(self.store.db_pool.runInteraction("test", _delete_tables))
+
+        return room_id, [state1, state2]
+
+    def test_background_update_single_room(self):
+        """Test that the background update to calculate auth chains for historic
+        rooms works correctly.
+        """
+
+        # Create a room
+        room_id, states = self._generate_room()
+
+        # Insert and run the background update.
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                {"update_name": "chain_cover", "progress_json": "{}"},
+            )
+        )
+
+        # Ugh, have to reset this flag
+        self.store.db_pool.updates._all_done = False
+
+        while not self.get_success(
+            self.store.db_pool.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
+            )
+
+        # Test that the `has_auth_chain_index` has been set
+        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))
+
+        # Test that calculating the auth chain difference using the newly
+        # calculated chain cover works.
+        self.get_success(
+            self.store.db_pool.runInteraction(
+                "test",
+                self.store._get_auth_chain_difference_using_cover_index_txn,
+                room_id,
+                states,
+            )
+        )
+
+    def test_background_update_multiple_rooms(self):
+        """Test that the background update to calculate auth chains for historic
+        rooms works correctly.
+        """
+        # Create a room
+        room_id1, states1 = self._generate_room()
+        room_id2, states2 = self._generate_room()
+        room_id3, states2 = self._generate_room()
+
+        # Insert and run the background update.
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                {"update_name": "chain_cover", "progress_json": "{}"},
+            )
+        )
+
+        # Ugh, have to reset this flag
+        self.store.db_pool.updates._all_done = False
+
+        while not self.get_success(
+            self.store.db_pool.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
+            )
+
+        # Test that the `has_auth_chain_index` has been set
+        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
+        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))
+        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id3)))
+
+        # Test that calculating the auth chain difference using the newly
+        # calculated chain cover works.
+        self.get_success(
+            self.store.db_pool.runInteraction(
+                "test",
+                self.store._get_auth_chain_difference_using_cover_index_txn,
+                room_id1,
+                states1,
+            )
+        )
+
+    def test_background_update_single_large_room(self):
+        """Test that the background update to calculate auth chains for historic
+        rooms works correctly.
+        """
+
+        # Create a room
+        room_id, states = self._generate_room()
+
+        # Add a bunch of state so that it takes multiple iterations of the
+        # background update to process the room.
+        for i in range(0, 150):
+            self.helper.send_state(
+                room_id, event_type="m.test", body={"index": i}, tok=self.token
+            )
+
+        # Insert and run the background update.
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                {"update_name": "chain_cover", "progress_json": "{}"},
+            )
+        )
+
+        # Ugh, have to reset this flag
+        self.store.db_pool.updates._all_done = False
+
+        iterations = 0
+        while not self.get_success(
+            self.store.db_pool.updates.has_completed_background_updates()
+        ):
+            iterations += 1
+            self.get_success(
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
+            )
+
+        # Ensure that we did actually take multiple iterations to process the
+        # room.
+        self.assertGreater(iterations, 1)
+
+        # Test that the `has_auth_chain_index` has been set
+        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))
+
+        # Test that calculating the auth chain difference using the newly
+        # calculated chain cover works.
+        self.get_success(
+            self.store.db_pool.runInteraction(
+                "test",
+                self.store._get_auth_chain_difference_using_cover_index_txn,
+                room_id,
+                states,
+            )
+        )
+
+    def test_background_update_multiple_large_room(self):
+        """Test that the background update to calculate auth chains for historic
+        rooms works correctly.
+        """
+
+        # Create the rooms
+        room_id1, _ = self._generate_room()
+        room_id2, _ = self._generate_room()
+
+        # Add a bunch of state so that it takes multiple iterations of the
+        # background update to process the room.
+        for i in range(0, 150):
+            self.helper.send_state(
+                room_id1, event_type="m.test", body={"index": i}, tok=self.token
+            )
+
+        for i in range(0, 150):
+            self.helper.send_state(
+                room_id2, event_type="m.test", body={"index": i}, tok=self.token
+            )
+
+        # Insert and run the background update.
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                {"update_name": "chain_cover", "progress_json": "{}"},
+            )
+        )
+
+        # Ugh, have to reset this flag
+        self.store.db_pool.updates._all_done = False
+
+        iterations = 0
+        while not self.get_success(
+            self.store.db_pool.updates.has_completed_background_updates()
+        ):
+            iterations += 1
+            self.get_success(
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
+            )
+
+        # Ensure that we did actually take multiple iterations to process the
+        # room.
+        self.assertGreater(iterations, 1)
+
+        # Test that the `has_auth_chain_index` has been set
+        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
+        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index d4c3b867e3..9d04a066d8 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -13,6 +13,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import attr
+from parameterized import parameterized
+
+from synapse.events import _EventInternalMetadata
+
 import tests.unittest
 import tests.utils
 
@@ -113,7 +118,8 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
         self.assertTrue(r == [room2] or r == [room3])
 
-    def test_auth_difference(self):
+    @parameterized.expand([(True,), (False,)])
+    def test_auth_difference(self, use_chain_cover_index: bool):
         room_id = "@ROOM:local"
 
         # The silly auth graph we use to test the auth difference algorithm,
@@ -159,77 +165,279 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
             "j": 1,
         }
 
+        # Mark the room as not having a cover index
+
+        def store_room(txn):
+            self.store.db_pool.simple_insert_txn(
+                txn,
+                "rooms",
+                {
+                    "room_id": room_id,
+                    "creator": "room_creator_user_id",
+                    "is_public": True,
+                    "room_version": "6",
+                    "has_auth_chain_index": use_chain_cover_index,
+                },
+            )
+
+        self.get_success(self.store.db_pool.runInteraction("store_room", store_room))
+
         # We rudely fiddle with the appropriate tables directly, as that's much
         # easier than constructing events properly.
 
-        def insert_event(txn, event_id, stream_ordering):
+        def insert_event(txn):
+            stream_ordering = 0
+
+            for event_id in auth_graph:
+                stream_ordering += 1
+                depth = depth_map[event_id]
+
+                self.store.db_pool.simple_insert_txn(
+                    txn,
+                    table="events",
+                    values={
+                        "event_id": event_id,
+                        "room_id": room_id,
+                        "depth": depth,
+                        "topological_ordering": depth,
+                        "type": "m.test",
+                        "processed": True,
+                        "outlier": False,
+                        "stream_ordering": stream_ordering,
+                    },
+                )
+
+            self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+                txn,
+                [
+                    FakeEvent(event_id, room_id, auth_graph[event_id])
+                    for event_id in auth_graph
+                ],
+            )
+
+        self.get_success(self.store.db_pool.runInteraction("insert", insert_event,))
+
+        # Now actually test that various combinations give the right result:
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
+        )
+        self.assertSetEqual(difference, {"a", "b"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "c"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}])
+        )
+        self.assertSetEqual(difference, {"a", "b"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "d", "e"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}])
+        )
+        self.assertSetEqual(difference, {"a", "b"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}])
+        )
+        self.assertSetEqual(difference, set())
+
+    def test_auth_difference_partial_cover(self):
+        """Test that we correctly handle rooms where not all events have a chain
+        cover calculated. This can happen in some obscure edge cases, including
+        during the background update that calculates the chain cover for old
+        rooms.
+        """
+
+        room_id = "@ROOM:local"
+
+        # The silly auth graph we use to test the auth difference algorithm,
+        # where the top are the most recent events.
+        #
+        #   A   B
+        #    \ /
+        #  D  E
+        #  \  |
+        #   ` F   C
+        #     |  /|
+        #     G ´ |
+        #     | \ |
+        #     H   I
+        #     |   |
+        #     K   J
+
+        auth_graph = {
+            "a": ["e"],
+            "b": ["e"],
+            "c": ["g", "i"],
+            "d": ["f"],
+            "e": ["f"],
+            "f": ["g"],
+            "g": ["h", "i"],
+            "h": ["k"],
+            "i": ["j"],
+            "k": [],
+            "j": [],
+        }
 
-            depth = depth_map[event_id]
+        depth_map = {
+            "a": 7,
+            "b": 7,
+            "c": 4,
+            "d": 6,
+            "e": 6,
+            "f": 5,
+            "g": 3,
+            "h": 2,
+            "i": 2,
+            "k": 1,
+            "j": 1,
+        }
+
+        # We rudely fiddle with the appropriate tables directly, as that's much
+        # easier than constructing events properly.
 
+        def insert_event(txn):
+            # First insert the room and mark it as having a chain cover.
             self.store.db_pool.simple_insert_txn(
                 txn,
-                table="events",
-                values={
-                    "event_id": event_id,
+                "rooms",
+                {
                     "room_id": room_id,
-                    "depth": depth,
-                    "topological_ordering": depth,
-                    "type": "m.test",
-                    "processed": True,
-                    "outlier": False,
-                    "stream_ordering": stream_ordering,
+                    "creator": "room_creator_user_id",
+                    "is_public": True,
+                    "room_version": "6",
+                    "has_auth_chain_index": True,
                 },
             )
 
-            self.store.db_pool.simple_insert_many_txn(
+            stream_ordering = 0
+
+            for event_id in auth_graph:
+                stream_ordering += 1
+                depth = depth_map[event_id]
+
+                self.store.db_pool.simple_insert_txn(
+                    txn,
+                    table="events",
+                    values={
+                        "event_id": event_id,
+                        "room_id": room_id,
+                        "depth": depth,
+                        "topological_ordering": depth,
+                        "type": "m.test",
+                        "processed": True,
+                        "outlier": False,
+                        "stream_ordering": stream_ordering,
+                    },
+                )
+
+            # Insert all events apart from 'B'
+            self.hs.datastores.persist_events._persist_event_auth_chain_txn(
                 txn,
-                table="event_auth",
-                values=[
-                    {"event_id": event_id, "room_id": room_id, "auth_id": a}
-                    for a in auth_graph[event_id]
+                [
+                    FakeEvent(event_id, room_id, auth_graph[event_id])
+                    for event_id in auth_graph
+                    if event_id != "b"
                 ],
             )
 
-        next_stream_ordering = 0
-        for event_id in auth_graph:
-            next_stream_ordering += 1
-            self.get_success(
-                self.store.db_pool.runInteraction(
-                    "insert", insert_event, event_id, next_stream_ordering
-                )
+            # Now we insert the event 'B' without a chain cover, by temporarily
+            # pretending the room doesn't have a chain cover.
+
+            self.store.db_pool.simple_update_txn(
+                txn,
+                table="rooms",
+                keyvalues={"room_id": room_id},
+                updatevalues={"has_auth_chain_index": False},
+            )
+
+            self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+                txn, [FakeEvent("b", room_id, auth_graph["b"])],
+            )
+
+            self.store.db_pool.simple_update_txn(
+                txn,
+                table="rooms",
+                keyvalues={"room_id": room_id},
+                updatevalues={"has_auth_chain_index": True},
             )
 
+        self.get_success(self.store.db_pool.runInteraction("insert", insert_event,))
+
         # Now actually test that various combinations give the right result:
 
         difference = self.get_success(
-            self.store.get_auth_chain_difference([{"a"}, {"b"}])
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
         )
         self.assertSetEqual(difference, {"a", "b"})
 
         difference = self.get_success(
-            self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}])
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}])
         )
         self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
 
         difference = self.get_success(
-            self.store.get_auth_chain_difference([{"a", "c"}, {"b"}])
+            self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}])
         )
         self.assertSetEqual(difference, {"a", "b", "c"})
 
         difference = self.get_success(
-            self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}])
+            self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}])
+        )
+        self.assertSetEqual(difference, {"a", "b"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}])
         )
         self.assertSetEqual(difference, {"a", "b", "d", "e"})
 
         difference = self.get_success(
-            self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}])
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}])
         )
         self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
 
         difference = self.get_success(
-            self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}])
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}])
         )
         self.assertSetEqual(difference, {"a", "b"})
 
-        difference = self.get_success(self.store.get_auth_chain_difference([{"a"}]))
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}])
+        )
         self.assertSetEqual(difference, set())
+
+
+@attr.s
+class FakeEvent:
+    event_id = attr.ib()
+    room_id = attr.ib()
+    auth_events = attr.ib()
+
+    type = "foo"
+    state_key = "foo"
+
+    internal_metadata = _EventInternalMetadata({})
+
+    def auth_event_ids(self):
+        return self.auth_events
+
+    def is_state(self):
+        return True
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
new file mode 100644
index 0000000000..71210ce606
--- /dev/null
+++ b/tests/storage/test_events.py
@@ -0,0 +1,334 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.room_versions import RoomVersions
+from synapse.federation.federation_base import event_from_pdu_json
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests.unittest import HomeserverTestCase
+
+
+class ExtremPruneTestCase(HomeserverTestCase):
+    servlets = [
+        admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, homeserver):
+        self.state = self.hs.get_state_handler()
+        self.persistence = self.hs.get_storage().persistence
+        self.store = self.hs.get_datastore()
+
+        self.register_user("user", "pass")
+        self.token = self.login("user", "pass")
+
+        self.room_id = self.helper.create_room_as(
+            "user", room_version=RoomVersions.V6.identifier, tok=self.token
+        )
+
+        body = self.helper.send(self.room_id, body="Test", tok=self.token)
+        local_message_event_id = body["event_id"]
+
+        # Fudge a remote event and persist it. This will be the extremity before
+        # the gap.
+        self.remote_event_1 = event_from_pdu_json(
+            {
+                "type": EventTypes.Message,
+                "state_key": "@user:other",
+                "content": {},
+                "room_id": self.room_id,
+                "sender": "@user:other",
+                "depth": 5,
+                "prev_events": [local_message_event_id],
+                "auth_events": [],
+                "origin_server_ts": self.clock.time_msec(),
+            },
+            RoomVersions.V6,
+        )
+
+        self.persist_event(self.remote_event_1)
+
+        # Check that the current extremities is the remote event.
+        self.assert_extremities([self.remote_event_1.event_id])
+
+    def persist_event(self, event, state=None):
+        """Persist the event, with optional state
+        """
+        context = self.get_success(
+            self.state.compute_event_context(event, old_state=state)
+        )
+        self.get_success(self.persistence.persist_event(event, context))
+
+    def assert_extremities(self, expected_extremities):
+        """Assert the current extremities for the room
+        """
+        extremities = self.get_success(
+            self.store.get_prev_events_for_room(self.room_id)
+        )
+        self.assertCountEqual(extremities, expected_extremities)
+
+    def test_prune_gap(self):
+        """Test that we drop extremities after a gap when we see an event from
+        the same domain.
+        """
+
+        # Fudge a second event which points to an event we don't have. This is a
+        # state event so that the state changes (otherwise we won't prune the
+        # extremity as they'll have the same state group).
+        remote_event_2 = event_from_pdu_json(
+            {
+                "type": EventTypes.Member,
+                "state_key": "@user:other",
+                "content": {"membership": Membership.JOIN},
+                "room_id": self.room_id,
+                "sender": "@user:other",
+                "depth": 50,
+                "prev_events": ["$some_unknown_message"],
+                "auth_events": [],
+                "origin_server_ts": self.clock.time_msec(),
+            },
+            RoomVersions.V6,
+        )
+
+        state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+
+        self.persist_event(remote_event_2, state=state_before_gap.values())
+
+        # Check the new extremity is just the new remote event.
+        self.assert_extremities([remote_event_2.event_id])
+
+    def test_do_not_prune_gap_if_state_different(self):
+        """Test that we don't prune extremities after a gap if the resolved
+        state is different.
+        """
+
+        # Fudge a second event which points to an event we don't have.
+        remote_event_2 = event_from_pdu_json(
+            {
+                "type": EventTypes.Message,
+                "state_key": "@user:other",
+                "content": {},
+                "room_id": self.room_id,
+                "sender": "@user:other",
+                "depth": 10,
+                "prev_events": ["$some_unknown_message"],
+                "auth_events": [],
+                "origin_server_ts": self.clock.time_msec(),
+            },
+            RoomVersions.V6,
+        )
+
+        # Now we persist it with state with a dropped history visibility
+        # setting. The state resolution across the old and new event will then
+        # include it, and so the resolved state won't match the new state.
+        state_before_gap = dict(
+            self.get_success(self.state.get_current_state(self.room_id))
+        )
+        state_before_gap.pop(("m.room.history_visibility", ""))
+
+        context = self.get_success(
+            self.state.compute_event_context(
+                remote_event_2, old_state=state_before_gap.values()
+            )
+        )
+
+        self.get_success(self.persistence.persist_event(remote_event_2, context))
+
+        # Check that we haven't dropped the old extremity.
+        self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
+
+    def test_prune_gap_if_old(self):
+        """Test that we drop extremities after a gap when the previous extremity
+        is "old"
+        """
+
+        # Advance the clock for many days to make the old extremity "old". We
+        # also set the depth to "lots".
+        self.reactor.advance(7 * 24 * 60 * 60)
+
+        # Fudge a second event which points to an event we don't have. This is a
+        # state event so that the state changes (otherwise we won't prune the
+        # extremity as they'll have the same state group).
+        remote_event_2 = event_from_pdu_json(
+            {
+                "type": EventTypes.Member,
+                "state_key": "@user:other2",
+                "content": {"membership": Membership.JOIN},
+                "room_id": self.room_id,
+                "sender": "@user:other2",
+                "depth": 10000,
+                "prev_events": ["$some_unknown_message"],
+                "auth_events": [],
+                "origin_server_ts": self.clock.time_msec(),
+            },
+            RoomVersions.V6,
+        )
+
+        state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+
+        self.persist_event(remote_event_2, state=state_before_gap.values())
+
+        # Check the new extremity is just the new remote event.
+        self.assert_extremities([remote_event_2.event_id])
+
+    def test_do_not_prune_gap_if_other_server(self):
+        """Test that we do not drop extremities after a gap when we see an event
+        from a different domain.
+        """
+
+        # Fudge a second event which points to an event we don't have. This is a
+        # state event so that the state changes (otherwise we won't prune the
+        # extremity as they'll have the same state group).
+        remote_event_2 = event_from_pdu_json(
+            {
+                "type": EventTypes.Member,
+                "state_key": "@user:other2",
+                "content": {"membership": Membership.JOIN},
+                "room_id": self.room_id,
+                "sender": "@user:other2",
+                "depth": 10,
+                "prev_events": ["$some_unknown_message"],
+                "auth_events": [],
+                "origin_server_ts": self.clock.time_msec(),
+            },
+            RoomVersions.V6,
+        )
+
+        state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+
+        self.persist_event(remote_event_2, state=state_before_gap.values())
+
+        # Check the new extremity is just the new remote event.
+        self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
+
+    def test_prune_gap_if_dummy_remote(self):
+        """Test that we drop extremities after a gap when the previous extremity
+        is a local dummy event and only points to remote events.
+        """
+
+        body = self.helper.send_event(
+            self.room_id, type=EventTypes.Dummy, content={}, tok=self.token
+        )
+        local_message_event_id = body["event_id"]
+        self.assert_extremities([local_message_event_id])
+
+        # Advance the clock for many days to make the old extremity "old". We
+        # also set the depth to "lots".
+        self.reactor.advance(7 * 24 * 60 * 60)
+
+        # Fudge a second event which points to an event we don't have. This is a
+        # state event so that the state changes (otherwise we won't prune the
+        # extremity as they'll have the same state group).
+        remote_event_2 = event_from_pdu_json(
+            {
+                "type": EventTypes.Member,
+                "state_key": "@user:other2",
+                "content": {"membership": Membership.JOIN},
+                "room_id": self.room_id,
+                "sender": "@user:other2",
+                "depth": 10000,
+                "prev_events": ["$some_unknown_message"],
+                "auth_events": [],
+                "origin_server_ts": self.clock.time_msec(),
+            },
+            RoomVersions.V6,
+        )
+
+        state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+
+        self.persist_event(remote_event_2, state=state_before_gap.values())
+
+        # Check the new extremity is just the new remote event.
+        self.assert_extremities([remote_event_2.event_id])
+
+    def test_prune_gap_if_dummy_local(self):
+        """Test that we don't drop extremities after a gap when the previous
+        extremity is a local dummy event and points to local events.
+        """
+
+        body = self.helper.send(self.room_id, body="Test", tok=self.token)
+
+        body = self.helper.send_event(
+            self.room_id, type=EventTypes.Dummy, content={}, tok=self.token
+        )
+        local_message_event_id = body["event_id"]
+        self.assert_extremities([local_message_event_id])
+
+        # Advance the clock for many days to make the old extremity "old". We
+        # also set the depth to "lots".
+        self.reactor.advance(7 * 24 * 60 * 60)
+
+        # Fudge a second event which points to an event we don't have. This is a
+        # state event so that the state changes (otherwise we won't prune the
+        # extremity as they'll have the same state group).
+        remote_event_2 = event_from_pdu_json(
+            {
+                "type": EventTypes.Member,
+                "state_key": "@user:other2",
+                "content": {"membership": Membership.JOIN},
+                "room_id": self.room_id,
+                "sender": "@user:other2",
+                "depth": 10000,
+                "prev_events": ["$some_unknown_message"],
+                "auth_events": [],
+                "origin_server_ts": self.clock.time_msec(),
+            },
+            RoomVersions.V6,
+        )
+
+        state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+
+        self.persist_event(remote_event_2, state=state_before_gap.values())
+
+        # Check the new extremity is just the new remote event.
+        self.assert_extremities([remote_event_2.event_id, local_message_event_id])
+
+    def test_do_not_prune_gap_if_not_dummy(self):
+        """Test that we do not drop extremities after a gap when the previous extremity
+        is not a dummy event.
+        """
+
+        body = self.helper.send(self.room_id, body="test", tok=self.token)
+        local_message_event_id = body["event_id"]
+        self.assert_extremities([local_message_event_id])
+
+        # Fudge a second event which points to an event we don't have. This is a
+        # state event so that the state changes (otherwise we won't prune the
+        # extremity as they'll have the same state group).
+        remote_event_2 = event_from_pdu_json(
+            {
+                "type": EventTypes.Member,
+                "state_key": "@user:other2",
+                "content": {"membership": Membership.JOIN},
+                "room_id": self.room_id,
+                "sender": "@user:other2",
+                "depth": 10000,
+                "prev_events": ["$some_unknown_message"],
+                "auth_events": [],
+                "origin_server_ts": self.clock.time_msec(),
+            },
+            RoomVersions.V6,
+        )
+
+        state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+
+        self.persist_event(remote_event_2, state=state_before_gap.values())
+
+        # Check the new extremity is just the new remote event.
+        self.assert_extremities([local_message_event_id, remote_event_2.event_id])
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index cc0612cf65..3e2fd4da01 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -51,9 +51,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
                 self.db_pool,
                 stream_name="test_stream",
                 instance_name=instance_name,
-                table="foobar",
-                instance_column="instance_name",
-                id_column="stream_id",
+                tables=[("foobar", "instance_name", "stream_id")],
                 sequence_name="foobar_seq",
                 writers=writers,
             )
@@ -487,9 +485,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
                 self.db_pool,
                 stream_name="test_stream",
                 instance_name=instance_name,
-                table="foobar",
-                instance_column="instance_name",
-                id_column="stream_id",
+                tables=[("foobar", "instance_name", "stream_id")],
                 sequence_name="foobar_seq",
                 writers=writers,
                 positive=False,
@@ -579,3 +575,107 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
         self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
         self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
+
+
+class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
+    if not USE_POSTGRES_FOR_TESTS:
+        skip = "Requires Postgres"
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+        self.db_pool = self.store.db_pool  # type: DatabasePool
+
+        self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
+
+    def _setup_db(self, txn):
+        txn.execute("CREATE SEQUENCE foobar_seq")
+        txn.execute(
+            """
+            CREATE TABLE foobar1 (
+                stream_id BIGINT NOT NULL,
+                instance_name TEXT NOT NULL,
+                data TEXT
+            );
+            """
+        )
+
+        txn.execute(
+            """
+            CREATE TABLE foobar2 (
+                stream_id BIGINT NOT NULL,
+                instance_name TEXT NOT NULL,
+                data TEXT
+            );
+            """
+        )
+
+    def _create_id_generator(
+        self, instance_name="master", writers=["master"]
+    ) -> MultiWriterIdGenerator:
+        def _create(conn):
+            return MultiWriterIdGenerator(
+                conn,
+                self.db_pool,
+                stream_name="test_stream",
+                instance_name=instance_name,
+                tables=[
+                    ("foobar1", "instance_name", "stream_id"),
+                    ("foobar2", "instance_name", "stream_id"),
+                ],
+                sequence_name="foobar_seq",
+                writers=writers,
+            )
+
+        return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
+
+    def _insert_rows(
+        self,
+        table: str,
+        instance_name: str,
+        number: int,
+        update_stream_table: bool = True,
+    ):
+        """Insert N rows as the given instance, inserting with stream IDs pulled
+        from the postgres sequence.
+        """
+
+        def _insert(txn):
+            for _ in range(number):
+                txn.execute(
+                    "INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,),
+                    (instance_name,),
+                )
+                if update_stream_table:
+                    txn.execute(
+                        """
+                        INSERT INTO stream_positions VALUES ('test_stream', ?,  lastval())
+                        ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
+                        """,
+                        (instance_name,),
+                    )
+
+        self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
+
+    def test_load_existing_stream(self):
+        """Test creating ID gens with multiple tables that have rows from after
+        the position in `stream_positions` table.
+        """
+        self._insert_rows("foobar1", "first", 3)
+        self._insert_rows("foobar2", "second", 3)
+        self._insert_rows("foobar2", "second", 1, update_stream_table=False)
+
+        first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+        second_id_gen = self._create_id_generator("second", writers=["first", "second"])
+
+        # The first ID gen will notice that it can advance its token to 7 as it
+        # has no in progress writes...
+        self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 6})
+        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
+        self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6)
+        self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
+
+        # ... but the second ID gen doesn't know that.
+        self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
+        self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index 7e7f1286d9..e9e3bca3bf 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -48,3 +48,10 @@ class DataStoreTestCase(unittest.TestCase):
 
         self.assertEquals(1, total)
         self.assertEquals(self.displayname, users.pop()["displayname"])
+
+        users, total = yield defer.ensureDeferred(
+            self.store.get_users_paginate(0, 10, name="BC", guests=False)
+        )
+
+        self.assertEquals(1, total)
+        self.assertEquals(self.displayname, users.pop()["displayname"])
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 3fd0a38cf5..ea63bd56b4 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -48,6 +48,19 @@ class ProfileStoreTestCase(unittest.TestCase):
             ),
         )
 
+        # test set to None
+        yield defer.ensureDeferred(
+            self.store.set_profile_displayname(self.u_frank.localpart, None)
+        )
+
+        self.assertIsNone(
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_displayname(self.u_frank.localpart)
+                )
+            )
+        )
+
     @defer.inlineCallbacks
     def test_avatar_url(self):
         yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
@@ -66,3 +79,16 @@ class ProfileStoreTestCase(unittest.TestCase):
                 )
             ),
         )
+
+        # test set to None
+        yield defer.ensureDeferred(
+            self.store.set_profile_avatar_url(self.u_frank.localpart, None)
+        )
+
+        self.assertIsNone(
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_avatar_url(self.u_frank.localpart)
+                )
+            )
+        )
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index cc1f3c53c5..a06ad2c03e 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -27,7 +27,7 @@ class PurgeTests(HomeserverTestCase):
     servlets = [room.register_servlets]
 
     def make_homeserver(self, reactor, clock):
-        hs = self.setup_test_homeserver("server", http_client=None)
+        hs = self.setup_test_homeserver("server", federation_http_client=None)
         return hs
 
     def prepare(self, reactor, clock, hs):
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index d4f9e809db..a6303bf0ee 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -14,9 +14,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
-from mock import Mock
-
 from canonicaljson import json
 
 from twisted.internet import defer
@@ -30,12 +27,10 @@ from tests.utils import create_room
 
 
 class RedactionTestCase(unittest.HomeserverTestCase):
-    def make_homeserver(self, reactor, clock):
-        config = self.default_config()
+    def default_config(self):
+        config = super().default_config()
         config["redaction_retention_period"] = "30d"
-        return self.setup_test_homeserver(
-            resource_for_federation=Mock(), http_client=None, config=config
-        )
+        return config
 
     def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index ff972daeaa..d2aed66f6d 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -14,8 +14,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from unittest.mock import Mock
-
 from synapse.api.constants import Membership
 from synapse.rest.admin import register_servlets_for_client_rest_resource
 from synapse.rest.client.v1 import login, room
@@ -34,12 +32,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
-        hs = self.setup_test_homeserver(
-            resource_for_federation=Mock(), http_client=None
-        )
-        return hs
-
     def prepare(self, reactor, clock, hs: TestHomeServer):
 
         # We can't test the RoomMemberStore on its own without the other event
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 738e912468..a6f63f4aaf 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -21,6 +21,8 @@ from tests.utils import setup_test_homeserver
 ALICE = "@alice:a"
 BOB = "@bob:b"
 BOBBY = "@bobby:a"
+# The localpart isn't 'Bela' on purpose so we can test looking up display names.
+BELA = "@somenickname:a"
 
 
 class UserDirectoryStoreTestCase(unittest.TestCase):
@@ -41,6 +43,9 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
             self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
         )
         yield defer.ensureDeferred(
+            self.store.update_profile_in_user_dir(BELA, "Bela", None)
+        )
+        yield defer.ensureDeferred(
             self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
         )
 
@@ -72,3 +77,21 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
             )
         finally:
             self.hs.config.user_directory_search_all_users = False
+
+    @defer.inlineCallbacks
+    def test_search_user_dir_stop_words(self):
+        """Tests that a user can look up another user by searching for the start if its
+        display name even if that name happens to be a common English word that would
+        usually be ignored in full text searches.
+        """
+        self.hs.config.user_directory_search_all_users = True
+        try:
+            r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "be", 10))
+            self.assertFalse(r["limited"])
+            self.assertEqual(1, len(r["results"]))
+            self.assertDictEqual(
+                r["results"][0],
+                {"user_id": BELA, "display_name": "Bela", "avatar_url": None},
+            )
+        finally:
+            self.hs.config.user_directory_search_all_users = False
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 1ce4ea3a01..fc9aab32d0 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -37,7 +37,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         self.hs_clock = Clock(self.reactor)
         self.homeserver = setup_test_homeserver(
             self.addCleanup,
-            http_client=self.http_client,
+            federation_http_client=self.http_client,
             clock=self.hs_clock,
             reactor=self.reactor,
         )
@@ -134,7 +134,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
             }
         )
 
-        with LoggingContext(request="lying_event"):
+        with LoggingContext():
             failure = self.get_failure(
                 self.handler.on_receive_pdu(
                     "test.serv", lying_event, sent_to_us_directly=True
diff --git a/tests/test_mau.py b/tests/test_mau.py
index c5ec6396a7..51660b51d5 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -19,6 +19,7 @@ import json
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import Codes, HttpResponseException, SynapseError
+from synapse.appservice import ApplicationService
 from synapse.rest.client.v2_alpha import register, sync
 
 from tests import unittest
@@ -75,6 +76,45 @@ class TestMauLimit(unittest.HomeserverTestCase):
         self.assertEqual(e.code, 403)
         self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
 
+    def test_as_ignores_mau(self):
+        """Test that application services can still create users when the MAU
+        limit has been reached. This only works when application service
+        user ip tracking is disabled.
+        """
+
+        # Create and sync so that the MAU counts get updated
+        token1 = self.create_user("kermit1")
+        self.do_sync_for_user(token1)
+        token2 = self.create_user("kermit2")
+        self.do_sync_for_user(token2)
+
+        # check we're testing what we think we are: there should be two active users
+        self.assertEqual(self.get_success(self.store.get_monthly_active_count()), 2)
+
+        # We've created and activated two users, we shouldn't be able to
+        # register new users
+        with self.assertRaises(SynapseError) as cm:
+            self.create_user("kermit3")
+
+        e = cm.exception
+        self.assertEqual(e.code, 403)
+        self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+
+        # Cheekily add an application service that we use to register a new user
+        # with.
+        as_token = "foobartoken"
+        self.store.services_cache.append(
+            ApplicationService(
+                token=as_token,
+                hostname=self.hs.hostname,
+                id="SomeASID",
+                sender="@as_sender:test",
+                namespaces={"users": [{"regex": "@as_*", "exclusive": True}]},
+            )
+        )
+
+        self.create_user("as_kermit4", token=as_token)
+
     def test_allowed_after_a_month_mau(self):
         # Create and sync so that the MAU counts get updated
         token1 = self.create_user("kermit1")
@@ -192,7 +232,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
         self.reactor.advance(100)
         self.assertEqual(2, self.successResultOf(count))
 
-    def create_user(self, localpart):
+    def create_user(self, localpart, token=None):
         request_data = json.dumps(
             {
                 "username": localpart,
@@ -201,7 +241,9 @@ class TestMauLimit(unittest.HomeserverTestCase):
             }
         )
 
-        request, channel = self.make_request("POST", "/register", request_data)
+        channel = self.make_request(
+            "POST", "/register", request_data, access_token=token,
+        )
 
         if channel.code != 200:
             raise HttpResponseException(
@@ -213,7 +255,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
         return access_token
 
     def do_sync_for_user(self, token):
-        request, channel = self.make_request("GET", "/sync", access_token=token)
+        channel = self.make_request("GET", "/sync", access_token=token)
 
         if channel.code != 200:
             raise HttpResponseException(
diff --git a/tests/test_preview.py b/tests/test_preview.py
index 7f67ee9e1f..0c6cbbd921 100644
--- a/tests/test_preview.py
+++ b/tests/test_preview.py
@@ -20,8 +20,16 @@ from synapse.rest.media.v1.preview_url_resource import (
 
 from . import unittest
 
+try:
+    import lxml
+except ImportError:
+    lxml = None
+
 
 class PreviewTestCase(unittest.TestCase):
+    if not lxml:
+        skip = "url preview feature requires lxml"
+
     def test_long_summarize(self):
         example_paras = [
             """Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:
@@ -56,7 +64,7 @@ class PreviewTestCase(unittest.TestCase):
 
         desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
 
-        self.assertEquals(
+        self.assertEqual(
             desc,
             "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
             " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -69,7 +77,7 @@ class PreviewTestCase(unittest.TestCase):
 
         desc = summarize_paragraphs(example_paras[1:], min_size=200, max_size=500)
 
-        self.assertEquals(
+        self.assertEqual(
             desc,
             "Tromsø lies in Northern Norway. The municipality has a population of"
             " (2015) 72,066, but with an annual influx of students it has over 75,000"
@@ -96,7 +104,7 @@ class PreviewTestCase(unittest.TestCase):
 
         desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
 
-        self.assertEquals(
+        self.assertEqual(
             desc,
             "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
             " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -122,7 +130,7 @@ class PreviewTestCase(unittest.TestCase):
         ]
 
         desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
-        self.assertEquals(
+        self.assertEqual(
             desc,
             "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
             " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -137,6 +145,9 @@ class PreviewTestCase(unittest.TestCase):
 
 
 class PreviewUrlTestCase(unittest.TestCase):
+    if not lxml:
+        skip = "url preview feature requires lxml"
+
     def test_simple(self):
         html = """
         <html>
@@ -149,7 +160,7 @@ class PreviewUrlTestCase(unittest.TestCase):
 
         og = decode_and_calc_og(html, "http://example.com/test.html")
 
-        self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."})
+        self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
     def test_comment(self):
         html = """
@@ -164,7 +175,7 @@ class PreviewUrlTestCase(unittest.TestCase):
 
         og = decode_and_calc_og(html, "http://example.com/test.html")
 
-        self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."})
+        self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
     def test_comment2(self):
         html = """
@@ -182,7 +193,7 @@ class PreviewUrlTestCase(unittest.TestCase):
 
         og = decode_and_calc_og(html, "http://example.com/test.html")
 
-        self.assertEquals(
+        self.assertEqual(
             og,
             {
                 "og:title": "Foo",
@@ -203,7 +214,7 @@ class PreviewUrlTestCase(unittest.TestCase):
 
         og = decode_and_calc_og(html, "http://example.com/test.html")
 
-        self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."})
+        self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
     def test_missing_title(self):
         html = """
@@ -216,7 +227,7 @@ class PreviewUrlTestCase(unittest.TestCase):
 
         og = decode_and_calc_og(html, "http://example.com/test.html")
 
-        self.assertEquals(og, {"og:title": None, "og:description": "Some text."})
+        self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
 
     def test_h1_as_title(self):
         html = """
@@ -230,7 +241,7 @@ class PreviewUrlTestCase(unittest.TestCase):
 
         og = decode_and_calc_og(html, "http://example.com/test.html")
 
-        self.assertEquals(og, {"og:title": "Title", "og:description": "Some text."})
+        self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
 
     def test_missing_title_and_broken_h1(self):
         html = """
@@ -244,4 +255,38 @@ class PreviewUrlTestCase(unittest.TestCase):
 
         og = decode_and_calc_og(html, "http://example.com/test.html")
 
-        self.assertEquals(og, {"og:title": None, "og:description": "Some text."})
+        self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
+
+    def test_empty(self):
+        html = ""
+        og = decode_and_calc_og(html, "http://example.com/test.html")
+        self.assertEqual(og, {})
+
+    def test_invalid_encoding(self):
+        """An invalid character encoding should be ignored and treated as UTF-8, if possible."""
+        html = """
+        <html>
+        <head><title>Foo</title></head>
+        <body>
+        Some text.
+        </body>
+        </html>
+        """
+        og = decode_and_calc_og(
+            html, "http://example.com/test.html", "invalid-encoding"
+        )
+        self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
+
+    def test_invalid_encoding2(self):
+        """A body which doesn't match the sent character encoding."""
+        # Note that this contains an invalid UTF-8 sequence in the title.
+        html = b"""
+        <html>
+        <head><title>\xff\xff Foo</title></head>
+        <body>
+        Some text.
+        </body>
+        </html>
+        """
+        og = decode_and_calc_og(html, "http://example.com/test.html")
+        self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
diff --git a/tests/test_server.py b/tests/test_server.py
index c387a85f2e..815da18e65 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -38,7 +38,10 @@ class JsonResourceTests(unittest.TestCase):
         self.reactor = ThreadedMemoryReactorClock()
         self.hs_clock = Clock(self.reactor)
         self.homeserver = setup_test_homeserver(
-            self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor
+            self.addCleanup,
+            federation_http_client=None,
+            clock=self.hs_clock,
+            reactor=self.reactor,
         )
 
     def test_handler_for_request(self):
@@ -61,11 +64,10 @@ class JsonResourceTests(unittest.TestCase):
             "test_servlet",
         )
 
-        request, channel = make_request(
+        make_request(
             self.reactor, FakeSite(res), b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83"
         )
 
-        self.assertEqual(request.args, {b"a": ["\N{SNOWMAN}".encode("utf8")]})
         self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"})
 
     def test_callback_direct_exception(self):
@@ -82,7 +84,7 @@ class JsonResourceTests(unittest.TestCase):
             "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
         )
 
-        _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
+        channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
 
         self.assertEqual(channel.result["code"], b"500")
 
@@ -106,7 +108,7 @@ class JsonResourceTests(unittest.TestCase):
             "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
         )
 
-        _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
+        channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
 
         self.assertEqual(channel.result["code"], b"500")
 
@@ -124,7 +126,7 @@ class JsonResourceTests(unittest.TestCase):
             "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
         )
 
-        _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
+        channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
 
         self.assertEqual(channel.result["code"], b"403")
         self.assertEqual(channel.json_body["error"], "Forbidden!!one!")
@@ -146,9 +148,7 @@ class JsonResourceTests(unittest.TestCase):
             "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
         )
 
-        _, channel = make_request(
-            self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar"
-        )
+        channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar")
 
         self.assertEqual(channel.result["code"], b"400")
         self.assertEqual(channel.json_body["error"], "Unrecognized request")
@@ -170,7 +170,7 @@ class JsonResourceTests(unittest.TestCase):
         )
 
         # The path was registered as GET, but this is a HEAD request.
-        _, channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo")
+        channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo")
 
         self.assertEqual(channel.result["code"], b"200")
         self.assertNotIn("body", channel.result)
@@ -202,7 +202,7 @@ class OptionsResourceTests(unittest.TestCase):
         )
 
         # render the request and return the channel
-        _, channel = make_request(self.reactor, site, method, path, shorthand=False)
+        channel = make_request(self.reactor, site, method, path, shorthand=False)
         return channel
 
     def test_unknown_options_request(self):
@@ -275,7 +275,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
         res = WrapHtmlRequestHandlerTests.TestResource()
         res.callback = callback
 
-        _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
+        channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
 
         self.assertEqual(channel.result["code"], b"200")
         body = channel.result["body"]
@@ -293,7 +293,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
         res = WrapHtmlRequestHandlerTests.TestResource()
         res.callback = callback
 
-        _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
+        channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
 
         self.assertEqual(channel.result["code"], b"301")
         headers = channel.result["headers"]
@@ -314,7 +314,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
         res = WrapHtmlRequestHandlerTests.TestResource()
         res.callback = callback
 
-        _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
+        channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
 
         self.assertEqual(channel.result["code"], b"304")
         headers = channel.result["headers"]
@@ -333,7 +333,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
         res = WrapHtmlRequestHandlerTests.TestResource()
         res.callback = callback
 
-        _, channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path")
+        channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path")
 
         self.assertEqual(channel.result["code"], b"200")
         self.assertNotIn("body", channel.result)
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index 71580b454d..a743cdc3a9 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -53,7 +53,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
     def test_ui_auth(self):
         # Do a UI auth request
         request_data = json.dumps({"username": "kermit", "password": "monkey"})
-        request, channel = self.make_request(b"POST", self.url, request_data)
+        channel = self.make_request(b"POST", self.url, request_data)
 
         self.assertEquals(channel.result["code"], b"401", channel.result)
 
@@ -96,7 +96,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
 
         self.registration_handler.check_username = Mock(return_value=True)
 
-        request, channel = self.make_request(b"POST", self.url, request_data)
+        channel = self.make_request(b"POST", self.url, request_data)
 
         # We don't bother checking that the response is correct - we'll leave that to
         # other tests. We just want to make sure we're on the right path.
@@ -113,7 +113,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
                 },
             }
         )
-        request, channel = self.make_request(b"POST", self.url, request_data)
+        channel = self.make_request(b"POST", self.url, request_data)
 
         # We're interested in getting a response that looks like a successful
         # registration, not so much that the details are exactly what we want.
diff --git a/tests/test_types.py b/tests/test_types.py
index 480bea1bdc..acdeea7a09 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -58,6 +58,10 @@ class RoomAliasTestCase(unittest.HomeserverTestCase):
 
         self.assertEquals(room.to_string(), "#channel:my.domain")
 
+    def test_validate(self):
+        id_string = "#test:domain,test"
+        self.assertFalse(RoomAlias.is_valid(id_string))
+
 
 class GroupIDTestCase(unittest.TestCase):
     def test_parse(self):
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index d232b72264..43898d8142 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -22,6 +22,13 @@ import warnings
 from asyncio import Future
 from typing import Any, Awaitable, Callable, TypeVar
 
+from mock import Mock
+
+import attr
+
+from twisted.python.failure import Failure
+from twisted.web.client import ResponseDone
+
 TV = TypeVar("TV")
 
 
@@ -80,3 +87,35 @@ def setup_awaitable_errors() -> Callable[[], None]:
     sys.unraisablehook = unraisablehook  # type: ignore
 
     return cleanup
+
+
+def simple_async_mock(return_value=None, raises=None) -> Mock:
+    # AsyncMock is not available in python3.5, this mimics part of its behaviour
+    async def cb(*args, **kwargs):
+        if raises:
+            raise raises
+        return return_value
+
+    return Mock(side_effect=cb)
+
+
+@attr.s
+class FakeResponse:
+    """A fake twisted.web.IResponse object
+
+    there is a similar class at treq.test.test_response, but it lacks a `phrase`
+    attribute, and didn't support deliverBody until recently.
+    """
+
+    # HTTP response code
+    code = attr.ib(type=int)
+
+    # HTTP response phrase (eg b'OK' for a 200)
+    phrase = attr.ib(type=bytes)
+
+    # body of the response
+    body = attr.ib(type=bytes)
+
+    def deliverBody(self, protocol):
+        protocol.dataReceived(self.body)
+        protocol.connectionLost(Failure(ResponseDone()))
diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py
new file mode 100644
index 0000000000..ad563eb3f0
--- /dev/null
+++ b/tests/test_utils/html_parsers.py
@@ -0,0 +1,53 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from html.parser import HTMLParser
+from typing import Dict, Iterable, List, Optional, Tuple
+
+
+class TestHtmlParser(HTMLParser):
+    """A generic HTML page parser which extracts useful things from the HTML"""
+
+    def __init__(self):
+        super().__init__()
+
+        # a list of links found in the doc
+        self.links = []  # type: List[str]
+
+        # the values of any hidden <input>s: map from name to value
+        self.hiddens = {}  # type: Dict[str, Optional[str]]
+
+        # the values of any radio buttons: map from name to list of values
+        self.radios = {}  # type: Dict[str, List[Optional[str]]]
+
+    def handle_starttag(
+        self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
+    ) -> None:
+        attr_dict = dict(attrs)
+        if tag == "a":
+            href = attr_dict["href"]
+            if href:
+                self.links.append(href)
+        elif tag == "input":
+            input_name = attr_dict.get("name")
+            if attr_dict["type"] == "radio":
+                assert input_name
+                self.radios.setdefault(input_name, []).append(attr_dict["value"])
+            elif attr_dict["type"] == "hidden":
+                assert input_name
+                self.hiddens[input_name] = attr_dict["value"]
+
+    def error(_, message):
+        raise AssertionError(message)
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index fdfb840b62..52ae5c5713 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -48,7 +48,7 @@ def setup_logging():
     handler = ToTwistedHandler()
     formatter = logging.Formatter(log_format)
     handler.setFormatter(formatter)
-    handler.addFilter(LoggingContextFilter(request=""))
+    handler.addFilter(LoggingContextFilter())
     root_logger.addHandler(handler)
 
     log_level = os.environ.get("SYNAPSE_TEST_LOG_LEVEL", "ERROR")
diff --git a/tests/unittest.py b/tests/unittest.py
index a9d59e31f7..767d5d6077 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -20,7 +20,7 @@ import hmac
 import inspect
 import logging
 import time
-from typing import Optional, Tuple, Type, TypeVar, Union, overload
+from typing import Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
 
 from mock import Mock, patch
 
@@ -46,6 +46,7 @@ from synapse.logging.context import (
 )
 from synapse.server import HomeServer
 from synapse.types import UserID, create_requester
+from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.ratelimitutils import FederationRateLimiter
 
 from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
@@ -320,15 +321,28 @@ class HomeserverTestCase(TestCase):
         """
         Create a the root resource for the test server.
 
-        The default implementation creates a JsonResource and calls each function in
-        `servlets` to register servletes against it
+        The default calls `self.create_resource_dict` and builds the resultant dict
+        into a tree.
         """
-        resource = JsonResource(self.hs)
+        root_resource = Resource()
+        create_resource_tree(self.create_resource_dict(), root_resource)
+        return root_resource
 
-        for servlet in self.servlets:
-            servlet(self.hs, resource)
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        """Create a resource tree for the test server
 
-        return resource
+        A resource tree is a mapping from path to twisted.web.resource.
+
+        The default implementation creates a JsonResource and calls each function in
+        `servlets` to register servlets against it.
+        """
+        servlet_resource = JsonResource(self.hs)
+        for servlet in self.servlets:
+            servlet(self.hs, servlet_resource)
+        return {
+            "/_matrix/client": servlet_resource,
+            "/_synapse/admin": servlet_resource,
+        }
 
     def default_config(self):
         """
@@ -358,24 +372,6 @@ class HomeserverTestCase(TestCase):
         Function to optionally be overridden in subclasses.
         """
 
-    # Annoyingly mypy doesn't seem to pick up the fact that T is SynapseRequest
-    # when the `request` arg isn't given, so we define an explicit override to
-    # cover that case.
-    @overload
-    def make_request(
-        self,
-        method: Union[bytes, str],
-        path: Union[bytes, str],
-        content: Union[bytes, dict] = b"",
-        access_token: Optional[str] = None,
-        shorthand: bool = True,
-        federation_auth_origin: str = None,
-        content_is_form: bool = False,
-        await_result: bool = True,
-    ) -> Tuple[SynapseRequest, FakeChannel]:
-        ...
-
-    @overload
     def make_request(
         self,
         method: Union[bytes, str],
@@ -387,21 +383,11 @@ class HomeserverTestCase(TestCase):
         federation_auth_origin: str = None,
         content_is_form: bool = False,
         await_result: bool = True,
-    ) -> Tuple[T, FakeChannel]:
-        ...
-
-    def make_request(
-        self,
-        method: Union[bytes, str],
-        path: Union[bytes, str],
-        content: Union[bytes, dict] = b"",
-        access_token: Optional[str] = None,
-        request: Type[T] = SynapseRequest,
-        shorthand: bool = True,
-        federation_auth_origin: str = None,
-        content_is_form: bool = False,
-        await_result: bool = True,
-    ) -> Tuple[T, FakeChannel]:
+        custom_headers: Optional[
+            Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+        ] = None,
+        client_ip: str = "127.0.0.1",
+    ) -> FakeChannel:
         """
         Create a SynapseRequest at the path using the method and containing the
         given content.
@@ -423,8 +409,13 @@ class HomeserverTestCase(TestCase):
                  true (the default), will pump the test reactor until the the renderer
                  tells the channel the request is finished.
 
+            custom_headers: (name, value) pairs to add as request headers
+
+            client_ip: The IP to use as the requesting IP. Useful for testing
+                ratelimiting.
+
         Returns:
-            Tuple[synapse.http.site.SynapseRequest, channel]
+            The FakeChannel object which stores the result of the request.
         """
         return make_request(
             self.reactor,
@@ -438,6 +429,8 @@ class HomeserverTestCase(TestCase):
             federation_auth_origin,
             content_is_form,
             await_result,
+            custom_headers,
+            client_ip,
         )
 
     def setup_test_homeserver(self, *args, **kwargs):
@@ -554,7 +547,7 @@ class HomeserverTestCase(TestCase):
         self.hs.config.registration_shared_secret = "shared"
 
         # Create the user
-        request, channel = self.make_request("GET", "/_synapse/admin/v1/register")
+        channel = self.make_request("GET", "/_synapse/admin/v1/register")
         self.assertEqual(channel.code, 200, msg=channel.result)
         nonce = channel.json_body["nonce"]
 
@@ -579,7 +572,7 @@ class HomeserverTestCase(TestCase):
                 "inhibit_login": True,
             }
         )
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", "/_synapse/admin/v1/register", body.encode("utf8")
         )
         self.assertEqual(channel.code, 200, channel.json_body)
@@ -597,7 +590,7 @@ class HomeserverTestCase(TestCase):
         if device_id:
             body["device_id"] = device_id
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
         )
         self.assertEqual(channel.code, 200, channel.result)
@@ -665,7 +658,7 @@ class HomeserverTestCase(TestCase):
         """
         body = {"type": "m.login.password", "user": username, "password": password}
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
         )
         self.assertEqual(channel.code, 403, channel.result)
@@ -691,13 +684,29 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
     A federating homeserver that authenticates incoming requests as `other.example.com`.
     """
 
-    def prepare(self, reactor, clock, homeserver):
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        d = super().create_resource_dict()
+        d["/_matrix/federation"] = TestTransportLayerServer(self.hs)
+        return d
+
+
+class TestTransportLayerServer(JsonResource):
+    """A test implementation of TransportLayerServer
+
+    authenticates incoming requests as `other.example.com`.
+    """
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
         class Authenticator:
             def authenticate_request(self, request, content):
                 return succeed("other.example.com")
 
+        authenticator = Authenticator()
+
         ratelimiter = FederationRateLimiter(
-            clock,
+            hs.get_clock(),
             FederationRateLimitConfig(
                 window_size=1,
                 sleep_limit=1,
@@ -706,11 +715,8 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
                 concurrent_requests=1000,
             ),
         )
-        federation_server.register_servlets(
-            homeserver, self.resource, Authenticator(), ratelimiter
-        )
 
-        return super().prepare(reactor, clock, homeserver)
+        federation_server.register_servlets(hs, self, authenticator, ratelimiter)
 
 
 def override_config(extra_config):
@@ -735,3 +741,29 @@ def override_config(extra_config):
         return func
 
     return decorator
+
+
+TV = TypeVar("TV")
+
+
+def skip_unless(condition: bool, reason: str) -> Callable[[TV], TV]:
+    """A test decorator which will skip the decorated test unless a condition is set
+
+    For example:
+
+    class MyTestCase(TestCase):
+        @skip_unless(HAS_FOO, "Cannot test without foo")
+        def test_foo(self):
+            ...
+
+    Args:
+        condition: If true, the test will be skipped
+        reason: the reason to give for skipping the test
+    """
+
+    def decorator(f: TV) -> TV:
+        if not condition:
+            f.skip = reason  # type: ignore
+        return f
+
+    return decorator
diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py
index dadfabd46d..ecd9efc4df 100644
--- a/tests/util/caches/test_deferred_cache.py
+++ b/tests/util/caches/test_deferred_cache.py
@@ -25,13 +25,8 @@ from tests.unittest import TestCase
 class DeferredCacheTestCase(TestCase):
     def test_empty(self):
         cache = DeferredCache("test")
-        failed = False
-        try:
+        with self.assertRaises(KeyError):
             cache.get("foo")
-        except KeyError:
-            failed = True
-
-        self.assertTrue(failed)
 
     def test_hit(self):
         cache = DeferredCache("test")
@@ -155,13 +150,8 @@ class DeferredCacheTestCase(TestCase):
         cache.prefill(("foo",), 123)
         cache.invalidate(("foo",))
 
-        failed = False
-        try:
+        with self.assertRaises(KeyError):
             cache.get(("foo",))
-        except KeyError:
-            failed = True
-
-        self.assertTrue(failed)
 
     def test_invalidate_all(self):
         cache = DeferredCache("testcache")
@@ -215,13 +205,8 @@ class DeferredCacheTestCase(TestCase):
         cache.prefill(2, "two")
         cache.prefill(3, "three")  # 1 will be evicted
 
-        failed = False
-        try:
+        with self.assertRaises(KeyError):
             cache.get(1)
-        except KeyError:
-            failed = True
-
-        self.assertTrue(failed)
 
         cache.get(2)
         cache.get(3)
@@ -239,13 +224,55 @@ class DeferredCacheTestCase(TestCase):
 
         cache.prefill(3, "three")
 
-        failed = False
-        try:
+        with self.assertRaises(KeyError):
             cache.get(2)
-        except KeyError:
-            failed = True
 
-        self.assertTrue(failed)
+        cache.get(1)
+        cache.get(3)
+
+    def test_eviction_iterable(self):
+        cache = DeferredCache(
+            "test", max_entries=3, apply_cache_factor_from_config=False, iterable=True,
+        )
+
+        cache.prefill(1, ["one", "two"])
+        cache.prefill(2, ["three"])
 
+        # Now access 1 again, thus causing 2 to be least-recently used
+        cache.get(1)
+
+        # Now add an item to the cache, which evicts 2.
+        cache.prefill(3, ["four"])
+        with self.assertRaises(KeyError):
+            cache.get(2)
+
+        # Ensure 1 & 3 are in the cache.
         cache.get(1)
         cache.get(3)
+
+        # Now access 1 again, thus causing 3 to be least-recently used
+        cache.get(1)
+
+        # Now add an item with multiple elements to the cache
+        cache.prefill(4, ["five", "six"])
+
+        # Both 1 and 3 are evicted since there's too many elements.
+        with self.assertRaises(KeyError):
+            cache.get(1)
+        with self.assertRaises(KeyError):
+            cache.get(3)
+
+        # Now add another item to fill the cache again.
+        cache.prefill(5, ["seven"])
+
+        # Now access 4, thus causing 5 to be least-recently used
+        cache.get(4)
+
+        # Add an empty item.
+        cache.prefill(6, [])
+
+        # 5 gets evicted and replaced since an empty element counts as an item.
+        with self.assertRaises(KeyError):
+            cache.get(5)
+        cache.get(4)
+        cache.get(6)
diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py
index 0ab0a91483..1ef0af8e8f 100644
--- a/tests/util/test_itertools.py
+++ b/tests/util/test_itertools.py
@@ -12,7 +12,9 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from synapse.util.iterutils import chunk_seq
+from typing import Dict, List
+
+from synapse.util.iterutils import chunk_seq, sorted_topologically
 
 from tests.unittest import TestCase
 
@@ -45,3 +47,60 @@ class ChunkSeqTests(TestCase):
         self.assertEqual(
             list(parts), [],
         )
+
+
+class SortTopologically(TestCase):
+    def test_empty(self):
+        "Test that an empty graph works correctly"
+
+        graph = {}  # type: Dict[int, List[int]]
+        self.assertEqual(list(sorted_topologically([], graph)), [])
+
+    def test_handle_empty_graph(self):
+        "Test that a graph where a node doesn't have an entry is treated as empty"
+
+        graph = {}  # type: Dict[int, List[int]]
+
+        # For disconnected nodes the output is simply sorted.
+        self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
+
+    def test_disconnected(self):
+        "Test that a graph with no edges work"
+
+        graph = {1: [], 2: []}  # type: Dict[int, List[int]]
+
+        # For disconnected nodes the output is simply sorted.
+        self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
+
+    def test_linear(self):
+        "Test that a simple `4 -> 3 -> 2 -> 1` graph works"
+
+        graph = {1: [], 2: [1], 3: [2], 4: [3]}  # type: Dict[int, List[int]]
+
+        self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
+
+    def test_subset(self):
+        "Test that only sorting a subset of the graph works"
+        graph = {1: [], 2: [1], 3: [2], 4: [3]}  # type: Dict[int, List[int]]
+
+        self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4])
+
+    def test_fork(self):
+        "Test that a forked graph works"
+        graph = {1: [], 2: [1], 3: [1], 4: [2, 3]}  # type: Dict[int, List[int]]
+
+        # Valid orderings are `[1, 3, 2, 4]` or `[1, 2, 3, 4]`, but we should
+        # always get the same one.
+        self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
+
+    def test_duplicates(self):
+        "Test that a graph with duplicate edges work"
+        graph = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]}  # type: Dict[int, List[int]]
+
+        self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
+
+    def test_multiple_paths(self):
+        "Test that a graph with multiple paths between two nodes work"
+        graph = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]}  # type: Dict[int, List[int]]
+
+        self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
diff --git a/tests/utils.py b/tests/utils.py
index c8d3ffbaba..840b657f82 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -20,13 +20,12 @@ import os
 import time
 import uuid
 import warnings
-from inspect import getcallargs
 from typing import Type
 from urllib import parse as urlparse
 
 from mock import Mock, patch
 
-from twisted.internet import defer, reactor
+from twisted.internet import defer
 
 from synapse.api.constants import EventTypes
 from synapse.api.errors import CodeMessageException, cs_error
@@ -34,15 +33,12 @@ from synapse.api.room_versions import RoomVersions
 from synapse.config.database import DatabaseConnectionConfig
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import DEFAULT_ROOM_VERSION
-from synapse.federation.transport import server as federation_server
-from synapse.http.server import HttpServer
 from synapse.logging.context import current_context, set_current_context
 from synapse.server import HomeServer
 from synapse.storage import DataStore
 from synapse.storage.database import LoggingDatabaseConnection
 from synapse.storage.engines import PostgresEngine, create_engine
 from synapse.storage.prepare_database import prepare_database
-from synapse.util.ratelimitutils import FederationRateLimiter
 
 # set this to True to run the tests against postgres instead of sqlite.
 #
@@ -161,6 +157,7 @@ def default_config(name, parse=False):
             "local": {"per_second": 10000, "burst_count": 10000},
             "remote": {"per_second": 10000, "burst_count": 10000},
         },
+        "rc_3pid_validation": {"per_second": 10000, "burst_count": 10000},
         "saml2_enabled": False,
         "public_baseurl": None,
         "default_identity_server": None,
@@ -342,32 +339,9 @@ def setup_test_homeserver(
 
     hs.get_auth_handler().validate_hash = validate_hash
 
-    fed = kwargs.get("resource_for_federation", None)
-    if fed:
-        register_federation_servlets(hs, fed)
-
     return hs
 
 
-def register_federation_servlets(hs, resource):
-    federation_server.register_servlets(
-        hs,
-        resource=resource,
-        authenticator=federation_server.Authenticator(hs),
-        ratelimiter=FederationRateLimiter(
-            hs.get_clock(), config=hs.config.rc_federation
-        ),
-    )
-
-
-def get_mock_call_args(pattern_func, mock_func):
-    """ Return the arguments the mock function was called with interpreted
-    by the pattern functions argument list.
-    """
-    invoked_args, invoked_kargs = mock_func.call_args
-    return getcallargs(pattern_func, *invoked_args, **invoked_kargs)
-
-
 def mock_getRawHeaders(headers=None):
     headers = headers if headers is not None else {}
 
@@ -378,7 +352,7 @@ def mock_getRawHeaders(headers=None):
 
 
 # This is a mock /resource/ not an entire server
-class MockHttpResource(HttpServer):
+class MockHttpResource:
     def __init__(self, prefix=""):
         self.callbacks = []  # 3-tuple of method/pattern/function
         self.prefix = prefix
@@ -553,86 +527,6 @@ class MockClock:
         return d
 
 
-def _format_call(args, kwargs):
-    return ", ".join(
-        ["%r" % (a) for a in args] + ["%s=%r" % (k, v) for k, v in kwargs.items()]
-    )
-
-
-class DeferredMockCallable:
-    """A callable instance that stores a set of pending call expectations and
-    return values for them. It allows a unit test to assert that the given set
-    of function calls are eventually made, by awaiting on them to be called.
-    """
-
-    def __init__(self):
-        self.expectations = []
-        self.calls = []
-
-    def __call__(self, *args, **kwargs):
-        self.calls.append((args, kwargs))
-
-        if not self.expectations:
-            raise ValueError(
-                "%r has no pending calls to handle call(%s)"
-                % (self, _format_call(args, kwargs))
-            )
-
-        for (call, result, d) in self.expectations:
-            if args == call[1] and kwargs == call[2]:
-                d.callback(None)
-                return result
-
-        failure = AssertionError(
-            "Was not expecting call(%s)" % (_format_call(args, kwargs))
-        )
-
-        for _, _, d in self.expectations:
-            try:
-                d.errback(failure)
-            except Exception:
-                pass
-
-        raise failure
-
-    def expect_call_and_return(self, call, result):
-        self.expectations.append((call, result, defer.Deferred()))
-
-    @defer.inlineCallbacks
-    def await_calls(self, timeout=1000):
-        deferred = defer.DeferredList(
-            [d for _, _, d in self.expectations], fireOnOneErrback=True
-        )
-
-        timer = reactor.callLater(
-            timeout / 1000,
-            deferred.errback,
-            AssertionError(
-                "%d pending calls left: %s"
-                % (
-                    len([e for e in self.expectations if not e[2].called]),
-                    [e for e in self.expectations if not e[2].called],
-                )
-            ),
-        )
-
-        yield deferred
-
-        timer.cancel()
-
-        self.calls = []
-
-    def assert_had_no_calls(self):
-        if self.calls:
-            calls = self.calls
-            self.calls = []
-
-            raise AssertionError(
-                "Expected not to received any calls, got:\n"
-                + "\n".join(["call(%s)" % _format_call(c[0], c[1]) for c in calls])
-            )
-
-
 async def create_room(hs, room_id: str, creator_id: str):
     """Creates and persist a creation event for the given room
     """