summary refs log tree commit diff
path: root/tests/rest
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest')
-rw-r--r--tests/rest/admin/test_admin.py48
-rw-r--r--tests/rest/admin/test_background_updates.py154
-rw-r--r--tests/rest/admin/test_room.py982
-rw-r--r--tests/rest/admin/test_user.py30
-rw-r--r--tests/rest/client/test_capabilities.py8
-rw-r--r--tests/rest/client/test_directory.py105
-rw-r--r--tests/rest/client/test_login.py73
-rw-r--r--tests/rest/client/test_relations.py173
-rw-r--r--tests/rest/client/test_rooms.py154
-rw-r--r--tests/rest/client/utils.py71
10 files changed, 1635 insertions, 163 deletions
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 192073c520..af849bd471 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -474,3 +474,51 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
                 % server_and_media_id_2
             ),
         )
+
+
+class PurgeHistoryTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.other_user = self.register_user("user", "pass")
+        self.other_user_tok = self.login("user", "pass")
+
+        self.room_id = self.helper.create_room_as(
+            self.other_user, tok=self.other_user_tok
+        )
+        self.url = f"/_synapse/admin/v1/purge_history/{self.room_id}"
+        self.url_status = "/_synapse/admin/v1/purge_history_status/"
+
+    def test_purge_history(self):
+        """
+        Simple test of purge history API.
+        Test only that is is possible to call, get status 200 and purge_id.
+        """
+
+        channel = self.make_request(
+            "POST",
+            self.url,
+            content={"delete_local_events": True, "purge_up_to_ts": 0},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertIn("purge_id", channel.json_body)
+        purge_id = channel.json_body["purge_id"]
+
+        # get status
+        channel = self.make_request(
+            "GET",
+            self.url_status + purge_id,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual("complete", channel.json_body["status"])
diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py
index 78c48db552..1786316763 100644
--- a/tests/rest/admin/test_background_updates.py
+++ b/tests/rest/admin/test_background_updates.py
@@ -11,8 +11,13 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from http import HTTPStatus
+from typing import Collection
+
+from parameterized import parameterized
 
 import synapse.rest.admin
+from synapse.api.errors import Codes
 from synapse.rest.client import login
 from synapse.server import HomeServer
 
@@ -30,6 +35,60 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
+    @parameterized.expand(
+        [
+            ("GET", "/_synapse/admin/v1/background_updates/enabled"),
+            ("POST", "/_synapse/admin/v1/background_updates/enabled"),
+            ("GET", "/_synapse/admin/v1/background_updates/status"),
+            ("POST", "/_synapse/admin/v1/background_updates/start_job"),
+        ]
+    )
+    def test_requester_is_no_admin(self, method: str, url: str):
+        """
+        If the user is not a server admin, an error 403 is returned.
+        """
+
+        self.register_user("user", "pass", admin=False)
+        other_user_tok = self.login("user", "pass")
+
+        channel = self.make_request(
+            method,
+            url,
+            content={},
+            access_token=other_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_invalid_parameter(self):
+        """
+        If parameters are invalid, an error is returned.
+        """
+        url = "/_synapse/admin/v1/background_updates/start_job"
+
+        # empty content
+        channel = self.make_request(
+            "POST",
+            url,
+            content={},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
+
+        # job_name invalid
+        channel = self.make_request(
+            "POST",
+            url,
+            content={"job_name": "unknown"},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
     def _register_bg_update(self):
         "Adds a bg update but doesn't start it"
 
@@ -60,7 +119,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
             "/_synapse/admin/v1/background_updates/status",
             access_token=self.admin_user_tok,
         )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
 
         # Background updates should be enabled, but none should be running.
         self.assertDictEqual(
@@ -82,7 +141,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
             "/_synapse/admin/v1/background_updates/status",
             access_token=self.admin_user_tok,
         )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
 
         # Background updates should be enabled, and one should be running.
         self.assertDictEqual(
@@ -114,7 +173,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
             "/_synapse/admin/v1/background_updates/enabled",
             access_token=self.admin_user_tok,
         )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
         self.assertDictEqual(channel.json_body, {"enabled": True})
 
         # Disable the BG updates
@@ -124,7 +183,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
             content={"enabled": False},
             access_token=self.admin_user_tok,
         )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
         self.assertDictEqual(channel.json_body, {"enabled": False})
 
         # Advance a bit and get the current status, note this will finish the in
@@ -137,7 +196,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
             "/_synapse/admin/v1/background_updates/status",
             access_token=self.admin_user_tok,
         )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
         self.assertDictEqual(
             channel.json_body,
             {
@@ -162,7 +221,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
             "/_synapse/admin/v1/background_updates/status",
             access_token=self.admin_user_tok,
         )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
 
         # There should be no change from the previous /status response.
         self.assertDictEqual(
@@ -188,7 +247,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
             content={"enabled": True},
             access_token=self.admin_user_tok,
         )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
 
         self.assertDictEqual(channel.json_body, {"enabled": True})
 
@@ -199,7 +258,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
             "/_synapse/admin/v1/background_updates/status",
             access_token=self.admin_user_tok,
         )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
 
         # Background updates should be enabled and making progress.
         self.assertDictEqual(
@@ -216,3 +275,82 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
                 "enabled": True,
             },
         )
+
+    @parameterized.expand(
+        [
+            ("populate_stats_process_rooms", ["populate_stats_process_rooms"]),
+            (
+                "regenerate_directory",
+                [
+                    "populate_user_directory_createtables",
+                    "populate_user_directory_process_rooms",
+                    "populate_user_directory_process_users",
+                    "populate_user_directory_cleanup",
+                ],
+            ),
+        ]
+    )
+    def test_start_backround_job(self, job_name: str, updates: Collection[str]):
+        """
+        Test that background updates add to database and be processed.
+
+        Args:
+            job_name: name of the job to call with API
+            updates: collection of background updates to be started
+        """
+
+        # no background update is waiting
+        self.assertTrue(
+            self.get_success(
+                self.store.db_pool.updates.has_completed_background_updates()
+            )
+        )
+
+        channel = self.make_request(
+            "POST",
+            "/_synapse/admin/v1/background_updates/start_job",
+            content={"job_name": job_name},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+
+        # test that each background update is waiting now
+        for update in updates:
+            self.assertFalse(
+                self.get_success(
+                    self.store.db_pool.updates.has_completed_background_update(update)
+                )
+            )
+
+        self.wait_for_background_updates()
+
+        # background updates are done
+        self.assertTrue(
+            self.get_success(
+                self.store.db_pool.updates.has_completed_background_updates()
+            )
+        )
+
+    def test_start_backround_job_twice(self):
+        """Test that add a background update twice return an error."""
+
+        # add job to database
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                table="background_updates",
+                values={
+                    "update_name": "populate_stats_process_rooms",
+                    "progress_json": "{}",
+                },
+            )
+        )
+
+        channel = self.make_request(
+            "POST",
+            "/_synapse/admin/v1/background_updates/start_job",
+            content={"job_name": "populate_stats_process_rooms"},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 46116644ce..07077aff78 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -14,12 +14,16 @@
 
 import json
 import urllib.parse
+from http import HTTPStatus
 from typing import List, Optional
 from unittest.mock import Mock
 
+from parameterized import parameterized
+
 import synapse.rest.admin
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import Codes
+from synapse.handlers.pagination import PaginationHandler
 from synapse.rest.client import directory, events, login, room
 
 from tests import unittest
@@ -68,11 +72,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "DELETE",
             self.url,
-            json.dumps({}),
+            {},
             access_token=self.other_user_tok,
         )
 
-        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
     def test_room_does_not_exist(self):
@@ -84,11 +88,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "DELETE",
             url,
-            json.dumps({}),
+            {},
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
 
     def test_room_is_not_valid(self):
@@ -100,11 +104,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "DELETE",
             url,
