summary refs log tree commit diff
path: root/tests/rest
diff options
context:
space:
mode:
authorBrendan Abolivier <babolivier@matrix.org>2022-03-28 13:54:02 +0100
committerBrendan Abolivier <babolivier@matrix.org>2022-03-28 13:54:02 +0100
commit25507bffc67c40e83cbcd4a79fdfee3667855a7c (patch)
tree5620b2a06a5a9894ac875ddcf3b232db45cae48d /tests/rest
parentMerge branch 'develop' of github.com:matrix-org/synapse into babolivier/sign_... (diff)
parentAdd restrictions by default to open registration in Synapse (#12091) (diff)
downloadsynapse-babolivier/sign_json_module.tar.xz
Merge branch 'develop' into babolivier/sign_json_module github/babolivier/sign_json_module babolivier/sign_json_module
Diffstat (limited to 'tests/rest')
-rw-r--r--tests/rest/admin/test_background_updates.py21
-rw-r--r--tests/rest/admin/test_user.py19
-rw-r--r--tests/rest/client/test_account.py58
-rw-r--r--tests/rest/client/test_mutual_rooms.py (renamed from tests/rest/client/test_shared_rooms.py)30
-rw-r--r--tests/rest/client/test_relations.py1536
-rw-r--r--tests/rest/client/test_retention.py29
-rw-r--r--tests/rest/client/test_rooms.py18
-rw-r--r--tests/rest/client/test_third_party_rules.py121
-rw-r--r--tests/rest/client/test_transactions.py19
-rw-r--r--tests/rest/key/v2/test_remote_key_resource.py44
-rw-r--r--tests/rest/media/v1/test_base.py4
-rw-r--r--tests/rest/media/v1/test_filepath.py48
-rw-r--r--tests/rest/media/v1/test_html_preview.py102
-rw-r--r--tests/rest/media/v1/test_media_storage.py110
-rw-r--r--tests/rest/media/v1/test_oembed.py10
-rw-r--r--tests/rest/media/v1/test_url_preview.py83
-rw-r--r--tests/rest/test_health.py8
-rw-r--r--tests/rest/test_well_known.py20
18 files changed, 1341 insertions, 939 deletions
diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py
index fb36aa9940..6cf56b1e35 100644
--- a/tests/rest/admin/test_background_updates.py
+++ b/tests/rest/admin/test_background_updates.py
@@ -39,6 +39,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
         self.store = hs.get_datastores().main
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
+        self.updater = BackgroundUpdater(hs, self.store.db_pool)
 
     @parameterized.expand(
         [
@@ -135,10 +136,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
         """Test the status API works with a background update."""
 
         # Create a new background update
-
         self._register_bg_update()
 
         self.store.db_pool.updates.start_doing_background_updates()
+
         self.reactor.pump([1.0, 1.0, 1.0])
 
         channel = self.make_request(
@@ -155,10 +156,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
                 "current_updates": {
                     "master": {
                         "name": "test_update",
-                        "average_items_per_ms": 0.001,
+                        "average_items_per_ms": 0.1,
                         "total_duration_ms": 1000.0,
                         "total_item_count": (
-                            BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
+                            self.updater.default_background_batch_size
                         ),
                     }
                 },
@@ -210,10 +211,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
                 "current_updates": {
                     "master": {
                         "name": "test_update",
-                        "average_items_per_ms": 0.001,
+                        "average_items_per_ms": 0.1,
                         "total_duration_ms": 1000.0,
                         "total_item_count": (
-                            BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
+                            self.updater.default_background_batch_size
                         ),
                     }
                 },
@@ -239,10 +240,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
                 "current_updates": {
                     "master": {
                         "name": "test_update",
-                        "average_items_per_ms": 0.001,
+                        "average_items_per_ms": 0.1,
                         "total_duration_ms": 1000.0,
                         "total_item_count": (
-                            BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
+                            self.updater.default_background_batch_size
                         ),
                     }
                 },
@@ -278,11 +279,9 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
                 "current_updates": {
                     "master": {
                         "name": "test_update",
-                        "average_items_per_ms": 0.001,
+                        "average_items_per_ms": 0.05263157894736842,
                         "total_duration_ms": 2000.0,
-                        "total_item_count": (
-                            2 * BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
-                        ),
+                        "total_item_count": (110),
                     }
                 },
                 "enabled": True,
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index a60ea0a563..bef911d5df 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1050,6 +1050,25 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
 
         self._is_erased("@user:test", True)
 
+    @override_config({"max_avatar_size": 1234})
+    def test_deactivate_user_erase_true_avatar_nonnull_but_empty(self) -> None:
+        """Check we can erase a user whose avatar is the empty string.
+
+        Reproduces #12257.
+        """
+        # Patch `self.other_user` to have an empty string as their avatar.
+        self.get_success(self.store.set_profile_avatar_url("user", ""))
+
+        # Check we can still erase them.
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+            content={"erase": True},
+        )
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self._is_erased("@user:test", True)
+
     def test_deactivate_user_erase_false(self) -> None:
         """
         Test deactivating a user and set `erase` to `false`
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index def836054d..27946febff 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -31,7 +31,7 @@ from synapse.rest import admin
 from synapse.rest.client import account, login, register, room
 from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
 from synapse.server import HomeServer
-from synapse.types import JsonDict
+from synapse.types import JsonDict, UserID
 from synapse.util import Clock
 
 from tests import unittest
@@ -1222,6 +1222,62 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
             expected_failures=[users[2]],
         )
 
+    @unittest.override_config(
+        {
+            "use_account_validity_in_account_status": True,
+        }
+    )
+    def test_no_account_validity(self) -> None:
+        """Tests that if we decide to include account validity in the response but no
+        account validity 'is_user_expired' callback is provided, we default to marking all
+        users as not expired.
+        """
+        user = self.register_user("someuser", "password")
+
+        self._test_status(
+            users=[user],
+            expected_statuses={
+                user: {
+                    "exists": True,
+                    "deactivated": False,
+                    "org.matrix.expired": False,
+                },
+            },
+            expected_failures=[],
+        )
+
+    @unittest.override_config(
+        {
+            "use_account_validity_in_account_status": True,
+        }
+    )
+    def test_account_validity_expired(self) -> None:
+        """Test that if we decide to include account validity in the response and the user
+        is expired, we return the correct info.
+        """
+        user = self.register_user("someuser", "password")
+
+        async def is_expired(user_id: str) -> bool:
+            # We can't blindly say everyone is expired, otherwise the request to get the
+            # account status will fail.
+            return UserID.from_string(user_id).localpart == "someuser"
+
+        self.hs.get_account_validity_handler()._is_user_expired_callbacks.append(
+            is_expired
+        )
+
+        self._test_status(
+            users=[user],
+            expected_statuses={
+                user: {
+                    "exists": True,
+                    "deactivated": False,
+                    "org.matrix.expired": True,
+                },
+            },
+            expected_failures=[],
+        )
+
     def _test_status(
         self,
         users: Optional[List[str]],
diff --git a/tests/rest/client/test_shared_rooms.py b/tests/rest/client/test_mutual_rooms.py
index 3818b7b14b..7b7d283bb6 100644
--- a/tests/rest/client/test_shared_rooms.py
+++ b/tests/rest/client/test_mutual_rooms.py
@@ -14,7 +14,7 @@
 from twisted.test.proto_helpers import MemoryReactor
 
 import synapse.rest.admin
-from synapse.rest.client import login, room, shared_rooms
+from synapse.rest.client import login, mutual_rooms, room
 from synapse.server import HomeServer
 from synapse.util import Clock
 
@@ -22,16 +22,16 @@ from tests import unittest
 from tests.server import FakeChannel
 
 
-class UserSharedRoomsTest(unittest.HomeserverTestCase):
+class UserMutualRoomsTest(unittest.HomeserverTestCase):
     """
-    Tests the UserSharedRoomsServlet.
+    Tests the UserMutualRoomsServlet.
     """
 
     servlets = [
         login.register_servlets,
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         room.register_servlets,
-        shared_rooms.register_servlets,
+        mutual_rooms.register_servlets,
     ]
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
@@ -43,10 +43,10 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
         self.store = hs.get_datastores().main
         self.handler = hs.get_user_directory_handler()
 
-    def _get_shared_rooms(self, token: str, other_user: str) -> FakeChannel:
+    def _get_mutual_rooms(self, token: str, other_user: str) -> FakeChannel:
         return self.make_request(
             "GET",
-            "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s"
+            "/_matrix/client/unstable/uk.half-shot.msc2666/user/mutual_rooms/%s"
             % other_user,
             access_token=token,
         )
@@ -56,14 +56,14 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
         A room should show up in the shared list of rooms between two users
         if it is public.
         """
-        self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True)
+        self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=True)
 
     def test_shared_room_list_private(self) -> None:
         """
         A room should show up in the shared list of rooms between two users
         if it is private.
         """
-        self._check_shared_rooms_with(
+        self._check_mutual_rooms_with(
             room_one_is_public=False, room_two_is_public=False
         )
 
@@ -72,9 +72,9 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
         The shared room list between two users should contain both public and private
         rooms.
         """
-        self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=False)
+        self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=False)
 
-    def _check_shared_rooms_with(
+    def _check_mutual_rooms_with(
         self, room_one_is_public: bool, room_two_is_public: bool
     ) -> None:
         """Checks that shared public or private rooms between two users appear in
@@ -94,7 +94,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
 
         # Check shared rooms from user1's perspective.
         # We should see the one room in common
-        channel = self._get_shared_rooms(u1_token, u2)
+        channel = self._get_mutual_rooms(u1_token, u2)
         self.assertEqual(200, channel.code, channel.result)
         self.assertEqual(len(channel.json_body["joined"]), 1)
         self.assertEqual(channel.json_body["joined"][0], room_id_one)
@@ -107,7 +107,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
         self.helper.join(room_id_two, user=u2, tok=u2_token)
 
         # Check shared rooms again. We should now see both rooms.
-        channel = self._get_shared_rooms(u1_token, u2)
+        channel = self._get_mutual_rooms(u1_token, u2)
         self.assertEqual(200, channel.code, channel.result)
         self.assertEqual(len(channel.json_body["joined"]), 2)
         for room_id_id in channel.json_body["joined"]:
@@ -128,7 +128,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
         self.helper.join(room, user=u2, tok=u2_token)
 
         # Assert user directory is not empty
-        channel = self._get_shared_rooms(u1_token, u2)
+        channel = self._get_mutual_rooms(u1_token, u2)
         self.assertEqual(200, channel.code, channel.result)
         self.assertEqual(len(channel.json_body["joined"]), 1)
         self.assertEqual(channel.json_body["joined"][0], room)
@@ -136,11 +136,11 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
         self.helper.leave(room, user=u1, tok=u1_token)
 
         # Check user1's view of shared rooms with user2
-        channel = self._get_shared_rooms(u1_token, u2)
+        channel = self._get_mutual_rooms(u1_token, u2)
         self.assertEqual(200, channel.code, channel.result)
         self.assertEqual(len(channel.json_body["joined"]), 0)
 
         # Check user2's view of shared rooms with user1
-        channel = self._get_shared_rooms(u2_token, u1)
+        channel = self._get_mutual_rooms(u2_token, u1)
         self.assertEqual(200, channel.code, channel.result)
         self.assertEqual(len(channel.json_body["joined"]), 0)
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 709f851a38..fe97a0b3dd 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -15,17 +15,16 @@
 
 import itertools
 import urllib.parse
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
 from unittest.mock import patch
 
 from twisted.test.proto_helpers import MemoryReactor
 
-from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.constants import AccountDataTypes, EventTypes, RelationTypes
 from synapse.rest import admin
 from synapse.rest.client import login, register, relations, room, sync
 from synapse.server import HomeServer
-from synapse.storage.relations import RelationPaginationToken
-from synapse.types import JsonDict, StreamToken
+from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests import unittest
@@ -80,6 +79,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
         content: Optional[dict] = None,
         access_token: Optional[str] = None,
         parent_id: Optional[str] = None,
+        expected_response_code: int = 200,
     ) -> FakeChannel:
         """Helper function to send a relation pointing at `self.parent_id`
 
@@ -116,16 +116,60 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
             content,
             access_token=access_token,
         )
+        self.assertEqual(expected_response_code, channel.code, channel.json_body)
         return channel
 
+    def _get_related_events(self) -> List[str]:
+        """
+        Requests /relations on the parent ID and returns a list of event IDs.
+        """
+        # Request the relations of the event.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        return [ev["event_id"] for ev in channel.json_body["chunk"]]
+
+    def _get_bundled_aggregations(self) -> JsonDict:
+        """
+        Requests /event on the parent ID and returns the m.relations field (from unsigned), if it exists.
+        """
+        # Fetch the bundled aggregations of the event.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{self.room}/event/{self.parent_id}",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        return channel.json_body["unsigned"].get("m.relations", {})
+
+    def _get_aggregations(self) -> List[JsonDict]:
+        """Request /aggregations on the parent ID and includes the returned chunk."""
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        return channel.json_body["chunk"]
+
+    def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
+        """
+        Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
+        """
+        for event in events:
+            if event["event_id"] == self.parent_id:
+                return event
+
+        raise AssertionError(f"Event {self.parent_id} not found in chunk")
+
 
 class RelationsTestCase(BaseRelationsTestCase):
     def test_send_relation(self) -> None:
         """Tests that sending a relation works."""
-
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
-        self.assertEqual(200, channel.code, channel.json_body)
-
         event_id = channel.json_body["event_id"]
 
         channel = self.make_request(
@@ -152,13 +196,13 @@ class RelationsTestCase(BaseRelationsTestCase):
 
     def test_deny_invalid_event(self) -> None:
         """Test that we deny relations on non-existant events"""
-        channel = self._send_relation(
+        self._send_relation(
             RelationTypes.ANNOTATION,
             EventTypes.Message,
             parent_id="foo",
             content={"body": "foo", "msgtype": "m.text"},
+            expected_response_code=400,
         )
-        self.assertEqual(400, channel.code, channel.json_body)
 
         # Unless that event is referenced from another event!
         self.get_success(
@@ -172,13 +216,12 @@ class RelationsTestCase(BaseRelationsTestCase):
                 desc="test_deny_invalid_event",
             )
         )
-        channel = self._send_relation(
+        self._send_relation(
             RelationTypes.THREAD,
             EventTypes.Message,
             parent_id="foo",
             content={"body": "foo", "msgtype": "m.text"},
         )
-        self.assertEqual(200, channel.code, channel.json_body)
 
     def test_deny_invalid_room(self) -> None:
         """Test that we deny relations on non-existant events"""
@@ -188,18 +231,20 @@ class RelationsTestCase(BaseRelationsTestCase):
         parent_id = res["event_id"]
 
         # Attempt to send an annotation to that event.
-        channel = self._send_relation(
-            RelationTypes.ANNOTATION, "m.reaction", parent_id=parent_id, key="A"
+        self._send_relation(
+            RelationTypes.ANNOTATION,
+            "m.reaction",
+            parent_id=parent_id,
+            key="A",
+            expected_response_code=400,
         )
-        self.assertEqual(400, channel.code, channel.json_body)
 
     def test_deny_double_react(self) -> None:
         """Test that we deny relations on membership events"""
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        self.assertEqual(400, channel.code, channel.json_body)
+        self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
+        self._send_relation(
+            RelationTypes.ANNOTATION, "m.reaction", "a", expected_response_code=400
+        )
 
     def test_deny_forked_thread(self) -> None:
         """It is invalid to start a thread off a thread."""
@@ -209,386 +254,24 @@ class RelationsTestCase(BaseRelationsTestCase):
             content={"msgtype": "m.text", "body": "foo"},
             parent_id=self.parent_id,
         )
-        self.assertEqual(200, channel.code, channel.json_body)
         parent_id = channel.json_body["event_id"]
 
-        channel = self._send_relation(
+        self._send_relation(
             RelationTypes.THREAD,
             "m.room.message",
             content={"msgtype": "m.text", "body": "foo"},
             parent_id=parent_id,
-        )
-        self.assertEqual(400, channel.code, channel.json_body)
-
-    def test_basic_paginate_relations(self) -> None:
-        """Tests that calling pagination API correctly the latest relations."""
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        self.assertEqual(200, channel.code, channel.json_body)
-        first_annotation_id = channel.json_body["event_id"]
-
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
-        self.assertEqual(200, channel.code, channel.json_body)
-        second_annotation_id = channel.json_body["event_id"]
-
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        # We expect to get back a single pagination result, which is the latest
-        # full relation event we sent above.
-        self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-        self.assert_dict(
-            {
-                "event_id": second_annotation_id,
-                "sender": self.user_id,
-                "type": "m.reaction",
-            },
-            channel.json_body["chunk"][0],
-        )
-
-        # We also expect to get the original event (the id of which is self.parent_id)
-        self.assertEqual(
-            channel.json_body["original_event"]["event_id"], self.parent_id
-        )
-
-        # Make sure next_batch has something in it that looks like it could be a
-        # valid token.
-        self.assertIsInstance(
-            channel.json_body.get("next_batch"), str, channel.json_body
-        )
-
-        # Request the relations again, but with a different direction.
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/unstable/rooms/{self.room}/relations"
-            f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        # We expect to get back a single pagination result, which is the earliest
-        # full relation event we sent above.
-        self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-        self.assert_dict(
-            {
-                "event_id": first_annotation_id,
-                "sender": self.user_id,
-                "type": "m.reaction",
-            },
-            channel.json_body["chunk"][0],
-        )
-
-    def _stream_token_to_relation_token(self, token: str) -> str:
-        """Convert a StreamToken into a legacy token (RelationPaginationToken)."""
-        room_key = self.get_success(StreamToken.from_string(self.store, token)).room_key
-        return self.get_success(
-            RelationPaginationToken(
-                topological=room_key.topological, stream=room_key.stream
-            ).to_string(self.store)
-        )
-
-    def test_repeated_paginate_relations(self) -> None:
-        """Test that if we paginate using a limit and tokens then we get the
-        expected events.
-        """
-
-        expected_event_ids = []
-        for idx in range(10):
-            channel = self._send_relation(
-                RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx)
-            )
-            self.assertEqual(200, channel.code, channel.json_body)
-            expected_event_ids.append(channel.json_body["event_id"])
-
-        prev_token = ""
-        found_event_ids: List[str] = []
-        for _ in range(20):
-            from_token = ""
-            if prev_token:
-                from_token = "&from=" + prev_token
-
-            channel = self.make_request(
-                "GET",
-                f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
-                access_token=self.user_token,
-            )
-            self.assertEqual(200, channel.code, channel.json_body)
-
-            found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
-            next_batch = channel.json_body.get("next_batch")
-
-            self.assertNotEqual(prev_token, next_batch)
-            prev_token = next_batch
-
-            if not prev_token:
-                break
-
-        # We paginated backwards, so reverse
-        found_event_ids.reverse()
-        self.assertEqual(found_event_ids, expected_event_ids)
-
-        # Reset and try again, but convert the tokens to the legacy format.
-        prev_token = ""
-        found_event_ids = []
-        for _ in range(20):
-            from_token = ""
-            if prev_token:
-                from_token = "&from=" + self._stream_token_to_relation_token(prev_token)
-
-            channel = self.make_request(
-                "GET",
-                f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
-                access_token=self.user_token,
-            )
-            self.assertEqual(200, channel.code, channel.json_body)
-
-            found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
-            next_batch = channel.json_body.get("next_batch")
-
-            self.assertNotEqual(prev_token, next_batch)
-            prev_token = next_batch
-
-            if not prev_token:
-                break
-
-        # We paginated backwards, so reverse
-        found_event_ids.reverse()
-        self.assertEqual(found_event_ids, expected_event_ids)
-
-    def test_pagination_from_sync_and_messages(self) -> None:
-        """Pagination tokens from /sync and /messages can be used to paginate /relations."""
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
-        self.assertEqual(200, channel.code, channel.json_body)
-        annotation_id = channel.json_body["event_id"]
-        # Send an event after the relation events.
-        self.helper.send(self.room, body="Latest event", tok=self.user_token)
-
-        # Request /sync, limiting it such that only the latest event is returned
-        # (and not the relation).
-        filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}')
-        channel = self.make_request(
-            "GET", f"/sync?filter={filter}", access_token=self.user_token
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
-        sync_prev_batch = room_timeline["prev_batch"]
-        self.assertIsNotNone(sync_prev_batch)
-        # Ensure the relation event is not in the batch returned from /sync.
-        self.assertNotIn(
-            annotation_id, [ev["event_id"] for ev in room_timeline["events"]]
-        )
-
-        # Request /messages, limiting it such that only the latest event is
-        # returned (and not the relation).
-        channel = self.make_request(
-            "GET",
-            f"/rooms/{self.room}/messages?dir=b&limit=1",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        messages_end = channel.json_body["end"]
-        self.assertIsNotNone(messages_end)
-        # Ensure the relation event is not in the chunk returned from /messages.
-        self.assertNotIn(
-            annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+            expected_response_code=400,
         )
 
-        # Request /relations with the pagination tokens received from both the
-        # /sync and /messages responses above, in turn.
-        #
-        # This is a tiny bit silly since the client wouldn't know the parent ID
-        # from the requests above; consider the parent ID to be known from a
-        # previous /sync.
-        for from_token in (sync_prev_batch, messages_end):
-            channel = self.make_request(
-                "GET",
-                f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}",
-                access_token=self.user_token,
-            )
-            self.assertEqual(200, channel.code, channel.json_body)
-
-            # The relation should be in the returned chunk.
-            self.assertIn(
-                annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
-            )
-
-    def test_aggregation_pagination_groups(self) -> None:
-        """Test that we can paginate annotation groups correctly."""
-
-        # We need to create ten separate users to send each reaction.
-        access_tokens = [self.user_token, self.user2_token]
-        idx = 0
-        while len(access_tokens) < 10:
-            user_id, token = self._create_user("test" + str(idx))
-            idx += 1
-
-            self.helper.join(self.room, user=user_id, tok=token)
-            access_tokens.append(token)
-
-        idx = 0
-        sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1}
-        for key in itertools.chain.from_iterable(
-            itertools.repeat(key, num) for key, num in sent_groups.items()
-        ):
-            channel = self._send_relation(
-                RelationTypes.ANNOTATION,
-                "m.reaction",
-                key=key,
-                access_token=access_tokens[idx],
-            )
-            self.assertEqual(200, channel.code, channel.json_body)
-
-            idx += 1
-            idx %= len(access_tokens)
-
-        prev_token: Optional[str] = None
-        found_groups: Dict[str, int] = {}
-        for _ in range(20):
-            from_token = ""
-            if prev_token:
-                from_token = "&from=" + prev_token
-
-            channel = self.make_request(
-                "GET",
-                f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}",
-                access_token=self.user_token,
-            )
-            self.assertEqual(200, channel.code, channel.json_body)
-
-            self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-
-            for groups in channel.json_body["chunk"]:
-                # We only expect reactions
-                self.assertEqual(groups["type"], "m.reaction", channel.json_body)
-
-                # We should only see each key once
-                self.assertNotIn(groups["key"], found_groups, channel.json_body)
-
-                found_groups[groups["key"]] = groups["count"]
-
-            next_batch = channel.json_body.get("next_batch")
-
-            self.assertNotEqual(prev_token, next_batch)
-            prev_token = next_batch
-
-            if not prev_token:
-                break
-
-        self.assertEqual(sent_groups, found_groups)
-
-    def test_aggregation_pagination_within_group(self) -> None:
-        """Test that we can paginate within an annotation group."""
-
-        # We need to create ten separate users to send each reaction.
-        access_tokens = [self.user_token, self.user2_token]
-        idx = 0
-        while len(access_tokens) < 10:
-            user_id, token = self._create_user("test" + str(idx))
-            idx += 1
-
-            self.helper.join(self.room, user=user_id, tok=token)
-            access_tokens.append(token)
-
-        idx = 0
-        expected_event_ids = []
-        for _ in range(10):
-            channel = self._send_relation(
-                RelationTypes.ANNOTATION,
-                "m.reaction",
-                key="👍",
-                access_token=access_tokens[idx],
-            )
-            self.assertEqual(200, channel.code, channel.json_body)
-            expected_event_ids.append(channel.json_body["event_id"])
-
-            idx += 1
-
-        # Also send a different type of reaction so that we test we don't see it
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        prev_token = ""
-        found_event_ids: List[str] = []
-        encoded_key = urllib.parse.quote_plus("👍".encode())
-        for _ in range(20):
-            from_token = ""
-            if prev_token:
-                from_token = "&from=" + prev_token
-
-            channel = self.make_request(
-                "GET",
-                f"/_matrix/client/unstable/rooms/{self.room}"
-                f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
-                f"/m.reaction/{encoded_key}?limit=1{from_token}",
-                access_token=self.user_token,
-            )
-            self.assertEqual(200, channel.code, channel.json_body)
-
-            self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-
-            found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
-
-            next_batch = channel.json_body.get("next_batch")
-
-            self.assertNotEqual(prev_token, next_batch)
-            prev_token = next_batch
-
-            if not prev_token:
-                break
-
-        # We paginated backwards, so reverse
-        found_event_ids.reverse()
-        self.assertEqual(found_event_ids, expected_event_ids)
-
-        # Reset and try again, but convert the tokens to the legacy format.
-        prev_token = ""
-        found_event_ids = []
-        for _ in range(20):
-            from_token = ""
-            if prev_token:
-                from_token = "&from=" + self._stream_token_to_relation_token(prev_token)
-
-            channel = self.make_request(
-                "GET",
-                f"/_matrix/client/unstable/rooms/{self.room}"
-                f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
-                f"/m.reaction/{encoded_key}?limit=1{from_token}",
-                access_token=self.user_token,
-            )
-            self.assertEqual(200, channel.code, channel.json_body)
-
-            self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-
-            found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
-
-            next_batch = channel.json_body.get("next_batch")
-
-            self.assertNotEqual(prev_token, next_batch)
-            prev_token = next_batch
-
-            if not prev_token:
-                break
-
-        # We paginated backwards, so reverse
-        found_event_ids.reverse()
-        self.assertEqual(found_event_ids, expected_event_ids)
-
     def test_aggregation(self) -> None:
         """Test that annotations get correctly aggregated."""
 
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        channel = self._send_relation(
+        self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+        self._send_relation(
             RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
         )
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
-        self.assertEqual(200, channel.code, channel.json_body)
+        self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
 
         channel = self.make_request(
             "GET",
@@ -618,220 +301,6 @@ class RelationsTestCase(BaseRelationsTestCase):
         )
         self.assertEqual(400, channel.code, channel.json_body)
 
-    @unittest.override_config(
-        {"experimental_features": {"msc3440_enabled": True, "msc3666_enabled": True}}
-    )
-    def test_bundled_aggregations(self) -> None:
-        """
-        Test that annotations, references, and threads get correctly bundled.
-
-        Note that this doesn't test against /relations since only thread relations
-        get bundled via that API. See test_aggregation_get_event_for_thread.
-
-        See test_edit for a similar test for edits.
-        """
-        # Setup by sending a variety of relations.
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        channel = self._send_relation(
-            RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
-        self.assertEqual(200, channel.code, channel.json_body)
-        reply_1 = channel.json_body["event_id"]
-
-        channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
-        self.assertEqual(200, channel.code, channel.json_body)
-        reply_2 = channel.json_body["event_id"]
-
-        channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
-        self.assertEqual(200, channel.code, channel.json_body)
-        thread_2 = channel.json_body["event_id"]
-
-        def assert_bundle(event_json: JsonDict) -> None:
-            """Assert the expected values of the bundled aggregations."""
-            relations_dict = event_json["unsigned"].get("m.relations")
-
-            # Ensure the fields are as expected.
-            self.assertCountEqual(
-                relations_dict.keys(),
-                (
-                    RelationTypes.ANNOTATION,
-                    RelationTypes.REFERENCE,
-                    RelationTypes.THREAD,
-                ),
-            )
-
-            # Check the values of each field.
-            self.assertEqual(
-                {
-                    "chunk": [
-                        {"type": "m.reaction", "key": "a", "count": 2},
-                        {"type": "m.reaction", "key": "b", "count": 1},
-                    ]
-                },
-                relations_dict[RelationTypes.ANNOTATION],
-            )
-
-            self.assertEqual(
-                {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]},
-                relations_dict[RelationTypes.REFERENCE],
-            )
-
-            self.assertEqual(
-                2,
-                relations_dict[RelationTypes.THREAD].get("count"),
-            )
-            self.assertTrue(
-                relations_dict[RelationTypes.THREAD].get("current_user_participated")
-            )
-            # The latest thread event has some fields that don't matter.
-            self.assert_dict(
-                {
-                    "content": {
-                        "m.relates_to": {
-                            "event_id": self.parent_id,
-                            "rel_type": RelationTypes.THREAD,
-                        }
-                    },
-                    "event_id": thread_2,
-                    "room_id": self.room,
-                    "sender": self.user_id,
-                    "type": "m.room.test",
-                    "user_id": self.user_id,
-                },
-                relations_dict[RelationTypes.THREAD].get("latest_event"),
-            )
-
-        # Request the event directly.
-        channel = self.make_request(
-            "GET",
-            f"/rooms/{self.room}/event/{self.parent_id}",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        assert_bundle(channel.json_body)
-
-        # Request the room messages.
-        channel = self.make_request(
-            "GET",
-            f"/rooms/{self.room}/messages?dir=b",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
-
-        # Request the room context.
-        channel = self.make_request(
-            "GET",
-            f"/rooms/{self.room}/context/{self.parent_id}",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        assert_bundle(channel.json_body["event"])
-
-        # Request sync.
-        channel = self.make_request("GET", "/sync", access_token=self.user_token)
-        self.assertEqual(200, channel.code, channel.json_body)
-        room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
-        self.assertTrue(room_timeline["limited"])
-        assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
-
-        # Request search.
-        channel = self.make_request(
-            "POST",
-            "/search",
-            # Search term matches the parent message.
-            content={"search_categories": {"room_events": {"search_term": "Hi"}}},
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        chunk = [
-            result["result"]
-            for result in channel.json_body["search_categories"]["room_events"][
-                "results"
-            ]
-        ]
-        assert_bundle(self._find_event_in_chunk(chunk))
-
-    def test_aggregation_get_event_for_annotation(self) -> None:
-        """Test that annotations do not get bundled aggregations included
-        when directly requested.
-        """
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        self.assertEqual(200, channel.code, channel.json_body)
-        annotation_id = channel.json_body["event_id"]
-
-        # Annotate the annotation.
-        channel = self._send_relation(
-            RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        channel = self.make_request(
-            "GET",
-            f"/rooms/{self.room}/event/{annotation_id}",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        self.assertIsNone(channel.json_body["unsigned"].get("m.relations"))
-
-    def test_aggregation_get_event_for_thread(self) -> None:
-        """Test that threads get bundled aggregations included when directly requested."""
-        channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
-        self.assertEqual(200, channel.code, channel.json_body)
-        thread_id = channel.json_body["event_id"]
-
-        # Annotate the annotation.
-        channel = self._send_relation(
-            RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        channel = self.make_request(
-            "GET",
-            f"/rooms/{self.room}/event/{thread_id}",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        self.assertEqual(
-            channel.json_body["unsigned"].get("m.relations"),
-            {
-                RelationTypes.ANNOTATION: {
-                    "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
-                },
-            },
-        )
-
-        # It should also be included when the entire thread is requested.
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        self.assertEqual(len(channel.json_body["chunk"]), 1)
-
-        thread_message = channel.json_body["chunk"][0]
-        self.assertEqual(
-            thread_message["unsigned"].get("m.relations"),
-            {
-                RelationTypes.ANNOTATION: {
-                    "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
-                },
-            },
-        )
-
-    @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
     def test_ignore_invalid_room(self) -> None:
         """Test that we ignore invalid relations over federation."""
         # Create another room and send a message in it.
@@ -953,8 +422,6 @@ class RelationsTestCase(BaseRelationsTestCase):
             "m.room.message",
             content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
         )
-        self.assertEqual(200, channel.code, channel.json_body)
-
         edit_event_id = channel.json_body["event_id"]
 
         def assert_bundle(event_json: JsonDict) -> None:
@@ -1030,7 +497,7 @@ class RelationsTestCase(BaseRelationsTestCase):
         shouldn't be allowed, are correctly handled.
         """
 
-        channel = self._send_relation(
+        self._send_relation(
             RelationTypes.REPLACE,
             "m.room.message",
             content={
@@ -1039,7 +506,6 @@ class RelationsTestCase(BaseRelationsTestCase):
                 "m.new_content": {"msgtype": "m.text", "body": "First edit"},
             },
         )
-        self.assertEqual(200, channel.code, channel.json_body)
 
         new_body = {"msgtype": "m.text", "body": "I've been edited!"}
         channel = self._send_relation(
@@ -1047,11 +513,9 @@ class RelationsTestCase(BaseRelationsTestCase):
             "m.room.message",
             content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
         )
-        self.assertEqual(200, channel.code, channel.json_body)
-
         edit_event_id = channel.json_body["event_id"]
 
-        channel = self._send_relation(
+        self._send_relation(
             RelationTypes.REPLACE,
             "m.room.message.WRONG_TYPE",
             content={
@@ -1060,7 +524,6 @@ class RelationsTestCase(BaseRelationsTestCase):
                 "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"},
             },
         )
-        self.assertEqual(200, channel.code, channel.json_body)
 
         channel = self.make_request(
             "GET",
@@ -1091,7 +554,6 @@ class RelationsTestCase(BaseRelationsTestCase):
             "m.room.message",
             content={"msgtype": "m.text", "body": "A reply!"},
         )
-        self.assertEqual(200, channel.code, channel.json_body)
         reply = channel.json_body["event_id"]
 
         new_body = {"msgtype": "m.text", "body": "I've been edited!"}
@@ -1101,8 +563,6 @@ class RelationsTestCase(BaseRelationsTestCase):
             content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
             parent_id=reply,
         )
-        self.assertEqual(200, channel.code, channel.json_body)
-
         edit_event_id = channel.json_body["event_id"]
 
         channel = self.make_request(
@@ -1138,7 +598,6 @@ class RelationsTestCase(BaseRelationsTestCase):
             {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
         )
 
-    @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
     def test_edit_thread(self) -> None:
         """Test that editing a thread works."""
 
@@ -1148,17 +607,15 @@ class RelationsTestCase(BaseRelationsTestCase):
             "m.room.message",
             content={"msgtype": "m.text", "body": "A threaded reply!"},
         )
-        self.assertEqual(200, channel.code, channel.json_body)
         threaded_event_id = channel.json_body["event_id"]
 
         new_body = {"msgtype": "m.text", "body": "I've been edited!"}
-        channel = self._send_relation(
+        self._send_relation(
             RelationTypes.REPLACE,
             "m.room.message",
             content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
             parent_id=threaded_event_id,
         )
-        self.assertEqual(200, channel.code, channel.json_body)
 
         # Fetch the thread root, to get the bundled aggregation for the thread.
         channel = self.make_request(
@@ -1190,11 +647,10 @@ class RelationsTestCase(BaseRelationsTestCase):
                 "m.new_content": new_body,
             },
         )
-        self.assertEqual(200, channel.code, channel.json_body)
         edit_event_id = channel.json_body["event_id"]
 
         # Edit the edit event.
-        channel = self._send_relation(
+        self._send_relation(
             RelationTypes.REPLACE,
             "m.room.message",
             content={
@@ -1204,7 +660,6 @@ class RelationsTestCase(BaseRelationsTestCase):
             },
             parent_id=edit_event_id,
         )
-        self.assertEqual(200, channel.code, channel.json_body)
 
         # Request the original event.
         channel = self.make_request(
@@ -1231,7 +686,6 @@ class RelationsTestCase(BaseRelationsTestCase):
     def test_unknown_relations(self) -> None:
         """Unknown relations should be accepted."""
         channel = self._send_relation("m.relation.test", "m.room.test")
-        self.assertEqual(200, channel.code, channel.json_body)
         event_id = channel.json_body["event_id"]
 
         channel = self.make_request(
@@ -1272,28 +726,15 @@ class RelationsTestCase(BaseRelationsTestCase):
         self.assertEqual(200, channel.code, channel.json_body)
         self.assertEqual(channel.json_body["chunk"], [])
 
-    def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
-        """
-        Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
-        """
-        for event in events:
-            if event["event_id"] == self.parent_id:
-                return event
-
-        raise AssertionError(f"Event {self.parent_id} not found in chunk")
-
     def test_background_update(self) -> None:
         """Test the event_arbitrary_relations background update."""
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
-        self.assertEqual(200, channel.code, channel.json_body)
         annotation_event_id_good = channel.json_body["event_id"]
 
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A")
-        self.assertEqual(200, channel.code, channel.json_body)
         annotation_event_id_bad = channel.json_body["event_id"]
 
         channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
-        self.assertEqual(200, channel.code, channel.json_body)
         thread_event_id = channel.json_body["event_id"]
 
         # Clean-up the table as if the inserts did not happen during event creation.
@@ -1345,8 +786,638 @@ class RelationsTestCase(BaseRelationsTestCase):
         )
 
 
+class RelationPaginationTestCase(BaseRelationsTestCase):
+    def test_basic_paginate_relations(self) -> None:
+        """Tests that calling pagination API correctly the latest relations."""
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+        first_annotation_id = channel.json_body["event_id"]
+
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
+        second_annotation_id = channel.json_body["event_id"]
+
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+
+        # We expect to get back a single pagination result, which is the latest
+        # full relation event we sent above.
+        self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+        self.assert_dict(
+            {
+                "event_id": second_annotation_id,
+                "sender": self.user_id,
+                "type": "m.reaction",
+            },
+            channel.json_body["chunk"][0],
+        )
+
+        # We also expect to get the original event (the id of which is self.parent_id)
+        self.assertEqual(
+            channel.json_body["original_event"]["event_id"], self.parent_id
+        )
+
+        # Make sure next_batch has something in it that looks like it could be a
+        # valid token.
+        self.assertIsInstance(
+            channel.json_body.get("next_batch"), str, channel.json_body
+        )
+
+        # Request the relations again, but with a different direction.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{self.room}/relations"
+            f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+
+        # We expect to get back a single pagination result, which is the earliest
+        # full relation event we sent above.
+        self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+        self.assert_dict(
+            {
+                "event_id": first_annotation_id,
+                "sender": self.user_id,
+                "type": "m.reaction",
+            },
+            channel.json_body["chunk"][0],
+        )
+
+    def test_repeated_paginate_relations(self) -> None:
+        """Test that if we paginate using a limit and tokens then we get the
+        expected events.
+        """
+
+        expected_event_ids = []
+        for idx in range(10):
+            channel = self._send_relation(
+                RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx)
+            )
+            expected_event_ids.append(channel.json_body["event_id"])
+
+        prev_token = ""
+        found_event_ids: List[str] = []
+        for _ in range(20):
+            from_token = ""
+            if prev_token:
+                from_token = "&from=" + prev_token
+
+            channel = self.make_request(
+                "GET",
+                f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
+                access_token=self.user_token,
+            )
+            self.assertEqual(200, channel.code, channel.json_body)
+
+            found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+            next_batch = channel.json_body.get("next_batch")
+
+            self.assertNotEqual(prev_token, next_batch)
+            prev_token = next_batch
+
+            if not prev_token:
+                break
+
+        # We paginated backwards, so reverse
+        found_event_ids.reverse()
+        self.assertEqual(found_event_ids, expected_event_ids)
+
+    def test_pagination_from_sync_and_messages(self) -> None:
+        """Pagination tokens from /sync and /messages can be used to paginate /relations."""
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
+        annotation_id = channel.json_body["event_id"]
+        # Send an event after the relation events.
+        self.helper.send(self.room, body="Latest event", tok=self.user_token)
+
+        # Request /sync, limiting it such that only the latest event is returned
+        # (and not the relation).
+        filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}')
+        channel = self.make_request(
+            "GET", f"/sync?filter={filter}", access_token=self.user_token
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+        sync_prev_batch = room_timeline["prev_batch"]
+        self.assertIsNotNone(sync_prev_batch)
+        # Ensure the relation event is not in the batch returned from /sync.
+        self.assertNotIn(
+            annotation_id, [ev["event_id"] for ev in room_timeline["events"]]
+        )
+
+        # Request /messages, limiting it such that only the latest event is
+        # returned (and not the relation).
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room}/messages?dir=b&limit=1",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        messages_end = channel.json_body["end"]
+        self.assertIsNotNone(messages_end)
+        # Ensure the relation event is not in the chunk returned from /messages.
+        self.assertNotIn(
+            annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+        )
+
+        # Request /relations with the pagination tokens received from both the
+        # /sync and /messages responses above, in turn.
+        #
+        # This is a tiny bit silly since the client wouldn't know the parent ID
+        # from the requests above; consider the parent ID to be known from a
+        # previous /sync.
+        for from_token in (sync_prev_batch, messages_end):
+            channel = self.make_request(
+                "GET",
+                f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}",
+                access_token=self.user_token,
+            )
+            self.assertEqual(200, channel.code, channel.json_body)
+
+            # The relation should be in the returned chunk.
+            self.assertIn(
+                annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+            )
+
+    def test_aggregation_pagination_groups(self) -> None:
+        """Test that we can paginate annotation groups correctly."""
+
+        # We need to create ten separate users to send each reaction.
+        access_tokens = [self.user_token, self.user2_token]
+        idx = 0
+        while len(access_tokens) < 10:
+            user_id, token = self._create_user("test" + str(idx))
+            idx += 1
+
+            self.helper.join(self.room, user=user_id, tok=token)
+            access_tokens.append(token)
+
+        idx = 0
+        sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1}
+        for key in itertools.chain.from_iterable(
+            itertools.repeat(key, num) for key, num in sent_groups.items()
+        ):
+            self._send_relation(
+                RelationTypes.ANNOTATION,
+                "m.reaction",
+                key=key,
+                access_token=access_tokens[idx],
+            )
+
+            idx += 1
+            idx %= len(access_tokens)
+
+        prev_token: Optional[str] = None
+        found_groups: Dict[str, int] = {}
+        for _ in range(20):
+            from_token = ""
+            if prev_token:
+                from_token = "&from=" + prev_token
+
+            channel = self.make_request(
+                "GET",
+                f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}",
+                access_token=self.user_token,
+            )
+            self.assertEqual(200, channel.code, channel.json_body)
+
+            self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+
+            for groups in channel.json_body["chunk"]:
+                # We only expect reactions
+                self.assertEqual(groups["type"], "m.reaction", channel.json_body)
+
+                # We should only see each key once
+                self.assertNotIn(groups["key"], found_groups, channel.json_body)
+
+                found_groups[groups["key"]] = groups["count"]
+
+            next_batch = channel.json_body.get("next_batch")
+
+            self.assertNotEqual(prev_token, next_batch)
+            prev_token = next_batch
+
+            if not prev_token:
+                break
+
+        self.assertEqual(sent_groups, found_groups)
+
+    def test_aggregation_pagination_within_group(self) -> None:
+        """Test that we can paginate within an annotation group."""
+
+        # We need to create ten separate users to send each reaction.
+        access_tokens = [self.user_token, self.user2_token]
+        idx = 0
+        while len(access_tokens) < 10:
+            user_id, token = self._create_user("test" + str(idx))
+            idx += 1
+
+            self.helper.join(self.room, user=user_id, tok=token)
+            access_tokens.append(token)
+
+        idx = 0
+        expected_event_ids = []
+        for _ in range(10):
+            channel = self._send_relation(
+                RelationTypes.ANNOTATION,
+                "m.reaction",
+                key="👍",
+                access_token=access_tokens[idx],
+            )
+            expected_event_ids.append(channel.json_body["event_id"])
+
+            idx += 1
+
+        # Also send a different type of reaction so that we test we don't see it
+        self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
+
+        prev_token = ""
+        found_event_ids: List[str] = []
+        encoded_key = urllib.parse.quote_plus("👍".encode())
+        for _ in range(20):
+            from_token = ""
+            if prev_token:
+                from_token = "&from=" + prev_token
+
+            channel = self.make_request(
+                "GET",
+                f"/_matrix/client/unstable/rooms/{self.room}"
+                f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
+                f"/m.reaction/{encoded_key}?limit=1{from_token}",
+                access_token=self.user_token,
+            )
+            self.assertEqual(200, channel.code, channel.json_body)
+
+            self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+
+            found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+
+            next_batch = channel.json_body.get("next_batch")
+
+            self.assertNotEqual(prev_token, next_batch)
+            prev_token = next_batch
+
+            if not prev_token:
+                break
+
+        # We paginated backwards, so reverse
+        found_event_ids.reverse()
+        self.assertEqual(found_event_ids, expected_event_ids)
+
+
+class BundledAggregationsTestCase(BaseRelationsTestCase):
+    """
+    See RelationsTestCase.test_edit for a similar test for edits.
+
+    Note that this doesn't test against /relations since only thread relations
+    get bundled via that API. See test_aggregation_get_event_for_thread.
+    """
+
+    def _test_bundled_aggregations(
+        self,
+        relation_type: str,
+        assertion_callable: Callable[[JsonDict], None],
+        expected_db_txn_for_event: int,
+    ) -> None:
+        """
+        Makes requests to various endpoints which should include bundled aggregations
+        and then calls an assertion function on the bundled aggregations.
+
+        Args:
+            relation_type: The field to search for in the `m.relations` field in unsigned.
+            assertion_callable: Called with the contents of unsigned["m.relations"][relation_type]
+                for relation-specific assertions.
+            expected_db_txn_for_event: The number of database transactions which
+                are expected for a call to /event/.
+        """
+
+        def assert_bundle(event_json: JsonDict) -> None:
+            """Assert the expected values of the bundled aggregations."""
+            relations_dict = event_json["unsigned"].get("m.relations")
+
+            # Ensure the fields are as expected.
+            self.assertCountEqual(relations_dict.keys(), (relation_type,))
+            assertion_callable(relations_dict[relation_type])
+
+        # Request the event directly.
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room}/event/{self.parent_id}",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        assert_bundle(channel.json_body)
+        assert channel.resource_usage is not None
+        self.assertEqual(channel.resource_usage.db_txn_count, expected_db_txn_for_event)
+
+        # Request the room messages.
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room}/messages?dir=b",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
+
+        # Request the room context.
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room}/context/{self.parent_id}",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        assert_bundle(channel.json_body["event"])
+
+        # Request sync.
+        filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}')
+        channel = self.make_request(
+            "GET", f"/sync?filter={filter}", access_token=self.user_token
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+        self.assertTrue(room_timeline["limited"])
+        assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
+
+        # Request search.
+        channel = self.make_request(
+            "POST",
+            "/search",
+            # Search term matches the parent message.
+            content={"search_categories": {"room_events": {"search_term": "Hi"}}},
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        chunk = [
+            result["result"]
+            for result in channel.json_body["search_categories"]["room_events"][
+                "results"
+            ]
+        ]
+        assert_bundle(self._find_event_in_chunk(chunk))
+
+    @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
+    def test_annotation(self) -> None:
+        """
+        Test that annotations get correctly bundled.
+        """
+        # Setup by sending a variety of relations.
+        self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+        self._send_relation(
+            RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
+        )
+        self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
+
+        def assert_annotations(bundled_aggregations: JsonDict) -> None:
+            self.assertEqual(
+                {
+                    "chunk": [
+                        {"type": "m.reaction", "key": "a", "count": 2},
+                        {"type": "m.reaction", "key": "b", "count": 1},
+                    ]
+                },
+                bundled_aggregations,
+            )
+
+        self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7)
+
+    @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
+    def test_reference(self) -> None:
+        """
+        Test that references get correctly bundled.
+        """
+        channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
+        reply_1 = channel.json_body["event_id"]
+
+        channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
+        reply_2 = channel.json_body["event_id"]
+
+        def assert_annotations(bundled_aggregations: JsonDict) -> None:
+            self.assertEqual(
+                {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]},
+                bundled_aggregations,
+            )
+
+        self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7)
+
+    @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
+    def test_thread(self) -> None:
+        """
+        Test that threads get correctly bundled.
+        """
+        self._send_relation(RelationTypes.THREAD, "m.room.test")
+        channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+        thread_2 = channel.json_body["event_id"]
+
+        def assert_annotations(bundled_aggregations: JsonDict) -> None:
+            self.assertEqual(2, bundled_aggregations.get("count"))
+            self.assertTrue(bundled_aggregations.get("current_user_participated"))
+            # The latest thread event has some fields that don't matter.
+            self.assert_dict(
+                {
+                    "content": {
+                        "m.relates_to": {
+                            "event_id": self.parent_id,
+                            "rel_type": RelationTypes.THREAD,
+                        }
+                    },
+                    "event_id": thread_2,
+                    "sender": self.user_id,
+                    "type": "m.room.test",
+                },
+                bundled_aggregations.get("latest_event"),
+            )
+
+        self._test_bundled_aggregations(RelationTypes.THREAD, assert_annotations, 9)
+
+    def test_aggregation_get_event_for_annotation(self) -> None:
+        """Test that annotations do not get bundled aggregations included
+        when directly requested.
+        """
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+        annotation_id = channel.json_body["event_id"]
+
+        # Annotate the annotation.
+        self._send_relation(
+            RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id
+        )
+
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room}/event/{annotation_id}",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        self.assertIsNone(channel.json_body["unsigned"].get("m.relations"))
+
+    def test_aggregation_get_event_for_thread(self) -> None:
+        """Test that threads get bundled aggregations included when directly requested."""
+        channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+        thread_id = channel.json_body["event_id"]
+
+        # Annotate the annotation.
+        self._send_relation(
+            RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id
+        )
+
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room}/event/{thread_id}",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        self.assertEqual(
+            channel.json_body["unsigned"].get("m.relations"),
+            {
+                RelationTypes.ANNOTATION: {
+                    "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
+                },
+            },
+        )
+
+        # It should also be included when the entire thread is requested.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        self.assertEqual(len(channel.json_body["chunk"]), 1)
+
+        thread_message = channel.json_body["chunk"][0]
+        self.assertEqual(
+            thread_message["unsigned"].get("m.relations"),
+            {
+                RelationTypes.ANNOTATION: {
+                    "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
+                },
+            },
+        )
+
+    def test_bundled_aggregations_with_filter(self) -> None:
+        """
+        If "unsigned" is an omitted field (due to filtering), adding the bundled
+        aggregations should not break.
+
+        Note that the spec allows for a server to return additional fields beyond
+        what is specified.
+        """
+        self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+
+        # Note that the sync filter does not include "unsigned" as a field.
+        filter = urllib.parse.quote_plus(
+            b'{"event_fields": ["content", "event_id"], "room": {"timeline": {"limit": 3}}}'
+        )
+        channel = self.make_request(
+            "GET", f"/sync?filter={filter}", access_token=self.user_token
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+
+        # Ensure the timeline is limited, find the parent event.
+        room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+        self.assertTrue(room_timeline["limited"])
+        parent_event = self._find_event_in_chunk(room_timeline["events"])
+
+        # Ensure there's bundled aggregations on it.
+        self.assertIn("unsigned", parent_event)
+        self.assertIn("m.relations", parent_event["unsigned"])
+
+
+class RelationIgnoredUserTestCase(BaseRelationsTestCase):
+    """Relations sent from an ignored user should be ignored."""
+
+    def _test_ignored_user(
+        self, allowed_event_ids: List[str], ignored_event_ids: List[str]
+    ) -> None:
+        """
+        Fetch the relations and ensure they're all there, then ignore user2, and
+        repeat.
+        """
+        # Get the relations.
+        event_ids = self._get_related_events()
+        self.assertCountEqual(event_ids, allowed_event_ids + ignored_event_ids)
+
+        # Ignore user2 and re-do the requests.
+        self.get_success(
+            self.store.add_account_data_for_user(
+                self.user_id,
+                AccountDataTypes.IGNORED_USER_LIST,
+                {"ignored_users": {self.user2_id: {}}},
+            )
+        )
+
+        # Get the relations.
+        event_ids = self._get_related_events()
+        self.assertCountEqual(event_ids, allowed_event_ids)
+
+    def test_annotation(self) -> None:
+        """Annotations should ignore"""
+        # Send 2 from us, 2 from the to be ignored user.
+        allowed_event_ids = []
+        ignored_event_ids = []
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
+        allowed_event_ids.append(channel.json_body["event_id"])
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="b")
+        allowed_event_ids.append(channel.json_body["event_id"])
+        channel = self._send_relation(
+            RelationTypes.ANNOTATION,
+            "m.reaction",
+            key="a",
+            access_token=self.user2_token,
+        )
+        ignored_event_ids.append(channel.json_body["event_id"])
+        channel = self._send_relation(
+            RelationTypes.ANNOTATION,
+            "m.reaction",
+            key="c",
+            access_token=self.user2_token,
+        )
+        ignored_event_ids.append(channel.json_body["event_id"])
+
+        self._test_ignored_user(allowed_event_ids, ignored_event_ids)
+
+    def test_reference(self) -> None:
+        """Annotations should ignore"""
+        channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
+        allowed_event_ids = [channel.json_body["event_id"]]
+
+        channel = self._send_relation(
+            RelationTypes.REFERENCE, "m.room.test", access_token=self.user2_token
+        )
+        ignored_event_ids = [channel.json_body["event_id"]]
+
+        self._test_ignored_user(allowed_event_ids, ignored_event_ids)
+
+    def test_thread(self) -> None:
+        """Annotations should ignore"""
+        channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+        allowed_event_ids = [channel.json_body["event_id"]]
+
+        channel = self._send_relation(
+            RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
+        )
+        ignored_event_ids = [channel.json_body["event_id"]]
+
+        self._test_ignored_user(allowed_event_ids, ignored_event_ids)
+
+
 class RelationRedactionTestCase(BaseRelationsTestCase):
