summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_filtering.py107
-rw-r--r--tests/config/test_load.py9
-rw-r--r--tests/crypto/test_keyring.py56
-rw-r--r--tests/handlers/test_appservice.py8
-rw-r--r--tests/handlers/test_auth.py10
-rw-r--r--tests/handlers/test_directory.py95
-rw-r--r--tests/handlers/test_e2e_keys.py32
-rw-r--r--tests/handlers/test_password_providers.py5
-rw-r--r--tests/handlers/test_register.py9
-rw-r--r--tests/handlers/test_room_summary.py55
-rw-r--r--tests/handlers/test_sync.py5
-rw-r--r--tests/replication/_base.py5
-rw-r--r--tests/rest/admin/test_admin.py48
-rw-r--r--tests/rest/admin/test_background_updates.py154
-rw-r--r--tests/rest/admin/test_room.py982
-rw-r--r--tests/rest/admin/test_user.py30
-rw-r--r--tests/rest/client/test_capabilities.py8
-rw-r--r--tests/rest/client/test_directory.py105
-rw-r--r--tests/rest/client/test_login.py73
-rw-r--r--tests/rest/client/test_relations.py173
-rw-r--r--tests/rest/client/test_rooms.py154
-rw-r--r--tests/rest/client/utils.py71
-rw-r--r--tests/server.py18
-rw-r--r--tests/storage/test_profile.py9
-rw-r--r--tests/storage/test_rollback_worker.py2
-rw-r--r--tests/storage/test_roommember.py48
-rw-r--r--tests/storage/test_stream.py207
-rw-r--r--tests/test_federation.py2
-rw-r--r--tests/unittest.py70
29 files changed, 2262 insertions, 288 deletions
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index f44c91a373..b7fc33dc94 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -15,6 +15,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from unittest.mock import patch
+
 import jsonschema
 
 from synapse.api.constants import EventContentFields
@@ -51,9 +53,8 @@ class FilteringTestCase(unittest.HomeserverTestCase):
             {"presence": {"senders": ["@bar;pik.test.com"]}},
         ]
         for filter in invalid_filters:
-            with self.assertRaises(SynapseError) as check_filter_error:
+            with self.assertRaises(SynapseError):
                 self.filtering.check_valid_filter(filter)
-                self.assertIsInstance(check_filter_error.exception, SynapseError)
 
     def test_valid_filters(self):
         valid_filters = [
@@ -119,12 +120,12 @@ class FilteringTestCase(unittest.HomeserverTestCase):
         definition = {"types": ["m.room.message", "org.matrix.foo.bar"]}
         event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
 
-        self.assertTrue(Filter(definition).check(event))
+        self.assertTrue(Filter(self.hs, definition)._check(event))
 
     def test_definition_types_works_with_wildcards(self):
         definition = {"types": ["m.*", "org.matrix.foo.bar"]}
         event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
-        self.assertTrue(Filter(definition).check(event))
+        self.assertTrue(Filter(self.hs, definition)._check(event))
 
     def test_definition_types_works_with_unknowns(self):
         definition = {"types": ["m.room.message", "org.matrix.foo.bar"]}
@@ -133,24 +134,24 @@ class FilteringTestCase(unittest.HomeserverTestCase):
             type="now.for.something.completely.different",
             room_id="!foo:bar",
         )
-        self.assertFalse(Filter(definition).check(event))
+        self.assertFalse(Filter(self.hs, definition)._check(event))
 
     def test_definition_not_types_works_with_literals(self):
         definition = {"not_types": ["m.room.message", "org.matrix.foo.bar"]}
         event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
-        self.assertFalse(Filter(definition).check(event))
+        self.assertFalse(Filter(self.hs, definition)._check(event))
 
     def test_definition_not_types_works_with_wildcards(self):
         definition = {"not_types": ["m.room.message", "org.matrix.*"]}
         event = MockEvent(
             sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar"
         )
-        self.assertFalse(Filter(definition).check(event))
+        self.assertFalse(Filter(self.hs, definition)._check(event))
 
     def test_definition_not_types_works_with_unknowns(self):
         definition = {"not_types": ["m.*", "org.*"]}
         event = MockEvent(sender="@foo:bar", type="com.nom.nom.nom", room_id="!foo:bar")
-        self.assertTrue(Filter(definition).check(event))
+        self.assertTrue(Filter(self.hs, definition)._check(event))
 
     def test_definition_not_types_takes_priority_over_types(self):
         definition = {
@@ -158,35 +159,35 @@ class FilteringTestCase(unittest.HomeserverTestCase):
             "types": ["m.room.message", "m.room.topic"],
         }
         event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
-        self.assertFalse(Filter(definition).check(event))
+        self.assertFalse(Filter(self.hs, definition)._check(event))
 
     def test_definition_senders_works_with_literals(self):
         definition = {"senders": ["@flibble:wibble"]}
         event = MockEvent(
             sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar"
         )
-        self.assertTrue(Filter(definition).check(event))
+        self.assertTrue(Filter(self.hs, definition)._check(event))
 
     def test_definition_senders_works_with_unknowns(self):
         definition = {"senders": ["@flibble:wibble"]}
         event = MockEvent(
             sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar"
         )
-        self.assertFalse(Filter(definition).check(event))
+        self.assertFalse(Filter(self.hs, definition)._check(event))
 
     def test_definition_not_senders_works_with_literals(self):
         definition = {"not_senders": ["@flibble:wibble"]}
         event = MockEvent(
             sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar"
         )
-        self.assertFalse(Filter(definition).check(event))
+        self.assertFalse(Filter(self.hs, definition)._check(event))
 
     def test_definition_not_senders_works_with_unknowns(self):
         definition = {"not_senders": ["@flibble:wibble"]}
         event = MockEvent(
             sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar"
         )
-        self.assertTrue(Filter(definition).check(event))
+        self.assertTrue(Filter(self.hs, definition)._check(event))
 
     def test_definition_not_senders_takes_priority_over_senders(self):
         definition = {
@@ -196,14 +197,14 @@ class FilteringTestCase(unittest.HomeserverTestCase):
         event = MockEvent(
             sender="@misspiggy:muppets", type="m.room.topic", room_id="!foo:bar"
         )
-        self.assertFalse(Filter(definition).check(event))
+        self.assertFalse(Filter(self.hs, definition)._check(event))
 
     def test_definition_rooms_works_with_literals(self):
         definition = {"rooms": ["!secretbase:unknown"]}
         event = MockEvent(
             sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown"
         )
-        self.assertTrue(Filter(definition).check(event))
+        self.assertTrue(Filter(self.hs, definition)._check(event))
 
     def test_definition_rooms_works_with_unknowns(self):
         definition = {"rooms": ["!secretbase:unknown"]}
@@ -212,7 +213,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
             type="m.room.message",
             room_id="!anothersecretbase:unknown",
         )
-        self.assertFalse(Filter(definition).check(event))
+        self.assertFalse(Filter(self.hs, definition)._check(event))
 
     def test_definition_not_rooms_works_with_literals(self):
         definition = {"not_rooms": ["!anothersecretbase:unknown"]}
@@ -221,7 +222,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
             type="m.room.message",
             room_id="!anothersecretbase:unknown",
         )
-        self.assertFalse(Filter(definition).check(event))
+        self.assertFalse(Filter(self.hs, definition)._check(event))
 
     def test_definition_not_rooms_works_with_unknowns(self):
         definition = {"not_rooms": ["!secretbase:unknown"]}
@@ -230,7 +231,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
             type="m.room.message",
             room_id="!anothersecretbase:unknown",
         )
-        self.assertTrue(Filter(definition).check(event))
+        self.assertTrue(Filter(self.hs, definition)._check(event))
 
     def test_definition_not_rooms_takes_priority_over_rooms(self):
         definition = {
@@ -240,7 +241,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
         event = MockEvent(
             sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown"
         )
-        self.assertFalse(Filter(definition).check(event))
+        self.assertFalse(Filter(self.hs, definition)._check(event))
 
     def test_definition_combined_event(self):
         definition = {
@@ -256,7 +257,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
             type="m.room.message",  # yup
             room_id="!stage:unknown",  # yup
         )
-        self.assertTrue(Filter(definition).check(event))
+        self.assertTrue(Filter(self.hs, definition)._check(event))
 
     def test_definition_combined_event_bad_sender(self):
         definition = {
@@ -272,7 +273,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
             type="m.room.message",  # yup
             room_id="!stage:unknown",  # yup
         )
-        self.assertFalse(Filter(definition).check(event))
+        self.assertFalse(Filter(self.hs, definition)._check(event))
 
     def test_definition_combined_event_bad_room(self):
         definition = {
@@ -288,7 +289,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
             type="m.room.message",  # yup
             room_id="!piggyshouse:muppets",  # nope
         )
-        self.assertFalse(Filter(definition).check(event))
+        self.assertFalse(Filter(self.hs, definition)._check(event))
 
     def test_definition_combined_event_bad_type(self):
         definition = {
@@ -304,7 +305,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
             type="muppets.misspiggy.kisses",  # nope
             room_id="!stage:unknown",  # yup
         )
-        self.assertFalse(Filter(definition).check(event))
+        self.assertFalse(Filter(self.hs, definition)._check(event))
 
     def test_filter_labels(self):
         definition = {"org.matrix.labels": ["#fun"]}
@@ -315,7 +316,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
             content={EventContentFields.LABELS: ["#fun"]},
         )
 
-        self.assertTrue(Filter(definition).check(event))
+        self.assertTrue(Filter(self.hs, definition)._check(event))
 
         event = MockEvent(
             sender="@foo:bar",
@@ -324,7 +325,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
             content={EventContentFields.LABELS: ["#notfun"]},
         )
 
-        self.assertFalse(Filter(definition).check(event))
+        self.assertFalse(Filter(self.hs, definition)._check(event))
 
     def test_filter_not_labels(self):
         definition = {"org.matrix.not_labels": ["#fun"]}
@@ -335,7 +336,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
             content={EventContentFields.LABELS: ["#fun"]},
         )
 
-        self.assertFalse(Filter(definition).check(event))
+        self.assertFalse(Filter(self.hs, definition)._check(event))
 
         event = MockEvent(
             sender="@foo:bar",
@@ -344,7 +345,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
             content={EventContentFields.LABELS: ["#notfun"]},
         )
 
-        self.assertTrue(Filter(definition).check(event))
+        self.assertTrue(Filter(self.hs, definition)._check(event))
 
     def test_filter_presence_match(self):
         user_filter_json = {"presence": {"types": ["m.*"]}}
@@ -362,7 +363,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        results = user_filter.filter_presence(events=events)
+        results = self.get_success(user_filter.filter_presence(events=events))
         self.assertEquals(events, results)
 
     def test_filter_presence_no_match(self):
@@ -386,7 +387,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        results = user_filter.filter_presence(events=events)
+        results = self.get_success(user_filter.filter_presence(events=events))
         self.assertEquals([], results)
 
     def test_filter_room_state_match(self):
@@ -405,7 +406,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        results = user_filter.filter_room_state(events=events)
+        results = self.get_success(user_filter.filter_room_state(events=events))
         self.assertEquals(events, results)
 
     def test_filter_room_state_no_match(self):
@@ -426,7 +427,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        results = user_filter.filter_room_state(events)
+        results = self.get_success(user_filter.filter_room_state(events))
         self.assertEquals([], results)
 
     def test_filter_rooms(self):
@@ -441,10 +442,52 @@ class FilteringTestCase(unittest.HomeserverTestCase):
             "!not_included:example.com",  # Disallowed because not in rooms.
         ]
 
-        filtered_room_ids = list(Filter(definition).filter_rooms(room_ids))
+        filtered_room_ids = list(Filter(self.hs, definition).filter_rooms(room_ids))
 
         self.assertEquals(filtered_room_ids, ["!allowed:example.com"])
 
+    @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
+    def test_filter_relations(self):
+        events = [
+            # An event without a relation.
+            MockEvent(
+                event_id="$no_relation",
+                sender="@foo:bar",
+                type="org.matrix.custom.event",
+                room_id="!foo:bar",
+            ),
+            # An event with a relation.
+            MockEvent(
+                event_id="$with_relation",
+                sender="@foo:bar",
+                type="org.matrix.custom.event",
+                room_id="!foo:bar",
+            ),
+            # Non-EventBase objects get passed through.
+            {},
+        ]
+
+        # For the following tests we patch the datastore method (intead of injecting
+        # events). This is a bit cheeky, but tests the logic of _check_event_relations.
+
+        # Filter for a particular sender.
+        definition = {
+            "io.element.relation_senders": ["@foo:bar"],
+        }
+
+        async def events_have_relations(*args, **kwargs):
+            return ["$with_relation"]
+
+        with patch.object(
+            self.datastore, "events_have_relations", new=events_have_relations
+        ):
+            filtered_events = list(
+                self.get_success(
+                    Filter(self.hs, definition)._check_event_relations(events)
+                )
+            )
+        self.assertEquals(filtered_events, events[1:])
+
     def test_add_filter(self):
         user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
 
diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index 765258c47a..d8668d56b2 100644
--- a/tests/config/test_load.py
+++ b/tests/config/test_load.py
@@ -94,3 +94,12 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
         # The default Metrics Flags are off by default.
         config = HomeServerConfig.load_config("", ["-c", self.config_file])
         self.assertFalse(config.metrics.metrics_flags.known_servers)
+
+    def test_depreciated_identity_server_flag_throws_error(self):
+        self.generate_config()
+        # Needed to ensure that actual key/value pair added below don't end up on a line with a comment
+        self.add_lines_to_config([" "])
+        # Check that presence of "trust_identity_server_for_password" throws config error
+        self.add_lines_to_config(["trust_identity_server_for_password_resets: true"])
+        with self.assertRaises(ConfigError):
+            HomeServerConfig.load_config("", ["-c", self.config_file])
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index cbecc1c20f..4d1e154578 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -1,4 +1,4 @@
-# Copyright 2017 New Vector Ltd
+# Copyright 2017-2021 The Matrix.org Foundation C.I.C
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -40,7 +40,7 @@ from synapse.storage.keys import FetchKeyResult
 
 from tests import unittest
 from tests.test_utils import make_awaitable