-            json.dumps({}),
+            {},
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual(
             "invalidroom is not a legal room ID",
             channel.json_body["error"],
@@ -119,11 +123,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "DELETE",
             self.url,
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertIn("new_room_id", channel.json_body)
         self.assertIn("kicked_users", channel.json_body)
         self.assertIn("failed_to_kick_users", channel.json_body)
@@ -138,11 +142,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "DELETE",
             self.url,
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual(
             "User must be our own: @not:exist.bla",
             channel.json_body["error"],
@@ -157,11 +161,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "DELETE",
             self.url,
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
 
     def test_purge_is_not_bool(self):
@@ -173,11 +177,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "DELETE",
             self.url,
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
 
     def test_purge_room_and_block(self):
@@ -199,11 +203,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "DELETE",
             self.url.encode("ascii"),
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(None, channel.json_body["new_room_id"])
         self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
         self.assertIn("failed_to_kick_users", channel.json_body)
@@ -232,11 +236,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "DELETE",
             self.url.encode("ascii"),
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(None, channel.json_body["new_room_id"])
         self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
         self.assertIn("failed_to_kick_users", channel.json_body)
@@ -266,11 +270,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "DELETE",
             self.url.encode("ascii"),
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(None, channel.json_body["new_room_id"])
         self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
         self.assertIn("failed_to_kick_users", channel.json_body)
@@ -281,6 +285,31 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         self._is_blocked(self.room_id, expect=True)
         self._has_no_members(self.room_id)
 
+    @parameterized.expand([(True,), (False,)])
+    def test_block_unknown_room(self, purge: bool) -> None:
+        """
+        We can block an unknown room. In this case, the `purge` argument
+        should be ignored.
+        """
+        room_id = "!unknown:test"
+
+        # The room isn't already in the blocked rooms table
+        self._is_blocked(room_id, expect=False)
+
+        # Request the room be blocked.
+        channel = self.make_request(
+            "DELETE",
+            f"/_synapse/admin/v1/rooms/{room_id}",
+            {"block": True, "purge": purge},
+            access_token=self.admin_user_tok,
+        )
+
+        # The room is now blocked.
+        self.assertEqual(
+            HTTPStatus.OK, int(channel.result["code"]), msg=channel.result["body"]
+        )
+        self._is_blocked(room_id)
+
     def test_shutdown_room_consent(self):
         """Test that we can shutdown rooms with local users who have not
         yet accepted the privacy policy. This used to fail when we tried to
@@ -316,7 +345,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
         self.assertIn("new_room_id", channel.json_body)
         self.assertIn("failed_to_kick_users", channel.json_body)
@@ -345,7 +374,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
             json.dumps({"history_visibility": "world_readable"}),
             access_token=self.other_user_tok,
         )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
 
         # Test that room is not purged
         with self.assertRaises(AssertionError):
@@ -362,7 +391,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
         self.assertIn("new_room_id", channel.json_body)
         self.assertIn("failed_to_kick_users", channel.json_body)
@@ -418,17 +447,616 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "GET", url.encode("ascii"), access_token=self.admin_user_tok
         )
+        self.assertEqual(expect_code, channel.code, msg=channel.json_body)
+
+        url = "events?timeout=0&room_id=" + room_id
+        channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok
+        )
+        self.assertEqual(expect_code, channel.code, msg=channel.json_body)
+
+
+class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        events.register_servlets,
+        room.register_servlets,
+        room.register_deprecated_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.event_creation_handler = hs.get_event_creation_handler()
+        hs.config.consent.user_consent_version = "1"
+
+        consent_uri_builder = Mock()
+        consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
+        self.event_creation_handler._consent_uri_builder = consent_uri_builder
+
+        self.store = hs.get_datastore()
+
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.other_user = self.register_user("user", "pass")
+        self.other_user_tok = self.login("user", "pass")
+
+        # Mark the admin user as having consented
+        self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
+
+        self.room_id = self.helper.create_room_as(
+            self.other_user, tok=self.other_user_tok
+        )
+        self.url = f"/_synapse/admin/v2/rooms/{self.room_id}"
+        self.url_status_by_room_id = (
+            f"/_synapse/admin/v2/rooms/{self.room_id}/delete_status"
+        )
+        self.url_status_by_delete_id = "/_synapse/admin/v2/rooms/delete_status/"
+
+    @parameterized.expand(
+        [
+            ("DELETE", "/_synapse/admin/v2/rooms/%s"),
+            ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"),
+            ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"),
+        ]
+    )
+    def test_requester_is_no_admin(self, method: str, url: str):
+        """
+        If the user is not a server admin, an error 403 is returned.
+        """
+
+        channel = self.make_request(
+            method,
+            url % self.room_id,
+            content={},
+            access_token=self.other_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    @parameterized.expand(
+        [
+            ("DELETE", "/_synapse/admin/v2/rooms/%s"),
+            ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"),
+            ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"),
+        ]
+    )
+    def test_room_does_not_exist(self, method: str, url: str):
+        """
+        Check that unknown rooms/server return error 404.
+        """
+
+        channel = self.make_request(
+            method,
+            url % "!unknown:test",
+            content={},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+    @parameterized.expand(
+        [
+            ("DELETE", "/_synapse/admin/v2/rooms/%s"),
+            ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"),
+        ]
+    )
+    def test_room_is_not_valid(self, method: str, url: str):
+        """
+        Check that invalid room names, return an error 400.
+        """
+
+        channel = self.make_request(
+            method,
+            url % "invalidroom",
+            content={},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+        self.assertEqual(
+            "invalidroom is not a legal room ID",
+            channel.json_body["error"],
+        )
+
+    def test_new_room_user_does_not_exist(self):
+        """
+        Tests that the user ID must be from local server but it does not have to exist.
+        """
+
+        channel = self.make_request(
+            "DELETE",
+            self.url,
+            content={"new_room_user_id": "@unknown:test"},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertIn("delete_id", channel.json_body)
+        delete_id = channel.json_body["delete_id"]
+
+        self._test_result(delete_id, self.other_user, expect_new_room=True)
+
+    def test_new_room_user_is_not_local(self):
+        """
+        Check that only local users can create new room to move members.
+        """
+
+        channel = self.make_request(
+            "DELETE",
+            self.url,
+            content={"new_room_user_id": "@not:exist.bla"},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
         self.assertEqual(
-            expect_code, int(channel.result["code"]), msg=channel.result["body"]
+            "User must be our own: @not:exist.bla",
+            channel.json_body["error"],
         )
 
+    def test_block_is_not_bool(self):
+        """
+        If parameter `block` is not boolean, return an error
+        """
+
+        channel = self.make_request(
+            "DELETE",
+            self.url,
+            content={"block": "NotBool"},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+    def test_purge_is_not_bool(self):
+        """
+        If parameter `purge` is not boolean, return an error
+        """
+
+        channel = self.make_request(
+            "DELETE",
+            self.url,
+            content={"purge": "NotBool"},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+    def test_delete_expired_status(self):
+        """Test that the task status is removed after expiration."""
+
+        # first task, do not purge, that we can create a second task
+        channel = self.make_request(
+            "DELETE",
+            self.url.encode("ascii"),
+            content={"purge": False},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertIn("delete_id", channel.json_body)
+        delete_id1 = channel.json_body["delete_id"]
+
+        # go ahead
+        self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2)
+
+        # second task
+        channel = self.make_request(
+            "DELETE",
+            self.url.encode("ascii"),
+            content={"purge": True},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertIn("delete_id", channel.json_body)
+        delete_id2 = channel.json_body["delete_id"]
+
+        # get status
+        channel = self.make_request(
+            "GET",
+            self.url_status_by_room_id,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(2, len(channel.json_body["results"]))
+        self.assertEqual("complete", channel.json_body["results"][0]["status"])
+        self.assertEqual("complete", channel.json_body["results"][1]["status"])
+        self.assertEqual(delete_id1, channel.json_body["results"][0]["delete_id"])
+        self.assertEqual(delete_id2, channel.json_body["results"][1]["delete_id"])
+
+        # get status after more than clearing time for first task
+        # second task is not cleared
+        self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2)
+
+        channel = self.make_request(
+            "GET",
+            self.url_status_by_room_id,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(1, len(channel.json_body["results"]))
+        self.assertEqual("complete", channel.json_body["results"][0]["status"])
+        self.assertEqual(delete_id2, channel.json_body["results"][0]["delete_id"])
+
+        # get status after more than clearing time for all tasks
+        self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2)
+
+        channel = self.make_request(
+            "GET",
+            self.url_status_by_room_id,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+    def test_delete_same_room_twice(self):
+        """Test that the call for delete a room at second time gives an exception."""
+
+        body = {"new_room_user_id": self.admin_user}
+
+        # first call to delete room
+        # and do not wait for finish the task
+        first_channel = self.make_request(
+            "DELETE",
+            self.url.encode("ascii"),
+            content=body,
+            access_token=self.admin_user_tok,
+            await_result=False,
+        )
+
+        # second call to delete room
+        second_channel = self.make_request(
+            "DELETE",
+            self.url.encode("ascii"),
+            content=body,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(
+            HTTPStatus.BAD_REQUEST, second_channel.code, msg=second_channel.json_body
+        )
+        self.assertEqual(Codes.UNKNOWN, second_channel.json_body["errcode"])
+        self.assertEqual(
+            f"History purge already in progress for {self.room_id}",
+            second_channel.json_body["error"],
+        )
+
+        # get result of first call
+        first_channel.await_result()
+        self.assertEqual(HTTPStatus.OK, first_channel.code, msg=first_channel.json_body)
+        self.assertIn("delete_id", first_channel.json_body)
+
+        # check status after finish the task
+        self._test_result(
+            first_channel.json_body["delete_id"],
+            self.other_user,
+            expect_new_room=True,
+        )
+
+    def test_purge_room_and_block(self):
+        """Test to purge a room and block it.
+        Members will not be moved to a new room and will not receive a message.
+        """
+        # Test that room is not purged
+        with self.assertRaises(AssertionError):
+            self._is_purged(self.room_id)
+
+        # Test that room is not blocked
+        self._is_blocked(self.room_id, expect=False)
+
+        # Assert one user in room
+        self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+        channel = self.make_request(
+            "DELETE",
+            self.url.encode("ascii"),
+            content={"block": True, "purge": True},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertIn("delete_id", channel.json_body)
+        delete_id = channel.json_body["delete_id"]
+
+        self._test_result(delete_id, self.other_user)
+
+        self._is_purged(self.room_id)
+        self._is_blocked(self.room_id, expect=True)
+        self._has_no_members(self.room_id)
+
+    def test_purge_room_and_not_block(self):
+        """Test to purge a room and do not block it.
+        Members will not be moved to a new room and will not receive a message.
+        """
+        # Test that room is not purged
+        with self.assertRaises(AssertionError):
+            self._is_purged(self.room_id)
+
+        # Test that room is not blocked
+        self._is_blocked(self.room_id, expect=False)
+
+        # Assert one user in room
+        self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+        channel = self.make_request(
+            "DELETE",
+            self.url.encode("ascii"),
+            content={"block": False, "purge": True},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertIn("delete_id", channel.json_body)
+        delete_id = channel.json_body["delete_id"]
+
+        self._test_result(delete_id, self.other_user)
+
+        self._is_purged(self.room_id)
+        self._is_blocked(self.room_id, expect=False)
+        self._has_no_members(self.room_id)
+
+    def test_block_room_and_not_purge(self):
+        """Test to block a room without purging it.
+        Members will not be moved to a new room and will not receive a message.
+        The room will not be purged.
+        """
+        # Test that room is not purged
+        with self.assertRaises(AssertionError):
+            self._is_purged(self.room_id)
+
+        # Test that room is not blocked
+        self._is_blocked(self.room_id, expect=False)
+
+        # Assert one user in room
+        self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+        channel = self.make_request(
+            "DELETE",
+            self.url.encode("ascii"),
+            content={"block": True, "purge": False},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertIn("delete_id", channel.json_body)
+        delete_id = channel.json_body["delete_id"]
+
+        self._test_result(delete_id, self.other_user)
+
+        with self.assertRaises(AssertionError):
+            self._is_purged(self.room_id)
+        self._is_blocked(self.room_id, expect=True)
+        self._has_no_members(self.room_id)
+
+    def test_shutdown_room_consent(self):
+        """Test that we can shutdown rooms with local users who have not
+        yet accepted the privacy policy. This used to fail when we tried to
+        force part the user from the old room.
+        Members will be moved to a new room and will receive a message.
+        """
+        self.event_creation_handler._block_events_without_consent_error = None
+
+        # Assert one user in room
+        users_in_room = self.get_success(self.store.get_users_in_room(self.room_id))
+        self.assertEqual([self.other_user], users_in_room)
+
+        # Enable require consent to send events
+        self.event_creation_handler._block_events_without_consent_error = "Error"
+
+        # Assert that the user is getting consent error
+        self.helper.send(
+            self.room_id, body="foo", tok=self.other_user_tok, expect_code=403
+        )
+
+        # Test that room is not purged
+        with self.assertRaises(AssertionError):
+            self._is_purged(self.room_id)
+
+        # Assert one user in room
+        self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+        # Test that the admin can still send shutdown
+        channel = self.make_request(
+            "DELETE",
+            self.url,
+            content={"new_room_user_id": self.admin_user},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertIn("delete_id", channel.json_body)
+        delete_id = channel.json_body["delete_id"]
+
+        self._test_result(delete_id, self.other_user, expect_new_room=True)
+
+        channel = self.make_request(
+            "GET",
+            self.url_status_by_room_id,
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(1, len(channel.json_body["results"]))
+
+        # Test that member has moved to new room
+        self._is_member(
+            room_id=channel.json_body["results"][0]["shutdown_room"]["new_room_id"],
+            user_id=self.other_user,
+        )
+
+        self._is_purged(self.room_id)
+        self._has_no_members(self.room_id)
+
+    def test_shutdown_room_block_peek(self):
+        """Test that a world_readable room can no longer be peeked into after
+        it has been shut down.
+        Members will be moved to a new room and will receive a message.
+        """
+        self.event_creation_handler._block_events_without_consent_error = None
+
+        # Enable world readable
+        url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,)
+        channel = self.make_request(
+            "PUT",
+            url.encode("ascii"),
+            content={"history_visibility": "world_readable"},
+            access_token=self.other_user_tok,
+        )
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+
+        # Test that room is not purged
+        with self.assertRaises(AssertionError):
+            self._is_purged(self.room_id)
+
+        # Assert one user in room
+        self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+        # Test that the admin can still send shutdown
+        channel = self.make_request(
+            "DELETE",
+            self.url,
+            content={"new_room_user_id": self.admin_user},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertIn("delete_id", channel.json_body)
+        delete_id = channel.json_body["delete_id"]
+
+        self._test_result(delete_id, self.other_user, expect_new_room=True)
+
+        channel = self.make_request(
+            "GET",
+            self.url_status_by_room_id,
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(1, len(channel.json_body["results"]))
+
+        # Test that member has moved to new room
+        self._is_member(
+            room_id=channel.json_body["results"][0]["shutdown_room"]["new_room_id"],
+            user_id=self.other_user,
+        )
+
+        self._is_purged(self.room_id)
+        self._has_no_members(self.room_id)
+
+        # Assert we can no longer peek into the room
+        self._assert_peek(self.room_id, expect_code=403)
+
+    def _is_blocked(self, room_id: str, expect: bool = True) -> None:
+        """Assert that the room is blocked or not"""
+        d = self.store.is_room_blocked(room_id)
+        if expect:
+            self.assertTrue(self.get_success(d))
+        else:
+            self.assertIsNone(self.get_success(d))
+
+    def _has_no_members(self, room_id: str) -> None:
+        """Assert there is now no longer anyone in the room"""
+        users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+        self.assertEqual([], users_in_room)
+
+    def _is_member(self, room_id: str, user_id: str) -> None:
+        """Test that user is member of the room"""
+        users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+        self.assertIn(user_id, users_in_room)
+
+    def _is_purged(self, room_id: str) -> None:
+        """Test that the following tables have been purged of all rows related to the room."""
+        for table in PURGE_TABLES:
+            count = self.get_success(
+                self.store.db_pool.simple_select_one_onecol(
+                    table=table,
+                    keyvalues={"room_id": room_id},
+                    retcol="COUNT(*)",
+                    desc="test_purge_room",
+                )
+            )
+
+            self.assertEqual(count, 0, msg=f"Rows not purged in {table}")
+
+    def _assert_peek(self, room_id: str, expect_code: int) -> None:
+        """Assert that the admin user can (or cannot) peek into the room."""
+
+        url = f"rooms/{room_id}/initialSync"
+        channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok
+        )
+        self.assertEqual(expect_code, channel.code, msg=channel.json_body)
+
         url = "events?timeout=0&room_id=" + room_id
         channel = self.make_request(
             "GET", url.encode("ascii"), access_token=self.admin_user_tok
         )
+        self.assertEqual(expect_code, channel.code, msg=channel.json_body)
+
+    def _test_result(
+        self,
+        delete_id: str,
+        kicked_user: str,
+        expect_new_room: bool = False,
+    ) -> None:
+        """
+        Test that the result is the expected.
+        Uses both APIs (status by room_id and delete_id)
+
+        Args:
+            delete_id: id of this purge
+            kicked_user: a user_id which is kicked from the room
+            expect_new_room: if we expect that a new room was created
+        """
+
+        # get information by room_id
+        channel_room_id = self.make_request(
+            "GET",
+            self.url_status_by_room_id,
+            access_token=self.admin_user_tok,
+        )
         self.assertEqual(
-            expect_code, int(channel.result["code"]), msg=channel.result["body"]
+            HTTPStatus.OK, channel_room_id.code, msg=channel_room_id.json_body
         )
+        self.assertEqual(1, len(channel_room_id.json_body["results"]))
+        self.assertEqual(
+            delete_id, channel_room_id.json_body["results"][0]["delete_id"]
+        )
+
+        # get information by delete_id
+        channel_delete_id = self.make_request(
+            "GET",
+            self.url_status_by_delete_id + delete_id,
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(
+            HTTPStatus.OK,
+            channel_delete_id.code,
+            msg=channel_delete_id.json_body,
+        )
+
+        # test values that are the same in both responses
+        for content in [
+            channel_room_id.json_body["results"][0],
+            channel_delete_id.json_body,
+        ]:
+            self.assertEqual("complete", content["status"])
+            self.assertEqual(kicked_user, content["shutdown_room"]["kicked_users"][0])
+            self.assertIn("failed_to_kick_users", content["shutdown_room"])
+            self.assertIn("local_aliases", content["shutdown_room"])
+            self.assertNotIn("error", content)
+
+            if expect_new_room:
+                self.assertIsNotNone(content["shutdown_room"]["new_room_id"])
+            else:
+                self.assertIsNone(content["shutdown_room"]["new_room_id"])
 
 
 class RoomTestCase(unittest.HomeserverTestCase):
@@ -466,7 +1094,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
         )
 
         # Check request completed successfully
-        self.assertEqual(200, int(channel.code), msg=channel.json_body)
+        self.assertEqual(200, channel.code, msg=channel.json_body)
 
         # Check that response json body contains a "rooms" key
         self.assertTrue(
@@ -550,9 +1178,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
                 url.encode("ascii"),
                 access_token=self.admin_user_tok,
             )
-            self.assertEqual(
-                200, int(channel.result["code"]), msg=channel.result["body"]
-            )
+            self.assertEqual(200, channel.code, msg=channel.json_body)
 
             self.assertTrue("rooms" in channel.json_body)
             for r in channel.json_body["rooms"]:
@@ -592,7 +1218,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
             url.encode("ascii"),
             access_token=self.admin_user_tok,
         )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
 
     def test_correct_room_attributes(self):
         """Test the correct attributes for a room are returned"""
@@ -615,7 +1241,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
             {"room_id": room_id},
             access_token=self.admin_user_tok,
         )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
 
         # Set this new alias as the canonical alias for this room
         self.helper.send_state(
@@ -647,7 +1273,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
             url.encode("ascii"),
             access_token=self.admin_user_tok,
         )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
 
         # Check that rooms were returned
         self.assertTrue("rooms" in channel.json_body)
@@ -1107,7 +1733,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
             {"room_id": room_id},
             access_token=admin_user_tok,
         )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
 
         # Set this new alias as the canonical alias for this room
         self.helper.send_state(
@@ -1157,11 +1783,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "POST",
             self.url,
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.second_tok,
         )
 
-        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
     def test_invalid_parameter(self):
@@ -1173,11 +1799,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "POST",
             self.url,
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
 
     def test_local_user_does_not_exist(self):
@@ -1189,11 +1815,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "POST",
             self.url,
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
 
     def test_remote_user(self):
@@ -1205,11 +1831,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "POST",
             self.url,
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual(
             "This endpoint can only be used with local users",
             channel.json_body["error"],
@@ -1225,11 +1851,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "POST",
             url,
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual("No known servers", channel.json_body["error"])
 
     def test_room_is_not_valid(self):
@@ -1242,11 +1868,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "POST",
             url,
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual(
             "invalidroom was not legal room ID or room alias",
             channel.json_body["error"],
@@ -1261,11 +1887,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "POST",
             self.url,
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(self.public_room_id, channel.json_body["room_id"])
 
         # Validate if user is a member of the room
@@ -1275,7 +1901,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
             "/_matrix/client/r0/joined_rooms",
             access_token=self.second_tok,
         )
-        self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEquals(200, channel.code, msg=channel.json_body)
         self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
 
     def test_join_private_room_if_not_member(self):
@@ -1292,11 +1918,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "POST",
             url,
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
     def test_join_private_room_if_member(self):
@@ -1324,7 +1950,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
             "/_matrix/client/r0/joined_rooms",
             access_token=self.admin_user_tok,
         )
-        self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEquals(200, channel.code, msg=channel.json_body)
         self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
 
         # Join user to room.
@@ -1335,10 +1961,10 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "POST",
             url,
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(private_room_id, channel.json_body["room_id"])
 
         # Validate if user is a member of the room
@@ -1348,7 +1974,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
             "/_matrix/client/r0/joined_rooms",
             access_token=self.second_tok,
         )
-        self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEquals(200, channel.code, msg=channel.json_body)
         self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
 
     def test_join_private_room_if_owner(self):
@@ -1365,11 +1991,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "POST",
             url,
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(private_room_id, channel.json_body["room_id"])
 
         # Validate if user is a member of the room
@@ -1379,7 +2005,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
             "/_matrix/client/r0/joined_rooms",
             access_token=self.second_tok,
         )
-        self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEquals(200, channel.code, msg=channel.json_body)
         self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
 
     def test_context_as_non_admin(self):
@@ -1413,9 +2039,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
                 % (room_id, events[midway]["event_id"]),
                 access_token=tok,
             )
-            self.assertEquals(
-                403, int(channel.result["code"]), msg=channel.result["body"]
-            )
+            self.assertEquals(403, channel.code, msg=channel.json_body)
             self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
     def test_context_as_admin(self):
@@ -1445,7 +2069,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
             % (room_id, events[midway]["event_id"]),
             access_token=self.admin_user_tok,
         )
-        self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEquals(200, channel.code, msg=channel.json_body)
         self.assertEquals(
             channel.json_body["event"]["event_id"], events[midway]["event_id"]
         )
@@ -1504,7 +2128,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
 
         # Now we test that we can join the room and ban a user.
         self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok)
@@ -1531,7 +2155,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
 
         # Now we test that we can join the room (we should have received an
         # invite) and can ban a user.
@@ -1557,7 +2181,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
 
         # Now we test that we can join the room and ban a user.
         self.helper.join(room_id, self.second_user_id, tok=self.second_tok)
@@ -1595,13 +2219,241 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
         #
         # (Note we assert the error message to ensure that it's not denied for
         # some other reason)
-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual(
             channel.json_body["error"],
             "No local admin user in room with power to update power levels.",
         )
 
 
+class BlockRoomTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self._store = hs.get_datastore()
+
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.other_user = self.register_user("user", "pass")
+        self.other_user_tok = self.login("user", "pass")
+
+        self.room_id = self.helper.create_room_as(
+            self.other_user, tok=self.other_user_tok
+        )
+        self.url = "/_synapse/admin/v1/rooms/%s/block"
+
+    @parameterized.expand([("PUT",), ("GET",)])
+    def test_requester_is_no_admin(self, method: str):
+        """If the user is not a server admin, an error 403 is returned."""
+
+        channel = self.make_request(
+            method,
+            self.url % self.room_id,
+            content={},
+            access_token=self.other_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    @parameterized.expand([("PUT",), ("GET",)])
+    def test_room_is_not_valid(self, method: str):
+        """Check that invalid room names, return an error 400."""
+
+        channel = self.make_request(
+            method,
+            self.url % "invalidroom",
+            content={},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+        self.assertEqual(
+            "invalidroom is not a legal room ID",
+            channel.json_body["error"],
+        )
+
+    def test_block_is_not_valid(self):
+        """If parameter `block` is not valid, return an error."""
+
+        # `block` is not valid
+        channel = self.make_request(
+            "PUT",
+            self.url % self.room_id,
+            content={"block": "NotBool"},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+        # `block` is not set
+        channel = self.make_request(
+            "PUT",
+            self.url % self.room_id,
+            content={},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
+
+        # no content is send
+        channel = self.make_request(
+            "PUT",
+            self.url % self.room_id,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"])
+
+    def test_block_room(self):
+        """Test that block a room is successful."""
+
+        def _request_and_test_block_room(room_id: str) -> None:
+            self._is_blocked(room_id, expect=False)
+            channel = self.make_request(
+                "PUT",
+                self.url % room_id,
+                content={"block": True},
+                access_token=self.admin_user_tok,
+            )
+            self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+            self.assertTrue(channel.json_body["block"])
+            self._is_blocked(room_id, expect=True)
+
+        # known internal room
+        _request_and_test_block_room(self.room_id)
+
+        # unknown internal room
+        _request_and_test_block_room("!unknown:test")
+
+        # unknown remote room
+        _request_and_test_block_room("!unknown:remote")
+
+    def test_block_room_twice(self):
+        """Test that block a room that is already blocked is successful."""
+
+        self._is_blocked(self.room_id, expect=False)
+        for _ in range(2):
+            channel = self.make_request(
+                "PUT",
+                self.url % self.room_id,
+                content={"block": True},
+                access_token=self.admin_user_tok,
+            )
+            self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+            self.assertTrue(channel.json_body["block"])
+            self._is_blocked(self.room_id, expect=True)
+
+    def test_unblock_room(self):
+        """Test that unblock a room is successful."""
+
+        def _request_and_test_unblock_room(room_id: str) -> None:
+            self._block_room(room_id)
+
+            channel = self.make_request(
+                "PUT",
+                self.url % room_id,
+                content={"block": False},
+                access_token=self.admin_user_tok,
+            )
+            self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+            self.assertFalse(channel.json_body["block"])
+            self._is_blocked(room_id, expect=False)
+
+        # known internal room
+        _request_and_test_unblock_room(self.room_id)
+
+        # unknown internal room
+        _request_and_test_unblock_room("!unknown:test")
+
+        # unknown remote room
+        _request_and_test_unblock_room("!unknown:remote")
+
+    def test_unblock_room_twice(self):
+        """Test that unblock a room that is not blocked is successful."""
+
+        self._block_room(self.room_id)
+        for _ in range(2):
+            channel = self.make_request(
+                "PUT",
+                self.url % self.room_id,
+                content={"block": False},
+                access_token=self.admin_user_tok,
+            )
+            self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+            self.assertFalse(channel.json_body["block"])
+            self._is_blocked(self.room_id, expect=False)
+
+    def test_get_blocked_room(self):
+        """Test get status of a blocked room"""
+
+        def _request_blocked_room(room_id: str) -> None:
+            self._block_room(room_id)
+
+            channel = self.make_request(
+                "GET",
+                self.url % room_id,
+                access_token=self.admin_user_tok,
+            )
+            self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+            self.assertTrue(channel.json_body["block"])
+            self.assertEqual(self.other_user, channel.json_body["user_id"])
+
+        # known internal room
+        _request_blocked_room(self.room_id)
+
+        # unknown internal room
+        _request_blocked_room("!unknown:test")
+
+        # unknown remote room
+        _request_blocked_room("!unknown:remote")
+
+    def test_get_unblocked_room(self):
+        """Test get status of a unblocked room"""
+
+        def _request_unblocked_room(room_id: str) -> None:
+            self._is_blocked(room_id, expect=False)
+
+            channel = self.make_request(
+                "GET",
+                self.url % room_id,
+                access_token=self.admin_user_tok,
+            )
+            self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+            self.assertFalse(channel.json_body["block"])
+            self.assertNotIn("user_id", channel.json_body)
+
+        # known internal room
+        _request_unblocked_room(self.room_id)
+
+        # unknown internal room
+        _request_unblocked_room("!unknown:test")
+
+        # unknown remote room
+        _request_unblocked_room("!unknown:remote")
+
+    def _is_blocked(self, room_id: str, expect: bool = True) -> None:
+        """Assert that the room is blocked or not"""
+        d = self._store.is_room_blocked(room_id)
+        if expect:
+            self.assertTrue(self.get_success(d))
+        else:
+            self.assertIsNone(self.get_success(d))
+
+    def _block_room(self, room_id: str) -> None:
+        """Block a room in database"""
+        self.get_success(self._store.block_room(room_id, self.other_user))
+        self._is_blocked(room_id, expect=True)
+
+
 PURGE_TABLES = [
     "current_state_events",
     "event_backward_extremities",
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 25e8d6cf27..5011e54563 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1169,14 +1169,14 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         # regardless of whether password login or SSO is allowed
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.get_success(
-            self.auth_handler.get_access_token_for_user_id(
+            self.auth_handler.create_access_token_for_user_id(
                 self.admin_user, device_id=None, valid_until_ms=None
             )
         )
 
         self.other_user = self.register_user("user", "pass", displayname="User")
         self.other_user_token = self.get_success(
-            self.auth_handler.get_access_token_for_user_id(
+            self.auth_handler.create_access_token_for_user_id(
                 self.other_user, device_id=None, valid_until_ms=None
             )
         )
@@ -3592,31 +3592,34 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
             self.other_user
         )
 
-    def test_no_auth(self):
+    @parameterized.expand(["POST", "DELETE"])
+    def test_no_auth(self, method: str):
         """
         Try to get information of an user without authentication.
         """
-        channel = self.make_request("POST", self.url)
+        channel = self.make_request(method, self.url)
         self.assertEqual(401, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
 
-    def test_requester_is_not_admin(self):
+    @parameterized.expand(["POST", "DELETE"])
+    def test_requester_is_not_admin(self, method: str):
         """
         If the user is not a server admin, an error is returned.
         """
         other_user_token = self.login("user", "pass")
 
-        channel = self.make_request("POST", self.url, access_token=other_user_token)
+        channel = self.make_request(method, self.url, access_token=other_user_token)
         self.assertEqual(403, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
-    def test_user_is_not_local(self):
+    @parameterized.expand(["POST", "DELETE"])
+    def test_user_is_not_local(self, method: str):
         """
         Tests that shadow-banning for a user that is not a local returns a 400
         """
         url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
 
-        channel = self.make_request("POST", url, access_token=self.admin_user_tok)
+        channel = self.make_request(method, url, access_token=self.admin_user_tok)
         self.assertEqual(400, channel.code, msg=channel.json_body)
 
     def test_success(self):
@@ -3636,6 +3639,17 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
         result = self.get_success(self.store.get_user_by_access_token(other_user_token))
         self.assertTrue(result.shadow_banned)
 
+        # Un-shadow-ban the user.
+        channel = self.make_request(
+            "DELETE", self.url, access_token=self.admin_user_tok
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual({}, channel.json_body)
+
+        # Ensure the user is no longer shadow-banned (and the cache was cleared).
+        result = self.get_success(self.store.get_user_by_access_token(other_user_token))
+        self.assertFalse(result.shadow_banned)
+
 
 class RateLimitTestCase(unittest.HomeserverTestCase):
 
diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py
index b9e3602552..249808b031 100644
--- a/tests/rest/client/test_capabilities.py
+++ b/tests/rest/client/test_capabilities.py
@@ -71,7 +71,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
     @override_config({"password_config": {"localdb_enabled": False}})
     def test_get_change_password_capabilities_localdb_disabled(self):
         access_token = self.get_success(
-            self.auth_handler.get_access_token_for_user_id(
+            self.auth_handler.create_access_token_for_user_id(
                 self.user, device_id=None, valid_until_ms=None
             )
         )
@@ -85,7 +85,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
     @override_config({"password_config": {"enabled": False}})
     def test_get_change_password_capabilities_password_disabled(self):
         access_token = self.get_success(
-            self.auth_handler.get_access_token_for_user_id(
+            self.auth_handler.create_access_token_for_user_id(
                 self.user, device_id=None, valid_until_ms=None
             )
         )
@@ -174,7 +174,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
     @override_config({"experimental_features": {"msc3244_enabled": False}})
     def test_get_does_not_include_msc3244_fields_when_disabled(self):
         access_token = self.get_success(
-            self.auth_handler.get_access_token_for_user_id(
+            self.auth_handler.create_access_token_for_user_id(
                 self.user, device_id=None, valid_until_ms=None
             )
         )
@@ -189,7 +189,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
 
     def test_get_does_include_msc3244_fields_when_enabled(self):
         access_token = self.get_success(
-            self.auth_handler.get_access_token_for_user_id(
+            self.auth_handler.create_access_token_for_user_id(
                 self.user, device_id=None, valid_until_ms=None
             )
         )
diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py
index d2181ea907..aca03afd0e 100644
--- a/tests/rest/client/test_directory.py
+++ b/tests/rest/client/test_directory.py
@@ -11,12 +11,16 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import json
+from http import HTTPStatus
+
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.rest import admin
 from synapse.rest.client import directory, login, room
+from synapse.server import HomeServer
 from synapse.types import RoomAlias
+from synapse.util import Clock
 from synapse.util.stringutils import random_string
 
 from tests import unittest
@@ -32,7 +36,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         config = self.default_config()
         config["require_membership_for_aliases"] = True
 
@@ -40,7 +44,11 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
 
         return self.hs
 
-    def prepare(self, reactor, clock, homeserver):
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+    ) -> None:
+        """Create two local users and access tokens for them.
+        One of them creates a room."""
         self.room_owner = self.register_user("room_owner", "test")
         self.room_owner_tok = self.login("room_owner", "test")
 
@@ -51,39 +59,39 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         self.user = self.register_user("user", "test")
         self.user_tok = self.login("user", "test")
 
-    def test_state_event_not_in_room(self):
+    def test_state_event_not_in_room(self) -> None:
         self.ensure_user_left_room()
-        self.set_alias_via_state_event(403)
+        self.set_alias_via_state_event(HTTPStatus.FORBIDDEN)
 
-    def test_directory_endpoint_not_in_room(self):
+    def test_directory_endpoint_not_in_room(self) -> None:
         self.ensure_user_left_room()
-        self.set_alias_via_directory(403)
+        self.set_alias_via_directory(HTTPStatus.FORBIDDEN)
 
-    def test_state_event_in_room_too_long(self):
+    def test_state_event_in_room_too_long(self) -> None:
         self.ensure_user_joined_room()
-        self.set_alias_via_state_event(400, alias_length=256)
+        self.set_alias_via_state_event(HTTPStatus.BAD_REQUEST, alias_length=256)
 
-    def test_directory_in_room_too_long(self):
+    def test_directory_in_room_too_long(self) -> None:
         self.ensure_user_joined_room()
-        self.set_alias_via_directory(400, alias_length=256)
+        self.set_alias_via_directory(HTTPStatus.BAD_REQUEST, alias_length=256)
 
     @override_config({"default_room_version": 5})
-    def test_state_event_user_in_v5_room(self):
+    def test_state_event_user_in_v5_room(self) -> None:
         """Test that a regular user can add alias events before room v6"""
         self.ensure_user_joined_room()
-        self.set_alias_via_state_event(200)
+        self.set_alias_via_state_event(HTTPStatus.OK)
 
     @override_config({"default_room_version": 6})
-    def test_state_event_v6_room(self):
+    def test_state_event_v6_room(self) -> None:
         """Test that a regular user can *not* add alias events from room v6"""
         self.ensure_user_joined_room()
-        self.set_alias_via_state_event(403)
+        self.set_alias_via_state_event(HTTPStatus.FORBIDDEN)
 
-    def test_directory_in_room(self):
+    def test_directory_in_room(self) -> None:
         self.ensure_user_joined_room()
-        self.set_alias_via_directory(200)
+        self.set_alias_via_directory(HTTPStatus.OK)
 
-    def test_room_creation_too_long(self):
+    def test_room_creation_too_long(self) -> None:
         url = "/_matrix/client/r0/createRoom"
 
         # We use deliberately a localpart under the length threshold so
@@ -93,9 +101,9 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "POST", url, request_data, access_token=self.user_tok
         )
-        self.assertEqual(channel.code, 400, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
 
-    def test_room_creation(self):
+    def test_room_creation(self) -> None:
         url = "/_matrix/client/r0/createRoom"
 
         # Check with an alias of allowed length. There should already be
@@ -106,9 +114,46 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "POST", url, request_data, access_token=self.user_tok
         )
-        self.assertEqual(channel.code, 200, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+    def test_deleting_alias_via_directory(self) -> None:
+        # Add an alias for the room. We must be joined to do so.
+        self.ensure_user_joined_room()
+        alias = self.set_alias_via_directory(HTTPStatus.OK)
+
+        # Then try to remove the alias
+        channel = self.make_request(
+            "DELETE",
+            f"/_matrix/client/r0/directory/room/{alias}",
+            access_token=self.user_tok,
+        )
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+    def test_deleting_nonexistant_alias(self) -> None:
+        # Check that no alias exists
+        alias = "#potato:test"
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/r0/directory/room/{alias}",
+            access_token=self.user_tok,
+        )
+        self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
+        self.assertIn("error", channel.json_body, channel.json_body)
+        self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND", channel.json_body)
+
+        # Then try to remove the alias
+        channel = self.make_request(
+            "DELETE",
+            f"/_matrix/client/r0/directory/room/{alias}",
+            access_token=self.user_tok,
+        )
+        self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
+        self.assertIn("error", channel.json_body, channel.json_body)
+        self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND", channel.json_body)
 
-    def set_alias_via_state_event(self, expected_code, alias_length=5):
+    def set_alias_via_state_event(
+        self, expected_code: HTTPStatus, alias_length: int = 5
+    ) -> None:
         url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % (
             self.room_id,
             self.hs.hostname,
@@ -122,8 +167,11 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(channel.code, expected_code, channel.result)
 
-    def set_alias_via_directory(self, expected_code, alias_length=5):
-        url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length)
+    def set_alias_via_directory(
+        self, expected_code: HTTPStatus, alias_length: int = 5
+    ) -> str:
+        alias = self.random_alias(alias_length)
+        url = "/_matrix/client/r0/directory/room/%s" % alias
         data = {"room_id": self.room_id}
         request_data = json.dumps(data)
 
@@ -131,17 +179,18 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
             "PUT", url, request_data, access_token=self.user_tok
         )
         self.assertEqual(channel.code, expected_code, channel.result)
+        return alias
 
-    def random_alias(self, length):
+    def random_alias(self, length: int) -> str:
         return RoomAlias(random_string(length), self.hs.hostname).to_string()
 
-    def ensure_user_left_room(self):
+    def ensure_user_left_room(self) -> None:
         self.ensure_membership("leave")
 
-    def ensure_user_joined_room(self):
+    def ensure_user_joined_room(self) -> None:
         self.ensure_membership("join")
 
-    def ensure_membership(self, membership):
+    def ensure_membership(self, membership: str) -> None:
         try:
             if membership == "leave":
                 self.helper.leave(room=self.room_id, user=self.user, tok=self.user_tok)
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index a63f04bd41..19f5e46537 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -79,7 +79,10 @@ EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fΓΆ&=o"')]
 
 # (possibly experimental) login flows we expect to appear in the list after the normal
 # ones
-ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}]
+ADDITIONAL_LOGIN_FLOWS = [
+    {"type": "m.login.application_service"},
+    {"type": "uk.half-shot.msc2778.login.application_service"},
+]
 
 
 class LoginRestServletTestCase(unittest.HomeserverTestCase):
@@ -812,13 +815,20 @@ class JWTTestCase(unittest.HomeserverTestCase):
 
     jwt_secret = "secret"
     jwt_algorithm = "HS256"
+    base_config = {
+        "enabled": True,
+        "secret": jwt_secret,
+        "algorithm": jwt_algorithm,
+    }
 
-    def make_homeserver(self, reactor, clock):
-        self.hs = self.setup_test_homeserver()
-        self.hs.config.jwt.jwt_enabled = True
-        self.hs.config.jwt.jwt_secret = self.jwt_secret
-        self.hs.config.jwt.jwt_algorithm = self.jwt_algorithm
-        return self.hs
+    def default_config(self):
+        config = super().default_config()
+
+        # If jwt_config has been defined (eg via @override_config), don't replace it.
+        if config.get("jwt_config") is None:
+            config["jwt_config"] = self.base_config
+
+        return config
 
     def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
         # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
@@ -876,16 +886,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(channel.json_body["error"], "Invalid JWT")
 
-    @override_config(
-        {
-            "jwt_config": {
-                "jwt_enabled": True,
-                "secret": jwt_secret,
-                "algorithm": jwt_algorithm,
-                "issuer": "test-issuer",
-            }
-        }
-    )
+    @override_config({"jwt_config": {**base_config, "issuer": "test-issuer"}})
     def test_login_iss(self):
         """Test validating the issuer claim."""
         # A valid issuer.
@@ -916,16 +917,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.result["code"], b"200", channel.result)
         self.assertEqual(channel.json_body["user_id"], "@kermit:test")
 
-    @override_config(
-        {
-            "jwt_config": {
-                "jwt_enabled": True,
-                "secret": jwt_secret,
-                "algorithm": jwt_algorithm,
-                "audiences": ["test-audience"],
-            }
-        }
-    )
+    @override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}})
     def test_login_aud(self):
         """Test validating the audience claim."""
         # A valid audience.
@@ -959,6 +951,19 @@ class JWTTestCase(unittest.HomeserverTestCase):
             channel.json_body["error"], "JWT validation failed: Invalid audience"
         )
 
+    def test_login_default_sub(self):
+        """Test reading user ID from the default subject claim."""
+        channel = self.jwt_login({"sub": "kermit"})
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+    @override_config({"jwt_config": {**base_config, "subject_claim": "username"}})
+    def test_login_custom_sub(self):
+        """Test reading user ID from a custom subject claim."""
+        channel = self.jwt_login({"username": "frog"})
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.json_body["user_id"], "@frog:test")
+
     def test_login_no_token(self):
         params = {"type": "org.matrix.login.jwt"}
         channel = self.make_request(b"POST", LOGIN_URL, params)
@@ -1021,12 +1026,14 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
         ]
     )
 
-    def make_homeserver(self, reactor, clock):
-        self.hs = self.setup_test_homeserver()
-        self.hs.config.jwt.jwt_enabled = True
-        self.hs.config.jwt.jwt_secret = self.jwt_pubkey
-        self.hs.config.jwt.jwt_algorithm = "RS256"
-        return self.hs
+    def default_config(self):
+        config = super().default_config()
+        config["jwt_config"] = {
+            "enabled": True,
+            "secret": self.jwt_pubkey,
+            "algorithm": "RS256",
+        }
+        return config
 
     def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
         # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 78c2fb86b9..eb10d43217 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -1,4 +1,5 @@
 # Copyright 2019 New Vector Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -46,6 +47,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         return config
 
     def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+
         self.user_id, self.user_token = self._create_user("alice")
         self.user2_id, self.user2_token = self._create_user("bob")
 
@@ -91,6 +94,49 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member)
         self.assertEquals(400, channel.code, channel.json_body)
 
+    def test_deny_invalid_event(self):
+        """Test that we deny relations on non-existant events"""
+        channel = self._send_relation(
+            RelationTypes.ANNOTATION,
+            EventTypes.Message,
+            parent_id="foo",
+            content={"body": "foo", "msgtype": "m.text"},
+        )
+        self.assertEquals(400, channel.code, channel.json_body)
+
+        # Unless that event is referenced from another event!
+        self.get_success(
+            self.hs.get_datastore().db_pool.simple_insert(
+                table="event_relations",
+                values={
+                    "event_id": "bar",
+                    "relates_to_id": "foo",
+                    "relation_type": RelationTypes.THREAD,
+                },
+                desc="test_deny_invalid_event",
+            )
+        )
+        channel = self._send_relation(
+            RelationTypes.THREAD,
+            EventTypes.Message,
+            parent_id="foo",
+            content={"body": "foo", "msgtype": "m.text"},
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+
+    def test_deny_invalid_room(self):
+        """Test that we deny relations on non-existant events"""
+        # Create another room and send a message in it.
+        room2 = self.helper.create_room_as(self.user_id, tok=self.user_token)
+        res = self.helper.send(room2, body="Hi!", tok=self.user_token)
+        parent_id = res["event_id"]
+
+        # Attempt to send an annotation to that event.
+        channel = self._send_relation(
+            RelationTypes.ANNOTATION, "m.reaction", parent_id=parent_id, key="A"
+        )
+        self.assertEquals(400, channel.code, channel.json_body)
+
     def test_deny_double_react(self):
         """Test that we deny relations on membership events"""
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
@@ -99,6 +145,25 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
         self.assertEquals(400, channel.code, channel.json_body)
 
+    def test_deny_forked_thread(self):
+        """It is invalid to start a thread off a thread."""
+        channel = self._send_relation(
+            RelationTypes.THREAD,
+            "m.room.message",
+            content={"msgtype": "m.text", "body": "foo"},
+            parent_id=self.parent_id,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        parent_id = channel.json_body["event_id"]
+
+        channel = self._send_relation(
+            RelationTypes.THREAD,
+            "m.room.message",
+            content={"msgtype": "m.text", "body": "foo"},
+            parent_id=parent_id,
+        )
+        self.assertEquals(400, channel.code, channel.json_body)
+
     def test_basic_paginate_relations(self):
         """Tests that calling pagination API correctly the latest relations."""
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
@@ -703,6 +768,52 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertIn("chunk", channel.json_body)
         self.assertEquals(channel.json_body["chunk"], [])
 
+    def test_unknown_relations(self):
+        """Unknown relations should be accepted."""
+        channel = self._send_relation("m.relation.test", "m.room.test")
+        self.assertEquals(200, channel.code, channel.json_body)
+        event_id = channel.json_body["event_id"]
+
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1"
+            % (self.room, self.parent_id),
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        # We expect to get back a single pagination result, which is the full
+        # relation event we sent above.
+        self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body)
+        self.assert_dict(
+            {"event_id": event_id, "sender": self.user_id, "type": "m.room.test"},
+            channel.json_body["chunk"][0],
+        )
+
+        # We also expect to get the original event (the id of which is self.parent_id)
+        self.assertEquals(
+            channel.json_body["original_event"]["event_id"], self.parent_id
+        )
+
+        # When bundling the unknown relation is not included.
+        channel = self.make_request(
+            "GET",
+            "/rooms/%s/event/%s" % (self.room, self.parent_id),
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertNotIn("m.relations", channel.json_body["unsigned"])
+
+        # But unknown relations can be directly queried.
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1"
+            % (self.room, self.parent_id),
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEquals(channel.json_body["chunk"], [])
+
     def _send_relation(
         self,
         relation_type: str,
@@ -749,3 +860,65 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         access_token = self.login(localpart, "abc123")
 
         return user_id, access_token
+
+    def test_background_update(self):
+        """Test the event_arbitrary_relations background update."""
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="πŸ‘")
+        self.assertEquals(200, channel.code, channel.json_body)
+        annotation_event_id_good = channel.json_body["event_id"]
+
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A")
+        self.assertEquals(200, channel.code, channel.json_body)
+        annotation_event_id_bad = channel.json_body["event_id"]
+
+        channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+        self.assertEquals(200, channel.code, channel.json_body)
+        thread_event_id = channel.json_body["event_id"]
+
+        # Clean-up the table as if the inserts did not happen during event creation.
+        self.get_success(
+            self.store.db_pool.simple_delete_many(
+                table="event_relations",
+                column="event_id",
+                iterable=(annotation_event_id_bad, thread_event_id),
+                keyvalues={},
+                desc="RelationsTestCase.test_background_update",
+            )
+        )
+
+        # Only the "good" annotation should be found.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEquals(
+            [ev["event_id"] for ev in channel.json_body["chunk"]],
+            [annotation_event_id_good],
+        )
+
+        # Insert and run the background update.
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                {"update_name": "event_arbitrary_relations", "progress_json": "{}"},
+            )
+        )
+
+        # Ugh, have to reset this flag
+        self.store.db_pool.updates._all_done = False
+        self.wait_for_background_updates()
+
+        # The "good" annotation and the thread should be found, but not the "bad"
+        # annotation.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertCountEqual(
+            [ev["event_id"] for ev in channel.json_body["chunk"]],
+            [annotation_event_id_good, thread_event_id],
+        )
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 376853fd65..10a4a4dc5e 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -25,7 +25,12 @@ from urllib import parse as urlparse
 from twisted.internet import defer
 
 import synapse.rest.admin
-from synapse.api.constants import EventContentFields, EventTypes, Membership
+from synapse.api.constants import (
+    EventContentFields,
+    EventTypes,
+    Membership,
+    RelationTypes,
+)
 from synapse.api.errors import Codes, HttpResponseException
 from synapse.handlers.pagination import PurgeStatus
 from synapse.rest import admin
@@ -2157,6 +2162,153 @@ class LabelsTestCase(unittest.HomeserverTestCase):
         return event_id
 
 
+class RelationsTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def default_config(self):
+        config = super().default_config()
+        config["experimental_features"] = {"msc3440_enabled": True}
+        return config
+
+    def prepare(self, reactor, clock, homeserver):
+        self.user_id = self.register_user("test", "test")
+        self.tok = self.login("test", "test")
+        self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+        self.second_user_id = self.register_user("second", "test")
+        self.second_tok = self.login("second", "test")
+        self.helper.join(
+            room=self.room_id, user=self.second_user_id, tok=self.second_tok
+        )
+
+        self.third_user_id = self.register_user("third", "test")
+        self.third_tok = self.login("third", "test")
+        self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok)
+
+        # An initial event with a relation from second user.
+        res = self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={"msgtype": "m.text", "body": "Message 1"},
+            tok=self.tok,
+        )
+        self.event_id_1 = res["event_id"]
+        self.helper.send_event(
+            room_id=self.room_id,
+            type="m.reaction",
+            content={
+                "m.relates_to": {
+                    "rel_type": RelationTypes.ANNOTATION,
+                    "event_id": self.event_id_1,
+                    "key": "πŸ‘",
+                }
+            },
+            tok=self.second_tok,
+        )
+
+        # Another event with a relation from third user.
+        res = self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={"msgtype": "m.text", "body": "Message 2"},
+            tok=self.tok,
+        )
+        self.event_id_2 = res["event_id"]
+        self.helper.send_event(
+            room_id=self.room_id,
+            type="m.reaction",
+            content={
+                "m.relates_to": {
+                    "rel_type": RelationTypes.REFERENCE,
+                    "event_id": self.event_id_2,
+                }
+            },
+            tok=self.third_tok,
+        )
+
+        # An event with no relations.
+        self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={"msgtype": "m.text", "body": "No relations"},
+            tok=self.tok,
+        )
+
+    def _filter_messages(self, filter: JsonDict) -> List[JsonDict]:
+        """Make a request to /messages with a filter, returns the chunk of events."""
+        channel = self.make_request(
+            "GET",
+            "/rooms/%s/messages?filter=%s&dir=b" % (self.room_id, json.dumps(filter)),
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+
+        return channel.json_body["chunk"]
+
+    def test_filter_relation_senders(self):
+        # Messages which second user reacted to.
+        filter = {"io.element.relation_senders": [self.second_user_id]}
+        chunk = self._filter_messages(filter)
+        self.assertEqual(len(chunk), 1, chunk)
+        self.assertEqual(chunk[0]["event_id"], self.event_id_1)
+
+        # Messages which third user reacted to.
+        filter = {"io.element.relation_senders": [self.third_user_id]}
+        chunk = self._filter_messages(filter)
+        self.assertEqual(len(chunk), 1, chunk)
+        self.assertEqual(chunk[0]["event_id"], self.event_id_2)
+
+        # Messages which either user reacted to.
+        filter = {
+            "io.element.relation_senders": [self.second_user_id, self.third_user_id]
+        }
+        chunk = self._filter_messages(filter)
+        self.assertEqual(len(chunk), 2, chunk)
+        self.assertCountEqual(
+            [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
+        )
+
+    def test_filter_relation_type(self):
+        # Messages which have annotations.
+        filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]}
+        chunk = self._filter_messages(filter)
+        self.assertEqual(len(chunk), 1, chunk)
+        self.assertEqual(chunk[0]["event_id"], self.event_id_1)
+
+        # Messages which have references.
+        filter = {"io.element.relation_types": [RelationTypes.REFERENCE]}
+        chunk = self._filter_messages(filter)
+        self.assertEqual(len(chunk), 1, chunk)
+        self.assertEqual(chunk[0]["event_id"], self.event_id_2)
+
+        # Messages which have either annotations or references.
+        filter = {
+            "io.element.relation_types": [
+                RelationTypes.ANNOTATION,
+                RelationTypes.REFERENCE,
+            ]
+        }
+        chunk = self._filter_messages(filter)
+        self.assertEqual(len(chunk), 2, chunk)
+        self.assertCountEqual(
+            [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
+        )
+
+    def test_filter_relation_senders_and_type(self):
+        # Messages which second user reacted to.
+        filter = {
+            "io.element.relation_senders": [self.second_user_id],
+            "io.element.relation_types": [RelationTypes.ANNOTATION],
+        }
+        chunk = self._filter_messages(filter)
+        self.assertEqual(len(chunk), 1, chunk)
+        self.assertEqual(chunk[0]["event_id"], self.event_id_1)
+
+
 class ContextTestCase(unittest.HomeserverTestCase):
 
     servlets = [
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index ec0979850b..1af5e5cee5 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -19,10 +19,21 @@ import json
 import re
 import time
 import urllib.parse
-from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union
+from typing import (
+    Any,
+    AnyStr,
+    Dict,
+    Iterable,
+    Mapping,
+    MutableMapping,
+    Optional,
+    Tuple,
+    overload,
+)
 from unittest.mock import patch
 
 import attr
+from typing_extensions import Literal
 
 from twisted.web.resource import Resource
 from twisted.web.server import Site
@@ -45,6 +56,32 @@ class RestHelper:
     site = attr.ib(type=Site)
     auth_user_id = attr.ib()
 
+    @overload
+    def create_room_as(
+        self,
+        room_creator: Optional[str] = ...,
+        is_public: Optional[bool] = ...,
+        room_version: Optional[str] = ...,
+        tok: Optional[str] = ...,
+        expect_code: Literal[200] = ...,
+        extra_content: Optional[Dict] = ...,
+        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ...,
+    ) -> str:
+        ...
+
+    @overload
+    def create_room_as(
+        self,
+        room_creator: Optional[str] = ...,
+        is_public: Optional[bool] = ...,
+        room_version: Optional[str] = ...,
+        tok: Optional[str] = ...,
+        expect_code: int = ...,
+        extra_content: Optional[Dict] = ...,
+        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ...,
+    ) -> Optional[str]:
+        ...
+
     def create_room_as(
         self,
         room_creator: Optional[str] = None,
@@ -53,10 +90,8 @@ class RestHelper:
         tok: Optional[str] = None,
         expect_code: int = 200,
         extra_content: Optional[Dict] = None,
-        custom_headers: Optional[
-            Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
-        ] = None,
-    ) -> str:
+        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
+    ) -> Optional[str]:
         """
         Create a room.
 
@@ -99,6 +134,8 @@ class RestHelper:
 
         if expect_code == 200:
             return channel.json_body["room_id"]
+        else:
+            return None
 
     def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
         self.change_membership(
@@ -168,7 +205,7 @@ class RestHelper:
         extra_data: Optional[dict] = None,
         tok: Optional[str] = None,
         expect_code: int = 200,
-        expect_errcode: str = None,
+        expect_errcode: Optional[str] = None,
     ) -> None:
         """
         Send a membership state event into a room.
@@ -227,9 +264,7 @@ class RestHelper:
         txn_id=None,
         tok=None,
         expect_code=200,
-        custom_headers: Optional[
-            Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
-        ] = None,
+        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
     ):
         if body is None:
             body = "body_text_here"
@@ -254,9 +289,7 @@ class RestHelper:
         txn_id=None,
         tok=None,
         expect_code=200,
-        custom_headers: Optional[
-            Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
-        ] = None,
+        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
     ):
         if txn_id is None:
             txn_id = "m%s" % (str(time.time()))
@@ -418,7 +451,7 @@ class RestHelper:
             path,
             content=image_data,
             access_token=tok,
-            custom_headers=[(b"Content-Length", str(image_length))],
+            custom_headers=[("Content-Length", str(image_length))],
         )
 
         assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
@@ -503,7 +536,7 @@ class RestHelper:
             went.
         """
 
-        cookies = {}
+        cookies: Dict[str, str] = {}
 
         # if we're doing a ui auth, hit the ui auth redirect endpoint
         if ui_auth_session_id:
@@ -625,7 +658,13 @@ class RestHelper:
 
         # hit the redirect url again with the right Host header, which should now issue
         # a cookie and redirect to the SSO provider.
-        location = channel.headers.getRawHeaders("Location")[0]
+        def get_location(channel: FakeChannel) -> str:
+            location_values = channel.headers.getRawHeaders("Location")
+            # Keep mypy happy by asserting that location_values is nonempty
+            assert location_values
+            return location_values[0]
+
+        location = get_location(channel)
         parts = urllib.parse.urlsplit(location)
         channel = make_request(
             self.hs.get_reactor(),
@@ -639,7 +678,7 @@ class RestHelper:
 
         assert channel.code == 302
         channel.extract_cookies(cookies)
-        return channel.headers.getRawHeaders("Location")[0]
+        return get_location(channel)
 
     def initiate_sso_ui_auth(
         self, ui_auth_session_id: str, cookies: MutableMapping[str, str]