-    """Test the behaviour of relations when the parent or child event is redacted."""
+    """
+    Test the behaviour of relations when the parent or child event is redacted.
+
+    The behaviour of each relation type is subtly different which causes the tests
+    to be a bit repetitive, they follow a naming scheme of:
+
+        test_redact_(relation|parent)_{relation_type}
+
+    The first bit of "relation" means that the event with the relation defined
+    on it (the child event) is to be redacted. A "parent" means that the target
+    of the relation (the parent event) is to be redacted.
+
+    The relation_type describes which type of relation is under test (i.e. it is
+    related to the value of rel_type in the event content).
+    """
 
     def _redact(self, event_id: str) -> None:
         channel = self.make_request(
@@ -1358,40 +1429,116 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
         self.assertEqual(200, channel.code, channel.json_body)
 
     def test_redact_relation_annotation(self) -> None:
-        """Test that annotations of an event are properly handled after the
+        """
+        Test that annotations of an event are properly handled after the
         annotation is redacted.
+
+        The redacted relation should not be included in bundled aggregations or
+        the response to relations.
         """
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        self.assertEqual(200, channel.code, channel.json_body)
         to_redact_event_id = channel.json_body["event_id"]
 
         channel = self._send_relation(
             RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
         )
-        self.assertEqual(200, channel.code, channel.json_body)
+        unredacted_event_id = channel.json_body["event_id"]
+
+        # Both relations should exist.
+        event_ids = self._get_related_events()
+        relations = self._get_bundled_aggregations()
+        self.assertCountEqual(event_ids, [to_redact_event_id, unredacted_event_id])
+        self.assertEquals(
+            relations["m.annotation"],
+            {"chunk": [{"type": "m.reaction", "key": "a", "count": 2}]},
+        )
+
+        # Both relations appear in the aggregation.
+        chunk = self._get_aggregations()
+        self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 2}])
 
         # Redact one of the reactions.
         self._redact(to_redact_event_id)
 
-        # Ensure that the aggregations are correct.
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
-            access_token=self.user_token,
+        # The unredacted relation should still exist.
+        event_ids = self._get_related_events()
+        relations = self._get_bundled_aggregations()
+        self.assertEquals(event_ids, [unredacted_event_id])
+        self.assertEquals(
+            relations["m.annotation"],
+            {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
         )
-        self.assertEqual(200, channel.code, channel.json_body)
 
+        # The unredacted aggregation should still exist.
+        chunk = self._get_aggregations()
+        self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 1}])
+
+    def test_redact_relation_thread(self) -> None:
+        """
+        Test that thread replies are properly handled after the thread reply redacted.
+
+        The redacted event should not be included in bundled aggregations or
+        the response to relations.
+        """
+        channel = self._send_relation(
+            RelationTypes.THREAD,
+            EventTypes.Message,
+            content={"body": "reply 1", "msgtype": "m.text"},
+        )
+        unredacted_event_id = channel.json_body["event_id"]
+
+        # Note that the *last* event in the thread is redacted, as that gets
+        # included in the bundled aggregation.
+        channel = self._send_relation(
+            RelationTypes.THREAD,
+            EventTypes.Message,
+            content={"body": "reply 2", "msgtype": "m.text"},
+        )
+        to_redact_event_id = channel.json_body["event_id"]
+
+        # Both relations exist.
+        event_ids = self._get_related_events()
+        relations = self._get_bundled_aggregations()
+        self.assertEquals(event_ids, [to_redact_event_id, unredacted_event_id])
+        self.assertDictContainsSubset(
+            {
+                "count": 2,
+                "current_user_participated": True,
+            },
+            relations[RelationTypes.THREAD],
+        )
+        # And the latest event returned is the event that will be redacted.
+        self.assertEqual(
+            relations[RelationTypes.THREAD]["latest_event"]["event_id"],
+            to_redact_event_id,
+        )
+
+        # Redact one of the reactions.
+        self._redact(to_redact_event_id)
+
+        # The unredacted relation should still exist.
+        event_ids = self._get_related_events()
+        relations = self._get_bundled_aggregations()
+        self.assertEquals(event_ids, [unredacted_event_id])
+        self.assertDictContainsSubset(
+            {
+                "count": 1,
+                "current_user_participated": True,
+            },
+            relations[RelationTypes.THREAD],
+        )
+        # And the latest event is now the unredacted event.
         self.assertEqual(
-            channel.json_body,
-            {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
+            relations[RelationTypes.THREAD]["latest_event"]["event_id"],
+            unredacted_event_id,
         )
 
-    def test_redact_relation_edit(self) -> None:
+    def test_redact_parent_edit(self) -> None:
         """Test that edits of an event are redacted when the original event
         is redacted.
         """
         # Add a relation
-        channel = self._send_relation(
+        self._send_relation(
             RelationTypes.REPLACE,
             "m.room.message",
             parent_id=self.parent_id,
@@ -1401,54 +1548,83 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
                 "m.new_content": {"msgtype": "m.text", "body": "First edit"},
             },
         )
-        self.assertEqual(200, channel.code, channel.json_body)
 
         # Check the relation is returned
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/unstable/rooms/{self.room}/relations"
-            f"/{self.parent_id}/m.replace/m.room.message",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        self.assertIn("chunk", channel.json_body)
-        self.assertEqual(len(channel.json_body["chunk"]), 1)
+        event_ids = self._get_related_events()
+        relations = self._get_bundled_aggregations()
+        self.assertEqual(len(event_ids), 1)
+        self.assertIn(RelationTypes.REPLACE, relations)
 
         # Redact the original event
         self._redact(self.parent_id)
 
-        # Try to check for remaining m.replace relations
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/unstable/rooms/{self.room}/relations"
-            f"/{self.parent_id}/m.replace/m.room.message",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
+        # The relations are not returned.
+        event_ids = self._get_related_events()
+        relations = self._get_bundled_aggregations()
+        self.assertEqual(len(event_ids), 0)
+        self.assertEqual(relations, {})
 
-        # Check that no relations are returned
-        self.assertIn("chunk", channel.json_body)
-        self.assertEqual(channel.json_body["chunk"], [])
-
-    def test_redact_parent(self) -> None:
-        """Test that annotations of an event are redacted when the original event
+    def test_redact_parent_annotation(self) -> None:
+        """Test that annotations of an event are viewable when the original event
         is redacted.
         """
         # Add a relation
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
-        self.assertEqual(200, channel.code, channel.json_body)
+        related_event_id = channel.json_body["event_id"]
+
+        # The relations should exist.
+        event_ids = self._get_related_events()
+        relations = self._get_bundled_aggregations()
+        self.assertEqual(len(event_ids), 1)
+        self.assertIn(RelationTypes.ANNOTATION, relations)
+
+        # The aggregation should exist.
+        chunk = self._get_aggregations()
+        self.assertEqual(chunk, [{"type": "m.reaction", "key": "👍", "count": 1}])
 
         # Redact the original event.
         self._redact(self.parent_id)
 
-        # Check that aggregations returns zero
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}/m.annotation/m.reaction",
-            access_token=self.user_token,
+        # The relations are returned.
+        event_ids = self._get_related_events()
+        relations = self._get_bundled_aggregations()
+        self.assertEquals(event_ids, [related_event_id])
+        self.assertEquals(
+            relations["m.annotation"],
+            {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]},
         )