-from tests.unittest import logcontext_clean
+from tests.unittest import logcontext_clean, override_config
 
 
 class MockPerspectiveServer:
@@ -197,7 +197,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         # self.assertFalse(d.called)
         self.get_success(d)
 
-    def test_verify_for_server_locally(self):
+    def test_verify_for_local_server(self):
         """Ensure that locally signed JSON can be verified without fetching keys
         over federation
         """
@@ -209,6 +209,56 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         d = kr.verify_json_for_server(self.hs.hostname, json1, 0)
         self.get_success(d)
 
+    OLD_KEY = signedjson.key.generate_signing_key("old")
+
+    @override_config(
+        {
+            "old_signing_keys": {
+                f"{OLD_KEY.alg}:{OLD_KEY.version}": {
+                    "key": encode_verify_key_base64(OLD_KEY.verify_key),
+                    "expired_ts": 1000,
+                }
+            }
+        }
+    )
+    def test_verify_for_local_server_old_key(self):
+        """Can also use keys in old_signing_keys for verification"""
+        json1 = {}
+        signedjson.sign.sign_json(json1, self.hs.hostname, self.OLD_KEY)
+
+        kr = keyring.Keyring(self.hs)
+        d = kr.verify_json_for_server(self.hs.hostname, json1, 0)
+        self.get_success(d)
+
+    def test_verify_for_local_server_unknown_key(self):
+        """Local keys that we no longer have should be fetched via the fetcher"""
+
+        # the key we'll sign things with (nb, not known to the Keyring)
+        key2 = signedjson.key.generate_signing_key("2")
+
+        # set up a mock fetcher which will return the key
+        async def get_keys(
+            server_name: str, key_ids: List[str], minimum_valid_until_ts: int
+        ) -> Dict[str, FetchKeyResult]:
+            self.assertEqual(server_name, self.hs.hostname)
+            self.assertEqual(key_ids, [get_key_id(key2)])
+
+            return {get_key_id(key2): FetchKeyResult(get_verify_key(key2), 1200)}
+
+        mock_fetcher = Mock()
+        mock_fetcher.get_keys = Mock(side_effect=get_keys)
+        kr = keyring.Keyring(
+            self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher)
+        )
+
+        # sign the json
+        json1 = {}
+        signedjson.sign.sign_json(json1, self.hs.hostname, key2)
+
+        # ... and check we can verify it.
+        d = kr.verify_json_for_server(self.hs.hostname, json1, 0)
+        self.get_success(d)
+
     def test_verify_json_for_server_with_null_valid_until_ms(self):
         """Tests that we correctly handle key requests for keys we've stored
         with a null `ts_valid_until_ms`
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 1f6a924452..d6f14e2dba 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -272,7 +272,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             make_awaitable(([event], None))
         )
 
-        self.handler.notify_interested_services_ephemeral("receipt_key", 580)
+        self.handler.notify_interested_services_ephemeral(
+            "receipt_key", 580, ["@fakerecipient:example.com"]
+        )
         self.mock_scheduler.submit_ephemeral_events_for_as.assert_called_once_with(
             interested_service, [event]
         )
@@ -300,7 +302,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             make_awaitable(([event], None))
         )
 
-        self.handler.notify_interested_services_ephemeral("receipt_key", 579)
+        self.handler.notify_interested_services_ephemeral(
+            "receipt_key", 580, ["@fakerecipient:example.com"]
+        )
         self.mock_scheduler.submit_ephemeral_events_for_as.assert_not_called()
 
     def _mkservice(self, is_interested, protocols=None):
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 12857053e7..72e176da75 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -116,7 +116,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.auth_blocking._limit_usage_by_mau = False
         # Ensure does not throw exception
         self.get_success(
-            self.auth_handler.get_access_token_for_user_id(
+            self.auth_handler.create_access_token_for_user_id(
                 self.user1, device_id=None, valid_until_ms=None
             )
         )
@@ -134,7 +134,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         )
 
         self.get_failure(
-            self.auth_handler.get_access_token_for_user_id(
+            self.auth_handler.create_access_token_for_user_id(
                 self.user1, device_id=None, valid_until_ms=None
             ),
             ResourceLimitError,
@@ -162,7 +162,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
         # If not in monthly active cohort
         self.get_failure(
-            self.auth_handler.get_access_token_for_user_id(
+            self.auth_handler.create_access_token_for_user_id(
                 self.user1, device_id=None, valid_until_ms=None
             ),
             ResourceLimitError,
@@ -179,7 +179,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             return_value=make_awaitable(self.clock.time_msec())
         )
         self.get_success(
-            self.auth_handler.get_access_token_for_user_id(
+            self.auth_handler.create_access_token_for_user_id(
                 self.user1, device_id=None, valid_until_ms=None
             )
         )
@@ -197,7 +197,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         )
         # Ensure does not raise exception
         self.get_success(
-            self.auth_handler.get_access_token_for_user_id(
+            self.auth_handler.create_access_token_for_user_id(
                 self.user1, device_id=None, valid_until_ms=None
             )
         )
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index be008227df..0ea4e753e2 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -1,4 +1,5 @@
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2021 Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,13 +13,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 from unittest.mock import Mock
 
 import synapse.api.errors
 import synapse.rest.admin
 from synapse.api.constants import EventTypes
-from synapse.config.room_directory import RoomDirectoryConfig
 from synapse.rest.client import directory, login, room
 from synapse.types import RoomAlias, create_requester
 
@@ -394,22 +393,15 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
 
     servlets = [directory.register_servlets, room.register_servlets]
 
-    def prepare(self, reactor, clock, hs):
-        # We cheekily override the config to add custom alias creation rules
-        config = {}
+    def default_config(self):
+        config = super().default_config()
+
+        # Add custom alias creation rules to the config.
         config["alias_creation_rules"] = [
             {"user_id": "*", "alias": "#unofficial_*", "action": "allow"}
         ]
-        config["room_list_publication_rules"] = []
 
-        rd_config = RoomDirectoryConfig()
-        rd_config.read_config(config)
-
-        self.hs.config.roomdirectory.is_alias_creation_allowed = (
-            rd_config.is_alias_creation_allowed
-        )
-
-        return hs
+        return config
 
     def test_denied(self):
         room_id = self.helper.create_room_as(self.user_id)
@@ -417,7 +409,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
         channel = self.make_request(
             "PUT",
             b"directory/room/%23test%3Atest",
-            ('{"room_id":"%s"}' % (room_id,)).encode("ascii"),
+            {"room_id": room_id},
         )
         self.assertEquals(403, channel.code, channel.result)
 
@@ -427,14 +419,35 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
         channel = self.make_request(
             "PUT",
             b"directory/room/%23unofficial_test%3Atest",
-            ('{"room_id":"%s"}' % (room_id,)).encode("ascii"),
+            {"room_id": room_id},
         )
         self.assertEquals(200, channel.code, channel.result)
 
+    def test_denied_during_creation(self):
+        """A room alias that is not allowed should be rejected during creation."""
+        # Invalid room alias.
+        self.helper.create_room_as(
+            self.user_id,
+            expect_code=403,
+            extra_content={"room_alias_name": "foo"},
+        )
 
-class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
-    data = {"room_alias_name": "unofficial_test"}
+    def test_allowed_during_creation(self):
+        """A valid room alias should be allowed during creation."""
+        room_id = self.helper.create_room_as(
+            self.user_id,
+            extra_content={"room_alias_name": "unofficial_test"},
+        )
 
+        channel = self.make_request(
+            "GET",
+            b"directory/room/%23unofficial_test%3Atest",
+        )
+        self.assertEquals(200, channel.code, channel.result)
+        self.assertEquals(channel.json_body["room_id"], room_id)
+
+
+class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         login.register_servlets,
@@ -443,27 +456,30 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
     ]
     hijack_auth = False
 
-    def prepare(self, reactor, clock, hs):
-        self.allowed_user_id = self.register_user("allowed", "pass")
-        self.allowed_access_token = self.login("allowed", "pass")
+    data = {"room_alias_name": "unofficial_test"}
+    allowed_localpart = "allowed"
 
-        self.denied_user_id = self.register_user("denied", "pass")
-        self.denied_access_token = self.login("denied", "pass")
+    def default_config(self):
+        config = super().default_config()
 
-        # This time we add custom room list publication rules
-        config = {}
-        config["alias_creation_rules"] = []
+        # Add custom room list publication rules to the config.
         config["room_list_publication_rules"] = [
+            {
+                "user_id": "@" + self.allowed_localpart + "*",
+                "alias": "#unofficial_*",
+                "action": "allow",
+            },
             {"user_id": "*", "alias": "*", "action": "deny"},
-            {"user_id": self.allowed_user_id, "alias": "*", "action": "allow"},
         ]
 
-        rd_config = RoomDirectoryConfig()
-        rd_config.read_config(config)
+        return config
 
-        self.hs.config.roomdirectory.is_publishing_room_allowed = (
-            rd_config.is_publishing_room_allowed
-        )
+    def prepare(self, reactor, clock, hs):
+        self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
+        self.allowed_access_token = self.login(self.allowed_localpart, "pass")
+
+        self.denied_user_id = self.register_user("denied", "pass")
+        self.denied_access_token = self.login("denied", "pass")
 
         return hs
 
@@ -505,10 +521,23 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
             self.allowed_user_id,
             tok=self.allowed_access_token,
             extra_content=self.data,
-            is_public=False,
+            is_public=True,
             expect_code=200,
         )
 
+    def test_denied_publication_with_invalid_alias(self):
+        """
+        Try to create a room, register an alias for it, and publish it,
+        as a user WITH permission to publish rooms.
+        """
+        self.helper.create_room_as(
+            self.allowed_user_id,
+            tok=self.allowed_access_token,
+            extra_content={"room_alias_name": "foo"},
+            is_public=True,
+            expect_code=403,
+        )
+
     def test_can_create_as_private_room_after_rejection(self):
         """
         After failing to publish a room with an alias as a user without publish permission,
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 0c3b86fda9..f0723892e4 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -162,6 +162,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         local_user = "@boris:" + self.hs.hostname
         device_id = "xyz"
         fallback_key = {"alg1:k1": "key1"}
+        fallback_key2 = {"alg1:k2": "key2"}
         otk = {"alg1:k2": "key2"}
 
         # we shouldn't have any unused fallback keys yet
@@ -213,6 +214,35 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
         )
 
+        # re-uploading the same fallback key should still result in no unused fallback
+        # keys
+        self.get_success(
+            self.handler.upload_keys_for_user(
+                local_user,
+                device_id,
+                {"org.matrix.msc2732.fallback_keys": fallback_key},
+            )
+        )
+
+        res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+        )
+        self.assertEqual(res, [])
+
+        # uploading a new fallback key should result in an unused fallback key
+        self.get_success(
+            self.handler.upload_keys_for_user(
+                local_user,
+                device_id,
+                {"org.matrix.msc2732.fallback_keys": fallback_key2},
+            )
+        )
+
+        res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+        )
+        self.assertEqual(res, ["alg1"])
+
         # if the user uploads a one-time key, the next claim should fetch the
         # one-time key, and then go back to the fallback
         self.get_success(
@@ -238,7 +268,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(
             res,
-            {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
+            {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}},
         )
 
     def test_replace_master_key(self):
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 7dd4a5a367..08e9730d4d 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -31,7 +31,10 @@ from tests.unittest import override_config
 
 # (possibly experimental) login flows we expect to appear in the list after the normal
 # ones
-ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}]
+ADDITIONAL_LOGIN_FLOWS = [
+    {"type": "m.login.application_service"},
+    {"type": "uk.half-shot.msc2778.login.application_service"},
+]
 
 # a mock instance which the dummy auth providers delegate to, so we can see what's going
 # on
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index db691c4c1c..cd6f2c77ae 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -193,7 +193,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
 
     @override_config({"limit_usage_by_mau": True})
     def test_get_or_create_user_mau_not_blocked(self):
-        self.store.count_monthly_users = Mock(
+        # Type ignore: mypy doesn't like us assigning to methods.
+        self.store.count_monthly_users = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
         )
         # Ensure does not throw exception
@@ -201,7 +202,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
 
     @override_config({"limit_usage_by_mau": True})
     def test_get_or_create_user_mau_blocked(self):
-        self.store.get_monthly_active_count = Mock(
+        # Type ignore: mypy doesn't like us assigning to methods.
+        self.store.get_monthly_active_count = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(self.lots_of_users)
         )
         self.get_failure(
@@ -209,7 +211,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
             ResourceLimitError,
         )
 