-        self.assertEqual(200, channel.code, channel.json_body)
 
-        self.assertIn("chunk", channel.json_body)
-        self.assertEqual(channel.json_body["chunk"], [])
+        # There's nothing to aggregate.
+        chunk = self._get_aggregations()
+        self.assertEqual(chunk, [{"count": 1, "key": "👍", "type": "m.reaction"}])
+
+    @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
+    def test_redact_parent_thread(self) -> None:
+        """
+        Test that thread replies are still available when the root event is redacted.
+        """
+        channel = self._send_relation(
+            RelationTypes.THREAD,
+            EventTypes.Message,
+            content={"body": "reply 1", "msgtype": "m.text"},
+        )
+        related_event_id = channel.json_body["event_id"]
+
+        # Redact one of the reactions.
+        self._redact(self.parent_id)
+
+        # The unredacted relation should still exist.
+        event_ids = self._get_related_events()
+        relations = self._get_bundled_aggregations()
+        self.assertEquals(len(event_ids), 1)
+        self.assertDictContainsSubset(
+            {
+                "count": 1,
+                "current_user_participated": True,
+            },
+            relations[RelationTypes.THREAD],
+        )
+        self.assertEqual(
+            relations[RelationTypes.THREAD]["latest_event"]["event_id"],
+            related_event_id,
+        )
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index f3bf8d0934..7b8fe6d025 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -24,6 +24,7 @@ from synapse.util import Clock
 from synapse.visibility import filter_events_for_client
 
 from tests import unittest
+from tests.unittest import override_config
 
 one_hour_ms = 3600000
 one_day_ms = one_hour_ms * 24
@@ -38,7 +39,10 @@ class RetentionTestCase(unittest.HomeserverTestCase):
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         config = self.default_config()
-        config["retention"] = {
+
+        # merge this default retention config with anything that was specified in
+        # @override_config
+        retention_config = {
             "enabled": True,
             "default_policy": {
                 "min_lifetime": one_day_ms,
@@ -47,6 +51,8 @@ class RetentionTestCase(unittest.HomeserverTestCase):
             "allowed_lifetime_min": one_day_ms,
             "allowed_lifetime_max": one_day_ms * 3,
         }
+        retention_config.update(config.get("retention", {}))
+        config["retention"] = retention_config
 
         self.hs = self.setup_test_homeserver(config=config)
 
@@ -115,22 +121,20 @@ class RetentionTestCase(unittest.HomeserverTestCase):
 
         self._test_retention_event_purged(room_id, one_day_ms * 2)
 
+    @override_config({"retention": {"purge_jobs": [{"interval": "5d"}]}})
     def test_visibility(self) -> None:
         """Tests that synapse.visibility.filter_events_for_client correctly filters out
-        outdated events
+        outdated events, even if the purge job hasn't got to them yet.
+
+        We do this by setting a very long time between purge jobs.
         """
         store = self.hs.get_datastores().main
         storage = self.hs.get_storage()
         room_id = self.helper.create_room_as(self.user_id, tok=self.token)
-        events = []
 
         # Send a first event, which should be filtered out at the end of the test.
         resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
-
-        # Get the event from the store so that we end up with a FrozenEvent that we can
-        # give to filter_events_for_client. We need to do this now because the event won't
-        # be in the database anymore after it has expired.
-        events.append(self.get_success(store.get_event(resp.get("event_id"))))
+        first_event_id = resp.get("event_id")
 
         # Advance the time by 2 days. We're using the default retention policy, therefore
         # after this the first event will still be valid.
@@ -138,16 +142,17 @@ class RetentionTestCase(unittest.HomeserverTestCase):
 
         # Send another event, which shouldn't get filtered out.
         resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
-
         valid_event_id = resp.get("event_id")
 
-        events.append(self.get_success(store.get_event(valid_event_id)))
-
         # Advance the time by another 2 days. After this, the first event should be
         # outdated but not the second one.
         self.reactor.advance(one_day_ms * 2 / 1000)
 
-        # Run filter_events_for_client with our list of FrozenEvents.
+        # Fetch the events, and run filter_events_for_client on them
+        events = self.get_success(
+            store.get_events_as_list([first_event_id, valid_event_id])
+        )
+        self.assertEqual(2, len(events), "events retrieved from database")
         filtered_events = self.get_success(
             filter_events_for_client(storage, self.user_id, events)
         )
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 37866ee330..3a9617d6da 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -2141,21 +2141,19 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
     def test_filter_relation_senders(self) -> None:
         # Messages which second user reacted to.
-        filter = {"io.element.relation_senders": [self.second_user_id]}
+        filter = {"related_by_senders": [self.second_user_id]}
         chunk = self._filter_messages(filter)
         self.assertEqual(len(chunk), 1, chunk)
         self.assertEqual(chunk[0]["event_id"], self.event_id_1)
 
         # Messages which third user reacted to.
-        filter = {"io.element.relation_senders": [self.third_user_id]}
+        filter = {"related_by_senders": [self.third_user_id]}
         chunk = self._filter_messages(filter)
         self.assertEqual(len(chunk), 1, chunk)
         self.assertEqual(chunk[0]["event_id"], self.event_id_2)
 
         # Messages which either user reacted to.
-        filter = {
-            "io.element.relation_senders": [self.second_user_id, self.third_user_id]
-        }
+        filter = {"related_by_senders": [self.second_user_id, self.third_user_id]}
         chunk = self._filter_messages(filter)
         self.assertEqual(len(chunk), 2, chunk)
         self.assertCountEqual(
@@ -2164,20 +2162,20 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
     def test_filter_relation_type(self) -> None:
         # Messages which have annotations.
-        filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]}
+        filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]}
         chunk = self._filter_messages(filter)
         self.assertEqual(len(chunk), 1, chunk)
         self.assertEqual(chunk[0]["event_id"], self.event_id_1)
 
         # Messages which have references.