-        self.store.get_monthly_active_count = Mock(
+        # Type ignore: mypy doesn't like us assigning to methods.
+        self.store.get_monthly_active_count = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(self.hs.config.server.max_mau_value)
         )
         self.get_failure(
diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py
index d3d0bf1ac5..7b95844b55 100644
--- a/tests/handlers/test_room_summary.py
+++ b/tests/handlers/test_room_summary.py
@@ -14,6 +14,8 @@
 from typing import Any, Iterable, List, Optional, Tuple
 from unittest import mock
 
+from twisted.internet.defer import ensureDeferred
+
 from synapse.api.constants import (
     EventContentFields,
     EventTypes,
@@ -316,6 +318,59 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
             AuthError,
         )
 
+    def test_room_hierarchy_cache(self) -> None:
+        """In-flight room hierarchy requests are deduplicated."""
+        # Run two `get_room_hierarchy` calls up until they block.
+        deferred1 = ensureDeferred(
+            self.handler.get_room_hierarchy(self.user, self.space)
+        )
+        deferred2 = ensureDeferred(
+            self.handler.get_room_hierarchy(self.user, self.space)
+        )
+
+        # Complete the two calls.
+        result1 = self.get_success(deferred1)
+        result2 = self.get_success(deferred2)
+
+        # Both `get_room_hierarchy` calls should return the same result.
+        expected = [(self.space, [self.room]), (self.room, ())]
+        self._assert_hierarchy(result1, expected)
+        self._assert_hierarchy(result2, expected)
+        self.assertIs(result1, result2)
+
+        # A subsequent `get_room_hierarchy` call should not reuse the result.
+        result3 = self.get_success(
+            self.handler.get_room_hierarchy(self.user, self.space)
+        )
+        self._assert_hierarchy(result3, expected)
+        self.assertIsNot(result1, result3)
+
+    def test_room_hierarchy_cache_sharing(self) -> None:
+        """Room hierarchy responses for different users are not shared."""
+        user2 = self.register_user("user2", "pass")
+
+        # Make the room within the space invite-only.
+        self.helper.send_state(
+            self.room,
+            event_type=EventTypes.JoinRules,
+            body={"join_rule": JoinRules.INVITE},
+            tok=self.token,
+        )
+
+        # Run two `get_room_hierarchy` calls for different users up until they block.
+        deferred1 = ensureDeferred(
+            self.handler.get_room_hierarchy(self.user, self.space)
+        )
+        deferred2 = ensureDeferred(self.handler.get_room_hierarchy(user2, self.space))
+
+        # Complete the two calls.
+        result1 = self.get_success(deferred1)
+        result2 = self.get_success(deferred2)
+
+        # The `get_room_hierarchy` calls should return different results.
+        self._assert_hierarchy(result1, [(self.space, [self.room]), (self.room, ())])
+        self._assert_hierarchy(result2, [(self.space, [self.room])])
+
     def _create_room_with_join_rule(
         self, join_rule: str, room_version: Optional[str] = None, **extra_content
     ) -> str:
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 339c039914..638186f173 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -13,10 +13,11 @@
 # limitations under the License.
 
 from typing import Optional
+from unittest.mock import Mock
 
 from synapse.api.constants import EventTypes, JoinRules
 from synapse.api.errors import Codes, ResourceLimitError
-from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
+from synapse.api.filtering import Filtering
 from synapse.api.room_versions import RoomVersions
 from synapse.handlers.sync import SyncConfig
 from synapse.rest import admin
@@ -197,7 +198,7 @@ def generate_sync_config(
     _request_key += 1
     return SyncConfig(
         user=UserID.from_string(user_id),
-        filter_collection=DEFAULT_FILTER_COLLECTION,
+        filter_collection=Filtering(Mock()).DEFAULT_FILTER_COLLECTION,
         is_guest=False,
         request_key=("request_key", _request_key),
         device_id=device_id,
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index eac4664b41..cb02eddf07 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -12,13 +12,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Any, Callable, Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
 
 from twisted.internet.protocol import Protocol
 from twisted.web.resource import Resource
 
 from synapse.app.generic_worker import GenericWorkerServer
-from synapse.http.server import JsonResource
 from synapse.http.site import SynapseRequest, SynapseSite
 from synapse.replication.http import ReplicationRestResource
 from synapse.replication.tcp.client import ReplicationDataHandler
@@ -220,8 +219,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
     unlike `BaseStreamTestCase`.
     """
 
-    servlets: List[Callable[[HomeServer, JsonResource], None]] = []
-
     def setUp(self):
         super().setUp()
 
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 192073c520..af849bd471 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -474,3 +474,51 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
                 % server_and_media_id_2
             ),
         )
+
+
+class PurgeHistoryTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.other_user = self.register_user("user", "pass")
+        self.other_user_tok = self.login("user", "pass")
+
+        self.room_id = self.helper.create_room_as(
+            self.other_user, tok=self.other_user_tok
+        )
+        self.url = f"/_synapse/admin/v1/purge_history/{self.room_id}"
+        self.url_status = "/_synapse/admin/v1/purge_history_status/"
+
+    def test_purge_history(self):
+        """
+        Simple test of purge history API.
+        Test only that is is possible to call, get status 200 and purge_id.
+        """
+
+        channel = self.make_request(
+            "POST",
+            self.url,
+            content={"delete_local_events": True, "purge_up_to_ts": 0},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertIn("purge_id", channel.json_body)
+        purge_id = channel.json_body["purge_id"]
+
+        # get status
+        channel = self.make_request(
+            "GET",
+            self.url_status + purge_id,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual("complete", channel.json_body["status"])
diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py
index 78c48db552..1786316763 100644
--- a/tests/rest/admin/test_background_updates.py
+++ b/tests/rest/admin/test_background_updates.py
@@ -11,8 +11,13 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from http import HTTPStatus
+from typing import Collection
+
+from parameterized import parameterized
 
 import synapse.rest.admin
+from synapse.api.errors import Codes
 from synapse.rest.client import login
 from synapse.server import HomeServer
 
@@ -30,6 +35,60 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
+    @parameterized.expand(
+        [
+            ("GET", "/_synapse/admin/v1/background_updates/enabled"),
+            ("POST", "/_synapse/admin/v1/background_updates/enabled"),
+            ("GET", "/_synapse/admin/v1/background_updates/status"),
+            ("POST", "/_synapse/admin/v1/background_updates/start_job"),
+        ]
+    )
+    def test_requester_is_no_admin(self, method: str, url: str):
+        """
+        If the user is not a server admin, an error 403 is returned.
+        """
+
+        self.register_user("user", "pass", admin=False)
+        other_user_tok = self.login("user", "pass")
+
+        channel = self.make_request(
+            method,
+            url,
+            content={},
+            access_token=other_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_invalid_parameter(self):
+        """
+        If parameters are invalid, an error is returned.
+        """
+        url = "/_synapse/admin/v1/background_updates/start_job"
+
+        # empty content
+        channel = self.make_request(
+            "POST",
+            url,
+            content={},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
+
+        # job_name invalid
+        channel = self.make_request(
+            "POST",
+            url,
+            content={"job_name": "unknown"},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
     def _register_bg_update(self):
         "Adds a bg update but doesn't start it"
 
@@ -60,7 +119,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
             "/_synapse/admin/v1/background_updates/status",
             access_token=self.admin_user_tok,
         )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
 
         # Background updates should be enabled, but none should be running.
         self.assertDictEqual(
@@ -82,7 +141,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
             "/_synapse/admin/v1/background_updates/status",
             access_token=self.admin_user_tok,
         )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
 
         # Background updates should be enabled, and one should be running.
         self.assertDictEqual(
@@ -114,7 +173,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
             "/_synapse/admin/v1/background_updates/enabled",
             access_token=self.admin_user_tok,
         )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
         self.assertDictEqual(channel.json_body, {"enabled": True})
 
         # Disable the BG updates
@@ -124,7 +183,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
             content={"enabled": False},
             access_token=self.admin_user_tok,
         )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
         self.assertDictEqual(channel.json_body, {"enabled": False})
 
         # Advance a bit and get the current status, note this will finish the in
@@ -137,7 +196,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
             "/_synapse/admin/v1/background_updates/status",
             access_token=self.admin_user_tok,
         )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
         self.assertDictEqual(
             channel.json_body,
             {
@@ -162,7 +221,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
             "/_synapse/admin/v1/background_updates/status",
             access_token=self.admin_user_tok,
         )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
 
         # There should be no change from the previous /status response.
         self.assertDictEqual(
@@ -188,7 +247,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
             content={"enabled": True},
             access_token=self.admin_user_tok,
         )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
 
         self.assertDictEqual(channel.json_body, {"enabled": True})
 
@@ -199,7 +258,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
             "/_synapse/admin/v1/background_updates/status",
             access_token=self.admin_user_tok,
         )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
 
         # Background updates should be enabled and making progress.
         self.assertDictEqual(
@@ -216,3 +275,82 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
                 "enabled": True,
             },
         )
+
+    @parameterized.expand(
+        [
+            ("populate_stats_process_rooms", ["populate_stats_process_rooms"]),
+            (
+                "regenerate_directory",
+                [
+                    "populate_user_directory_createtables",
+                    "populate_user_directory_process_rooms",
+                    "populate_user_directory_process_users",
+                    "populate_user_directory_cleanup",
+                ],
+            ),
+        ]
+    )
+    def test_start_backround_job(self, job_name: str, updates: Collection[str]):
+        """
+        Test that background updates add to database and be processed.
+
+        Args:
+            job_name: name of the job to call with API
+            updates: collection of background updates to be started
+        """
+
+        # no background update is waiting
+        self.assertTrue(
+            self.get_success(
+                self.store.db_pool.updates.has_completed_background_updates()
+            )
+        )
+
+        channel = self.make_request(
+            "POST",
+            "/_synapse/admin/v1/background_updates/start_job",
+            content={"job_name": job_name},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+
+        # test that each background update is waiting now
+        for update in updates:
+            self.assertFalse(
+                self.get_success(
+                    self.store.db_pool.updates.has_completed_background_update(update)
+                )
+            )
+
+        self.wait_for_background_updates()
+
+        # background updates are done
+        self.assertTrue(
+            self.get_success(
+                self.store.db_pool.updates.has_completed_background_updates()
+            )
+        )
+
+    def test_start_backround_job_twice(self):
+        """Test that add a background update twice return an error."""
+
+        # add job to database
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                table="background_updates",
+                values={
+                    "update_name": "populate_stats_process_rooms",
+                    "progress_json": "{}",
+                },
+            )
+        )
+
+        channel = self.make_request(
+            "POST",
+            "/_synapse/admin/v1/background_updates/start_job",
+            content={"job_name": "populate_stats_process_rooms"},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 46116644ce..07077aff78 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -14,12 +14,16 @@
 
 import json
 import urllib.parse
+from http import HTTPStatus
 from typing import List, Optional
 from unittest.mock import Mock
 
+from parameterized import parameterized
+
 import synapse.rest.admin
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import Codes
+from synapse.handlers.pagination import PaginationHandler
 from synapse.rest.client import directory, events, login, room
 
 from tests import unittest
@@ -68,11 +72,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "DELETE",
             self.url,
-            json.dumps({}),
+            {},
             access_token=self.other_user_tok,
         )
 
-        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
     def test_room_does_not_exist(self):
@@ -84,11 +88,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "DELETE",
             url,
-            json.dumps({}),
+            {},
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
 
     def test_room_is_not_valid(self):
@@ -100,11 +104,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "DELETE",
             url,
-            json.dumps({}),
+            {},
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual(
             "invalidroom is not a legal room ID",
             channel.json_body["error"],
@@ -119,11 +123,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "DELETE",
             self.url,
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertIn("new_room_id", channel.json_body)
         self.assertIn("kicked_users", channel.json_body)
         self.assertIn("failed_to_kick_users", channel.json_body)
@@ -138,11 +142,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "DELETE",
             self.url,
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual(
             "User must be our own: @not:exist.bla",
             channel.json_body["error"],
@@ -157,11 +161,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "DELETE",
             self.url,
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
 
     def test_purge_is_not_bool(self):
@@ -173,11 +177,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "DELETE",
             self.url,
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
 
     def test_purge_room_and_block(self):
@@ -199,11 +203,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "DELETE",
             self.url.encode("ascii"),
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(None, channel.json_body["new_room_id"])
         self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
         self.assertIn("failed_to_kick_users", channel.json_body)
@@ -232,11 +236,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "DELETE",
             self.url.encode("ascii"),
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(None, channel.json_body["new_room_id"])
         self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
         self.assertIn("failed_to_kick_users", channel.json_body)
@@ -266,11 +270,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "DELETE",
             self.url.encode("ascii"),
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(None, channel.json_body["new_room_id"])
         self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
         self.assertIn("failed_to_kick_users", channel.json_body)
@@ -281,6 +285,31 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         self._is_blocked(self.room_id, expect=True)
         self._has_no_members(self.room_id)
 
+    @parameterized.expand([(True,), (False,)])
+    def test_block_unknown_room(self, purge: bool) -> None:
+        """
+        We can block an unknown room. In this case, the `purge` argument
+        should be ignored.
+        """
+        room_id = "!unknown:test"
+
+        # The room isn't already in the blocked rooms table
+        self._is_blocked(room_id, expect=False)
+
+        # Request the room be blocked.
+        channel = self.make_request(
+            "DELETE",
+            f"/_synapse/admin/v1/rooms/{room_id}",
+            {"block": True, "purge": purge},
+            access_token=self.admin_user_tok,
+        )
+
+        # The room is now blocked.
+        self.assertEqual(
+            HTTPStatus.OK, int(channel.result["code"]), msg=channel.result["body"]
+        )
+        self._is_blocked(room_id)
+
     def test_shutdown_room_consent(self):
         """Test that we can shutdown rooms with local users who have not
         yet accepted the privacy policy. This used to fail when we tried to
@@ -316,7 +345,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
         self.assertIn("new_room_id", channel.json_body)
         self.assertIn("failed_to_kick_users", channel.json_body)
@@ -345,7 +374,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
             json.dumps({"history_visibility": "world_readable"}),
             access_token=self.other_user_tok,
         )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
 
         # Test that room is not purged
         with self.assertRaises(AssertionError):
@@ -362,7 +391,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
         self.assertIn("new_room_id", channel.json_body)
         self.assertIn("failed_to_kick_users", channel.json_body)
@@ -418,17 +447,616 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "GET", url.encode("ascii"), access_token=self.admin_user_tok
         )
+        self.assertEqual(expect_code, channel.code, msg=channel.json_body)
+
+        url = "events?timeout=0&room_id=" + room_id
+        channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok
+        )
+        self.assertEqual(expect_code, channel.code, msg=channel.json_body)
+
+
+class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        events.register_servlets,
+        room.register_servlets,
+        room.register_deprecated_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.event_creation_handler = hs.get_event_creation_handler()
+        hs.config.consent.user_consent_version = "1"
+
+        consent_uri_builder = Mock()
+        consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
+        self.event_creation_handler._consent_uri_builder = consent_uri_builder
+
+        self.store = hs.get_datastore()
+
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.other_user = self.register_user("user", "pass")
+        self.other_user_tok = self.login("user", "pass")
+
+        # Mark the admin user as having consented
+        self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
+
+        self.room_id = self.helper.create_room_as(
+            self.other_user, tok=self.other_user_tok
+        )
+        self.url = f"/_synapse/admin/v2/rooms/{self.room_id}"
+        self.url_status_by_room_id = (
+            f"/_synapse/admin/v2/rooms/{self.room_id}/delete_status"
+        )
+        self.url_status_by_delete_id = "/_synapse/admin/v2/rooms/delete_status/"
+
+    @parameterized.expand(
+        [
+            ("DELETE", "/_synapse/admin/v2/rooms/%s"),
+            ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"),
+            ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"),
+        ]
+    )
+    def test_requester_is_no_admin(self, method: str, url: str):
+        """
+        If the user is not a server admin, an error 403 is returned.
+        """
+
+        channel = self.make_request(
+            method,
+            url % self.room_id,
+            content={},
+            access_token=self.other_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    @parameterized.expand(
+        [
+            ("DELETE", "/_synapse/admin/v2/rooms/%s"),
+            ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"),
+            ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"),
+        ]
+    )
+    def test_room_does_not_exist(self, method: str, url: str):
+        """
+        Check that unknown rooms/server return error 404.
+        """
+
+        channel = self.make_request(
+            method,
+            url % "!unknown:test",
+            content={},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+    @parameterized.expand(
+        [
+            ("DELETE", "/_synapse/admin/v2/rooms/%s"),
+            ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"),
+        ]
+    )
+    def test_room_is_not_valid(self, method: str, url: str):
+        """
+        Check that invalid room names, return an error 400.
+        """
+
+        channel = self.make_request(
+            method,
+            url % "invalidroom",
+            content={},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+        self.assertEqual(
+            "invalidroom is not a legal room ID",
+            channel.json_body["error"],
+        )
+
+    def test_new_room_user_does_not_exist(self):
+        """
+        Tests that the user ID must be from local server but it does not have to exist.
+        """
+
+        channel = self.make_request(
+            "DELETE",
+            self.url,
+            content={"new_room_user_id": "@unknown:test"},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertIn("delete_id", channel.json_body)
+        delete_id = channel.json_body["delete_id"]
+
+        self._test_result(delete_id, self.other_user, expect_new_room=True)
+
+    def test_new_room_user_is_not_local(self):
+        """
+        Check that only local users can create new room to move members.
+        """
+
+        channel = self.make_request(
+            "DELETE",
+            self.url,
+            content={"new_room_user_id": "@not:exist.bla"},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
         self.assertEqual(
-            expect_code, int(channel.result["code"]), msg=channel.result["body"]
+            "User must be our own: @not:exist.bla",
+            channel.json_body["error"],
         )
 
+    def test_block_is_not_bool(self):
+        """
+        If parameter `block` is not boolean, return an error
+        """
+
+        channel = self.make_request(
+            "DELETE",
+            self.url,
+            content={"block": "NotBool"},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+    def test_purge_is_not_bool(self):
+        """
+        If parameter `purge` is not boolean, return an error
+        """
+
+        channel = self.make_request(
+            "DELETE",
+            self.url,
+            content={"purge": "NotBool"},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+    def test_delete_expired_status(self):
+        """Test that the task status is removed after expiration."""
+
+        # first task, do not purge, that we can create a second task
+        channel = self.make_request(
+            "DELETE",
+            self.url.encode("ascii"),
+            content={"purge": False},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertIn("delete_id", channel.json_body)
+        delete_id1 = channel.json_body["delete_id"]
+
+        # go ahead
+        self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2)
+
+        # second task
+        channel = self.make_request(
+            "DELETE",
+            self.url.encode("ascii"),
+            content={"purge": True},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertIn("delete_id", channel.json_body)
+        delete_id2 = channel.json_body["delete_id"]
+
+        # get status
+        channel = self.make_request(
+            "GET",
+            self.url_status_by_room_id,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(2, len(channel.json_body["results"]))
+        self.assertEqual("complete", channel.json_body["results"][0]["status"])
+        self.assertEqual("complete", channel.json_body["results"][1]["status"])
+        self.assertEqual(delete_id1, channel.json_body["results"][0]["delete_id"])
+        self.assertEqual(delete_id2, channel.json_body["results"][1]["delete_id"])
+
+        # get status after more than clearing time for first task
+        # second task is not cleared
+        self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2)
+
+        channel = self.make_request(
+            "GET",
+            self.url_status_by_room_id,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(1, len(channel.json_body["results"]))
+        self.assertEqual("complete", channel.json_body["results"][0]["status"])
+        self.assertEqual(delete_id2, channel.json_body["results"][0]["delete_id"])
+
+        # get status after more than clearing time for all tasks
+        self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2)
+
+        channel = self.make_request(
+            "GET",
+            self.url_status_by_room_id,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+    def test_delete_same_room_twice(self):
+        """Test that the call for delete a room at second time gives an exception."""
+
+        body = {"new_room_user_id": self.admin_user}
+
+        # first call to delete room
+        # and do not wait for finish the task
+        first_channel = self.make_request(
+            "DELETE",
+            self.url.encode("ascii"),
+            content=body,
+            access_token=self.admin_user_tok,
+            await_result=False,
+        )
+
+        # second call to delete room
+        second_channel = self.make_request(
+            "DELETE",
+            self.url.encode("ascii"),
+            content=body,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(
+            HTTPStatus.BAD_REQUEST, second_channel.code, msg=second_channel.json_body
+        )
+        self.assertEqual(Codes.UNKNOWN, second_channel.json_body["errcode"])
+        self.assertEqual(
+            f"History purge already in progress for {self.room_id}",
+            second_channel.json_body["error"],
+        )
+
+        # get result of first call
+        first_channel.await_result()
+        self.assertEqual(HTTPStatus.OK, first_channel.code, msg=first_channel.json_body)
+        self.assertIn("delete_id", first_channel.json_body)
+
+        # check status after finish the task
+        self._test_result(
+            first_channel.json_body["delete_id"],
+            self.other_user,
+            expect_new_room=True,
+        )
+
+    def test_purge_room_and_block(self):
+        """Test to purge a room and block it.
+        Members will not be moved to a new room and will not receive a message.
+        """
+        # Test that room is not purged
+        with self.assertRaises(AssertionError):
+            self._is_purged(self.room_id)
+
+        # Test that room is not blocked
+        self._is_blocked(self.room_id, expect=False)
+
+        # Assert one user in room
+        self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+        channel = self.make_request(
+            "DELETE",
+            self.url.encode("ascii"),
+            content={"block": True, "purge": True},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertIn("delete_id", channel.json_body)
+        delete_id = channel.json_body["delete_id"]
+
+        self._test_result(delete_id, self.other_user)
+
+        self._is_purged(self.room_id)
+        self._is_blocked(self.room_id, expect=True)
+        self._has_no_members(self.room_id)
+
+    def test_purge_room_and_not_block(self):
+        """Test to purge a room and do not block it.
+        Members will not be moved to a new room and will not receive a message.
+        """
+        # Test that room is not purged
+        with self.assertRaises(AssertionError):
+            self._is_purged(self.room_id)
+
+        # Test that room is not blocked
+        self._is_blocked(self.room_id, expect=False)
+
+        # Assert one user in room
+        self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+        channel = self.make_request(
+            "DELETE",
+            self.url.encode("ascii"),
+            content={"block": False, "purge": True},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertIn("delete_id", channel.json_body)
+        delete_id = channel.json_body["delete_id"]
+
+        self._test_result(delete_id, self.other_user)
+
+        self._is_purged(self.room_id)
+        self._is_blocked(self.room_id, expect=False)
+        self._has_no_members(self.room_id)
+
+    def test_block_room_and_not_purge(self):
+        """Test to block a room without purging it.
+        Members will not be moved to a new room and will not receive a message.
+        The room will not be purged.
+        """
+        # Test that room is not purged
+        with self.assertRaises(AssertionError):
+            self._is_purged(self.room_id)
+
+        # Test that room is not blocked
+        self._is_blocked(self.room_id, expect=False)
+
+        # Assert one user in room
+        self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+        channel = self.make_request(
+            "DELETE",
+            self.url.encode("ascii"),
+            content={"block": True, "purge": False},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertIn("delete_id", channel.json_body)
+        delete_id = channel.json_body["delete_id"]
+
+        self._test_result(delete_id, self.other_user)
+
+        with self.assertRaises(AssertionError):
+            self._is_purged(self.room_id)
+        self._is_blocked(self.room_id, expect=True)
+        self._has_no_members(self.room_id)
+
+    def test_shutdown_room_consent(self):
+        """Test that we can shutdown rooms with local users who have not
+        yet accepted the privacy policy. This used to fail when we tried to
+        force part the user from the old room.
+        Members will be moved to a new room and will receive a message.
+        """
+        self.event_creation_handler._block_events_without_consent_error = None
+
+        # Assert one user in room
+        users_in_room = self.get_success(self.store.get_users_in_room(self.room_id))
+        self.assertEqual([self.other_user], users_in_room)
+
+        # Enable require consent to send events
+        self.event_creation_handler._block_events_without_consent_error = "Error"
+
+        # Assert that the user is getting consent error
+        self.helper.send(
+            self.room_id, body="foo", tok=self.other_user_tok, expect_code=403
+        )
+
+        # Test that room is not purged
+        with self.assertRaises(AssertionError):
+            self._is_purged(self.room_id)
+
+        # Assert one user in room
+        self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+        # Test that the admin can still send shutdown
+        channel = self.make_request(
+            "DELETE",
+            self.url,
+            content={"new_room_user_id": self.admin_user},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertIn("delete_id", channel.json_body)
+        delete_id = channel.json_body["delete_id"]
+
+        self._test_result(delete_id, self.other_user, expect_new_room=True)
+
+        channel = self.make_request(
+            "GET",
+            self.url_status_by_room_id,
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(1, len(channel.json_body["results"]))
+
+        # Test that member has moved to new room
+        self._is_member(
+            room_id=channel.json_body["results"][0]["shutdown_room"]["new_room_id"],
+            user_id=self.other_user,
+        )
+
+        self._is_purged(self.room_id)
+        self._has_no_members(self.room_id)
+
+    def test_shutdown_room_block_peek(self):
+        """Test that a world_readable room can no longer be peeked into after
+        it has been shut down.
+        Members will be moved to a new room and will receive a message.
+        """
+        self.event_creation_handler._block_events_without_consent_error = None
+
+        # Enable world readable
+        url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,)
+        channel = self.make_request(
+            "PUT",
+            url.encode("ascii"),
+            content={"history_visibility": "world_readable"},
+            access_token=self.other_user_tok,
+        )
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+
+        # Test that room is not purged
+        with self.assertRaises(AssertionError):
+            self._is_purged(self.room_id)
+
+        # Assert one user in room
+        self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+        # Test that the admin can still send shutdown
+        channel = self.make_request(
+            "DELETE",
+            self.url,
+            content={"new_room_user_id": self.admin_user},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertIn("delete_id", channel.json_body)
+        delete_id = channel.json_body["delete_id"]
+
+        self._test_result(delete_id, self.other_user, expect_new_room=True)
+
+        channel = self.make_request(
+            "GET",
+            self.url_status_by_room_id,
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(1, len(channel.json_body["results"]))
+
+        # Test that member has moved to new room
+        self._is_member(
+            room_id=channel.json_body["results"][0]["shutdown_room"]["new_room_id"],
+            user_id=self.other_user,
+        )
+
+        self._is_purged(self.room_id)
+        self._has_no_members(self.room_id)
+
+        # Assert we can no longer peek into the room
+        self._assert_peek(self.room_id, expect_code=403)
+
+    def _is_blocked(self, room_id: str, expect: bool = True) -> None:
+        """Assert that the room is blocked or not"""
+        d = self.store.is_room_blocked(room_id)
+        if expect:
+            self.assertTrue(self.get_success(d))
+        else:
+            self.assertIsNone(self.get_success(d))
+
+    def _has_no_members(self, room_id: str) -> None:
+        """Assert there is now no longer anyone in the room"""
+        users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+        self.assertEqual([], users_in_room)
+
+    def _is_member(self, room_id: str, user_id: str) -> None:
+        """Test that user is member of the room"""
+        users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+        self.assertIn(user_id, users_in_room)
+
+    def _is_purged(self, room_id: str) -> None:
+        """Test that the following tables have been purged of all rows related to the room."""
+        for table in PURGE_TABLES:
+            count = self.get_success(
+                self.store.db_pool.simple_select_one_onecol(
+                    table=table,
+                    keyvalues={"room_id": room_id},
+                    retcol="COUNT(*)",
+                    desc="test_purge_room",
+                )
+            )
+
+            self.assertEqual(count, 0, msg=f"Rows not purged in {table}")
+
+    def _assert_peek(self, room_id: str, expect_code: int) -> None:
+        """Assert that the admin user can (or cannot) peek into the room."""
+
+        url = f"rooms/{room_id}/initialSync"
+        channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok
+        )
+        self.assertEqual(expect_code, channel.code, msg=channel.json_body)
+
         url = "events?timeout=0&room_id=" + room_id
         channel = self.make_request(
             "GET", url.encode("ascii"), access_token=self.admin_user_tok
         )
+        self.assertEqual(expect_code, channel.code, msg=channel.json_body)
+
+    def _test_result(
+        self,
+        delete_id: str,
+        kicked_user: str,
+        expect_new_room: bool = False,
+    ) -> None:
+        """
+        Test that the result is the expected.
+        Uses both APIs (status by room_id and delete_id)
+
+        Args:
+            delete_id: id of this purge
+            kicked_user: a user_id which is kicked from the room
+            expect_new_room: if we expect that a new room was created
+        """
+
+        # get information by room_id
+        channel_room_id = self.make_request(
+            "GET",
+            self.url_status_by_room_id,
+            access_token=self.admin_user_tok,
+        )
         self.assertEqual(
-            expect_code, int(channel.result["code"]), msg=channel.result["body"]
+            HTTPStatus.OK, channel_room_id.code, msg=channel_room_id.json_body
         )
+        self.assertEqual(1, len(channel_room_id.json_body["results"]))
+        self.assertEqual(
+            delete_id, channel_room_id.json_body["results"][0]["delete_id"]
+        )
+
+        # get information by delete_id
+        channel_delete_id = self.make_request(
+            "GET",
+            self.url_status_by_delete_id + delete_id,
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(
+            HTTPStatus.OK,
+            channel_delete_id.code,
+            msg=channel_delete_id.json_body,
+        )
+
+        # test values that are the same in both responses
+        for content in [
+            channel_room_id.json_body["results"][0],
+            channel_delete_id.json_body,
+        ]:
+            self.assertEqual("complete", content["status"])
+            self.assertEqual(kicked_user, content["shutdown_room"]["kicked_users"][0])
+            self.assertIn("failed_to_kick_users", content["shutdown_room"])
+            self.assertIn("local_aliases", content["shutdown_room"])
+            self.assertNotIn("error", content)
+
+            if expect_new_room:
+                self.assertIsNotNone(content["shutdown_room"]["new_room_id"])
+            else:
+                self.assertIsNone(content["shutdown_room"]["new_room_id"])
 
 
 class RoomTestCase(unittest.HomeserverTestCase):
@@ -466,7 +1094,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
         )
 
         # Check request completed successfully
-        self.assertEqual(200, int(channel.code), msg=channel.json_body)
+        self.assertEqual(200, channel.code, msg=channel.json_body)
 
         # Check that response json body contains a "rooms" key
         self.assertTrue(
@@ -550,9 +1178,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
                 url.encode("ascii"),
                 access_token=self.admin_user_tok,
             )
-            self.assertEqual(
-                200, int(channel.result["code"]), msg=channel.result["body"]
-            )
+            self.assertEqual(200, channel.code, msg=channel.json_body)
 
             self.assertTrue("rooms" in channel.json_body)
             for r in channel.json_body["rooms"]:
@@ -592,7 +1218,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
             url.encode("ascii"),
             access_token=self.admin_user_tok,
         )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
 
     def test_correct_room_attributes(self):
         """Test the correct attributes for a room are returned"""
@@ -615,7 +1241,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
             {"room_id": room_id},
             access_token=self.admin_user_tok,
         )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
 
         # Set this new alias as the canonical alias for this room
         self.helper.send_state(
@@ -647,7 +1273,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
             url.encode("ascii"),
             access_token=self.admin_user_tok,
         )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
 
         # Check that rooms were returned
         self.assertTrue("rooms" in channel.json_body)
@@ -1107,7 +1733,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
             {"room_id": room_id},
             access_token=admin_user_tok,
         )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
 
         # Set this new alias as the canonical alias for this room
         self.helper.send_state(
@@ -1157,11 +1783,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "POST",
             self.url,
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.second_tok,
         )
 
-        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
     def test_invalid_parameter(self):
@@ -1173,11 +1799,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "POST",
             self.url,
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
 
     def test_local_user_does_not_exist(self):
@@ -1189,11 +1815,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "POST",
             self.url,
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
 
     def test_remote_user(self):
@@ -1205,11 +1831,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "POST",
             self.url,
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual(
             "This endpoint can only be used with local users",
             channel.json_body["error"],
@@ -1225,11 +1851,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "POST",
             url,
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual("No known servers", channel.json_body["error"])
 
     def test_room_is_not_valid(self):
@@ -1242,11 +1868,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "POST",
             url,
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual(
             "invalidroom was not legal room ID or room alias",
             channel.json_body["error"],
@@ -1261,11 +1887,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "POST",
             self.url,
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(self.public_room_id, channel.json_body["room_id"])
 
         # Validate if user is a member of the room
@@ -1275,7 +1901,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
             "/_matrix/client/r0/joined_rooms",
             access_token=self.second_tok,
         )
-        self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEquals(200, channel.code, msg=channel.json_body)
         self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
 
     def test_join_private_room_if_not_member(self):
@@ -1292,11 +1918,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "POST",
             url,
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
     def test_join_private_room_if_member(self):
@@ -1324,7 +1950,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
             "/_matrix/client/r0/joined_rooms",
             access_token=self.admin_user_tok,
         )
-        self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEquals(200, channel.code, msg=channel.json_body)
         self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
 
         # Join user to room.
@@ -1335,10 +1961,10 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "POST",
             url,
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(private_room_id, channel.json_body["room_id"])
 
         # Validate if user is a member of the room
@@ -1348,7 +1974,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
             "/_matrix/client/r0/joined_rooms",
             access_token=self.second_tok,
         )
-        self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEquals(200, channel.code, msg=channel.json_body)
         self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
 
     def test_join_private_room_if_owner(self):
@@ -1365,11 +1991,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "POST",
             url,
-            content=body.encode(encoding="utf_8"),
+            content=body,
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(private_room_id, channel.json_body["room_id"])
 
         # Validate if user is a member of the room
@@ -1379,7 +2005,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
             "/_matrix/client/r0/joined_rooms",
             access_token=self.second_tok,
         )
-        self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEquals(200, channel.code, msg=channel.json_body)
         self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
 
     def test_context_as_non_admin(self):
@@ -1413,9 +2039,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
                 % (room_id, events[midway]["event_id"]),
                 access_token=tok,
             )
-            self.assertEquals(
-                403, int(channel.result["code"]), msg=channel.result["body"]
-            )
+            self.assertEquals(403, channel.code, msg=channel.json_body)
             self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
     def test_context_as_admin(self):
@@ -1445,7 +2069,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
             % (room_id, events[midway]["event_id"]),
             access_token=self.admin_user_tok,
         )
-        self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEquals(200, channel.code, msg=channel.json_body)
         self.assertEquals(
             channel.json_body["event"]["event_id"], events[midway]["event_id"]
         )
@@ -1504,7 +2128,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
 
         # Now we test that we can join the room and ban a user.
         self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok)
@@ -1531,7 +2155,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
 
         # Now we test that we can join the room (we should have received an
         # invite) and can ban a user.
@@ -1557,7 +2181,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
             access_token=self.admin_user_tok,
         )
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
 
         # Now we test that we can join the room and ban a user.
         self.helper.join(room_id, self.second_user_id, tok=self.second_tok)
@@ -1595,13 +2219,241 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
         #
         # (Note we assert the error message to ensure that it's not denied for
         # some other reason)
-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual(
             channel.json_body["error"],
             "No local admin user in room with power to update power levels.",
         )
 
 
+class BlockRoomTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self._store = hs.get_datastore()
+
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.other_user = self.register_user("user", "pass")
+        self.other_user_tok = self.login("user", "pass")
+
+        self.room_id = self.helper.create_room_as(
+            self.other_user, tok=self.other_user_tok
+        )
+        self.url = "/_synapse/admin/v1/rooms/%s/block"
+
+    @parameterized.expand([("PUT",), ("GET",)])
+    def test_requester_is_no_admin(self, method: str):
+        """If the user is not a server admin, an error 403 is returned."""
+
+        channel = self.make_request(
+            method,
+            self.url % self.room_id,
+            content={},
+            access_token=self.other_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    @parameterized.expand([("PUT",), ("GET",)])
+    def test_room_is_not_valid(self, method: str):
+        """Check that invalid room names, return an error 400."""
+
+        channel = self.make_request(
+            method,
+            self.url % "invalidroom",
+            content={},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+        self.assertEqual(
+            "invalidroom is not a legal room ID",
+            channel.json_body["error"],
+        )
+
+    def test_block_is_not_valid(self):
+        """If parameter `block` is not valid, return an error."""
+
+        # `block` is not valid
+        channel = self.make_request(
+            "PUT",
+            self.url % self.room_id,
+            content={"block": "NotBool"},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+        # `block` is not set
+        channel = self.make_request(
+            "PUT",
+            self.url % self.room_id,
+            content={},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
+
+        # no content is send
+        channel = self.make_request(
+            "PUT",
+            self.url % self.room_id,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"])
+
+    def test_block_room(self):
+        """Test that block a room is successful."""
+
+        def _request_and_test_block_room(room_id: str) -> None:
+            self._is_blocked(room_id, expect=False)
+            channel = self.make_request(
+                "PUT",
+                self.url % room_id,
+                content={"block": True},
+                access_token=self.admin_user_tok,
+            )
+            self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+            self.assertTrue(channel.json_body["block"])
+            self._is_blocked(room_id, expect=True)
+
+        # known internal room
+        _request_and_test_block_room(self.room_id)
+
+        # unknown internal room
+        _request_and_test_block_room("!unknown:test")
+
+        # unknown remote room
+        _request_and_test_block_room("!unknown:remote")
+
+    def test_block_room_twice(self):
+        """Test that block a room that is already blocked is successful."""
+
+        self._is_blocked(self.room_id, expect=False)
+        for _ in range(2):
+            channel = self.make_request(
+                "PUT",
+                self.url % self.room_id,
+                content={"block": True},
+                access_token=self.admin_user_tok,
+            )
+            self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+            self.assertTrue(channel.json_body["block"])
+            self._is_blocked(self.room_id, expect=True)
+
+    def test_unblock_room(self):
+        """Test that unblock a room is successful."""
+
+        def _request_and_test_unblock_room(room_id: str) -> None:
+            self._block_room(room_id)
+
+            channel = self.make_request(
+                "PUT",
+                self.url % room_id,
+                content={"block": False},
+                access_token=self.admin_user_tok,
+            )
+            self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+            self.assertFalse(channel.json_body["block"])
+            self._is_blocked(room_id, expect=False)
+
+        # known internal room
+        _request_and_test_unblock_room(self.room_id)
+
+        # unknown internal room
+        _request_and_test_unblock_room("!unknown:test")
+
+        # unknown remote room
+        _request_and_test_unblock_room("!unknown:remote")
+
+    def test_unblock_room_twice(self):
+        """Test that unblock a room that is not blocked is successful."""
+
+        self._block_room(self.room_id)
+        for _ in range(2):
+            channel = self.make_request(
+                "PUT",
+                self.url % self.room_id,
+                content={"block": False},
+                access_token=self.admin_user_tok,
+            )
+            self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+            self.assertFalse(channel.json_body["block"])
+            self._is_blocked(self.room_id, expect=False)
+
+    def test_get_blocked_room(self):
+        """Test get status of a blocked room"""
+
+        def _request_blocked_room(room_id: str) -> None:
+            self._block_room(room_id)
+
+            channel = self.make_request(
+                "GET",
+                self.url % room_id,
+                access_token=self.admin_user_tok,
+            )
+            self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+            self.assertTrue(channel.json_body["block"])
+            self.assertEqual(self.other_user, channel.json_body["user_id"])
+
+        # known internal room
+        _request_blocked_room(self.room_id)
+
+        # unknown internal room
+        _request_blocked_room("!unknown:test")
+
+        # unknown remote room
+        _request_blocked_room("!unknown:remote")
+
+    def test_get_unblocked_room(self):
+        """Test get status of a unblocked room"""
+
+        def _request_unblocked_room(room_id: str) -> None:
+            self._is_blocked(room_id, expect=False)
+
+            channel = self.make_request(
+                "GET",
+                self.url % room_id,
+                access_token=self.admin_user_tok,
+            )
+            self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+            self.assertFalse(channel.json_body["block"])
+            self.assertNotIn("user_id", channel.json_body)
+
+        # known internal room
+        _request_unblocked_room(self.room_id)
+
+        # unknown internal room
+        _request_unblocked_room("!unknown:test")
+
+        # unknown remote room
+        _request_unblocked_room("!unknown:remote")
+
+    def _is_blocked(self, room_id: str, expect: bool = True) -> None:
+        """Assert that the room is blocked or not"""
+        d = self._store.is_room_blocked(room_id)
+        if expect:
+            self.assertTrue(self.get_success(d))
+        else:
+            self.assertIsNone(self.get_success(d))
+
+    def _block_room(self, room_id: str) -> None:
+        """Block a room in database"""
+        self.get_success(self._store.block_room(room_id, self.other_user))
+        self._is_blocked(room_id, expect=True)
+
+
 PURGE_TABLES = [
     "current_state_events",
     "event_backward_extremities",
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 25e8d6cf27..5011e54563 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1169,14 +1169,14 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         # regardless of whether password login or SSO is allowed
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.get_success(
-            self.auth_handler.get_access_token_for_user_id(
+            self.auth_handler.create_access_token_for_user_id(
                 self.admin_user, device_id=None, valid_until_ms=None
             )
         )
 
         self.other_user = self.register_user("user", "pass", displayname="User")
         self.other_user_token = self.get_success(
-            self.auth_handler.get_access_token_for_user_id(
+            self.auth_handler.create_access_token_for_user_id(
                 self.other_user, device_id=None, valid_until_ms=None
             )
         )
@@ -3592,31 +3592,34 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
             self.other_user
         )
 
-    def test_no_auth(self):
+    @parameterized.expand(["POST", "DELETE"])
+    def test_no_auth(self, method: str):
         """
         Try to get information of an user without authentication.
         """
-        channel = self.make_request("POST", self.url)
+        channel = self.make_request(method, self.url)
         self.assertEqual(401, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
 
-    def test_requester_is_not_admin(self):
+    @parameterized.expand(["POST", "DELETE"])
+    def test_requester_is_not_admin(self, method: str):
         """
         If the user is not a server admin, an error is returned.
         """
         other_user_token = self.login("user", "pass")
 
-        channel = self.make_request("POST", self.url, access_token=other_user_token)
+        channel = self.make_request(method, self.url, access_token=other_user_token)
         self.assertEqual(403, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
-    def test_user_is_not_local(self):
+    @parameterized.expand(["POST", "DELETE"])
+    def test_user_is_not_local(self, method: str):
         """
         Tests that shadow-banning for a user that is not a local returns a 400
         """
         url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
 
-        channel = self.make_request("POST", url, access_token=self.admin_user_tok)
+        channel = self.make_request(method, url, access_token=self.admin_user_tok)
         self.assertEqual(400, channel.code, msg=channel.json_body)
 
     def test_success(self):
@@ -3636,6 +3639,17 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
         result = self.get_success(self.store.get_user_by_access_token(other_user_token))
         self.assertTrue(result.shadow_banned)
 
+        # Un-shadow-ban the user.
+        channel = self.make_request(
+            "DELETE", self.url, access_token=self.admin_user_tok
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual({}, channel.json_body)
+
+        # Ensure the user is no longer shadow-banned (and the cache was cleared).
+        result = self.get_success(self.store.get_user_by_access_token(other_user_token))
+        self.assertFalse(result.shadow_banned)
+
 
 class RateLimitTestCase(unittest.HomeserverTestCase):
 
diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py
index b9e3602552..249808b031 100644
--- a/tests/rest/client/test_capabilities.py
+++ b/tests/rest/client/test_capabilities.py
@@ -71,7 +71,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
     @override_config({"password_config": {"localdb_enabled": False}})
     def test_get_change_password_capabilities_localdb_disabled(self):
         access_token = self.get_success(
-            self.auth_handler.get_access_token_for_user_id(
+            self.auth_handler.create_access_token_for_user_id(
                 self.user, device_id=None, valid_until_ms=None
             )
         )
@@ -85,7 +85,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
     @override_config({"password_config": {"enabled": False}})
     def test_get_change_password_capabilities_password_disabled(self):
         access_token = self.get_success(
-            self.auth_handler.get_access_token_for_user_id(
+            self.auth_handler.create_access_token_for_user_id(
                 self.user, device_id=None, valid_until_ms=None
             )
         )
@@ -174,7 +174,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
     @override_config({"experimental_features": {"msc3244_enabled": False}})
     def test_get_does_not_include_msc3244_fields_when_disabled(self):
         access_token = self.get_success(
-            self.auth_handler.get_access_token_for_user_id(
+            self.auth_handler.create_access_token_for_user_id(
                 self.user, device_id=None, valid_until_ms=None
             )
         )
@@ -189,7 +189,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
 
     def test_get_does_include_msc3244_fields_when_enabled(self):
         access_token = self.get_success(
-            self.auth_handler.get_access_token_for_user_id(
+            self.auth_handler.create_access_token_for_user_id(
                 self.user, device_id=None, valid_until_ms=None
             )
         )
diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py
index d2181ea907..aca03afd0e 100644
--- a/tests/rest/client/test_directory.py
+++ b/tests/rest/client/test_directory.py
@@ -11,12 +11,16 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import json
+from http import HTTPStatus
+
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.rest import admin
 from synapse.rest.client import directory, login, room
+from synapse.server import HomeServer
 from synapse.types import RoomAlias
+from synapse.util import Clock
 from synapse.util.stringutils import random_string
 
 from tests import unittest
@@ -32,7 +36,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         config = self.default_config()
         config["require_membership_for_aliases"] = True
 
@@ -40,7 +44,11 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
 
         return self.hs
 
-    def prepare(self, reactor, clock, homeserver):
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+    ) -> None:
+        """Create two local users and access tokens for them.
+        One of them creates a room."""
         self.room_owner = self.register_user("room_owner", "test")
         self.room_owner_tok = self.login("room_owner", "test")
 
@@ -51,39 +59,39 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         self.user = self.register_user("user", "test")
         self.user_tok = self.login("user", "test")
 
-    def test_state_event_not_in_room(self):
+    def test_state_event_not_in_room(self) -> None:
         self.ensure_user_left_room()
-        self.set_alias_via_state_event(403)
+        self.set_alias_via_state_event(HTTPStatus.FORBIDDEN)
 
-    def test_directory_endpoint_not_in_room(self):
+    def test_directory_endpoint_not_in_room(self) -> None:
         self.ensure_user_left_room()
-        self.set_alias_via_directory(403)
+        self.set_alias_via_directory(HTTPStatus.FORBIDDEN)
 
-    def test_state_event_in_room_too_long(self):
+    def test_state_event_in_room_too_long(self) -> None:
         self.ensure_user_joined_room()
-        self.set_alias_via_state_event(400, alias_length=256)
+        self.set_alias_via_state_event(HTTPStatus.BAD_REQUEST, alias_length=256)
 
-    def test_directory_in_room_too_long(self):
+    def test_directory_in_room_too_long(self) -> None:
         self.ensure_user_joined_room()
-        self.set_alias_via_directory(400, alias_length=256)
+        self.set_alias_via_directory(HTTPStatus.BAD_REQUEST, alias_length=256)
 
     @override_config({"default_room_version": 5})
-    def test_state_event_user_in_v5_room(self):
+    def test_state_event_user_in_v5_room(self) -> None:
         """Test that a regular user can add alias events before room v6"""
         self.ensure_user_joined_room()
-        self.set_alias_via_state_event(200)
+        self.set_alias_via_state_event(HTTPStatus.OK)
 
     @override_config({"default_room_version": 6})
-    def test_state_event_v6_room(self):
+    def test_state_event_v6_room(self) -> None:
         """Test that a regular user can *not* add alias events from room v6"""
         self.ensure_user_joined_room()
-        self.set_alias_via_state_event(403)
+        self.set_alias_via_state_event(HTTPStatus.FORBIDDEN)
 
-    def test_directory_in_room(self):
+    def test_directory_in_room(self) -> None:
         self.ensure_user_joined_room()
-        self.set_alias_via_directory(200)
+        self.set_alias_via_directory(HTTPStatus.OK)
 
-    def test_room_creation_too_long(self):
+    def test_room_creation_too_long(self) -> None:
         url = "/_matrix/client/r0/createRoom"
 
         # We use deliberately a localpart under the length threshold so
@@ -93,9 +101,9 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "POST", url, request_data, access_token=self.user_tok
         )
-        self.assertEqual(channel.code, 400, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
 
-    def test_room_creation(self):
+    def test_room_creation(self) -> None:
         url = "/_matrix/client/r0/createRoom"
 
         # Check with an alias of allowed length. There should already be
@@ -106,9 +114,46 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "POST", url, request_data, access_token=self.user_tok
         )
-        self.assertEqual(channel.code, 200, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+    def test_deleting_alias_via_directory(self) -> None:
+        # Add an alias for the room. We must be joined to do so.
+        self.ensure_user_joined_room()
+        alias = self.set_alias_via_directory(HTTPStatus.OK)
+
+        # Then try to remove the alias
+        channel = self.make_request(
+            "DELETE",
+            f"/_matrix/client/r0/directory/room/{alias}",
+            access_token=self.user_tok,
+        )
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+    def test_deleting_nonexistant_alias(self) -> None:
+        # Check that no alias exists
+        alias = "#potato:test"
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/r0/directory/room/{alias}",
+            access_token=self.user_tok,
+        )
+        self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
+        self.assertIn("error", channel.json_body, channel.json_body)
+        self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND", channel.json_body)
+
+        # Then try to remove the alias
+        channel = self.make_request(
+            "DELETE",
+            f"/_matrix/client/r0/directory/room/{alias}",
+            access_token=self.user_tok,
+        )
+        self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
+        self.assertIn("error", channel.json_body, channel.json_body)
+        self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND", channel.json_body)
 
-    def set_alias_via_state_event(self, expected_code, alias_length=5):
+    def set_alias_via_state_event(
+        self, expected_code: HTTPStatus, alias_length: int = 5
+    ) -> None:
         url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % (
             self.room_id,
             self.hs.hostname,
@@ -122,8 +167,11 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(channel.code, expected_code, channel.result)
 
-    def set_alias_via_directory(self, expected_code, alias_length=5):
-        url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length)
+    def set_alias_via_directory(
+        self, expected_code: HTTPStatus, alias_length: int = 5
+    ) -> str:
+        alias = self.random_alias(alias_length)
+        url = "/_matrix/client/r0/directory/room/%s" % alias
         data = {"room_id": self.room_id}
         request_data = json.dumps(data)
 
@@ -131,17 +179,18 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
             "PUT", url, request_data, access_token=self.user_tok
         )
         self.assertEqual(channel.code, expected_code, channel.result)
+        return alias
 
-    def random_alias(self, length):
+    def random_alias(self, length: int) -> str:
         return RoomAlias(random_string(length), self.hs.hostname).to_string()
 
-    def ensure_user_left_room(self):
+    def ensure_user_left_room(self) -> None:
         self.ensure_membership("leave")
 
-    def ensure_user_joined_room(self):
+    def ensure_user_joined_room(self) -> None:
         self.ensure_membership("join")
 
-    def ensure_membership(self, membership):
+    def ensure_membership(self, membership: str) -> None:
         try:
             if membership == "leave":
                 self.helper.leave(room=self.room_id, user=self.user, tok=self.user_tok)
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index a63f04bd41..19f5e46537 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -79,7 +79,10 @@ EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fö&=o"')]
 
 # (possibly experimental) login flows we expect to appear in the list after the normal
 # ones
-ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}]
+ADDITIONAL_LOGIN_FLOWS = [
+    {"type": "m.login.application_service"},
+    {"type": "uk.half-shot.msc2778.login.application_service"},
+]
 
 
 class LoginRestServletTestCase(unittest.HomeserverTestCase):
@@ -812,13 +815,20 @@ class JWTTestCase(unittest.HomeserverTestCase):
 
     jwt_secret = "secret"
     jwt_algorithm = "HS256"
+    base_config = {
+        "enabled": True,
+        "secret": jwt_secret,
+        "algorithm": jwt_algorithm,
+    }
 
-    def make_homeserver(self, reactor, clock):
-        self.hs = self.setup_test_homeserver()
-        self.hs.config.jwt.jwt_enabled = True
-        self.hs.config.jwt.jwt_secret = self.jwt_secret
-        self.hs.config.jwt.jwt_algorithm = self.jwt_algorithm
-        return self.hs
+    def default_config(self):
+        config = super().default_config()
+
+        # If jwt_config has been defined (eg via @override_config), don't replace it.
+        if config.get("jwt_config") is None:
+            config["jwt_config"] = self.base_config
+
+        return config
 
     def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
         # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
@@ -876,16 +886,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(channel.json_body["error"], "Invalid JWT")
 
-    @override_config(
-        {
-            "jwt_config": {
-                "jwt_enabled": True,
-                "secret": jwt_secret,
-                "algorithm": jwt_algorithm,
-                "issuer": "test-issuer",
-            }
-        }
-    )
+    @override_config({"jwt_config": {**base_config, "issuer": "test-issuer"}})
     def test_login_iss(self):
         """Test validating the issuer claim."""
         # A valid issuer.
@@ -916,16 +917,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.result["code"], b"200", channel.result)
         self.assertEqual(channel.json_body["user_id"], "@kermit:test")
 
-    @override_config(
-        {
-            "jwt_config": {
-                "jwt_enabled": True,
-                "secret": jwt_secret,
-                "algorithm": jwt_algorithm,
-                "audiences": ["test-audience"],
-            }
-        }
-    )
+    @override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}})
     def test_login_aud(self):
         """Test validating the audience claim."""
         # A valid audience.
@@ -959,6 +951,19 @@ class JWTTestCase(unittest.HomeserverTestCase):
             channel.json_body["error"], "JWT validation failed: Invalid audience"
         )
 
+    def test_login_default_sub(self):
+        """Test reading user ID from the default subject claim."""
+        channel = self.jwt_login({"sub": "kermit"})
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+    @override_config({"jwt_config": {**base_config, "subject_claim": "username"}})
+    def test_login_custom_sub(self):
+        """Test reading user ID from a custom subject claim."""
+        channel = self.jwt_login({"username": "frog"})
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.json_body["user_id"], "@frog:test")
+
     def test_login_no_token(self):
         params = {"type": "org.matrix.login.jwt"}
         channel = self.make_request(b"POST", LOGIN_URL, params)
@@ -1021,12 +1026,14 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
         ]
     )
 
-    def make_homeserver(self, reactor, clock):
-        self.hs = self.setup_test_homeserver()
-        self.hs.config.jwt.jwt_enabled = True
-        self.hs.config.jwt.jwt_secret = self.jwt_pubkey
-        self.hs.config.jwt.jwt_algorithm = "RS256"
-        return self.hs
+    def default_config(self):
+        config = super().default_config()
+        config["jwt_config"] = {
+            "enabled": True,
+            "secret": self.jwt_pubkey,
+            "algorithm": "RS256",
+        }
+        return config
 
     def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
         # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 78c2fb86b9..eb10d43217 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -1,4 +1,5 @@
 # Copyright 2019 New Vector Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -46,6 +47,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         return config
 
     def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+
         self.user_id, self.user_token = self._create_user("alice")
         self.user2_id, self.user2_token = self._create_user("bob")
 
@@ -91,6 +94,49 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member)
         self.assertEquals(400, channel.code, channel.json_body)
 
+    def test_deny_invalid_event(self):
+        """Test that we deny relations on non-existant events"""
+        channel = self._send_relation(
+            RelationTypes.ANNOTATION,
+            EventTypes.Message,
+            parent_id="foo",
+            content={"body": "foo", "msgtype": "m.text"},
+        )
+        self.assertEquals(400, channel.code, channel.json_body)
+
+        # Unless that event is referenced from another event!
+        self.get_success(
+            self.hs.get_datastore().db_pool.simple_insert(
+                table="event_relations",
+                values={
+                    "event_id": "bar",
+                    "relates_to_id": "foo",
+                    "relation_type": RelationTypes.THREAD,
+                },
+                desc="test_deny_invalid_event",
+            )
+        )
+        channel = self._send_relation(
+            RelationTypes.THREAD,
+            EventTypes.Message,
+            parent_id="foo",
+            content={"body": "foo", "msgtype": "m.text"},
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+
+    def test_deny_invalid_room(self):
+        """Test that we deny relations on non-existant events"""
+        # Create another room and send a message in it.
+        room2 = self.helper.create_room_as(self.user_id, tok=self.user_token)
+        res = self.helper.send(room2, body="Hi!", tok=self.user_token)
+        parent_id = res["event_id"]
+
+        # Attempt to send an annotation to that event.
+        channel = self._send_relation(
+            RelationTypes.ANNOTATION, "m.reaction", parent_id=parent_id, key="A"
+        )
+        self.assertEquals(400, channel.code, channel.json_body)
+
     def test_deny_double_react(self):
         """Test that we deny relations on membership events"""
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
@@ -99,6 +145,25 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
         self.assertEquals(400, channel.code, channel.json_body)
 
+    def test_deny_forked_thread(self):
+        """It is invalid to start a thread off a thread."""
+        channel = self._send_relation(
+            RelationTypes.THREAD,
+            "m.room.message",
+            content={"msgtype": "m.text", "body": "foo"},
+            parent_id=self.parent_id,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        parent_id = channel.json_body["event_id"]
+
+        channel = self._send_relation(
+            RelationTypes.THREAD,
+            "m.room.message",
+            content={"msgtype": "m.text", "body": "foo"},
+            parent_id=parent_id,
+        )
+        self.assertEquals(400, channel.code, channel.json_body)
+
     def test_basic_paginate_relations(self):
         """Tests that calling pagination API correctly the latest relations."""
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
@@ -703,6 +768,52 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertIn("chunk", channel.json_body)
         self.assertEquals(channel.json_body["chunk"], [])
 
+    def test_unknown_relations(self):
+        """Unknown relations should be accepted."""
+        channel = self._send_relation("m.relation.test", "m.room.test")
+        self.assertEquals(200, channel.code, channel.json_body)
+        event_id = channel.json_body["event_id"]
+
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1"
+            % (self.room, self.parent_id),
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        # We expect to get back a single pagination result, which is the full
+        # relation event we sent above.
+        self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body)
+        self.assert_dict(
+            {"event_id": event_id, "sender": self.user_id, "type": "m.room.test"},
+            channel.json_body["chunk"][0],
+        )
+
+        # We also expect to get the original event (the id of which is self.parent_id)
+        self.assertEquals(
+            channel.json_body["original_event"]["event_id"], self.parent_id
+        )
+
+        # When bundling the unknown relation is not included.
+        channel = self.make_request(
+            "GET",
+            "/rooms/%s/event/%s" % (self.room, self.parent_id),
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertNotIn("m.relations", channel.json_body["unsigned"])
+
+        # But unknown relations can be directly queried.
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1"
+            % (self.room, self.parent_id),
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEquals(channel.json_body["chunk"], [])
+
     def _send_relation(
         self,
         relation_type: str,
@@ -749,3 +860,65 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         access_token = self.login(localpart, "abc123")
 
         return user_id, access_token
+
+    def test_background_update(self):
+        """Test the event_arbitrary_relations background update."""
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
+        self.assertEquals(200, channel.code, channel.json_body)
+        annotation_event_id_good = channel.json_body["event_id"]
+
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A")
+        self.assertEquals(200, channel.code, channel.json_body)
+        annotation_event_id_bad = channel.json_body["event_id"]
+
+        channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+        self.assertEquals(200, channel.code, channel.json_body)
+        thread_event_id = channel.json_body["event_id"]
+
+        # Clean-up the table as if the inserts did not happen during event creation.
+        self.get_success(
+            self.store.db_pool.simple_delete_many(
+                table="event_relations",
+                column="event_id",
+                iterable=(annotation_event_id_bad, thread_event_id),
+                keyvalues={},
+                desc="RelationsTestCase.test_background_update",
+            )
+        )
+
+        # Only the "good" annotation should be found.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEquals(
+            [ev["event_id"] for ev in channel.json_body["chunk"]],
+            [annotation_event_id_good],
+        )
+
+        # Insert and run the background update.
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                {"update_name": "event_arbitrary_relations", "progress_json": "{}"},
+            )
+        )
+
+        # Ugh, have to reset this flag
+        self.store.db_pool.updates._all_done = False
+        self.wait_for_background_updates()
+
+        # The "good" annotation and the thread should be found, but not the "bad"
+        # annotation.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertCountEqual(
+            [ev["event_id"] for ev in channel.json_body["chunk"]],
+            [annotation_event_id_good, thread_event_id],
+        )
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 376853fd65..10a4a4dc5e 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -25,7 +25,12 @@ from urllib import parse as urlparse
 from twisted.internet import defer
 
 import synapse.rest.admin
-from synapse.api.constants import EventContentFields, EventTypes, Membership
+from synapse.api.constants import (
+    EventContentFields,
+    EventTypes,
+    Membership,
+    RelationTypes,
+)
 from synapse.api.errors import Codes, HttpResponseException
 from synapse.handlers.pagination import PurgeStatus
 from synapse.rest import admin
@@ -2157,6 +2162,153 @@ class LabelsTestCase(unittest.HomeserverTestCase):
         return event_id
 
 
+class RelationsTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def default_config(self):
+        config = super().default_config()
+        config["experimental_features"] = {"msc3440_enabled": True}
+        return config
+
+    def prepare(self, reactor, clock, homeserver):
+        self.user_id = self.register_user("test", "test")
+        self.tok = self.login("test", "test")
+        self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+        self.second_user_id = self.register_user("second", "test")
+        self.second_tok = self.login("second", "test")
+        self.helper.join(
+            room=self.room_id, user=self.second_user_id, tok=self.second_tok
+        )
+
+        self.third_user_id = self.register_user("third", "test")
+        self.third_tok = self.login("third", "test")
+        self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok)
+
+        # An initial event with a relation from second user.
+        res = self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={"msgtype": "m.text", "body": "Message 1"},
+            tok=self.tok,
+        )
+        self.event_id_1 = res["event_id"]
+        self.helper.send_event(
+            room_id=self.room_id,
+            type="m.reaction",
+            content={
+                "m.relates_to": {
+                    "rel_type": RelationTypes.ANNOTATION,
+                    "event_id": self.event_id_1,
+                    "key": "👍",
+                }
+            },
+            tok=self.second_tok,
+        )
+
+        # Another event with a relation from third user.
+        res = self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={"msgtype": "m.text", "body": "Message 2"},
+            tok=self.tok,
+        )
+        self.event_id_2 = res["event_id"]
+        self.helper.send_event(
+            room_id=self.room_id,
+            type="m.reaction",
+            content={
+                "m.relates_to": {
+                    "rel_type": RelationTypes.REFERENCE,
+                    "event_id": self.event_id_2,
+                }
+            },
+            tok=self.third_tok,
+        )
+
+        # An event with no relations.
+        self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={"msgtype": "m.text", "body": "No relations"},
+            tok=self.tok,
+        )
+
+    def _filter_messages(self, filter: JsonDict) -> List[JsonDict]:
+        """Make a request to /messages with a filter, returns the chunk of events."""
+        channel = self.make_request(
+            "GET",
+            "/rooms/%s/messages?filter=%s&dir=b" % (self.room_id, json.dumps(filter)),
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+
+        return channel.json_body["chunk"]
+
+    def test_filter_relation_senders(self):
+        # Messages which second user reacted to.
+        filter = {"io.element.relation_senders": [self.second_user_id]}
+        chunk = self._filter_messages(filter)
+        self.assertEqual(len(chunk), 1, chunk)
+        self.assertEqual(chunk[0]["event_id"], self.event_id_1)
+
+        # Messages which third user reacted to.
+        filter = {"io.element.relation_senders": [self.third_user_id]}
+        chunk = self._filter_messages(filter)
+        self.assertEqual(len(chunk), 1, chunk)
+        self.assertEqual(chunk[0]["event_id"], self.event_id_2)
+
+        # Messages which either user reacted to.
+        filter = {
+            "io.element.relation_senders": [self.second_user_id, self.third_user_id]
+        }
+        chunk = self._filter_messages(filter)
+        self.assertEqual(len(chunk), 2, chunk)
+        self.assertCountEqual(
+            [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
+        )
+
+    def test_filter_relation_type(self):
+        # Messages which have annotations.
+        filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]}
+        chunk = self._filter_messages(filter)
+        self.assertEqual(len(chunk), 1, chunk)
+        self.assertEqual(chunk[0]["event_id"], self.event_id_1)
+
+        # Messages which have references.
+        filter = {"io.element.relation_types": [RelationTypes.REFERENCE]}
+        chunk = self._filter_messages(filter)
+        self.assertEqual(len(chunk), 1, chunk)
+        self.assertEqual(chunk[0]["event_id"], self.event_id_2)
+
+        # Messages which have either annotations or references.
+        filter = {
+            "io.element.relation_types": [
+                RelationTypes.ANNOTATION,
+                RelationTypes.REFERENCE,
+            ]
+        }
+        chunk = self._filter_messages(filter)
+        self.assertEqual(len(chunk), 2, chunk)
+        self.assertCountEqual(
+            [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
+        )
+
+    def test_filter_relation_senders_and_type(self):
+        # Messages which second user reacted to.
+        filter = {
+            "io.element.relation_senders": [self.second_user_id],
+            "io.element.relation_types": [RelationTypes.ANNOTATION],
+        }
+        chunk = self._filter_messages(filter)
+        self.assertEqual(len(chunk), 1, chunk)
+        self.assertEqual(chunk[0]["event_id"], self.event_id_1)
+
+
 class ContextTestCase(unittest.HomeserverTestCase):
 
     servlets = [
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index ec0979850b..1af5e5cee5 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -19,10 +19,21 @@ import json
 import re
 import time
 import urllib.parse
-from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union
+from typing import (
+    Any,
+    AnyStr,
+    Dict,
+    Iterable,
+    Mapping,
+    MutableMapping,
+    Optional,
+    Tuple,
+    overload,
+)
 from unittest.mock import patch
 
 import attr
+from typing_extensions import Literal
 
 from twisted.web.resource import Resource
 from twisted.web.server import Site
@@ -45,6 +56,32 @@ class RestHelper:
     site = attr.ib(type=Site)
     auth_user_id = attr.ib()
 
+    @overload
+    def create_room_as(
+        self,
+        room_creator: Optional[str] = ...,
+        is_public: Optional[bool] = ...,
+        room_version: Optional[str] = ...,
+        tok: Optional[str] = ...,
+        expect_code: Literal[200] = ...,
+        extra_content: Optional[Dict] = ...,
+        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ...,
+    ) -> str:
+        ...
+
+    @overload
+    def create_room_as(
+        self,
+        room_creator: Optional[str] = ...,
+        is_public: Optional[bool] = ...,
+        room_version: Optional[str] = ...,
+        tok: Optional[str] = ...,
+        expect_code: int = ...,
+        extra_content: Optional[Dict] = ...,
+        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ...,
+    ) -> Optional[str]:
+        ...
+
     def create_room_as(
         self,
         room_creator: Optional[str] = None,
@@ -53,10 +90,8 @@ class RestHelper:
         tok: Optional[str] = None,
         expect_code: int = 200,
         extra_content: Optional[Dict] = None,
-        custom_headers: Optional[
-            Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
-        ] = None,
-    ) -> str:
+        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
+    ) -> Optional[str]:
         """
         Create a room.
 
@@ -99,6 +134,8 @@ class RestHelper:
 
         if expect_code == 200:
             return channel.json_body["room_id"]
+        else:
+            return None
 
     def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
         self.change_membership(
@@ -168,7 +205,7 @@ class RestHelper:
         extra_data: Optional[dict] = None,
         tok: Optional[str] = None,
         expect_code: int = 200,
-        expect_errcode: str = None,
+        expect_errcode: Optional[str] = None,
     ) -> None:
         """
         Send a membership state event into a room.
@@ -227,9 +264,7 @@ class RestHelper:
         txn_id=None,
         tok=None,
         expect_code=200,
-        custom_headers: Optional[
-            Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
-        ] = None,
+        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
     ):
         if body is None:
             body = "body_text_here"
@@ -254,9 +289,7 @@ class RestHelper:
         txn_id=None,
         tok=None,
         expect_code=200,
-        custom_headers: Optional[
-            Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
-        ] = None,
+        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
     ):
         if txn_id is None:
             txn_id = "m%s" % (str(time.time()))
@@ -418,7 +451,7 @@ class RestHelper:
             path,
             content=image_data,
             access_token=tok,
-            custom_headers=[(b"Content-Length", str(image_length))],
+            custom_headers=[("Content-Length", str(image_length))],
         )
 
         assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
@@ -503,7 +536,7 @@ class RestHelper:
             went.
         """
 
-        cookies = {}
+        cookies: Dict[str, str] = {}
 
         # if we're doing a ui auth, hit the ui auth redirect endpoint
         if ui_auth_session_id:
@@ -625,7 +658,13 @@ class RestHelper:
 
         # hit the redirect url again with the right Host header, which should now issue
         # a cookie and redirect to the SSO provider.
-        location = channel.headers.getRawHeaders("Location")[0]
+        def get_location(channel: FakeChannel) -> str:
+            location_values = channel.headers.getRawHeaders("Location")
+            # Keep mypy happy by asserting that location_values is nonempty
+            assert location_values
+            return location_values[0]
+
+        location = get_location(channel)
         parts = urllib.parse.urlsplit(location)
         channel = make_request(
             self.hs.get_reactor(),
@@ -639,7 +678,7 @@ class RestHelper:
 
         assert channel.code == 302
         channel.extract_cookies(cookies)
-        return channel.headers.getRawHeaders("Location")[0]
+        return get_location(channel)
 
     def initiate_sso_ui_auth(
         self, ui_auth_session_id: str, cookies: MutableMapping[str, str]
diff --git a/tests/server.py b/tests/server.py
index 103351b487..40cf5b12c3 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -16,7 +16,17 @@ import json
 import logging
 from collections import deque
 from io import SEEK_END, BytesIO
-from typing import Callable, Dict, Iterable, MutableMapping, Optional, Tuple, Union
+from typing import (
+    AnyStr,
+    Callable,
+    Dict,
+    Iterable,
+    MutableMapping,
+    Optional,
+    Tuple,
+    Type,
+    Union,
+)
 
 import attr
 from typing_extensions import Deque
@@ -217,14 +227,12 @@ def make_request(
     path: Union[bytes, str],
     content: Union[bytes, str, JsonDict] = b"",
     access_token: Optional[str] = None,
-    request: Request = SynapseRequest,
+    request: Type[Request] = SynapseRequest,
     shorthand: bool = True,
     federation_auth_origin: Optional[bytes] = None,
     content_is_form: bool = False,
     await_result: bool = True,
-    custom_headers: Optional[
-        Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
-    ] = None,
+    custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
     client_ip: str = "127.0.0.1",
 ) -> FakeChannel:
     """
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index a1ba99ff14..d37736edf8 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -11,19 +11,22 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
 