-        filter = {"io.element.relation_types": [RelationTypes.REFERENCE]}
+        filter = {"related_by_rel_types": [RelationTypes.REFERENCE]}
         chunk = self._filter_messages(filter)
         self.assertEqual(len(chunk), 1, chunk)
         self.assertEqual(chunk[0]["event_id"], self.event_id_2)
 
         # Messages which have either annotations or references.
         filter = {
-            "io.element.relation_types": [
+            "related_by_rel_types": [
                 RelationTypes.ANNOTATION,
                 RelationTypes.REFERENCE,
             ]
@@ -2191,8 +2189,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
     def test_filter_relation_senders_and_type(self) -> None:
         # Messages which second user reacted to.
         filter = {
-            "io.element.relation_senders": [self.second_user_id],
-            "io.element.relation_types": [RelationTypes.ANNOTATION],
+            "related_by_senders": [self.second_user_id],
+            "related_by_rel_types": [RelationTypes.ANNOTATION],
         }
         chunk = self._filter_messages(filter)
         self.assertEqual(len(chunk), 1, chunk)
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 58f1ea11b7..e7de67e3a3 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -775,3 +775,124 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         self.assertEqual(args[0], user_id)
         self.assertFalse(args[1])
         self.assertTrue(args[2])
+
+    def test_check_can_deactivate_user(self) -> None:
+        """Tests that the on_user_deactivation_status_changed module callback is called
+        correctly when processing a user's deactivation.
+        """
+        # Register a mocked callback.
+        deactivation_mock = Mock(return_value=make_awaitable(False))
+        third_party_rules = self.hs.get_third_party_event_rules()
+        third_party_rules._check_can_deactivate_user_callbacks.append(
+            deactivation_mock,
+        )
+
+        # Register a user that we'll deactivate.
+        user_id = self.register_user("altan", "password")
+        tok = self.login("altan", "password")
+
+        # Deactivate that user.
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/v3/account/deactivate",
+            {
+                "auth": {
+                    "type": LoginType.PASSWORD,
+                    "password": "password",
+                    "identifier": {
+                        "type": "m.id.user",
+                        "user": user_id,
+                    },
+                },
+                "erase": True,
+            },
+            access_token=tok,
+        )
+
+        # Check that the deactivation was blocked
+        self.assertEqual(channel.code, 403, channel.json_body)
+
+        # Check that the mock was called once.
+        deactivation_mock.assert_called_once()
+        args = deactivation_mock.call_args[0]
+
+        # Check that the mock was called with the right user ID
+        self.assertEqual(args[0], user_id)
+
+        # Check that the request was not made by an admin
+        self.assertEqual(args[1], False)
+
+    def test_check_can_deactivate_user_admin(self) -> None:
+        """Tests that the on_user_deactivation_status_changed module callback is called
+        correctly when processing a user's deactivation triggered by a server admin.
+        """
+        # Register a mocked callback.
+        deactivation_mock = Mock(return_value=make_awaitable(False))
+        third_party_rules = self.hs.get_third_party_event_rules()
+        third_party_rules._check_can_deactivate_user_callbacks.append(
+            deactivation_mock,
+        )
+
+        # Register an admin user.
+        self.register_user("admin", "password", admin=True)
+        admin_tok = self.login("admin", "password")
+
+        # Register a user that we'll deactivate.
+        user_id = self.register_user("altan", "password")
+
+        # Deactivate the user.
+        channel = self.make_request(
+            "PUT",
+            "/_synapse/admin/v2/users/%s" % user_id,
+            {"deactivated": True},
+            access_token=admin_tok,
+        )
+
+        # Check that the deactivation was blocked
+        self.assertEqual(channel.code, 403, channel.json_body)
+
+        # Check that the mock was called once.
+        deactivation_mock.assert_called_once()
+        args = deactivation_mock.call_args[0]
+
+        # Check that the mock was called with the right user ID
+        self.assertEqual(args[0], user_id)
+
+        # Check that the mock was made by an admin
+        self.assertEqual(args[1], True)
+
+    def test_check_can_shutdown_room(self) -> None:
+        """Tests that the check_can_shutdown_room module callback is called
+        correctly when processing an admin's shutdown room request.
+        """
+        # Register a mocked callback.
+        shutdown_mock = Mock(return_value=make_awaitable(False))
+        third_party_rules = self.hs.get_third_party_event_rules()
+        third_party_rules._check_can_shutdown_room_callbacks.append(
+            shutdown_mock,
+        )
+
+        # Register an admin user.
+        admin_user_id = self.register_user("admin", "password", admin=True)
+        admin_tok = self.login("admin", "password")
+
+        # Shutdown the room.
+        channel = self.make_request(
+            "DELETE",
+            "/_synapse/admin/v2/rooms/%s" % self.room_id,
+            {},
+            access_token=admin_tok,
+        )
+
+        # Check that the shutdown was blocked
+        self.assertEqual(channel.code, 403, channel.json_body)
+
+        # Check that the mock was called once.
+        shutdown_mock.assert_called_once()
+        args = shutdown_mock.call_args[0]
+
+        # Check that the mock was called with the right user ID
+        self.assertEqual(args[0], admin_user_id)
+
+        # Check that the mock was called with the right room ID
+        self.assertEqual(args[1], self.room_id)
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index 3b5747cb12..8d8251b2ac 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -1,3 +1,18 @@
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from http import HTTPStatus
 from unittest.mock import Mock, call
 
 from twisted.internet import defer, reactor
@@ -11,14 +26,14 @@ from tests.utils import MockClock
 
 
 class HttpTransactionCacheTestCase(unittest.TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         self.clock = MockClock()
         self.hs = Mock()
         self.hs.get_clock = Mock(return_value=self.clock)
         self.hs.get_auth = Mock()
         self.cache = HttpTransactionCache(self.hs)
 
-        self.mock_http_response = (200, "GOOD JOB!")
+        self.mock_http_response = (HTTPStatus.OK, "GOOD JOB!")
         self.mock_key = "foo"
 
     @defer.inlineCallbacks
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 4672a68596..978c252f84 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -13,19 +13,24 @@
 # limitations under the License.
 import urllib.parse
 from io import BytesIO, StringIO
+from typing import Any, Dict, Optional, Union
 from unittest.mock import Mock
 
 import signedjson.key
 from canonicaljson import encode_canonical_json
-from nacl.signing import SigningKey
 from signedjson.sign import sign_json
+from signedjson.types import SigningKey
 
-from twisted.web.resource import NoResource
+from twisted.test.proto_helpers import MemoryReactor
+from twisted.web.resource import NoResource, Resource
 
 from synapse.crypto.keyring import PerspectivesKeyFetcher
 from synapse.http.site import SynapseRequest
 from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.server import HomeServer
 from synapse.storage.keys import FetchKeyResult
+from synapse.types import JsonDict
+from synapse.util import Clock
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.stringutils import random_string
 
@@ -35,11 +40,11 @@ from tests.utils import default_config
 
 
 class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.http_client = Mock()
         return self.setup_test_homeserver(federation_http_client=self.http_client)
 
-    def create_test_resource(self):
+    def create_test_resource(self) -> Resource:
         return create_resource_tree(
             {"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource()
         )
@@ -51,7 +56,12 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
         Tell the mock http client to expect an outgoing GET request for the given key
         """
 
-        async def get_json(destination, path, ignore_backoff=False, **kwargs):
+        async def get_json(
+            destination: str,
+            path: str,
+            ignore_backoff: bool = False,
+            **kwargs: Any,
+        ) -> Union[JsonDict, list]:
             self.assertTrue(ignore_backoff)
             self.assertEqual(destination, server_name)
             key_id = "%s:%s" % (signing_key.alg, signing_key.version)
@@ -84,7 +94,8 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
         Checks that the response is a 200 and returns the decoded json body.
         """
         channel = FakeChannel(self.site, self.reactor)
-        req = SynapseRequest(channel, self.site)
+        # channel is a `FakeChannel` but `HTTPChannel` is expected
+        req = SynapseRequest(channel, self.site)  # type: ignore[arg-type]
         req.content = BytesIO(b"")
         req.requestReceived(
             b"GET",
@@ -97,7 +108,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
         resp = channel.json_body
         return resp
 
-    def test_get_key(self):
+    def test_get_key(self) -> None:
         """Fetch a remote key"""
         SERVER_NAME = "remote.server"
         testkey = signedjson.key.generate_signing_key("ver1")
@@ -114,7 +125,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
         self.assertIn(SERVER_NAME, keys[0]["signatures"])
         self.assertIn(self.hs.hostname, keys[0]["signatures"])
 
-    def test_get_own_key(self):
+    def test_get_own_key(self) -> None:
         """Fetch our own key"""
         testkey = signedjson.key.generate_signing_key("ver1")
         self.expect_outgoing_key_request(self.hs.hostname, testkey)
@@ -141,7 +152,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
     endpoint, to check that the two implementations are compatible.
     """
 
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
 
         # replace the signing key with our own
@@ -152,7 +163,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
 
         return config
 
-    def prepare(self, reactor, clock, homeserver):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         # make a second homeserver, configured to use the first one as a key notary
         self.http_client2 = Mock()
         config = default_config(name="keyclient")
@@ -175,7 +186,9 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
 
         # wire up outbound POST /key/v2/query requests from hs2 so that they
         # will be forwarded to hs1
-        async def post_json(destination, path, data):
+        async def post_json(
+            destination: str, path: str, data: Optional[JsonDict] = None
+        ) -> Union[JsonDict, list]:
             self.assertEqual(destination, self.hs.hostname)
             self.assertEqual(
                 path,
@@ -183,7 +196,8 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
             )
 
             channel = FakeChannel(self.site, self.reactor)
-            req = SynapseRequest(channel, self.site)
+            # channel is a `FakeChannel` but `HTTPChannel` is expected
+            req = SynapseRequest(channel, self.site)  # type: ignore[arg-type]
             req.content = BytesIO(encode_canonical_json(data))
 
             req.requestReceived(
@@ -198,7 +212,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
 
         self.http_client2.post_json.side_effect = post_json
 
-    def test_get_key(self):
+    def test_get_key(self) -> None:
         """Fetch a key belonging to a random server"""
         # make up a key to be fetched.
         testkey = signedjson.key.generate_signing_key("abc")
@@ -218,7 +232,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
             signedjson.key.encode_verify_key_base64(testkey.verify_key),
         )
 
-    def test_get_notary_key(self):
+    def test_get_notary_key(self) -> None:
         """Fetch a key belonging to the notary server"""
         # make up a key to be fetched. We randomise the keyid to try to get it to
         # appear before the key server signing key sometimes (otherwise we bail out
@@ -240,7 +254,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
             signedjson.key.encode_verify_key_base64(testkey.verify_key),
         )
 
-    def test_get_notary_keyserver_key(self):
+    def test_get_notary_keyserver_key(self) -> None:
         """Fetch the notary's keyserver key"""
         # we expect hs1 to make a regular key request to itself
         self.expect_outgoing_key_request(self.hs.hostname, self.hs_signing_key)
diff --git a/tests/rest/media/v1/test_base.py b/tests/rest/media/v1/test_base.py
index f761e23f1b..c73179151a 100644
--- a/tests/rest/media/v1/test_base.py
+++ b/tests/rest/media/v1/test_base.py
@@ -28,11 +28,11 @@ class GetFileNameFromHeadersTests(unittest.TestCase):
         b"inline; filename*=utf-8''foo%C2%A3bar": "foo£bar",
     }
 
-    def tests(self):
+    def tests(self) -> None:
         for hdr, expected in self.TEST_CASES.items():
             res = get_filename_from_headers({b"Content-Disposition": [hdr]})
             self.assertEqual(
                 res,
                 expected,
-                "expected output for %s to be %s but was %s" % (hdr, expected, res),
+                f"expected output for {hdr!r} to be {expected} but was {res}",
             )
diff --git a/tests/rest/media/v1/test_filepath.py b/tests/rest/media/v1/test_filepath.py
index 913bc530aa..43e6f0f70a 100644
--- a/tests/rest/media/v1/test_filepath.py
+++ b/tests/rest/media/v1/test_filepath.py
@@ -21,12 +21,12 @@ from tests import unittest
 
 
 class MediaFilePathsTestCase(unittest.TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         super().setUp()
 
         self.filepaths = MediaFilePaths("/media_store")
 
-    def test_local_media_filepath(self):
+    def test_local_media_filepath(self) -> None:
         """Test local media paths"""
         self.assertEqual(
             self.filepaths.local_media_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"),
@@ -37,7 +37,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             "/media_store/local_content/Ge/rZ/NDnDZVjsOtardLuwfIBg",
         )
 
-    def test_local_media_thumbnail(self):
+    def test_local_media_thumbnail(self) -> None:
         """Test local media thumbnail paths"""
         self.assertEqual(
             self.filepaths.local_media_thumbnail_rel(
@@ -52,14 +52,14 @@ class MediaFilePathsTestCase(unittest.TestCase):
             "/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
         )
 
-    def test_local_media_thumbnail_dir(self):
+    def test_local_media_thumbnail_dir(self) -> None:
         """Test local media thumbnail directory paths"""
         self.assertEqual(
             self.filepaths.local_media_thumbnail_dir("GerZNDnDZVjsOtardLuwfIBg"),
             "/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
         )
 
-    def test_remote_media_filepath(self):
+    def test_remote_media_filepath(self) -> None:
         """Test remote media paths"""
         self.assertEqual(
             self.filepaths.remote_media_filepath_rel(
@@ -74,7 +74,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             "/media_store/remote_content/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg",
         )
 
-    def test_remote_media_thumbnail(self):
+    def test_remote_media_thumbnail(self) -> None:
         """Test remote media thumbnail paths"""
         self.assertEqual(
             self.filepaths.remote_media_thumbnail_rel(
@@ -99,7 +99,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             "/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
         )
 
-    def test_remote_media_thumbnail_legacy(self):
+    def test_remote_media_thumbnail_legacy(self) -> None:
         """Test old-style remote media thumbnail paths"""
         self.assertEqual(
             self.filepaths.remote_media_thumbnail_rel_legacy(
@@ -108,7 +108,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             "remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg",
         )
 
-    def test_remote_media_thumbnail_dir(self):
+    def test_remote_media_thumbnail_dir(self) -> None:
         """Test remote media thumbnail directory paths"""
         self.assertEqual(
             self.filepaths.remote_media_thumbnail_dir(
@@ -117,7 +117,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             "/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg",
         )
 
-    def test_url_cache_filepath(self):
+    def test_url_cache_filepath(self) -> None:
         """Test URL cache paths"""
         self.assertEqual(
             self.filepaths.url_cache_filepath_rel("2020-01-02_GerZNDnDZVjsOtar"),
@@ -128,7 +128,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             "/media_store/url_cache/2020-01-02/GerZNDnDZVjsOtar",
         )
 
-    def test_url_cache_filepath_legacy(self):
+    def test_url_cache_filepath_legacy(self) -> None:
         """Test old-style URL cache paths"""
         self.assertEqual(
             self.filepaths.url_cache_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"),
@@ -139,7 +139,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             "/media_store/url_cache/Ge/rZ/NDnDZVjsOtardLuwfIBg",
         )
 
-    def test_url_cache_filepath_dirs_to_delete(self):
+    def test_url_cache_filepath_dirs_to_delete(self) -> None:
         """Test URL cache cleanup paths"""
         self.assertEqual(
             self.filepaths.url_cache_filepath_dirs_to_delete(
@@ -148,7 +148,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             ["/media_store/url_cache/2020-01-02"],
         )
 
-    def test_url_cache_filepath_dirs_to_delete_legacy(self):
+    def test_url_cache_filepath_dirs_to_delete_legacy(self) -> None:
         """Test old-style URL cache cleanup paths"""
         self.assertEqual(
             self.filepaths.url_cache_filepath_dirs_to_delete(
@@ -160,7 +160,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             ],
         )
 
-    def test_url_cache_thumbnail(self):
+    def test_url_cache_thumbnail(self) -> None:
         """Test URL cache thumbnail paths"""
         self.assertEqual(
             self.filepaths.url_cache_thumbnail_rel(
@@ -175,7 +175,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar/800-600-image-jpeg-scale",
         )
 
-    def test_url_cache_thumbnail_legacy(self):
+    def test_url_cache_thumbnail_legacy(self) -> None:
         """Test old-style URL cache thumbnail paths"""
         self.assertEqual(
             self.filepaths.url_cache_thumbnail_rel(
@@ -190,7 +190,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
         )
 
-    def test_url_cache_thumbnail_directory(self):
+    def test_url_cache_thumbnail_directory(self) -> None:
         """Test URL cache thumbnail directory paths"""
         self.assertEqual(
             self.filepaths.url_cache_thumbnail_directory_rel(
@@ -203,7 +203,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar",
         )
 
-    def test_url_cache_thumbnail_directory_legacy(self):
+    def test_url_cache_thumbnail_directory_legacy(self) -> None:
         """Test old-style URL cache thumbnail directory paths"""
         self.assertEqual(
             self.filepaths.url_cache_thumbnail_directory_rel(
@@ -216,7 +216,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
         )
 
-    def test_url_cache_thumbnail_dirs_to_delete(self):
+    def test_url_cache_thumbnail_dirs_to_delete(self) -> None:
         """Test URL cache thumbnail cleanup paths"""
         self.assertEqual(
             self.filepaths.url_cache_thumbnail_dirs_to_delete(
@@ -228,7 +228,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             ],
         )
 
-    def test_url_cache_thumbnail_dirs_to_delete_legacy(self):
+    def test_url_cache_thumbnail_dirs_to_delete_legacy(self) -> None:
         """Test old-style URL cache thumbnail cleanup paths"""
         self.assertEqual(
             self.filepaths.url_cache_thumbnail_dirs_to_delete(
@@ -241,7 +241,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             ],
         )
 
-    def test_server_name_validation(self):
+    def test_server_name_validation(self) -> None:
         """Test validation of server names"""
         self._test_path_validation(
             [
@@ -274,7 +274,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             ],
         )
 
-    def test_file_id_validation(self):
+    def test_file_id_validation(self) -> None:
         """Test validation of local, remote and legacy URL cache file / media IDs"""
         # File / media IDs get split into three parts to form paths, consisting of the
         # first two characters, next two characters and rest of the ID.
@@ -357,7 +357,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             invalid_values=invalid_file_ids,
         )
 
-    def test_url_cache_media_id_validation(self):
+    def test_url_cache_media_id_validation(self) -> None:
         """Test validation of URL cache media IDs"""
         self._test_path_validation(
             [
@@ -387,7 +387,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             ],
         )
 
-    def test_content_type_validation(self):
+    def test_content_type_validation(self) -> None:
         """Test validation of thumbnail content types"""
         self._test_path_validation(
             [
@@ -410,7 +410,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             ],
         )
 
-    def test_thumbnail_method_validation(self):
+    def test_thumbnail_method_validation(self) -> None:
         """Test validation of thumbnail methods"""
         self._test_path_validation(
             [
@@ -440,7 +440,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
         parameter: str,
         valid_values: Iterable[str],
         invalid_values: Iterable[str],
-    ):
+    ) -> None:
         """Test that the specified methods validate the named parameter as expected
 
         Args:
diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/rest/media/v1/test_html_preview.py
index a4b57e3d1f..62e308814d 100644
--- a/tests/rest/media/v1/test_html_preview.py
+++ b/tests/rest/media/v1/test_html_preview.py
@@ -16,7 +16,6 @@ from synapse.rest.media.v1.preview_html import (
     _get_html_media_encodings,
     decode_body,
     parse_html_to_open_graph,
-    rebase_url,
     summarize_paragraphs,
 )
 
@@ -32,7 +31,7 @@ class SummarizeTestCase(unittest.TestCase):
     if not lxml:
         skip = "url preview feature requires lxml"
 
-    def test_long_summarize(self):
+    def test_long_summarize(self) -> None:
         example_paras = [
             """Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:
             Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in
@@ -90,7 +89,7 @@ class SummarizeTestCase(unittest.TestCase):
             " Tromsøya had a population of 36,088. Substantial parts of the urban…",
         )
 
-    def test_short_summarize(self):
+    def test_short_summarize(self) -> None:
         example_paras = [
             "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
             " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -117,7 +116,7 @@ class SummarizeTestCase(unittest.TestCase):
             " most of the year.",
         )
 
-    def test_small_then_large_summarize(self):
+    def test_small_then_large_summarize(self) -> None:
         example_paras = [
             "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
             " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -150,7 +149,7 @@ class CalcOgTestCase(unittest.TestCase):
     if not lxml:
         skip = "url preview feature requires lxml"
 
-    def test_simple(self):
+    def test_simple(self) -> None:
         html = b"""
         <html>
         <head><title>Foo</title></head>
@@ -161,11 +160,11 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
 
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
-    def test_comment(self):
+    def test_comment(self) -> None:
         html = b"""
         <html>
         <head><title>Foo</title></head>
@@ -177,11 +176,11 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
 
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
-    def test_comment2(self):
+    def test_comment2(self) -> None:
         html = b"""
         <html>
         <head><title>Foo</title></head>
@@ -196,7 +195,7 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
 
         self.assertEqual(
             og,
@@ -206,7 +205,7 @@ class CalcOgTestCase(unittest.TestCase):
             },
         )
 
-    def test_script(self):
+    def test_script(self) -> None:
         html = b"""
         <html>
         <head><title>Foo</title></head>
@@ -218,11 +217,11 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
 
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
-    def test_missing_title(self):
+    def test_missing_title(self) -> None:
         html = b"""
         <html>
         <body>
@@ -232,11 +231,11 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
 
         self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
 
-    def test_h1_as_title(self):
+    def test_h1_as_title(self) -> None:
         html = b"""
         <html>
         <meta property="og:description" content="Some text."/>
@@ -247,11 +246,11 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
 
         self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
 
-    def test_missing_title_and_broken_h1(self):
+    def test_missing_title_and_broken_h1(self) -> None:
         html = b"""
         <html>
         <body>
@@ -262,23 +261,23 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
 
         self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
 
-    def test_empty(self):
+    def test_empty(self) -> None:
         """Test a body with no data in it."""
         html = b""
         tree = decode_body(html, "http://example.com/test.html")
         self.assertIsNone(tree)
 
-    def test_no_tree(self):
+    def test_no_tree(self) -> None:
         """A valid body with no tree in it."""
         html = b"\x00"
         tree = decode_body(html, "http://example.com/test.html")
         self.assertIsNone(tree)
 
-    def test_xml(self):
+    def test_xml(self) -> None:
         """Test decoding XML and ensure it works properly."""
         # Note that the strip() call is important to ensure the xml tag starts
         # at the initial byte.
@@ -290,10 +289,10 @@ class CalcOgTestCase(unittest.TestCase):
         <head><title>Foo</title></head><body>Some text.</body></html>
         """.strip()
         tree = decode_body(html, "http://example.com/test.html")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
-    def test_invalid_encoding(self):
+    def test_invalid_encoding(self) -> None:
         """An invalid character encoding should be ignored and treated as UTF-8, if possible."""
         html = b"""
         <html>
@@ -304,10 +303,10 @@ class CalcOgTestCase(unittest.TestCase):
         </html>
         """
         tree = decode_body(html, "http://example.com/test.html", "invalid-encoding")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
-    def test_invalid_encoding2(self):
+    def test_invalid_encoding2(self) -> None:
         """A body which doesn't match the sent character encoding."""
         # Note that this contains an invalid UTF-8 sequence in the title.
         html = b"""
@@ -319,10 +318,10 @@ class CalcOgTestCase(unittest.TestCase):
         </html>
         """
         tree = decode_body(html, "http://example.com/test.html")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
         self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
 
-    def test_windows_1252(self):
+    def test_windows_1252(self) -> None:
         """A body which uses cp1252, but doesn't declare that."""
         html = b"""
         <html>
@@ -333,12 +332,12 @@ class CalcOgTestCase(unittest.TestCase):
         </html>
         """
         tree = decode_body(html, "http://example.com/test.html")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
         self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."})
 
 
 class MediaEncodingTestCase(unittest.TestCase):
-    def test_meta_charset(self):
+    def test_meta_charset(self) -> None:
         """A character encoding is found via the meta tag."""
         encodings = _get_html_media_encodings(
             b"""
@@ -363,7 +362,7 @@ class MediaEncodingTestCase(unittest.TestCase):
         )
         self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
 
-    def test_meta_charset_underscores(self):
+    def test_meta_charset_underscores(self) -> None:
         """A character encoding contains underscore."""
         encodings = _get_html_media_encodings(
             b"""
@@ -376,7 +375,7 @@ class MediaEncodingTestCase(unittest.TestCase):
         )
         self.assertEqual(list(encodings), ["shift_jis", "utf-8", "cp1252"])
 
-    def test_xml_encoding(self):
+    def test_xml_encoding(self) -> None:
         """A character encoding is found via the meta tag."""
         encodings = _get_html_media_encodings(
             b"""
@@ -388,7 +387,7 @@ class MediaEncodingTestCase(unittest.TestCase):
         )
         self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
 
-    def test_meta_xml_encoding(self):
+    def test_meta_xml_encoding(self) -> None:
         """Meta tags take precedence over XML encoding."""
         encodings = _get_html_media_encodings(
             b"""
@@ -402,7 +401,7 @@ class MediaEncodingTestCase(unittest.TestCase):
         )
         self.assertEqual(list(encodings), ["utf-16", "ascii", "utf-8", "cp1252"])
 
-    def test_content_type(self):
+    def test_content_type(self) -> None:
         """A character encoding is found via the Content-Type header."""
         # Test a few variations of the header.
         headers = (
@@ -417,12 +416,12 @@ class MediaEncodingTestCase(unittest.TestCase):
             encodings = _get_html_media_encodings(b"", header)
             self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
 
-    def test_fallback(self):
+    def test_fallback(self) -> None:
         """A character encoding cannot be found in the body or header."""
         encodings = _get_html_media_encodings(b"", "text/html")
         self.assertEqual(list(encodings), ["utf-8", "cp1252"])
 
-    def test_duplicates(self):
+    def test_duplicates(self) -> None:
         """Ensure each encoding is only attempted once."""
         encodings = _get_html_media_encodings(
             b"""
@@ -436,7 +435,7 @@ class MediaEncodingTestCase(unittest.TestCase):
         )
         self.assertEqual(list(encodings), ["utf-8", "cp1252"])
 
-    def test_unknown_invalid(self):
+    def test_unknown_invalid(self) -> None:
         """A character encoding should be ignored if it is unknown or invalid."""
         encodings = _get_html_media_encodings(
             b"""
@@ -448,34 +447,3 @@ class MediaEncodingTestCase(unittest.TestCase):
             'text/html; charset="invalid"',
         )
         self.assertEqual(list(encodings), ["utf-8", "cp1252"])
-
-
-class RebaseUrlTestCase(unittest.TestCase):
-    def test_relative(self):
-        """Relative URLs should be resolved based on the context of the base URL."""
-        self.assertEqual(
-            rebase_url("subpage", "https://example.com/foo/"),
-            "https://example.com/foo/subpage",
-        )
-        self.assertEqual(
-            rebase_url("sibling", "https://example.com/foo"),
-            "https://example.com/sibling",
-        )
-        self.assertEqual(
-            rebase_url("/bar", "https://example.com/foo/"),
-            "https://example.com/bar",
-        )
-
-    def test_absolute(self):
-        """Absolute URLs should not be modified."""
-        self.assertEqual(
-            rebase_url("https://alice.com/a/", "https://example.com/foo/"),
-            "https://alice.com/a/",
-        )
-
-    def test_data(self):
-        """Data URLs should not be modified."""
-        self.assertEqual(
-            rebase_url("data:,Hello%2C%20World%21", "https://example.com/foo/"),
-            "data:,Hello%2C%20World%21",
-        )
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index cba9be17c4..7204b2dfe0 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -16,7 +16,7 @@ import shutil
 import tempfile
 from binascii import unhexlify
 from io import BytesIO
-from typing import Optional
+from typing import Any, BinaryIO, Dict, List, Optional, Union
 from unittest.mock import Mock
 from urllib import parse
 
@@ -26,18 +26,24 @@ from PIL import Image as Image
 
 from twisted.internet import defer
 from twisted.internet.defer import Deferred
+from twisted.test.proto_helpers import MemoryReactor
 
+from synapse.events import EventBase
 from synapse.events.spamcheck import load_legacy_spam_checkers
 from synapse.logging.context import make_deferred_yieldable
+from synapse.module_api import ModuleApi
 from synapse.rest import admin
 from synapse.rest.client import login
 from synapse.rest.media.v1._base import FileInfo
 from synapse.rest.media.v1.filepath import MediaFilePaths
-from synapse.rest.media.v1.media_storage import MediaStorage
+from synapse.rest.media.v1.media_storage import MediaStorage, ReadableFileWrapper
 from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
+from synapse.server import HomeServer
+from synapse.types import RoomAlias
+from synapse.util import Clock
 
 from tests import unittest
-from tests.server import FakeSite, make_request
+from tests.server import FakeChannel, FakeSite, make_request
 from tests.test_utils import SMALL_PNG
 from tests.utils import default_config
 
@@ -46,7 +52,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
 
     needs_threadpool = True
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
         self.addCleanup(shutil.rmtree, self.test_dir)
 
@@ -62,7 +68,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
             hs, self.primary_base_path, self.filepaths, storage_providers
         )
 
-    def test_ensure_media_is_in_local_cache(self):
+    def test_ensure_media_is_in_local_cache(self) -> None:
         media_id = "some_media_id"
         test_body = "Test\n"
 
@@ -105,7 +111,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
         self.assertEqual(test_body, body)
 
 
-@attr.s(slots=True, frozen=True)
+@attr.s(auto_attribs=True, slots=True, frozen=True)
 class _TestImage:
     """An image for testing thumbnailing with the expected results
 
@@ -121,18 +127,18 @@ class _TestImage:
             a 404 is expected.
     """
 
-    data = attr.ib(type=bytes)
-    content_type = attr.ib(type=bytes)
-    extension = attr.ib(type=bytes)
-    expected_cropped = attr.ib(type=Optional[bytes], default=None)
-    expected_scaled = attr.ib(type=Optional[bytes], default=None)
-    expected_found = attr.ib(default=True, type=bool)
+    data: bytes
+    content_type: bytes
+    extension: bytes
+    expected_cropped: Optional[bytes] = None
+    expected_scaled: Optional[bytes] = None
+    expected_found: bool = True
 
 
 @parameterized_class(
     ("test_image",),
     [
-        # smoll png
+        # small png
         (
             _TestImage(
                 SMALL_PNG,
@@ -193,11 +199,17 @@ class MediaRepoTests(unittest.HomeserverTestCase):
     hijack_auth = True
     user_id = "@test:user"
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
 
         self.fetches = []
 
-        def get_file(destination, path, output_stream, args=None, max_size=None):
+        def get_file(
+            destination: str,
+            path: str,
+            output_stream: BinaryIO,
+            args: Optional[Dict[str, Union[str, List[str]]]] = None,
+            max_size: Optional[int] = None,
+        ) -> Deferred:
             """
             Returns tuple[int,dict,str,int] of file length, response headers,
             absolute URI, and response code.
@@ -238,7 +250,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
 
         return hs
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
 
         media_resource = hs.get_media_repository_resource()
         self.download_resource = media_resource.children[b"download"]
@@ -248,8 +260,9 @@ class MediaRepoTests(unittest.HomeserverTestCase):
 
         self.media_id = "example.com/12345"
 
-    def _req(self, content_disposition, include_content_type=True):
-
+    def _req(
+        self, content_disposition: Optional[bytes], include_content_type: bool = True
+    ) -> FakeChannel:
         channel = make_request(
             self.reactor,
             FakeSite(self.download_resource, self.reactor),
@@ -288,7 +301,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
 
         return channel
 
-    def test_handle_missing_content_type(self):
+    def test_handle_missing_content_type(self) -> None:
         channel = self._req(
             b"inline; filename=out" + self.test_image.extension,
             include_content_type=False,
@@ -299,7 +312,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
             headers.getRawHeaders(b"Content-Type"), [b"application/octet-stream"]
         )
 
-    def test_disposition_filename_ascii(self):
+    def test_disposition_filename_ascii(self) -> None:
         """
         If the filename is filename=<ascii> then Synapse will decode it as an
         ASCII string, and use filename= in the response.
@@ -315,7 +328,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
             [b"inline; filename=out" + self.test_image.extension],
         )
 
-    def test_disposition_filenamestar_utf8escaped(self):
+    def test_disposition_filenamestar_utf8escaped(self) -> None:
         """
         If the filename is filename=*utf8''<utf8 escaped> then Synapse will
         correctly decode it as the UTF-8 string, and use filename* in the
@@ -335,7 +348,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
             [b"inline; filename*=utf-8''" + filename + self.test_image.extension],
         )
 
-    def test_disposition_none(self):
+    def test_disposition_none(self) -> None:
         """
         If there is no filename, one isn't passed on in the Content-Disposition
         of the request.
@@ -348,26 +361,26 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         )
         self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
 
-    def test_thumbnail_crop(self):
+    def test_thumbnail_crop(self) -> None:
         """Test that a cropped remote thumbnail is available."""
         self._test_thumbnail(
             "crop", self.test_image.expected_cropped, self.test_image.expected_found
         )
 
-    def test_thumbnail_scale(self):
+    def test_thumbnail_scale(self) -> None:
         """Test that a scaled remote thumbnail is available."""
         self._test_thumbnail(
             "scale", self.test_image.expected_scaled, self.test_image.expected_found
         )
 
-    def test_invalid_type(self):
+    def test_invalid_type(self) -> None:
         """An invalid thumbnail type is never available."""
         self._test_thumbnail("invalid", None, False)
 
     @unittest.override_config(
         {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
     )
-    def test_no_thumbnail_crop(self):
+    def test_no_thumbnail_crop(self) -> None:
         """
         Override the config to generate only scaled thumbnails, but request a cropped one.
         """
@@ -376,13 +389,13 @@ class MediaRepoTests(unittest.HomeserverTestCase):
     @unittest.override_config(
         {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
     )
-    def test_no_thumbnail_scale(self):
+    def test_no_thumbnail_scale(self) -> None:
         """
         Override the config to generate only cropped thumbnails, but request a scaled one.
         """
         self._test_thumbnail("scale", None, False)
 
-    def test_thumbnail_repeated_thumbnail(self):
+    def test_thumbnail_repeated_thumbnail(self) -> None:
         """Test that fetching the same thumbnail works, and deleting the on disk
         thumbnail regenerates it.
         """
@@ -443,7 +456,9 @@ class MediaRepoTests(unittest.HomeserverTestCase):
                 channel.result["body"],
             )
 
-    def _test_thumbnail(self, method, expected_body, expected_found):
+    def _test_thumbnail(
+        self, method: str, expected_body: Optional[bytes], expected_found: bool
+    ) -> None:
         params = "?width=32&height=32&method=" + method
         channel = make_request(
             self.reactor,
@@ -485,7 +500,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
             )
 
     @parameterized.expand([("crop", 16), ("crop", 64), ("scale", 16), ("scale", 64)])
-    def test_same_quality(self, method, desired_size):
+    def test_same_quality(self, method: str, desired_size: int) -> None:
         """Test that choosing between thumbnails with the same quality rating succeeds.
 
         We are not particular about which thumbnail is chosen."""
@@ -521,7 +536,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
             )
         )
 
-    def test_x_robots_tag_header(self):
+    def test_x_robots_tag_header(self) -> None:
         """
         Tests that the `X-Robots-Tag` header is present, which informs web crawlers
         to not index, archive, or follow links in media.
@@ -540,29 +555,38 @@ class TestSpamChecker:
     `evil`.
     """
 
-    def __init__(self, config, api):
+    def __init__(self, config: Dict[str, Any], api: ModuleApi) -> None:
         self.config = config
         self.api = api
 
-    def parse_config(config):
+    def parse_config(config: Dict[str, Any]) -> Dict[str, Any]:
         return config
 
-    async def check_event_for_spam(self, foo):
+    async def check_event_for_spam(self, event: EventBase) -> Union[bool, str]:
         return False  # allow all events
 
-    async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
+    async def user_may_invite(
+        self,
+        inviter_userid: str,
+        invitee_userid: str,
+        room_id: str,
+    ) -> bool:
         return True  # allow all invites
 
-    async def user_may_create_room(self, userid):
+    async def user_may_create_room(self, userid: str) -> bool:
         return True  # allow all room creations
 
-    async def user_may_create_room_alias(self, userid, room_alias):
+    async def user_may_create_room_alias(
+        self, userid: str, room_alias: RoomAlias
+    ) -> bool:
         return True  # allow all room aliases
 
-    async def user_may_publish_room(self, userid, room_id):
+    async def user_may_publish_room(self, userid: str, room_id: str) -> bool:
         return True  # allow publishing of all rooms
 
-    async def check_media_file_for_spam(self, file_wrapper, file_info) -> bool:
+    async def check_media_file_for_spam(
+        self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
+    ) -> bool:
         buf = BytesIO()
         await file_wrapper.write_chunks_to(buf.write)
 
@@ -575,7 +599,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
         admin.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.user = self.register_user("user", "pass")
         self.tok = self.login("user", "pass")
 
@@ -586,7 +610,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
 
         load_legacy_spam_checkers(hs)
 
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = default_config("test")
 
         config.update(
@@ -602,13 +626,13 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
 
         return config
 
-    def test_upload_innocent(self):
+    def test_upload_innocent(self) -> None:
         """Attempt to upload some innocent data that should be allowed."""
         self.helper.upload_media(
             self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200
         )
 
-    def test_upload_ban(self):
+    def test_upload_ban(self) -> None:
         """Attempt to upload some data that includes bytes "evil", which should
         get rejected by the spam checker.
         """
diff --git a/tests/rest/media/v1/test_oembed.py b/tests/rest/media/v1/test_oembed.py
index 048d0ca44a..f38d7225f8 100644
--- a/tests/rest/media/v1/test_oembed.py
+++ b/tests/rest/media/v1/test_oembed.py
@@ -16,7 +16,7 @@ import json
 
 from twisted.test.proto_helpers import MemoryReactor
 
-from synapse.rest.media.v1.oembed import OEmbedProvider
+from synapse.rest.media.v1.oembed import OEmbedProvider, OEmbedResult
 from synapse.server import HomeServer
 from synapse.types import JsonDict
 from synapse.util import Clock
@@ -25,15 +25,15 @@ from tests.unittest import HomeserverTestCase
 
 
 class OEmbedTests(HomeserverTestCase):
-    def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
-        self.oembed = OEmbedProvider(homeserver)
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.oembed = OEmbedProvider(hs)
 
-    def parse_response(self, response: JsonDict):
+    def parse_response(self, response: JsonDict) -> OEmbedResult:
         return self.oembed.parse_oembed_response(
             "https://test", json.dumps(response).encode("utf-8")
         )
 
-    def test_version(self):
+    def test_version(self) -> None:
         """Accept versions that are similar to 1.0 as a string or int (or missing)."""
         for version in ("1.0", 1.0, 1):
             result = self.parse_response({"version": version, "type": "link"})
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index da2c533260..5148c39874 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -16,16 +16,21 @@ import base64
 import json
 import os
 import re
+from typing import Any, Dict, Optional, Sequence, Tuple, Type
 from urllib.parse import urlencode
 
 from twisted.internet._resolver import HostResolution
 from twisted.internet.address import IPv4Address, IPv6Address
 from twisted.internet.error import DNSLookupError
-from twisted.test.proto_helpers import AccumulatingProtocol
+from twisted.internet.interfaces import IAddress, IResolutionReceiver
+from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor
 
 from synapse.config.oembed import OEmbedEndpointConfig
+from synapse.rest.media.v1.media_repository import MediaRepositoryResource
 from synapse.rest.media.v1.preview_url_resource import IMAGE_CACHE_EXPIRY_MS
+from synapse.server import HomeServer
 from synapse.types import JsonDict
+from synapse.util import Clock
 from synapse.util.stringutils import parse_and_validate_mxc_uri
 
 from tests import unittest
@@ -52,7 +57,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         b"</head></html>"
     )
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
 
         config = self.default_config()
         config["url_preview_enabled"] = True
@@ -113,22 +118,22 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         return hs
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
 
         self.media_repo = hs.get_media_repository_resource()
         self.preview_url = self.media_repo.children[b"preview_url"]
 
-        self.lookups = {}
+        self.lookups: Dict[str, Any] = {}
 
         class Resolver:
             def resolveHostName(
                 _self,
-                resolutionReceiver,
-                hostName,
-                portNumber=0,
-                addressTypes=None,
-                transportSemantics="TCP",
-            ):
+                resolutionReceiver: IResolutionReceiver,
+                hostName: str,
+                portNumber: int = 0,
+                addressTypes: Optional[Sequence[Type[IAddress]]] = None,
+                transportSemantics: str = "TCP",
+            ) -> IResolutionReceiver:
 
                 resolution = HostResolution(hostName)
                 resolutionReceiver.resolutionBegan(resolution)
@@ -140,9 +145,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
                 resolutionReceiver.resolutionComplete()
                 return resolutionReceiver
 
-        self.reactor.nameResolver = Resolver()
+        self.reactor.nameResolver = Resolver()  # type: ignore[assignment]
 
-    def create_test_resource(self):
+    def create_test_resource(self) -> MediaRepositoryResource:
         return self.hs.get_media_repository_resource()
 
     def _assert_small_png(self, json_body: JsonDict) -> None:
@@ -153,7 +158,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertEqual(json_body["og:image:type"], "image/png")
         self.assertEqual(json_body["matrix:image:size"], 67)
 
-    def test_cache_returns_correct_type(self):
+    def test_cache_returns_correct_type(self) -> None:
         self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
 
         channel = self.make_request(
@@ -207,7 +212,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
         )
 
-    def test_non_ascii_preview_httpequiv(self):
+    def test_non_ascii_preview_httpequiv(self) -> None:
         self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
 
         end_content = (
@@ -243,7 +248,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200)
         self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
 
-    def test_video_rejected(self):
+    def test_video_rejected(self) -> None:
         self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
 
         end_content = b"anything"
@@ -279,7 +284,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             },
         )
 
-    def test_audio_rejected(self):
+    def test_audio_rejected(self) -> None:
         self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
 
         end_content = b"anything"
@@ -315,7 +320,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             },
         )
 
-    def test_non_ascii_preview_content_type(self):
+    def test_non_ascii_preview_content_type(self) -> None:
         self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
 
         end_content = (
@@ -350,7 +355,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200)
         self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
 
-    def test_overlong_title(self):
+    def test_overlong_title(self) -> None:
         self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
 
         end_content = (
@@ -387,7 +392,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         # We should only see the `og:description` field, as `title` is too long and should be stripped out
         self.assertCountEqual(["og:description"], res.keys())
 
-    def test_ipaddr(self):
+    def test_ipaddr(self) -> None:
         """
         IP addresses can be previewed directly.
         """
@@ -417,7 +422,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
         )
 
-    def test_blacklisted_ip_specific(self):
+    def test_blacklisted_ip_specific(self) -> None:
         """
         Blacklisted IP addresses, found via DNS, are not spidered.
         """
@@ -438,7 +443,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             },
         )
 
-    def test_blacklisted_ip_range(self):
+    def test_blacklisted_ip_range(self) -> None:
         """
         Blacklisted IP ranges, IPs found over DNS, are not spidered.
         """
@@ -457,7 +462,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             },
         )
 
-    def test_blacklisted_ip_specific_direct(self):
+    def test_blacklisted_ip_specific_direct(self) -> None:
         """
         Blacklisted IP addresses, accessed directly, are not spidered.
         """
@@ -476,7 +481,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         )
         self.assertEqual(channel.code, 403)
 
-    def test_blacklisted_ip_range_direct(self):
+    def test_blacklisted_ip_range_direct(self) -> None:
         """
         Blacklisted IP ranges, accessed directly, are not spidered.
         """
@@ -493,7 +498,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             },
         )
 
-    def test_blacklisted_ip_range_whitelisted_ip(self):
+    def test_blacklisted_ip_range_whitelisted_ip(self) -> None:
         """
         Blacklisted but then subsequently whitelisted IP addresses can be
         spidered.
@@ -526,7 +531,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
         )
 
-    def test_blacklisted_ip_with_external_ip(self):
+    def test_blacklisted_ip_with_external_ip(self) -> None:
         """
         If a hostname resolves a blacklisted IP, even if there's a
         non-blacklisted one, it will be rejected.
@@ -549,7 +554,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             },
         )
 
-    def test_blacklisted_ipv6_specific(self):
+    def test_blacklisted_ipv6_specific(self) -> None:
         """
         Blacklisted IP addresses, found via DNS, are not spidered.
         """
@@ -572,7 +577,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             },
         )
 
-    def test_blacklisted_ipv6_range(self):
+    def test_blacklisted_ipv6_range(self) -> None:
         """
         Blacklisted IP ranges, IPs found over DNS, are not spidered.
         """
@@ -591,7 +596,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             },
         )
 
-    def test_OPTIONS(self):
+    def test_OPTIONS(self) -> None:
         """
         OPTIONS returns the OPTIONS.
         """
@@ -601,7 +606,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200)
         self.assertEqual(channel.json_body, {})
 
-    def test_accept_language_config_option(self):
+    def test_accept_language_config_option(self) -> None:
         """
         Accept-Language header is sent to the remote server
         """
@@ -652,7 +657,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             server.data,
         )
 
-    def test_data_url(self):
+    def test_data_url(self) -> None:
         """
         Requesting to preview a data URL is not supported.
         """
@@ -675,7 +680,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         self.assertEqual(channel.code, 500)
 
-    def test_inline_data_url(self):
+    def test_inline_data_url(self) -> None:
         """
         An inline image (as a data URL) should be parsed properly.
         """