+from synapse.server import HomeServer
 from synapse.types import UserID
+from synapse.util import Clock
 
 from tests import unittest
 
 
 class ProfileStoreTestCase(unittest.HomeserverTestCase):
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastore()
 
         self.u_frank = UserID.from_string("@frank:test")
 
-    def test_displayname(self):
+    def test_displayname(self) -> None:
         self.get_success(self.store.create_profile(self.u_frank.localpart))
 
         self.get_success(
@@ -48,7 +51,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
             self.get_success(self.store.get_profile_displayname(self.u_frank.localpart))
         )
 
-    def test_avatar_url(self):
+    def test_avatar_url(self) -> None:
         self.get_success(self.store.create_profile(self.u_frank.localpart))
 
         self.get_success(
diff --git a/tests/storage/test_rollback_worker.py b/tests/storage/test_rollback_worker.py
index 0ce0892165..cfc8098af6 100644
--- a/tests/storage/test_rollback_worker.py
+++ b/tests/storage/test_rollback_worker.py
@@ -33,7 +33,7 @@ def fake_listdir(filepath: str) -> List[str]:
         A list of files and folders in the directory.
     """
     if filepath.endswith("full_schemas"):
-        return [SCHEMA_VERSION]
+        return [str(SCHEMA_VERSION)]
 
     return ["99_add_unicorn_to_database.sql"]
 
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 2873e22ccf..fccab733c0 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -161,6 +161,54 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(users.keys(), {self.u_alice, self.u_bob})
 
+    def test__null_byte_in_display_name_properly_handled(self):
+        room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
+
+        res = self.get_success(
+            self.store.db_pool.simple_select_list(
+                "room_memberships",
+                {"user_id": "@alice:test"},
+                ["display_name", "event_id"],
+            )
+        )
+        # Check that we only got one result back
+        self.assertEqual(len(res), 1)
+
+        # Check that alice's display name is "alice"
+        self.assertEqual(res[0]["display_name"], "alice")
+
+        # Grab the event_id to use later
+        event_id = res[0]["event_id"]
+
+        # Create a profile with the offending null byte in the display name
+        new_profile = {"displayname": "ali\u0000ce"}
+
+        # Ensure that the change goes smoothly and does not fail due to the null byte
+        self.helper.change_membership(
+            room,
+            self.u_alice,
+            self.u_alice,
+            "join",
+            extra_data=new_profile,
+            tok=self.t_alice,
+        )
+
+        res2 = self.get_success(
+            self.store.db_pool.simple_select_list(
+                "room_memberships",
+                {"user_id": "@alice:test"},
+                ["display_name", "event_id"],
+            )
+        )
+        # Check that we only have two results
+        self.assertEqual(len(res2), 2)
+
+        # Filter out the previous event using the event_id we grabbed above
+        row = [row for row in res2 if row["event_id"] != event_id]
+
+        # Check that alice's display name is now None
+        self.assertEqual(row[0]["display_name"], None)
+
 
 class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, homeserver):
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
new file mode 100644
index 0000000000..ce782c7e1d
--- /dev/null
+++ b/tests/storage/test_stream.py
@@ -0,0 +1,207 @@
+#  Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+
+from typing import List
+
+from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.filtering import Filter
+from synapse.events import EventBase
+from synapse.rest import admin
+from synapse.rest.client import login, room
+from synapse.types import JsonDict
+
+from tests.unittest import HomeserverTestCase
+
+
+class PaginationTestCase(HomeserverTestCase):
+    """
+    Test the pre-filtering done in the pagination code.
+
+    This is similar to some of the tests in tests.rest.client.test_rooms but here
+    we ensure that the filtering done in the database is applied successfully.
+    """
+
+    servlets = [
+        admin.register_servlets_for_client_rest_resource,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def default_config(self):
+        config = super().default_config()
+        config["experimental_features"] = {"msc3440_enabled": True}
+        return config
+
+    def prepare(self, reactor, clock, homeserver):
+        self.user_id = self.register_user("test", "test")
+        self.tok = self.login("test", "test")
+        self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+        self.second_user_id = self.register_user("second", "test")
+        self.second_tok = self.login("second", "test")
+        self.helper.join(
+            room=self.room_id, user=self.second_user_id, tok=self.second_tok
+        )
+
+        self.third_user_id = self.register_user("third", "test")
+        self.third_tok = self.login("third", "test")
+        self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok)
+
+        # An initial event with a relation from second user.
+        res = self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={"msgtype": "m.text", "body": "Message 1"},
+            tok=self.tok,
+        )
+        self.event_id_1 = res["event_id"]
+        self.helper.send_event(
+            room_id=self.room_id,
+            type="m.reaction",
+            content={
+                "m.relates_to": {
+                    "rel_type": RelationTypes.ANNOTATION,
+                    "event_id": self.event_id_1,
+                    "key": "👍",
+                }
+            },
+            tok=self.second_tok,
+        )
+
+        # Another event with a relation from third user.
+        res = self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={"msgtype": "m.text", "body": "Message 2"},
+            tok=self.tok,
+        )
+        self.event_id_2 = res["event_id"]
+        self.helper.send_event(
+            room_id=self.room_id,
+            type="m.reaction",
+            content={
+                "m.relates_to": {
+                    "rel_type": RelationTypes.REFERENCE,
+                    "event_id": self.event_id_2,
+                }
+            },
+            tok=self.third_tok,
+        )
+
+        # An event with no relations.
+        self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={"msgtype": "m.text", "body": "No relations"},
+            tok=self.tok,
+        )
+
+    def _filter_messages(self, filter: JsonDict) -> List[EventBase]:
+        """Make a request to /messages with a filter, returns the chunk of events."""
+
+        from_token = self.get_success(
+            self.hs.get_event_sources().get_current_token_for_pagination()
+        )
+
+        events, next_key = self.get_success(
+            self.hs.get_datastore().paginate_room_events(
+                room_id=self.room_id,
+                from_key=from_token.room_key,
+                to_key=None,
+                direction="b",
+                limit=10,
+                event_filter=Filter(self.hs, filter),
+            )
+        )
+
+        return events
+
+    def test_filter_relation_senders(self):
+        # Messages which second user reacted to.
+        filter = {"io.element.relation_senders": [self.second_user_id]}
+        chunk = self._filter_messages(filter)
+        self.assertEqual(len(chunk), 1, chunk)
+        self.assertEqual(chunk[0].event_id, self.event_id_1)
+
+        # Messages which third user reacted to.
+        filter = {"io.element.relation_senders": [self.third_user_id]}
+        chunk = self._filter_messages(filter)
+        self.assertEqual(len(chunk), 1, chunk)
+        self.assertEqual(chunk[0].event_id, self.event_id_2)
+
+        # Messages which either user reacted to.
+        filter = {
+            "io.element.relation_senders": [self.second_user_id, self.third_user_id]
+        }
+        chunk = self._filter_messages(filter)
+        self.assertEqual(len(chunk), 2, chunk)
+        self.assertCountEqual(
+            [c.event_id for c in chunk], [self.event_id_1, self.event_id_2]
+        )
+
+    def test_filter_relation_type(self):
+        # Messages which have annotations.
+        filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]}
+        chunk = self._filter_messages(filter)
+        self.assertEqual(len(chunk), 1, chunk)
+        self.assertEqual(chunk[0].event_id, self.event_id_1)
+
+        # Messages which have references.
+        filter = {"io.element.relation_types": [RelationTypes.REFERENCE]}
+        chunk = self._filter_messages(filter)
+        self.assertEqual(len(chunk), 1, chunk)
+        self.assertEqual(chunk[0].event_id, self.event_id_2)
+
+        # Messages which have either annotations or references.
+        filter = {
+            "io.element.relation_types": [
+                RelationTypes.ANNOTATION,
+                RelationTypes.REFERENCE,
+            ]
+        }
+        chunk = self._filter_messages(filter)
+        self.assertEqual(len(chunk), 2, chunk)
+        self.assertCountEqual(
+            [c.event_id for c in chunk], [self.event_id_1, self.event_id_2]
+        )
+
+    def test_filter_relation_senders_and_type(self):
+        # Messages which second user reacted to.
+        filter = {
+            "io.element.relation_senders": [self.second_user_id],
+            "io.element.relation_types": [RelationTypes.ANNOTATION],
+        }
+        chunk = self._filter_messages(filter)
+        self.assertEqual(len(chunk), 1, chunk)
+        self.assertEqual(chunk[0].event_id, self.event_id_1)
+
+    def test_duplicate_relation(self):
+        """An event should only be returned once if there are multiple relations to it."""
+        self.helper.send_event(
+            room_id=self.room_id,
+            type="m.reaction",
+            content={
+                "m.relates_to": {
+                    "rel_type": RelationTypes.ANNOTATION,
+                    "event_id": self.event_id_1,
+                    "key": "A",
+                }
+            },
+            tok=self.second_tok,
+        )
+
+        filter = {"io.element.relation_senders": [self.second_user_id]}
+        chunk = self._filter_messages(filter)
+        self.assertEqual(len(chunk), 1, chunk)
+        self.assertEqual(chunk[0].event_id, self.event_id_1)
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 24fc77d7a7..3eef1c4c05 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -81,8 +81,6 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
             origin,
             event,
             context,
-            state=None,
-            backfilled=False,
         ):
             return context
 
diff --git a/tests/unittest.py b/tests/unittest.py
index a9b60b7eeb..165aafc574 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -20,7 +20,20 @@ import inspect
 import logging
 import secrets
 import time
-from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
+from typing import (
+    Any,
+    AnyStr,
+    Callable,
+    ClassVar,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Tuple,
+    Type,
+    TypeVar,
+    Union,
+)
 from unittest.mock import Mock, patch
 
 from canonicaljson import json
@@ -31,6 +44,7 @@ from twisted.python.threadpool import ThreadPool
 from twisted.test.proto_helpers import MemoryReactor
 from twisted.trial import unittest
 from twisted.web.resource import Resource
+from twisted.web.server import Request
 
 from synapse import events
 from synapse.api.constants import EventTypes, Membership
@@ -45,6 +59,7 @@ from synapse.logging.context import (
     current_context,
     set_current_context,
 )
+from synapse.rest import RegisterServletsFunc
 from synapse.server import HomeServer
 from synapse.types import JsonDict, UserID, create_requester
 from synapse.util import Clock
@@ -81,16 +96,13 @@ def around(target):
     return _around
 
 
-T = TypeVar("T")
-
-
 class TestCase(unittest.TestCase):
     """A subclass of twisted.trial's TestCase which looks for 'loglevel'
     attributes on both itself and its individual test methods, to override the
     root logger's logging level while that test (case|method) runs."""
 
-    def __init__(self, methodName, *args, **kwargs):
-        super().__init__(methodName, *args, **kwargs)
+    def __init__(self, methodName: str):
+        super().__init__(methodName)
 
         method = getattr(self, methodName)
 
@@ -204,18 +216,18 @@ class HomeserverTestCase(TestCase):
       config dict.
 
     Attributes:
-        servlets (list[function]): List of servlet registration function.
+        servlets: List of servlet registration function.
         user_id (str): The user ID to assume if auth is hijacked.
-        hijack_auth (bool): Whether to hijack auth to return the user specified
+        hijack_auth: Whether to hijack auth to return the user specified
         in user_id.
     """
 
-    servlets = []
-    hijack_auth = True
-    needs_threadpool = False
+    hijack_auth: ClassVar[bool] = True
+    needs_threadpool: ClassVar[bool] = False
+    servlets: ClassVar[List[RegisterServletsFunc]] = []
 
-    def __init__(self, methodName, *args, **kwargs):
-        super().__init__(methodName, *args, **kwargs)
+    def __init__(self, methodName: str):
+        super().__init__(methodName)
 
         # see if we have any additional config for this test
         method = getattr(self, methodName)
@@ -287,9 +299,10 @@ class HomeserverTestCase(TestCase):
                         None,
                     )
 
-                self.hs.get_auth().get_user_by_req = get_user_by_req
-                self.hs.get_auth().get_user_by_access_token = get_user_by_access_token
-                self.hs.get_auth().get_access_token_from_request = Mock(
+                # Type ignore: mypy doesn't like us assigning to methods.
+                self.hs.get_auth().get_user_by_req = get_user_by_req  # type: ignore[assignment]
+                self.hs.get_auth().get_user_by_access_token = get_user_by_access_token  # type: ignore[assignment]
+                self.hs.get_auth().get_access_token_from_request = Mock(  # type: ignore[assignment]
                     return_value="1234"
                 )
 
@@ -318,7 +331,12 @@ class HomeserverTestCase(TestCase):
             time.sleep(0.01)
 
     def wait_for_background_updates(self) -> None:
-        """Block until all background database updates have completed."""
+        """
+        Block until all background database updates have completed.
+
+        Note that callers must ensure that's a store property created on the
+        testcase.
+        """
         while not self.get_success(
             self.store.db_pool.updates.has_completed_background_updates()
         ):
@@ -403,14 +421,12 @@ class HomeserverTestCase(TestCase):
         path: Union[bytes, str],
         content: Union[bytes, str, JsonDict] = b"",
         access_token: Optional[str] = None,
-        request: Type[T] = SynapseRequest,
+        request: Type[Request] = SynapseRequest,
         shorthand: bool = True,
-        federation_auth_origin: str = None,
+        federation_auth_origin: Optional[bytes] = None,
         content_is_form: bool = False,
         await_result: bool = True,
-        custom_headers: Optional[
-            Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
-        ] = None,
+        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
         client_ip: str = "127.0.0.1",
     ) -> FakeChannel:
         """
@@ -425,7 +441,7 @@ class HomeserverTestCase(TestCase):
             a dict.
             shorthand: Whether to try and be helpful and prefix the given URL
             with the usual REST API path, if it doesn't contain it.
-            federation_auth_origin (bytes|None): if set to not-None, we will add a fake
+            federation_auth_origin: if set to not-None, we will add a fake
                 Authorization header pretenting to be the given server name.
             content_is_form: Whether the content is URL encoded form data. Adds the
                 'Content-Type': 'application/x-www-form-urlencoded' header.
@@ -584,7 +600,7 @@ class HomeserverTestCase(TestCase):
             nonce_str += b"\x00notadmin"
 
         want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str)
-        want_mac = want_mac.hexdigest()
+        want_mac_digest = want_mac.hexdigest()
 
         body = json.dumps(
             {
@@ -593,7 +609,7 @@ class HomeserverTestCase(TestCase):
                 "displayname": displayname,
                 "password": password,
                 "admin": admin,
-                "mac": want_mac,
+                "mac": want_mac_digest,
                 "inhibit_login": True,
             }
         )
@@ -639,9 +655,7 @@ class HomeserverTestCase(TestCase):
         username,
         password,
         device_id=None,
-        custom_headers: Optional[
-            Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
-        ] = None,
+        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
     ):
         """
         Log in a user, and get an access token. Requires the Login API be