@@ -712,7 +717,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200)
         self._assert_small_png(channel.json_body)
 
-    def test_oembed_photo(self):
+    def test_oembed_photo(self) -> None:
         """Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL."""
         self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
         self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
@@ -771,7 +776,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertEqual(body["og:url"], "http://twitter.com/matrixdotorg/status/12345")
         self._assert_small_png(body)
 
-    def test_oembed_rich(self):
+    def test_oembed_rich(self) -> None:
         """Test an oEmbed endpoint which returns HTML content via the 'rich' type."""
         self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
 
@@ -817,7 +822,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             },
         )
 
-    def test_oembed_format(self):
+    def test_oembed_format(self) -> None:
         """Test an oEmbed endpoint which requires the format in the URL."""
         self.lookups["www.hulu.com"] = [(IPv4Address, "10.1.2.3")]
 
@@ -866,7 +871,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             },
         )
 
-    def test_oembed_autodiscovery(self):
+    def test_oembed_autodiscovery(self) -> None:
         """
         Autodiscovery works by finding the link in the HTML response and then requesting an oEmbed URL.
         1. Request a preview of a URL which is not known to the oEmbed code.
@@ -962,7 +967,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         )
         self._assert_small_png(body)
 
-    def _download_image(self):
+    def _download_image(self) -> Tuple[str, str]:
         """Downloads an image into the URL cache.
         Returns:
             A (host, media_id) tuple representing the MXC URI of the image.
@@ -995,7 +1000,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertIsNone(_port)
         return host, media_id
 
-    def test_storage_providers_exclude_files(self):
+    def test_storage_providers_exclude_files(self) -> None:
         """Test that files are not stored in or fetched from storage providers."""
         host, media_id = self._download_image()
 
@@ -1037,7 +1042,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             "URL cache file was unexpectedly retrieved from a storage provider",
         )
 
-    def test_storage_providers_exclude_thumbnails(self):
+    def test_storage_providers_exclude_thumbnails(self) -> None:
         """Test that thumbnails are not stored in or fetched from storage providers."""
         host, media_id = self._download_image()
 
@@ -1090,7 +1095,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             "URL cache thumbnail was unexpectedly retrieved from a storage provider",
         )
 
-    def test_cache_expiry(self):
+    def test_cache_expiry(self) -> None:
         """Test that URL cache files and thumbnails are cleaned up properly on expiry."""
         self.preview_url.clock = MockClock()
 
diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py
index 01d48c3860..da325955f8 100644
--- a/tests/rest/test_health.py
+++ b/tests/rest/test_health.py
@@ -11,7 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+from http import HTTPStatus
 
 from synapse.rest.health import HealthResource
 
@@ -19,12 +19,12 @@ from tests import unittest
 
 
 class HealthCheckTests(unittest.HomeserverTestCase):
-    def create_test_resource(self):
+    def create_test_resource(self) -> HealthResource:
         # replace the JsonResource with a HealthResource.
         return HealthResource()
 
-    def test_health(self):
+    def test_health(self) -> None:
         channel = self.make_request("GET", "/health", shorthand=False)
 
-        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.code, HTTPStatus.OK)
         self.assertEqual(channel.result["body"], b"OK")
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index 118aa93a32..11f78f52b8 100644
--- a/tests/rest/test_well_known.py
+++ b/tests/rest/test_well_known.py
@@ -11,6 +11,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from http import HTTPStatus
+
 from twisted.web.resource import Resource
 
 from synapse.rest.well_known import well_known_resource
@@ -19,7 +21,7 @@ from tests import unittest
 
 
 class WellKnownTests(unittest.HomeserverTestCase):
-    def create_test_resource(self):
+    def create_test_resource(self) -> Resource:
         # replace the JsonResource with a Resource wrapping the WellKnownResource
         res = Resource()
         res.putChild(b".well-known", well_known_resource(self.hs))
@@ -31,12 +33,12 @@ class WellKnownTests(unittest.HomeserverTestCase):
             "default_identity_server": "https://testis",
         }
     )
-    def test_client_well_known(self):
+    def test_client_well_known(self) -> None:
         channel = self.make_request(
             "GET", "/.well-known/matrix/client", shorthand=False
         )
 
-        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.code, HTTPStatus.OK)
         self.assertEqual(
             channel.json_body,
             {
@@ -50,27 +52,27 @@ class WellKnownTests(unittest.HomeserverTestCase):
             "public_baseurl": None,
         }
     )
-    def test_client_well_known_no_public_baseurl(self):
+    def test_client_well_known_no_public_baseurl(self) -> None:
         channel = self.make_request(
             "GET", "/.well-known/matrix/client", shorthand=False
         )
 
-        self.assertEqual(channel.code, 404)
+        self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
 
     @unittest.override_config({"serve_server_wellknown": True})
-    def test_server_well_known(self):
+    def test_server_well_known(self) -> None:
         channel = self.make_request(
             "GET", "/.well-known/matrix/server", shorthand=False
         )
 
-        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.code, HTTPStatus.OK)
         self.assertEqual(
             channel.json_body,
             {"m.server": "test:443"},
         )
 
-    def test_server_well_known_disabled(self):
+    def test_server_well_known_disabled(self) -> None:
         channel = self.make_request(
             "GET", "/.well-known/matrix/server", shorthand=False
         )
-        self.assertEqual(channel.code, 404)
+        self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)