summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_auth.py108
-rw-r--r--tests/api/test_errors.py43
-rw-r--r--tests/api/test_filtering.py31
-rw-r--r--tests/api/test_ratelimiting.py67
-rw-r--r--tests/app/test_homeserver_start.py4
-rw-r--r--tests/app/test_openid_listener.py9
-rw-r--r--tests/app/test_phone_stats_home.py156
-rw-r--r--tests/appservice/test_api.py82
-rw-r--r--tests/appservice/test_appservice.py113
-rw-r--r--tests/appservice/test_scheduler.py43
-rw-r--r--tests/config/test_oauth_delegation.py278
-rw-r--r--tests/config/test_ratelimiting.py31
-rw-r--r--tests/config/test_workers.py16
-rw-r--r--tests/crypto/test_keyring.py113
-rw-r--r--tests/events/test_presence_router.py5
-rw-r--r--tests/events/test_snapshot.py3
-rw-r--r--tests/events/test_utils.py78
-rw-r--r--tests/federation/test_complexity.py37
-rw-r--r--tests/federation/test_federation_catch_up.py32
-rw-r--r--tests/federation/test_federation_client.py4
-rw-r--r--tests/federation/test_federation_sender.py71
-rw-r--r--tests/federation/test_federation_server.py35
-rw-r--r--tests/federation/transport/test_knocking.py6
-rw-r--r--tests/handlers/test_appservice.py270
-rw-r--r--tests/handlers/test_auth.py27
-rw-r--r--tests/handlers/test_cas.py28
-rw-r--r--tests/handlers/test_device.py197
-rw-r--r--tests/handlers/test_directory.py12
-rw-r--r--tests/handlers/test_e2e_keys.py146
-rw-r--r--tests/handlers/test_federation.py60
-rw-r--r--tests/handlers/test_federation_event.py186
-rw-r--r--tests/handlers/test_message.py15
-rw-r--r--tests/handlers/test_oauth_delegation.py687
-rw-r--r--tests/handlers/test_oidc.py10
-rw-r--r--tests/handlers/test_password_providers.py104
-rw-r--r--tests/handlers/test_presence.py957
-rw-r--r--tests/handlers/test_profile.py47
-rw-r--r--tests/handlers/test_register.py48
-rw-r--r--tests/handlers/test_room_member.py54
-rw-r--r--tests/handlers/test_saml.py17
-rw-r--r--tests/handlers/test_send_email.py69
-rw-r--r--tests/handlers/test_sync.py9
-rw-r--r--tests/handlers/test_typing.py101
-rw-r--r--tests/handlers/test_user_directory.py43
-rw-r--r--tests/handlers/test_worker_lock.py74
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py417
-rw-r--r--tests/http/test_matrixfederationclient.py304
-rw-r--r--tests/http/test_proxy.py53
-rw-r--r--tests/http/test_proxyagent.py4
-rw-r--r--tests/logging/test_opentracing.py43
-rw-r--r--tests/logging/test_remote_handler.py12
-rw-r--r--tests/logging/test_terse_json.py2
-rw-r--r--tests/media/test_base.py41
-rw-r--r--tests/media/test_html_preview.py18
-rw-r--r--tests/media/test_media_storage.py176
-rw-r--r--tests/media/test_oembed.py2
-rw-r--r--tests/media/test_url_previewer.py8
-rw-r--r--tests/metrics/test_metrics.py10
-rw-r--r--tests/module_api/test_api.py21
-rw-r--r--tests/push/test_bulk_push_rule_evaluator.py61
-rw-r--r--tests/push/test_email.py55
-rw-r--r--tests/replication/_base.py10
-rw-r--r--tests/replication/storage/_base.py17
-rw-r--r--tests/replication/storage/test_events.py53
-rw-r--r--tests/replication/tcp/streams/test_events.py10
-rw-r--r--tests/replication/tcp/streams/test_to_device.py2
-rw-r--r--tests/replication/test_federation_ack.py1
-rw-r--r--tests/replication/test_federation_sender_shard.py37
-rw-r--r--tests/replication/test_multi_media_repo.py21
-rw-r--r--tests/rest/admin/test_admin.py62
-rw-r--r--tests/rest/admin/test_federation.py6
-rw-r--r--tests/rest/admin/test_jwks.py106
-rw-r--r--tests/rest/admin/test_media.py71
-rw-r--r--tests/rest/admin/test_room.py159
-rw-r--r--tests/rest/admin/test_server_notice.py20
-rw-r--r--tests/rest/admin/test_statistics.py15
-rw-r--r--tests/rest/admin/test_user.py362
-rw-r--r--tests/rest/admin/test_username_available.py2
-rw-r--r--tests/rest/client/test_account.py17
-rw-r--r--tests/rest/client/test_account_data.py5
-rw-r--r--tests/rest/client/test_capabilities.py28
-rw-r--r--tests/rest/client/test_devices.py358
-rw-r--r--tests/rest/client/test_events.py2
-rw-r--r--tests/rest/client/test_filter.py8
-rw-r--r--tests/rest/client/test_login.py164
-rw-r--r--tests/rest/client/test_login_token_request.py71
-rw-r--r--tests/rest/client/test_models.py8
-rw-r--r--tests/rest/client/test_notifications.py5
-rw-r--r--tests/rest/client/test_presence.py6
-rw-r--r--tests/rest/client/test_profile.py12
-rw-r--r--tests/rest/client/test_public_rooms.py4
-rw-r--r--tests/rest/client/test_push_rule_attrs.py67
-rw-r--r--tests/rest/client/test_read_marker.py3
-rw-r--r--tests/rest/client/test_receipts.py221
-rw-r--r--tests/rest/client/test_redactions.py182
-rw-r--r--tests/rest/client/test_register.py14
-rw-r--r--tests/rest/client/test_relations.py87
-rw-r--r--tests/rest/client/test_room_batch.py302
-rw-r--r--tests/rest/client/test_rooms.py124
-rw-r--r--tests/rest/client/test_shadow_banned.py2
-rw-r--r--tests/rest/client/test_sync.py154
-rw-r--r--tests/rest/client/test_third_party_rules.py37
-rw-r--r--tests/rest/client/test_transactions.py9
-rw-r--r--tests/rest/client/utils.py6
-rw-r--r--tests/rest/media/test_url_preview.py126
-rw-r--r--tests/rest/test_well_known.py41
-rw-r--r--tests/server.py95
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py83
-rw-r--r--tests/state/test_v2.py13
-rw-r--r--tests/storage/databases/main/test_events_worker.py57
-rw-r--r--tests/storage/databases/main/test_lock.py340
-rw-r--r--tests/storage/test_appservice.py11
-rw-r--r--tests/storage/test_background_update.py268
-rw-r--r--tests/storage/test_cleanup_extrems.py14
-rw-r--r--tests/storage/test_client_ips.py102
-rw-r--r--tests/storage/test_devices.py18
-rw-r--r--tests/storage/test_e2e_room_keys.py2
-rw-r--r--tests/storage/test_end_to_end_keys.py10
-rw-r--r--tests/storage/test_event_chain.py11
-rw-r--r--tests/storage/test_event_federation.py494
-rw-r--r--tests/storage/test_keys.py137
-rw-r--r--tests/storage/test_main.py2
-rw-r--r--tests/storage/test_monthly_active_users.py23
-rw-r--r--tests/storage/test_profile.py21
-rw-r--r--tests/storage/test_purge.py2
-rw-r--r--tests/storage/test_registration.py46
-rw-r--r--tests/storage/test_rollback_worker.py5
-rw-r--r--tests/storage/test_room.py12
-rw-r--r--tests/storage/test_room_search.py8
-rw-r--r--tests/storage/test_transactions.py20
-rw-r--r--tests/storage/test_txn_limit.py2
-rw-r--r--tests/storage/test_user_filters.py4
-rw-r--r--tests/storage/util/test_partial_state_events_tracker.py8
-rw-r--r--tests/test_federation.py70
-rw-r--r--tests/test_server.py35
-rw-r--r--tests/test_state.py19
-rw-r--r--tests/test_terms_auth.py4
-rw-r--r--tests/test_utils/__init__.py35
-rw-r--r--tests/test_utils/logging_setup.py19
-rw-r--r--tests/test_visibility.py10
-rw-r--r--tests/unittest.py62
-rw-r--r--tests/util/caches/test_descriptors.py218
-rw-r--r--tests/util/test_async_helpers.py14
-rw-r--r--tests/util/test_retryutils.py73
-rw-r--r--tests/util/test_task_scheduler.py208
145 files changed, 8165 insertions, 3180 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 6e36e73f0d..e00d7215df 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -12,13 +12,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 import pymacaroons
 
 from twisted.test.proto_helpers import MemoryReactor
 
-from synapse.api.auth import Auth
+from synapse.api.auth.internal import InternalAuth
 from synapse.api.auth_blocking import AuthBlocking
 from synapse.api.constants import UserTypes
 from synapse.api.errors import (
@@ -35,7 +35,6 @@ from synapse.types import Requester, UserID
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import simple_async_mock
 from tests.unittest import override_config
 from tests.utils import mock_getRawHeaders
 
@@ -48,7 +47,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         # have been called by the HomeserverTestCase machinery.
         hs.datastores.main = self.store  # type: ignore[union-attr]
         hs.get_auth_handler().store = self.store
-        self.auth = Auth(hs)
+        self.auth = InternalAuth(hs)
 
         # AuthBlocking reads from the hs' config on initialization. We need to
         # modify its config instead of the hs'
@@ -60,15 +59,16 @@ class AuthTestCase(unittest.HomeserverTestCase):
         # this is overridden for the appservice tests
         self.store.get_app_service_by_token = Mock(return_value=None)
 
-        self.store.insert_client_ip = simple_async_mock(None)
-        self.store.is_support_user = simple_async_mock(False)
+        self.store.insert_client_ip = AsyncMock(return_value=None)
+        self.store.is_support_user = AsyncMock(return_value=False)
 
     def test_get_user_by_req_user_valid_token(self) -> None:
         user_info = TokenLookupResult(
             user_id=self.test_user, token_id=5, device_id="device"
         )
-        self.store.get_user_by_access_token = simple_async_mock(user_info)
-        self.store.mark_access_token_as_used = simple_async_mock(None)
+        self.store.get_user_by_access_token = AsyncMock(return_value=user_info)
+        self.store.mark_access_token_as_used = AsyncMock(return_value=None)
+        self.store.get_user_locked_status = AsyncMock(return_value=False)
 
         request = Mock(args={})
         request.args[b"access_token"] = [self.test_token]
@@ -77,7 +77,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.assertEqual(requester.user.to_string(), self.test_user)
 
     def test_get_user_by_req_user_bad_token(self) -> None:
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         request = Mock(args={})
         request.args[b"access_token"] = [self.test_token]
@@ -90,7 +90,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
     def test_get_user_by_req_user_missing_token(self) -> None:
         user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
-        self.store.get_user_by_access_token = simple_async_mock(user_info)
+        self.store.get_user_by_access_token = AsyncMock(return_value=user_info)
 
         request = Mock(args={})
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
@@ -105,7 +105,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
         )
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         request = Mock(args={})
         request.getClientAddress.return_value.host = "127.0.0.1"
@@ -124,7 +124,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             ip_range_whitelist=IPSet(["192.168/16"]),
         )
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         request = Mock(args={})
         request.getClientAddress.return_value.host = "192.168.10.10"
@@ -143,7 +143,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             ip_range_whitelist=IPSet(["192.168/16"]),
         )
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         request = Mock(args={})
         request.getClientAddress.return_value.host = "131.111.8.42"
@@ -157,7 +157,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
     def test_get_user_by_req_appservice_bad_token(self) -> None:
         self.store.get_app_service_by_token = Mock(return_value=None)
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         request = Mock(args={})
         request.args[b"access_token"] = [self.test_token]
@@ -171,7 +171,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
     def test_get_user_by_req_appservice_missing_token(self) -> None:
         app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         request = Mock(args={})
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
@@ -188,9 +188,12 @@ class AuthTestCase(unittest.HomeserverTestCase):
         )
         app_service.is_interested_in_user = Mock(return_value=True)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        # This just needs to return a truth-y value.
-        self.store.get_user_by_id = simple_async_mock({"is_guest": False})
-        self.store.get_user_by_access_token = simple_async_mock(None)
+
+        class FakeUserInfo:
+            is_guest = False
+
+        self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo())
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         request = Mock(args={})
         request.getClientAddress.return_value.host = "127.0.0.1"
@@ -209,7 +212,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         )
         app_service.is_interested_in_user = Mock(return_value=False)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         request = Mock(args={})
         request.getClientAddress.return_value.host = "127.0.0.1"
@@ -233,10 +236,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
         app_service.is_interested_in_user = Mock(return_value=True)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
         # This just needs to return a truth-y value.
-        self.store.get_user_by_id = simple_async_mock({"is_guest": False})
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
         # This also needs to just return a truth-y value
-        self.store.get_device = simple_async_mock({"hidden": False})
+        self.store.get_device = AsyncMock(return_value={"hidden": False})
 
         request = Mock(args={})
         request.getClientAddress.return_value.host = "127.0.0.1"
@@ -265,10 +268,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
         app_service.is_interested_in_user = Mock(return_value=True)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
         # This just needs to return a truth-y value.
-        self.store.get_user_by_id = simple_async_mock({"is_guest": False})
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
         # This also needs to just return a falsey value
-        self.store.get_device = simple_async_mock(None)
+        self.store.get_device = AsyncMock(return_value=None)
 
         request = Mock(args={})
         request.getClientAddress.return_value.host = "127.0.0.1"
@@ -282,8 +285,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.assertEqual(failure.value.errcode, Codes.EXCLUSIVE)
 
     def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> None:
-        self.store.get_user_by_access_token = simple_async_mock(
-            TokenLookupResult(
+        self.store.get_user_by_access_token = AsyncMock(
+            return_value=TokenLookupResult(
                 user_id="@baldrick:matrix.org",
                 device_id="device",
                 token_id=5,
@@ -291,8 +294,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
                 token_used=True,
             )
         )
-        self.store.insert_client_ip = simple_async_mock(None)
-        self.store.mark_access_token_as_used = simple_async_mock(None)
+        self.store.insert_client_ip = AsyncMock(return_value=None)
+        self.store.mark_access_token_as_used = AsyncMock(return_value=None)
+        self.store.get_user_locked_status = AsyncMock(return_value=False)
         request = Mock(args={})
         request.getClientAddress.return_value.host = "127.0.0.1"
         request.args[b"access_token"] = [self.test_token]
@@ -302,8 +306,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
     def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None:
         self.auth._track_puppeted_user_ips = True
-        self.store.get_user_by_access_token = simple_async_mock(
-            TokenLookupResult(
+        self.store.get_user_by_access_token = AsyncMock(
+            return_value=TokenLookupResult(
                 user_id="@baldrick:matrix.org",
                 device_id="device",
                 token_id=5,
@@ -311,8 +315,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
                 token_used=True,
             )
         )
-        self.store.insert_client_ip = simple_async_mock(None)
-        self.store.mark_access_token_as_used = simple_async_mock(None)
+        self.store.get_user_locked_status = AsyncMock(return_value=False)
+        self.store.insert_client_ip = AsyncMock(return_value=None)
+        self.store.mark_access_token_as_used = AsyncMock(return_value=None)
         request = Mock(args={})
         request.getClientAddress.return_value.host = "127.0.0.1"
         request.args[b"access_token"] = [self.test_token]
@@ -321,7 +326,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.assertEqual(self.store.insert_client_ip.call_count, 2)
 
     def test_get_user_from_macaroon(self) -> None:
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         user_id = "@baldrick:matrix.org"
         macaroon = pymacaroons.Macaroon(
@@ -339,8 +344,11 @@ class AuthTestCase(unittest.HomeserverTestCase):
         )
 
     def test_get_guest_user_from_macaroon(self) -> None:
-        self.store.get_user_by_id = simple_async_mock({"is_guest": True})
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        class FakeUserInfo:
+            is_guest = True
+
+        self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo())
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         user_id = "@baldrick:matrix.org"
         macaroon = pymacaroons.Macaroon(
@@ -370,7 +378,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
         self.auth_blocking._limit_usage_by_mau = True
 
-        self.store.get_monthly_active_count = simple_async_mock(lots_of_users)
+        self.store.get_monthly_active_count = AsyncMock(return_value=lots_of_users)
 
         e = self.get_failure(
             self.auth_blocking.check_auth_blocking(), ResourceLimitError
@@ -380,25 +388,27 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.assertEqual(e.value.code, 403)
 
         # Ensure does not throw an error
-        self.store.get_monthly_active_count = simple_async_mock(small_number_of_users)
+        self.store.get_monthly_active_count = AsyncMock(
+            return_value=small_number_of_users
+        )
         self.get_success(self.auth_blocking.check_auth_blocking())
 
     def test_blocking_mau__depending_on_user_type(self) -> None:
         self.auth_blocking._max_mau_value = 50
         self.auth_blocking._limit_usage_by_mau = True
 
-        self.store.get_monthly_active_count = simple_async_mock(100)
+        self.store.get_monthly_active_count = AsyncMock(return_value=100)
         # Support users allowed
         self.get_success(
             self.auth_blocking.check_auth_blocking(user_type=UserTypes.SUPPORT)
         )
-        self.store.get_monthly_active_count = simple_async_mock(100)
+        self.store.get_monthly_active_count = AsyncMock(return_value=100)
         # Bots not allowed
         self.get_failure(
             self.auth_blocking.check_auth_blocking(user_type=UserTypes.BOT),
             ResourceLimitError,
         )
-        self.store.get_monthly_active_count = simple_async_mock(100)
+        self.store.get_monthly_active_count = AsyncMock(return_value=100)
         # Real users not allowed
         self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError)
 
@@ -409,9 +419,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.auth_blocking._limit_usage_by_mau = True
         self.auth_blocking._track_appservice_user_ips = False
 
-        self.store.get_monthly_active_count = simple_async_mock(100)
-        self.store.user_last_seen_monthly_active = simple_async_mock()
-        self.store.is_trial_user = simple_async_mock()
+        self.store.get_monthly_active_count = AsyncMock(return_value=100)
+        self.store.user_last_seen_monthly_active = AsyncMock(return_value=None)
+        self.store.is_trial_user = AsyncMock(return_value=False)
 
         appservice = ApplicationService(
             "abcd",
@@ -426,6 +436,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             access_token_id=None,
             device_id="FOOBAR",
             is_guest=False,
+            scope=set(),
             shadow_banned=False,
             app_service=appservice,
             authenticated_entity="@appservice:server",
@@ -439,9 +450,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.auth_blocking._limit_usage_by_mau = True
         self.auth_blocking._track_appservice_user_ips = True
 
-        self.store.get_monthly_active_count = simple_async_mock(100)
-        self.store.user_last_seen_monthly_active = simple_async_mock()
-        self.store.is_trial_user = simple_async_mock()
+        self.store.get_monthly_active_count = AsyncMock(return_value=100)
+        self.store.user_last_seen_monthly_active = AsyncMock(return_value=None)
+        self.store.is_trial_user = AsyncMock(return_value=False)
 
         appservice = ApplicationService(
             "abcd",
@@ -456,6 +467,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             access_token_id=None,
             device_id="FOOBAR",
             is_guest=False,
+            scope=set(),
             shadow_banned=False,
             app_service=appservice,
             authenticated_entity="@appservice:server",
@@ -468,7 +480,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
     def test_reserved_threepid(self) -> None:
         self.auth_blocking._limit_usage_by_mau = True
         self.auth_blocking._max_mau_value = 1
-        self.store.get_monthly_active_count = simple_async_mock(2)
+        self.store.get_monthly_active_count = AsyncMock(return_value=2)
         threepid = {"medium": "email", "address": "reserved@server.com"}
         unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
         self.auth_blocking._mau_limits_reserved_threepids = [threepid]
diff --git a/tests/api/test_errors.py b/tests/api/test_errors.py
new file mode 100644
index 0000000000..8e159029d9
--- /dev/null
+++ b/tests/api/test_errors.py
@@ -0,0 +1,43 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+from synapse.api.errors import LimitExceededError
+
+from tests import unittest
+
+
+class LimitExceededErrorTestCase(unittest.TestCase):
+    def test_key_appears_in_context_but_not_error_dict(self) -> None:
+        err = LimitExceededError("needle")
+        serialised = json.dumps(err.error_dict(None))
+        self.assertIn("needle", err.debug_context)
+        self.assertNotIn("needle", serialised)
+
+    # Create a sub-class to avoid mutating the class-level property.
+    class LimitExceededErrorHeaders(LimitExceededError):
+        include_retry_after_header = True
+
+    def test_limit_exceeded_header(self) -> None:
+        err = self.LimitExceededErrorHeaders(limiter_name="test", retry_after_ms=100)
+        self.assertEqual(err.error_dict(None).get("retry_after_ms"), 100)
+        assert err.headers is not None
+        self.assertEqual(err.headers.get("Retry-After"), "1")
+
+    def test_limit_exceeded_rounding(self) -> None:
+        err = self.LimitExceededErrorHeaders(limiter_name="test", retry_after_ms=3001)
+        self.assertEqual(err.error_dict(None).get("retry_after_ms"), 3001)
+        assert err.headers is not None
+        self.assertEqual(err.headers.get("Retry-After"), "4")
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 222449baac..868f0c6995 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -35,7 +35,6 @@ from tests.events.test_utils import MockEvent
 
 user_id = UserID.from_string("@test_user:test")
 user2_id = UserID.from_string("@test_user2:test")
-user_localpart = "test_user"
 
 
 class FilteringTestCase(unittest.HomeserverTestCase):
@@ -48,8 +47,6 @@ class FilteringTestCase(unittest.HomeserverTestCase):
         invalid_filters: List[JsonDict] = [
             # `account_data` must be a dictionary
             {"account_data": "Hello World"},
-            # `event_fields` entries must not contain backslashes
-            {"event_fields": [r"\\foo"]},
             # `event_format` must be "client" or "federation"
             {"event_format": "other"},
             # `not_rooms` must contain valid room IDs
@@ -114,10 +111,6 @@ class FilteringTestCase(unittest.HomeserverTestCase):
                 "event_format": "client",
                 "event_fields": ["type", "content", "sender"],
             },
-            # a single backslash should be permitted (though it is debatable whether
-            # it should be permitted before anything other than `.`, and what that
-            # actually means)
-            #
             # (note that event_fields is implemented in
             # synapse.events.utils.serialize_event, and so whether this actually works
             # is tested elsewhere. We just want to check that it is allowed through the
@@ -455,9 +448,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
         ]
 
         user_filter = self.get_success(
-            self.filtering.get_user_filter(
-                user_localpart=user_localpart, filter_id=filter_id
-            )
+            self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
         )
 
         results = self.get_success(user_filter.filter_presence(presence_states))
@@ -485,9 +476,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
         ]
 
         user_filter = self.get_success(
-            self.filtering.get_user_filter(
-                user_localpart=user_localpart + "2", filter_id=filter_id
-            )
+            self.filtering.get_user_filter(user_id=user2_id, filter_id=filter_id)
         )
 
         results = self.get_success(user_filter.filter_presence(presence_states))
@@ -504,9 +493,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
         events = [event]
 
         user_filter = self.get_success(
-            self.filtering.get_user_filter(
-                user_localpart=user_localpart, filter_id=filter_id
-            )
+            self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
         )
 
         results = self.get_success(user_filter.filter_room_state(events=events))
@@ -525,9 +512,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
         events = [event]
 
         user_filter = self.get_success(
-            self.filtering.get_user_filter(
-                user_localpart=user_localpart, filter_id=filter_id
-            )
+            self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
         )
 
         results = self.get_success(user_filter.filter_room_state(events))
@@ -609,9 +594,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
             user_filter_json,
             (
                 self.get_success(
-                    self.datastore.get_user_filter(
-                        user_localpart=user_localpart, filter_id=0
-                    )
+                    self.datastore.get_user_filter(user_id=user_id, filter_id=0)
                 )
             ),
         )
@@ -626,9 +609,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
         )
 
         filter = self.get_success(
-            self.filtering.get_user_filter(
-                user_localpart=user_localpart, filter_id=filter_id
-            )
+            self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
         )
 
         self.assertEqual(filter.get_filter_json(), user_filter_json)
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index fa6c1c02ce..a24638c9ef 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -1,5 +1,6 @@
 from synapse.api.ratelimiting import LimitExceededError, Ratelimiter
 from synapse.appservice import ApplicationService
+from synapse.config.ratelimiting import RatelimitSettings
 from synapse.types import create_requester
 
 from tests import unittest
@@ -10,8 +11,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=1,
+            cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
         )
         allowed, time_allowed = self.get_success_or_raise(
             limiter.can_do_action(None, key="test_id", _time_now_s=0)
@@ -43,8 +43,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=1,
+            cfg=RatelimitSettings(
+                key="",
+                per_second=0.1,
+                burst_count=1,
+            ),
         )
         allowed, time_allowed = self.get_success_or_raise(
             limiter.can_do_action(as_requester, _time_now_s=0)
@@ -76,8 +79,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=1,
+            cfg=RatelimitSettings(
+                key="",
+                per_second=0.1,
+                burst_count=1,
+            ),
         )
         allowed, time_allowed = self.get_success_or_raise(
             limiter.can_do_action(as_requester, _time_now_s=0)
@@ -101,8 +107,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=1,
+            cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
         )
 
         # Shouldn't raise
@@ -128,8 +133,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=1,
+            cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
         )
 
         # First attempt should be allowed
@@ -177,8 +181,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=1,
+            cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
         )
 
         # First attempt should be allowed
@@ -208,8 +211,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=1,
+            cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
         )
         self.get_success_or_raise(
             limiter.can_do_action(None, key="test_id_1", _time_now_s=0)
@@ -244,7 +246,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
             )
         )
 
-        limiter = Ratelimiter(store=store, clock=self.clock, rate_hz=0.1, burst_count=1)
+        limiter = Ratelimiter(
+            store=store,
+            clock=self.clock,
+            cfg=RatelimitSettings("", per_second=0.1, burst_count=1),
+        )
 
         # Shouldn't raise
         for _ in range(20):
@@ -254,8 +260,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=3,
+            cfg=RatelimitSettings(
+                key="",
+                per_second=0.1,
+                burst_count=3,
+            ),
         )
         # Test that 4 actions aren't allowed with a maximum burst of 3.
         allowed, time_allowed = self.get_success_or_raise(
@@ -321,8 +330,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=3,
+            cfg=RatelimitSettings("", per_second=0.1, burst_count=3),
         )
 
         def consume_at(time: float) -> bool:
@@ -346,8 +354,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=3,
+            cfg=RatelimitSettings(
+                "",
+                per_second=0.1,
+                burst_count=3,
+            ),
         )
 
         # Observe two actions, leaving room in the bucket for one more.
@@ -369,8 +380,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=3,
+            cfg=RatelimitSettings(
+                "",
+                per_second=0.1,
+                burst_count=3,
+            ),
         )
 
         # Observe three actions, filling up the bucket.
@@ -398,8 +412,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=3,
+            cfg=RatelimitSettings(
+                "",
+                per_second=0.1,
+                burst_count=3,
+            ),
         )
 
         # Observe four actions, exceeding the bucket.
diff --git a/tests/app/test_homeserver_start.py b/tests/app/test_homeserver_start.py
index 788c935537..0201933b04 100644
--- a/tests/app/test_homeserver_start.py
+++ b/tests/app/test_homeserver_start.py
@@ -25,7 +25,9 @@ class HomeserverAppStartTestCase(ConfigFileTestCase):
         # Add a blank line as otherwise the next addition ends up on a line with a comment
         self.add_lines_to_config(["  "])
         self.add_lines_to_config(["worker_app: test_worker_app"])
-
+        self.add_lines_to_config(["worker_log_config: /data/logconfig.config"])
+        self.add_lines_to_config(["instance_map:"])
+        self.add_lines_to_config(["  main:", "    host: 127.0.0.1", "    port: 1234"])
         # Ensure that starting master process with worker config raises an exception
         with self.assertRaises(ConfigError):
             synapse.app.homeserver.setup(["-c", self.config_file])
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index 6e0413400e..21c5309740 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -31,9 +31,7 @@ from tests.unittest import HomeserverTestCase
 
 class FederationReaderOpenIDListenerTests(HomeserverTestCase):
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-        hs = self.setup_test_homeserver(
-            federation_http_client=None, homeserver_to_use=GenericWorkerServer
-        )
+        hs = self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer)
         return hs
 
     def default_config(self) -> JsonDict:
@@ -42,6 +40,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
         # have to tell the FederationHandler not to try to access stuff that is only
         # in the primary store.
         conf["worker_app"] = "yes"
+        conf["instance_map"] = {"main": {"host": "127.0.0.1", "port": 0}}
 
         return conf
 
@@ -90,9 +89,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
 @patch("synapse.app.homeserver.KeyResource", new=Mock())
 class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-        hs = self.setup_test_homeserver(
-            federation_http_client=None, homeserver_to_use=SynapseHomeServer
-        )
+        hs = self.setup_test_homeserver(homeserver_to_use=SynapseHomeServer)
         return hs
 
     @parameterized.expand(
diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py
index a860eedbcf..93af614def 100644
--- a/tests/app/test_phone_stats_home.py
+++ b/tests/app/test_phone_stats_home.py
@@ -4,7 +4,6 @@ from synapse.rest.client import login, room
 from synapse.server import HomeServer
 from synapse.util import Clock
 
-from tests import unittest
 from tests.server import ThreadedMemoryReactorClock
 from tests.unittest import HomeserverTestCase
 
@@ -12,154 +11,6 @@ FIVE_MINUTES_IN_SECONDS = 300
 ONE_DAY_IN_SECONDS = 86400
 
 
-class PhoneHomeTestCase(HomeserverTestCase):
-    servlets = [
-        synapse.rest.admin.register_servlets_for_client_rest_resource,
-        room.register_servlets,
-        login.register_servlets,
-    ]
-
-    # Override the retention time for the user_ips table because otherwise it
-    # gets pruned too aggressively for our R30 test.
-    @unittest.override_config({"user_ips_max_age": "365d"})
-    def test_r30_minimum_usage(self) -> None:
-        """
-        Tests the minimum amount of interaction necessary for the R30 metric
-        to consider a user 'retained'.
-        """
-
-        # Register a user, log it in, create a room and send a message
-        user_id = self.register_user("u1", "secret!")
-        access_token = self.login("u1", "secret!")
-        room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token)
-        self.helper.send(room_id, "message", tok=access_token)
-
-        # Check the R30 results do not count that user.
-        r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
-        self.assertEqual(r30_results, {"all": 0})
-
-        # Advance 30 days (+ 1 second, because strict inequality causes issues if we are
-        # bang on 30 days later).
-        self.reactor.advance(30 * ONE_DAY_IN_SECONDS + 1)
-
-        # (Make sure the user isn't somehow counted by this point.)
-        r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
-        self.assertEqual(r30_results, {"all": 0})
-
-        # Send a message (this counts as activity)
-        self.helper.send(room_id, "message2", tok=access_token)
-
-        # We have to wait some time for _update_client_ips_batch to get
-        # called and update the user_ips table.
-        self.reactor.advance(2 * 60 * 60)
-
-        # *Now* the user is counted.
-        r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
-        self.assertEqual(r30_results, {"all": 1, "unknown": 1})
-
-        # Advance 29 days. The user has now not posted for 29 days.
-        self.reactor.advance(29 * ONE_DAY_IN_SECONDS)
-
-        # The user is still counted.
-        r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
-        self.assertEqual(r30_results, {"all": 1, "unknown": 1})
-
-        # Advance another day. The user has now not posted for 30 days.
-        self.reactor.advance(ONE_DAY_IN_SECONDS)
-
-        # The user is now no longer counted in R30.
-        r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
-        self.assertEqual(r30_results, {"all": 0})
-
-    def test_r30_minimum_usage_using_default_config(self) -> None:
-        """
-        Tests the minimum amount of interaction necessary for the R30 metric
-        to consider a user 'retained'.
-
-        N.B. This test does not override the `user_ips_max_age` config setting,
-        which defaults to 28 days.
-        """
-
-        # Register a user, log it in, create a room and send a message
-        user_id = self.register_user("u1", "secret!")
-        access_token = self.login("u1", "secret!")
-        room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token)
-        self.helper.send(room_id, "message", tok=access_token)
-
-        # Check the R30 results do not count that user.
-        r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
-        self.assertEqual(r30_results, {"all": 0})
-
-        # Advance 30 days (+ 1 second, because strict inequality causes issues if we are
-        # bang on 30 days later).
-        self.reactor.advance(30 * ONE_DAY_IN_SECONDS + 1)
-
-        # (Make sure the user isn't somehow counted by this point.)
-        r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
-        self.assertEqual(r30_results, {"all": 0})
-
-        # Send a message (this counts as activity)
-        self.helper.send(room_id, "message2", tok=access_token)
-
-        # We have to wait some time for _update_client_ips_batch to get
-        # called and update the user_ips table.
-        self.reactor.advance(2 * 60 * 60)
-
-        # *Now* the user is counted.
-        r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
-        self.assertEqual(r30_results, {"all": 1, "unknown": 1})
-
-        # Advance 27 days. The user has now not posted for 27 days.
-        self.reactor.advance(27 * ONE_DAY_IN_SECONDS)
-
-        # The user is still counted.
-        r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
-        self.assertEqual(r30_results, {"all": 1, "unknown": 1})
-
-        # Advance another day. The user has now not posted for 28 days.
-        self.reactor.advance(ONE_DAY_IN_SECONDS)
-
-        # The user is now no longer counted in R30.
-        # (This is because the user_ips table has been pruned, which by default
-        # only preserves the last 28 days of entries.)
-        r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
-        self.assertEqual(r30_results, {"all": 0})
-
-    def test_r30_user_must_be_retained_for_at_least_a_month(self) -> None:
-        """
-        Tests that a newly-registered user must be retained for a whole month
-        before appearing in the R30 statistic, even if they post every day
-        during that time!
-        """
-        # Register a user and send a message
-        user_id = self.register_user("u1", "secret!")
-        access_token = self.login("u1", "secret!")
-        room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token)
-        self.helper.send(room_id, "message", tok=access_token)
-
-        # Check the user does not contribute to R30 yet.
-        r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
-        self.assertEqual(r30_results, {"all": 0})
-
-        for _ in range(30):
-            # This loop posts a message every day for 30 days
-            self.reactor.advance(ONE_DAY_IN_SECONDS)
-            self.helper.send(room_id, "I'm still here", tok=access_token)
-
-            # Notice that the user *still* does not contribute to R30!
-            r30_results = self.get_success(
-                self.hs.get_datastores().main.count_r30_users()
-            )
-            self.assertEqual(r30_results, {"all": 0})
-
-        self.reactor.advance(ONE_DAY_IN_SECONDS)
-        self.helper.send(room_id, "Still here!", tok=access_token)
-
-        # *Now* the user appears in R30.
-        r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
-        self.assertEqual(r30_results, {"all": 1, "unknown": 1})
-
-
 class PhoneHomeR30V2TestCase(HomeserverTestCase):
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -175,7 +26,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase):
     def make_homeserver(
         self, reactor: ThreadedMemoryReactorClock, clock: Clock
     ) -> HomeServer:
-        hs = super(PhoneHomeR30V2TestCase, self).make_homeserver(reactor, clock)
+        hs = super().make_homeserver(reactor, clock)
 
         # We don't want our tests to actually report statistics, so check
         # that it's not enabled
@@ -363,11 +214,6 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase):
             r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
         )
 
-        # Check that this is a situation where old R30 differs:
-        # old R30 DOES count this as 'retained'.
-        r30_results = self.get_success(store.count_r30_users())
-        self.assertEqual(r30_results, {"all": 1, "ios": 1})
-
         # Now we want to check that the user will still be able to appear in
         # R30v2 as long as the user performs some other activity between
         # 30 and 60 days later.
diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py
index 15fce165b6..366b6fd5f0 100644
--- a/tests/appservice/test_api.py
+++ b/tests/appservice/test_api.py
@@ -11,18 +11,18 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Any, List, Mapping, Sequence, Union
+from typing import Any, List, Mapping, Optional, Sequence, Union
 from unittest.mock import Mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
-from synapse.api.errors import HttpResponseException
 from synapse.appservice import ApplicationService
 from synapse.server import HomeServer
 from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests import unittest
+from tests.unittest import override_config
 
 PROTOCOL = "myproto"
 TOKEN = "myastoken"
@@ -40,7 +40,7 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
             hs_token=TOKEN,
         )
 
-    def test_query_3pe_authenticates_token(self) -> None:
+    def test_query_3pe_authenticates_token_via_header(self) -> None:
         """
         Tests that 3pe queries to the appservice are authenticated
         with the appservice's token.
@@ -75,12 +75,18 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
             args: Mapping[Any, Any],
             headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]],
         ) -> List[JsonDict]:
-            # Ensure the access token is passed as both a header and query arg.
-            if not headers.get("Authorization") or not args.get(b"access_token"):
+            # Ensure the access token is passed as a header.
+            if not headers or not headers.get(b"Authorization"):
                 raise RuntimeError("Access token not provided")
+            # ... and not as a query param
+            if b"access_token" in args:
+                raise RuntimeError(
+                    "Access token should not be passed as a query param."
+                )
 
-            self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"])
-            self.assertEqual(args.get(b"access_token"), TOKEN)
+            self.assertEqual(
+                headers.get(b"Authorization"), [f"Bearer {TOKEN}".encode()]
+            )
             self.request_url = url
             if url == URL_USER:
                 return SUCCESS_RESULT_USER
@@ -92,7 +98,7 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
                 )
 
         # We assign to a method, which mypy doesn't like.
-        self.api.get_json = Mock(side_effect=get_json)  # type: ignore[assignment]
+        self.api.get_json = Mock(side_effect=get_json)  # type: ignore[method-assign]
 
         result = self.get_success(
             self.api.query_3pe(self.service, "user", PROTOCOL, {b"some": [b"field"]})
@@ -107,10 +113,13 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
         self.assertEqual(self.request_url, URL_LOCATION)
         self.assertEqual(result, SUCCESS_RESULT_LOCATION)
 
-    def test_fallback(self) -> None:
+    @override_config({"use_appservice_legacy_authorization": True})
+    def test_query_3pe_authenticates_token_via_param(self) -> None:
         """
-        Tests that the fallback to legacy URLs works.
+        Tests that 3pe queries to the appservice are authenticated
+        with the appservice's token.
         """
+
         SUCCESS_RESULT_USER = [
             {
                 "protocol": PROTOCOL,
@@ -120,44 +129,63 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
                 },
             }
         ]
+        SUCCESS_RESULT_LOCATION = [
+            {
+                "protocol": PROTOCOL,
+                "alias": "#a:room",
+                "fields": {
+                    "more": "fields",
+                },
+            }
+        ]
 
         URL_USER = f"{URL}/_matrix/app/v1/thirdparty/user/{PROTOCOL}"
-        FALLBACK_URL_USER = f"{URL}/_matrix/app/unstable/thirdparty/user/{PROTOCOL}"
+        URL_LOCATION = f"{URL}/_matrix/app/v1/thirdparty/location/{PROTOCOL}"
 
         self.request_url = None
-        self.v1_seen = False
 
         async def get_json(
             url: str,
             args: Mapping[Any, Any],
-            headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]],
+            headers: Optional[
+                Mapping[Union[str, bytes], Sequence[Union[str, bytes]]]
+            ] = None,
         ) -> List[JsonDict]:
-            # Ensure the access token is passed as both a header and query arg.
-            if not headers.get("Authorization") or not args.get(b"access_token"):
-                raise RuntimeError("Access token not provided")
+            # Ensure the access token is passed as a both a query param and in the headers.
+            if not args.get(b"access_token"):
+                raise RuntimeError("Access token should be provided in query params.")
+            if not headers or not headers.get(b"Authorization"):
+                raise RuntimeError("Access token should be provided in auth headers.")
 
-            self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"])
             self.assertEqual(args.get(b"access_token"), TOKEN)
+            self.assertEqual(
+                headers.get(b"Authorization"), [f"Bearer {TOKEN}".encode()]
+            )
             self.request_url = url
             if url == URL_USER:
-                self.v1_seen = True
-                raise HttpResponseException(404, "NOT_FOUND", b"NOT_FOUND")
-            elif url == FALLBACK_URL_USER:
                 return SUCCESS_RESULT_USER
+            elif url == URL_LOCATION:
+                return SUCCESS_RESULT_LOCATION
             else:
                 raise RuntimeError(
                     "URL provided was invalid. This should never be seen."
                 )
 
         # We assign to a method, which mypy doesn't like.
-        self.api.get_json = Mock(side_effect=get_json)  # type: ignore[assignment]
+        self.api.get_json = Mock(side_effect=get_json)  # type: ignore[method-assign]
 
         result = self.get_success(
             self.api.query_3pe(self.service, "user", PROTOCOL, {b"some": [b"field"]})
         )
-        self.assertTrue(self.v1_seen)
-        self.assertEqual(self.request_url, FALLBACK_URL_USER)
+        self.assertEqual(self.request_url, URL_USER)
         self.assertEqual(result, SUCCESS_RESULT_USER)
+        result = self.get_success(
+            self.api.query_3pe(
+                self.service, "location", PROTOCOL, {b"some": [b"field"]}
+            )
+        )
+        self.assertEqual(self.request_url, URL_LOCATION)
+        self.assertEqual(result, SUCCESS_RESULT_LOCATION)
 
     def test_claim_keys(self) -> None:
         """
@@ -184,14 +212,16 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
             headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]],
         ) -> JsonDict:
             # Ensure the access token is passed as both a header and query arg.
-            if not headers.get("Authorization"):
+            if not headers.get(b"Authorization"):
                 raise RuntimeError("Access token not provided")
 
-            self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"])
+            self.assertEqual(
+                headers.get(b"Authorization"), [f"Bearer {TOKEN}".encode()]
+            )
             return RESPONSE
 
         # We assign to a method, which mypy doesn't like.
-        self.api.post_json_get_json = Mock(side_effect=post_json_get_json)  # type: ignore[assignment]
+        self.api.post_json_get_json = Mock(side_effect=post_json_get_json)  # type: ignore[method-assign]
 
         MISSING_KEYS = [
             # Known user, known device, missing algorithm.
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index dee976356f..6ac5fc1ae7 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -12,15 +12,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import re
-from typing import Generator
-from unittest.mock import Mock
+from typing import Any, Generator
+from unittest.mock import AsyncMock, Mock
 
 from twisted.internet import defer
 
 from synapse.appservice import ApplicationService, Namespace
 
 from tests import unittest
-from tests.test_utils import simple_async_mock
 
 
 def _regex(regex: str, exclusive: bool = True) -> Namespace:
@@ -43,21 +42,19 @@ class ApplicationServiceTestCase(unittest.TestCase):
         )
 
         self.store = Mock()
-        self.store.get_aliases_for_room = simple_async_mock([])
-        self.store.get_local_users_in_room = simple_async_mock([])
+        self.store.get_aliases_for_room = AsyncMock(return_value=[])
+        self.store.get_local_users_in_room = AsyncMock(return_value=[])
 
     @defer.inlineCallbacks
     def test_regex_user_id_prefix_match(
         self,
-    ) -> Generator["defer.Deferred[object]", object, None]:
+    ) -> Generator["defer.Deferred[Any]", object, None]:
         self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
         self.event.sender = "@irc_foobar:matrix.org"
         self.assertTrue(
             (
-                yield defer.ensureDeferred(
-                    self.service.is_interested_in_event(
-                        self.event.event_id, self.event, self.store
-                    )
+                yield self.service.is_interested_in_event(
+                    self.event.event_id, self.event, self.store
                 )
             )
         )
@@ -65,15 +62,13 @@ class ApplicationServiceTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_regex_user_id_prefix_no_match(
         self,
-    ) -> Generator["defer.Deferred[object]", object, None]:
+    ) -> Generator["defer.Deferred[Any]", object, None]:
         self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
         self.event.sender = "@someone_else:matrix.org"
         self.assertFalse(
             (
-                yield defer.ensureDeferred(
-                    self.service.is_interested_in_event(
-                        self.event.event_id, self.event, self.store
-                    )
+                yield self.service.is_interested_in_event(
+                    self.event.event_id, self.event, self.store
                 )
             )
         )
@@ -81,17 +76,15 @@ class ApplicationServiceTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_regex_room_member_is_checked(
         self,
-    ) -> Generator["defer.Deferred[object]", object, None]:
+    ) -> Generator["defer.Deferred[Any]", object, None]:
         self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
         self.event.sender = "@someone_else:matrix.org"
         self.event.type = "m.room.member"
         self.event.state_key = "@irc_foobar:matrix.org"
         self.assertTrue(
             (
-                yield defer.ensureDeferred(
-                    self.service.is_interested_in_event(
-                        self.event.event_id, self.event, self.store
-                    )
+                yield self.service.is_interested_in_event(
+                    self.event.event_id, self.event, self.store
                 )
             )
         )
@@ -99,17 +92,15 @@ class ApplicationServiceTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_regex_room_id_match(
         self,
-    ) -> Generator["defer.Deferred[object]", object, None]:
+    ) -> Generator["defer.Deferred[Any]", object, None]:
         self.service.namespaces[ApplicationService.NS_ROOMS].append(
             _regex("!some_prefix.*some_suffix:matrix.org")
         )
         self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
         self.assertTrue(
             (
-                yield defer.ensureDeferred(
-                    self.service.is_interested_in_event(
-                        self.event.event_id, self.event, self.store
-                    )
+                yield self.service.is_interested_in_event(
+                    self.event.event_id, self.event, self.store
                 )
             )
         )
@@ -117,38 +108,32 @@ class ApplicationServiceTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_regex_room_id_no_match(
         self,
-    ) -> Generator["defer.Deferred[object]", object, None]:
+    ) -> Generator["defer.Deferred[Any]", object, None]:
         self.service.namespaces[ApplicationService.NS_ROOMS].append(
             _regex("!some_prefix.*some_suffix:matrix.org")
         )
         self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
         self.assertFalse(
             (
-                yield defer.ensureDeferred(
-                    self.service.is_interested_in_event(
-                        self.event.event_id, self.event, self.store
-                    )
+                yield self.service.is_interested_in_event(
+                    self.event.event_id, self.event, self.store
                 )
             )
         )
 
     @defer.inlineCallbacks
-    def test_regex_alias_match(
-        self,
-    ) -> Generator["defer.Deferred[object]", object, None]:
+    def test_regex_alias_match(self) -> Generator["defer.Deferred[Any]", object, None]:
         self.service.namespaces[ApplicationService.NS_ALIASES].append(
             _regex("#irc_.*:matrix.org")
         )
-        self.store.get_aliases_for_room = simple_async_mock(
-            ["#irc_foobar:matrix.org", "#athing:matrix.org"]
+        self.store.get_aliases_for_room = AsyncMock(
+            return_value=["#irc_foobar:matrix.org", "#athing:matrix.org"]
         )
-        self.store.get_local_users_in_room = simple_async_mock([])
+        self.store.get_local_users_in_room = AsyncMock(return_value=[])
         self.assertTrue(
             (
-                yield defer.ensureDeferred(
-                    self.service.is_interested_in_event(
-                        self.event.event_id, self.event, self.store
-                    )
+                yield self.service.is_interested_in_event(
+                    self.event.event_id, self.event, self.store
                 )
             )
         )
@@ -192,14 +177,14 @@ class ApplicationServiceTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_regex_alias_no_match(
         self,
-    ) -> Generator["defer.Deferred[object]", object, None]:
+    ) -> Generator["defer.Deferred[Any]", object, None]:
         self.service.namespaces[ApplicationService.NS_ALIASES].append(
             _regex("#irc_.*:matrix.org")
         )
-        self.store.get_aliases_for_room = simple_async_mock(
-            ["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
+        self.store.get_aliases_for_room = AsyncMock(
+            return_value=["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
         )
-        self.store.get_local_users_in_room = simple_async_mock([])
+        self.store.get_local_users_in_room = AsyncMock(return_value=[])
         self.assertFalse(
             (
                 yield defer.ensureDeferred(
@@ -213,28 +198,26 @@ class ApplicationServiceTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_regex_multiple_matches(
         self,
-    ) -> Generator["defer.Deferred[object]", object, None]:
+    ) -> Generator["defer.Deferred[Any]", object, None]:
         self.service.namespaces[ApplicationService.NS_ALIASES].append(
             _regex("#irc_.*:matrix.org")
         )
         self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
         self.event.sender = "@irc_foobar:matrix.org"
-        self.store.get_aliases_for_room = simple_async_mock(["#irc_barfoo:matrix.org"])
-        self.store.get_local_users_in_room = simple_async_mock([])
+        self.store.get_aliases_for_room = AsyncMock(
+            return_value=["#irc_barfoo:matrix.org"]
+        )
+        self.store.get_local_users_in_room = AsyncMock(return_value=[])
         self.assertTrue(
             (
-                yield defer.ensureDeferred(
-                    self.service.is_interested_in_event(
-                        self.event.event_id, self.event, self.store
-                    )
+                yield self.service.is_interested_in_event(
+                    self.event.event_id, self.event, self.store
                 )
             )
         )
 
     @defer.inlineCallbacks
-    def test_interested_in_self(
-        self,
-    ) -> Generator["defer.Deferred[object]", object, None]:
+    def test_interested_in_self(self) -> Generator["defer.Deferred[Any]", object, None]:
         # make sure invites get through
         self.service.sender = "@appservice:name"
         self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
@@ -243,32 +226,26 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.event.state_key = self.service.sender
         self.assertTrue(
             (
-                yield defer.ensureDeferred(
-                    self.service.is_interested_in_event(
-                        self.event.event_id, self.event, self.store
-                    )
+                yield self.service.is_interested_in_event(
+                    self.event.event_id, self.event, self.store
                 )
             )
         )
 
     @defer.inlineCallbacks
-    def test_member_list_match(
-        self,
-    ) -> Generator["defer.Deferred[object]", object, None]:
+    def test_member_list_match(self) -> Generator["defer.Deferred[Any]", object, None]:
         self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
         # Note that @irc_fo:here is the AS user.
-        self.store.get_local_users_in_room = simple_async_mock(
-            ["@alice:here", "@irc_fo:here", "@bob:here"]
+        self.store.get_local_users_in_room = AsyncMock(
+            return_value=["@alice:here", "@irc_fo:here", "@bob:here"]
         )
-        self.store.get_aliases_for_room = simple_async_mock([])
+        self.store.get_aliases_for_room = AsyncMock(return_value=[])
 
         self.event.sender = "@xmpp_foobar:matrix.org"
         self.assertTrue(
             (
-                yield defer.ensureDeferred(
-                    self.service.is_interested_in_event(
-                        self.event.event_id, self.event, self.store
-                    )
+                yield self.service.is_interested_in_event(
+                    self.event.event_id, self.event, self.store
                 )
             )
         )
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index e2a3bad065..445919417e 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from typing import List, Optional, Sequence, Tuple, cast
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from typing_extensions import TypeAlias
 
@@ -37,7 +37,6 @@ from synapse.types import DeviceListUpdates, JsonDict
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import simple_async_mock
 
 from ..utils import MockClock
 
@@ -62,10 +61,12 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
         txn = Mock(id=txn_id, service=service, events=events)
 
         # mock methods
-        self.store.get_appservice_state = simple_async_mock(ApplicationServiceState.UP)
-        txn.send = simple_async_mock(True)
-        txn.complete = simple_async_mock(True)
-        self.store.create_appservice_txn = simple_async_mock(txn)
+        self.store.get_appservice_state = AsyncMock(
+            return_value=ApplicationServiceState.UP
+        )
+        txn.send = AsyncMock(return_value=True)
+        txn.complete = AsyncMock(return_value=True)
+        self.store.create_appservice_txn = AsyncMock(return_value=txn)
 
         # actual call
         self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
@@ -89,10 +90,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
         events = [Mock(), Mock()]
 
         txn = Mock(id="idhere", service=service, events=events)
-        self.store.get_appservice_state = simple_async_mock(
-            ApplicationServiceState.DOWN
+        self.store.get_appservice_state = AsyncMock(
+            return_value=ApplicationServiceState.DOWN
         )
-        self.store.create_appservice_txn = simple_async_mock(txn)
+        self.store.create_appservice_txn = AsyncMock(return_value=txn)
 
         # actual call
         self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
@@ -118,10 +119,12 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
         txn = Mock(id=txn_id, service=service, events=events)
 
         # mock methods
-        self.store.get_appservice_state = simple_async_mock(ApplicationServiceState.UP)
-        self.store.set_appservice_state = simple_async_mock(True)
-        txn.send = simple_async_mock(False)  # fails to send
-        self.store.create_appservice_txn = simple_async_mock(txn)
+        self.store.get_appservice_state = AsyncMock(
+            return_value=ApplicationServiceState.UP
+        )
+        self.store.set_appservice_state = AsyncMock(return_value=True)
+        txn.send = AsyncMock(return_value=False)  # fails to send
+        self.store.create_appservice_txn = AsyncMock(return_value=txn)
 
         # actual call
         self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
@@ -150,7 +153,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
         self.as_api = Mock()
         self.store = Mock()
         self.service = Mock()
-        self.callback = simple_async_mock()
+        self.callback = AsyncMock()
         self.recoverer = _Recoverer(
             clock=cast(Clock, self.clock),
             as_api=self.as_api,
@@ -174,8 +177,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
         self.recoverer.recover()
         # shouldn't have called anything prior to waiting for exp backoff
         self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count)
-        txn.send = simple_async_mock(True)
-        txn.complete = simple_async_mock(None)
+        txn.send = AsyncMock(return_value=True)
+        txn.complete = AsyncMock(return_value=None)
         # wait for exp backoff
         self.clock.advance_time(2)
         self.assertEqual(1, txn.send.call_count)
@@ -202,8 +205,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
 
         self.recoverer.recover()
         self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count)
-        txn.send = simple_async_mock(False)
-        txn.complete = simple_async_mock(None)
+        txn.send = AsyncMock(return_value=False)
+        txn.complete = AsyncMock(return_value=None)
         self.clock.advance_time(2)
         self.assertEqual(1, txn.send.call_count)
         self.assertEqual(0, txn.complete.call_count)
@@ -216,7 +219,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
         self.assertEqual(3, txn.send.call_count)
         self.assertEqual(0, txn.complete.call_count)
         self.assertEqual(0, self.callback.call_count)
-        txn.send = simple_async_mock(True)  # successfully send the txn
+        txn.send = AsyncMock(return_value=True)  # successfully send the txn
         pop_txn = True  # returns the txn the first time, then no more.
         self.clock.advance_time(16)
         self.assertEqual(1, txn.send.call_count)  # new mock reset call count
@@ -244,7 +247,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor: "MemoryReactor", clock: Clock, hs: HomeServer) -> None:
         self.scheduler = ApplicationServiceScheduler(hs)
         self.txn_ctrl = Mock()
-        self.txn_ctrl.send = simple_async_mock()
+        self.txn_ctrl.send = AsyncMock()
 
         # Replace instantiated _TransactionController instances with our Mock
         self.scheduler.txn_ctrl = self.txn_ctrl
diff --git a/tests/config/test_oauth_delegation.py b/tests/config/test_oauth_delegation.py
new file mode 100644
index 0000000000..5c91031746
--- /dev/null
+++ b/tests/config/test_oauth_delegation.py
@@ -0,0 +1,278 @@
+# Copyright 2023 Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from unittest.mock import Mock
+
+from synapse.config import ConfigError
+from synapse.config.homeserver import HomeServerConfig
+from synapse.module_api import ModuleApi
+from synapse.types import JsonDict
+
+from tests.server import get_clock, setup_test_homeserver
+from tests.unittest import TestCase, skip_unless
+from tests.utils import default_config
+
+try:
+    import authlib  # noqa: F401
+
+    HAS_AUTHLIB = True
+except ImportError:
+    HAS_AUTHLIB = False
+
+
+# These are a few constants that are used as config parameters in the tests.
+SERVER_NAME = "test"
+ISSUER = "https://issuer/"
+CLIENT_ID = "test-client-id"
+CLIENT_SECRET = "test-client-secret"
+BASE_URL = "https://synapse/"
+
+
+class CustomAuthModule:
+    """A module which registers a password auth provider."""
+
+    @staticmethod
+    def parse_config(config: JsonDict) -> None:
+        pass
+
+    def __init__(self, config: None, api: ModuleApi):
+        api.register_password_auth_provider_callbacks(
+            auth_checkers={("m.login.password", ("password",)): Mock()},
+        )
+
+
+@skip_unless(HAS_AUTHLIB, "requires authlib")
+class MSC3861OAuthDelegation(TestCase):
+    """Test that the Homeserver fails to initialize if the config is invalid."""
+
+    def setUp(self) -> None:
+        self.config_dict: JsonDict = {
+            **default_config("test"),
+            "public_baseurl": BASE_URL,
+            "enable_registration": False,
+            "experimental_features": {
+                "msc3861": {
+                    "enabled": True,
+                    "issuer": ISSUER,
+                    "client_id": CLIENT_ID,
+                    "client_auth_method": "client_secret_post",
+                    "client_secret": CLIENT_SECRET,
+                }
+            },
+        }
+
+    def parse_config(self) -> HomeServerConfig:
+        config = HomeServerConfig()
+        config.parse_config_dict(self.config_dict, "", "")
+        return config
+
+    def test_client_secret_post_works(self) -> None:
+        self.config_dict["experimental_features"]["msc3861"].update(
+            client_auth_method="client_secret_post",
+            client_secret=CLIENT_SECRET,
+        )
+
+        self.parse_config()
+
+    def test_client_secret_post_requires_client_secret(self) -> None:
+        self.config_dict["experimental_features"]["msc3861"].update(
+            client_auth_method="client_secret_post",
+            client_secret=None,
+        )
+
+        with self.assertRaises(ConfigError):
+            self.parse_config()
+
+    def test_client_secret_basic_works(self) -> None:
+        self.config_dict["experimental_features"]["msc3861"].update(
+            client_auth_method="client_secret_basic",
+            client_secret=CLIENT_SECRET,
+        )
+
+        self.parse_config()
+
+    def test_client_secret_basic_requires_client_secret(self) -> None:
+        self.config_dict["experimental_features"]["msc3861"].update(
+            client_auth_method="client_secret_basic",
+            client_secret=None,
+        )
+
+        with self.assertRaises(ConfigError):
+            self.parse_config()
+
+    def test_client_secret_jwt_works(self) -> None:
+        self.config_dict["experimental_features"]["msc3861"].update(
+            client_auth_method="client_secret_jwt",
+            client_secret=CLIENT_SECRET,
+        )
+
+        self.parse_config()
+
+    def test_client_secret_jwt_requires_client_secret(self) -> None:
+        self.config_dict["experimental_features"]["msc3861"].update(
+            client_auth_method="client_secret_jwt",
+            client_secret=None,
+        )
+
+        with self.assertRaises(ConfigError):
+            self.parse_config()
+
+    def test_invalid_client_auth_method(self) -> None:
+        self.config_dict["experimental_features"]["msc3861"].update(
+            client_auth_method="invalid",
+        )
+
+        with self.assertRaises(ConfigError):
+            self.parse_config()
+
+    def test_private_key_jwt_requires_jwk(self) -> None:
+        self.config_dict["experimental_features"]["msc3861"].update(
+            client_auth_method="private_key_jwt",
+        )
+
+        with self.assertRaises(ConfigError):
+            self.parse_config()
+
+    def test_private_key_jwt_works(self) -> None:
+        self.config_dict["experimental_features"]["msc3861"].update(
+            client_auth_method="private_key_jwt",
+            jwk={
+                "p": "-frVdP_tZ-J_nIR6HNMDq1N7aunwm51nAqNnhqIyuA8ikx7LlQED1tt2LD3YEvYyW8nxE2V95HlCRZXQPMiRJBFOsbmYkzl2t-MpavTaObB_fct_JqcRtdXddg4-_ihdjRDwUOreq_dpWh6MIKsC3UyekfkHmeEJg5YpOTL15j8",
+                "kty": "RSA",
+                "q": "oFw-Enr_YozQB1ab-kawn4jY3yHi8B1nSmYT0s8oTCflrmps5BFJfCkHL5ij3iY15z0o2m0N-jjB1oSJ98O4RayEEYNQlHnTNTl0kRIWzpoqblHUIxVcahIpP_xTovBJzwi8XXoLGqHOOMA-r40LSyVgP2Ut8D9qBwV6_UfT0LU",
+                "d": "WFkDPYo4b4LIS64D_QtQfGGuAObPvc3HFfp9VZXyq3SJR58XZRHE0jqtlEMNHhOTgbMYS3w8nxPQ_qVzY-5hs4fIanwvB64mAoOGl0qMHO65DTD_WsGFwzYClJPBVniavkLE2Hmpu8IGe6lGliN8vREC6_4t69liY-XcN_ECboVtC2behKkLOEASOIMuS7YcKAhTJFJwkl1dqDlliEn5A4u4xy7nuWQz3juB1OFdKlwGA5dfhDNglhoLIwNnkLsUPPFO-WB5ZNEW35xxHOToxj4bShvDuanVA6mJPtTKjz0XibjB36bj_nF_j7EtbE2PdGJ2KevAVgElR4lqS4ISgQ",
+                "e": "AQAB",
+                "kid": "test",
+                "qi": "cPfNk8l8W5exVNNea4d7QZZ8Qr8LgHghypYAxz8PQh1fNa8Ya1SNUDVzC2iHHhszxxA0vB9C7jGze8dBrvnzWYF1XvQcqNIVVgHhD57R1Nm3dj2NoHIKe0Cu4bCUtP8xnZQUN4KX7y4IIcgRcBWG1hT6DEYZ4BxqicnBXXNXAUI",
+                "dp": "dKlMHvslV1sMBQaKWpNb3gPq0B13TZhqr3-E2_8sPlvJ3fD8P4CmwwnOn50JDuhY3h9jY5L06sBwXjspYISVv8hX-ndMLkEeF3lrJeA5S70D8rgakfZcPIkffm3tlf1Ok3v5OzoxSv3-67Df4osMniyYwDUBCB5Oq1tTx77xpU8",
+                "dq": "S4ooU1xNYYcjl9FcuJEEMqKsRrAXzzSKq6laPTwIp5dDwt2vXeAm1a4eDHXC-6rUSZGt5PbqVqzV4s-cjnJMI8YYkIdjNg4NSE1Ac_YpeDl3M3Colb5CQlU7yUB7xY2bt0NOOFp9UJZYJrOo09mFMGjy5eorsbitoZEbVqS3SuE",
+                "n": "nJbYKqFwnURKimaviyDFrNLD3gaKR1JW343Qem25VeZxoMq1665RHVoO8n1oBm4ClZdjIiZiVdpyqzD5-Ow12YQgQEf1ZHP3CCcOQQhU57Rh5XvScTe5IxYVkEW32IW2mp_CJ6WfjYpfeL4azarVk8H3Vr59d1rSrKTVVinVdZer9YLQyC_rWAQNtHafPBMrf6RYiNGV9EiYn72wFIXlLlBYQ9Fx7bfe1PaL6qrQSsZP3_rSpuvVdLh1lqGeCLR0pyclA9uo5m2tMyCXuuGQLbA_QJm5xEc7zd-WFdux2eXF045oxnSZ_kgQt-pdN7AxGWOVvwoTf9am6mSkEdv6iw",
+            },
+        )
+        self.parse_config()
+
+    def test_registration_cannot_be_enabled(self) -> None:
+        self.config_dict["enable_registration"] = True
+        with self.assertRaises(ConfigError):
+            self.parse_config()
+
+    def test_user_consent_cannot_be_enabled(self) -> None:
+        tmpdir = self.mktemp()
+        os.mkdir(tmpdir)
+        self.config_dict["user_consent"] = {
+            "require_at_registration": True,
+            "version": "1",
+            "template_dir": tmpdir,
+            "server_notice_content": {
+                "msgtype": "m.text",
+                "body": "foo",
+            },
+        }
+        with self.assertRaises(ConfigError):
+            self.parse_config()
+
+    def test_password_config_cannot_be_enabled(self) -> None:
+        self.config_dict["password_config"] = {"enabled": True}
+        with self.assertRaises(ConfigError):
+            self.parse_config()
+
+    def test_oidc_sso_cannot_be_enabled(self) -> None:
+        self.config_dict["oidc_providers"] = [
+            {
+                "idp_id": "microsoft",
+                "idp_name": "Microsoft",
+                "issuer": "https://login.microsoftonline.com/<tenant id>/v2.0",
+                "client_id": "<client id>",
+                "client_secret": "<client secret>",
+                "scopes": ["openid", "profile"],
+                "authorization_endpoint": "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/authorize",
+                "token_endpoint": "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/token",
+                "userinfo_endpoint": "https://graph.microsoft.com/oidc/userinfo",
+            }
+        ]
+
+        with self.assertRaises(ConfigError):
+            self.parse_config()
+
+    def test_cas_sso_cannot_be_enabled(self) -> None:
+        self.config_dict["cas_config"] = {
+            "enabled": True,
+            "server_url": "https://cas-server.com",
+            "displayname_attribute": "name",
+            "required_attributes": {"userGroup": "staff", "department": "None"},
+        }
+
+        with self.assertRaises(ConfigError):
+            self.parse_config()
+
+    def test_auth_providers_cannot_be_enabled(self) -> None:
+        self.config_dict["modules"] = [
+            {
+                "module": f"{__name__}.{CustomAuthModule.__qualname__}",
+                "config": {},
+            }
+        ]
+
+        # This requires actually setting up an HS, as the module will be run on setup,
+        # which should raise as the module tries to register an auth provider
+        config = self.parse_config()
+        reactor, clock = get_clock()
+        with self.assertRaises(ConfigError):
+            setup_test_homeserver(
+                self.addCleanup, reactor=reactor, clock=clock, config=config
+            )
+
+    def test_jwt_auth_cannot_be_enabled(self) -> None:
+        self.config_dict["jwt_config"] = {
+            "enabled": True,
+            "secret": "my-secret-token",
+            "algorithm": "HS256",
+        }
+
+        with self.assertRaises(ConfigError):
+            self.parse_config()
+
+    def test_login_via_existing_session_cannot_be_enabled(self) -> None:
+        self.config_dict["login_via_existing_session"] = {"enabled": True}
+        with self.assertRaises(ConfigError):
+            self.parse_config()
+
+    def test_captcha_cannot_be_enabled(self) -> None:
+        self.config_dict.update(
+            enable_registration_captcha=True,
+            recaptcha_public_key="test",
+            recaptcha_private_key="test",
+        )
+        with self.assertRaises(ConfigError):
+            self.parse_config()
+
+    def test_refreshable_tokens_cannot_be_enabled(self) -> None:
+        self.config_dict.update(
+            refresh_token_lifetime="24h",
+            refreshable_access_token_lifetime="10m",
+            nonrefreshable_access_token_lifetime="24h",
+        )
+        with self.assertRaises(ConfigError):
+            self.parse_config()
+
+    def test_session_lifetime_cannot_be_set(self) -> None:
+        self.config_dict["session_lifetime"] = "24h"
+        with self.assertRaises(ConfigError):
+            self.parse_config()
+
+    def test_enable_3pid_changes_cannot_be_enabled(self) -> None:
+        self.config_dict["enable_3pid_changes"] = True
+        with self.assertRaises(ConfigError):
+            self.parse_config()
diff --git a/tests/config/test_ratelimiting.py b/tests/config/test_ratelimiting.py
index f12147eaa0..0c27dd21e2 100644
--- a/tests/config/test_ratelimiting.py
+++ b/tests/config/test_ratelimiting.py
@@ -12,11 +12,42 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from synapse.config.homeserver import HomeServerConfig
+from synapse.config.ratelimiting import RatelimitSettings
 
 from tests.unittest import TestCase
 from tests.utils import default_config
 
 
+class ParseRatelimitSettingsTestcase(TestCase):
+    def test_depth_1(self) -> None:
+        cfg = {
+            "a": {
+                "per_second": 5,
+                "burst_count": 10,
+            }
+        }
+        parsed = RatelimitSettings.parse(cfg, "a")
+        self.assertEqual(parsed, RatelimitSettings("a", 5, 10))
+
+    def test_depth_2(self) -> None:
+        cfg = {
+            "a": {
+                "b": {
+                    "per_second": 5,
+                    "burst_count": 10,
+                },
+            }
+        }
+        parsed = RatelimitSettings.parse(cfg, "a.b")
+        self.assertEqual(parsed, RatelimitSettings("a.b", 5, 10))
+
+    def test_missing(self) -> None:
+        parsed = RatelimitSettings.parse(
+            {}, "a", defaults={"per_second": 5, "burst_count": 10}
+        )
+        self.assertEqual(parsed, RatelimitSettings("a", 5, 10))
+
+
 class RatelimitConfigTestCase(TestCase):
     def test_parse_rc_federation(self) -> None:
         config_dict = default_config("test")
diff --git a/tests/config/test_workers.py b/tests/config/test_workers.py
index 49a6bdf408..2a643ae4f3 100644
--- a/tests/config/test_workers.py
+++ b/tests/config/test_workers.py
@@ -94,6 +94,7 @@ class WorkerDutyConfigTestCase(TestCase):
                 # so that it doesn't raise an exception here.
                 # (This is not read by `_should_this_worker_perform_duty`.)
                 "notify_appservices": False,
+                "instance_map": {"main": {"host": "127.0.0.1", "port": 0}},
             },
         )
 
@@ -138,7 +139,9 @@ class WorkerDutyConfigTestCase(TestCase):
         """
 
         main_process_config = self._make_worker_config(
-            worker_app="synapse.app.homeserver", worker_name=None
+            worker_app="synapse.app.homeserver",
+            worker_name=None,
+            extras={"instance_map": {"main": {"host": "127.0.0.1", "port": 0}}},
         )
 
         self.assertTrue(
@@ -203,6 +206,7 @@ class WorkerDutyConfigTestCase(TestCase):
                 # so that it doesn't raise an exception here.
                 # (This is not read by `_should_this_worker_perform_duty`.)
                 "notify_appservices": False,
+                "instance_map": {"main": {"host": "127.0.0.1", "port": 0}},
             },
         )
 
@@ -236,7 +240,9 @@ class WorkerDutyConfigTestCase(TestCase):
         Tests new config options. This is for the master's config.
         """
         main_process_config = self._make_worker_config(
-            worker_app="synapse.app.homeserver", worker_name=None
+            worker_app="synapse.app.homeserver",
+            worker_name=None,
+            extras={"instance_map": {"main": {"host": "127.0.0.1", "port": 0}}},
         )
 
         self.assertTrue(
@@ -262,7 +268,9 @@ class WorkerDutyConfigTestCase(TestCase):
         Tests new config options. This is for the worker's config.
         """
         appservice_worker_config = self._make_worker_config(
-            worker_app="synapse.app.generic_worker", worker_name="worker1"
+            worker_app="synapse.app.generic_worker",
+            worker_name="worker1",
+            extras={"instance_map": {"main": {"host": "127.0.0.1", "port": 0}}},
         )
 
         self.assertTrue(
@@ -298,6 +306,7 @@ class WorkerDutyConfigTestCase(TestCase):
             extras={
                 "notify_appservices_from_worker": "worker2",
                 "update_user_directory_from_worker": "worker1",
+                "instance_map": {"main": {"host": "127.0.0.1", "port": 0}},
             },
         )
         self.assertFalse(worker1_config.should_notify_appservices)
@@ -309,6 +318,7 @@ class WorkerDutyConfigTestCase(TestCase):
             extras={
                 "notify_appservices_from_worker": "worker2",
                 "update_user_directory_from_worker": "worker1",
+                "instance_map": {"main": {"host": "127.0.0.1", "port": 0}},
             },
         )
         self.assertTrue(worker2_config.should_notify_appservices)
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 7c63b2ea4c..c5700771b0 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -45,7 +45,6 @@ from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 from tests.unittest import logcontext_clean, override_config
 
 
@@ -190,23 +189,24 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         kr = keyring.Keyring(self.hs)
 
         key1 = signedjson.key.generate_signing_key("1")
-        r = self.hs.get_datastores().main.store_server_keys_json(
+        r = self.hs.get_datastores().main.store_server_keys_response(
             "server9",
-            get_key_id(key1),
             from_server="test",
-            ts_now_ms=int(time.time() * 1000),
-            ts_expires_ms=1000,
+            ts_added_ms=int(time.time() * 1000),
+            verify_keys={
+                get_key_id(key1): FetchKeyResult(
+                    verify_key=get_verify_key(key1), valid_until_ts=1000
+                )
+            },
             # The entire response gets signed & stored, just include the bits we
             # care about.
-            key_json_bytes=canonicaljson.encode_canonical_json(
-                {
-                    "verify_keys": {
-                        get_key_id(key1): {
-                            "key": encode_verify_key_base64(get_verify_key(key1))
-                        }
+            response_json={
+                "verify_keys": {
+                    get_key_id(key1): {
+                        "key": encode_verify_key_base64(get_verify_key(key1))
                     }
                 }
-            ),
+            },
         )
         self.get_success(r)
 
@@ -286,34 +286,6 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         d = kr.verify_json_for_server(self.hs.hostname, json1, 0)
         self.get_success(d)
 
-    def test_verify_json_for_server_with_null_valid_until_ms(self) -> None:
-        """Tests that we correctly handle key requests for keys we've stored
-        with a null `ts_valid_until_ms`
-        """
-        mock_fetcher = Mock()
-        mock_fetcher.get_keys = Mock(return_value=make_awaitable({}))
-
-        key1 = signedjson.key.generate_signing_key("1")
-        r = self.hs.get_datastores().main.store_server_signature_keys(
-            "server9",
-            int(time.time() * 1000),
-            # None is not a valid value in FetchKeyResult, but we're abusing this
-            # API to insert null values into the database. The nulls get converted
-            # to 0 when fetched in KeyStore.get_server_signature_keys.
-            {("server9", get_key_id(key1)): FetchKeyResult(get_verify_key(key1), None)},  # type: ignore[arg-type]
-        )
-        self.get_success(r)
-
-        json1: JsonDict = {}
-        signedjson.sign.sign_json(json1, "server9", key1)
-
-        # should succeed on a signed object with a 0 minimum_valid_until_ms
-        d = self.hs.get_datastores().main.get_server_signature_keys(
-            [("server9", get_key_id(key1))]
-        )
-        result = self.get_success(d)
-        self.assertEquals(result[("server9", get_key_id(key1))].valid_until_ts, 0)
-
     def test_verify_json_dedupes_key_requests(self) -> None:
         """Two requests for the same key should be deduped."""
         key1 = signedjson.key.generate_signing_key("1")
@@ -456,24 +428,19 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
         self.assertEqual(k.verify_key.version, "ver1")
 
         # check that the perspectives store is correctly updated
-        lookup_triplet = (SERVER_NAME, testverifykey_id, None)
         key_json = self.get_success(
             self.hs.get_datastores().main.get_server_keys_json_for_remote(
-                [lookup_triplet]
+                SERVER_NAME, [testverifykey_id]
             )
         )
-        res_keys = key_json[lookup_triplet]
-        self.assertEqual(len(res_keys), 1)
-        res = res_keys[0]
-        self.assertEqual(res["key_id"], testverifykey_id)
-        self.assertEqual(res["from_server"], SERVER_NAME)
-        self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
-        self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
+        res = key_json[testverifykey_id]
+        self.assertIsNotNone(res)
+        assert res is not None
+        self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
+        self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
 
         # we expect it to be encoded as canonical json *before* it hits the db
-        self.assertEqual(
-            bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
-        )
+        self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
 
         # change the server name: the result should be ignored
         response["server_name"] = "OTHER_SERVER"
@@ -576,23 +543,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
         self.assertEqual(k.verify_key.version, "ver1")
 
         # check that the perspectives store is correctly updated
-        lookup_triplet = (SERVER_NAME, testverifykey_id, None)
         key_json = self.get_success(
             self.hs.get_datastores().main.get_server_keys_json_for_remote(
-                [lookup_triplet]
+                SERVER_NAME, [testverifykey_id]
             )
         )
-        res_keys = key_json[lookup_triplet]
-        self.assertEqual(len(res_keys), 1)
-        res = res_keys[0]
-        self.assertEqual(res["key_id"], testverifykey_id)
-        self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
-        self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
-        self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
-
-        self.assertEqual(
-            bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
-        )
+        res = key_json[testverifykey_id]
+        self.assertIsNotNone(res)
+        assert res is not None
+        self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
+        self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
+
+        self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
 
     def test_get_multiple_keys_from_perspectives(self) -> None:
         """Check that we can correctly request multiple keys for the same server"""
@@ -699,23 +661,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
         self.assertEqual(k.verify_key.version, "ver1")
 
         # check that the perspectives store is correctly updated
-        lookup_triplet = (SERVER_NAME, testverifykey_id, None)
         key_json = self.get_success(
             self.hs.get_datastores().main.get_server_keys_json_for_remote(
-                [lookup_triplet]
+                SERVER_NAME, [testverifykey_id]
             )
         )
-        res_keys = key_json[lookup_triplet]
-        self.assertEqual(len(res_keys), 1)
-        res = res_keys[0]
-        self.assertEqual(res["key_id"], testverifykey_id)
-        self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
-        self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
-        self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
-
-        self.assertEqual(
-            bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
-        )
+        res = key_json[testverifykey_id]
+        self.assertIsNotNone(res)
+        assert res is not None
+        self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
+        self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
+
+        self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
 
     def test_invalid_perspectives_responses(self) -> None:
         """Check that invalid responses from the perspectives server are rejected"""
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index 6fb1f1bd6e..0fcfe38efa 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 import attr
 
@@ -30,7 +30,6 @@ from synapse.types import JsonDict, StreamToken, create_requester
 from synapse.util import Clock
 
 from tests.handlers.test_sync import generate_sync_config
-from tests.test_utils import simple_async_mock
 from tests.unittest import (
     FederatingHomeserverTestCase,
     HomeserverTestCase,
@@ -157,7 +156,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         # Mock out the calls over federation.
         self.fed_transport_client = Mock(spec=["send_transaction"])
-        self.fed_transport_client.send_transaction = simple_async_mock({})
+        self.fed_transport_client.send_transaction = AsyncMock(return_value={})
 
         hs = self.setup_test_homeserver(
             federation_transport_client=self.fed_transport_client,
diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
index 6687c28e8f..b5e42f9600 100644
--- a/tests/events/test_snapshot.py
+++ b/tests/events/test_snapshot.py
@@ -101,8 +101,7 @@ class TestEventContext(unittest.HomeserverTestCase):
         self.assertEqual(
             context.state_group_before_event, d_context.state_group_before_event
         )
-        self.assertEqual(context.prev_group, d_context.prev_group)
-        self.assertEqual(context.delta_ids, d_context.delta_ids)
+        self.assertEqual(context.state_group_deltas, d_context.state_group_deltas)
         self.assertEqual(context.app_service, d_context.app_service)
 
         self.assertEqual(
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index e40eac2eb0..978612e432 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -16,6 +16,7 @@ import unittest as stdlib_unittest
 from typing import Any, List, Mapping, Optional
 
 import attr
+from parameterized import parameterized
 
 from synapse.api.constants import EventContentFields
 from synapse.api.room_versions import RoomVersions
@@ -23,6 +24,7 @@ from synapse.events import EventBase, make_event_from_dict
 from synapse.events.utils import (
     PowerLevelsContent,
     SerializeEventConfig,
+    _split_field,
     copy_and_fixup_power_levels_contents,
     maybe_upsert_event_field,
     prune_event,
@@ -138,18 +140,16 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
             },
         )
 
-        # As of MSC2176 we now redact the membership and prev_states keys.
+        # As of room versions we now redact the membership, prev_states, and origin keys.
         self.run_test(
-            {"type": "A", "prev_state": "prev_state", "membership": "join"},
-            {"type": "A", "content": {}, "signatures": {}, "unsigned": {}},
-            room_version=RoomVersions.MSC2176,
-        )
-
-        # As of MSC3989 we now redact the origin key.
-        self.run_test(
-            {"type": "A", "origin": "example.com"},
+            {
+                "type": "A",
+                "prev_state": "prev_state",
+                "membership": "join",
+                "origin": "example.com",
+            },
             {"type": "A", "content": {}, "signatures": {}, "unsigned": {}},
-            room_version=RoomVersions.MSC3989,
+            room_version=RoomVersions.V11,
         )
 
     def test_unsigned(self) -> None:
@@ -225,16 +225,21 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
             },
         )
 
-        # After MSC2176, create events get nothing redacted.
+        # After MSC2176, create events should preserve field `content`
         self.run_test(
-            {"type": "m.room.create", "content": {"not_a_real_key": True}},
+            {
+                "type": "m.room.create",
+                "content": {"not_a_real_key": True},
+                "origin": "some_homeserver",
+                "nonsense_field": "some_random_garbage",
+            },
             {
                 "type": "m.room.create",
                 "content": {"not_a_real_key": True},
                 "signatures": {},
                 "unsigned": {},
             },
-            room_version=RoomVersions.MSC2176,
+            room_version=RoomVersions.V11,
         )
 
     def test_power_levels(self) -> None:
@@ -284,7 +289,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
                 "signatures": {},
                 "unsigned": {},
             },
-            room_version=RoomVersions.MSC2176,
+            room_version=RoomVersions.V11,
         )
 
     def test_alias_event(self) -> None:
@@ -347,7 +352,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
                 "signatures": {},
                 "unsigned": {},
             },
-            room_version=RoomVersions.MSC2176,
+            room_version=RoomVersions.V11,
         )
 
     def test_join_rules(self) -> None:
@@ -470,7 +475,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
                 "signatures": {},
                 "unsigned": {},
             },
-            room_version=RoomVersions.MSC3821,
+            room_version=RoomVersions.V11,
         )
 
         # Ensure this doesn't break if an invalid field is sent.
@@ -489,7 +494,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
                 "signatures": {},
                 "unsigned": {},
             },
-            room_version=RoomVersions.MSC3821,
+            room_version=RoomVersions.V11,
         )
 
         self.run_test(
@@ -507,7 +512,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
                 "signatures": {},
                 "unsigned": {},
             },
-            room_version=RoomVersions.MSC3821,
+            room_version=RoomVersions.V11,
         )
 
     def test_relations(self) -> None:
@@ -794,3 +799,40 @@ class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase):
     def test_invalid_nesting_raises_type_error(self) -> None:
         with self.assertRaises(TypeError):
             copy_and_fixup_power_levels_contents({"a": {"b": {"c": 1}}})  # type: ignore[dict-item]
+
+
+class SplitFieldTestCase(stdlib_unittest.TestCase):
+    @parameterized.expand(
+        [
+            # A field with no dots.
+            ["m", ["m"]],
+            # Simple dotted fields.
+            ["m.foo", ["m", "foo"]],
+            ["m.foo.bar", ["m", "foo", "bar"]],
+            # Backslash is used as an escape character.
+            [r"m\.foo", ["m.foo"]],
+            [r"m\\.foo", ["m\\", "foo"]],
+            [r"m\\\.foo", [r"m\.foo"]],
+            [r"m\\\\.foo", ["m\\\\", "foo"]],
+            [r"m\foo", [r"m\foo"]],
+            [r"m\\foo", [r"m\foo"]],
+            [r"m\\\foo", [r"m\\foo"]],
+            [r"m\\\\foo", [r"m\\foo"]],
+            # Ensure that escapes at the end don't cause issues.
+            ["m.foo\\", ["m", "foo\\"]],
+            ["m.foo\\", ["m", "foo\\"]],
+            [r"m.foo\.", ["m", "foo."]],
+            [r"m.foo\\.", ["m", "foo\\", ""]],
+            [r"m.foo\\\.", ["m", r"foo\."]],
+            # Empty parts (corresponding to properties which are an empty string) are allowed.
+            [".m", ["", "m"]],
+            ["..m", ["", "", "m"]],
+            ["m.", ["m", ""]],
+            ["m..", ["m", "", ""]],
+            ["m..foo", ["m", "", "foo"]],
+            # Invalid escape sequences.
+            [r"\m", [r"\m"]],
+        ]
+    )
+    def test_split_field(self, input: str, expected: str) -> None:
+        self.assertEqual(_split_field(input), expected)
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 129d7cfd93..73a2766baf 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from unittest.mock import Mock
+from unittest.mock import AsyncMock
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.rest import admin
@@ -20,7 +20,6 @@ from synapse.rest.client import login, room
 from synapse.types import JsonDict, UserID, create_requester
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 
 
 class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
@@ -58,7 +57,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         async def get_current_state_event_counts(room_id: str) -> int:
             return int(500 * 1.23)
 
-        store.get_current_state_event_counts = get_current_state_event_counts  # type: ignore[assignment]
+        store.get_current_state_event_counts = get_current_state_event_counts  # type: ignore[method-assign]
 
         # Get the room complexity again -- make sure it's our artificial value
         channel = self.make_signed_federation_request(
@@ -75,9 +74,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))  # type: ignore[assignment]
-        handler.federation_handler.do_invite_join = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(("", 1))
+        fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999})  # type: ignore[method-assign]
+        handler.federation_handler.do_invite_join = AsyncMock(  # type: ignore[method-assign]
+            return_value=("", 1)
         )
 
         d = handler._remote_join(
@@ -106,9 +105,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))  # type: ignore[assignment]
-        handler.federation_handler.do_invite_join = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(("", 1))
+        fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999})  # type: ignore[method-assign]
+        handler.federation_handler.do_invite_join = AsyncMock(  # type: ignore[method-assign]
+            return_value=("", 1)
         )
 
         d = handler._remote_join(
@@ -143,16 +142,16 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=make_awaitable(None))  # type: ignore[assignment]
-        handler.federation_handler.do_invite_join = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(("", 1))
+        fed_transport.client.get_json = AsyncMock(return_value=None)  # type: ignore[method-assign]
+        handler.federation_handler.do_invite_join = AsyncMock(  # type: ignore[method-assign]
+            return_value=("", 1)
         )
 
         # Artificially raise the complexity
         async def get_current_state_event_counts(room_id: str) -> int:
             return 600
 
-        self.hs.get_datastores().main.get_current_state_event_counts = get_current_state_event_counts  # type: ignore[assignment]
+        self.hs.get_datastores().main.get_current_state_event_counts = get_current_state_event_counts  # type: ignore[method-assign]
 
         d = handler._remote_join(
             create_requester(u1),
@@ -200,9 +199,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))  # type: ignore[assignment]
-        handler.federation_handler.do_invite_join = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(("", 1))
+        fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999})  # type: ignore[method-assign]
+        handler.federation_handler.do_invite_join = AsyncMock(  # type: ignore[method-assign]
+            return_value=("", 1)
         )
 
         d = handler._remote_join(
@@ -230,9 +229,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))  # type: ignore[assignment]
-        handler.federation_handler.do_invite_join = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(("", 1))
+        fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999})  # type: ignore[method-assign]
+        handler.federation_handler.do_invite_join = AsyncMock(  # type: ignore[method-assign]
+            return_value=("", 1)
         )
 
         d = handler._remote_join(
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index 391ae51707..75ae740b43 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -1,6 +1,6 @@
 from typing import Callable, Collection, List, Optional, Tuple
 from unittest import mock
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -19,7 +19,7 @@ from synapse.types import JsonDict
 from synapse.util import Clock
 from synapse.util.retryutils import NotRetryingDestination
 
-from tests.test_utils import event_injection, make_awaitable
+from tests.test_utils import event_injection
 from tests.unittest import FederatingHomeserverTestCase
 
 
@@ -50,8 +50,8 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
         # This mock is crucial for destination_rooms to be populated.
         # TODO: this seems to no longer be the case---tests pass with this mock
         # commented out.
-        state_storage_controller.get_current_hosts_in_room = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable({"test", "host2"})
+        state_storage_controller.get_current_hosts_in_room = AsyncMock(  # type: ignore[method-assign]
+            return_value={"test", "host2"}
         )
 
         # whenever send_transaction is called, record the pdu data
@@ -431,28 +431,24 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
         # ACT: call _wake_destinations_needing_catchup
 
         # patch wake_destination to just count the destinations instead
-        woken = []
+        woken = set()
 
         def wake_destination_track(destination: str) -> None:
-            woken.append(destination)
+            woken.add(destination)
 
-        self.federation_sender.wake_destination = wake_destination_track  # type: ignore[assignment]
+        self.federation_sender.wake_destination = wake_destination_track  # type: ignore[method-assign]
 
-        # cancel the pre-existing timer for _wake_destinations_needing_catchup
-        # this is because we are calling it manually rather than waiting for it
-        # to be called automatically
-        assert self.federation_sender._catchup_after_startup_timer is not None
-        self.federation_sender._catchup_after_startup_timer.cancel()
-
-        self.get_success(
-            self.federation_sender._wake_destinations_needing_catchup(), by=5.0
-        )
+        # We wait quite long so that all dests can be woken up, since there is a delay
+        # between them.
+        self.pump(by=5.0)
 
         # ASSERT (_wake_destinations_needing_catchup):
         # - all remotes are woken up, save for zzzerver
         self.assertNotIn("zzzerver", woken)
-        # - all destinations are woken exactly once; they appear once in woken.
-        self.assertCountEqual(woken, server_names[:-1])
+        # - all destinations are woken, potentially more than once, since the
+        # wake up is called regularly and we don't ack in this test that a transaction
+        # has been successfully sent.
+        self.assertCountEqual(woken, set(server_names[:-1]))
 
     def test_not_latest_event(self) -> None:
         """Test that we send the latest event in the room even if its not ours."""
diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py
index 91694e4fca..a45ab83683 100644
--- a/tests/federation/test_federation_client.py
+++ b/tests/federation/test_federation_client.py
@@ -124,7 +124,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
         # check the right call got made to the agent
         self._mock_agent.request.assert_called_once_with(
             b"GET",
-            b"matrix://yet.another.server/_matrix/federation/v1/state/%21room_id?event_id=event_id",
+            b"matrix-federation://yet.another.server/_matrix/federation/v1/state/%21room_id?event_id=event_id",
             headers=mock.ANY,
             bodyProducer=None,
         )
@@ -232,7 +232,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
         # check the right call got made to the agent
         self._mock_agent.request.assert_called_once_with(
             b"GET",
-            b"matrix://yet.another.server/_matrix/federation/v1/event/event_id",
+            b"matrix-federation://yet.another.server/_matrix/federation/v1/event/event_id",
             headers=mock.ANY,
             bodyProducer=None,
         )
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 9e104fd96a..caf04b54cb 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from typing import Callable, FrozenSet, List, Optional, Set
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from signedjson import key, sign
 from signedjson.types import BaseKey, SigningKey
@@ -29,7 +29,6 @@ from synapse.server import HomeServer
 from synapse.types import JsonDict, ReadReceipt
 from synapse.util import Clock
 
-from tests.test_utils import make_awaitable
 from tests.unittest import HomeserverTestCase
 
 
@@ -43,15 +42,16 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.federation_transport_client = Mock(spec=["send_transaction"])
+        self.federation_transport_client.send_transaction = AsyncMock()
         hs = self.setup_test_homeserver(
             federation_transport_client=self.federation_transport_client,
         )
 
-        hs.get_storage_controllers().state.get_current_hosts_in_room = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable({"test", "host2"})
+        hs.get_storage_controllers().state.get_current_hosts_in_room = AsyncMock(  # type: ignore[method-assign]
+            return_value={"test", "host2"}
         )
 
-        hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = (  # type: ignore[assignment]
+        hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = (  # type: ignore[method-assign]
             hs.get_storage_controllers().state.get_current_hosts_in_room
         )
 
@@ -64,7 +64,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
 
     def test_send_receipts(self) -> None:
         mock_send_transaction = self.federation_transport_client.send_transaction
-        mock_send_transaction.return_value = make_awaitable({})
+        mock_send_transaction.return_value = {}
 
         sender = self.hs.get_federation_sender()
         receipt = ReadReceipt(
@@ -75,7 +75,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
             thread_id=None,
             data={"ts": 1234},
         )
-        self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
+        self.get_success(sender.send_read_receipt(receipt))
 
         self.pump()
 
@@ -104,13 +104,16 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
 
     def test_send_receipts_thread(self) -> None:
         mock_send_transaction = self.federation_transport_client.send_transaction
-        mock_send_transaction.return_value = make_awaitable({})
+        mock_send_transaction.return_value = {}
 
         # Create receipts for:
         #
         # * The same room / user on multiple threads.
         # * A different user in the same room.
         sender = self.hs.get_federation_sender()
+        # Hack so that we have a txn in-flight so we batch up read receipts
+        # below
+        sender.wake_destination("host2")
         for user, thread in (
             ("alice", None),
             ("alice", "thread"),
@@ -125,9 +128,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
                 thread_id=thread,
                 data={"ts": 1234},
             )
-            self.successResultOf(
-                defer.ensureDeferred(sender.send_read_receipt(receipt))
-            )
+            defer.ensureDeferred(sender.send_read_receipt(receipt))
 
         self.pump()
 
@@ -180,7 +181,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
         """Send two receipts in quick succession; the second should be flushed, but
         only after 20ms"""
         mock_send_transaction = self.federation_transport_client.send_transaction
-        mock_send_transaction.return_value = make_awaitable({})
+        mock_send_transaction.return_value = {}
 
         sender = self.hs.get_federation_sender()
         receipt = ReadReceipt(
@@ -191,7 +192,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
             thread_id=None,
             data={"ts": 1234},
         )
-        self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
+        self.get_success(sender.send_read_receipt(receipt))
 
         self.pump()
 
@@ -276,6 +277,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         self.federation_transport_client = Mock(
             spec=["send_transaction", "query_user_devices"]
         )
+        self.federation_transport_client.send_transaction = AsyncMock()
+        self.federation_transport_client.query_user_devices = AsyncMock()
         return self.setup_test_homeserver(
             federation_transport_client=self.federation_transport_client,
         )
@@ -317,13 +320,13 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
             self.record_transaction
         )
 
-    def record_transaction(
+    async def record_transaction(
         self, txn: Transaction, json_cb: Optional[Callable[[], JsonDict]] = None
-    ) -> "defer.Deferred[JsonDict]":
+    ) -> JsonDict:
         assert json_cb is not None
         data = json_cb()
         self.edus.extend(data["edus"])
-        return defer.succeed({})
+        return {}
 
     def test_send_device_updates(self) -> None:
         """Basic case: each device update should result in an EDU"""
@@ -340,7 +343,9 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         self.reactor.advance(1)
 
         # a second call should produce no new device EDUs
-        self.hs.get_federation_sender().send_device_messages("host2")
+        self.get_success(
+            self.hs.get_federation_sender().send_device_messages(["host2"])
+        )
         self.assertEqual(self.edus, [])
 
         # a second device
@@ -354,15 +359,11 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
 
         # Send the server a device list EDU for the other user, this will cause
         # it to try and resync the device lists.
-        self.federation_transport_client.query_user_devices.return_value = (
-            make_awaitable(
-                {
-                    "stream_id": "1",
-                    "user_id": "@user2:host2",
-                    "devices": [{"device_id": "D1"}],
-                }
-            )
-        )
+        self.federation_transport_client.query_user_devices.return_value = {
+            "stream_id": "1",
+            "user_id": "@user2:host2",
+            "devices": [{"device_id": "D1"}],
+        }
 
         self.get_success(
             self.device_handler.device_list_updater.incoming_device_list_update(
@@ -533,7 +534,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         recovery
         """
         mock_send_txn = self.federation_transport_client.send_transaction
-        mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
+        mock_send_txn.side_effect = AssertionError("fail")
 
         # create devices
         u1 = self.register_user("user", "pass")
@@ -552,7 +553,9 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
 
         # recover the server
         mock_send_txn.side_effect = self.record_transaction
-        self.hs.get_federation_sender().send_device_messages("host2")
+        self.get_success(
+            self.hs.get_federation_sender().send_device_messages(["host2"])
+        )
 
         # We queue up device list updates to be sent over federation, so we
         # advance to clear the queue.
@@ -578,7 +581,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         This case tests the behaviour when the server has never been reachable.
         """
         mock_send_txn = self.federation_transport_client.send_transaction
-        mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
+        mock_send_txn.side_effect = AssertionError("fail")
 
         # create devices
         u1 = self.register_user("user", "pass")
@@ -603,7 +606,9 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
 
         # recover the server
         mock_send_txn.side_effect = self.record_transaction
-        self.hs.get_federation_sender().send_device_messages("host2")
+        self.get_success(
+            self.hs.get_federation_sender().send_device_messages(["host2"])
+        )
 
         # We queue up device list updates to be sent over federation, so we
         # advance to clear the queue.
@@ -636,7 +641,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
 
         # now the server goes offline
         mock_send_txn = self.federation_transport_client.send_transaction
-        mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
+        mock_send_txn.side_effect = AssertionError("fail")
 
         self.login("user", "pass", device_id="D2")
         self.login("user", "pass", device_id="D3")
@@ -658,7 +663,9 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
 
         # recover the server
         mock_send_txn.side_effect = self.record_transaction
-        self.hs.get_federation_sender().send_device_messages("host2")
+        self.get_success(
+            self.hs.get_federation_sender().send_device_messages(["host2"])
+        )
 
         # We queue up device list updates to be sent over federation, so we
         # advance to clear the queue.
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index 5c850d1843..1831a5b47a 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -22,10 +22,10 @@ from twisted.test.proto_helpers import MemoryReactor
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.config.server import DEFAULT_ROOM_VERSION
 from synapse.events import EventBase, make_event_from_dict
-from synapse.federation.federation_server import server_matches_acl_event
 from synapse.rest import admin
 from synapse.rest.client import login, room
 from synapse.server import HomeServer
+from synapse.storage.controllers.state import server_acl_evaluator_from_event
 from synapse.types import JsonDict
 from synapse.util import Clock
 
@@ -67,37 +67,46 @@ class ServerACLsTestCase(unittest.TestCase):
         e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]})
         logging.info("ACL event: %s", e.content)
 
-        self.assertFalse(server_matches_acl_event("evil.com", e))
-        self.assertFalse(server_matches_acl_event("EVIL.COM", e))
+        server_acl_evalutor = server_acl_evaluator_from_event(e)
 
-        self.assertTrue(server_matches_acl_event("evil.com.au", e))
-        self.assertTrue(server_matches_acl_event("honestly.not.evil.com", e))
+        self.assertFalse(server_acl_evalutor.server_matches_acl_event("evil.com"))
+        self.assertFalse(server_acl_evalutor.server_matches_acl_event("EVIL.COM"))
+
+        self.assertTrue(server_acl_evalutor.server_matches_acl_event("evil.com.au"))
+        self.assertTrue(
+            server_acl_evalutor.server_matches_acl_event("honestly.not.evil.com")
+        )
 
     def test_block_ip_literals(self) -> None:
         e = _create_acl_event({"allow_ip_literals": False, "allow": ["*"]})
         logging.info("ACL event: %s", e.content)
 
-        self.assertFalse(server_matches_acl_event("1.2.3.4", e))
-        self.assertTrue(server_matches_acl_event("1a.2.3.4", e))
-        self.assertFalse(server_matches_acl_event("[1:2::]", e))
-        self.assertTrue(server_matches_acl_event("1:2:3:4", e))
+        server_acl_evalutor = server_acl_evaluator_from_event(e)
+
+        self.assertFalse(server_acl_evalutor.server_matches_acl_event("1.2.3.4"))
+        self.assertTrue(server_acl_evalutor.server_matches_acl_event("1a.2.3.4"))
+        self.assertFalse(server_acl_evalutor.server_matches_acl_event("[1:2::]"))
+        self.assertTrue(server_acl_evalutor.server_matches_acl_event("1:2:3:4"))
 
     def test_wildcard_matching(self) -> None:
         e = _create_acl_event({"allow": ["good*.com"]})
+
+        server_acl_evalutor = server_acl_evaluator_from_event(e)
+
         self.assertTrue(
-            server_matches_acl_event("good.com", e),
+            server_acl_evalutor.server_matches_acl_event("good.com"),
             "* matches 0 characters",
         )
         self.assertTrue(
-            server_matches_acl_event("GOOD.COM", e),
+            server_acl_evalutor.server_matches_acl_event("GOOD.COM"),
             "pattern is case-insensitive",
         )
         self.assertTrue(
-            server_matches_acl_event("good.aa.com", e),
+            server_acl_evalutor.server_matches_acl_event("good.aa.com"),
             "* matches several characters, including '.'",
         )
         self.assertFalse(
-            server_matches_acl_event("ishgood.com", e),
+            server_acl_evalutor.server_matches_acl_event("ishgood.com"),
             "pattern does not allow prefixes",
         )
 
diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py
index 70209ab090..b63ef3d4ed 100644
--- a/tests/federation/transport/test_knocking.py
+++ b/tests/federation/transport/test_knocking.py
@@ -218,7 +218,7 @@ class FederationKnockingTestCase(
         ) -> EventBase:
             return pdu
 
-        homeserver.get_federation_server()._check_sigs_and_hash = (  # type: ignore[assignment]
+        homeserver.get_federation_server()._check_sigs_and_hash = (  # type: ignore[method-assign]
             approve_all_signature_checking
         )
 
@@ -229,7 +229,7 @@ class FederationKnockingTestCase(
         ) -> None:
             pass
 
-        homeserver.get_federation_event_handler()._check_event_auth = _check_event_auth  # type: ignore[assignment]
+        homeserver.get_federation_event_handler()._check_event_auth = _check_event_auth  # type: ignore[method-assign]
 
         return super().prepare(reactor, clock, homeserver)
 
@@ -308,7 +308,7 @@ class FederationKnockingTestCase(
         self.assertEqual(200, channel.code, channel.result)
 
         # Check that we got the stripped room state in return
-        room_state_events = channel.json_body["knock_state_events"]
+        room_state_events = channel.json_body["knock_room_state"]
 
         # Validate the stripped room state events
         self.check_knock_room_state_against_room_state(
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 9014e60577..867dbd6001 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 from typing import Dict, Iterable, List, Optional
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from parameterized import parameterized
 
@@ -31,12 +31,12 @@ from synapse.appservice import (
 from synapse.handlers.appservice import ApplicationServicesHandler
 from synapse.rest.client import login, receipts, register, room, sendtodevice
 from synapse.server import HomeServer
-from synapse.types import JsonDict, RoomStreamToken
+from synapse.types import JsonDict, RoomStreamToken, StreamKeyType
 from synapse.util import Clock
 from synapse.util.stringutils import random_string
 
 from tests import unittest
-from tests.test_utils import event_injection, make_awaitable, simple_async_mock
+from tests.test_utils import event_injection
 from tests.unittest import override_config
 from tests.utils import MockClock
 
@@ -46,15 +46,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
 
     def setUp(self) -> None:
         self.mock_store = Mock()
-        self.mock_as_api = Mock()
+        self.mock_as_api = AsyncMock()
         self.mock_scheduler = Mock()
         hs = Mock()
         hs.get_datastores.return_value = Mock(main=self.mock_store)
-        self.mock_store.get_appservice_last_pos.return_value = make_awaitable(None)
-        self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None)
-        self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable(
-            None
-        )
+        self.mock_store.get_appservice_last_pos = AsyncMock(return_value=None)
+        self.mock_store.set_appservice_last_pos = AsyncMock(return_value=None)
+        self.mock_store.set_appservice_stream_type_pos = AsyncMock(return_value=None)
         hs.get_application_service_api.return_value = self.mock_as_api
         hs.get_application_service_scheduler.return_value = self.mock_scheduler
         hs.get_clock.return_value = MockClock()
@@ -69,22 +67,26 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             self._mkservice(is_interested_in_event=False),
         ]
 
-        self.mock_as_api.query_user.return_value = make_awaitable(True)
+        self.mock_as_api.query_user.return_value = True
         self.mock_store.get_app_services.return_value = services
-        self.mock_store.get_user_by_id.return_value = make_awaitable([])
+        self.mock_store.get_user_by_id = AsyncMock(return_value=[])
 
         event = Mock(
             sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
         )
-        self.mock_store.get_all_new_event_ids_stream.side_effect = [
-            make_awaitable((0, {})),
-            make_awaitable((1, {event.event_id: 0})),
-        ]
-        self.mock_store.get_events_as_list.side_effect = [
-            make_awaitable([]),
-            make_awaitable([event]),
-        ]
-        self.handler.notify_interested_services(RoomStreamToken(None, 1))
+        self.mock_store.get_all_new_event_ids_stream = AsyncMock(
+            side_effect=[
+                (0, {}),
+                (1, {event.event_id: 0}),
+            ]
+        )
+        self.mock_store.get_events_as_list = AsyncMock(
+            side_effect=[
+                [],
+                [event],
+            ]
+        )
+        self.handler.notify_interested_services(RoomStreamToken(stream=1))
 
         self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
             interested_service, events=[event]
@@ -95,15 +97,17 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         services = [self._mkservice(is_interested_in_event=True)]
         services[0].is_interested_in_user.return_value = True
         self.mock_store.get_app_services.return_value = services
-        self.mock_store.get_user_by_id.return_value = make_awaitable(None)
+        self.mock_store.get_user_by_id = AsyncMock(return_value=None)
 
         event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
-        self.mock_as_api.query_user.return_value = make_awaitable(True)
-        self.mock_store.get_all_new_event_ids_stream.side_effect = [
-            make_awaitable((0, {event.event_id: 0})),
-        ]
-        self.mock_store.get_events_as_list.side_effect = [make_awaitable([event])]
-        self.handler.notify_interested_services(RoomStreamToken(None, 0))
+        self.mock_as_api.query_user.return_value = True
+        self.mock_store.get_all_new_event_ids_stream = AsyncMock(
+            side_effect=[
+                (0, {event.event_id: 0}),
+            ]
+        )
+        self.mock_store.get_events_as_list = AsyncMock(side_effect=[[event]])
+        self.handler.notify_interested_services(RoomStreamToken(stream=0))
 
         self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
 
@@ -112,15 +116,17 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         services = [self._mkservice(is_interested_in_event=True)]
         services[0].is_interested_in_user.return_value = True
         self.mock_store.get_app_services.return_value = services
-        self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id})
+        self.mock_store.get_user_by_id = AsyncMock(return_value={"name": user_id})
 
         event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
-        self.mock_as_api.query_user.return_value = make_awaitable(True)
-        self.mock_store.get_all_new_event_ids_stream.side_effect = [
-            make_awaitable((0, [event], {event.event_id: 0})),
-        ]
+        self.mock_as_api.query_user.return_value = True
+        self.mock_store.get_all_new_event_ids_stream = AsyncMock(
+            side_effect=[
+                (0, [event], {event.event_id: 0}),
+            ]
+        )
 
-        self.handler.notify_interested_services(RoomStreamToken(None, 0))
+        self.handler.notify_interested_services(RoomStreamToken(stream=0))
 
         self.assertFalse(
             self.mock_as_api.query_user.called,
@@ -141,10 +147,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             self._mkservice_alias(is_room_alias_in_namespace=False),
         ]
 
-        self.mock_as_api.query_alias.return_value = make_awaitable(True)
+        self.mock_as_api.query_alias = AsyncMock(return_value=True)
         self.mock_store.get_app_services.return_value = services
-        self.mock_store.get_association_from_room_alias.return_value = make_awaitable(
-            Mock(room_id=room_id, servers=servers)
+        self.mock_store.get_association_from_room_alias = AsyncMock(
+            return_value=Mock(room_id=room_id, servers=servers)
         )
 
         result = self.successResultOf(
@@ -177,7 +183,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
     def test_get_3pe_protocols_protocol_no_response(self) -> None:
         service = self._mkservice(False, ["my-protocol"])
         self.mock_store.get_app_services.return_value = [service]
-        self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None)
+        self.mock_as_api.get_3pe_protocol.return_value = None
         response = self.successResultOf(
             defer.ensureDeferred(self.handler.get_3pe_protocols())
         )
@@ -189,9 +195,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
     def test_get_3pe_protocols_select_one_protocol(self) -> None:
         service = self._mkservice(False, ["my-protocol"])
         self.mock_store.get_app_services.return_value = [service]
-        self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
-            {"x-protocol-data": 42, "instances": []}
-        )
+        self.mock_as_api.get_3pe_protocol.return_value = {
+            "x-protocol-data": 42,
+            "instances": [],
+        }
         response = self.successResultOf(
             defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol"))
         )
@@ -205,9 +212,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
     def test_get_3pe_protocols_one_protocol(self) -> None:
         service = self._mkservice(False, ["my-protocol"])
         self.mock_store.get_app_services.return_value = [service]
-        self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
-            {"x-protocol-data": 42, "instances": []}
-        )
+        self.mock_as_api.get_3pe_protocol.return_value = {
+            "x-protocol-data": 42,
+            "instances": [],
+        }
         response = self.successResultOf(
             defer.ensureDeferred(self.handler.get_3pe_protocols())
         )
@@ -222,9 +230,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         service_one = self._mkservice(False, ["my-protocol"])
         service_two = self._mkservice(False, ["other-protocol"])
         self.mock_store.get_app_services.return_value = [service_one, service_two]
-        self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
-            {"x-protocol-data": 42, "instances": []}
-        )
+        self.mock_as_api.get_3pe_protocol.return_value = {
+            "x-protocol-data": 42,
+            "instances": [],
+        }
         response = self.successResultOf(
             defer.ensureDeferred(self.handler.get_3pe_protocols())
         )
@@ -287,17 +296,15 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         interested_service = self._mkservice(is_interested_in_event=True)
         services = [interested_service]
         self.mock_store.get_app_services.return_value = services
-        self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
-            579
-        )
+        self.mock_store.get_type_stream_id_for_appservice = AsyncMock(return_value=579)
 
         event = Mock(event_id="event_1")
-        self.event_source.sources.receipt.get_new_events_as.return_value = (
-            make_awaitable(([event], None))
+        self.event_source.sources.receipt.get_new_events_as = AsyncMock(
+            return_value=([event], None)
         )
 
         self.handler.notify_interested_services_ephemeral(
-            "receipt_key", 580, ["@fakerecipient:example.com"]
+            StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"]
         )
         self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
             interested_service, ephemeral=[event]
@@ -317,17 +324,15 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         services = [interested_service]
 
         self.mock_store.get_app_services.return_value = services
-        self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
-            580
-        )
+        self.mock_store.get_type_stream_id_for_appservice = AsyncMock(return_value=580)
 
         event = Mock(event_id="event_1")
-        self.event_source.sources.receipt.get_new_events_as.return_value = (
-            make_awaitable(([event], None))
+        self.event_source.sources.receipt.get_new_events_as = AsyncMock(
+            return_value=([event], None)
         )
 
         self.handler.notify_interested_services_ephemeral(
-            "receipt_key", 580, ["@fakerecipient:example.com"]
+            StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"]
         )
         # This method will be called, but with an empty list of events
         self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
@@ -350,9 +355,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             A mock representing the ApplicationService.
         """
         service = Mock()
-        service.is_interested_in_event.return_value = make_awaitable(
-            is_interested_in_event
-        )
+        service.is_interested_in_event = AsyncMock(return_value=is_interested_in_event)
         service.token = "mock_service_token"
         service.url = "mock_service_url"
         service.protocols = protocols
@@ -396,12 +399,12 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
         self.hs = hs
         # Mock the ApplicationServiceScheduler's _TransactionController's send method so that
         # we can track any outgoing ephemeral events
-        self.send_mock = simple_async_mock()
-        hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock  # type: ignore[assignment]
+        self.send_mock = AsyncMock()
+        hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock  # type: ignore[method-assign]
 
         # Mock out application services, and allow defining our own in tests
         self._services: List[ApplicationService] = []
-        self.hs.get_datastores().main.get_app_services = Mock(  # type: ignore[assignment]
+        self.hs.get_datastores().main.get_app_services = Mock(  # type: ignore[method-assign]
             return_value=self._services
         )
 
@@ -419,6 +422,18 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
             "exclusive_as_user", "password", self.exclusive_as_user_device_id
         )
 
+        self.exclusive_as_user_2_device_id = "exclusive_as_device_2"
+        self.exclusive_as_user_2 = self.register_user("exclusive_as_user_2", "password")
+        self.exclusive_as_user_2_token = self.login(
+            "exclusive_as_user_2", "password", self.exclusive_as_user_2_device_id
+        )
+
+        self.exclusive_as_user_3_device_id = "exclusive_as_device_3"
+        self.exclusive_as_user_3 = self.register_user("exclusive_as_user_3", "password")
+        self.exclusive_as_user_3_token = self.login(
+            "exclusive_as_user_3", "password", self.exclusive_as_user_3_device_id
+        )
+
     def _notify_interested_services(self) -> None:
         # This is normally set in `notify_interested_services` but we need to call the
         # internal async version so the reactor gets pushed to completion.
@@ -426,7 +441,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
         self.get_success(
             self.hs.get_application_service_handler()._notify_interested_services(
                 RoomStreamToken(
-                    None, self.hs.get_application_service_handler().current_max
+                    stream=self.hs.get_application_service_handler().current_max
                 )
             )
         )
@@ -619,7 +634,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
             self.get_success(
                 self.hs.get_application_service_handler()._notify_interested_services_ephemeral(
                     services=[interested_appservice],
-                    stream_key="receipt_key",
+                    stream_key=StreamKeyType.RECEIPT,
                     new_token=stream_token,
                     users=[self.exclusive_as_user],
                 )
@@ -846,6 +861,119 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
         for count in service_id_to_message_count.values():
             self.assertEqual(count, number_of_messages)
 
+    @unittest.override_config(
+        {"experimental_features": {"msc2409_to_device_messages_enabled": True}}
+    )
+    def test_application_services_receive_local_to_device_for_many_users(self) -> None:
+        """
+        Test that when a user sends a to-device message to many users
+        in an application service's user namespace, the
+        application service will receive all of them.
+        """
+        interested_appservice = self._register_application_service(
+            namespaces={
+                ApplicationService.NS_USERS: [
+                    {
+                        "regex": "@exclusive_as_user:.+",
+                        "exclusive": True,
+                    },
+                    {
+                        "regex": "@exclusive_as_user_2:.+",
+                        "exclusive": True,
+                    },
+                    {
+                        "regex": "@exclusive_as_user_3:.+",
+                        "exclusive": True,
+                    },
+                ],
+            },
+        )
+
+        # Have local_user send a to-device message to exclusive_as_users
+        message_content = {"some_key": "some really interesting value"}
+        chan = self.make_request(
+            "PUT",
+            "/_matrix/client/r0/sendToDevice/m.room_key_request/3",
+            content={
+                "messages": {
+                    self.exclusive_as_user: {
+                        self.exclusive_as_user_device_id: message_content
+                    },
+                    self.exclusive_as_user_2: {
+                        self.exclusive_as_user_2_device_id: message_content
+                    },
+                    self.exclusive_as_user_3: {
+                        self.exclusive_as_user_3_device_id: message_content
+                    },
+                }
+            },
+            access_token=self.local_user_token,
+        )
+        self.assertEqual(chan.code, 200, chan.result)
+
+        # Have exclusive_as_user send a to-device message to local_user
+        for user_token in [
+            self.exclusive_as_user_token,
+            self.exclusive_as_user_2_token,
+            self.exclusive_as_user_3_token,
+        ]:
+            chan = self.make_request(
+                "PUT",
+                "/_matrix/client/r0/sendToDevice/m.room_key_request/4",
+                content={
+                    "messages": {
+                        self.local_user: {self.local_user_device_id: message_content}
+                    }
+                },
+                access_token=user_token,
+            )
+            self.assertEqual(chan.code, 200, chan.result)
+
+        # Check if our application service - that is interested in exclusive_as_user - received
+        # the to-device message as part of an AS transaction.
+        # Only the local_user -> exclusive_as_user to-device message should have been forwarded to the AS.
+        #
+        # The uninterested application service should not have been notified at all.
+        self.send_mock.assert_called_once()
+        (
+            service,
+            _events,
+            _ephemeral,
+            to_device_messages,
+            _otks,
+            _fbks,
+            _device_list_summary,
+        ) = self.send_mock.call_args[0]
+
+        # Assert that this was the same to-device message that local_user sent
+        self.assertEqual(service, interested_appservice)
+
+        # Assert expected number of messages
+        self.assertEqual(len(to_device_messages), 3)
+
+        for device_msg in to_device_messages:
+            self.assertEqual(device_msg["type"], "m.room_key_request")
+            self.assertEqual(device_msg["sender"], self.local_user)
+            self.assertEqual(device_msg["content"], message_content)
+
+        self.assertEqual(to_device_messages[0]["to_user_id"], self.exclusive_as_user)
+        self.assertEqual(
+            to_device_messages[0]["to_device_id"],
+            self.exclusive_as_user_device_id,
+        )
+
+        self.assertEqual(to_device_messages[1]["to_user_id"], self.exclusive_as_user_2)
+        self.assertEqual(
+            to_device_messages[1]["to_device_id"],
+            self.exclusive_as_user_2_device_id,
+        )
+
+        self.assertEqual(to_device_messages[2]["to_user_id"], self.exclusive_as_user_3)
+        self.assertEqual(
+            to_device_messages[2]["to_device_id"],
+            self.exclusive_as_user_3_device_id,
+        )
+
     def _register_application_service(
         self,
         namespaces: Optional[Dict[str, Iterable[Dict]]] = None,
@@ -894,12 +1022,12 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase)
 
         # Mock ApplicationServiceApi's put_json, so we can verify the raw JSON that
         # will be sent over the wire
-        self.put_json = simple_async_mock()
-        hs.get_application_service_api().put_json = self.put_json  # type: ignore[assignment]
+        self.put_json = AsyncMock()
+        hs.get_application_service_api().put_json = self.put_json  # type: ignore[method-assign]
 
         # Mock out application services, and allow defining our own in tests
         self._services: List[ApplicationService] = []
-        self.hs.get_datastores().main.get_app_services = Mock(  # type: ignore[assignment]
+        self.hs.get_datastores().main.get_app_services = Mock(  # type: ignore[method-assign]
             return_value=self._services
         )
 
@@ -1000,8 +1128,8 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         # Mock the ApplicationServiceScheduler's _TransactionController's send method so that
         # we can track what's going out
-        self.send_mock = simple_async_mock()
-        hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock  # type: ignore[assignment]  # We assign to a method.
+        self.send_mock = AsyncMock()
+        hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock  # type: ignore[method-assign]  # We assign to a method.
 
         # Define an application service for the tests
         self._service_token = "VERYSECRET"
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 036dbbc45b..413ff8795b 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from typing import Optional
-from unittest.mock import Mock
+from unittest.mock import AsyncMock
 
 import pymacaroons
 
@@ -25,7 +25,6 @@ from synapse.server import HomeServer
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 
 
 class AuthTestCase(unittest.HomeserverTestCase):
@@ -166,8 +165,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
     def test_mau_limits_exceeded_large(self) -> None:
         self.auth_blocking._limit_usage_by_mau = True
-        self.hs.get_datastores().main.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.large_number_of_users)
+        self.hs.get_datastores().main.get_monthly_active_count = AsyncMock(
+            return_value=self.large_number_of_users
         )
 
         self.get_failure(
@@ -177,8 +176,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
             ResourceLimitError,
         )
 
-        self.hs.get_datastores().main.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.large_number_of_users)
+        self.hs.get_datastores().main.get_monthly_active_count = AsyncMock(
+            return_value=self.large_number_of_users
         )
         token = self.get_success(
             self.auth_handler.create_login_token_for_user_id(self.user1)
@@ -191,8 +190,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.auth_blocking._limit_usage_by_mau = True
 
         # Set the server to be at the edge of too many users.
-        self.hs.get_datastores().main.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.auth_blocking._max_mau_value)
+        self.hs.get_datastores().main.get_monthly_active_count = AsyncMock(
+            return_value=self.auth_blocking._max_mau_value
         )
 
         # If not in monthly active cohort
@@ -208,8 +207,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.assertIsNone(self.token_login(token))
 
         # If in monthly active cohort
-        self.hs.get_datastores().main.user_last_seen_monthly_active = Mock(
-            return_value=make_awaitable(self.clock.time_msec())
+        self.hs.get_datastores().main.user_last_seen_monthly_active = AsyncMock(
+            return_value=self.clock.time_msec()
         )
         self.get_success(
             self.auth_handler.create_access_token_for_user_id(
@@ -224,8 +223,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
     def test_mau_limits_not_exceeded(self) -> None:
         self.auth_blocking._limit_usage_by_mau = True
 
-        self.hs.get_datastores().main.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.small_number_of_users)
+        self.hs.get_datastores().main.get_monthly_active_count = AsyncMock(
+            return_value=self.small_number_of_users
         )
         # Ensure does not raise exception
         self.get_success(
@@ -234,8 +233,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        self.hs.get_datastores().main.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.small_number_of_users)
+        self.hs.get_datastores().main.get_monthly_active_count = AsyncMock(
+            return_value=self.small_number_of_users
         )
         token = self.get_success(
             self.auth_handler.create_login_token_for_user_id(self.user1)
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 63aad0d10c..13e2cd153a 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -12,7 +12,7 @@
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
 from typing import Any, Dict
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -20,7 +20,6 @@ from synapse.handlers.cas import CasResponse
 from synapse.server import HomeServer
 from synapse.util import Clock
 
-from tests.test_utils import simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
 
 # These are a few constants that are used as config parameters in the tests.
@@ -61,7 +60,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
+        auth_handler.complete_sso_login = AsyncMock()  # type: ignore[method-assign]
 
         cas_response = CasResponse("test_user", {})
         request = _mock_request()
@@ -89,7 +88,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
+        auth_handler.complete_sso_login = AsyncMock()  # type: ignore[method-assign]
 
         # Map a user via SSO.
         cas_response = CasResponse("test_user", {})
@@ -129,7 +128,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
+        auth_handler.complete_sso_login = AsyncMock()  # type: ignore[method-assign]
 
         cas_response = CasResponse("föö", {})
         request = _mock_request()
@@ -160,7 +159,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
+        auth_handler.complete_sso_login = AsyncMock()  # type: ignore[method-assign]
 
         # The response doesn't have the proper userGroup or department.
         cas_response = CasResponse("test_user", {})
@@ -198,6 +197,23 @@ class CasHandlerTestCase(HomeserverTestCase):
             auth_provider_session_id=None,
         )
 
+    @override_config({"cas_config": {"enable_registration": False}})
+    def test_map_cas_user_does_not_register_new_user(self) -> None:
+        """Ensures new users are not registered if the enabled registration flag is disabled."""
+
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = AsyncMock()  # type: ignore[method-assign]
+
+        cas_response = CasResponse("test_user", {})
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+
+        # check that the auth handler was not called as expected
+        auth_handler.complete_sso_login.assert_not_called()
+
 
 def _mock_request() -> Mock:
     """Returns a mock which will stand in as a SynapseRequest"""
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index ee48f9e546..d4ed068357 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -17,19 +17,22 @@
 from typing import Optional
 from unittest import mock
 
+from twisted.internet.defer import ensureDeferred
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import RoomEncryptionAlgorithms
 from synapse.api.errors import NotFoundError, SynapseError
 from synapse.appservice import ApplicationService
 from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
+from synapse.rest import admin
+from synapse.rest.client import devices, login, register
 from synapse.server import HomeServer
 from synapse.storage.databases.main.appservice import _make_exclusive_regex
-from synapse.types import JsonDict
+from synapse.types import JsonDict, create_requester
 from synapse.util import Clock
+from synapse.util.task_scheduler import TaskScheduler
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 
 user1 = "@boris:aaa"
@@ -38,16 +41,16 @@ user2 = "@theresa:bbb"
 
 class DeviceTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-        self.appservice_api = mock.Mock()
+        self.appservice_api = mock.AsyncMock()
         hs = self.setup_test_homeserver(
             "server",
-            federation_http_client=None,
             application_service_api=self.appservice_api,
         )
         handler = hs.get_device_handler()
         assert isinstance(handler, DeviceHandler)
         self.handler = handler
         self.store = hs.get_datastores().main
+        self.device_message_handler = hs.get_device_message_handler()
         return hs
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -121,50 +124,50 @@ class DeviceTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(3, len(res))
         device_map = {d["device_id"]: d for d in res}
-        self.assertDictContainsSubset(
+        self.assertLessEqual(
             {
                 "user_id": user1,
                 "device_id": "xyz",
                 "display_name": "display 0",
                 "last_seen_ip": None,
                 "last_seen_ts": None,
-            },
-            device_map["xyz"],
+            }.items(),
+            device_map["xyz"].items(),
         )
-        self.assertDictContainsSubset(
+        self.assertLessEqual(
             {
                 "user_id": user1,
                 "device_id": "fco",
                 "display_name": "display 1",
                 "last_seen_ip": "ip1",
                 "last_seen_ts": 1000000,
-            },
-            device_map["fco"],
+            }.items(),
+            device_map["fco"].items(),
         )
-        self.assertDictContainsSubset(
+        self.assertLessEqual(
             {
                 "user_id": user1,
                 "device_id": "abc",
                 "display_name": "display 2",
                 "last_seen_ip": "ip3",
                 "last_seen_ts": 3000000,
-            },
-            device_map["abc"],
+            }.items(),
+            device_map["abc"].items(),
         )
 
     def test_get_device(self) -> None:
         self._record_users()
 
         res = self.get_success(self.handler.get_device(user1, "abc"))
-        self.assertDictContainsSubset(
+        self.assertLessEqual(
             {
                 "user_id": user1,
                 "device_id": "abc",
                 "display_name": "display 2",
                 "last_seen_ip": "ip3",
                 "last_seen_ts": 3000000,
-            },
-            res,
+            }.items(),
+            res.items(),
         )
 
     def test_delete_device(self) -> None:
@@ -210,6 +213,51 @@ class DeviceTestCase(unittest.HomeserverTestCase):
         )
         self.assertIsNone(res)
 
+    def test_delete_device_and_big_device_inbox(self) -> None:
+        """Check that deleting a big device inbox is staged and batched asynchronously."""
+        DEVICE_ID = "abc"
+        sender = "@sender:" + self.hs.hostname
+        receiver = "@receiver:" + self.hs.hostname
+        self._record_user(sender, DEVICE_ID, DEVICE_ID)
+        self._record_user(receiver, DEVICE_ID, DEVICE_ID)
+
+        # queue a bunch of messages in the inbox
+        requester = create_requester(sender, device_id=DEVICE_ID)
+        for i in range(DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT + 10):
+            self.get_success(
+                self.device_message_handler.send_device_message(
+                    requester, "message_type", {receiver: {"*": {"val": i}}}
+                )
+            )
+
+        # delete the device
+        self.get_success(self.handler.delete_devices(receiver, [DEVICE_ID]))
+
+        # messages should be deleted up to DEVICE_MSGS_DELETE_BATCH_LIMIT straight away
+        res = self.get_success(
+            self.store.db_pool.simple_select_list(
+                table="device_inbox",
+                keyvalues={"user_id": receiver},
+                retcols=("user_id", "device_id", "stream_id"),
+                desc="get_device_id_from_device_inbox",
+            )
+        )
+        self.assertEqual(10, len(res))
+
+        # wait for the task scheduler to do a second delete pass
+        self.reactor.advance(TaskScheduler.SCHEDULE_INTERVAL_MS / 1000)
+
+        # remaining messages should now be deleted
+        res = self.get_success(
+            self.store.db_pool.simple_select_list(
+                table="device_inbox",
+                keyvalues={"user_id": receiver},
+                retcols=("user_id", "device_id", "stream_id"),
+                desc="get_device_id_from_device_inbox",
+            )
+        )
+        self.assertEqual(0, len(res))
+
     def test_update_device(self) -> None:
         self._record_users()
 
@@ -373,13 +421,11 @@ class DeviceTestCase(unittest.HomeserverTestCase):
         )
 
         # Setup a response.
-        self.appservice_api.query_keys.return_value = make_awaitable(
-            {
-                "device_keys": {
-                    local_user: {device_2: device_key_2b, device_3: device_key_3}
-                }
+        self.appservice_api.query_keys.return_value = {
+            "device_keys": {
+                local_user: {device_2: device_key_2b, device_3: device_key_3}
             }
-        )
+        }
 
         # Request all devices.
         res = self.get_success(
@@ -400,13 +446,22 @@ class DeviceTestCase(unittest.HomeserverTestCase):
 
 
 class DehydrationTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        admin.register_servlets_for_client_rest_resource,
+        login.register_servlets,
+        register.register_servlets,
+        devices.register_servlets,
+    ]
+
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-        hs = self.setup_test_homeserver("server", federation_http_client=None)
+        hs = self.setup_test_homeserver("server")
         handler = hs.get_device_handler()
         assert isinstance(handler, DeviceHandler)
         self.handler = handler
+        self.message_handler = hs.get_device_message_handler()
         self.registration = hs.get_registration_handler()
         self.auth = hs.get_auth()
+        self.auth_handler = hs.get_auth_handler()
         self.store = hs.get_datastores().main
         return hs
 
@@ -419,6 +474,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
         stored_dehydrated_device_id = self.get_success(
             self.handler.store_dehydrated_device(
                 user_id=user_id,
+                device_id=None,
                 device_data={"device_data": {"foo": "bar"}},
                 initial_device_display_name="dehydrated device",
             )
@@ -432,11 +488,12 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
         self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
 
         # Create a new login for the user and dehydrated the device
-        device_id, access_token, _expiration_time, _refresh_token = self.get_success(
+        device_id, access_token, _expiration_time, refresh_token = self.get_success(
             self.registration.register_device(
                 user_id=user_id,
                 device_id=None,
                 initial_display_name="new device",
+                should_issue_refresh_token=True,
             )
         )
 
@@ -467,6 +524,12 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(user_info.device_id, retrieved_device_id)
 
+        # make sure the user device has the refresh token
+        assert refresh_token is not None
+        self.get_success(
+            self.auth_handler.refresh_token(refresh_token, 5 * 60 * 1000, 5 * 60 * 1000)
+        )
+
         # make sure the device has the display name that was set from the login
         res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))
 
@@ -482,3 +545,89 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
         ret = self.get_success(self.handler.get_dehydrated_device(user_id=user_id))
 
         self.assertIsNone(ret)
+
+    @unittest.override_config(
+        {"experimental_features": {"msc2697_enabled": False, "msc3814_enabled": True}}
+    )
+    def test_dehydrate_v2_and_fetch_events(self) -> None:
+        user_id = "@boris:server"
+
+        self.get_success(self.store.register_user(user_id, "foobar"))
+
+        # First check if we can store and fetch a dehydrated device
+        stored_dehydrated_device_id = self.get_success(
+            self.handler.store_dehydrated_device(
+                user_id=user_id,
+                device_id=None,
+                device_data={"device_data": {"foo": "bar"}},
+                initial_device_display_name="dehydrated device",
+            )
+        )
+
+        device_info = self.get_success(
+            self.handler.get_dehydrated_device(user_id=user_id)
+        )
+        assert device_info is not None
+        retrieved_device_id, device_data = device_info
+        self.assertEqual(retrieved_device_id, stored_dehydrated_device_id)
+        self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
+
+        # Create a new login for the user
+        device_id, access_token, _expiration_time, _refresh_token = self.get_success(
+            self.registration.register_device(
+                user_id=user_id,
+                device_id=None,
+                initial_display_name="new device",
+            )
+        )
+
+        requester = create_requester(user_id, device_id=device_id)
+
+        # Fetching messages for a non-existing device should return an error
+        self.get_failure(
+            self.message_handler.get_events_for_dehydrated_device(
+                requester=requester,
+                device_id="not the right device ID",
+                since_token=None,
+                limit=10,
+            ),
+            SynapseError,
+        )
+
+        # Send a message to the dehydrated device
+        ensureDeferred(
+            self.message_handler.send_device_message(
+                requester=requester,
+                message_type="test.message",
+                messages={user_id: {stored_dehydrated_device_id: {"body": "foo"}}},
+            )
+        )
+        self.pump()
+
+        # Fetch the message of the dehydrated device
+        res = self.get_success(
+            self.message_handler.get_events_for_dehydrated_device(
+                requester=requester,
+                device_id=stored_dehydrated_device_id,
+                since_token=None,
+                limit=10,
+            )
+        )
+
+        self.assertTrue(len(res["next_batch"]) > 1)
+        self.assertEqual(len(res["events"]), 1)
+        self.assertEqual(res["events"][0]["content"]["body"], "foo")
+
+        # Fetch the message of the dehydrated device again, which should return
+        # the same message as it has not been deleted
+        res = self.get_success(
+            self.message_handler.get_events_for_dehydrated_device(
+                requester=requester,
+                device_id=stored_dehydrated_device_id,
+                since_token=None,
+                limit=10,
+            )
+        )
+        self.assertTrue(len(res["next_batch"]) > 1)
+        self.assertEqual(len(res["events"]), 1)
+        self.assertEqual(res["events"][0]["content"]["body"], "foo")
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 90aec484c4..367d94eca3 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from typing import Any, Awaitable, Callable, Dict
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -27,14 +27,13 @@ from synapse.types import JsonDict, RoomAlias, create_requester
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 
 
 class DirectoryTestCase(unittest.HomeserverTestCase):
     """Tests the directory service."""
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-        self.mock_federation = Mock()
+        self.mock_federation = AsyncMock()
         self.mock_registry = Mock()
 
         self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
@@ -73,9 +72,10 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
 
     def test_get_remote_association(self) -> None:
-        self.mock_federation.make_query.return_value = make_awaitable(
-            {"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
-        )
+        self.mock_federation.make_query.return_value = {
+            "room_id": "!8765qwer:test",
+            "servers": ["test", "remote"],
+        }
 
         result = self.get_success(self.handler.get_association(self.remote_room))
 
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 72d0584061..c5556f2844 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -13,7 +13,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Iterable
+from typing import Dict, Iterable
 from unittest import mock
 
 from parameterized import parameterized
@@ -27,17 +27,16 @@ from synapse.appservice import ApplicationService
 from synapse.handlers.device import DeviceHandler
 from synapse.server import HomeServer
 from synapse.storage.databases.main.appservice import _make_exclusive_regex
-from synapse.types import JsonDict
+from synapse.types import JsonDict, UserID
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 
 
 class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-        self.appservice_api = mock.Mock()
+        self.appservice_api = mock.AsyncMock()
         return self.setup_test_homeserver(
             federation_client=mock.Mock(), application_service_api=self.appservice_api
         )
@@ -45,6 +44,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.handler = hs.get_e2e_keys_handler()
         self.store = self.hs.get_datastores().main
+        self.requester = UserID.from_string(f"@test_requester:{self.hs.hostname}")
 
     def test_query_local_devices_no_devices(self) -> None:
         """If the user has no devices, we expect an empty list."""
@@ -161,6 +161,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         res2 = self.get_success(
             self.handler.claim_one_time_keys(
                 {local_user: {device_id: {"alg1": 1}}},
+                self.requester,
                 timeout=None,
                 always_include_fallback_keys=False,
             )
@@ -206,6 +207,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {local_user: {device_id: {"alg1": 1}}},
+                self.requester,
                 timeout=None,
                 always_include_fallback_keys=False,
             )
@@ -225,6 +227,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {local_user: {device_id: {"alg1": 1}}},
+                self.requester,
                 timeout=None,
                 always_include_fallback_keys=False,
             )
@@ -274,6 +277,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {local_user: {device_id: {"alg1": 1}}},
+                self.requester,
                 timeout=None,
                 always_include_fallback_keys=False,
             )
@@ -286,6 +290,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {local_user: {device_id: {"alg1": 1}}},
+                self.requester,
                 timeout=None,
                 always_include_fallback_keys=False,
             )
@@ -307,6 +312,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {local_user: {device_id: {"alg1": 1}}},
+                self.requester,
                 timeout=None,
                 always_include_fallback_keys=False,
             )
@@ -348,6 +354,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {local_user: {device_id: {"alg1": 1}}},
+                self.requester,
                 timeout=None,
                 always_include_fallback_keys=True,
             )
@@ -370,6 +377,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {local_user: {device_id: {"alg1": 1}}},
+                self.requester,
                 timeout=None,
                 always_include_fallback_keys=True,
             )
@@ -792,29 +800,27 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
         remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
 
-        self.hs.get_federation_client().query_client_keys = mock.Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(
-                {
-                    "device_keys": {remote_user_id: {}},
-                    "master_keys": {
-                        remote_user_id: {
-                            "user_id": remote_user_id,
-                            "usage": ["master"],
-                            "keys": {"ed25519:" + remote_master_key: remote_master_key},
-                        },
-                    },
-                    "self_signing_keys": {
-                        remote_user_id: {
-                            "user_id": remote_user_id,
-                            "usage": ["self_signing"],
-                            "keys": {
-                                "ed25519:"
-                                + remote_self_signing_key: remote_self_signing_key
-                            },
-                        }
+        self.hs.get_federation_client().query_client_keys = mock.AsyncMock(  # type: ignore[method-assign]
+            return_value={
+                "device_keys": {remote_user_id: {}},
+                "master_keys": {
+                    remote_user_id: {
+                        "user_id": remote_user_id,
+                        "usage": ["master"],
+                        "keys": {"ed25519:" + remote_master_key: remote_master_key},
                     },
-                }
-            )
+                },
+                "self_signing_keys": {
+                    remote_user_id: {
+                        "user_id": remote_user_id,
+                        "usage": ["self_signing"],
+                        "keys": {
+                            "ed25519:"
+                            + remote_self_signing_key: remote_self_signing_key
+                        },
+                    }
+                },
+            }
         )
 
         e2e_handler = self.hs.get_e2e_keys_handler()
@@ -865,34 +871,29 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
 
         # Pretend we're sharing a room with the user we're querying. If not,
         # `_query_devices_for_destination` will return early.
-        self.store.get_rooms_for_user = mock.Mock(
-            return_value=make_awaitable({"some_room_id"})
-        )
+        self.store.get_rooms_for_user = mock.AsyncMock(return_value={"some_room_id"})
 
         remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
         remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
 
-        self.hs.get_federation_client().query_user_devices = mock.Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(
-                {
+        self.hs.get_federation_client().query_user_devices = mock.AsyncMock(  # type: ignore[method-assign]
+            return_value={
+                "user_id": remote_user_id,
+                "stream_id": 1,
+                "devices": [],
+                "master_key": {
                     "user_id": remote_user_id,
-                    "stream_id": 1,
-                    "devices": [],
-                    "master_key": {
-                        "user_id": remote_user_id,
-                        "usage": ["master"],
-                        "keys": {"ed25519:" + remote_master_key: remote_master_key},
-                    },
-                    "self_signing_key": {
-                        "user_id": remote_user_id,
-                        "usage": ["self_signing"],
-                        "keys": {
-                            "ed25519:"
-                            + remote_self_signing_key: remote_self_signing_key
-                        },
+                    "usage": ["master"],
+                    "keys": {"ed25519:" + remote_master_key: remote_master_key},
+                },
+                "self_signing_key": {
+                    "user_id": remote_user_id,
+                    "usage": ["self_signing"],
+                    "keys": {
+                        "ed25519:" + remote_self_signing_key: remote_self_signing_key
                     },
-                }
-            )
+                },
+            }
         )
 
         e2e_handler = self.hs.get_e2e_keys_handler()
@@ -978,20 +979,20 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         mock_get_rooms = mock.patch.object(
             self.store,
             "get_rooms_for_user",
-            new_callable=mock.MagicMock,
-            return_value=make_awaitable(["some_room_id"]),
+            new_callable=mock.AsyncMock,
+            return_value=["some_room_id"],
         )
         mock_get_users = mock.patch.object(
             self.store,
             "get_users_server_still_shares_room_with",
-            new_callable=mock.MagicMock,
-            return_value=make_awaitable({remote_user_id}),
+            new_callable=mock.AsyncMock,
+            return_value={remote_user_id},
         )
         mock_request = mock.patch.object(
             self.hs.get_federation_client(),
             "query_user_devices",
-            new_callable=mock.MagicMock,
-            return_value=make_awaitable(response_body),
+            new_callable=mock.AsyncMock,
+            return_value=response_body,
         )
 
         with mock_get_rooms, mock_get_users, mock_request as mocked_federation_request:
@@ -1051,8 +1052,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         )
 
         # Setup a response, but only for device 2.
-        self.appservice_api.claim_client_keys.return_value = make_awaitable(
-            ({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1", 1)])
+        self.appservice_api.claim_client_keys.return_value = (
+            {local_user: {device_id_2: otk}},
+            [(local_user, device_id_1, "alg1", 1)],
         )
 
         # we shouldn't have any unused fallback keys yet
@@ -1080,6 +1082,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {local_user: {device_id_1: {"alg1": 1}, device_id_2: {"alg1": 1}}},
+                self.requester,
                 timeout=None,
                 always_include_fallback_keys=False,
             )
@@ -1117,14 +1120,16 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         )
 
         # Setup a response.
-        self.appservice_api.claim_client_keys.return_value = make_awaitable(
-            ({local_user: {device_id_1: {**as_otk, **as_fallback_key}}}, [])
-        )
+        response: Dict[str, Dict[str, Dict[str, JsonDict]]] = {
+            local_user: {device_id_1: {**as_otk, **as_fallback_key}}
+        }
+        self.appservice_api.claim_client_keys.return_value = (response, [])
 
         # Claim OTKs, which will ask the appservice and do nothing else.
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {local_user: {device_id_1: {"alg1": 1}}},
+                self.requester,
                 timeout=None,
                 always_include_fallback_keys=True,
             )
@@ -1160,8 +1165,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         self.assertEqual(fallback_res, ["alg1"])
 
         # The appservice will return only the OTK.
-        self.appservice_api.claim_client_keys.return_value = make_awaitable(
-            ({local_user: {device_id_1: as_otk}}, [])
+        self.appservice_api.claim_client_keys.return_value = (
+            {local_user: {device_id_1: as_otk}},
+            [],
         )
 
         # Claim OTKs, which should return the OTK from the appservice and the
@@ -1169,6 +1175,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {local_user: {device_id_1: {"alg1": 1}}},
+                self.requester,
                 timeout=None,
                 always_include_fallback_keys=True,
             )
@@ -1202,6 +1209,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {local_user: {device_id_1: {"alg1": 1}}},
+                self.requester,
                 timeout=None,
                 always_include_fallback_keys=True,
             )
@@ -1221,14 +1229,16 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         self.assertEqual(fallback_res, ["alg1"])
 
         # Finally, return only the fallback key from the appservice.
-        self.appservice_api.claim_client_keys.return_value = make_awaitable(
-            ({local_user: {device_id_1: as_fallback_key}}, [])
+        self.appservice_api.claim_client_keys.return_value = (
+            {local_user: {device_id_1: as_fallback_key}},
+            [],
         )
 
         # Claim OTKs, which will return only the fallback key from the database.
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {local_user: {device_id_1: {"alg1": 1}}},
+                self.requester,
                 timeout=None,
                 always_include_fallback_keys=True,
             )
@@ -1336,13 +1346,11 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         )
 
         # Setup a response.
-        self.appservice_api.query_keys.return_value = make_awaitable(
-            {
-                "device_keys": {
-                    local_user: {device_2: device_key_2b, device_3: device_key_3}
-                }
+        self.appservice_api.query_keys.return_value = {
+            "device_keys": {
+                local_user: {device_2: device_key_2b, device_3: device_key_3}
             }
-        )
+        }
 
         # Request all devices.
         res = self.get_success(self.handler.query_local_devices({local_user: None}))
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index bf0862ed54..4fc0742413 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -14,7 +14,7 @@
 import logging
 from typing import Collection, Optional, cast
 from unittest import TestCase
-from unittest.mock import Mock, patch
+from unittest.mock import AsyncMock, Mock, patch
 
 from twisted.internet.defer import Deferred
 from twisted.test.proto_helpers import MemoryReactor
@@ -40,7 +40,7 @@ from synapse.util import Clock
 from synapse.util.stringutils import random_string
 
 from tests import unittest
-from tests.test_utils import event_injection, make_awaitable
+from tests.test_utils import event_injection
 
 logger = logging.getLogger(__name__)
 
@@ -57,7 +57,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
     ]
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-        hs = self.setup_test_homeserver(federation_http_client=None)
+        hs = self.setup_test_homeserver()
         self.handler = hs.get_federation_handler()
         self.store = hs.get_datastores().main
         return hs
@@ -262,7 +262,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
             if (ev.type, ev.state_key)
             in {("m.room.create", ""), ("m.room.member", remote_server_user_id)}
         ]
-        for _ in range(0, 8):
+        for _ in range(8):
             event = make_event_from_dict(
                 self.add_hashes_and_signatures_from_other_server(
                     {
@@ -370,15 +370,15 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
 
         # We mock out the FederationClient.backfill method, to pretend that a remote
         # server has returned our fake event.
-        federation_client_backfill_mock = Mock(return_value=make_awaitable([event]))
-        self.hs.get_federation_client().backfill = federation_client_backfill_mock  # type: ignore[assignment]
+        federation_client_backfill_mock = AsyncMock(return_value=[event])
+        self.hs.get_federation_client().backfill = federation_client_backfill_mock  # type: ignore[method-assign]
 
         # We also mock the persist method with a side effect of itself. This allows us
         # to track when it has been called while preserving its function.
         persist_events_and_notify_mock = Mock(
             side_effect=self.hs.get_federation_event_handler().persist_events_and_notify
         )
-        self.hs.get_federation_event_handler().persist_events_and_notify = (  # type: ignore[assignment]
+        self.hs.get_federation_event_handler().persist_events_and_notify = (  # type: ignore[method-assign]
             persist_events_and_notify_mock
         )
 
@@ -631,33 +631,29 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
             },
             RoomVersions.V10,
         )
-        mock_make_membership_event = Mock(
-            return_value=make_awaitable(
-                (
-                    "example.com",
-                    membership_event,
-                    RoomVersions.V10,
-                )
+        mock_make_membership_event = AsyncMock(
+            return_value=(
+                "example.com",
+                membership_event,
+                RoomVersions.V10,
             )
         )
-        mock_send_join = Mock(
-            return_value=make_awaitable(
-                SendJoinResult(
-                    membership_event,
-                    "example.com",
-                    state=[
-                        EVENT_CREATE,
-                        EVENT_CREATOR_MEMBERSHIP,
-                        EVENT_INVITATION_MEMBERSHIP,
-                    ],
-                    auth_chain=[
-                        EVENT_CREATE,
-                        EVENT_CREATOR_MEMBERSHIP,
-                        EVENT_INVITATION_MEMBERSHIP,
-                    ],
-                    partial_state=True,
-                    servers_in_room={"example.com"},
-                )
+        mock_send_join = AsyncMock(
+            return_value=SendJoinResult(
+                membership_event,
+                "example.com",
+                state=[
+                    EVENT_CREATE,
+                    EVENT_CREATOR_MEMBERSHIP,
+                    EVENT_INVITATION_MEMBERSHIP,
+                ],
+                auth_chain=[
+                    EVENT_CREATE,
+                    EVENT_CREATOR_MEMBERSHIP,
+                    EVENT_INVITATION_MEMBERSHIP,
+                ],
+                partial_state=True,
+                servers_in_room={"example.com"},
             )
         )
 
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index c067e5bfe3..70e6a7e142 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -35,7 +35,7 @@ from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import event_injection, make_awaitable
+from tests.test_utils import event_injection
 
 
 class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
@@ -50,6 +50,10 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
         self.mock_federation_transport_client = mock.Mock(
             spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"]
         )
+        self.mock_federation_transport_client.get_room_state_ids = mock.AsyncMock()
+        self.mock_federation_transport_client.get_room_state = mock.AsyncMock()
+        self.mock_federation_transport_client.get_event = mock.AsyncMock()
+        self.mock_federation_transport_client.backfill = mock.AsyncMock()
         return super().setup_test_homeserver(
             federation_transport_client=self.mock_federation_transport_client
         )
@@ -198,20 +202,14 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
         )
 
         # we expect an outbound request to /state_ids, so stub that out
-        self.mock_federation_transport_client.get_room_state_ids.return_value = (
-            make_awaitable(
-                {
-                    "pdu_ids": [e.event_id for e in state_at_prev_event],
-                    "auth_chain_ids": [],
-                }
-            )
-        )
+        self.mock_federation_transport_client.get_room_state_ids.return_value = {
+            "pdu_ids": [e.event_id for e in state_at_prev_event],
+            "auth_chain_ids": [],
+        }
 
         # we also expect an outbound request to /state
         self.mock_federation_transport_client.get_room_state.return_value = (
-            make_awaitable(
-                StateRequestResponse(auth_events=[], state=state_at_prev_event)
-            )
+            StateRequestResponse(auth_events=[], state=state_at_prev_event)
         )
 
         # we have to bump the clock a bit, to keep the retry logic in
@@ -273,26 +271,23 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
         room_version = self.get_success(main_store.get_room_version(room_id))
 
         # We expect an outbound request to /state_ids, so stub that out
-        self.mock_federation_transport_client.get_room_state_ids.return_value = make_awaitable(
-            {
-                # Mimic the other server not knowing about the state at all.
-                # We want to cause Synapse to throw an error (`Unable to get
-                # missing prev_event $fake_prev_event`) and fail to backfill
-                # the pulled event.
-                "pdu_ids": [],
-                "auth_chain_ids": [],
-            }
-        )
+        self.mock_federation_transport_client.get_room_state_ids.return_value = {
+            # Mimic the other server not knowing about the state at all.
+            # We want to cause Synapse to throw an error (`Unable to get
+            # missing prev_event $fake_prev_event`) and fail to backfill
+            # the pulled event.
+            "pdu_ids": [],
+            "auth_chain_ids": [],
+        }
+
         # We also expect an outbound request to /state
-        self.mock_federation_transport_client.get_room_state.return_value = make_awaitable(
-            StateRequestResponse(
-                # Mimic the other server not knowing about the state at all.
-                # We want to cause Synapse to throw an error (`Unable to get
-                # missing prev_event $fake_prev_event`) and fail to backfill
-                # the pulled event.
-                auth_events=[],
-                state=[],
-            )
+        self.mock_federation_transport_client.get_room_state.return_value = StateRequestResponse(
+            # Mimic the other server not knowing about the state at all.
+            # We want to cause Synapse to throw an error (`Unable to get
+            # missing prev_event $fake_prev_event`) and fail to backfill
+            # the pulled event.
+            auth_events=[],
+            state=[],
         )
 
         pulled_event = make_event_from_dict(
@@ -545,25 +540,23 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
         )
 
         # We expect an outbound request to /backfill, so stub that out
-        self.mock_federation_transport_client.backfill.return_value = make_awaitable(
-            {
-                "origin": self.OTHER_SERVER_NAME,
-                "origin_server_ts": 123,
-                "pdus": [
-                    # This is one of the important aspects of this test: we include
-                    # `pulled_event_without_signatures` so it fails the signature check
-                    # when we filter down the backfill response down to events which
-                    # have valid signatures in
-                    # `_check_sigs_and_hash_for_pulled_events_and_fetch`
-                    pulled_event_without_signatures.get_pdu_json(),
-                    # Then later when we process this valid signature event, when we
-                    # fetch the missing `prev_event`s, we want to make sure that we
-                    # backoff and don't try and fetch `pulled_event_without_signatures`
-                    # again since we know it just had an invalid signature.
-                    pulled_event.get_pdu_json(),
-                ],
-            }
-        )
+        self.mock_federation_transport_client.backfill.return_value = {
+            "origin": self.OTHER_SERVER_NAME,
+            "origin_server_ts": 123,
+            "pdus": [
+                # This is one of the important aspects of this test: we include
+                # `pulled_event_without_signatures` so it fails the signature check
+                # when we filter down the backfill response down to events which
+                # have valid signatures in
+                # `_check_sigs_and_hash_for_pulled_events_and_fetch`
+                pulled_event_without_signatures.get_pdu_json(),
+                # Then later when we process this valid signature event, when we
+                # fetch the missing `prev_event`s, we want to make sure that we
+                # backoff and don't try and fetch `pulled_event_without_signatures`
+                # again since we know it just had an invalid signature.
+                pulled_event.get_pdu_json(),
+            ],
+        }
 
         # Keep track of the count and make sure we don't make any of these requests
         event_endpoint_requested_count = 0
@@ -664,6 +657,99 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
             StoreError,
         )
 
+    def test_backfill_process_previously_failed_pull_attempt_event_in_the_background(
+        self,
+    ) -> None:
+        """
+        Sanity check that events are still processed even if it is in the background
+        for events that already have failed pull attempts.
+        """
+        OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
+        main_store = self.hs.get_datastores().main
+
+        # Create the room
+        user_id = self.register_user("kermit", "test")
+        tok = self.login("kermit", "test")
+        room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+        room_version = self.get_success(main_store.get_room_version(room_id))
+
+        # Allow the remote user to send state events
+        self.helper.send_state(
+            room_id,
+            "m.room.power_levels",
+            {"events_default": 0, "state_default": 0},
+            tok=tok,
+        )
+
+        # Add the remote user to the room
+        member_event = self.get_success(
+            event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join")
+        )
+
+        initial_state_map = self.get_success(
+            main_store.get_partial_current_state_ids(room_id)
+        )
+
+        auth_event_ids = [
+            initial_state_map[("m.room.create", "")],
+            initial_state_map[("m.room.power_levels", "")],
+            member_event.event_id,
+        ]
+
+        # Create a regular event that should process
+        pulled_event = make_event_from_dict(
+            self.add_hashes_and_signatures_from_other_server(
+                {
+                    "type": "test_regular_type",
+                    "room_id": room_id,
+                    "sender": OTHER_USER,
+                    "prev_events": [
+                        member_event.event_id,
+                    ],
+                    "auth_events": auth_event_ids,
+                    "origin_server_ts": 1,
+                    "depth": 12,
+                    "content": {"body": "pulled_event"},
+                }
+            ),
+            room_version,
+        )
+
+        # Record a failed pull attempt for this event which will cause us to backfill it
+        # in the background from here on out.
+        self.get_success(
+            main_store.record_event_failed_pull_attempt(
+                room_id, pulled_event.event_id, "fake cause"
+            )
+        )
+
+        # We expect an outbound request to /backfill, so stub that out
+        self.mock_federation_transport_client.backfill.return_value = {
+            "origin": self.OTHER_SERVER_NAME,
+            "origin_server_ts": 123,
+            "pdus": [
+                pulled_event.get_pdu_json(),
+            ],
+        }
+
+        # The function under test: try to backfill and process the pulled event
+        with LoggingContext("test"):
+            self.get_success(
+                self.hs.get_federation_event_handler().backfill(
+                    self.OTHER_SERVER_NAME,
+                    room_id,
+                    limit=1,
+                    extremities=["$some_extremity"],
+                )
+            )
+
+        # Ensure `run_as_background_process(...)` has a chance to run (essentially
+        # `wait_for_background_processes()`)
+        self.reactor.pump((0.1,))
+
+        # Make sure we processed and persisted the pulled event
+        self.get_success(main_store.get_event(pulled_event.event_id, allow_none=False))
+
     def test_process_pulled_event_with_rejected_missing_state(self) -> None:
         """Ensure that we correctly handle pulled events with missing state containing a
         rejected state event
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index 9691d66b48..1c5897c84e 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -46,18 +46,11 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
         self._persist_event_storage_controller = persistence
 
         self.user_id = self.register_user("tester", "foobar")
-        self.access_token = self.login("tester", "foobar")
-        self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
-
-        info = self.get_success(
-            self.hs.get_datastores().main.get_user_by_access_token(
-                self.access_token,
-            )
-        )
-        assert info is not None
-        self.token_id = info.token_id
+        device_id = "dev-1"
+        access_token = self.login("tester", "foobar", device_id=device_id)
+        self.room_id = self.helper.create_room_as(self.user_id, tok=access_token)
 
-        self.requester = create_requester(self.user_id, access_token_id=self.token_id)
+        self.requester = create_requester(self.user_id, device_id=device_id)
 
     def _create_and_persist_member_event(self) -> Tuple[EventBase, EventContext]:
         # Create a member event we can use as an auth_event
diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py
new file mode 100644
index 0000000000..a72ecfdc97
--- /dev/null
+++ b/tests/handlers/test_oauth_delegation.py
@@ -0,0 +1,687 @@
+# Copyright 2022 Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from http import HTTPStatus
+from typing import Any, Dict, Union
+from unittest.mock import ANY, AsyncMock, Mock
+from urllib.parse import parse_qs
+
+from signedjson.key import (
+    encode_verify_key_base64,
+    generate_signing_key,
+    get_verify_key,
+)
+from signedjson.sign import sign_json
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.errors import (
+    AuthError,
+    Codes,
+    InvalidClientTokenError,
+    OAuthInsufficientScopeError,
+    SynapseError,
+)
+from synapse.rest import admin
+from synapse.rest.client import account, devices, keys, login, logout, register
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
+
+from tests.test_utils import FakeResponse, get_awaitable_result
+from tests.unittest import HomeserverTestCase, skip_unless
+from tests.utils import mock_getRawHeaders
+
+try:
+    import authlib  # noqa: F401
+
+    HAS_AUTHLIB = True
+except ImportError:
+    HAS_AUTHLIB = False
+
+
+# These are a few constants that are used as config parameters in the tests.
+SERVER_NAME = "test"
+ISSUER = "https://issuer/"
+CLIENT_ID = "test-client-id"
+CLIENT_SECRET = "test-client-secret"
+BASE_URL = "https://synapse/"
+SCOPES = ["openid"]
+
+AUTHORIZATION_ENDPOINT = ISSUER + "authorize"
+TOKEN_ENDPOINT = ISSUER + "token"
+USERINFO_ENDPOINT = ISSUER + "userinfo"
+WELL_KNOWN = ISSUER + ".well-known/openid-configuration"
+JWKS_URI = ISSUER + ".well-known/jwks.json"
+INTROSPECTION_ENDPOINT = ISSUER + "introspect"
+
+SYNAPSE_ADMIN_SCOPE = "urn:synapse:admin:*"
+MATRIX_USER_SCOPE = "urn:matrix:org.matrix.msc2967.client:api:*"
+MATRIX_GUEST_SCOPE = "urn:matrix:org.matrix.msc2967.client:api:guest"
+MATRIX_DEVICE_SCOPE_PREFIX = "urn:matrix:org.matrix.msc2967.client:device:"
+DEVICE = "AABBCCDD"
+MATRIX_DEVICE_SCOPE = MATRIX_DEVICE_SCOPE_PREFIX + DEVICE
+SUBJECT = "abc-def-ghi"
+USERNAME = "test-user"
+USER_ID = "@" + USERNAME + ":" + SERVER_NAME
+
+
+async def get_json(url: str) -> JsonDict:
+    # Mock get_json calls to handle jwks & oidc discovery endpoints
+    if url == WELL_KNOWN:
+        # Minimal discovery document, as defined in OpenID.Discovery
+        # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
+        return {
+            "issuer": ISSUER,
+            "authorization_endpoint": AUTHORIZATION_ENDPOINT,
+            "token_endpoint": TOKEN_ENDPOINT,
+            "jwks_uri": JWKS_URI,
+            "userinfo_endpoint": USERINFO_ENDPOINT,
+            "introspection_endpoint": INTROSPECTION_ENDPOINT,
+            "response_types_supported": ["code"],
+            "subject_types_supported": ["public"],
+            "id_token_signing_alg_values_supported": ["RS256"],
+        }
+    elif url == JWKS_URI:
+        return {"keys": []}
+
+    return {}
+
+
+@skip_unless(HAS_AUTHLIB, "requires authlib")
+class MSC3861OAuthDelegation(HomeserverTestCase):
+    servlets = [
+        account.register_servlets,
+        devices.register_servlets,
+        keys.register_servlets,
+        register.register_servlets,
+        login.register_servlets,
+        logout.register_servlets,
+        admin.register_servlets,
+    ]
+
+    def default_config(self) -> Dict[str, Any]:
+        config = super().default_config()
+        config["public_baseurl"] = BASE_URL
+        config["disable_registration"] = True
+        config["experimental_features"] = {
+            "msc3861": {
+                "enabled": True,
+                "issuer": ISSUER,
+                "client_id": CLIENT_ID,
+                "client_auth_method": "client_secret_post",
+                "client_secret": CLIENT_SECRET,
+                "admin_token": "admin_token_value",
+            }
+        }
+        return config
+
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+        self.http_client = Mock(spec=["get_json"])
+        self.http_client.get_json.side_effect = get_json
+        self.http_client.user_agent = b"Synapse Test"
+
+        hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
+
+        self.auth = hs.get_auth()
+
+        return hs
+
+    def _assertParams(self) -> None:
+        """Assert that the request parameters are correct."""
+        params = parse_qs(self.http_client.request.call_args[1]["data"].decode("utf-8"))
+        self.assertEqual(params["token"], ["mockAccessToken"])
+        self.assertEqual(params["client_id"], [CLIENT_ID])
+        self.assertEqual(params["client_secret"], [CLIENT_SECRET])
+
+    def test_inactive_token(self) -> None:
+        """The handler should return a 403 where the token is inactive."""
+
+        self.http_client.request = AsyncMock(
+            return_value=FakeResponse.json(
+                code=200,
+                payload={"active": False},
+            )
+        )
+        request = Mock(args={})
+        request.args[b"access_token"] = [b"mockAccessToken"]
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+        self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
+        self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+        self.http_client.request.assert_called_once_with(
+            method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+        )
+        self._assertParams()
+
+    def test_active_no_scope(self) -> None:
+        """The handler should return a 403 where no scope is given."""
+
+        self.http_client.request = AsyncMock(
+            return_value=FakeResponse.json(
+                code=200,
+                payload={"active": True},
+            )
+        )
+        request = Mock(args={})
+        request.args[b"access_token"] = [b"mockAccessToken"]
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+        self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
+        self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+        self.http_client.request.assert_called_once_with(
+            method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+        )
+        self._assertParams()
+
+    def test_active_user_no_subject(self) -> None:
+        """The handler should return a 500 when no subject is present."""
+
+        self.http_client.request = AsyncMock(
+            return_value=FakeResponse.json(
+                code=200,
+                payload={"active": True, "scope": " ".join([MATRIX_USER_SCOPE])},
+            )
+        )
+        request = Mock(args={})
+        request.args[b"access_token"] = [b"mockAccessToken"]
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+        self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
+        self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+        self.http_client.request.assert_called_once_with(
+            method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+        )
+        self._assertParams()
+
+    def test_active_no_user_scope(self) -> None:
+        """The handler should return a 500 when no subject is present."""
+
+        self.http_client.request = AsyncMock(
+            return_value=FakeResponse.json(
+                code=200,
+                payload={
+                    "active": True,
+                    "sub": SUBJECT,
+                    "scope": " ".join([MATRIX_DEVICE_SCOPE]),
+                },
+            )
+        )
+        request = Mock(args={})
+        request.args[b"access_token"] = [b"mockAccessToken"]
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+        self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
+        self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+        self.http_client.request.assert_called_once_with(
+            method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+        )
+        self._assertParams()
+
+    def test_active_admin_not_user(self) -> None:
+        """The handler should raise when the scope has admin right but not user."""
+
+        self.http_client.request = AsyncMock(
+            return_value=FakeResponse.json(
+                code=200,
+                payload={
+                    "active": True,
+                    "sub": SUBJECT,
+                    "scope": " ".join([SYNAPSE_ADMIN_SCOPE]),
+                    "username": USERNAME,
+                },
+            )
+        )
+        request = Mock(args={})
+        request.args[b"access_token"] = [b"mockAccessToken"]
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+        self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
+        self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+        self.http_client.request.assert_called_once_with(
+            method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+        )
+        self._assertParams()
+
+    def test_active_admin(self) -> None:
+        """The handler should return a requester with admin rights."""
+
+        self.http_client.request = AsyncMock(
+            return_value=FakeResponse.json(
+                code=200,
+                payload={
+                    "active": True,
+                    "sub": SUBJECT,
+                    "scope": " ".join([SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE]),
+                    "username": USERNAME,
+                },
+            )
+        )
+        request = Mock(args={})
+        request.args[b"access_token"] = [b"mockAccessToken"]
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+        requester = self.get_success(self.auth.get_user_by_req(request))
+        self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+        self.http_client.request.assert_called_once_with(
+            method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+        )
+        self._assertParams()
+        self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
+        self.assertEqual(requester.is_guest, False)
+        self.assertEqual(requester.device_id, None)
+        self.assertEqual(
+            get_awaitable_result(self.auth.is_server_admin(requester)), True
+        )
+
+    def test_active_admin_highest_privilege(self) -> None:
+        """The handler should resolve to the most permissive scope."""
+
+        self.http_client.request = AsyncMock(
+            return_value=FakeResponse.json(
+                code=200,
+                payload={
+                    "active": True,
+                    "sub": SUBJECT,
+                    "scope": " ".join(
+                        [SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE, MATRIX_GUEST_SCOPE]
+                    ),
+                    "username": USERNAME,
+                },
+            )
+        )
+        request = Mock(args={})
+        request.args[b"access_token"] = [b"mockAccessToken"]
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+        requester = self.get_success(self.auth.get_user_by_req(request))
+        self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+        self.http_client.request.assert_called_once_with(
+            method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+        )
+        self._assertParams()
+        self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
+        self.assertEqual(requester.is_guest, False)
+        self.assertEqual(requester.device_id, None)
+        self.assertEqual(
+            get_awaitable_result(self.auth.is_server_admin(requester)), True
+        )
+
+    def test_active_user(self) -> None:
+        """The handler should return a requester with normal user rights."""
+
+        self.http_client.request = AsyncMock(
+            return_value=FakeResponse.json(
+                code=200,
+                payload={
+                    "active": True,
+                    "sub": SUBJECT,
+                    "scope": " ".join([MATRIX_USER_SCOPE]),
+                    "username": USERNAME,
+                },
+            )
+        )
+        request = Mock(args={})
+        request.args[b"access_token"] = [b"mockAccessToken"]
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+        requester = self.get_success(self.auth.get_user_by_req(request))
+        self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+        self.http_client.request.assert_called_once_with(
+            method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+        )
+        self._assertParams()
+        self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
+        self.assertEqual(requester.is_guest, False)
+        self.assertEqual(requester.device_id, None)
+        self.assertEqual(
+            get_awaitable_result(self.auth.is_server_admin(requester)), False
+        )
+
+    def test_active_user_with_device(self) -> None:
+        """The handler should return a requester with normal user rights and a device ID."""
+
+        self.http_client.request = AsyncMock(
+            return_value=FakeResponse.json(
+                code=200,
+                payload={
+                    "active": True,
+                    "sub": SUBJECT,
+                    "scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]),
+                    "username": USERNAME,
+                },
+            )
+        )
+        request = Mock(args={})
+        request.args[b"access_token"] = [b"mockAccessToken"]
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+        requester = self.get_success(self.auth.get_user_by_req(request))
+        self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+        self.http_client.request.assert_called_once_with(
+            method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+        )
+        self._assertParams()
+        self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
+        self.assertEqual(requester.is_guest, False)
+        self.assertEqual(
+            get_awaitable_result(self.auth.is_server_admin(requester)), False
+        )
+        self.assertEqual(requester.device_id, DEVICE)
+
+    def test_multiple_devices(self) -> None:
+        """The handler should raise an error if multiple devices are found in the scope."""
+
+        self.http_client.request = AsyncMock(
+            return_value=FakeResponse.json(
+                code=200,
+                payload={
+                    "active": True,
+                    "sub": SUBJECT,
+                    "scope": " ".join(
+                        [
+                            MATRIX_USER_SCOPE,
+                            f"{MATRIX_DEVICE_SCOPE_PREFIX}AABBCC",
+                            f"{MATRIX_DEVICE_SCOPE_PREFIX}DDEEFF",
+                        ]
+                    ),
+                    "username": USERNAME,
+                },
+            )
+        )
+        request = Mock(args={})
+        request.args[b"access_token"] = [b"mockAccessToken"]
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+        self.get_failure(self.auth.get_user_by_req(request), AuthError)
+
+    def test_active_guest_not_allowed(self) -> None:
+        """The handler should return an insufficient scope error."""
+
+        self.http_client.request = AsyncMock(
+            return_value=FakeResponse.json(
+                code=200,
+                payload={
+                    "active": True,
+                    "sub": SUBJECT,
+                    "scope": " ".join([MATRIX_GUEST_SCOPE, MATRIX_DEVICE_SCOPE]),
+                    "username": USERNAME,
+                },
+            )
+        )
+        request = Mock(args={})
+        request.args[b"access_token"] = [b"mockAccessToken"]
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+        error = self.get_failure(
+            self.auth.get_user_by_req(request), OAuthInsufficientScopeError
+        )
+        self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+        self.http_client.request.assert_called_once_with(
+            method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+        )
+        self._assertParams()
+        self.assertEqual(
+            getattr(error.value, "headers", {})["WWW-Authenticate"],
+            'Bearer error="insufficient_scope", scope="urn:matrix:org.matrix.msc2967.client:api:*"',
+        )
+
+    def test_active_guest_allowed(self) -> None:
+        """The handler should return a requester with guest user rights and a device ID."""
+
+        self.http_client.request = AsyncMock(
+            return_value=FakeResponse.json(
+                code=200,
+                payload={
+                    "active": True,
+                    "sub": SUBJECT,
+                    "scope": " ".join([MATRIX_GUEST_SCOPE, MATRIX_DEVICE_SCOPE]),
+                    "username": USERNAME,
+                },
+            )
+        )
+        request = Mock(args={})
+        request.args[b"access_token"] = [b"mockAccessToken"]
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+        requester = self.get_success(
+            self.auth.get_user_by_req(request, allow_guest=True)
+        )
+        self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+        self.http_client.request.assert_called_once_with(
+            method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+        )
+        self._assertParams()
+        self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
+        self.assertEqual(requester.is_guest, True)
+        self.assertEqual(
+            get_awaitable_result(self.auth.is_server_admin(requester)), False
+        )
+        self.assertEqual(requester.device_id, DEVICE)
+
+    def test_unavailable_introspection_endpoint(self) -> None:
+        """The handler should return an internal server error."""
+        request = Mock(args={})
+        request.args[b"access_token"] = [b"mockAccessToken"]
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+
+        # The introspection endpoint is returning an error.
+        self.http_client.request = AsyncMock(
+            return_value=FakeResponse(code=500, body=b"Internal Server Error")
+        )
+        error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
+        self.assertEqual(error.value.code, 503)
+
+        # The introspection endpoint request fails.
+        self.http_client.request = AsyncMock(side_effect=Exception())
+        error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
+        self.assertEqual(error.value.code, 503)
+
+        # The introspection endpoint does not return a JSON object.
+        self.http_client.request = AsyncMock(
+            return_value=FakeResponse.json(
+                code=200, payload=["this is an array", "not an object"]
+            )
+        )
+        error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
+        self.assertEqual(error.value.code, 503)
+
+        # The introspection endpoint does not return valid JSON.
+        self.http_client.request = AsyncMock(
+            return_value=FakeResponse(code=200, body=b"this is not valid JSON")
+        )
+        error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
+        self.assertEqual(error.value.code, 503)
+
+    def make_device_keys(self, user_id: str, device_id: str) -> JsonDict:
+        # We only generate a master key to simplify the test.
+        master_signing_key = generate_signing_key(device_id)
+        master_verify_key = encode_verify_key_base64(get_verify_key(master_signing_key))
+
+        return {
+            "master_key": sign_json(
+                {
+                    "user_id": user_id,
+                    "usage": ["master"],
+                    "keys": {"ed25519:" + master_verify_key: master_verify_key},
+                },
+                user_id,
+                master_signing_key,
+            ),
+        }
+
+    def test_cross_signing(self) -> None:
+        """Try uploading device keys with OAuth delegation enabled."""
+
+        self.http_client.request = AsyncMock(
+            return_value=FakeResponse.json(
+                code=200,
+                payload={
+                    "active": True,
+                    "sub": SUBJECT,
+                    "scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]),
+                    "username": USERNAME,
+                },
+            )
+        )
+        keys_upload_body = self.make_device_keys(USER_ID, DEVICE)
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/v3/keys/device_signing/upload",
+            keys_upload_body,
+            access_token="mockAccessToken",
+        )
+
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/v3/keys/device_signing/upload",
+            keys_upload_body,
+            access_token="mockAccessToken",
+        )
+
+        self.assertEqual(channel.code, HTTPStatus.NOT_IMPLEMENTED, channel.json_body)
+
+    def expect_unauthorized(
+        self, method: str, path: str, content: Union[bytes, str, JsonDict] = ""
+    ) -> None:
+        channel = self.make_request(method, path, content, shorthand=False)
+
+        self.assertEqual(channel.code, 401, channel.json_body)
+
+    def expect_unrecognized(
+        self, method: str, path: str, content: Union[bytes, str, JsonDict] = ""
+    ) -> None:
+        channel = self.make_request(method, path, content)
+
+        self.assertEqual(channel.code, 404, channel.json_body)
+        self.assertEqual(
+            channel.json_body["errcode"], Codes.UNRECOGNIZED, channel.json_body
+        )
+
+    def test_uia_endpoints(self) -> None:
+        """Test that endpoints that were removed in MSC2964 are no longer available."""
+
+        # This is just an endpoint that should remain visible (but requires auth):
+        self.expect_unauthorized("GET", "/_matrix/client/v3/devices")
+
+        # This remains usable, but will require a uia scope:
+        self.expect_unauthorized(
+            "POST", "/_matrix/client/v3/keys/device_signing/upload"
+        )
+
+    def test_3pid_endpoints(self) -> None:
+        """Test that 3pid account management endpoints that were removed in MSC2964 are no longer available."""
+
+        # Remains and requires auth:
+        self.expect_unauthorized("GET", "/_matrix/client/v3/account/3pid")
+        self.expect_unauthorized(
+            "POST",
+            "/_matrix/client/v3/account/3pid/bind",
+            {
+                "client_secret": "foo",
+                "id_access_token": "bar",
+                "id_server": "foo",
+                "sid": "bar",
+            },
+        )
+        self.expect_unauthorized("POST", "/_matrix/client/v3/account/3pid/unbind", {})
+
+        # These are gone:
+        self.expect_unrecognized(
+            "POST", "/_matrix/client/v3/account/3pid"
+        )  # deprecated
+        self.expect_unrecognized("POST", "/_matrix/client/v3/account/3pid/add")
+        self.expect_unrecognized("POST", "/_matrix/client/v3/account/3pid/delete")
+        self.expect_unrecognized(
+            "POST", "/_matrix/client/v3/account/3pid/email/requestToken"
+        )
+        self.expect_unrecognized(
+            "POST", "/_matrix/client/v3/account/3pid/msisdn/requestToken"
+        )
+
+    def test_account_management_endpoints_removed(self) -> None:
+        """Test that account management endpoints that were removed in MSC2964 are no longer available."""
+        self.expect_unrecognized("POST", "/_matrix/client/v3/account/deactivate")
+        self.expect_unrecognized("POST", "/_matrix/client/v3/account/password")
+        self.expect_unrecognized(
+            "POST", "/_matrix/client/v3/account/password/email/requestToken"
+        )
+        self.expect_unrecognized(
+            "POST", "/_matrix/client/v3/account/password/msisdn/requestToken"
+        )
+
+    def test_registration_endpoints_removed(self) -> None:
+        """Test that registration endpoints that were removed in MSC2964 are no longer available."""
+        self.expect_unrecognized(
+            "GET", "/_matrix/client/v1/register/m.login.registration_token/validity"
+        )
+        # This is still available for AS registrations
+        # self.expect_unrecognized("POST", "/_matrix/client/v3/register")
+        self.expect_unrecognized("GET", "/_matrix/client/v3/register/available")
+        self.expect_unrecognized(
+            "POST", "/_matrix/client/v3/register/email/requestToken"
+        )
+        self.expect_unrecognized(
+            "POST", "/_matrix/client/v3/register/msisdn/requestToken"
+        )
+
+    def test_session_management_endpoints_removed(self) -> None:
+        """Test that session management endpoints that were removed in MSC2964 are no longer available."""
+        self.expect_unrecognized("GET", "/_matrix/client/v3/login")
+        self.expect_unrecognized("POST", "/_matrix/client/v3/login")
+        self.expect_unrecognized("GET", "/_matrix/client/v3/login/sso/redirect")
+        self.expect_unrecognized("POST", "/_matrix/client/v3/logout")
+        self.expect_unrecognized("POST", "/_matrix/client/v3/logout/all")
+        self.expect_unrecognized("POST", "/_matrix/client/v3/refresh")
+        self.expect_unrecognized("GET", "/_matrix/static/client/login")
+
+    def test_device_management_endpoints_removed(self) -> None:
+        """Test that device management endpoints that were removed in MSC2964 are no longer available."""
+        self.expect_unrecognized("POST", "/_matrix/client/v3/delete_devices")
+        self.expect_unrecognized("DELETE", "/_matrix/client/v3/devices/{DEVICE}")
+
+    def test_openid_endpoints_removed(self) -> None:
+        """Test that OpenID id_token endpoints that were removed in MSC2964 are no longer available."""
+        self.expect_unrecognized(
+            "POST", "/_matrix/client/v3/user/{USERNAME}/openid/request_token"
+        )
+
+    def test_admin_api_endpoints_removed(self) -> None:
+        """Test that admin API endpoints that were removed in MSC2964 are no longer available."""
+        self.expect_unrecognized("GET", "/_synapse/admin/v1/registration_tokens")
+        self.expect_unrecognized("POST", "/_synapse/admin/v1/registration_tokens/new")
+        self.expect_unrecognized("GET", "/_synapse/admin/v1/registration_tokens/abcd")
+        self.expect_unrecognized("PUT", "/_synapse/admin/v1/registration_tokens/abcd")
+        self.expect_unrecognized(
+            "DELETE", "/_synapse/admin/v1/registration_tokens/abcd"
+        )
+        self.expect_unrecognized("POST", "/_synapse/admin/v1/reset_password/foo")
+        self.expect_unrecognized("POST", "/_synapse/admin/v1/users/foo/login")
+        self.expect_unrecognized("GET", "/_synapse/admin/v1/register")
+        self.expect_unrecognized("POST", "/_synapse/admin/v1/register")
+        self.expect_unrecognized("GET", "/_synapse/admin/v1/users/foo/admin")
+        self.expect_unrecognized("PUT", "/_synapse/admin/v1/users/foo/admin")
+        self.expect_unrecognized("POST", "/_synapse/admin/v1/account_validity/validity")
+
+    def test_admin_token(self) -> None:
+        """The handler should return a requester with admin rights when admin_token is used."""
+        self.http_client.request = AsyncMock(
+            return_value=FakeResponse.json(code=200, payload={"active": False}),
+        )
+
+        request = Mock(args={})
+        request.args[b"access_token"] = [b"admin_token_value"]
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+        requester = self.get_success(self.auth.get_user_by_req(request))
+        self.assertEqual(
+            requester.user.to_string(), "@%s:%s" % ("__oidc_admin", SERVER_NAME)
+        )
+        self.assertEqual(requester.is_guest, False)
+        self.assertEqual(requester.device_id, None)
+        self.assertEqual(
+            get_awaitable_result(self.auth.is_server_admin(requester)), True
+        )
+
+        # There should be no call to the introspection endpoint
+        self.http_client.request.assert_not_called()
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 0a8bae54fb..e797aaae00 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 import os
 from typing import Any, Awaitable, ContextManager, Dict, Optional, Tuple
-from unittest.mock import ANY, Mock, patch
+from unittest.mock import ANY, AsyncMock, Mock, patch
 from urllib.parse import parse_qs, urlparse
 
 import pymacaroons
@@ -28,7 +28,7 @@ from synapse.util import Clock
 from synapse.util.macaroons import get_value_from_macaroon
 from synapse.util.stringutils import random_string
 
-from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
+from tests.test_utils import FakeResponse, get_awaitable_result
 from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer
 from tests.unittest import HomeserverTestCase, override_config
 
@@ -157,15 +157,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
         sso_handler = hs.get_sso_handler()
         # Mock the render error method.
         self.render_error = Mock(return_value=None)
-        sso_handler.render_error = self.render_error  # type: ignore[assignment]
+        sso_handler.render_error = self.render_error  # type: ignore[method-assign]
 
         # Reduce the number of attempts when generating MXIDs.
         sso_handler._MAP_USERNAME_RETRIES = 3
 
         auth_handler = hs.get_auth_handler()
         # Mock the complete SSO login method.
-        self.complete_sso_login = simple_async_mock()
-        auth_handler.complete_sso_login = self.complete_sso_login  # type: ignore[assignment]
+        self.complete_sso_login = AsyncMock()
+        auth_handler.complete_sso_login = self.complete_sso_login  # type: ignore[method-assign]
 
         return hs
 
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index aa91bc0a3d..11ec8c7f11 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -16,7 +16,9 @@
 
 from http import HTTPStatus
 from typing import Any, Dict, List, Optional, Type, Union
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
+
+from twisted.test.proto_helpers import MemoryReactor
 
 import synapse
 from synapse.api.constants import LoginType
@@ -24,11 +26,12 @@ from synapse.api.errors import Codes
 from synapse.handlers.account import AccountHandler
 from synapse.module_api import ModuleApi
 from synapse.rest.client import account, devices, login, logout, register
+from synapse.server import HomeServer
 from synapse.types import JsonDict, UserID
+from synapse.util import Clock
 
 from tests import unittest
 from tests.server import FakeChannel
-from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 
 # Login flows we expect to appear in the list after the normal ones.
@@ -162,10 +165,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
     CALLBACK_USERNAME = "get_username_for_registration"
     CALLBACK_DISPLAYNAME = "get_displayname_for_registration"
 
-    def setUp(self) -> None:
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+    ) -> None:
         # we use a global mock device, so make sure we are starting with a clean slate
         mock_password_provider.reset_mock()
-        super().setUp()
+
+        # The mock password provider doesn't register the users, so ensure they
+        # are registered first.
+        self.register_user("u", "not-the-tested-password")
+        self.register_user("user", "not-the-tested-password")
 
     @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
     def test_password_only_auth_progiver_login_legacy(self) -> None:
@@ -177,7 +186,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
 
         # check_password must return an awaitable
-        mock_password_provider.check_password.return_value = make_awaitable(True)
+        mock_password_provider.check_password = AsyncMock(return_value=True)
         channel = self._send_password_login("u", "p")
         self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
         self.assertEqual("@u:test", channel.json_body["user_id"])
@@ -185,22 +194,12 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         mock_password_provider.reset_mock()
 
         # login with mxid should work too
-        channel = self._send_password_login("@u:bz", "p")
+        channel = self._send_password_login("@u:test", "p")
         self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
-        self.assertEqual("@u:bz", channel.json_body["user_id"])
-        mock_password_provider.check_password.assert_called_once_with("@u:bz", "p")
+        self.assertEqual("@u:test", channel.json_body["user_id"])
+        mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
         mock_password_provider.reset_mock()
 
-        # try a weird username / pass. Honestly it's unclear what we *expect* to happen
-        # in these cases, but at least we can guard against the API changing
-        # unexpectedly
-        channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ")
-        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
-        self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"])
-        mock_password_provider.check_password.assert_called_once_with(
-            "@ USER🙂NAME :test", " pASS😢word "
-        )
-
     @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
     def test_password_only_auth_provider_ui_auth_legacy(self) -> None:
         self.password_only_auth_provider_ui_auth_test_body()
@@ -208,18 +207,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
     def password_only_auth_provider_ui_auth_test_body(self) -> None:
         """UI Auth should delegate correctly to the password provider"""
 
-        # create the user, otherwise access doesn't work
-        module_api = self.hs.get_module_api()
-        self.get_success(module_api.register_user("u"))
-
         # log in twice, to get two devices
-        mock_password_provider.check_password.return_value = make_awaitable(True)
+        mock_password_provider.check_password = AsyncMock(return_value=True)
         tok1 = self.login("u", "p")
         self.login("u", "p", device_id="dev2")
         mock_password_provider.reset_mock()
 
         # have the auth provider deny the request to start with
-        mock_password_provider.check_password.return_value = make_awaitable(False)
+        mock_password_provider.check_password = AsyncMock(return_value=False)
 
         # make the initial request which returns a 401
         session = self._start_delete_device_session(tok1, "dev2")
@@ -233,7 +228,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         mock_password_provider.reset_mock()
 
         # Finally, check the request goes through when we allow it
-        mock_password_provider.check_password.return_value = make_awaitable(True)
+        mock_password_provider.check_password = AsyncMock(return_value=True)
         channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
         self.assertEqual(channel.code, 200)
         mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
@@ -247,7 +242,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.register_user("localuser", "localpass")
 
         # check_password must return an awaitable
-        mock_password_provider.check_password.return_value = make_awaitable(False)
+        mock_password_provider.check_password = AsyncMock(return_value=False)
         channel = self._send_password_login("u", "p")
         self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
 
@@ -264,7 +259,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.register_user("localuser", "localpass")
 
         # have the auth provider deny the request
-        mock_password_provider.check_password.return_value = make_awaitable(False)
+        mock_password_provider.check_password = AsyncMock(return_value=False)
 
         # log in twice, to get two devices
         tok1 = self.login("localuser", "localpass")
@@ -307,7 +302,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.register_user("localuser", "localpass")
 
         # check_password must return an awaitable
-        mock_password_provider.check_password.return_value = make_awaitable(False)
+        mock_password_provider.check_password = AsyncMock(return_value=False)
         channel = self._send_password_login("localuser", "localpass")
         self.assertEqual(channel.code, 403)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@@ -329,7 +324,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.register_user("localuser", "localpass")
 
         # allow login via the auth provider
-        mock_password_provider.check_password.return_value = make_awaitable(True)
+        mock_password_provider.check_password = AsyncMock(return_value=True)
 
         # log in twice, to get two devices
         tok1 = self.login("localuser", "p")
@@ -346,7 +341,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         mock_password_provider.check_password.assert_not_called()
 
         # now try deleting with the local password
-        mock_password_provider.check_password.return_value = make_awaitable(False)
+        mock_password_provider.check_password = AsyncMock(return_value=False)
         channel = self._authed_delete_device(
             tok1, "dev2", session, "localuser", "localpass"
         )
@@ -400,30 +395,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
         mock_password_provider.check_auth.assert_not_called()
 
-        mock_password_provider.check_auth.return_value = make_awaitable(
-            ("@user:bz", None)
-        )
+        mock_password_provider.check_auth = AsyncMock(return_value=("@user:test", None))
         channel = self._send_login("test.login_type", "u", test_field="y")
         self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
-        self.assertEqual("@user:bz", channel.json_body["user_id"])
+        self.assertEqual("@user:test", channel.json_body["user_id"])
         mock_password_provider.check_auth.assert_called_once_with(
             "u", "test.login_type", {"test_field": "y"}
         )
         mock_password_provider.reset_mock()
 
-        # try a weird username. Again, it's unclear what we *expect* to happen
-        # in these cases, but at least we can guard against the API changing
-        # unexpectedly
-        mock_password_provider.check_auth.return_value = make_awaitable(
-            ("@ MALFORMED! :bz", None)
-        )
-        channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
-        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
-        self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"])
-        mock_password_provider.check_auth.assert_called_once_with(
-            " USER🙂NAME ", "test.login_type", {"test_field": " abc "}
-        )
-
     @override_config(legacy_providers_config(LegacyCustomAuthProvider))
     def test_custom_auth_provider_ui_auth_legacy(self) -> None:
         self.custom_auth_provider_ui_auth_test_body()
@@ -464,9 +444,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         mock_password_provider.reset_mock()
 
         # right params, but authing as the wrong user
-        mock_password_provider.check_auth.return_value = make_awaitable(
-            ("@user:bz", None)
-        )
+        mock_password_provider.check_auth = AsyncMock(return_value=("@user:test", None))
         body["auth"]["test_field"] = "foo"
         channel = self._delete_device(tok1, "dev2", body)
         self.assertEqual(channel.code, 403)
@@ -477,8 +455,8 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         mock_password_provider.reset_mock()
 
         # and finally, succeed
-        mock_password_provider.check_auth.return_value = make_awaitable(
-            ("@localuser:test", None)
+        mock_password_provider.check_auth = AsyncMock(
+            return_value=("@localuser:test", None)
         )
         channel = self._delete_device(tok1, "dev2", body)
         self.assertEqual(channel.code, 200)
@@ -495,14 +473,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.custom_auth_provider_callback_test_body()
 
     def custom_auth_provider_callback_test_body(self) -> None:
-        callback = Mock(return_value=make_awaitable(None))
+        callback = AsyncMock(return_value=None)
 
-        mock_password_provider.check_auth.return_value = make_awaitable(
-            ("@user:bz", callback)
+        mock_password_provider.check_auth = AsyncMock(
+            return_value=("@user:test", callback)
         )
         channel = self._send_login("test.login_type", "u", test_field="y")
         self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
-        self.assertEqual("@user:bz", channel.json_body["user_id"])
+        self.assertEqual("@user:test", channel.json_body["user_id"])
         mock_password_provider.check_auth.assert_called_once_with(
             "u", "test.login_type", {"test_field": "y"}
         )
@@ -512,7 +490,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         call_args, call_kwargs = callback.call_args
         # should be one positional arg
         self.assertEqual(len(call_args), 1)
-        self.assertEqual(call_args[0]["user_id"], "@user:bz")
+        self.assertEqual(call_args[0]["user_id"], "@user:test")
         for p in ["user_id", "access_token", "device_id", "home_server"]:
             self.assertIn(p, call_args[0])
 
@@ -633,8 +611,8 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         login is disabled"""
         # register the user and log in twice via the test login type to get two devices,
         self.register_user("localuser", "localpass")
-        mock_password_provider.check_auth.return_value = make_awaitable(
-            ("@localuser:test", None)
+        mock_password_provider.check_auth = AsyncMock(
+            return_value=("@localuser:test", None)
         )
         channel = self._send_login("test.login_type", "localuser", test_field="")
         self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
@@ -852,11 +830,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             username: The username to use for the test.
             registration: Whether to test with registration URLs.
         """
-        self.hs.get_identity_handler().send_threepid_validation = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(0),
+        self.hs.get_identity_handler().send_threepid_validation = AsyncMock(  # type: ignore[method-assign]
+            return_value=0
         )
 
-        m = Mock(return_value=make_awaitable(False))
+        m = AsyncMock(return_value=False)
         self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
 
         self.register_user(username, "password")
@@ -886,7 +864,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
 
         m.assert_called_once_with("email", "foo@test.com", registration)
 
-        m = Mock(return_value=make_awaitable(True))
+        m = AsyncMock(return_value=True)
         self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
 
         channel = self.make_request(
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 19f5322317..41c8c44e02 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -21,11 +21,12 @@ from signedjson.key import generate_signing_key
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import EventTypes, Membership, PresenceState
-from synapse.api.presence import UserPresenceState
+from synapse.api.presence import UserDevicePresenceState, UserPresenceState
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events.builder import EventBuilder
 from synapse.federation.sender import FederationSender
 from synapse.handlers.presence import (
+    BUSY_ONLINE_TIMEOUT,
     EXTERNAL_PROCESS_EXPIRY,
     FEDERATION_PING_INTERVAL,
     FEDERATION_TIMEOUT,
@@ -38,6 +39,7 @@ from synapse.handlers.presence import (
 from synapse.rest import admin
 from synapse.rest.client import room
 from synapse.server import HomeServer
+from synapse.storage.database import LoggingDatabaseConnection
 from synapse.types import JsonDict, UserID, get_domain_from_id
 from synapse.util import Clock
 
@@ -351,6 +353,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
 
     def test_idle_timer(self) -> None:
         user_id = "@foo:bar"
+        device_id = "dev-1"
         status_msg = "I'm here!"
         now = 5000000
 
@@ -361,8 +364,21 @@ class PresenceTimeoutTestCase(unittest.TestCase):
             last_user_sync_ts=now,
             status_msg=status_msg,
         )
+        device_state = UserDevicePresenceState(
+            user_id=user_id,
+            device_id=device_id,
+            state=state.state,
+            last_active_ts=state.last_active_ts,
+            last_sync_ts=state.last_user_sync_ts,
+        )
 
-        new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
+        new_state = handle_timeout(
+            state,
+            is_mine=True,
+            syncing_device_ids=set(),
+            user_devices={device_id: device_state},
+            now=now,
+        )
 
         self.assertIsNotNone(new_state)
         assert new_state is not None
@@ -375,6 +391,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
         presence state into unavailable.
         """
         user_id = "@foo:bar"
+        device_id = "dev-1"
         status_msg = "I'm here!"
         now = 5000000
 
@@ -385,8 +402,21 @@ class PresenceTimeoutTestCase(unittest.TestCase):
             last_user_sync_ts=now,
             status_msg=status_msg,
         )
+        device_state = UserDevicePresenceState(
+            user_id=user_id,
+            device_id=device_id,
+            state=state.state,
+            last_active_ts=state.last_active_ts,
+            last_sync_ts=state.last_user_sync_ts,
+        )
 
-        new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
+        new_state = handle_timeout(
+            state,
+            is_mine=True,
+            syncing_device_ids=set(),
+            user_devices={device_id: device_state},
+            now=now,
+        )
 
         self.assertIsNotNone(new_state)
         assert new_state is not None
@@ -395,6 +425,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
 
     def test_sync_timeout(self) -> None:
         user_id = "@foo:bar"
+        device_id = "dev-1"
         status_msg = "I'm here!"
         now = 5000000
 
@@ -405,8 +436,21 @@ class PresenceTimeoutTestCase(unittest.TestCase):
             last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1,
             status_msg=status_msg,
         )
+        device_state = UserDevicePresenceState(
+            user_id=user_id,
+            device_id=device_id,
+            state=state.state,
+            last_active_ts=state.last_active_ts,
+            last_sync_ts=state.last_user_sync_ts,
+        )
 
-        new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
+        new_state = handle_timeout(
+            state,
+            is_mine=True,
+            syncing_device_ids=set(),
+            user_devices={device_id: device_state},
+            now=now,
+        )
 
         self.assertIsNotNone(new_state)
         assert new_state is not None
@@ -415,6 +459,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
 
     def test_sync_online(self) -> None:
         user_id = "@foo:bar"
+        device_id = "dev-1"
         status_msg = "I'm here!"
         now = 5000000
 
@@ -425,9 +470,20 @@ class PresenceTimeoutTestCase(unittest.TestCase):
             last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1,
             status_msg=status_msg,
         )
+        device_state = UserDevicePresenceState(
+            user_id=user_id,
+            device_id=device_id,
+            state=state.state,
+            last_active_ts=state.last_active_ts,
+            last_sync_ts=state.last_user_sync_ts,
+        )
 
         new_state = handle_timeout(
-            state, is_mine=True, syncing_user_ids={user_id}, now=now
+            state,
+            is_mine=True,
+            syncing_device_ids={(user_id, device_id)},
+            user_devices={device_id: device_state},
+            now=now,
         )
 
         self.assertIsNotNone(new_state)
@@ -437,6 +493,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
 
     def test_federation_ping(self) -> None:
         user_id = "@foo:bar"
+        device_id = "dev-1"
         status_msg = "I'm here!"
         now = 5000000
 
@@ -448,14 +505,28 @@ class PresenceTimeoutTestCase(unittest.TestCase):
             last_federation_update_ts=now - FEDERATION_PING_INTERVAL - 1,
             status_msg=status_msg,
         )
+        device_state = UserDevicePresenceState(
+            user_id=user_id,
+            device_id=device_id,
+            state=state.state,
+            last_active_ts=state.last_active_ts,
+            last_sync_ts=state.last_user_sync_ts,
+        )
 
-        new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
+        new_state = handle_timeout(
+            state,
+            is_mine=True,
+            syncing_device_ids=set(),
+            user_devices={device_id: device_state},
+            now=now,
+        )
 
         self.assertIsNotNone(new_state)
         self.assertEqual(state, new_state)
 
     def test_no_timeout(self) -> None:
         user_id = "@foo:bar"
+        device_id = "dev-1"
         now = 5000000
 
         state = UserPresenceState.default(user_id)
@@ -465,8 +536,21 @@ class PresenceTimeoutTestCase(unittest.TestCase):
             last_user_sync_ts=now,
             last_federation_update_ts=now,
         )
+        device_state = UserDevicePresenceState(
+            user_id=user_id,
+            device_id=device_id,
+            state=state.state,
+            last_active_ts=state.last_active_ts,
+            last_sync_ts=state.last_user_sync_ts,
+        )
 
-        new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
+        new_state = handle_timeout(
+            state,
+            is_mine=True,
+            syncing_device_ids=set(),
+            user_devices={device_id: device_state},
+            now=now,
+        )
 
         self.assertIsNone(new_state)
 
@@ -484,8 +568,9 @@ class PresenceTimeoutTestCase(unittest.TestCase):
             status_msg=status_msg,
         )
 
+        # Note that this is a remote user so we do not have their device information.
         new_state = handle_timeout(
-            state, is_mine=False, syncing_user_ids=set(), now=now
+            state, is_mine=False, syncing_device_ids=set(), user_devices={}, now=now
         )
 
         self.assertIsNotNone(new_state)
@@ -495,6 +580,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
 
     def test_last_active(self) -> None:
         user_id = "@foo:bar"
+        device_id = "dev-1"
         status_msg = "I'm here!"
         now = 5000000
 
@@ -506,14 +592,150 @@ class PresenceTimeoutTestCase(unittest.TestCase):
             last_federation_update_ts=now,
             status_msg=status_msg,
         )
+        device_state = UserDevicePresenceState(
+            user_id=user_id,
+            device_id=device_id,
+            state=state.state,
+            last_active_ts=state.last_active_ts,
+            last_sync_ts=state.last_user_sync_ts,
+        )
 
-        new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
+        new_state = handle_timeout(
+            state,
+            is_mine=True,
+            syncing_device_ids=set(),
+            user_devices={device_id: device_state},
+            now=now,
+        )
 
         self.assertIsNotNone(new_state)
         self.assertEqual(state, new_state)
 
 
+class PresenceHandlerInitTestCase(unittest.HomeserverTestCase):
+    def default_config(self) -> JsonDict:
+        config = super().default_config()
+        # Disable background tasks on this worker so that the PresenceHandler isn't
+        # loaded until we request it.
+        config["run_background_tasks_on"] = "other"
+        return config
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.user_id = f"@test:{self.hs.config.server.server_name}"
+        self.device_id = "dev-1"
+
+        # Move the reactor to the initial time.
+        self.reactor.advance(1000)
+        now = self.clock.time_msec()
+
+        main_store = hs.get_datastores().main
+        self.get_success(
+            main_store.update_presence(
+                [
+                    UserPresenceState(
+                        user_id=self.user_id,
+                        state=PresenceState.ONLINE,
+                        last_active_ts=now,
+                        last_federation_update_ts=now,
+                        last_user_sync_ts=now,
+                        status_msg=None,
+                        currently_active=True,
+                    )
+                ]
+            )
+        )
+
+        # Regenerate the preloaded presence information on PresenceStore.
+        def refill_presence(db_conn: LoggingDatabaseConnection) -> None:
+            main_store._presence_on_startup = main_store._get_active_presence(db_conn)
+
+        self.get_success(main_store.db_pool.runWithConnection(refill_presence))
+
+    def test_restored_presence_idles(self) -> None:
+        """The presence state restored from the database should not persist forever."""
+
+        # Get the handler (which kicks off a bunch of timers).
+        presence_handler = self.hs.get_presence_handler()
+
+        # Assert the user is online.
+        state = self.get_success(
+            presence_handler.get_state(UserID.from_string(self.user_id))
+        )
+        self.assertEqual(state.state, PresenceState.ONLINE)
+
+        # Advance such that the user should timeout.
+        self.reactor.advance(SYNC_ONLINE_TIMEOUT / 1000)
+        self.reactor.pump([5])
+
+        # Check that the user is now offline.
+        state = self.get_success(
+            presence_handler.get_state(UserID.from_string(self.user_id))
+        )
+        self.assertEqual(state.state, PresenceState.OFFLINE)
+
+    @parameterized.expand(
+        [
+            (PresenceState.BUSY, PresenceState.BUSY),
+            (PresenceState.ONLINE, PresenceState.ONLINE),
+            (PresenceState.UNAVAILABLE, PresenceState.ONLINE),
+            # Offline syncs don't update the state.
+            (PresenceState.OFFLINE, PresenceState.ONLINE),
+        ]
+    )
+    @unittest.override_config({"experimental_features": {"msc3026_enabled": True}})
+    def test_restored_presence_online_after_sync(
+        self, sync_state: str, expected_state: str
+    ) -> None:
+        """
+        The presence state restored from the database should be overridden with sync after a timeout.
+
+        Args:
+            sync_state: The presence state of the new sync.
+            expected_state: The expected presence right after the sync.
+        """
+
+        # Get the handler (which kicks off a bunch of timers).
+        presence_handler = self.hs.get_presence_handler()
+
+        # Assert the user is online, as restored.
+        state = self.get_success(
+            presence_handler.get_state(UserID.from_string(self.user_id))
+        )
+        self.assertEqual(state.state, PresenceState.ONLINE)
+
+        # Advance slightly and sync.
+        self.reactor.advance(SYNC_ONLINE_TIMEOUT / 1000 / 2)
+        self.get_success(
+            presence_handler.user_syncing(
+                self.user_id,
+                self.device_id,
+                sync_state != PresenceState.OFFLINE,
+                sync_state,
+            )
+        )
+
+        # Assert the user is in the expected state.
+        state = self.get_success(
+            presence_handler.get_state(UserID.from_string(self.user_id))
+        )
+        self.assertEqual(state.state, expected_state)
+
+        # Advance such that the user's preloaded data times out, but not the new sync.
+        self.reactor.advance(SYNC_ONLINE_TIMEOUT / 1000 / 2)
+        self.reactor.pump([5])
+
+        # Check that the user is in the sync state (as the client is currently syncing still).
+        state = self.get_success(
+            presence_handler.get_state(UserID.from_string(self.user_id))
+        )
+        self.assertEqual(state.state, sync_state)
+
+
 class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
+    user_id = "@test:server"
+    user_id_obj = UserID.from_string(user_id)
+    device_id = "dev-1"
+
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.presence_handler = hs.get_presence_handler()
         self.clock = hs.get_clock()
@@ -522,62 +744,57 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
         """Test that if an external process doesn't update the records for a while
         we time out their syncing users presence.
         """
-        process_id = "1"
-        user_id = "@test:server"
 
-        # Notify handler that a user is now syncing.
+        # Create a worker and use it to handle /sync traffic instead.
+        # This is used to test that presence changes get replicated from workers
+        # to the main process correctly.
+        worker_to_sync_against = self.make_worker_hs(
+            "synapse.app.generic_worker", {"worker_name": "synchrotron"}
+        )
+        worker_presence_handler = worker_to_sync_against.get_presence_handler()
+
         self.get_success(
-            self.presence_handler.update_external_syncs_row(
-                process_id, user_id, True, self.clock.time_msec()
-            )
+            worker_presence_handler.user_syncing(
+                self.user_id, self.device_id, True, PresenceState.ONLINE
+            ),
+            by=0.1,
         )
 
         # Check that if we wait a while without telling the handler the user has
         # stopped syncing that their presence state doesn't get timed out.
         self.reactor.advance(EXTERNAL_PROCESS_EXPIRY / 2)
 
-        state = self.get_success(
-            self.presence_handler.get_state(UserID.from_string(user_id))
-        )
+        state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
         self.assertEqual(state.state, PresenceState.ONLINE)
 
         # Check that if the external process timeout fires, then the syncing
         # user gets timed out
         self.reactor.advance(EXTERNAL_PROCESS_EXPIRY)
 
-        state = self.get_success(
-            self.presence_handler.get_state(UserID.from_string(user_id))
-        )
+        state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
         self.assertEqual(state.state, PresenceState.OFFLINE)
 
     def test_user_goes_offline_by_timeout_status_msg_remain(self) -> None:
         """Test that if a user doesn't update the records for a while
         users presence goes `OFFLINE` because of timeout and `status_msg` remains.
         """
-        user_id = "@test:server"
         status_msg = "I'm here!"
 
         # Mark user as online
-        self._set_presencestate_with_status_msg(
-            user_id, PresenceState.ONLINE, status_msg
-        )
+        self._set_presencestate_with_status_msg(PresenceState.ONLINE, status_msg)
 
         # Check that if we wait a while without telling the handler the user has
         # stopped syncing that their presence state doesn't get timed out.
         self.reactor.advance(SYNC_ONLINE_TIMEOUT / 2)
 
-        state = self.get_success(
-            self.presence_handler.get_state(UserID.from_string(user_id))
-        )
+        state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
         self.assertEqual(state.state, PresenceState.ONLINE)
         self.assertEqual(state.status_msg, status_msg)
 
         # Check that if the timeout fires, then the syncing user gets timed out
         self.reactor.advance(SYNC_ONLINE_TIMEOUT)
 
-        state = self.get_success(
-            self.presence_handler.get_state(UserID.from_string(user_id))
-        )
+        state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
         # status_msg should remain even after going offline
         self.assertEqual(state.state, PresenceState.OFFLINE)
         self.assertEqual(state.status_msg, status_msg)
@@ -586,24 +803,19 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
         """Test that if a user change presence manually to `OFFLINE`
         and no status is set, that `status_msg` is `None`.
         """
-        user_id = "@test:server"
         status_msg = "I'm here!"
 
         # Mark user as online
-        self._set_presencestate_with_status_msg(
-            user_id, PresenceState.ONLINE, status_msg
-        )
+        self._set_presencestate_with_status_msg(PresenceState.ONLINE, status_msg)
 
         # Mark user as offline
         self.get_success(
             self.presence_handler.set_state(
-                UserID.from_string(user_id), {"presence": PresenceState.OFFLINE}
+                self.user_id_obj, self.device_id, {"presence": PresenceState.OFFLINE}
             )
         )
 
-        state = self.get_success(
-            self.presence_handler.get_state(UserID.from_string(user_id))
-        )
+        state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
         self.assertEqual(state.state, PresenceState.OFFLINE)
         self.assertEqual(state.status_msg, None)
 
@@ -611,41 +823,31 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
         """Test that if a user change presence manually to `OFFLINE`
         and a status is set, that `status_msg` appears.
         """
-        user_id = "@test:server"
         status_msg = "I'm here!"
 
         # Mark user as online
-        self._set_presencestate_with_status_msg(
-            user_id, PresenceState.ONLINE, status_msg
-        )
+        self._set_presencestate_with_status_msg(PresenceState.ONLINE, status_msg)
 
         # Mark user as offline
-        self._set_presencestate_with_status_msg(
-            user_id, PresenceState.OFFLINE, "And now here."
-        )
+        self._set_presencestate_with_status_msg(PresenceState.OFFLINE, "And now here.")
 
     def test_user_reset_online_with_no_status(self) -> None:
         """Test that if a user set again the presence manually
         and no status is set, that `status_msg` is `None`.
         """
-        user_id = "@test:server"
         status_msg = "I'm here!"
 
         # Mark user as online
-        self._set_presencestate_with_status_msg(
-            user_id, PresenceState.ONLINE, status_msg
-        )
+        self._set_presencestate_with_status_msg(PresenceState.ONLINE, status_msg)
 
         # Mark user as online again
         self.get_success(
             self.presence_handler.set_state(
-                UserID.from_string(user_id), {"presence": PresenceState.ONLINE}
+                self.user_id_obj, self.device_id, {"presence": PresenceState.ONLINE}
             )
         )
 
-        state = self.get_success(
-            self.presence_handler.get_state(UserID.from_string(user_id))
-        )
+        state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
         # status_msg should remain even after going offline
         self.assertEqual(state.state, PresenceState.ONLINE)
         self.assertEqual(state.status_msg, None)
@@ -654,33 +856,27 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
         """Test that if a user set again the presence manually
         and status is `None`, that `status_msg` is `None`.
         """
-        user_id = "@test:server"
         status_msg = "I'm here!"
 
         # Mark user as online
-        self._set_presencestate_with_status_msg(
-            user_id, PresenceState.ONLINE, status_msg
-        )
+        self._set_presencestate_with_status_msg(PresenceState.ONLINE, status_msg)
 
         # Mark user as online and `status_msg = None`
-        self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None)
+        self._set_presencestate_with_status_msg(PresenceState.ONLINE, None)
 
     def test_set_presence_from_syncing_not_set(self) -> None:
         """Test that presence is not set by syncing if affect_presence is false"""
-        user_id = "@test:server"
         status_msg = "I'm here!"
 
-        self._set_presencestate_with_status_msg(
-            user_id, PresenceState.UNAVAILABLE, status_msg
-        )
+        self._set_presencestate_with_status_msg(PresenceState.UNAVAILABLE, status_msg)
 
         self.get_success(
-            self.presence_handler.user_syncing(user_id, False, PresenceState.ONLINE)
+            self.presence_handler.user_syncing(
+                self.user_id, self.device_id, False, PresenceState.ONLINE
+            )
         )
 
-        state = self.get_success(
-            self.presence_handler.get_state(UserID.from_string(user_id))
-        )
+        state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
         # we should still be unavailable
         self.assertEqual(state.state, PresenceState.UNAVAILABLE)
         # and status message should still be the same
@@ -688,50 +884,518 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
 
     def test_set_presence_from_syncing_is_set(self) -> None:
         """Test that presence is set by syncing if affect_presence is true"""
-        user_id = "@test:server"
         status_msg = "I'm here!"
 
-        self._set_presencestate_with_status_msg(
-            user_id, PresenceState.UNAVAILABLE, status_msg
+        self._set_presencestate_with_status_msg(PresenceState.UNAVAILABLE, status_msg)
+
+        self.get_success(
+            self.presence_handler.user_syncing(
+                self.user_id, self.device_id, True, PresenceState.ONLINE
+            )
+        )
+
+        state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
+        # we should now be online
+        self.assertEqual(state.state, PresenceState.ONLINE)
+
+    @parameterized.expand(
+        # A list of tuples of 4 strings:
+        #
+        # * The presence state of device 1.
+        # * The presence state of device 2.
+        # * The expected user presence state after both devices have synced.
+        # * The expected user presence state after device 1 has idled.
+        # * The expected user presence state after device 2 has idled.
+        # * True to use workers, False a monolith.
+        [
+            (*cases, workers)
+            for workers in (False, True)
+            for cases in [
+                # If both devices have the same state, online should eventually idle.
+                # Otherwise, the state doesn't change.
+                (
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                ),
+                (
+                    PresenceState.ONLINE,
+                    PresenceState.ONLINE,
+                    PresenceState.ONLINE,
+                    PresenceState.ONLINE,
+                    PresenceState.UNAVAILABLE,
+                ),
+                (
+                    PresenceState.UNAVAILABLE,
+                    PresenceState.UNAVAILABLE,
+                    PresenceState.UNAVAILABLE,
+                    PresenceState.UNAVAILABLE,
+                    PresenceState.UNAVAILABLE,
+                ),
+                (
+                    PresenceState.OFFLINE,
+                    PresenceState.OFFLINE,
+                    PresenceState.OFFLINE,
+                    PresenceState.OFFLINE,
+                    PresenceState.OFFLINE,
+                ),
+                # If the second device has a "lower" state it should fallback to it,
+                # except for "busy" which overrides.
+                (
+                    PresenceState.BUSY,
+                    PresenceState.ONLINE,
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                ),
+                (
+                    PresenceState.BUSY,
+                    PresenceState.UNAVAILABLE,
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                ),
+                (
+                    PresenceState.BUSY,
+                    PresenceState.OFFLINE,
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                ),
+                (
+                    PresenceState.ONLINE,
+                    PresenceState.UNAVAILABLE,
+                    PresenceState.ONLINE,
+                    PresenceState.UNAVAILABLE,
+                    PresenceState.UNAVAILABLE,
+                ),
+                (
+                    PresenceState.ONLINE,
+                    PresenceState.OFFLINE,
+                    PresenceState.ONLINE,
+                    PresenceState.UNAVAILABLE,
+                    PresenceState.UNAVAILABLE,
+                ),
+                (
+                    PresenceState.UNAVAILABLE,
+                    PresenceState.OFFLINE,
+                    PresenceState.UNAVAILABLE,
+                    PresenceState.UNAVAILABLE,
+                    PresenceState.UNAVAILABLE,
+                ),
+                # If the second device has a "higher" state it should override.
+                (
+                    PresenceState.ONLINE,
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                ),
+                (
+                    PresenceState.UNAVAILABLE,
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                ),
+                (
+                    PresenceState.OFFLINE,
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                ),
+                (
+                    PresenceState.UNAVAILABLE,
+                    PresenceState.ONLINE,
+                    PresenceState.ONLINE,
+                    PresenceState.ONLINE,
+                    PresenceState.UNAVAILABLE,
+                ),
+                (
+                    PresenceState.OFFLINE,
+                    PresenceState.ONLINE,
+                    PresenceState.ONLINE,
+                    PresenceState.ONLINE,
+                    PresenceState.UNAVAILABLE,
+                ),
+                (
+                    PresenceState.OFFLINE,
+                    PresenceState.UNAVAILABLE,
+                    PresenceState.UNAVAILABLE,
+                    PresenceState.UNAVAILABLE,
+                    PresenceState.UNAVAILABLE,
+                ),
+            ]
+        ],
+        name_func=lambda testcase_func, param_num, params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[5] else 'monolith'}",
+    )
+    @unittest.override_config({"experimental_features": {"msc3026_enabled": True}})
+    def test_set_presence_from_syncing_multi_device(
+        self,
+        dev_1_state: str,
+        dev_2_state: str,
+        expected_state_1: str,
+        expected_state_2: str,
+        expected_state_3: str,
+        test_with_workers: bool,
+    ) -> None:
+        """
+        Test the behaviour of multiple devices syncing at the same time.
+
+        Roughly the user's presence state should be set to the "highest" priority
+        of all the devices. When a device then goes offline its state should be
+        discarded and the next highest should win.
+
+        Note that these tests use the idle timer (and don't close the syncs), it
+        is unlikely that a *single* sync would last this long, but is close enough
+        to continually syncing with that current state.
+        """
+        user_id = f"@test:{self.hs.config.server.server_name}"
+
+        # By default, we call /sync against the main process.
+        worker_presence_handler = self.presence_handler
+        if test_with_workers:
+            # Create a worker and use it to handle /sync traffic instead.
+            # This is used to test that presence changes get replicated from workers
+            # to the main process correctly.
+            worker_to_sync_against = self.make_worker_hs(
+                "synapse.app.generic_worker", {"worker_name": "synchrotron"}
+            )
+            worker_presence_handler = worker_to_sync_against.get_presence_handler()
+
+        # 1. Sync with the first device.
+        self.get_success(
+            worker_presence_handler.user_syncing(
+                user_id,
+                "dev-1",
+                affect_presence=dev_1_state != PresenceState.OFFLINE,
+                presence_state=dev_1_state,
+            ),
+            by=0.01,
         )
 
+        # 2. Wait half the idle timer.
+        self.reactor.advance(IDLE_TIMER / 1000 / 2)
+        self.reactor.pump([0.1])
+
+        # 3. Sync with the second device.
         self.get_success(
-            self.presence_handler.user_syncing(user_id, True, PresenceState.ONLINE)
+            worker_presence_handler.user_syncing(
+                user_id,
+                "dev-2",
+                affect_presence=dev_2_state != PresenceState.OFFLINE,
+                presence_state=dev_2_state,
+            ),
+            by=0.01,
         )
 
+        # 4. Assert the expected presence state.
         state = self.get_success(
             self.presence_handler.get_state(UserID.from_string(user_id))
         )
-        # we should now be online
-        self.assertEqual(state.state, PresenceState.ONLINE)
+        self.assertEqual(state.state, expected_state_1)
+        if test_with_workers:
+            state = self.get_success(
+                worker_presence_handler.get_state(UserID.from_string(user_id))
+            )
+            self.assertEqual(state.state, expected_state_1)
 
-    def test_set_presence_from_syncing_keeps_status(self) -> None:
-        """Test that presence set by syncing retains status message"""
-        user_id = "@test:server"
-        status_msg = "I'm here!"
+        # When testing with workers, make another random sync (with any *different*
+        # user) to keep the process information from expiring.
+        #
+        # This is due to EXTERNAL_PROCESS_EXPIRY being equivalent to IDLE_TIMER.
+        if test_with_workers:
+            with self.get_success(
+                worker_presence_handler.user_syncing(
+                    f"@other-user:{self.hs.config.server.server_name}",
+                    "dev-3",
+                    affect_presence=True,
+                    presence_state=PresenceState.ONLINE,
+                ),
+                by=0.01,
+            ):
+                pass
 
-        self._set_presencestate_with_status_msg(
-            user_id, PresenceState.UNAVAILABLE, status_msg
+        # 5. Advance such that the first device should be discarded (the idle timer),
+        # then pump so _handle_timeouts function to called.
+        self.reactor.advance(IDLE_TIMER / 1000 / 2)
+        self.reactor.pump([0.01])
+
+        # 6. Assert the expected presence state.
+        state = self.get_success(
+            self.presence_handler.get_state(UserID.from_string(user_id))
         )
+        self.assertEqual(state.state, expected_state_2)
+        if test_with_workers:
+            state = self.get_success(
+                worker_presence_handler.get_state(UserID.from_string(user_id))
+            )
+            self.assertEqual(state.state, expected_state_2)
 
-        self.get_success(
-            self.presence_handler.user_syncing(user_id, True, PresenceState.ONLINE)
+        # 7. Advance such that the second device should be discarded (half the idle timer),
+        # then pump so _handle_timeouts function to called.
+        self.reactor.advance(IDLE_TIMER / 1000 / 2)
+        self.reactor.pump([0.1])
+
+        # 8. The devices are still "syncing" (the sync context managers were never
+        # closed), so might idle.
+        state = self.get_success(
+            self.presence_handler.get_state(UserID.from_string(user_id))
         )
+        self.assertEqual(state.state, expected_state_3)
+        if test_with_workers:
+            state = self.get_success(
+                worker_presence_handler.get_state(UserID.from_string(user_id))
+            )
+            self.assertEqual(state.state, expected_state_3)
+
+    @parameterized.expand(
+        # A list of tuples of 4 strings:
+        #
+        # * The presence state of device 1.
+        # * The presence state of device 2.
+        # * The expected user presence state after both devices have synced.
+        # * The expected user presence state after device 1 has stopped syncing.
+        # * True to use workers, False a monolith.
+        [
+            (*cases, workers)
+            for workers in (False, True)
+            for cases in [
+                # If both devices have the same state, nothing exciting should happen.
+                (
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                ),
+                (
+                    PresenceState.ONLINE,
+                    PresenceState.ONLINE,
+                    PresenceState.ONLINE,
+                    PresenceState.ONLINE,
+                ),
+                (
+                    PresenceState.UNAVAILABLE,
+                    PresenceState.UNAVAILABLE,
+                    PresenceState.UNAVAILABLE,
+                    PresenceState.UNAVAILABLE,
+                ),
+                (
+                    PresenceState.OFFLINE,
+                    PresenceState.OFFLINE,
+                    PresenceState.OFFLINE,
+                    PresenceState.OFFLINE,
+                ),
+                # If the second device has a "lower" state it should fallback to it,
+                # except for "busy" which overrides.
+                (
+                    PresenceState.BUSY,
+                    PresenceState.ONLINE,
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                ),
+                (
+                    PresenceState.BUSY,
+                    PresenceState.UNAVAILABLE,
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                ),
+                (
+                    PresenceState.BUSY,
+                    PresenceState.OFFLINE,
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                ),
+                (
+                    PresenceState.ONLINE,
+                    PresenceState.UNAVAILABLE,
+                    PresenceState.ONLINE,
+                    PresenceState.UNAVAILABLE,
+                ),
+                (
+                    PresenceState.ONLINE,
+                    PresenceState.OFFLINE,
+                    PresenceState.ONLINE,
+                    PresenceState.OFFLINE,
+                ),
+                (
+                    PresenceState.UNAVAILABLE,
+                    PresenceState.OFFLINE,
+                    PresenceState.UNAVAILABLE,
+                    PresenceState.OFFLINE,
+                ),
+                # If the second device has a "higher" state it should override.
+                (
+                    PresenceState.ONLINE,
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                ),
+                (
+                    PresenceState.UNAVAILABLE,
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                ),
+                (
+                    PresenceState.OFFLINE,
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                    PresenceState.BUSY,
+                ),
+                (
+                    PresenceState.UNAVAILABLE,
+                    PresenceState.ONLINE,
+                    PresenceState.ONLINE,
+                    PresenceState.ONLINE,
+                ),
+                (
+                    PresenceState.OFFLINE,
+                    PresenceState.ONLINE,
+                    PresenceState.ONLINE,
+                    PresenceState.ONLINE,
+                ),
+                (
+                    PresenceState.OFFLINE,
+                    PresenceState.UNAVAILABLE,
+                    PresenceState.UNAVAILABLE,
+                    PresenceState.UNAVAILABLE,
+                ),
+            ]
+        ],
+        name_func=lambda testcase_func, param_num, params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[4] else 'monolith'}",
+    )
+    @unittest.override_config({"experimental_features": {"msc3026_enabled": True}})
+    def test_set_presence_from_non_syncing_multi_device(
+        self,
+        dev_1_state: str,
+        dev_2_state: str,
+        expected_state_1: str,
+        expected_state_2: str,
+        test_with_workers: bool,
+    ) -> None:
+        """
+        Test the behaviour of multiple devices syncing at the same time.
+
+        Roughly the user's presence state should be set to the "highest" priority
+        of all the devices. When a device then goes offline its state should be
+        discarded and the next highest should win.
+
+        Note that these tests use the idle timer (and don't close the syncs), it
+        is unlikely that a *single* sync would last this long, but is close enough
+        to continually syncing with that current state.
+        """
+        user_id = f"@test:{self.hs.config.server.server_name}"
+
+        # By default, we call /sync against the main process.
+        worker_presence_handler = self.presence_handler
+        if test_with_workers:
+            # Create a worker and use it to handle /sync traffic instead.
+            # This is used to test that presence changes get replicated from workers
+            # to the main process correctly.
+            worker_to_sync_against = self.make_worker_hs(
+                "synapse.app.generic_worker", {"worker_name": "synchrotron"}
+            )
+            worker_presence_handler = worker_to_sync_against.get_presence_handler()
+
+        # 1. Sync with the first device.
+        sync_1 = self.get_success(
+            worker_presence_handler.user_syncing(
+                user_id,
+                "dev-1",
+                affect_presence=dev_1_state != PresenceState.OFFLINE,
+                presence_state=dev_1_state,
+            ),
+            by=0.1,
+        )
+
+        # 2. Sync with the second device.
+        sync_2 = self.get_success(
+            worker_presence_handler.user_syncing(
+                user_id,
+                "dev-2",
+                affect_presence=dev_2_state != PresenceState.OFFLINE,
+                presence_state=dev_2_state,
+            ),
+            by=0.1,
+        )
+
+        # 3. Assert the expected presence state.
+        state = self.get_success(
+            self.presence_handler.get_state(UserID.from_string(user_id))
+        )
+        self.assertEqual(state.state, expected_state_1)
+        if test_with_workers:
+            state = self.get_success(
+                worker_presence_handler.get_state(UserID.from_string(user_id))
+            )
+            self.assertEqual(state.state, expected_state_1)
+
+        # 4. Disconnect the first device.
+        with sync_1:
+            pass
+
+        # 5. Advance such that the first device should be discarded (the sync timeout),
+        # then pump so _handle_timeouts function to called.
+        self.reactor.advance(SYNC_ONLINE_TIMEOUT / 1000)
+        self.reactor.pump([5])
 
+        # 6. Assert the expected presence state.
         state = self.get_success(
             self.presence_handler.get_state(UserID.from_string(user_id))
         )
+        self.assertEqual(state.state, expected_state_2)
+        if test_with_workers:
+            state = self.get_success(
+                worker_presence_handler.get_state(UserID.from_string(user_id))
+            )
+            self.assertEqual(state.state, expected_state_2)
+
+        # 7. Disconnect the second device.
+        with sync_2:
+            pass
+
+        # 8. Advance such that the second device should be discarded (the sync timeout),
+        # then pump so _handle_timeouts function to called.
+        if dev_1_state == PresenceState.BUSY or dev_2_state == PresenceState.BUSY:
+            timeout = BUSY_ONLINE_TIMEOUT
+        else:
+            timeout = SYNC_ONLINE_TIMEOUT
+        self.reactor.advance(timeout / 1000)
+        self.reactor.pump([5])
+
+        # 9. There are no more devices, should be offline.
+        state = self.get_success(
+            self.presence_handler.get_state(UserID.from_string(user_id))
+        )
+        self.assertEqual(state.state, PresenceState.OFFLINE)
+        if test_with_workers:
+            state = self.get_success(
+                worker_presence_handler.get_state(UserID.from_string(user_id))
+            )
+            self.assertEqual(state.state, PresenceState.OFFLINE)
+
+    def test_set_presence_from_syncing_keeps_status(self) -> None:
+        """Test that presence set by syncing retains status message"""
+        status_msg = "I'm here!"
+
+        self._set_presencestate_with_status_msg(PresenceState.UNAVAILABLE, status_msg)
+
+        self.get_success(
+            self.presence_handler.user_syncing(
+                self.user_id, self.device_id, True, PresenceState.ONLINE
+            )
+        )
+
+        state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
         # our status message should be the same as it was before
         self.assertEqual(state.status_msg, status_msg)
 
     @parameterized.expand([(False,), (True,)])
-    @unittest.override_config(
-        {
-            "experimental_features": {
-                "msc3026_enabled": True,
-            },
-        }
-    )
+    @unittest.override_config({"experimental_features": {"msc3026_enabled": True}})
     def test_set_presence_from_syncing_keeps_busy(
         self, test_with_workers: bool
     ) -> None:
@@ -741,7 +1405,6 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
             test_with_workers: If True, check the presence state of the user by calling
                 /sync against a worker, rather than the main process.
         """
-        user_id = "@test:server"
         status_msg = "I'm busy!"
 
         # By default, we call /sync against the main process.
@@ -751,48 +1414,60 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
             # This is used to test that presence changes get replicated from workers
             # to the main process correctly.
             worker_to_sync_against = self.make_worker_hs(
-                "synapse.app.generic_worker", {"worker_name": "presence_writer"}
+                "synapse.app.generic_worker", {"worker_name": "synchrotron"}
             )
 
         # Set presence to BUSY
-        self._set_presencestate_with_status_msg(user_id, PresenceState.BUSY, status_msg)
+        self._set_presencestate_with_status_msg(PresenceState.BUSY, status_msg)
 
         # Perform a sync with a presence state other than busy. This should NOT change
         # our presence status; we only change from busy if we explicitly set it via
         # /presence/*.
         self.get_success(
             worker_to_sync_against.get_presence_handler().user_syncing(
-                user_id, True, PresenceState.ONLINE
-            )
+                self.user_id, self.device_id, True, PresenceState.ONLINE
+            ),
+            by=0.1,
         )
 
         # Check against the main process that the user's presence did not change.
-        state = self.get_success(
-            self.presence_handler.get_state(UserID.from_string(user_id))
-        )
+        state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
         # we should still be busy
         self.assertEqual(state.state, PresenceState.BUSY)
 
+        # Advance such that the device would be discarded if it was not busy,
+        # then pump so _handle_timeouts function to called.
+        self.reactor.advance(IDLE_TIMER / 1000)
+        self.reactor.pump([5])
+
+        # The account should still be busy.
+        state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
+        self.assertEqual(state.state, PresenceState.BUSY)
+
+        # Ensure that a /presence call can set the user *off* busy.
+        self._set_presencestate_with_status_msg(PresenceState.ONLINE, status_msg)
+
+        state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
+        self.assertEqual(state.state, PresenceState.ONLINE)
+
     def _set_presencestate_with_status_msg(
-        self, user_id: str, state: str, status_msg: Optional[str]
+        self, state: str, status_msg: Optional[str]
     ) -> None:
         """Set a PresenceState and status_msg and check the result.
 
         Args:
-            user_id: User for that the status is to be set.
             state: The new PresenceState.
             status_msg: Status message that is to be set.
         """
         self.get_success(
             self.presence_handler.set_state(
-                UserID.from_string(user_id),
+                self.user_id_obj,
+                self.device_id,
                 {"presence": state, "status_msg": status_msg},
             )
         )
 
-        new_state = self.get_success(
-            self.presence_handler.get_state(UserID.from_string(user_id))
-        )
+        new_state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
         self.assertEqual(new_state.state, state)
         self.assertEqual(new_state.status_msg, status_msg)
 
@@ -812,8 +1487,14 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
 
         prev_token = self.queue.get_current_token(self.instance_name)
 
-        self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
-        self.queue.send_presence_to_destinations((state3,), ("dest3",))
+        self.get_success(
+            self.queue.send_presence_to_destinations(
+                (state1, state2), ("dest1", "dest2")
+            )
+        )
+        self.get_success(
+            self.queue.send_presence_to_destinations((state3,), ("dest3",))
+        )
 
         now_token = self.queue.get_current_token(self.instance_name)
 
@@ -849,11 +1530,17 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
 
         prev_token = self.queue.get_current_token(self.instance_name)
 
-        self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
+        self.get_success(
+            self.queue.send_presence_to_destinations(
+                (state1, state2), ("dest1", "dest2")
+            )
+        )
 
         now_token = self.queue.get_current_token(self.instance_name)
 
-        self.queue.send_presence_to_destinations((state3,), ("dest3",))
+        self.get_success(
+            self.queue.send_presence_to_destinations((state3,), ("dest3",))
+        )
 
         rows, upto_token, limited = self.get_success(
             self.queue.get_replication_rows("master", prev_token, now_token, 10)
@@ -892,8 +1579,14 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
 
         prev_token = self.queue.get_current_token(self.instance_name)
 
-        self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
-        self.queue.send_presence_to_destinations((state3,), ("dest3",))
+        self.get_success(
+            self.queue.send_presence_to_destinations(
+                (state1, state2), ("dest1", "dest2")
+            )
+        )
+        self.get_success(
+            self.queue.send_presence_to_destinations((state3,), ("dest3",))
+        )
 
         self.reactor.advance(10 * 60 * 1000)
 
@@ -908,8 +1601,14 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
 
         prev_token = self.queue.get_current_token(self.instance_name)
 
-        self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
-        self.queue.send_presence_to_destinations((state3,), ("dest3",))
+        self.get_success(
+            self.queue.send_presence_to_destinations(
+                (state1, state2), ("dest1", "dest2")
+            )
+        )
+        self.get_success(
+            self.queue.send_presence_to_destinations((state3,), ("dest3",))
+        )
 
         now_token = self.queue.get_current_token(self.instance_name)
 
@@ -936,11 +1635,17 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
 
         prev_token = self.queue.get_current_token(self.instance_name)
 
-        self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
+        self.get_success(
+            self.queue.send_presence_to_destinations(
+                (state1, state2), ("dest1", "dest2")
+            )
+        )
 
         self.reactor.advance(2 * 60 * 1000)
 
-        self.queue.send_presence_to_destinations((state3,), ("dest3",))
+        self.get_success(
+            self.queue.send_presence_to_destinations((state3,), ("dest3",))
+        )
 
         self.reactor.advance(4 * 60 * 1000)
 
@@ -952,15 +1657,18 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
         self.assertEqual(upto_token, now_token)
         self.assertFalse(limited)
 
-        expected_rows = [
-            (2, ("dest3", "@user3:test")),
-        ]
         self.assertCountEqual(rows, [])
 
         prev_token = self.queue.get_current_token(self.instance_name)
 
-        self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
-        self.queue.send_presence_to_destinations((state3,), ("dest3",))
+        self.get_success(
+            self.queue.send_presence_to_destinations(
+                (state1, state2), ("dest1", "dest2")
+            )
+        )
+        self.get_success(
+            self.queue.send_presence_to_destinations((state3,), ("dest3",))
+        )
 
         now_token = self.queue.get_current_token(self.instance_name)
 
@@ -993,7 +1701,6 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         hs = self.setup_test_homeserver(
             "server",
-            federation_http_client=None,
             federation_sender=Mock(spec=FederationSender),
         )
         return hs
@@ -1033,7 +1740,9 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
         # Mark test2 as online, test will be offline with a last_active of 0
         self.get_success(
             self.presence_handler.set_state(
-                UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
+                UserID.from_string("@test2:server"),
+                "dev-1",
+                {"presence": PresenceState.ONLINE},
             )
         )
         self.reactor.pump([0])  # Wait for presence updates to be handled
@@ -1080,7 +1789,9 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
         # Mark test as online
         self.get_success(
             self.presence_handler.set_state(
-                UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE}
+                UserID.from_string("@test:server"),
+                "dev-1",
+                {"presence": PresenceState.ONLINE},
             )
         )
 
@@ -1088,7 +1799,9 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
         # Note we don't join them to the room yet
         self.get_success(
             self.presence_handler.set_state(
-                UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
+                UserID.from_string("@test2:server"),
+                "dev-1",
+                {"presence": PresenceState.ONLINE},
             )
         )
 
@@ -1145,7 +1858,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
         )
 
         event = self.get_success(
-            builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None)
+            builder.build(prev_event_ids=list(prev_event_ids), auth_event_ids=None)
         )
 
         self.get_success(self.federation_event_handler.on_receive_pdu(hostname, event))
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 64a9a22afe..f9b292b9ec 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from typing import Any, Awaitable, Callable, Dict
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from parameterized import parameterized
 
@@ -26,7 +26,6 @@ from synapse.types import JsonDict, UserID
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 
 
 class ProfileTestCase(unittest.HomeserverTestCase):
@@ -35,7 +34,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
     servlets = [admin.register_servlets]
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-        self.mock_federation = Mock()
+        self.mock_federation = AsyncMock()
         self.mock_registry = Mock()
 
         self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
@@ -80,11 +79,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEqual(
-            (
-                self.get_success(
-                    self.store.get_profile_displayname(self.frank.localpart)
-                )
-            ),
+            (self.get_success(self.store.get_profile_displayname(self.frank))),
             "Frank Jr.",
         )
 
@@ -96,11 +91,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEqual(
-            (
-                self.get_success(
-                    self.store.get_profile_displayname(self.frank.localpart)
-                )
-            ),
+            (self.get_success(self.store.get_profile_displayname(self.frank))),
             "Frank",
         )
 
@@ -112,7 +103,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertIsNone(
-            self.get_success(self.store.get_profile_displayname(self.frank.localpart))
+            self.get_success(self.store.get_profile_displayname(self.frank))
         )
 
     def test_set_my_name_if_disabled(self) -> None:
@@ -122,11 +113,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         self.get_success(self.store.set_profile_displayname(self.frank, "Frank"))
 
         self.assertEqual(
-            (
-                self.get_success(
-                    self.store.get_profile_displayname(self.frank.localpart)
-                )
-            ),
+            (self.get_success(self.store.get_profile_displayname(self.frank))),
             "Frank",
         )
 
@@ -147,9 +134,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         )
 
     def test_get_other_name(self) -> None:
-        self.mock_federation.make_query.return_value = make_awaitable(
-            {"displayname": "Alice"}
-        )
+        self.mock_federation.make_query.return_value = {"displayname": "Alice"}
 
         displayname = self.get_success(self.handler.get_displayname(self.alice))
 
@@ -191,6 +176,16 @@ class ProfileTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual("http://my.server/me.png", avatar_url)
 
+    def test_get_profile_empty_displayname(self) -> None:
+        self.get_success(self.store.set_profile_displayname(self.frank, None))
+        self.get_success(
+            self.store.set_profile_avatar_url(self.frank, "http://my.server/me.png")
+        )
+
+        profile = self.get_success(self.handler.get_profile(self.frank.to_string()))
+
+        self.assertEqual("http://my.server/me.png", profile["avatar_url"])
+
     def test_set_my_avatar(self) -> None:
         self.get_success(
             self.handler.set_avatar_url(
@@ -201,7 +196,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEqual(
-            (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
+            (self.get_success(self.store.get_profile_avatar_url(self.frank))),
             "http://my.server/pic.gif",
         )
 
@@ -215,7 +210,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEqual(
-            (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
+            (self.get_success(self.store.get_profile_avatar_url(self.frank))),
             "http://my.server/me.png",
         )
 
@@ -229,7 +224,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertIsNone(
-            (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
+            (self.get_success(self.store.get_profile_avatar_url(self.frank))),
         )
 
     def test_set_my_avatar_if_disabled(self) -> None:
@@ -241,7 +236,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEqual(
-            (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
+            (self.get_success(self.store.get_profile_avatar_url(self.frank))),
             "http://my.server/me.png",
         )
 
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 73822b07a5..e9fbf32c7c 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -13,11 +13,11 @@
 # limitations under the License.
 
 from typing import Any, Collection, List, Optional, Tuple
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
-from synapse.api.auth import Auth
+from synapse.api.auth.internal import InternalAuth
 from synapse.api.constants import UserTypes
 from synapse.api.errors import (
     CodeMessageException,
@@ -38,7 +38,6 @@ from synapse.types import (
 )
 from synapse.util import Clock
 
-from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 from tests.utils import mock_getRawHeaders
 
@@ -203,24 +202,22 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
 
     @override_config({"limit_usage_by_mau": True})
     def test_get_or_create_user_mau_not_blocked(self) -> None:
-        self.store.count_monthly_users = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
+        self.store.count_monthly_users = AsyncMock(  # type: ignore[method-assign]
+            return_value=self.hs.config.server.max_mau_value - 1
         )
         # Ensure does not throw exception
         self.get_success(self.get_or_create_user(self.requester, "c", "User"))
 
     @override_config({"limit_usage_by_mau": True})
     def test_get_or_create_user_mau_blocked(self) -> None:
-        self.store.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.lots_of_users)
-        )
+        self.store.get_monthly_active_count = AsyncMock(return_value=self.lots_of_users)
         self.get_failure(
             self.get_or_create_user(self.requester, "b", "display_name"),
             ResourceLimitError,
         )
 
-        self.store.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.hs.config.server.max_mau_value)
+        self.store.get_monthly_active_count = AsyncMock(
+            return_value=self.hs.config.server.max_mau_value
         )
         self.get_failure(
             self.get_or_create_user(self.requester, "b", "display_name"),
@@ -229,15 +226,13 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
 
     @override_config({"limit_usage_by_mau": True})
     def test_register_mau_blocked(self) -> None:
-        self.store.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.lots_of_users)
-        )
+        self.store.get_monthly_active_count = AsyncMock(return_value=self.lots_of_users)
         self.get_failure(
             self.handler.register_user(localpart="local_part"), ResourceLimitError
         )
 
-        self.store.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.hs.config.server.max_mau_value)
+        self.store.get_monthly_active_count = AsyncMock(
+            return_value=self.hs.config.server.max_mau_value
         )
         self.get_failure(
             self.handler.register_user(localpart="local_part"), ResourceLimitError
@@ -292,7 +287,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
     @override_config({"auto_join_rooms": ["#room:test"]})
     def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self) -> None:
         room_alias_str = "#room:test"
-        self.store.is_real_user = Mock(return_value=make_awaitable(False))
+        self.store.is_real_user = AsyncMock(return_value=False)
         user_id = self.get_success(self.handler.register_user(localpart="support"))
         rooms = self.get_success(self.store.get_rooms_for_user(user_id))
         self.assertEqual(len(rooms), 0)
@@ -304,8 +299,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
     def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None:
         room_alias_str = "#room:test"
 
-        self.store.count_real_users = Mock(return_value=make_awaitable(1))  # type: ignore[assignment]
-        self.store.is_real_user = Mock(return_value=make_awaitable(True))
+        self.store.count_real_users = AsyncMock(return_value=1)  # type: ignore[method-assign]
+        self.store.is_real_user = AsyncMock(return_value=True)
         user_id = self.get_success(self.handler.register_user(localpart="real"))
         rooms = self.get_success(self.store.get_rooms_for_user(user_id))
         directory_handler = self.hs.get_directory_handler()
@@ -319,8 +314,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
     def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(
         self,
     ) -> None:
-        self.store.count_real_users = Mock(return_value=make_awaitable(2))  # type: ignore[assignment]
-        self.store.is_real_user = Mock(return_value=make_awaitable(True))
+        self.store.count_real_users = AsyncMock(return_value=2)  # type: ignore[method-assign]
+        self.store.is_real_user = AsyncMock(return_value=True)
         user_id = self.get_success(self.handler.register_user(localpart="real"))
         rooms = self.get_success(self.store.get_rooms_for_user(user_id))
         self.assertEqual(len(rooms), 0)
@@ -587,17 +582,16 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         self.assertFalse(self.get_success(d))
 
     def test_invalid_user_id(self) -> None:
-        invalid_user_id = "+abcd"
+        invalid_user_id = "^abcd"
         self.get_failure(
             self.handler.register_user(localpart=invalid_user_id), SynapseError
         )
 
-    @override_config({"experimental_features": {"msc4009_e164_mxids": True}})
-    def text_extended_user_ids(self) -> None:
-        """+ should be allowed according to MSC4009."""
-        valid_user_id = "+1234"
+    def test_special_chars(self) -> None:
+        """Ensure that characters which are allowed in Matrix IDs work."""
+        valid_user_id = "a1234_-./=+"
         user_id = self.get_success(self.handler.register_user(localpart=valid_user_id))
-        self.assertEqual(user_id, valid_user_id)
+        self.assertEqual(user_id, f"@{valid_user_id}:test")
 
     def test_invalid_user_id_length(self) -> None:
         invalid_user_id = "x" * 256
@@ -683,7 +677,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         request = Mock(args={})
         request.args[b"access_token"] = [token.encode("ascii")]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        auth = Auth(self.hs)
+        auth = InternalAuth(self.hs)
         requester = self.get_success(auth.get_user_by_req(request))
 
         self.assertTrue(requester.shadow_banned)
diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py
index a444d822cd..3e28117e2c 100644
--- a/tests/handlers/test_room_member.py
+++ b/tests/handlers/test_room_member.py
@@ -1,4 +1,4 @@
-from unittest.mock import Mock, patch
+from unittest.mock import AsyncMock, patch
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -16,7 +16,6 @@ from synapse.util import Clock
 
 from tests.replication._base import BaseMultiWorkerStreamTestCase
 from tests.server import make_request
-from tests.test_utils import make_awaitable
 from tests.unittest import (
     FederatingHomeserverTestCase,
     HomeserverTestCase,
@@ -154,25 +153,21 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
             None,
         )
 
-        mock_make_membership_event = Mock(
-            return_value=make_awaitable(
-                (
-                    self.OTHER_SERVER_NAME,
-                    join_event,
-                    self.hs.config.server.default_room_version,
-                )
+        mock_make_membership_event = AsyncMock(
+            return_value=(
+                self.OTHER_SERVER_NAME,
+                join_event,
+                self.hs.config.server.default_room_version,
             )
         )
-        mock_send_join = Mock(
-            return_value=make_awaitable(
-                SendJoinResult(
-                    join_event,
-                    self.OTHER_SERVER_NAME,
-                    state=[create_event],
-                    auth_chain=[create_event],
-                    partial_state=False,
-                    servers_in_room=frozenset(),
-                )
+        mock_send_join = AsyncMock(
+            return_value=SendJoinResult(
+                join_event,
+                self.OTHER_SERVER_NAME,
+                state=[create_event],
+                auth_chain=[create_event],
+                partial_state=False,
+                servers_in_room=frozenset(),
             )
         )
 
@@ -333,6 +328,27 @@ class RoomMemberMasterHandlerTestCase(HomeserverTestCase):
             self.get_success(self.store.is_locally_forgotten_room(self.room_id))
         )
 
+    def test_leave_and_unforget(self) -> None:
+        """Tests if rejoining a room unforgets the room, so that it shows up in sync again."""
+        self.helper.join(self.room_id, user=self.bob, tok=self.bob_token)
+
+        # alice is not the last room member that leaves and forgets the room
+        self.helper.leave(self.room_id, user=self.alice, tok=self.alice_token)
+        self.get_success(self.handler.forget(self.alice_ID, self.room_id))
+        self.assertTrue(
+            self.get_success(self.store.did_forget(self.alice, self.room_id))
+        )
+
+        self.helper.join(self.room_id, user=self.alice, tok=self.alice_token)
+        self.assertFalse(
+            self.get_success(self.store.did_forget(self.alice, self.room_id))
+        )
+
+        # the server has not forgotten the room
+        self.assertFalse(
+            self.get_success(self.store.is_locally_forgotten_room(self.room_id))
+        )
+
     @override_config({"forget_rooms_on_leave": True})
     def test_leave_and_auto_forget(self) -> None:
         """Tests the `forget_rooms_on_leave` config option."""
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index b5c772a7ae..00f4e181e8 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -13,7 +13,7 @@
 #  limitations under the License.
 
 from typing import Any, Dict, Optional, Set, Tuple
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 import attr
 
@@ -25,7 +25,6 @@ from synapse.server import HomeServer
 from synapse.types import JsonDict
 from synapse.util import Clock
 
-from tests.test_utils import simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
 
 # Check if we have the dependencies to run the tests.
@@ -134,7 +133,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
+        auth_handler.complete_sso_login = AsyncMock()  # type: ignore[method-assign]
 
         # send a mocked-up SAML response to the callback
         saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
@@ -164,7 +163,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
+        auth_handler.complete_sso_login = AsyncMock()  # type: ignore[method-assign]
 
         # Map a user via SSO.
         saml_response = FakeAuthnResponse(
@@ -206,11 +205,11 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
+        auth_handler.complete_sso_login = AsyncMock()  # type: ignore[method-assign]
 
         # mock out the error renderer too
         sso_handler = self.hs.get_sso_handler()
-        sso_handler.render_error = Mock(return_value=None)  # type: ignore[assignment]
+        sso_handler.render_error = Mock(return_value=None)  # type: ignore[method-assign]
 
         saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
         request = _mock_request()
@@ -227,9 +226,9 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler and error renderer
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
+        auth_handler.complete_sso_login = AsyncMock()  # type: ignore[method-assign]
         sso_handler = self.hs.get_sso_handler()
-        sso_handler.render_error = Mock(return_value=None)  # type: ignore[assignment]
+        sso_handler.render_error = Mock(return_value=None)  # type: ignore[method-assign]
 
         # register a user to occupy the first-choice MXID
         store = self.hs.get_datastores().main
@@ -312,7 +311,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
+        auth_handler.complete_sso_login = AsyncMock()  # type: ignore[method-assign]
 
         # The response doesn't have the proper userGroup or department.
         saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
diff --git a/tests/handlers/test_send_email.py b/tests/handlers/test_send_email.py
index 8b6e4a40b6..a066745d70 100644
--- a/tests/handlers/test_send_email.py
+++ b/tests/handlers/test_send_email.py
@@ -13,19 +13,40 @@
 # limitations under the License.
 
 
-from typing import Callable, List, Tuple
+from typing import Callable, List, Tuple, Type, Union
+from unittest.mock import patch
 
 from zope.interface import implementer
 
 from twisted.internet import defer
-from twisted.internet.address import IPv4Address
+from twisted.internet._sslverify import ClientTLSOptions
+from twisted.internet.address import IPv4Address, IPv6Address
 from twisted.internet.defer import ensureDeferred
+from twisted.internet.interfaces import IProtocolFactory
+from twisted.internet.ssl import ContextFactory
 from twisted.mail import interfaces, smtp
 
 from tests.server import FakeTransport
 from tests.unittest import HomeserverTestCase, override_config
 
 
+def TestingESMTPTLSClientFactory(
+    contextFactory: ContextFactory,
+    _connectWrapped: bool,
+    wrappedProtocol: IProtocolFactory,
+) -> IProtocolFactory:
+    """We use this to pass through in testing without using TLS, but
+    saving the context information to check that it would have happened.
+
+    Note that this is what the MemoryReactor does on connectSSL.
+    It only saves the contextFactory, but starts the connection with the
+    underlying Factory.
+    See: L{twisted.internet.testing.MemoryReactor.connectSSL}"""
+
+    wrappedProtocol._testingContextFactory = contextFactory  # type: ignore[attr-defined]
+    return wrappedProtocol
+
+
 @implementer(interfaces.IMessageDelivery)
 class _DummyMessageDelivery:
     def __init__(self) -> None:
@@ -75,7 +96,13 @@ class _DummyMessage:
         pass
 
 
-class SendEmailHandlerTestCase(HomeserverTestCase):
+class SendEmailHandlerTestCaseIPv4(HomeserverTestCase):
+    ip_class: Union[Type[IPv4Address], Type[IPv6Address]] = IPv4Address
+
+    def setUp(self) -> None:
+        super().setUp()
+        self.reactor.lookups["localhost"] = "127.0.0.1"
+
     def test_send_email(self) -> None:
         """Happy-path test that we can send email to a non-TLS server."""
         h = self.hs.get_send_email_handler()
@@ -89,7 +116,7 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
         (host, port, client_factory, _timeout, _bindAddress) = self.reactor.tcpClients[
             0
         ]
-        self.assertEqual(host, "localhost")
+        self.assertEqual(host, self.reactor.lookups["localhost"])
         self.assertEqual(port, 25)
 
         # wire it up to an SMTP server
@@ -105,7 +132,9 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
             FakeTransport(
                 client_protocol,
                 self.reactor,
-                peer_address=IPv4Address("TCP", "127.0.0.1", 1234),
+                peer_address=self.ip_class(
+                    "TCP", self.reactor.lookups["localhost"], 1234
+                ),
             )
         )
 
@@ -118,6 +147,10 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
         self.assertEqual(str(user), "foo@bar.com")
         self.assertIn(b"Subject: test subject", msg)
 
+    @patch(
+        "synapse.handlers.send_email.TLSMemoryBIOFactory",
+        TestingESMTPTLSClientFactory,
+    )
     @override_config(
         {
             "email": {
@@ -135,17 +168,23 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
             )
         )
         # there should be an attempt to connect to localhost:465
-        self.assertEqual(len(self.reactor.sslClients), 1)
+        self.assertEqual(len(self.reactor.tcpClients), 1)
         (
             host,
             port,
             client_factory,
-            contextFactory,
             _timeout,
             _bindAddress,
-        ) = self.reactor.sslClients[0]
-        self.assertEqual(host, "localhost")
+        ) = self.reactor.tcpClients[0]
+        self.assertEqual(host, self.reactor.lookups["localhost"])
         self.assertEqual(port, 465)
+        # We need to make sure that TLS is happenning
+        self.assertIsInstance(
+            client_factory._wrappedFactory._testingContextFactory,
+            ClientTLSOptions,
+        )
+        # And since we use endpoints, they go through reactor.connectTCP
+        # which works differently to connectSSL on the testing reactor
 
         # wire it up to an SMTP server
         message_delivery = _DummyMessageDelivery()
@@ -160,7 +199,9 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
             FakeTransport(
                 client_protocol,
                 self.reactor,
-                peer_address=IPv4Address("TCP", "127.0.0.1", 1234),
+                peer_address=self.ip_class(
+                    "TCP", self.reactor.lookups["localhost"], 1234
+                ),
             )
         )
 
@@ -172,3 +213,11 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
         user, msg = message_delivery.messages.pop()
         self.assertEqual(str(user), "foo@bar.com")
         self.assertIn(b"Subject: test subject", msg)
+
+
+class SendEmailHandlerTestCaseIPv6(SendEmailHandlerTestCaseIPv4):
+    ip_class = IPv6Address
+
+    def setUp(self) -> None:
+        super().setUp()
+        self.reactor.lookups["localhost"] = "::1"
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 0d9a3de92a..948d04fc32 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from typing import Optional
-from unittest.mock import MagicMock, Mock, patch
+from unittest.mock import AsyncMock, Mock, patch
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -29,7 +29,6 @@ from synapse.util import Clock
 
 import tests.unittest
 import tests.utils
-from tests.test_utils import make_awaitable
 
 
 class SyncTestCase(tests.unittest.HomeserverTestCase):
@@ -163,7 +162,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
         # Blow away caches (supported room versions can only change due to a restart).
         self.store.get_rooms_for_user_with_stream_ordering.invalidate_all()
         self.store.get_rooms_for_user.invalidate_all()
-        self.get_success(self.store._get_event_cache.clear())
+        self.store._get_event_cache.clear()
         self.store._event_ref.clear()
 
         # The rooms should be excluded from the sync response.
@@ -253,8 +252,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
         mocked_get_prev_events = patch.object(
             self.hs.get_datastores().main,
             "get_prev_events_for_room",
-            new_callable=MagicMock,
-            return_value=make_awaitable([last_room_creation_event_id]),
+            new_callable=AsyncMock,
+            return_value=[last_room_creation_event_id],
         )
         with mocked_get_prev_events:
             self.helper.join(room_id, eve, tok=eve_token)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 94518a7196..3060bc9744 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -15,7 +15,9 @@
 
 import json
 from typing import Dict, List, Set
-from unittest.mock import ANY, Mock, call
+from unittest.mock import ANY, AsyncMock, Mock, call
+
+from netaddr import IPSet
 
 from twisted.test.proto_helpers import MemoryReactor
 from twisted.web.resource import Resource
@@ -24,13 +26,13 @@ from synapse.api.constants import EduTypes
 from synapse.api.errors import AuthError
 from synapse.federation.transport.server import TransportLayerServer
 from synapse.handlers.typing import TypingWriterHandler
+from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
 from synapse.server import HomeServer
-from synapse.types import JsonDict, Requester, UserID, create_requester
+from synapse.types import JsonDict, Requester, StreamKeyType, UserID, create_requester
 from synapse.util import Clock
 
 from tests import unittest
 from tests.server import ThreadedMemoryReactorClock
-from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 
 # Some local users to test with
@@ -71,11 +73,18 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         # we mock out the keyring so as to skip the authentication check on the
         # federation API call.
         mock_keyring = Mock(spec=["verify_json_for_server"])
-        mock_keyring.verify_json_for_server.return_value = make_awaitable(True)
+        mock_keyring.verify_json_for_server = AsyncMock(return_value=True)
 
         # we mock out the federation client too
-        self.mock_federation_client = Mock(spec=["put_json"])
-        self.mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))
+        self.mock_federation_client = AsyncMock(spec=["put_json"])
+        self.mock_federation_client.put_json.return_value = (200, "OK")
+        self.mock_federation_client.agent = MatrixFederationAgent(
+            reactor,
+            tls_client_options_factory=None,
+            user_agent=b"SynapseInTrialTest/0.0.0",
+            ip_allowlist=None,
+            ip_blocklist=IPSet(),
+        )
 
         # the tests assume that we are starting at unix time 1000
         reactor.pump((1000,))
@@ -111,20 +120,16 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.datastore = hs.get_datastores().main
 
-        self.datastore.get_destination_retry_timings = Mock(
-            return_value=make_awaitable(None)
-        )
-
-        self.datastore.get_device_updates_by_remote = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable((0, []))
+        self.datastore.get_device_updates_by_remote = AsyncMock(  # type: ignore[method-assign]
+            return_value=(0, [])
         )
 
-        self.datastore.get_destination_last_successful_stream_ordering = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(None)
+        self.datastore.get_destination_last_successful_stream_ordering = AsyncMock(  # type: ignore[method-assign]
+            return_value=None
         )
 
-        self.datastore.get_received_txn_response = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(None)
+        self.datastore.get_received_txn_response = AsyncMock(  # type: ignore[method-assign]
+            return_value=None
         )
 
         self.room_members: List[UserID] = []
@@ -136,25 +141,25 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
                 raise AuthError(401, "User is not in the room")
             return None
 
-        hs.get_auth().check_user_in_room = Mock(  # type: ignore[assignment]
+        hs.get_auth().check_user_in_room = Mock(  # type: ignore[method-assign]
             side_effect=check_user_in_room
         )
 
         async def check_host_in_room(room_id: str, server_name: str) -> bool:
             return room_id == ROOM_ID
 
-        hs.get_event_auth_handler().is_host_in_room = Mock(  # type: ignore[assignment]
+        hs.get_event_auth_handler().is_host_in_room = Mock(  # type: ignore[method-assign]
             side_effect=check_host_in_room
         )
 
         async def get_current_hosts_in_room(room_id: str) -> Set[str]:
             return {member.domain for member in self.room_members}
 
-        hs.get_storage_controllers().state.get_current_hosts_in_room = Mock(  # type: ignore[assignment]
+        hs.get_storage_controllers().state.get_current_hosts_in_room = Mock(  # type: ignore[method-assign]
             side_effect=get_current_hosts_in_room
         )
 
-        hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = Mock(  # type: ignore[assignment]
+        hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = Mock(  # type: ignore[method-assign]
             side_effect=get_current_hosts_in_room
         )
 
@@ -163,27 +168,25 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.datastore.get_users_in_room = Mock(side_effect=get_users_in_room)
 
-        self.datastore.get_user_directory_stream_pos = Mock(  # type: ignore[assignment]
-            side_effect=(
-                # we deliberately return a non-None stream pos to avoid
-                # doing an initial_sync
-                lambda: make_awaitable(1)
-            )
+        self.datastore.get_user_directory_stream_pos = AsyncMock(  # type: ignore[method-assign]
+            # we deliberately return a non-None stream pos to avoid
+            # doing an initial_sync
+            return_value=1
         )
 
-        self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None))  # type: ignore[assignment]
+        self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None))  # type: ignore[method-assign]
 
-        self.datastore.get_to_device_stream_token = Mock(  # type: ignore[assignment]
-            side_effect=lambda: 0
+        self.datastore.get_to_device_stream_token = Mock(  # type: ignore[method-assign]
+            return_value=0
         )
-        self.datastore.get_new_device_msgs_for_remote = Mock(  # type: ignore[assignment]
-            side_effect=lambda *args, **kargs: make_awaitable(([], 0))
+        self.datastore.get_new_device_msgs_for_remote = AsyncMock(  # type: ignore[method-assign]
+            return_value=([], 0)
         )
-        self.datastore.delete_device_msgs_for_remote = Mock(  # type: ignore[assignment]
-            side_effect=lambda *args, **kargs: make_awaitable(None)
+        self.datastore.delete_device_msgs_for_remote = AsyncMock(  # type: ignore[method-assign]
+            return_value=None
         )
-        self.datastore.set_received_txn_response = Mock(  # type: ignore[assignment]
-            side_effect=lambda *args, **kwargs: make_awaitable(None)
+        self.datastore.set_received_txn_response = AsyncMock(  # type: ignore[method-assign]
+            return_value=None
         )
 
     def test_started_typing_local(self) -> None:
@@ -200,7 +203,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
+        self.on_new_event.assert_has_calls(
+            [call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])]
+        )
 
         self.assertEqual(self.event_source.get_current_key(), 1)
         events = self.get_success(
@@ -246,8 +251,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             ),
             json_data_callback=ANY,
             long_retries=True,
-            backoff_on_404=True,
             try_trailing_slash_on_400=True,
+            backoff_on_all_error_codes=True,
         )
 
     def test_started_typing_remote_recv(self) -> None:
@@ -270,7 +275,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(channel.code, 200)
 
-        self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
+        self.on_new_event.assert_has_calls(
+            [call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])]
+        )
 
         self.assertEqual(self.event_source.get_current_key(), 1)
         events = self.get_success(
@@ -346,7 +353,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
+        self.on_new_event.assert_has_calls(
+            [call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])]
+        )
 
         self.mock_federation_client.put_json.assert_called_once_with(
             "farm",
@@ -361,7 +370,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             ),
             json_data_callback=ANY,
             long_retries=True,
-            backoff_on_404=True,
+            backoff_on_all_error_codes=True,
             try_trailing_slash_on_400=True,
         )
 
@@ -396,7 +405,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
+        self.on_new_event.assert_has_calls(
+            [call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])]
+        )
         self.on_new_event.reset_mock()
 
         self.assertEqual(self.event_source.get_current_key(), 1)
@@ -422,7 +433,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.reactor.pump([16])
 
-        self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])])
+        self.on_new_event.assert_has_calls(
+            [call(StreamKeyType.TYPING, 2, rooms=[ROOM_ID])]
+        )
 
         self.assertEqual(self.event_source.get_current_key(), 2)
         events = self.get_success(
@@ -456,7 +469,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        self.on_new_event.assert_has_calls([call("typing_key", 3, rooms=[ROOM_ID])])
+        self.on_new_event.assert_has_calls(
+            [call(StreamKeyType.TYPING, 3, rooms=[ROOM_ID])]
+        )
         self.on_new_event.reset_mock()
 
         self.assertEqual(self.event_source.get_current_key(), 3)
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 15a7dc6818..b5f15aa7d4 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from typing import Any, Tuple
-from unittest.mock import Mock, patch
+from unittest.mock import AsyncMock, Mock, patch
 from urllib.parse import quote
 
 from twisted.test.proto_helpers import MemoryReactor
@@ -30,7 +30,7 @@ from synapse.util import Clock
 
 from tests import unittest
 from tests.storage.test_user_directory import GetUserDirectoryTables
-from tests.test_utils import event_injection, make_awaitable
+from tests.test_utils import event_injection
 from tests.test_utils.event_injection import inject_member_event
 from tests.unittest import override_config
 
@@ -356,7 +356,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
                 support_user_id, ProfileInfo("I love support me", None)
             )
         )
-        profile = self.get_success(self.store.get_user_in_directory(support_user_id))
+        profile = self.get_success(self.store._get_user_in_directory(support_user_id))
         self.assertIsNone(profile)
         display_name = "display_name"
 
@@ -364,7 +364,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         self.get_success(
             self.handler.handle_local_profile_change(regular_user_id, profile_info)
         )
-        profile = self.get_success(self.store.get_user_in_directory(regular_user_id))
+        profile = self.get_success(self.store._get_user_in_directory(regular_user_id))
         assert profile is not None
         self.assertTrue(profile["display_name"] == display_name)
 
@@ -383,7 +383,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         )
 
         # profile is in directory
-        profile = self.get_success(self.store.get_user_in_directory(r_user_id))
+        profile = self.get_success(self.store._get_user_in_directory(r_user_id))
         assert profile is not None
         self.assertTrue(profile["display_name"] == display_name)
 
@@ -392,7 +392,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         self.get_success(self.handler.handle_local_user_deactivated(r_user_id))
 
         # profile is not in directory
-        profile = self.get_success(self.store.get_user_in_directory(r_user_id))
+        profile = self.get_success(self.store._get_user_in_directory(r_user_id))
         self.assertIsNone(profile)
 
         # update profile after deactivation
@@ -401,7 +401,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         )
 
         # profile is furthermore not in directory
-        profile = self.get_success(self.store.get_user_in_directory(r_user_id))
+        profile = self.get_success(self.store._get_user_in_directory(r_user_id))
         self.assertIsNone(profile)
 
     def test_handle_local_profile_change_with_appservice_user(self) -> None:
@@ -411,7 +411,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         )
 
         # profile is not in directory
-        profile = self.get_success(self.store.get_user_in_directory(as_user_id))
+        profile = self.get_success(self.store._get_user_in_directory(as_user_id))
         self.assertIsNone(profile)
 
         # update profile
@@ -421,13 +421,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         )
 
         # profile is still not in directory
-        profile = self.get_success(self.store.get_user_in_directory(as_user_id))
+        profile = self.get_success(self.store._get_user_in_directory(as_user_id))
         self.assertIsNone(profile)
 
     def test_handle_local_profile_change_with_appservice_sender(self) -> None:
         # profile is not in directory
         profile = self.get_success(
-            self.store.get_user_in_directory(self.appservice.sender)
+            self.store._get_user_in_directory(self.appservice.sender)
         )
         self.assertIsNone(profile)
 
@@ -441,11 +441,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
 
         # profile is still not in directory
         profile = self.get_success(
-            self.store.get_user_in_directory(self.appservice.sender)
+            self.store._get_user_in_directory(self.appservice.sender)
         )
         self.assertIsNone(profile)
 
     def test_handle_user_deactivated_support_user(self) -> None:
+        """Ensure a support user doesn't get added to the user directory after deactivation."""
         s_user_id = "@support:test"
         self.get_success(
             self.store.register_user(
@@ -453,14 +454,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        mock_remove_from_user_dir = Mock(return_value=make_awaitable(None))
-        with patch.object(
-            self.store, "remove_from_user_dir", mock_remove_from_user_dir
-        ):
-            self.get_success(self.handler.handle_local_user_deactivated(s_user_id))
-        # BUG: the correct spelling is assert_not_called, but that makes the test fail
-        # and it's not clear that this is actually the behaviour we want.
-        mock_remove_from_user_dir.not_called()
+        # The profile should not be in the directory.
+        profile = self.get_success(self.store._get_user_in_directory(s_user_id))
+        self.assertIsNone(profile)
+
+        # Remove the user from the directory.
+        self.get_success(self.handler.handle_local_user_deactivated(s_user_id))
+
+        # The profile should still not be in the user directory.
+        profile = self.get_success(self.store._get_user_in_directory(s_user_id))
+        self.assertIsNone(profile)
 
     def test_handle_user_deactivated_regular_user(self) -> None:
         r_user_id = "@regular:test"
@@ -468,7 +471,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             self.store.register_user(user_id=r_user_id, password_hash=None)
         )
 
-        mock_remove_from_user_dir = Mock(return_value=make_awaitable(None))
+        mock_remove_from_user_dir = AsyncMock(return_value=None)
         with patch.object(
             self.store, "remove_from_user_dir", mock_remove_from_user_dir
         ):
diff --git a/tests/handlers/test_worker_lock.py b/tests/handlers/test_worker_lock.py
new file mode 100644
index 0000000000..73e548726c
--- /dev/null
+++ b/tests/handlers/test_worker_lock.py
@@ -0,0 +1,74 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests import unittest
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+
+
+class WorkerLockTestCase(unittest.HomeserverTestCase):
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+    ) -> None:
+        self.worker_lock_handler = self.hs.get_worker_locks_handler()
+
+    def test_wait_for_lock_locally(self) -> None:
+        """Test waiting for a lock on a single worker"""
+
+        lock1 = self.worker_lock_handler.acquire_lock("name", "key")
+        self.get_success(lock1.__aenter__())
+
+        lock2 = self.worker_lock_handler.acquire_lock("name", "key")
+        d2 = defer.ensureDeferred(lock2.__aenter__())
+        self.assertNoResult(d2)
+
+        self.get_success(lock1.__aexit__(None, None, None))
+
+        self.get_success(d2)
+        self.get_success(lock2.__aexit__(None, None, None))
+
+
+class WorkerLockWorkersTestCase(BaseMultiWorkerStreamTestCase):
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+    ) -> None:
+        self.main_worker_lock_handler = self.hs.get_worker_locks_handler()
+
+    def test_wait_for_lock_worker(self) -> None:
+        """Test waiting for a lock on another worker"""
+
+        worker = self.make_worker_hs(
+            "synapse.app.generic_worker",
+            extra_config={
+                "redis": {"enabled": True},
+            },
+        )
+        worker_lock_handler = worker.get_worker_locks_handler()
+
+        lock1 = self.main_worker_lock_handler.acquire_lock("name", "key")
+        self.get_success(lock1.__aenter__())
+
+        lock2 = worker_lock_handler.acquire_lock("name", "key")
+        d2 = defer.ensureDeferred(lock2.__aenter__())
+        self.assertNoResult(d2)
+
+        self.get_success(lock1.__aexit__(None, None, None))
+
+        self.get_success(d2)
+        self.get_success(lock2.__aexit__(None, None, None))
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 105b4caefa..9f63fa6fa8 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -14,8 +14,8 @@
 import base64
 import logging
 import os
-from typing import Any, Awaitable, Callable, Generator, List, Optional, cast
-from unittest.mock import Mock, patch
+from typing import Generator, List, Optional, cast
+from unittest.mock import AsyncMock, call, patch
 
 import treq
 from netaddr import IPSet
@@ -41,7 +41,7 @@ from twisted.web.iweb import IPolicyForHTTPS, IResponse
 from synapse.config.homeserver import HomeServerConfig
 from synapse.crypto.context_factory import FederationPolicyForHTTPS
 from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
-from synapse.http.federation.srv_resolver import Server
+from synapse.http.federation.srv_resolver import Server, SrvResolver
 from synapse.http.federation.well_known_resolver import (
     WELL_KNOWN_MAX_SIZE,
     WellKnownResolver,
@@ -68,21 +68,11 @@ from tests.utils import checked_cast, default_config
 logger = logging.getLogger(__name__)
 
 
-# Once Async Mocks or lambdas are supported this can go away.
-def generate_resolve_service(
-    result: List[Server],
-) -> Callable[[Any], Awaitable[List[Server]]]:
-    async def resolve_service(_: Any) -> List[Server]:
-        return result
-
-    return resolve_service
-
-
 class MatrixFederationAgentTests(unittest.TestCase):
     def setUp(self) -> None:
         self.reactor = ThreadedMemoryReactorClock()
 
-        self.mock_resolver = Mock()
+        self.mock_resolver = AsyncMock(spec=SrvResolver)
 
         config_dict = default_config("test", parse=False)
         config_dict["federation_custom_ca_list"] = [get_test_ca_cert_file()]
@@ -292,7 +282,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self.agent = self._make_agent()
 
         self.reactor.lookups["testserv"] = "1.2.3.4"
-        test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar")
+        test_d = self._make_get_request(b"matrix-federation://testserv:8448/foo/bar")
 
         # Nothing happened yet
         self.assertNoResult(test_d)
@@ -393,7 +383,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
 
         self.reactor.lookups["testserv"] = "1.2.3.4"
         self.reactor.lookups["proxy.com"] = "9.9.9.9"
-        test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar")
+        test_d = self._make_get_request(b"matrix-federation://testserv:8448/foo/bar")
 
         # Nothing happened yet
         self.assertNoResult(test_d)
@@ -514,7 +504,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self.assertEqual(response.code, 200)
 
         # Send the body
-        request.write('{ "a": 1 }'.encode("ascii"))
+        request.write(b'{ "a": 1 }')
         request.finish()
 
         self.reactor.pump((0.1,))
@@ -532,7 +522,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         # there will be a getaddrinfo on the IP
         self.reactor.lookups["1.2.3.4"] = "1.2.3.4"
 
-        test_d = self._make_get_request(b"matrix://1.2.3.4/foo/bar")
+        test_d = self._make_get_request(b"matrix-federation://1.2.3.4/foo/bar")
 
         # Nothing happened yet
         self.assertNoResult(test_d)
@@ -568,7 +558,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         # there will be a getaddrinfo on the IP
         self.reactor.lookups["::1"] = "::1"
 
-        test_d = self._make_get_request(b"matrix://[::1]/foo/bar")
+        test_d = self._make_get_request(b"matrix-federation://[::1]/foo/bar")
 
         # Nothing happened yet
         self.assertNoResult(test_d)
@@ -604,7 +594,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         # there will be a getaddrinfo on the IP
         self.reactor.lookups["::1"] = "::1"
 
-        test_d = self._make_get_request(b"matrix://[::1]:80/foo/bar")
+        test_d = self._make_get_request(b"matrix-federation://[::1]:80/foo/bar")
 
         # Nothing happened yet
         self.assertNoResult(test_d)
@@ -636,10 +626,10 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """
         self.agent = self._make_agent()
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
+        self.mock_resolver.resolve_service.return_value = []
         self.reactor.lookups["testserv1"] = "1.2.3.4"
 
-        test_d = self._make_get_request(b"matrix://testserv1/foo/bar")
+        test_d = self._make_get_request(b"matrix-federation://testserv1/foo/bar")
 
         # Nothing happened yet
         self.assertNoResult(test_d)
@@ -661,9 +651,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
         # .well-known request fails.
         self.reactor.pump((0.4,))
 
-        # now there should be a SRV lookup
-        self.mock_resolver.resolve_service.assert_called_once_with(
-            b"_matrix._tcp.testserv1"
+        # now there should be two SRV lookups
+        self.mock_resolver.resolve_service.assert_has_calls(
+            [call(b"_matrix-fed._tcp.testserv1"), call(b"_matrix._tcp.testserv1")]
         )
 
         # we should fall back to a direct connection
@@ -693,7 +683,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         # there will be a getaddrinfo on the IP
         self.reactor.lookups["1.2.3.5"] = "1.2.3.5"
 
-        test_d = self._make_get_request(b"matrix://1.2.3.5/foo/bar")
+        test_d = self._make_get_request(b"matrix-federation://1.2.3.5/foo/bar")
 
         # Nothing happened yet
         self.assertNoResult(test_d)
@@ -722,10 +712,10 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """
         self.agent = self._make_agent()
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
+        self.mock_resolver.resolve_service.return_value = []
         self.reactor.lookups["testserv"] = "1.2.3.4"
 
-        test_d = self._make_get_request(b"matrix://testserv/foo/bar")
+        test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
 
         # Nothing happened yet
         self.assertNoResult(test_d)
@@ -747,9 +737,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
         # .well-known request fails.
         self.reactor.pump((0.4,))
 
-        # now there should be a SRV lookup
-        self.mock_resolver.resolve_service.assert_called_once_with(
-            b"_matrix._tcp.testserv"
+        # now there should be two SRV lookups
+        self.mock_resolver.resolve_service.assert_has_calls(
+            [call(b"_matrix-fed._tcp.testserv"), call(b"_matrix._tcp.testserv")]
         )
 
         # we should fall back to a direct connection
@@ -776,11 +766,11 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """Test the behaviour when the .well-known delegates elsewhere"""
         self.agent = self._make_agent()
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
+        self.mock_resolver.resolve_service.return_value = []
         self.reactor.lookups["testserv"] = "1.2.3.4"
         self.reactor.lookups["target-server"] = "1::f"
 
-        test_d = self._make_get_request(b"matrix://testserv/foo/bar")
+        test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
 
         # Nothing happened yet
         self.assertNoResult(test_d)
@@ -798,9 +788,12 @@ class MatrixFederationAgentTests(unittest.TestCase):
             content=b'{ "m.server": "target-server" }',
         )
 
-        # there should be a SRV lookup
-        self.mock_resolver.resolve_service.assert_called_once_with(
-            b"_matrix._tcp.target-server"
+        # there should be two SRV lookups
+        self.mock_resolver.resolve_service.assert_has_calls(
+            [
+                call(b"_matrix-fed._tcp.target-server"),
+                call(b"_matrix._tcp.target-server"),
+            ]
         )
 
         # now we should get a connection to the target server
@@ -840,11 +833,11 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """
         self.agent = self._make_agent()
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
+        self.mock_resolver.resolve_service.return_value = []
         self.reactor.lookups["testserv"] = "1.2.3.4"
         self.reactor.lookups["target-server"] = "1::f"
 
-        test_d = self._make_get_request(b"matrix://testserv/foo/bar")
+        test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
 
         # Nothing happened yet
         self.assertNoResult(test_d)
@@ -888,9 +881,12 @@ class MatrixFederationAgentTests(unittest.TestCase):
 
         self.reactor.pump((0.1,))
 
-        # there should be a SRV lookup
-        self.mock_resolver.resolve_service.assert_called_once_with(
-            b"_matrix._tcp.target-server"
+        # there should be two SRV lookups
+        self.mock_resolver.resolve_service.assert_has_calls(
+            [
+                call(b"_matrix-fed._tcp.target-server"),
+                call(b"_matrix._tcp.target-server"),
+            ]
         )
 
         # now we should get a connection to the target server
@@ -930,10 +926,10 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """
         self.agent = self._make_agent()
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
+        self.mock_resolver.resolve_service.return_value = []
         self.reactor.lookups["testserv"] = "1.2.3.4"
 
-        test_d = self._make_get_request(b"matrix://testserv/foo/bar")
+        test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
 
         # Nothing happened yet
         self.assertNoResult(test_d)
@@ -952,9 +948,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
             client_factory, expected_sni=b"testserv", content=b"NOT JSON"
         )
 
-        # now there should be a SRV lookup
-        self.mock_resolver.resolve_service.assert_called_once_with(
-            b"_matrix._tcp.testserv"
+        # now there should be two SRV lookups
+        self.mock_resolver.resolve_service.assert_has_calls(
+            [call(b"_matrix-fed._tcp.testserv"), call(b"_matrix._tcp.testserv")]
         )
 
         # we should fall back to a direct connection
@@ -986,7 +982,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         # the config left to the default, which will not trust it (since the
         # presented cert is signed by a test CA)
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
+        self.mock_resolver.resolve_service.return_value = []
         self.reactor.lookups["testserv"] = "1.2.3.4"
 
         config = default_config("test", parse=True)
@@ -1009,7 +1005,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
             ),
         )
 
-        test_d = agent.request(b"GET", b"matrix://testserv/foo/bar")
+        test_d = agent.request(b"GET", b"matrix-federation://testserv/foo/bar")
 
         # Nothing happened yet
         self.assertNoResult(test_d)
@@ -1026,30 +1022,74 @@ class MatrixFederationAgentTests(unittest.TestCase):
         # there should be no requests
         self.assertEqual(len(http_proto.requests), 0)
 
-        # and there should be a SRV lookup instead
-        self.mock_resolver.resolve_service.assert_called_once_with(
-            b"_matrix._tcp.testserv"
+        # and there should be two SRV lookups instead
+        self.mock_resolver.resolve_service.assert_has_calls(
+            [call(b"_matrix-fed._tcp.testserv"), call(b"_matrix._tcp.testserv")]
         )
 
     def test_get_hostname_srv(self) -> None:
         """
-        Test the behaviour when there is a single SRV record
+        Test the behaviour when there is a single SRV record for _matrix-fed.
         """
         self.agent = self._make_agent()
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
-            [Server(host=b"srvtarget", port=8443)]
-        )
+        self.mock_resolver.resolve_service.return_value = [
+            Server(host=b"srvtarget", port=8443)
+        ]
         self.reactor.lookups["srvtarget"] = "1.2.3.4"
 
-        test_d = self._make_get_request(b"matrix://testserv/foo/bar")
+        test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
 
         # Nothing happened yet
         self.assertNoResult(test_d)
 
         # the request for a .well-known will have failed with a DNS lookup error.
         self.mock_resolver.resolve_service.assert_called_once_with(
-            b"_matrix._tcp.testserv"
+            b"_matrix-fed._tcp.testserv"
+        )
+
+        # Make sure treq is trying to connect
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+        self.assertEqual(host, "1.2.3.4")
+        self.assertEqual(port, 8443)
+
+        # make a test server, and wire up the client
+        http_server = self._make_connection(client_factory, expected_sni=b"testserv")
+
+        self.assertEqual(len(http_server.requests), 1)
+        request = http_server.requests[0]
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"/foo/bar")
+        self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"])
+
+        # finish the request
+        request.finish()
+        self.reactor.pump((0.1,))
+        self.successResultOf(test_d)
+
+    def test_get_hostname_srv_legacy(self) -> None:
+        """
+        Test the behaviour when there is a single SRV record for _matrix.
+        """
+        self.agent = self._make_agent()
+
+        # Return no entries for the _matrix-fed lookup, and a response for _matrix.
+        self.mock_resolver.resolve_service.side_effect = [
+            [],
+            [Server(host=b"srvtarget", port=8443)],
+        ]
+        self.reactor.lookups["srvtarget"] = "1.2.3.4"
+
+        test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
+
+        # Nothing happened yet
+        self.assertNoResult(test_d)
+
+        # the request for a .well-known will have failed with a DNS lookup error.
+        self.mock_resolver.resolve_service.assert_has_calls(
+            [call(b"_matrix-fed._tcp.testserv"), call(b"_matrix._tcp.testserv")]
         )
 
         # Make sure treq is trying to connect
@@ -1075,14 +1115,14 @@ class MatrixFederationAgentTests(unittest.TestCase):
 
     def test_get_well_known_srv(self) -> None:
         """Test the behaviour when the .well-known redirects to a place where there
-        is a SRV.
+        is a _matrix-fed SRV record.
         """
         self.agent = self._make_agent()
 
         self.reactor.lookups["testserv"] = "1.2.3.4"
         self.reactor.lookups["srvtarget"] = "5.6.7.8"
 
-        test_d = self._make_get_request(b"matrix://testserv/foo/bar")
+        test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
 
         # Nothing happened yet
         self.assertNoResult(test_d)
@@ -1094,9 +1134,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self.assertEqual(host, "1.2.3.4")
         self.assertEqual(port, 443)
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
-            [Server(host=b"srvtarget", port=8443)]
-        )
+        self.mock_resolver.resolve_service.return_value = [
+            Server(host=b"srvtarget", port=8443)
+        ]
 
         self._handle_well_known_connection(
             client_factory,
@@ -1106,7 +1146,72 @@ class MatrixFederationAgentTests(unittest.TestCase):
 
         # there should be a SRV lookup
         self.mock_resolver.resolve_service.assert_called_once_with(
-            b"_matrix._tcp.target-server"
+            b"_matrix-fed._tcp.target-server"
+        )
+
+        # now we should get a connection to the target of the SRV record
+        self.assertEqual(len(clients), 2)
+        (host, port, client_factory, _timeout, _bindAddress) = clients[1]
+        self.assertEqual(host, "5.6.7.8")
+        self.assertEqual(port, 8443)
+
+        # make a test server, and wire up the client
+        http_server = self._make_connection(
+            client_factory, expected_sni=b"target-server"
+        )
+
+        self.assertEqual(len(http_server.requests), 1)
+        request = http_server.requests[0]
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"/foo/bar")
+        self.assertEqual(
+            request.requestHeaders.getRawHeaders(b"host"), [b"target-server"]
+        )
+
+        # finish the request
+        request.finish()
+        self.reactor.pump((0.1,))
+        self.successResultOf(test_d)
+
+    def test_get_well_known_srv_legacy(self) -> None:
+        """Test the behaviour when the .well-known redirects to a place where there
+        is a _matrix SRV record.
+        """
+        self.agent = self._make_agent()
+
+        self.reactor.lookups["testserv"] = "1.2.3.4"
+        self.reactor.lookups["srvtarget"] = "5.6.7.8"
+
+        test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
+
+        # Nothing happened yet
+        self.assertNoResult(test_d)
+
+        # there should be an attempt to connect on port 443 for the .well-known
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+        self.assertEqual(host, "1.2.3.4")
+        self.assertEqual(port, 443)
+
+        # Return no entries for the _matrix-fed lookup, and a response for _matrix.
+        self.mock_resolver.resolve_service.side_effect = [
+            [],
+            [Server(host=b"srvtarget", port=8443)],
+        ]
+
+        self._handle_well_known_connection(
+            client_factory,
+            expected_sni=b"testserv",
+            content=b'{ "m.server": "target-server" }',
+        )
+
+        # there should be two SRV lookups
+        self.mock_resolver.resolve_service.assert_has_calls(
+            [
+                call(b"_matrix-fed._tcp.target-server"),
+                call(b"_matrix._tcp.target-server"),
+            ]
         )
 
         # now we should get a connection to the target of the SRV record
@@ -1137,13 +1242,15 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """test the behaviour when the server name has idna chars in"""
         self.agent = self._make_agent()
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
+        self.mock_resolver.resolve_service.return_value = []
 
         # the resolver is always called with the IDNA hostname as a native string.
         self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4"
 
         # this is idna for bücher.com
-        test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar")
+        test_d = self._make_get_request(
+            b"matrix-federation://xn--bcher-kva.com/foo/bar"
+        )
 
         # Nothing happened yet
         self.assertNoResult(test_d)
@@ -1166,8 +1273,11 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self.reactor.pump((0.4,))
 
         # now there should have been a SRV lookup
-        self.mock_resolver.resolve_service.assert_called_once_with(
-            b"_matrix._tcp.xn--bcher-kva.com"
+        self.mock_resolver.resolve_service.assert_has_calls(
+            [
+                call(b"_matrix-fed._tcp.xn--bcher-kva.com"),
+                call(b"_matrix._tcp.xn--bcher-kva.com"),
+            ]
         )
 
         # We should fall back to port 8448
@@ -1196,21 +1306,73 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self.successResultOf(test_d)
 
     def test_idna_srv_target(self) -> None:
-        """test the behaviour when the target of a SRV record has idna chars"""
+        """test the behaviour when the target of a _matrix-fed SRV record has idna chars"""
         self.agent = self._make_agent()
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
-            [Server(host=b"xn--trget-3qa.com", port=8443)]  # târget.com
-        )
+        self.mock_resolver.resolve_service.return_value = [
+            Server(host=b"xn--trget-3qa.com", port=8443)
+        ]  # târget.com
         self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4"
 
-        test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar")
+        test_d = self._make_get_request(
+            b"matrix-federation://xn--bcher-kva.com/foo/bar"
+        )
 
         # Nothing happened yet
         self.assertNoResult(test_d)
 
         self.mock_resolver.resolve_service.assert_called_once_with(
-            b"_matrix._tcp.xn--bcher-kva.com"
+            b"_matrix-fed._tcp.xn--bcher-kva.com"
+        )
+
+        # Make sure treq is trying to connect
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+        self.assertEqual(host, "1.2.3.4")
+        self.assertEqual(port, 8443)
+
+        # make a test server, and wire up the client
+        http_server = self._make_connection(
+            client_factory, expected_sni=b"xn--bcher-kva.com"
+        )
+
+        self.assertEqual(len(http_server.requests), 1)
+        request = http_server.requests[0]
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"/foo/bar")
+        self.assertEqual(
+            request.requestHeaders.getRawHeaders(b"host"), [b"xn--bcher-kva.com"]
+        )
+
+        # finish the request
+        request.finish()
+        self.reactor.pump((0.1,))
+        self.successResultOf(test_d)
+
+    def test_idna_srv_target_legacy(self) -> None:
+        """test the behaviour when the target of a _matrix SRV record has idna chars"""
+        self.agent = self._make_agent()
+
+        # Return no entries for the _matrix-fed lookup, and a response for _matrix.
+        self.mock_resolver.resolve_service.side_effect = [
+            [],
+            [Server(host=b"xn--trget-3qa.com", port=8443)],
+        ]  # târget.com
+        self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4"
+
+        test_d = self._make_get_request(
+            b"matrix-federation://xn--bcher-kva.com/foo/bar"
+        )
+
+        # Nothing happened yet
+        self.assertNoResult(test_d)
+
+        self.mock_resolver.resolve_service.assert_has_calls(
+            [
+                call(b"_matrix-fed._tcp.xn--bcher-kva.com"),
+                call(b"_matrix._tcp.xn--bcher-kva.com"),
+            ]
         )
 
         # Make sure treq is trying to connect
@@ -1400,24 +1562,82 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self.assertIsNone(r.delegated_server)
 
     def test_srv_fallbacks(self) -> None:
-        """Test that other SRV results are tried if the first one fails."""
+        """Test that other SRV results are tried if the first one fails for _matrix-fed SRV."""
+        self.agent = self._make_agent()
+
+        self.mock_resolver.resolve_service.return_value = [
+            Server(host=b"target.com", port=8443),
+            Server(host=b"target.com", port=8444),
+        ]
+        self.reactor.lookups["target.com"] = "1.2.3.4"
+
+        test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
+
+        # Nothing happened yet
+        self.assertNoResult(test_d)
+
+        self.mock_resolver.resolve_service.assert_called_once_with(
+            b"_matrix-fed._tcp.testserv"
+        )
+
+        # We should see an attempt to connect to the first server
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+        self.assertEqual(host, "1.2.3.4")
+        self.assertEqual(port, 8443)
+
+        # Fonx the connection
+        client_factory.clientConnectionFailed(None, Exception("nope"))
+
+        # There's a 300ms delay in HostnameEndpoint
+        self.reactor.pump((0.4,))
+
+        # Hasn't failed yet
+        self.assertNoResult(test_d)
+
+        # We shouldnow see an attempt to connect to the second server
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+        self.assertEqual(host, "1.2.3.4")
+        self.assertEqual(port, 8444)
+
+        # make a test server, and wire up the client
+        http_server = self._make_connection(client_factory, expected_sni=b"testserv")
+
+        self.assertEqual(len(http_server.requests), 1)
+        request = http_server.requests[0]
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"/foo/bar")
+        self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"])
+
+        # finish the request
+        request.finish()
+        self.reactor.pump((0.1,))
+        self.successResultOf(test_d)
+
+    def test_srv_fallbacks_legacy(self) -> None:
+        """Test that other SRV results are tried if the first one fails for _matrix SRV."""
         self.agent = self._make_agent()
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
+        # Return no entries for the _matrix-fed lookup, and a response for _matrix.
+        self.mock_resolver.resolve_service.side_effect = [
+            [],
             [
                 Server(host=b"target.com", port=8443),
                 Server(host=b"target.com", port=8444),
-            ]
-        )
+            ],
+        ]
         self.reactor.lookups["target.com"] = "1.2.3.4"
 
-        test_d = self._make_get_request(b"matrix://testserv/foo/bar")
+        test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
 
         # Nothing happened yet
         self.assertNoResult(test_d)
 
-        self.mock_resolver.resolve_service.assert_called_once_with(
-            b"_matrix._tcp.testserv"
+        self.mock_resolver.resolve_service.assert_has_calls(
+            [call(b"_matrix-fed._tcp.testserv"), call(b"_matrix._tcp.testserv")]
         )
 
         # We should see an attempt to connect to the first server
@@ -1457,6 +1677,43 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self.reactor.pump((0.1,))
         self.successResultOf(test_d)
 
+    def test_srv_no_fallback_to_legacy(self) -> None:
+        """Test that _matrix SRV results are not tried if the _matrix-fed one fails."""
+        self.agent = self._make_agent()
+
+        # Return a failing entry for _matrix-fed.
+        self.mock_resolver.resolve_service.side_effect = [
+            [Server(host=b"target.com", port=8443)],
+            [],
+        ]
+        self.reactor.lookups["target.com"] = "1.2.3.4"
+
+        test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
+
+        # Nothing happened yet
+        self.assertNoResult(test_d)
+
+        # Only the _matrix-fed is checked, _matrix is ignored.
+        self.mock_resolver.resolve_service.assert_called_once_with(
+            b"_matrix-fed._tcp.testserv"
+        )
+
+        # We should see an attempt to connect to the first server
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+        self.assertEqual(host, "1.2.3.4")
+        self.assertEqual(port, 8443)
+
+        # Fonx the connection
+        client_factory.clientConnectionFailed(None, Exception("nope"))
+
+        # There's a 300ms delay in HostnameEndpoint
+        self.reactor.pump((0.4,))
+
+        # Failed to resolve a server.
+        self.assertFailure(test_d, Exception)
+
 
 class TestCachePeriodFromHeaders(unittest.TestCase):
     def test_cache_control(self) -> None:
diff --git a/tests/http/test_matrixfederationclient.py b/tests/http/test_matrixfederationclient.py
index 0dfc03ce50..ab94f3f67a 100644
--- a/tests/http/test_matrixfederationclient.py
+++ b/tests/http/test_matrixfederationclient.py
@@ -11,8 +11,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Generator
-from unittest.mock import Mock
+from typing import Any, Dict, Generator
+from unittest.mock import ANY, Mock, create_autospec
 
 from netaddr import IPSet
 from parameterized import parameterized
@@ -21,10 +21,12 @@ from twisted.internet import defer
 from twisted.internet.defer import Deferred, TimeoutError
 from twisted.internet.error import ConnectingCancelledError, DNSLookupError
 from twisted.test.proto_helpers import MemoryReactor, StringTransport
-from twisted.web.client import ResponseNeverReceived
+from twisted.web.client import Agent, ResponseNeverReceived
 from twisted.web.http import HTTPChannel
+from twisted.web.http_headers import Headers
 
-from synapse.api.errors import RequestSendFailed
+from synapse.api.errors import HttpResponseException, RequestSendFailed
+from synapse.config._base import ConfigError
 from synapse.http.matrixfederationclient import (
     ByteParser,
     MatrixFederationHttpClient,
@@ -39,8 +41,10 @@ from synapse.logging.context import (
 from synapse.server import HomeServer
 from synapse.util import Clock
 
+from tests.replication._base import BaseMultiWorkerStreamTestCase
 from tests.server import FakeTransport
-from tests.unittest import HomeserverTestCase
+from tests.test_utils import FakeResponse
+from tests.unittest import HomeserverTestCase, override_config
 
 
 def check_logcontext(context: LoggingContextOrSentinel) -> None:
@@ -640,3 +644,293 @@ class FederationClientTests(HomeserverTestCase):
             self.cl.build_auth_headers(
                 b"", b"GET", b"https://example.com", destination_is=b""
             )
+
+    @override_config(
+        {
+            "federation": {
+                "client_timeout": "180s",
+                "max_long_retry_delay": "100s",
+                "max_short_retry_delay": "7s",
+                "max_long_retries": 20,
+                "max_short_retries": 5,
+            }
+        }
+    )
+    def test_configurable_retry_and_delay_values(self) -> None:
+        self.assertEqual(self.cl.default_timeout_seconds, 180)
+        self.assertEqual(self.cl.max_long_retry_delay_seconds, 100)
+        self.assertEqual(self.cl.max_short_retry_delay_seconds, 7)
+        self.assertEqual(self.cl.max_long_retries, 20)
+        self.assertEqual(self.cl.max_short_retries, 5)
+
+
+class FederationClientProxyTests(BaseMultiWorkerStreamTestCase):
+    def default_config(self) -> Dict[str, Any]:
+        conf = super().default_config()
+        conf["instance_map"] = {
+            "main": {"host": "testserv", "port": 8765},
+            "federation_sender": {"host": "testserv", "port": 1001},
+        }
+        return conf
+
+    @override_config(
+        {
+            "outbound_federation_restricted_to": ["federation_sender"],
+            "worker_replication_secret": "secret",
+        }
+    )
+    def test_proxy_requests_through_federation_sender_worker(self) -> None:
+        """
+        Test that all outbound federation requests go through the `federation_sender`
+        worker
+        """
+        # Mock out the `MatrixFederationHttpClient` of the `federation_sender` instance
+        # so we can act like some remote server responding to requests
+        mock_client_on_federation_sender = Mock()
+        mock_agent_on_federation_sender = create_autospec(Agent, spec_set=True)
+        mock_client_on_federation_sender.agent = mock_agent_on_federation_sender
+
+        # Create the `federation_sender` worker
+        self.make_worker_hs(
+            "synapse.app.generic_worker",
+            {"worker_name": "federation_sender"},
+            federation_http_client=mock_client_on_federation_sender,
+        )
+
+        # Fake `remoteserv:8008` responding to requests
+        mock_agent_on_federation_sender.request.side_effect = (
+            lambda *args, **kwargs: defer.succeed(
+                FakeResponse.json(
+                    payload={
+                        "foo": "bar",
+                    }
+                )
+            )
+        )
+
+        # This federation request from the main process should be proxied through the
+        # `federation_sender` worker off to the remote server
+        test_request_from_main_process_d = defer.ensureDeferred(
+            self.hs.get_federation_http_client().get_json("remoteserv:8008", "foo/bar")
+        )
+
+        # Pump the reactor so our deferred goes through the motions
+        self.pump()
+
+        # Make sure that the request was proxied through the `federation_sender` worker
+        mock_agent_on_federation_sender.request.assert_called_once_with(
+            b"GET",
+            b"matrix-federation://remoteserv:8008/foo/bar",
+            headers=ANY,
+            bodyProducer=ANY,
+        )
+
+        # Make sure the response is as expected back on the main worker
+        res = self.successResultOf(test_request_from_main_process_d)
+        self.assertEqual(res, {"foo": "bar"})
+
+    @override_config(
+        {
+            "outbound_federation_restricted_to": ["federation_sender"],
+            "worker_replication_secret": "secret",
+        }
+    )
+    def test_proxy_request_with_network_error_through_federation_sender_worker(
+        self,
+    ) -> None:
+        """
+        Test that when the outbound federation request fails with a network related
+        error, a sensible error makes its way back to the main process.
+        """
+        # Mock out the `MatrixFederationHttpClient` of the `federation_sender` instance
+        # so we can act like some remote server responding to requests
+        mock_client_on_federation_sender = Mock()
+        mock_agent_on_federation_sender = create_autospec(Agent, spec_set=True)
+        mock_client_on_federation_sender.agent = mock_agent_on_federation_sender
+
+        # Create the `federation_sender` worker
+        self.make_worker_hs(
+            "synapse.app.generic_worker",
+            {"worker_name": "federation_sender"},
+            federation_http_client=mock_client_on_federation_sender,
+        )
+
+        # Fake `remoteserv:8008` responding to requests
+        mock_agent_on_federation_sender.request.side_effect = (
+            lambda *args, **kwargs: defer.fail(ResponseNeverReceived("fake error"))
+        )
+
+        # This federation request from the main process should be proxied through the
+        # `federation_sender` worker off to the remote server
+        test_request_from_main_process_d = defer.ensureDeferred(
+            self.hs.get_federation_http_client().get_json("remoteserv:8008", "foo/bar")
+        )
+
+        # Pump the reactor so our deferred goes through the motions. We pump with 10
+        # seconds (0.1 * 100) so the `MatrixFederationHttpClient` runs out of retries
+        # and finally passes along the error response.
+        self.pump(0.1)
+
+        # Make sure that the request was proxied through the `federation_sender` worker
+        mock_agent_on_federation_sender.request.assert_called_with(
+            b"GET",
+            b"matrix-federation://remoteserv:8008/foo/bar",
+            headers=ANY,
+            bodyProducer=ANY,
+        )
+
+        # Make sure we get some sort of error back on the main worker
+        failure_res = self.failureResultOf(test_request_from_main_process_d)
+        self.assertIsInstance(failure_res.value, RequestSendFailed)
+        self.assertIsInstance(failure_res.value.inner_exception, HttpResponseException)
+        self.assertEqual(failure_res.value.inner_exception.code, 502)
+
+    @override_config(
+        {
+            "outbound_federation_restricted_to": ["federation_sender"],
+            "worker_replication_secret": "secret",
+        }
+    )
+    def test_proxy_requests_and_discards_hop_by_hop_headers(self) -> None:
+        """
+        Test to make sure hop-by-hop headers and addional headers defined in the
+        `Connection` header are discarded when proxying requests
+        """
+        # Mock out the `MatrixFederationHttpClient` of the `federation_sender` instance
+        # so we can act like some remote server responding to requests
+        mock_client_on_federation_sender = Mock()
+        mock_agent_on_federation_sender = create_autospec(Agent, spec_set=True)
+        mock_client_on_federation_sender.agent = mock_agent_on_federation_sender
+
+        # Create the `federation_sender` worker
+        self.make_worker_hs(
+            "synapse.app.generic_worker",
+            {"worker_name": "federation_sender"},
+            federation_http_client=mock_client_on_federation_sender,
+        )
+
+        # Fake `remoteserv:8008` responding to requests
+        mock_agent_on_federation_sender.request.side_effect = lambda *args, **kwargs: defer.succeed(
+            FakeResponse(
+                code=200,
+                body=b'{"foo": "bar"}',
+                headers=Headers(
+                    {
+                        "Content-Type": ["application/json"],
+                        "Connection": ["close, X-Foo, X-Bar"],
+                        # Should be removed because it's defined in the `Connection` header
+                        "X-Foo": ["foo"],
+                        "X-Bar": ["bar"],
+                        # Should be removed because it's a hop-by-hop header
+                        "Proxy-Authorization": "abcdef",
+                    }
+                ),
+            )
+        )
+
+        # This federation request from the main process should be proxied through the
+        # `federation_sender` worker off to the remote server
+        test_request_from_main_process_d = defer.ensureDeferred(
+            self.hs.get_federation_http_client().get_json_with_headers(
+                "remoteserv:8008", "foo/bar"
+            )
+        )
+
+        # Pump the reactor so our deferred goes through the motions
+        self.pump()
+
+        # Make sure that the request was proxied through the `federation_sender` worker
+        mock_agent_on_federation_sender.request.assert_called_once_with(
+            b"GET",
+            b"matrix-federation://remoteserv:8008/foo/bar",
+            headers=ANY,
+            bodyProducer=ANY,
+        )
+
+        res, headers = self.successResultOf(test_request_from_main_process_d)
+        header_names = set(headers.keys())
+
+        # Make sure the response does not include the hop-by-hop headers
+        self.assertNotIn(b"X-Foo", header_names)
+        self.assertNotIn(b"X-Bar", header_names)
+        self.assertNotIn(b"Proxy-Authorization", header_names)
+        # Make sure the response is as expected back on the main worker
+        self.assertEqual(res, {"foo": "bar"})
+
+    @override_config(
+        {
+            "outbound_federation_restricted_to": ["federation_sender"],
+            # `worker_replication_secret` is set here so that the test setup is able to pass
+            # but the actual homserver creation test is in the test body below
+            "worker_replication_secret": "secret",
+        }
+    )
+    def test_not_able_to_proxy_requests_through_federation_sender_worker_when_no_secret_configured(
+        self,
+    ) -> None:
+        """
+        Test that we aren't able to proxy any outbound federation requests when
+        `worker_replication_secret` is not configured.
+        """
+        with self.assertRaises(ConfigError):
+            # Create the `federation_sender` worker
+            self.make_worker_hs(
+                "synapse.app.generic_worker",
+                {
+                    "worker_name": "federation_sender",
+                    # Test that we aren't able to proxy any outbound federation requests
+                    # when `worker_replication_secret` is not configured.
+                    "worker_replication_secret": None,
+                },
+            )
+
+    @override_config(
+        {
+            "outbound_federation_restricted_to": ["federation_sender"],
+            "worker_replication_secret": "secret",
+        }
+    )
+    def test_not_able_to_proxy_requests_through_federation_sender_worker_when_wrong_auth_given(
+        self,
+    ) -> None:
+        """
+        Test that we aren't able to proxy any outbound federation requests when the
+        wrong authorization is given.
+        """
+        # Mock out the `MatrixFederationHttpClient` of the `federation_sender` instance
+        # so we can act like some remote server responding to requests
+        mock_client_on_federation_sender = Mock()
+        mock_agent_on_federation_sender = create_autospec(Agent, spec_set=True)
+        mock_client_on_federation_sender.agent = mock_agent_on_federation_sender
+
+        # Create the `federation_sender` worker
+        self.make_worker_hs(
+            "synapse.app.generic_worker",
+            {
+                "worker_name": "federation_sender",
+                # Test that we aren't able to proxy any outbound federation requests
+                # when `worker_replication_secret` is wrong.
+                "worker_replication_secret": "wrong",
+            },
+            federation_http_client=mock_client_on_federation_sender,
+        )
+
+        # This federation request from the main process should be proxied through the
+        # `federation_sender` worker off but will fail here because it's using the wrong
+        # authorization.
+        test_request_from_main_process_d = defer.ensureDeferred(
+            self.hs.get_federation_http_client().get_json("remoteserv:8008", "foo/bar")
+        )
+
+        # Pump the reactor so our deferred goes through the motions. We pump with 10
+        # seconds (0.1 * 100) so the `MatrixFederationHttpClient` runs out of retries
+        # and finally passes along the error response.
+        self.pump(0.1)
+
+        # Make sure that the request was *NOT* proxied through the `federation_sender`
+        # worker
+        mock_agent_on_federation_sender.request.assert_not_called()
+
+        failure_res = self.failureResultOf(test_request_from_main_process_d)
+        self.assertIsInstance(failure_res.value, HttpResponseException)
+        self.assertEqual(failure_res.value.code, 401)
diff --git a/tests/http/test_proxy.py b/tests/http/test_proxy.py
new file mode 100644
index 0000000000..0dc9ba8e05
--- /dev/null
+++ b/tests/http/test_proxy.py
@@ -0,0 +1,53 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Set
+
+from parameterized import parameterized
+
+from synapse.http.proxy import parse_connection_header_value
+
+from tests.unittest import TestCase
+
+
+class ProxyTests(TestCase):
+    @parameterized.expand(
+        [
+            [b"close, X-Foo, X-Bar", {"Close", "X-Foo", "X-Bar"}],
+            # No whitespace
+            [b"close,X-Foo,X-Bar", {"Close", "X-Foo", "X-Bar"}],
+            # More whitespace
+            [b"close,    X-Foo,      X-Bar", {"Close", "X-Foo", "X-Bar"}],
+            # "close" directive in not the first position
+            [b"X-Foo, X-Bar, close", {"X-Foo", "X-Bar", "Close"}],
+            # Normalizes header capitalization
+            [b"keep-alive, x-fOo, x-bAr", {"Keep-Alive", "X-Foo", "X-Bar"}],
+            # Handles header names with whitespace
+            [
+                b"keep-alive, x  foo, x bar",
+                {"Keep-Alive", "X  foo", "X bar"},
+            ],
+        ]
+    )
+    def test_parse_connection_header_value(
+        self,
+        connection_header_value: bytes,
+        expected_extra_headers_to_remove: Set[str],
+    ) -> None:
+        """
+        Tests that the connection header value is parsed correctly
+        """
+        self.assertEqual(
+            expected_extra_headers_to_remove,
+            parse_connection_header_value(connection_header_value),
+        )
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index e0ae5a88ff..8164b0b78e 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -33,7 +33,7 @@ from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
 from twisted.web.http import HTTPChannel
 
 from synapse.http.client import BlocklistingReactorWrapper
-from synapse.http.connectproxyclient import ProxyCredentials
+from synapse.http.connectproxyclient import BasicProxyCredentials
 from synapse.http.proxyagent import ProxyAgent, parse_proxy
 
 from tests.http import (
@@ -205,7 +205,7 @@ class ProxyParserTests(TestCase):
         """
         proxy_cred = None
         if expected_credentials:
-            proxy_cred = ProxyCredentials(expected_credentials)
+            proxy_cred = BasicProxyCredentials(expected_credentials)
         self.assertEqual(
             (
                 expected_scheme,
diff --git a/tests/logging/test_opentracing.py b/tests/logging/test_opentracing.py
index e28ba84cc2..1bc7d64ad9 100644
--- a/tests/logging/test_opentracing.py
+++ b/tests/logging/test_opentracing.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import cast
+from typing import Awaitable, cast
 
 from twisted.internet import defer
 from twisted.test.proto_helpers import MemoryReactorClock
@@ -227,8 +227,6 @@ class LogContextScopeManagerTestCase(TestCase):
         Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
         with functions that return deferreds
         """
-        reactor = MemoryReactorClock()
-
         with LoggingContext("root context"):
 
             @trace_with_opname("fixture_deferred_func", tracer=self._tracer)
@@ -240,9 +238,6 @@ class LogContextScopeManagerTestCase(TestCase):
 
             result_d1 = fixture_deferred_func()
 
-            # let the tasks complete
-            reactor.pump((2,) * 8)
-
             self.assertEqual(self.successResultOf(result_d1), "foo")
 
         # the span should have been reported
@@ -256,8 +251,6 @@ class LogContextScopeManagerTestCase(TestCase):
         Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
         with async functions
         """
-        reactor = MemoryReactorClock()
-
         with LoggingContext("root context"):
 
             @trace_with_opname("fixture_async_func", tracer=self._tracer)
@@ -267,9 +260,6 @@ class LogContextScopeManagerTestCase(TestCase):
 
             d1 = defer.ensureDeferred(fixture_async_func())
 
-            # let the tasks complete
-            reactor.pump((2,) * 8)
-
             self.assertEqual(self.successResultOf(d1), "foo")
 
         # the span should have been reported
@@ -277,3 +267,34 @@ class LogContextScopeManagerTestCase(TestCase):
             [span.operation_name for span in self._reporter.get_spans()],
             ["fixture_async_func"],
         )
+
+    def test_trace_decorator_awaitable_return(self) -> None:
+        """
+        Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
+        with functions that return an awaitable (e.g. a coroutine)
+        """
+        with LoggingContext("root context"):
+            # Something we can return without `await` to get a coroutine
+            async def fixture_async_func() -> str:
+                return "foo"
+
+            # The actual kind of function we want to test that returns an awaitable
+            @trace_with_opname("fixture_awaitable_return_func", tracer=self._tracer)
+            @tag_args
+            def fixture_awaitable_return_func() -> Awaitable[str]:
+                return fixture_async_func()
+
+            # Something we can run with `defer.ensureDeferred(runner())` and pump the
+            # whole async tasks through to completion.
+            async def runner() -> str:
+                return await fixture_awaitable_return_func()
+
+            d1 = defer.ensureDeferred(runner())
+
+            self.assertEqual(self.successResultOf(d1), "foo")
+
+        # the span should have been reported
+        self.assertEqual(
+            [span.operation_name for span in self._reporter.get_spans()],
+            ["fixture_awaitable_return_func"],
+        )
diff --git a/tests/logging/test_remote_handler.py b/tests/logging/test_remote_handler.py
index 5191e31a8a..45eac100bf 100644
--- a/tests/logging/test_remote_handler.py
+++ b/tests/logging/test_remote_handler.py
@@ -78,11 +78,11 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
         logger = self.get_logger(handler)
 
         # Send some debug messages
-        for i in range(0, 3):
+        for i in range(3):
             logger.debug("debug %s" % (i,))
 
         # Send a bunch of useful messages
-        for i in range(0, 7):
+        for i in range(7):
             logger.info("info %s" % (i,))
 
         # The last debug message pushes it past the maximum buffer
@@ -108,15 +108,15 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
         logger = self.get_logger(handler)
 
         # Send some debug messages
-        for i in range(0, 3):
+        for i in range(3):
             logger.debug("debug %s" % (i,))
 
         # Send a bunch of useful messages
-        for i in range(0, 10):
+        for i in range(10):
             logger.warning("warn %s" % (i,))
 
         # Send a bunch of info messages
-        for i in range(0, 3):
+        for i in range(3):
             logger.info("info %s" % (i,))
 
         # The last debug message pushes it past the maximum buffer
@@ -144,7 +144,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
         logger = self.get_logger(handler)
 
         # Send a bunch of useful messages
-        for i in range(0, 20):
+        for i in range(20):
             logger.warning("warn %s" % (i,))
 
         # Allow the reconnection
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index fa27f1279a..c379853e20 100644
--- a/tests/logging/test_terse_json.py
+++ b/tests/logging/test_terse_json.py
@@ -164,7 +164,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
         # Call requestReceived to finish instantiating the object.
         request.content = BytesIO()
         # Partially skip some internal processing of SynapseRequest.
-        request._started_processing = Mock()  # type: ignore[assignment]
+        request._started_processing = Mock()  # type: ignore[method-assign]
         request.request_metrics = Mock(spec=["name"])
         with patch.object(Request, "render"):
             request.requestReceived(b"POST", b"/_matrix/client/versions", b"1.1")
diff --git a/tests/media/test_base.py b/tests/media/test_base.py
index 66498c744d..119d7ba66f 100644
--- a/tests/media/test_base.py
+++ b/tests/media/test_base.py
@@ -12,7 +12,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.media._base import get_filename_from_headers
+from unittest.mock import Mock
+
+from synapse.media._base import add_file_headers, get_filename_from_headers
 
 from tests import unittest
 
@@ -20,12 +22,12 @@ from tests import unittest
 class GetFileNameFromHeadersTests(unittest.TestCase):
     # input -> expected result
     TEST_CASES = {
-        b"inline; filename=abc.txt": "abc.txt",
-        b'inline; filename="azerty"': "azerty",
-        b'inline; filename="aze%20rty"': "aze%20rty",
-        b'inline; filename="aze"rty"': 'aze"rty',
-        b'inline; filename="azer;ty"': "azer;ty",
-        b"inline; filename*=utf-8''foo%C2%A3bar": "foo£bar",
+        b"attachment; filename=abc.txt": "abc.txt",
+        b'attachment; filename="azerty"': "azerty",
+        b'attachment; filename="aze%20rty"': "aze%20rty",
+        b'attachment; filename="aze"rty"': 'aze"rty',
+        b'attachment; filename="azer;ty"': "azer;ty",
+        b"attachment; filename*=utf-8''foo%C2%A3bar": "foo£bar",
     }
 
     def tests(self) -> None:
@@ -36,3 +38,28 @@ class GetFileNameFromHeadersTests(unittest.TestCase):
                 expected,
                 f"expected output for {hdr!r} to be {expected} but was {res}",
             )
+
+
+class AddFileHeadersTests(unittest.TestCase):
+    TEST_CASES = {
+        "text/plain": b"inline; filename=file.name",
+        "text/csv": b"inline; filename=file.name",
+        "image/png": b"inline; filename=file.name",
+        "text/html": b"attachment; filename=file.name",
+        "any/thing": b"attachment; filename=file.name",
+    }
+
+    def test_content_disposition(self) -> None:
+        for media_type, expected in self.TEST_CASES.items():
+            request = Mock()
+            add_file_headers(request, media_type, 0, "file.name")
+            request.setHeader.assert_any_call(b"Content-Disposition", expected)
+
+    def test_no_filename(self) -> None:
+        request = Mock()
+        add_file_headers(request, "text/plain", 0, None)
+        request.setHeader.assert_any_call(b"Content-Disposition", b"inline")
+
+        request.reset_mock()
+        add_file_headers(request, "text/html", 0, None)
+        request.setHeader.assert_any_call(b"Content-Disposition", b"attachment")
diff --git a/tests/media/test_html_preview.py b/tests/media/test_html_preview.py
index e7da75db3e..ea84bb3d3d 100644
--- a/tests/media/test_html_preview.py
+++ b/tests/media/test_html_preview.py
@@ -24,7 +24,7 @@ from tests import unittest
 try:
     import lxml
 except ImportError:
-    lxml = None
+    lxml = None  # type: ignore[assignment]
 
 
 class SummarizeTestCase(unittest.TestCase):
@@ -160,6 +160,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
+        assert tree is not None
         og = parse_html_to_open_graph(tree)
 
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -176,6 +177,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
+        assert tree is not None
         og = parse_html_to_open_graph(tree)
 
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -195,6 +197,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
+        assert tree is not None
         og = parse_html_to_open_graph(tree)
 
         self.assertEqual(
@@ -217,6 +220,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
+        assert tree is not None
         og = parse_html_to_open_graph(tree)
 
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -231,6 +235,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
+        assert tree is not None
         og = parse_html_to_open_graph(tree)
 
         self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
@@ -246,6 +251,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
+        assert tree is not None
         og = parse_html_to_open_graph(tree)
 
         self.assertEqual(og, {"og:title": "Title", "og:description": "Title"})
@@ -261,6 +267,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
+        assert tree is not None
         og = parse_html_to_open_graph(tree)
 
         self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
@@ -281,6 +288,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
+        assert tree is not None
         og = parse_html_to_open_graph(tree)
 
         self.assertEqual(og, {"og:title": "Title", "og:description": "Finally!"})
@@ -296,6 +304,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
+        assert tree is not None
         og = parse_html_to_open_graph(tree)
 
         self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
@@ -324,6 +333,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
         <head><title>Foo</title></head><body>Some text.</body></html>
         """.strip()
         tree = decode_body(html, "http://example.com/test.html")
+        assert tree is not None
         og = parse_html_to_open_graph(tree)
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
@@ -338,6 +348,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
         </html>
         """
         tree = decode_body(html, "http://example.com/test.html", "invalid-encoding")
+        assert tree is not None
         og = parse_html_to_open_graph(tree)
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
@@ -353,6 +364,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
         </html>
         """
         tree = decode_body(html, "http://example.com/test.html")
+        assert tree is not None
         og = parse_html_to_open_graph(tree)
         self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
 
@@ -367,6 +379,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
         </html>
         """
         tree = decode_body(html, "http://example.com/test.html")
+        assert tree is not None
         og = parse_html_to_open_graph(tree)
         self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."})
 
@@ -380,6 +393,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
         </html>
         """
         tree = decode_body(html, "http://example.com/test.html")
+        assert tree is not None
         og = parse_html_to_open_graph(tree)
         self.assertEqual(
             og,
@@ -401,6 +415,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
         </html>
         """
         tree = decode_body(html, "http://example.com/test.html")
+        assert tree is not None
         og = parse_html_to_open_graph(tree)
         self.assertEqual(
             og,
@@ -419,6 +434,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
         with a cheeky SVG</svg></u> and <strong>some</strong> tail text</b></a>
         """
         tree = decode_body(html, "http://example.com/test.html")
+        assert tree is not None
         og = parse_html_to_open_graph(tree)
         self.assertEqual(
             og,
diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py
index f0f2da65db..15f5d644e4 100644
--- a/tests/media/test_media_storage.py
+++ b/tests/media/test_media_storage.py
@@ -28,12 +28,13 @@ from typing_extensions import Literal
 from twisted.internet import defer
 from twisted.internet.defer import Deferred
 from twisted.test.proto_helpers import MemoryReactor
+from twisted.web.resource import Resource
 
 from synapse.api.errors import Codes
 from synapse.events import EventBase
 from synapse.http.types import QueryParams
 from synapse.logging.context import make_deferred_yieldable
-from synapse.media._base import FileInfo
+from synapse.media._base import FileInfo, ThumbnailInfo
 from synapse.media.filepath import MediaFilePaths
 from synapse.media.media_storage import MediaStorage, ReadableFileWrapper
 from synapse.media.storage_provider import FileStorageProviderBackend
@@ -41,12 +42,13 @@ from synapse.module_api import ModuleApi
 from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_checkers
 from synapse.rest import admin
 from synapse.rest.client import login
+from synapse.rest.media.thumbnail_resource import ThumbnailResource
 from synapse.server import HomeServer
 from synapse.types import JsonDict, RoomAlias
 from synapse.util import Clock
 
 from tests import unittest
-from tests.server import FakeChannel, FakeSite, make_request
+from tests.server import FakeChannel
 from tests.test_utils import SMALL_PNG
 from tests.utils import default_config
 
@@ -129,6 +131,8 @@ class _TestImage:
             a 404/400 is expected.
         unable_to_thumbnail: True if we expect the thumbnailing to fail (400), or
             False if the thumbnailing should succeed or a normal 404 is expected.
+        is_inline: True if we expect the file to be served using an inline
+            Content-Disposition or False if we expect an attachment.
     """
 
     data: bytes
@@ -138,6 +142,7 @@ class _TestImage:
     expected_scaled: Optional[bytes] = None
     expected_found: bool = True
     unable_to_thumbnail: bool = False
+    is_inline: bool = True
 
 
 @parameterized_class(
@@ -198,6 +203,25 @@ class _TestImage:
                 unable_to_thumbnail=True,
             ),
         ),
+        # An SVG.
+        (
+            _TestImage(
+                b"""<?xml version="1.0"?>
+<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
+  "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
+
+<svg xmlns="http://www.w3.org/2000/svg"
+     width="400" height="400">
+  <circle cx="100" cy="100" r="50" stroke="black"
+    stroke-width="5" fill="red" />
+</svg>""",
+                b"image/svg",
+                b".svg",
+                expected_found=False,
+                unable_to_thumbnail=True,
+                is_inline=False,
+            ),
+        ),
     ],
 )
 class MediaRepoTests(unittest.HomeserverTestCase):
@@ -266,22 +290,22 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         return hs
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        media_resource = hs.get_media_repository_resource()
-        self.download_resource = media_resource.children[b"download"]
-        self.thumbnail_resource = media_resource.children[b"thumbnail"]
         self.store = hs.get_datastores().main
         self.media_repo = hs.get_media_repository()
 
         self.media_id = "example.com/12345"
 
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        resources = super().create_resource_dict()
+        resources["/_matrix/media"] = self.hs.get_media_repository_resource()
+        return resources
+
     def _req(
         self, content_disposition: Optional[bytes], include_content_type: bool = True
     ) -> FakeChannel:
-        channel = make_request(
-            self.reactor,
-            FakeSite(self.download_resource, self.reactor),
+        channel = self.make_request(
             "GET",
-            self.media_id,
+            f"/_matrix/media/v3/download/{self.media_id}",
             shorthand=False,
             await_result=False,
         )
@@ -317,7 +341,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
 
     def test_handle_missing_content_type(self) -> None:
         channel = self._req(
-            b"inline; filename=out" + self.test_image.extension,
+            b"attachment; filename=out" + self.test_image.extension,
             include_content_type=False,
         )
         headers = channel.headers
@@ -331,7 +355,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         If the filename is filename=<ascii> then Synapse will decode it as an
         ASCII string, and use filename= in the response.
         """
-        channel = self._req(b"inline; filename=out" + self.test_image.extension)
+        channel = self._req(b"attachment; filename=out" + self.test_image.extension)
 
         headers = channel.headers
         self.assertEqual(
@@ -339,7 +363,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         )
         self.assertEqual(
             headers.getRawHeaders(b"Content-Disposition"),
-            [b"inline; filename=out" + self.test_image.extension],
+            [
+                (b"inline" if self.test_image.is_inline else b"attachment")
+                + b"; filename=out"
+                + self.test_image.extension
+            ],
         )
 
     def test_disposition_filenamestar_utf8escaped(self) -> None:
@@ -350,7 +378,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         """
         filename = parse.quote("\u2603".encode()).encode("ascii")
         channel = self._req(
-            b"inline; filename*=utf-8''" + filename + self.test_image.extension
+            b"attachment; filename*=utf-8''" + filename + self.test_image.extension
         )
 
         headers = channel.headers
@@ -359,13 +387,18 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         )
         self.assertEqual(
             headers.getRawHeaders(b"Content-Disposition"),
-            [b"inline; filename*=utf-8''" + filename + self.test_image.extension],
+            [
+                (b"inline" if self.test_image.is_inline else b"attachment")
+                + b"; filename*=utf-8''"
+                + filename
+                + self.test_image.extension
+            ],
         )
 
     def test_disposition_none(self) -> None:
         """
-        If there is no filename, one isn't passed on in the Content-Disposition
-        of the request.
+        If there is no filename, Content-Disposition should only
+        be a disposition type.
         """
         channel = self._req(None)
 
@@ -373,7 +406,10 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         self.assertEqual(
             headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
         )
-        self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
+        self.assertEqual(
+            headers.getRawHeaders(b"Content-Disposition"),
+            [b"inline" if self.test_image.is_inline else b"attachment"],
+        )
 
     def test_thumbnail_crop(self) -> None:
         """Test that a cropped remote thumbnail is available."""
@@ -447,11 +483,9 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         # Fetching again should work, without re-requesting the image from the
         # remote.
         params = "?width=32&height=32&method=scale"
-        channel = make_request(
-            self.reactor,
-            FakeSite(self.thumbnail_resource, self.reactor),
+        channel = self.make_request(
             "GET",
-            self.media_id + params,
+            f"/_matrix/media/v3/thumbnail/{self.media_id}{params}",
             shorthand=False,
             await_result=False,
         )
@@ -477,11 +511,9 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         )
         shutil.rmtree(thumbnail_dir, ignore_errors=True)
 
-        channel = make_request(
-            self.reactor,
-            FakeSite(self.thumbnail_resource, self.reactor),
+        channel = self.make_request(
             "GET",
-            self.media_id + params,
+            f"/_matrix/media/v3/thumbnail/{self.media_id}{params}",
             shorthand=False,
             await_result=False,
         )
@@ -515,11 +547,9 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         """
 
         params = "?width=32&height=32&method=" + method
-        channel = make_request(
-            self.reactor,
-            FakeSite(self.thumbnail_resource, self.reactor),
+        channel = self.make_request(
             "GET",
-            self.media_id + params,
+            f"/_matrix/media/r0/thumbnail/{self.media_id}{params}",
             shorthand=False,
             await_result=False,
         )
@@ -556,7 +586,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
                 channel.json_body,
                 {
                     "errcode": "M_UNKNOWN",
-                    "error": "Cannot find any thumbnails for the requested media ([b'example.com', b'12345']). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)",
+                    "error": "Cannot find any thumbnails for the requested media ('/_matrix/media/r0/thumbnail/example.com/12345'). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)",
                 },
             )
         else:
@@ -566,7 +596,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
                 channel.json_body,
                 {
                     "errcode": "M_NOT_FOUND",
-                    "error": "Not found [b'example.com', b'12345']",
+                    "error": "Not found '/_matrix/media/r0/thumbnail/example.com/12345'",
                 },
             )
 
@@ -575,34 +605,39 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         """Test that choosing between thumbnails with the same quality rating succeeds.
 
         We are not particular about which thumbnail is chosen."""
+
+        content_type = self.test_image.content_type.decode()
+        media_repo = self.hs.get_media_repository()
+        thumbnail_resouce = ThumbnailResource(
+            self.hs, media_repo, media_repo.media_storage
+        )
+
         self.assertIsNotNone(
-            self.thumbnail_resource._select_thumbnail(
+            thumbnail_resouce._select_thumbnail(
                 desired_width=desired_size,
                 desired_height=desired_size,
                 desired_method=method,
-                desired_type=self.test_image.content_type,
+                desired_type=content_type,
                 # Provide two identical thumbnails which are guaranteed to have the same
                 # quality rating.
                 thumbnail_infos=[
-                    {
-                        "thumbnail_width": 32,
-                        "thumbnail_height": 32,
-                        "thumbnail_method": method,
-                        "thumbnail_type": self.test_image.content_type,
-                        "thumbnail_length": 256,
-                        "filesystem_id": f"thumbnail1{self.test_image.extension.decode()}",
-                    },
-                    {
-                        "thumbnail_width": 32,
-                        "thumbnail_height": 32,
-                        "thumbnail_method": method,
-                        "thumbnail_type": self.test_image.content_type,
-                        "thumbnail_length": 256,
-                        "filesystem_id": f"thumbnail2{self.test_image.extension.decode()}",
-                    },
+                    ThumbnailInfo(
+                        width=32,
+                        height=32,
+                        method=method,
+                        type=content_type,
+                        length=256,
+                    ),
+                    ThumbnailInfo(
+                        width=32,
+                        height=32,
+                        method=method,
+                        type=content_type,
+                        length=256,
+                    ),
                 ],
                 file_id=f"image{self.test_image.extension.decode()}",
-                url_cache=None,
+                url_cache=False,
                 server_name=None,
             )
         )
@@ -612,7 +647,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         Tests that the `X-Robots-Tag` header is present, which informs web crawlers
         to not index, archive, or follow links in media.
         """
-        channel = self._req(b"inline; filename=out" + self.test_image.extension)
+        channel = self._req(b"attachment; filename=out" + self.test_image.extension)
 
         headers = channel.headers
         self.assertEqual(
@@ -625,7 +660,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         Test that the Cross-Origin-Resource-Policy header is set to "cross-origin"
         allowing web clients to embed media from the downloads API.
         """
-        channel = self._req(b"inline; filename=out" + self.test_image.extension)
+        channel = self._req(b"attachment; filename=out" + self.test_image.extension)
 
         headers = channel.headers
 
@@ -691,13 +726,13 @@ class SpamCheckerTestCaseLegacy(unittest.HomeserverTestCase):
         self.user = self.register_user("user", "pass")
         self.tok = self.login("user", "pass")
 
-        # Allow for uploading and downloading to/from the media repo
-        self.media_repo = hs.get_media_repository_resource()
-        self.download_resource = self.media_repo.children[b"download"]
-        self.upload_resource = self.media_repo.children[b"upload"]
-
         load_legacy_spam_checkers(hs)
 
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        resources = super().create_resource_dict()
+        resources["/_matrix/media"] = self.hs.get_media_repository_resource()
+        return resources
+
     def default_config(self) -> Dict[str, Any]:
         config = default_config("test")
 
@@ -717,9 +752,7 @@ class SpamCheckerTestCaseLegacy(unittest.HomeserverTestCase):
 
     def test_upload_innocent(self) -> None:
         """Attempt to upload some innocent data that should be allowed."""
-        self.helper.upload_media(
-            self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200
-        )
+        self.helper.upload_media(SMALL_PNG, tok=self.tok, expect_code=200)
 
     def test_upload_ban(self) -> None:
         """Attempt to upload some data that includes bytes "evil", which should
@@ -728,9 +761,7 @@ class SpamCheckerTestCaseLegacy(unittest.HomeserverTestCase):
 
         data = b"Some evil data"
 
-        self.helper.upload_media(
-            self.upload_resource, data, tok=self.tok, expect_code=400
-        )
+        self.helper.upload_media(data, tok=self.tok, expect_code=400)
 
 
 EVIL_DATA = b"Some evil data"
@@ -747,15 +778,15 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
         self.user = self.register_user("user", "pass")
         self.tok = self.login("user", "pass")
 
-        # Allow for uploading and downloading to/from the media repo
-        self.media_repo = hs.get_media_repository_resource()
-        self.download_resource = self.media_repo.children[b"download"]
-        self.upload_resource = self.media_repo.children[b"upload"]
-
         hs.get_module_api().register_spam_checker_callbacks(
             check_media_file_for_spam=self.check_media_file_for_spam
         )
 
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        resources = super().create_resource_dict()
+        resources["/_matrix/media"] = self.hs.get_media_repository_resource()
+        return resources
+
     async def check_media_file_for_spam(
         self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
     ) -> Union[Codes, Literal["NOT_SPAM"], Tuple[Codes, JsonDict]]:
@@ -771,21 +802,16 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
 
     def test_upload_innocent(self) -> None:
         """Attempt to upload some innocent data that should be allowed."""
-        self.helper.upload_media(
-            self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200
-        )
+        self.helper.upload_media(SMALL_PNG, tok=self.tok, expect_code=200)
 
     def test_upload_ban(self) -> None:
         """Attempt to upload some data that includes bytes "evil", which should
         get rejected by the spam checker.
         """
 
-        self.helper.upload_media(
-            self.upload_resource, EVIL_DATA, tok=self.tok, expect_code=400
-        )
+        self.helper.upload_media(EVIL_DATA, tok=self.tok, expect_code=400)
 
         self.helper.upload_media(
-            self.upload_resource,
             EVIL_DATA_EXPERIMENT,
             tok=self.tok,
             expect_code=400,
diff --git a/tests/media/test_oembed.py b/tests/media/test_oembed.py
index c8bf8421da..3bc19cb1cc 100644
--- a/tests/media/test_oembed.py
+++ b/tests/media/test_oembed.py
@@ -28,7 +28,7 @@ from tests.unittest import HomeserverTestCase
 try:
     import lxml
 except ImportError:
-    lxml = None
+    lxml = None  # type: ignore[assignment]
 
 
 class OEmbedTests(HomeserverTestCase):
diff --git a/tests/media/test_url_previewer.py b/tests/media/test_url_previewer.py
index 3c4c7d6765..04b69f378a 100644
--- a/tests/media/test_url_previewer.py
+++ b/tests/media/test_url_previewer.py
@@ -24,7 +24,7 @@ from tests.unittest import override_config
 try:
     import lxml
 except ImportError:
-    lxml = None
+    lxml = None  # type: ignore[assignment]
 
 
 class URLPreviewTests(unittest.HomeserverTestCase):
@@ -61,9 +61,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         return self.setup_test_homeserver(config=config)
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        media_repo_resource = hs.get_media_repository_resource()
-        preview_url = media_repo_resource.children[b"preview_url"]
-        self.url_previewer = preview_url._url_previewer
+        media_repo = hs.get_media_repository()
+        assert media_repo.url_previewer is not None
+        self.url_previewer = media_repo.url_previewer
 
     def test_all_urls_allowed(self) -> None:
         self.assertFalse(self.url_previewer._is_url_blocked("http://matrix.org"))
diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py
index 7c3656d049..d14876826c 100644
--- a/tests/metrics/test_metrics.py
+++ b/tests/metrics/test_metrics.py
@@ -12,19 +12,13 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from importlib import metadata
 from typing import Dict, Tuple
-
-from typing_extensions import Protocol
-
-try:
-    from importlib import metadata
-except ImportError:
-    import importlib_metadata as metadata  # type: ignore[no-redef]
-
 from unittest.mock import patch
 
 from pkg_resources import parse_version
 from prometheus_client.core import Sample
+from typing_extensions import Protocol
 
 from synapse.app._base import _set_prometheus_client_use_created_metrics
 from synapse.metrics import REGISTRY, InFlightGauge, generate_latest
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index bff7114cd8..172fc3a736 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from typing import Any, Dict, Optional
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.internet import defer
 from twisted.test.proto_helpers import MemoryReactor
@@ -28,12 +28,11 @@ from synapse.module_api import ModuleApi
 from synapse.rest import admin
 from synapse.rest.client import login, notifications, presence, profile, room
 from synapse.server import HomeServer
-from synapse.types import JsonDict, create_requester
+from synapse.types import JsonDict, UserID, create_requester
 from synapse.util import Clock
 
 from tests.events.test_presence_router import send_presence_update, sync_presence
 from tests.replication._base import BaseMultiWorkerStreamTestCase
-from tests.test_utils import simple_async_mock
 from tests.test_utils.event_injection import inject_member_event
 from tests.unittest import HomeserverTestCase, override_config
 
@@ -70,7 +69,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         # Mock out the calls over federation.
         self.fed_transport_client = Mock(spec=["send_transaction"])
-        self.fed_transport_client.send_transaction = simple_async_mock({})
+        self.fed_transport_client.send_transaction = AsyncMock(return_value={})
 
         return self.setup_test_homeserver(
             federation_transport_client=self.fed_transport_client,
@@ -103,7 +102,9 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
         self.assertEqual(email["added_at"], 0)
 
         # Check that the displayname was assigned
-        displayname = self.get_success(self.store.get_profile_displayname("bob"))
+        displayname = self.get_success(
+            self.store.get_profile_displayname(UserID.from_string("@bob:test"))
+        )
         self.assertEqual(displayname, "Bobberino")
 
     def test_can_register_admin_user(self) -> None:
@@ -232,7 +233,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
     def test_sending_events_into_room(self) -> None:
         """Tests that a module can send events into a room"""
         # Mock out create_and_send_nonmember_event to check whether events are being sent
-        self.event_creation_handler.create_and_send_nonmember_event = Mock(  # type: ignore[assignment]
+        self.event_creation_handler.create_and_send_nonmember_event = Mock(  # type: ignore[method-assign]
             spec=[],
             side_effect=self.event_creation_handler.create_and_send_nonmember_event,
         )
@@ -577,10 +578,8 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
         """Test that the module API can join a remote room."""
         # Necessary to fake a remote join.
         fake_stream_id = 1
-        mocked_remote_join = simple_async_mock(
-            return_value=("fake-event-id", fake_stream_id)
-        )
-        self.hs.get_room_member_handler()._remote_join = mocked_remote_join  # type: ignore[assignment]
+        mocked_remote_join = AsyncMock(return_value=("fake-event-id", fake_stream_id))
+        self.hs.get_room_member_handler()._remote_join = mocked_remote_join  # type: ignore[method-assign]
         fake_remote_host = f"{self.module_api.server_name}-remote"
 
         # Given that the join is to be faked, we expect the relevant join event not to
@@ -755,7 +754,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
         self.assertEqual(channel.json_body["creator"], user_id)
 
         # Check room alias.
-        self.assertEquals(room_alias, f"#foo-bar:{self.module_api.server_name}")
+        self.assertEqual(room_alias, f"#foo-bar:{self.module_api.server_name}")
 
         # Let's try a room with no alias.
         room_id, room_alias = self.get_success(
diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py
index 9501096a77..7c23b77e0a 100644
--- a/tests/push/test_bulk_push_rule_evaluator.py
+++ b/tests/push/test_bulk_push_rule_evaluator.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 from typing import Any, Optional
-from unittest.mock import patch
+from unittest.mock import AsyncMock, patch
 
 from parameterized import parameterized
 
@@ -28,7 +28,6 @@ from synapse.server import HomeServer
 from synapse.types import JsonDict, create_requester
 from synapse.util import Clock
 
-from tests.test_utils import simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
 
 
@@ -191,7 +190,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
         # Mock the method which calculates push rules -- we do this instead of
         # e.g. checking the results in the database because we want to ensure
         # that code isn't even running.
-        bulk_evaluator._action_for_event_by_user = simple_async_mock()  # type: ignore[assignment]
+        bulk_evaluator._action_for_event_by_user = AsyncMock()  # type: ignore[method-assign]
 
         # Ensure no actions are generated!
         self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)]))
@@ -228,7 +227,6 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
         )
         return len(result) > 0
 
-    @override_config({"experimental_features": {"msc3952_intentional_mentions": True}})
     def test_user_mentions(self) -> None:
         """Test the behavior of an event which includes invalid user mentions."""
         bulk_evaluator = BulkPushRuleEvaluator(self.hs)
@@ -237,9 +235,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
         self.assertFalse(self._create_and_process(bulk_evaluator))
         # An empty mentions field should not notify.
         self.assertFalse(
-            self._create_and_process(
-                bulk_evaluator, {EventContentFields.MSC3952_MENTIONS: {}}
-            )
+            self._create_and_process(bulk_evaluator, {EventContentFields.MENTIONS: {}})
         )
 
         # Non-dict mentions should be ignored.
@@ -253,7 +249,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
             for mentions in (None, True, False, 1, "foo", []):
                 self.assertFalse(
                     self._create_and_process(
-                        bulk_evaluator, {EventContentFields.MSC3952_MENTIONS: mentions}
+                        bulk_evaluator, {EventContentFields.MENTIONS: mentions}
                     )
                 )
 
@@ -262,7 +258,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
                 self.assertFalse(
                     self._create_and_process(
                         bulk_evaluator,
-                        {EventContentFields.MSC3952_MENTIONS: {"user_ids": mentions}},
+                        {EventContentFields.MENTIONS: {"user_ids": mentions}},
                     )
                 )
 
@@ -270,14 +266,14 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
         self.assertTrue(
             self._create_and_process(
                 bulk_evaluator,
-                {EventContentFields.MSC3952_MENTIONS: {"user_ids": [self.alice]}},
+                {EventContentFields.MENTIONS: {"user_ids": [self.alice]}},
             )
         )
         self.assertTrue(
             self._create_and_process(
                 bulk_evaluator,
                 {
-                    EventContentFields.MSC3952_MENTIONS: {
+                    EventContentFields.MENTIONS: {
                         "user_ids": ["@another:test", self.alice]
                     }
                 },
@@ -288,11 +284,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
         self.assertTrue(
             self._create_and_process(
                 bulk_evaluator,
-                {
-                    EventContentFields.MSC3952_MENTIONS: {
-                        "user_ids": [self.alice, self.alice]
-                    }
-                },
+                {EventContentFields.MENTIONS: {"user_ids": [self.alice, self.alice]}},
             )
         )
 
@@ -307,7 +299,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
                 self._create_and_process(
                     bulk_evaluator,
                     {
-                        EventContentFields.MSC3952_MENTIONS: {
+                        EventContentFields.MENTIONS: {
                             "user_ids": [None, True, False, {}, []]
                         }
                     },
@@ -317,7 +309,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
                 self._create_and_process(
                     bulk_evaluator,
                     {
-                        EventContentFields.MSC3952_MENTIONS: {
+                        EventContentFields.MENTIONS: {
                             "user_ids": [None, True, False, {}, [], self.alice]
                         }
                     },
@@ -331,12 +323,11 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
                 {
                     "body": self.alice,
                     "msgtype": "m.text",
-                    EventContentFields.MSC3952_MENTIONS: {},
+                    EventContentFields.MENTIONS: {},
                 },
             )
         )
 
-    @override_config({"experimental_features": {"msc3952_intentional_mentions": True}})
     def test_room_mentions(self) -> None:
         """Test the behavior of an event which includes invalid room mentions."""
         bulk_evaluator = BulkPushRuleEvaluator(self.hs)
@@ -344,7 +335,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
         # Room mentions from those without power should not notify.
         self.assertFalse(
             self._create_and_process(
-                bulk_evaluator, {EventContentFields.MSC3952_MENTIONS: {"room": True}}
+                bulk_evaluator, {EventContentFields.MENTIONS: {"room": True}}
             )
         )
 
@@ -358,7 +349,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
         )
         self.assertTrue(
             self._create_and_process(
-                bulk_evaluator, {EventContentFields.MSC3952_MENTIONS: {"room": True}}
+                bulk_evaluator, {EventContentFields.MENTIONS: {"room": True}}
             )
         )
 
@@ -374,7 +365,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
                 self.assertFalse(
                     self._create_and_process(
                         bulk_evaluator,
-                        {EventContentFields.MSC3952_MENTIONS: {"room": mentions}},
+                        {EventContentFields.MENTIONS: {"room": mentions}},
                     )
                 )
 
@@ -385,12 +376,11 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
                 {
                     "body": "@room",
                     "msgtype": "m.text",
-                    EventContentFields.MSC3952_MENTIONS: {},
+                    EventContentFields.MENTIONS: {},
                 },
             )
         )
 
-    @override_config({"experimental_features": {"msc3958_supress_edit_notifs": True}})
     def test_suppress_edits(self) -> None:
         """Under the default push rules, event edits should not generate notifications."""
         bulk_evaluator = BulkPushRuleEvaluator(self.hs)
@@ -417,16 +407,33 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
             )
         )
 
-        # Room mentions from those without power should not notify.
+        # The edit should not cause a notification.
         self.assertFalse(
             self._create_and_process(
                 bulk_evaluator,
                 {
-                    "body": self.alice,
+                    "body": "Test message",
+                    "m.relates_to": {
+                        "rel_type": RelationTypes.REPLACE,
+                        "event_id": event.event_id,
+                    },
+                },
+            )
+        )
+
+        # An edit which is a mention will cause a notification.
+        self.assertTrue(
+            self._create_and_process(
+                bulk_evaluator,
+                {
+                    "body": "Test message",
                     "m.relates_to": {
                         "rel_type": RelationTypes.REPLACE,
                         "event_id": event.event_id,
                     },
+                    "m.mentions": {
+                        "user_ids": [self.alice],
+                    },
                 },
             )
         )
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 4b5c96aeae..73a430ddc6 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -13,10 +13,12 @@
 # limitations under the License.
 import email.message
 import os
+from http import HTTPStatus
 from typing import Any, Dict, List, Sequence, Tuple
 
 import attr
 import pkg_resources
+from parameterized import parameterized
 
 from twisted.internet.defer import Deferred
 from twisted.test.proto_helpers import MemoryReactor
@@ -25,9 +27,11 @@ import synapse.rest.admin
 from synapse.api.errors import Codes, SynapseError
 from synapse.push.emailpusher import EmailPusher
 from synapse.rest.client import login, room
+from synapse.rest.synapse.client.unsubscribe import UnsubscribeResource
 from synapse.server import HomeServer
 from synapse.util import Clock
 
+from tests.server import FakeSite, make_request
 from tests.unittest import HomeserverTestCase
 
 
@@ -175,6 +179,57 @@ class EmailPusherTests(HomeserverTestCase):
 
         self._check_for_mail()
 
+    @parameterized.expand([(False,), (True,)])
+    def test_unsubscribe(self, use_post: bool) -> None:
+        # Create a simple room with two users
+        room = self.helper.create_room_as(self.user_id, tok=self.access_token)
+        self.helper.invite(
+            room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id
+        )
+        self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
+
+        # The other user sends a single message.
+        self.helper.send(room, body="Hi!", tok=self.others[0].token)
+
+        # We should get emailed about that message
+        args, kwargs = self._check_for_mail()
+
+        # That email should contain an unsubscribe link in the body and header.
+        msg: bytes = args[5]
+
+        # Multipart: plain text, base 64 encoded; html, base 64 encoded
+        multipart_msg = email.message_from_bytes(msg)
+        txt = multipart_msg.get_payload()[0].get_payload(decode=True).decode()
+        html = multipart_msg.get_payload()[1].get_payload(decode=True).decode()
+        self.assertIn("/_synapse/client/unsubscribe", txt)
+        self.assertIn("/_synapse/client/unsubscribe", html)
+
+        # The unsubscribe headers should exist.
+        assert multipart_msg.get("List-Unsubscribe") is not None
+        self.assertIsNotNone(multipart_msg.get("List-Unsubscribe-Post"))
+
+        # Open the unsubscribe link.
+        unsubscribe_link = multipart_msg["List-Unsubscribe"].strip("<>")
+        unsubscribe_resource = UnsubscribeResource(self.hs)
+        channel = make_request(
+            self.reactor,
+            FakeSite(unsubscribe_resource, self.reactor),
+            "POST" if use_post else "GET",
+            unsubscribe_link,
+            shorthand=False,
+        )
+        self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+        # Ensure the pusher was removed.
+        pushers = list(
+            self.get_success(
+                self.hs.get_datastores().main.get_pushers_by(
+                    {"user_name": self.user_id}
+                )
+            )
+        )
+        self.assertEqual(pushers, [])
+
     def test_invite_sends_email(self) -> None:
         # Create a room and invite the user to it
         room = self.helper.create_room_as(self.others[0].id, tok=self.others[0].token)
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index eb9b1f1cd9..6712ac485d 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -22,6 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor
 from twisted.web.resource import Resource
 
 from synapse.app.generic_worker import GenericWorkerServer
+from synapse.config.workers import InstanceTcpLocationConfig, InstanceUnixLocationConfig
 from synapse.http.site import SynapseRequest, SynapseSite
 from synapse.replication.http import ReplicationRestResource
 from synapse.replication.tcp.client import ReplicationDataHandler
@@ -69,10 +70,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         # Make a new HomeServer object for the worker
         self.reactor.lookups["testserv"] = "1.2.3.4"
         self.worker_hs = self.setup_test_homeserver(
-            federation_http_client=None,
             homeserver_to_use=GenericWorkerServer,
             config=self._get_worker_hs_config(),
             reactor=self.reactor,
+            federation_http_client=None,
         )
 
         # Since we use sqlite in memory databases we need to make sure the
@@ -339,7 +340,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         # `_handle_http_replication_attempt` like we do with the master HS.
         instance_name = worker_hs.get_instance_name()
         instance_loc = worker_hs.config.worker.instance_map.get(instance_name)
-        if instance_loc:
+        if instance_loc and isinstance(instance_loc, InstanceTcpLocationConfig):
             # Ensure the host is one that has a fake DNS entry.
             if instance_loc.host not in self.reactor.lookups:
                 raise Exception(
@@ -360,6 +361,10 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
                 instance_loc.port,
                 lambda: self._handle_http_replication_attempt(worker_hs, port),
             )
+        elif instance_loc and isinstance(instance_loc, InstanceUnixLocationConfig):
+            raise Exception(
+                "Unix sockets are not supported for unit tests at this time."
+            )
 
         store = worker_hs.get_datastores().main
         store.db_pool._db_pool = self.database_pool._db_pool
@@ -380,6 +385,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
             server_version_string="1",
             max_request_body_size=8192,
             reactor=self.reactor,
+            hs=worker_hs,
         )
 
         worker_hs.get_replication_command_handler().start_replication(worker_hs)
diff --git a/tests/replication/storage/_base.py b/tests/replication/storage/_base.py
index de26a62ae1..afcc80a8b3 100644
--- a/tests/replication/storage/_base.py
+++ b/tests/replication/storage/_base.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Any, Iterable, Optional
+from typing import Any, Callable, Iterable, Optional
 from unittest.mock import Mock
 
 from twisted.test.proto_helpers import MemoryReactor
@@ -47,24 +47,31 @@ class BaseWorkerStoreTestCase(BaseStreamTestCase):
         self.pump(0.1)
 
     def check(
-        self, method: str, args: Iterable[Any], expected_result: Optional[Any] = None
+        self,
+        method: str,
+        args: Iterable[Any],
+        expected_result: Optional[Any] = None,
+        asserter: Optional[Callable[[Any, Any, Optional[Any]], None]] = None,
     ) -> None:
+        if asserter is None:
+            asserter = self.assertEqual
+
         master_result = self.get_success(getattr(self.master_store, method)(*args))
         worker_result = self.get_success(getattr(self.worker_store, method)(*args))
         if expected_result is not None:
-            self.assertEqual(
+            asserter(
                 master_result,
                 expected_result,
                 "Expected master result to be %r but was %r"
                 % (expected_result, master_result),
             )
-            self.assertEqual(
+            asserter(
                 worker_result,
                 expected_result,
                 "Expected worker result to be %r but was %r"
                 % (expected_result, worker_result),
             )
-        self.assertEqual(
+        asserter(
             master_result,
             worker_result,
             "Worker result %r does not match master result %r"
diff --git a/tests/replication/storage/test_events.py b/tests/replication/storage/test_events.py
index f7c6417a09..17716253f8 100644
--- a/tests/replication/storage/test_events.py
+++ b/tests/replication/storage/test_events.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Any, Callable, Iterable, List, Optional, Tuple
+from typing import Any, Iterable, List, Optional, Tuple
 
 from canonicaljson import encode_canonical_json
 from parameterized import parameterized
@@ -21,7 +21,7 @@ from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import ReceiptTypes
 from synapse.api.room_versions import RoomVersions
-from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
+from synapse.events import EventBase, make_event_from_dict
 from synapse.events.snapshot import EventContext
 from synapse.handlers.room import RoomEventSource
 from synapse.server import HomeServer
@@ -46,32 +46,9 @@ ROOM_ID = "!room:test"
 logger = logging.getLogger(__name__)
 
 
-def dict_equals(self: EventBase, other: EventBase) -> bool:
-    me = encode_canonical_json(self.get_pdu_json())
-    them = encode_canonical_json(other.get_pdu_json())
-    return me == them
-
-
-def patch__eq__(cls: object) -> Callable[[], None]:
-    eq = getattr(cls, "__eq__", None)
-    cls.__eq__ = dict_equals  # type: ignore[assignment]
-
-    def unpatch() -> None:
-        if eq is not None:
-            cls.__eq__ = eq  # type: ignore[assignment]
-
-    return unpatch
-
-
 class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase):
     STORE_TYPE = EventsWorkerStore
 
-    def setUp(self) -> None:
-        # Patch up the equality operator for events so that we can check
-        # whether lists of events match using assertEqual
-        self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(EventBase)]
-        super().setUp()
-
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         super().prepare(reactor, clock, hs)
 
@@ -84,13 +61,19 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase):
             )
         )
 
-    def tearDown(self) -> None:
-        [unpatch() for unpatch in self.unpatches]
+    def assertEventsEqual(
+        self, first: EventBase, second: EventBase, msg: Optional[Any] = None
+    ) -> None:
+        self.assertEqual(
+            encode_canonical_json(first.get_pdu_json()),
+            encode_canonical_json(second.get_pdu_json()),
+            msg,
+        )
 
     def test_get_latest_event_ids_in_room(self) -> None:
         create = self.persist(type="m.room.create", key="", creator=USER_ID)
         self.replicate()
-        self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
+        self.check("get_latest_event_ids_in_room", (ROOM_ID,), {create.event_id})
 
         join = self.persist(
             type="m.room.member",
@@ -99,7 +82,7 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase):
             prev_events=[(create.event_id, {})],
         )
         self.replicate()
-        self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
+        self.check("get_latest_event_ids_in_room", (ROOM_ID,), {join.event_id})
 
     def test_redactions(self) -> None:
         self.persist(type="m.room.create", key="", creator=USER_ID)
@@ -107,7 +90,7 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase):
 
         msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
         self.replicate()
-        self.check("get_event", [msg.event_id], msg)
+        self.check("get_event", [msg.event_id], msg, asserter=self.assertEventsEqual)
 
         redaction = self.persist(type="m.room.redaction", redacts=msg.event_id)
         self.replicate()
@@ -119,7 +102,9 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase):
         redacted = make_event_from_dict(
             msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict()
         )
-        self.check("get_event", [msg.event_id], redacted)
+        self.check(
+            "get_event", [msg.event_id], redacted, asserter=self.assertEventsEqual
+        )
 
     def test_backfilled_redactions(self) -> None:
         self.persist(type="m.room.create", key="", creator=USER_ID)
@@ -127,7 +112,7 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase):
 
         msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
         self.replicate()
-        self.check("get_event", [msg.event_id], msg)
+        self.check("get_event", [msg.event_id], msg, asserter=self.assertEventsEqual)
 
         redaction = self.persist(
             type="m.room.redaction", redacts=msg.event_id, backfill=True
@@ -141,7 +126,9 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase):
         redacted = make_event_from_dict(
             msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict()
         )
-        self.check("get_event", [msg.event_id], redacted)
+        self.check(
+            "get_event", [msg.event_id], redacted, asserter=self.assertEventsEqual
+        )
 
     def test_invites(self) -> None:
         self.persist(type="m.room.create", key="", creator=USER_ID)
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 65ef4bb160..128fc3e046 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Any, List, Optional, Sequence
+from typing import Any, List, Optional
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -139,7 +139,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         )
 
         # this is the point in the DAG where we make a fork
-        fork_point: Sequence[str] = self.get_success(
+        fork_point = self.get_success(
             self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
         )
 
@@ -294,7 +294,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         )
 
         # this is the point in the DAG where we make a fork
-        fork_point: Sequence[str] = self.get_success(
+        fork_point = self.get_success(
             self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
         )
 
@@ -316,14 +316,14 @@ class EventsStreamTestCase(BaseStreamTestCase):
         self.test_handler.received_rdata_rows.clear()
 
         # now roll back all that state by de-modding the users
-        prev_events = fork_point
+        prev_events = list(fork_point)
         pl_events = []
         for u in user_ids:
             pls["users"][u] = 0
             e = self.get_success(
                 inject_event(
                     self.hs,
-                    prev_event_ids=list(prev_events),
+                    prev_event_ids=prev_events,
                     type=EventTypes.PowerLevels,
                     state_key="",
                     sender=self.user_id,
diff --git a/tests/replication/tcp/streams/test_to_device.py b/tests/replication/tcp/streams/test_to_device.py
index fb9eac668f..ab379e8cf1 100644
--- a/tests/replication/tcp/streams/test_to_device.py
+++ b/tests/replication/tcp/streams/test_to_device.py
@@ -49,7 +49,7 @@ class ToDeviceStreamTestCase(BaseStreamTestCase):
 
         # add messages to the device inbox for user1 up until the
         # limit defined for a stream update batch
-        for i in range(0, _STREAM_UPDATE_TARGET_ROW_COUNT):
+        for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT):
             msg["content"] = {"device": {}}
             messages = {user1: {"device": msg}}
 
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index 12668b34c5..cf59b1a204 100644
--- a/tests/replication/test_federation_ack.py
+++ b/tests/replication/test_federation_ack.py
@@ -32,6 +32,7 @@ class FederationAckTestCase(HomeserverTestCase):
         config["worker_app"] = "synapse.app.generic_worker"
         config["worker_name"] = "federation_sender1"
         config["federation_sender_instances"] = ["federation_sender1"]
+        config["instance_map"] = {"main": {"host": "127.0.0.1", "port": 0}}
         return config
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 08703206a9..59f4fdc70b 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -12,17 +12,20 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
+
+from netaddr import IPSet
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.events.builder import EventBuilderFactory
 from synapse.handlers.typing import TypingWriterHandler
+from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
 from synapse.rest.admin import register_servlets_for_client_rest_resource
 from synapse.rest.client import login, room
 from synapse.types import UserID, create_requester
 
 from tests.replication._base import BaseMultiWorkerStreamTestCase
-from tests.test_utils import make_awaitable
+from tests.server import get_clock
 
 logger = logging.getLogger(__name__)
 
@@ -41,13 +44,25 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         room.register_servlets,
     ]
 
+    def setUp(self) -> None:
+        super().setUp()
+
+        reactor, _ = get_clock()
+        self.matrix_federation_agent = MatrixFederationAgent(
+            reactor,
+            tls_client_options_factory=None,
+            user_agent=b"SynapseInTrialTest/0.0.0",
+            ip_allowlist=None,
+            ip_blocklist=IPSet(),
+        )
+
     def test_send_event_single_sender(self) -> None:
         """Test that using a single federation sender worker correctly sends a
         new event.
         """
         mock_client = Mock(spec=["put_json"])
-        mock_client.put_json.return_value = make_awaitable({})
-
+        mock_client.put_json = AsyncMock(return_value={})
+        mock_client.agent = self.matrix_federation_agent
         self.make_worker_hs(
             "synapse.app.generic_worker",
             {
@@ -77,7 +92,8 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         new events.
         """
         mock_client1 = Mock(spec=["put_json"])
-        mock_client1.put_json.return_value = make_awaitable({})
+        mock_client1.put_json = AsyncMock(return_value={})
+        mock_client1.agent = self.matrix_federation_agent
         self.make_worker_hs(
             "synapse.app.generic_worker",
             {
@@ -91,7 +107,8 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         )
 
         mock_client2 = Mock(spec=["put_json"])
-        mock_client2.put_json.return_value = make_awaitable({})
+        mock_client2.put_json = AsyncMock(return_value={})
+        mock_client2.agent = self.matrix_federation_agent
         self.make_worker_hs(
             "synapse.app.generic_worker",
             {
@@ -144,7 +161,8 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         new typing EDUs.
         """
         mock_client1 = Mock(spec=["put_json"])
-        mock_client1.put_json.return_value = make_awaitable({})
+        mock_client1.put_json = AsyncMock(return_value={})
+        mock_client1.agent = self.matrix_federation_agent
         self.make_worker_hs(
             "synapse.app.generic_worker",
             {
@@ -158,7 +176,8 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         )
 
         mock_client2 = Mock(spec=["put_json"])
-        mock_client2.put_json.return_value = make_awaitable({})
+        mock_client2.put_json = AsyncMock(return_value={})
+        mock_client2.agent = self.matrix_federation_agent
         self.make_worker_hs(
             "synapse.app.generic_worker",
             {
@@ -242,7 +261,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
 
         builder = factory.for_room_version(room_version, event_dict)
         join_event = self.get_success(
-            builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None)
+            builder.build(prev_event_ids=list(prev_event_ids), auth_event_ids=None)
         )
 
         self.get_success(federation.on_send_membership_event(remote_server, join_event))
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 1527b4a82d..b230a6c361 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 import logging
 import os
-from typing import Optional, Tuple
+from typing import Any, Optional, Tuple
 
 from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
 from twisted.internet.protocol import Factory
@@ -29,7 +29,7 @@ from synapse.util import Clock
 
 from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
 from tests.replication._base import BaseMultiWorkerStreamTestCase
-from tests.server import FakeChannel, FakeSite, FakeTransport, make_request
+from tests.server import FakeChannel, FakeTransport, make_request
 from tests.test_utils import SMALL_PNG
 
 logger = logging.getLogger(__name__)
@@ -56,6 +56,16 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         conf["federation_custom_ca_list"] = [get_test_ca_cert_file()]
         return conf
 
+    def make_worker_hs(
+        self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any
+    ) -> HomeServer:
+        worker_hs = super().make_worker_hs(worker_app, extra_config, **kwargs)
+        # Force the media paths onto the replication resource.
+        worker_hs.get_media_repository_resource().register_servlets(
+            self._hs_to_site[worker_hs].resource, worker_hs
+        )
+        return worker_hs
+
     def _get_media_req(
         self, hs: HomeServer, target: str, media_id: str
     ) -> Tuple[FakeChannel, Request]:
@@ -68,12 +78,11 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
             The channel for the *client* request and the *outbound* request for
             the media which the caller should respond to.
         """
-        resource = hs.get_media_repository_resource().children[b"download"]
         channel = make_request(
             self.reactor,
-            FakeSite(resource, self.reactor),
+            self._hs_to_site[hs],
             "GET",
-            f"/{target}/{media_id}",
+            f"/_matrix/media/r0/download/{target}/{media_id}",
             shorthand=False,
             access_token=self.access_token,
             await_result=False,
@@ -116,7 +125,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         self.assertEqual(request.method, b"GET")
         self.assertEqual(
             request.path,
-            f"/_matrix/media/r0/download/{target}/{media_id}".encode("utf-8"),
+            f"/_matrix/media/r0/download/{target}/{media_id}".encode(),
         )
         self.assertEqual(
             request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")]
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 695e84357a..8646b2f0fd 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -13,10 +13,12 @@
 # limitations under the License.
 
 import urllib.parse
+from typing import Dict
 
 from parameterized import parameterized
 
 from twisted.test.proto_helpers import MemoryReactor
+from twisted.web.resource import Resource
 
 import synapse.rest.admin
 from synapse.http.server import JsonResource
@@ -26,7 +28,6 @@ from synapse.server import HomeServer
 from synapse.util import Clock
 
 from tests import unittest
-from tests.server import FakeSite, make_request
 from tests.test_utils import SMALL_PNG
 
 
@@ -42,9 +43,7 @@ class VersionTestCase(unittest.HomeserverTestCase):
         channel = self.make_request("GET", self.url, shorthand=False)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
-        self.assertEqual(
-            {"server_version", "python_version"}, set(channel.json_body.keys())
-        )
+        self.assertEqual({"server_version"}, set(channel.json_body.keys()))
 
 
 class QuarantineMediaTestCase(unittest.HomeserverTestCase):
@@ -57,21 +56,18 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        # Allow for uploading and downloading to/from the media repo
-        self.media_repo = hs.get_media_repository_resource()
-        self.download_resource = self.media_repo.children[b"download"]
-        self.upload_resource = self.media_repo.children[b"upload"]
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        resources = super().create_resource_dict()
+        resources["/_matrix/media"] = self.hs.get_media_repository_resource()
+        return resources
 
     def _ensure_quarantined(
         self, admin_user_tok: str, server_and_media_id: str
     ) -> None:
         """Ensure a piece of media is quarantined when trying to access it."""
-        channel = make_request(
-            self.reactor,
-            FakeSite(self.download_resource, self.reactor),
+        channel = self.make_request(
             "GET",
-            server_and_media_id,
+            f"/_matrix/media/v3/download/{server_and_media_id}",
             shorthand=False,
             access_token=admin_user_tok,
         )
@@ -119,20 +115,16 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
         non_admin_user_tok = self.login("id_nonadmin", "pass")
 
         # Upload some media into the room
-        response = self.helper.upload_media(
-            self.upload_resource, SMALL_PNG, tok=admin_user_tok
-        )
+        response = self.helper.upload_media(SMALL_PNG, tok=admin_user_tok)
 
         # Extract media ID from the response
         server_name_and_media_id = response["content_uri"][6:]  # Cut off 'mxc://'
         server_name, media_id = server_name_and_media_id.split("/")
 
         # Attempt to access the media
-        channel = make_request(
-            self.reactor,
-            FakeSite(self.download_resource, self.reactor),
+        channel = self.make_request(
             "GET",
-            server_name_and_media_id,
+            f"/_matrix/media/v3/download/{server_name_and_media_id}",
             shorthand=False,
             access_token=non_admin_user_tok,
         )
@@ -175,12 +167,8 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
         self.helper.join(room_id, non_admin_user, tok=non_admin_user_tok)
 
         # Upload some media
-        response_1 = self.helper.upload_media(
-            self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
-        )
-        response_2 = self.helper.upload_media(
-            self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
-        )
+        response_1 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
+        response_2 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
 
         # Extract mxcs
         mxc_1 = response_1["content_uri"]
@@ -229,12 +217,8 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
         non_admin_user_tok = self.login("user_nonadmin", "pass")
 
         # Upload some media
-        response_1 = self.helper.upload_media(
-            self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
-        )
-        response_2 = self.helper.upload_media(
-            self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
-        )
+        response_1 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
+        response_2 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
 
         # Extract media IDs
         server_and_media_id_1 = response_1["content_uri"][6:]
@@ -267,12 +251,8 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
         non_admin_user_tok = self.login("user_nonadmin", "pass")
 
         # Upload some media
-        response_1 = self.helper.upload_media(
-            self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
-        )
-        response_2 = self.helper.upload_media(
-            self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
-        )
+        response_1 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
+        response_2 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
 
         # Extract media IDs
         server_and_media_id_1 = response_1["content_uri"][6:]
@@ -306,11 +286,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
         self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
 
         # Attempt to access each piece of media
-        channel = make_request(
-            self.reactor,
-            FakeSite(self.download_resource, self.reactor),
+        channel = self.make_request(
             "GET",
-            server_and_media_id_2,
+            f"/_matrix/media/v3/download/{server_and_media_id_2}",
             shorthand=False,
             access_token=non_admin_user_tok,
         )
diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py
index 4c7864c629..0e2824d1b5 100644
--- a/tests/rest/admin/test_federation.py
+++ b/tests/rest/admin/test_federation.py
@@ -510,7 +510,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         Args:
             number_destinations: Number of destinations to be created
         """
-        for i in range(0, number_destinations):
+        for i in range(number_destinations):
             dest = f"sub{i}.example.com"
             self._create_destination(dest, 50, 50, 50, 100)
 
@@ -690,7 +690,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
         self._check_fields(channel_desc.json_body["rooms"])
 
         # test that both lists have different directions
-        for i in range(0, number_rooms):
+        for i in range(number_rooms):
             self.assertEqual(
                 channel_asc.json_body["rooms"][i]["room_id"],
                 channel_desc.json_body["rooms"][number_rooms - 1 - i]["room_id"],
@@ -777,7 +777,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
         Args:
             number_rooms: Number of rooms to be created
         """
-        for _ in range(0, number_rooms):
+        for _ in range(number_rooms):
             room_id = self.helper.create_room_as(
                 self.admin_user, tok=self.admin_user_tok
             )
diff --git a/tests/rest/admin/test_jwks.py b/tests/rest/admin/test_jwks.py
new file mode 100644
index 0000000000..a9a6191c73
--- /dev/null
+++ b/tests/rest/admin/test_jwks.py
@@ -0,0 +1,106 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict
+
+from twisted.web.resource import Resource
+
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
+
+from tests.unittest import HomeserverTestCase, override_config, skip_unless
+
+try:
+    import authlib  # noqa: F401
+
+    HAS_AUTHLIB = True
+except ImportError:
+    HAS_AUTHLIB = False
+
+
+@skip_unless(HAS_AUTHLIB, "requires authlib")
+class JWKSTestCase(HomeserverTestCase):
+    """Test /_synapse/jwks JWKS data."""
+
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        d = super().create_resource_dict()
+        d.update(build_synapse_client_resource_tree(self.hs))
+        return d
+
+    def test_empty_jwks(self) -> None:
+        """Test that the JWKS endpoint is not present by default."""
+        channel = self.make_request("GET", "/_synapse/jwks")
+        self.assertEqual(404, channel.code, channel.result)
+
+    @override_config(
+        {
+            "disable_registration": True,
+            "experimental_features": {
+                "msc3861": {
+                    "enabled": True,
+                    "issuer": "https://issuer/",
+                    "client_id": "test-client-id",
+                    "client_auth_method": "client_secret_post",
+                    "client_secret": "secret",
+                },
+            },
+        }
+    )
+    def test_empty_jwks_for_msc3861_client_secret_post(self) -> None:
+        """Test that the JWKS endpoint is empty when plain auth is used."""
+        channel = self.make_request("GET", "/_synapse/jwks")
+        self.assertEqual(200, channel.code, channel.result)
+        self.assertEqual({"keys": []}, channel.json_body)
+
+    @override_config(
+        {
+            "disable_registration": True,
+            "experimental_features": {
+                "msc3861": {
+                    "enabled": True,
+                    "issuer": "https://issuer/",
+                    "client_id": "test-client-id",
+                    "client_auth_method": "private_key_jwt",
+                    "jwk": {
+                        "p": "-frVdP_tZ-J_nIR6HNMDq1N7aunwm51nAqNnhqIyuA8ikx7LlQED1tt2LD3YEvYyW8nxE2V95HlCRZXQPMiRJBFOsbmYkzl2t-MpavTaObB_fct_JqcRtdXddg4-_ihdjRDwUOreq_dpWh6MIKsC3UyekfkHmeEJg5YpOTL15j8",
+                        "kty": "RSA",
+                        "q": "oFw-Enr_YozQB1ab-kawn4jY3yHi8B1nSmYT0s8oTCflrmps5BFJfCkHL5ij3iY15z0o2m0N-jjB1oSJ98O4RayEEYNQlHnTNTl0kRIWzpoqblHUIxVcahIpP_xTovBJzwi8XXoLGqHOOMA-r40LSyVgP2Ut8D9qBwV6_UfT0LU",
+                        "d": "WFkDPYo4b4LIS64D_QtQfGGuAObPvc3HFfp9VZXyq3SJR58XZRHE0jqtlEMNHhOTgbMYS3w8nxPQ_qVzY-5hs4fIanwvB64mAoOGl0qMHO65DTD_WsGFwzYClJPBVniavkLE2Hmpu8IGe6lGliN8vREC6_4t69liY-XcN_ECboVtC2behKkLOEASOIMuS7YcKAhTJFJwkl1dqDlliEn5A4u4xy7nuWQz3juB1OFdKlwGA5dfhDNglhoLIwNnkLsUPPFO-WB5ZNEW35xxHOToxj4bShvDuanVA6mJPtTKjz0XibjB36bj_nF_j7EtbE2PdGJ2KevAVgElR4lqS4ISgQ",
+                        "e": "AQAB",
+                        "kid": "test",
+                        "qi": "cPfNk8l8W5exVNNea4d7QZZ8Qr8LgHghypYAxz8PQh1fNa8Ya1SNUDVzC2iHHhszxxA0vB9C7jGze8dBrvnzWYF1XvQcqNIVVgHhD57R1Nm3dj2NoHIKe0Cu4bCUtP8xnZQUN4KX7y4IIcgRcBWG1hT6DEYZ4BxqicnBXXNXAUI",
+                        "dp": "dKlMHvslV1sMBQaKWpNb3gPq0B13TZhqr3-E2_8sPlvJ3fD8P4CmwwnOn50JDuhY3h9jY5L06sBwXjspYISVv8hX-ndMLkEeF3lrJeA5S70D8rgakfZcPIkffm3tlf1Ok3v5OzoxSv3-67Df4osMniyYwDUBCB5Oq1tTx77xpU8",
+                        "dq": "S4ooU1xNYYcjl9FcuJEEMqKsRrAXzzSKq6laPTwIp5dDwt2vXeAm1a4eDHXC-6rUSZGt5PbqVqzV4s-cjnJMI8YYkIdjNg4NSE1Ac_YpeDl3M3Colb5CQlU7yUB7xY2bt0NOOFp9UJZYJrOo09mFMGjy5eorsbitoZEbVqS3SuE",
+                        "n": "nJbYKqFwnURKimaviyDFrNLD3gaKR1JW343Qem25VeZxoMq1665RHVoO8n1oBm4ClZdjIiZiVdpyqzD5-Ow12YQgQEf1ZHP3CCcOQQhU57Rh5XvScTe5IxYVkEW32IW2mp_CJ6WfjYpfeL4azarVk8H3Vr59d1rSrKTVVinVdZer9YLQyC_rWAQNtHafPBMrf6RYiNGV9EiYn72wFIXlLlBYQ9Fx7bfe1PaL6qrQSsZP3_rSpuvVdLh1lqGeCLR0pyclA9uo5m2tMyCXuuGQLbA_QJm5xEc7zd-WFdux2eXF045oxnSZ_kgQt-pdN7AxGWOVvwoTf9am6mSkEdv6iw",
+                    },
+                },
+            },
+        }
+    )
+    def test_key_returned_for_msc3861_client_secret_post(self) -> None:
+        """Test that the JWKS includes public part of JWK for private_key_jwt auth is used."""
+        channel = self.make_request("GET", "/_synapse/jwks")
+        self.assertEqual(200, channel.code, channel.result)
+        self.assertEqual(
+            {
+                "keys": [
+                    {
+                        "kty": "RSA",
+                        "e": "AQAB",
+                        "kid": "test",
+                        "n": "nJbYKqFwnURKimaviyDFrNLD3gaKR1JW343Qem25VeZxoMq1665RHVoO8n1oBm4ClZdjIiZiVdpyqzD5-Ow12YQgQEf1ZHP3CCcOQQhU57Rh5XvScTe5IxYVkEW32IW2mp_CJ6WfjYpfeL4azarVk8H3Vr59d1rSrKTVVinVdZer9YLQyC_rWAQNtHafPBMrf6RYiNGV9EiYn72wFIXlLlBYQ9Fx7bfe1PaL6qrQSsZP3_rSpuvVdLh1lqGeCLR0pyclA9uo5m2tMyCXuuGQLbA_QJm5xEc7zd-WFdux2eXF045oxnSZ_kgQt-pdN7AxGWOVvwoTf9am6mSkEdv6iw",
+                    }
+                ]
+            },
+            channel.json_body,
+        )
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 6d04911d67..278808abb5 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -13,10 +13,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import os
+from typing import Dict
 
 from parameterized import parameterized
 
 from twisted.test.proto_helpers import MemoryReactor
+from twisted.web.resource import Resource
 
 import synapse.rest.admin
 from synapse.api.errors import Codes
@@ -26,22 +28,27 @@ from synapse.server import HomeServer
 from synapse.util import Clock
 
 from tests import unittest
-from tests.server import FakeSite, make_request
 from tests.test_utils import SMALL_PNG
 
 VALID_TIMESTAMP = 1609459200000  # 2021-01-01 in milliseconds
 INVALID_TIMESTAMP_IN_S = 1893456000  # 2030-01-01 in seconds
 
 
-class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
+class _AdminMediaTests(unittest.HomeserverTestCase):
     servlets = [
         synapse.rest.admin.register_servlets,
         synapse.rest.admin.register_servlets_for_media_repo,
         login.register_servlets,
     ]
 
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        resources = super().create_resource_dict()
+        resources["/_matrix/media"] = self.hs.get_media_repository_resource()
+        return resources
+
+
+class DeleteMediaByIDTestCase(_AdminMediaTests):
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.media_repo = hs.get_media_repository_resource()
         self.server_name = hs.hostname
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -117,12 +124,8 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
         Tests that delete a media is successfully
         """
 
-        download_resource = self.media_repo.children[b"download"]
-        upload_resource = self.media_repo.children[b"upload"]
-
         # Upload some media into the room
         response = self.helper.upload_media(
-            upload_resource,
             SMALL_PNG,
             tok=self.admin_user_tok,
             expect_code=200,
@@ -134,11 +137,9 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
         self.assertEqual(server_name, self.server_name)
 
         # Attempt to access media
-        channel = make_request(
-            self.reactor,
-            FakeSite(download_resource, self.reactor),
+        channel = self.make_request(
             "GET",
-            server_and_media_id,
+            f"/_matrix/media/v3/download/{server_and_media_id}",
             shorthand=False,
             access_token=self.admin_user_tok,
         )
@@ -173,11 +174,9 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
         )
 
         # Attempt to access media
-        channel = make_request(
-            self.reactor,
-            FakeSite(download_resource, self.reactor),
+        channel = self.make_request(
             "GET",
-            server_and_media_id,
+            f"/_matrix/media/v3/download/{server_and_media_id}",
             shorthand=False,
             access_token=self.admin_user_tok,
         )
@@ -194,7 +193,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
         self.assertFalse(os.path.exists(local_path))
 
 
-class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
+class DeleteMediaByDateSizeTestCase(_AdminMediaTests):
     servlets = [
         synapse.rest.admin.register_servlets,
         synapse.rest.admin.register_servlets_for_media_repo,
@@ -529,11 +528,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         """
         Create a media and return media_id and server_and_media_id
         """
-        upload_resource = self.media_repo.children[b"upload"]
-
         # Upload some media into the room
         response = self.helper.upload_media(
-            upload_resource,
             SMALL_PNG,
             tok=self.admin_user_tok,
             expect_code=200,
@@ -553,16 +549,12 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         """
         Try to access a media and check the result
         """
-        download_resource = self.media_repo.children[b"download"]
-
         media_id = server_and_media_id.split("/")[1]
         local_path = self.filepaths.local_media_filepath(media_id)
 
-        channel = make_request(
-            self.reactor,
-            FakeSite(download_resource, self.reactor),
+        channel = self.make_request(
             "GET",
-            server_and_media_id,
+            f"/_matrix/media/v3/download/{server_and_media_id}",
             shorthand=False,
             access_token=self.admin_user_tok,
         )
@@ -591,27 +583,16 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
             self.assertFalse(os.path.exists(local_path))
 
 
-class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
-    servlets = [
-        synapse.rest.admin.register_servlets,
-        synapse.rest.admin.register_servlets_for_media_repo,
-        login.register_servlets,
-    ]
-
+class QuarantineMediaByIDTestCase(_AdminMediaTests):
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        media_repo = hs.get_media_repository_resource()
         self.store = hs.get_datastores().main
         self.server_name = hs.hostname
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
-        # Create media
-        upload_resource = media_repo.children[b"upload"]
-
         # Upload some media into the room
         response = self.helper.upload_media(
-            upload_resource,
             SMALL_PNG,
             tok=self.admin_user_tok,
             expect_code=200,
@@ -720,26 +701,16 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
         self.assertFalse(media_info["quarantined_by"])
 
 
-class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
-    servlets = [
-        synapse.rest.admin.register_servlets,
-        synapse.rest.admin.register_servlets_for_media_repo,
-        login.register_servlets,
-    ]
-
+class ProtectMediaByIDTestCase(_AdminMediaTests):
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        media_repo = hs.get_media_repository_resource()
+        hs.get_media_repository_resource()
         self.store = hs.get_datastores().main
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
-        # Create media
-        upload_resource = media_repo.children[b"upload"]
-
         # Upload some media into the room
         response = self.helper.upload_media(
-            upload_resource,
             SMALL_PNG,
             tok=self.admin_user_tok,
             expect_code=200,
@@ -816,7 +787,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
         self.assertFalse(media_info["safe_from_quarantine"])
 
 
-class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
+class PurgeMediaCacheTestCase(_AdminMediaTests):
     servlets = [
         synapse.rest.admin.register_servlets,
         synapse.rest.admin.register_servlets_for_media_repo,
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index eb50086c50..6ed451d7c4 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -15,26 +15,34 @@ import json
 import time
 import urllib.parse
 from typing import List, Optional
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from parameterized import parameterized
 
+from twisted.internet.task import deferLater
 from twisted.test.proto_helpers import MemoryReactor
 
 import synapse.rest.admin
 from synapse.api.constants import EventTypes, Membership, RoomTypes
 from synapse.api.errors import Codes
-from synapse.handlers.pagination import PaginationHandler, PurgeStatus
+from synapse.handlers.pagination import (
+    PURGE_ROOM_ACTION_NAME,
+    SHUTDOWN_AND_PURGE_ROOM_ACTION_NAME,
+)
 from synapse.rest.client import directory, events, login, room
 from synapse.server import HomeServer
+from synapse.types import UserID
 from synapse.util import Clock
-from synapse.util.stringutils import random_string
+from synapse.util.task_scheduler import TaskScheduler
 
 from tests import unittest
 
 """Tests admin REST events for /rooms paths."""
 
 
+ONE_HOUR_IN_S = 3600
+
+
 class DeleteRoomTestCase(unittest.HomeserverTestCase):
     servlets = [
         synapse.rest.admin.register_servlets,
@@ -46,6 +54,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.event_creation_handler = hs.get_event_creation_handler()
+        self.task_scheduler = hs.get_task_scheduler()
         hs.config.consent.user_consent_version = "1"
 
         consent_uri_builder = Mock()
@@ -476,6 +485,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.event_creation_handler = hs.get_event_creation_handler()
+        self.task_scheduler = hs.get_task_scheduler()
         hs.config.consent.user_consent_version = "1"
 
         consent_uri_builder = Mock()
@@ -502,6 +512,9 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
         )
         self.url_status_by_delete_id = "/_synapse/admin/v2/rooms/delete_status/"
 
+        self.room_member_handler = hs.get_room_member_handler()
+        self.pagination_handler = hs.get_pagination_handler()
+
     @parameterized.expand(
         [
             ("DELETE", "/_synapse/admin/v2/rooms/%s"),
@@ -661,7 +674,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
         delete_id1 = channel.json_body["delete_id"]
 
         # go ahead
-        self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2)
+        self.reactor.advance(TaskScheduler.KEEP_TASKS_FOR_MS / 1000 / 2)
 
         # second task
         channel = self.make_request(
@@ -686,12 +699,14 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
         self.assertEqual(2, len(channel.json_body["results"]))
         self.assertEqual("complete", channel.json_body["results"][0]["status"])
         self.assertEqual("complete", channel.json_body["results"][1]["status"])
-        self.assertEqual(delete_id1, channel.json_body["results"][0]["delete_id"])
-        self.assertEqual(delete_id2, channel.json_body["results"][1]["delete_id"])
+        delete_ids = {delete_id1, delete_id2}
+        self.assertTrue(channel.json_body["results"][0]["delete_id"] in delete_ids)
+        delete_ids.remove(channel.json_body["results"][0]["delete_id"])
+        self.assertTrue(channel.json_body["results"][1]["delete_id"] in delete_ids)
 
         # get status after more than clearing time for first task
         # second task is not cleared
-        self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2)
+        self.reactor.advance(TaskScheduler.KEEP_TASKS_FOR_MS / 1000 / 2)
 
         channel = self.make_request(
             "GET",
@@ -705,7 +720,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
         self.assertEqual(delete_id2, channel.json_body["results"][0]["delete_id"])
 
         # get status after more than clearing time for all tasks
-        self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2)
+        self.reactor.advance(TaskScheduler.KEEP_TASKS_FOR_MS / 1000 / 2)
 
         channel = self.make_request(
             "GET",
@@ -721,6 +736,13 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
 
         body = {"new_room_user_id": self.admin_user}
 
+        # Mock PaginationHandler.purge_room to sleep for 100s, so we have time to do a second call
+        # before the purge is over. Note that it doesn't purge anymore, but we don't care.
+        async def purge_room(room_id: str, force: bool) -> None:
+            await deferLater(self.hs.get_reactor(), 100, lambda: None)
+
+        self.pagination_handler.purge_room = AsyncMock(side_effect=purge_room)  # type: ignore[method-assign]
+
         # first call to delete room
         # and do not wait for finish the task
         first_channel = self.make_request(
@@ -728,7 +750,6 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
             self.url.encode("ascii"),
             content=body,
             access_token=self.admin_user_tok,
-            await_result=False,
         )
 
         # second call to delete room
@@ -742,7 +763,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
         self.assertEqual(400, second_channel.code, msg=second_channel.json_body)
         self.assertEqual(Codes.UNKNOWN, second_channel.json_body["errcode"])
         self.assertEqual(
-            f"History purge already in progress for {self.room_id}",
+            f"Purge already in progress for {self.room_id}",
             second_channel.json_body["error"],
         )
 
@@ -751,6 +772,9 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, first_channel.code, msg=first_channel.json_body)
         self.assertIn("delete_id", first_channel.json_body)
 
+        # wait for purge_room to finish
+        self.pump(1)
+
         # check status after finish the task
         self._test_result(
             first_channel.json_body["delete_id"],
@@ -972,6 +996,115 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
         # Assert we can no longer peek into the room
         self._assert_peek(self.room_id, expect_code=403)
 
+    @unittest.override_config({"forgotten_room_retention_period": "1d"})
+    def test_purge_forgotten_room(self) -> None:
+        # Create a test room
+        room_id = self.helper.create_room_as(
+            self.admin_user,
+            tok=self.admin_user_tok,
+        )
+
+        self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
+        self.get_success(
+            self.room_member_handler.forget(
+                UserID.from_string(self.admin_user), room_id
+            )
+        )
+
+        # Test that room is not yet purged
+        with self.assertRaises(AssertionError):
+            self._is_purged(room_id)
+
+        # Advance 24 hours in the future, past the `forgotten_room_retention_period`
+        self.reactor.advance(24 * ONE_HOUR_IN_S)
+
+        self._is_purged(room_id)
+
+    def test_scheduled_purge_room(self) -> None:
+        # Create a test room
+        room_id = self.helper.create_room_as(
+            self.admin_user,
+            tok=self.admin_user_tok,
+        )
+        self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
+
+        # Schedule a purge 10 seconds in the future
+        self.get_success(
+            self.task_scheduler.schedule_task(
+                PURGE_ROOM_ACTION_NAME,
+                resource_id=room_id,
+                timestamp=self.clock.time_msec() + 10 * 1000,
+            )
+        )
+
+        # Test that room is not yet purged
+        with self.assertRaises(AssertionError):
+            self._is_purged(room_id)
+
+        # Wait for next scheduler run
+        self.reactor.advance(TaskScheduler.SCHEDULE_INTERVAL_MS)
+
+        self._is_purged(room_id)
+
+    def test_schedule_shutdown_room(self) -> None:
+        # Create a test room
+        room_id = self.helper.create_room_as(
+            self.other_user,
+            tok=self.other_user_tok,
+        )
+
+        # Schedule a shutdown 10 seconds in the future
+        delete_id = self.get_success(
+            self.task_scheduler.schedule_task(
+                SHUTDOWN_AND_PURGE_ROOM_ACTION_NAME,
+                resource_id=room_id,
+                params={
+                    "requester_user_id": self.admin_user,
+                    "new_room_user_id": self.admin_user,
+                    "new_room_name": None,
+                    "message": None,
+                    "block": False,
+                    "purge": True,
+                    "force_purge": True,
+                },
+                timestamp=self.clock.time_msec() + 10 * 1000,
+            )
+        )
+
+        # Test that room is not yet shutdown
+        self._is_member(room_id, self.other_user)
+
+        # Test that room is not yet purged
+        with self.assertRaises(AssertionError):
+            self._is_purged(room_id)
+
+        # Wait for next scheduler run
+        self.reactor.advance(TaskScheduler.SCHEDULE_INTERVAL_MS)
+
+        # Test that all users has been kicked (room is shutdown)
+        self._has_no_members(room_id)
+
+        self._is_purged(room_id)
+
+        # Retrieve delete results
+        result = self.make_request(
+            "GET",
+            self.url_status_by_delete_id + delete_id,
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, result.code, msg=result.json_body)
+
+        # Check that the user is in kicked_users
+        self.assertIn(
+            self.other_user, result.json_body["shutdown_room"]["kicked_users"]
+        )
+
+        new_room_id = result.json_body["shutdown_room"]["new_room_id"]
+        self.assertTrue(new_room_id)
+
+        # Check that the user is actually in the new room
+        self._is_member(new_room_id, self.other_user)
+
     def _is_blocked(self, room_id: str, expect: bool = True) -> None:
         """Assert that the room is blocked or not"""
         d = self.store.is_room_blocked(room_id)
@@ -1034,7 +1167,6 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
             kicked_user: a user_id which is kicked from the room
             expect_new_room: if we expect that a new room was created
         """
-
         # get information by room_id
         channel_room_id = self.make_request(
             "GET",
@@ -1957,11 +2089,8 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase):
         self.assertEqual(len(chunk), 2, [event["content"] for event in chunk])
 
         # Purge every event before the second event.
-        purge_id = random_string(16)
-        pagination_handler._purges_by_id[purge_id] = PurgeStatus()
         self.get_success(
-            pagination_handler._purge_history(
-                purge_id=purge_id,
+            pagination_handler.purge_history(
                 room_id=self.room_id,
                 token=second_token_str,
                 delete_local_events=True,
diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
index 28b999573e..dfd14f5751 100644
--- a/tests/rest/admin/test_server_notice.py
+++ b/tests/rest/admin/test_server_notice.py
@@ -22,6 +22,7 @@ from synapse.server import HomeServer
 from synapse.storage.roommember import RoomsForUser
 from synapse.types import JsonDict
 from synapse.util import Clock
+from synapse.util.stringutils import random_string
 
 from tests import unittest
 from tests.unittest import override_config
@@ -413,11 +414,24 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
         self.assertEqual(messages[0]["content"]["body"], "test msg one")
         self.assertEqual(messages[0]["sender"], "@notices:test")
 
+        random_string(16)
+
         # shut down and purge room
         self.get_success(
-            self.room_shutdown_handler.shutdown_room(first_room_id, self.admin_user)
-        )
-        self.get_success(self.pagination_handler.purge_room(first_room_id))
+            self.room_shutdown_handler.shutdown_room(
+                first_room_id,
+                {
+                    "requester_user_id": self.admin_user,
+                    "new_room_user_id": None,
+                    "new_room_name": None,
+                    "message": None,
+                    "block": False,
+                    "purge": True,
+                    "force_purge": False,
+                },
+            )
+        )
+        self.get_success(self.pagination_handler.purge_room(first_room_id, force=False))
 
         # user is not member anymore
         self._check_invite_and_join_status(self.other_user, 0, 0)
diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py
index b60f16b914..cd8ee274d8 100644
--- a/tests/rest/admin/test_statistics.py
+++ b/tests/rest/admin/test_statistics.py
@@ -12,9 +12,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import List, Optional
+from typing import Dict, List, Optional
 
 from twisted.test.proto_helpers import MemoryReactor
+from twisted.web.resource import Resource
 
 import synapse.rest.admin
 from synapse.api.errors import Codes
@@ -34,8 +35,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.media_repo = hs.get_media_repository_resource()
-
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
@@ -44,6 +43,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
 
         self.url = "/_synapse/admin/v1/statistics/users/media"
 
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        resources = super().create_resource_dict()
+        resources["/_matrix/media"] = self.hs.get_media_repository_resource()
+        return resources
+
     def test_no_auth(self) -> None:
         """
         Try to list users without authentication.
@@ -470,12 +474,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
             user_token: Access token of the user
             number_media: Number of media to be created for the user
         """
-        upload_resource = self.media_repo.children[b"upload"]
         for _ in range(number_media):
             # Upload some media into the room
-            self.helper.upload_media(
-                upload_resource, SMALL_PNG, tok=user_token, expect_code=200
-            )
+            self.helper.upload_media(SMALL_PNG, tok=user_token, expect_code=200)
 
     def _check_fields(self, content: List[JsonDict]) -> None:
         """Checks that all attributes are present in content
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 434bb56d44..37f37a09d8 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -17,26 +17,36 @@ import hmac
 import os
 import urllib.parse
 from binascii import unhexlify
-from typing import List, Optional
-from unittest.mock import Mock, patch
+from typing import Dict, List, Optional
+from unittest.mock import AsyncMock, Mock, patch
 
 from parameterized import parameterized, parameterized_class
 
 from twisted.test.proto_helpers import MemoryReactor
+from twisted.web.resource import Resource
 
 import synapse.rest.admin
 from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes
 from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
 from synapse.api.room_versions import RoomVersions
 from synapse.media.filepath import MediaFilePaths
-from synapse.rest.client import devices, login, logout, profile, register, room, sync
+from synapse.rest.client import (
+    devices,
+    login,
+    logout,
+    profile,
+    register,
+    room,
+    sync,
+    user_directory,
+)
 from synapse.server import HomeServer
+from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
 from synapse.types import JsonDict, UserID, create_requester
 from synapse.util import Clock
 
 from tests import unittest
-from tests.server import FakeSite, make_request
-from tests.test_utils import SMALL_PNG, make_awaitable
+from tests.test_utils import SMALL_PNG
 from tests.unittest import override_config
 
 
@@ -62,8 +72,8 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
 
         self.hs.config.registration.registration_shared_secret = "shared"
 
-        self.hs.get_media_repository = Mock()  # type: ignore[assignment]
-        self.hs.get_deactivate_account_handler = Mock()  # type: ignore[assignment]
+        self.hs.get_media_repository = Mock()  # type: ignore[method-assign]
+        self.hs.get_deactivate_account_handler = Mock()  # type: ignore[method-assign]
 
         return self.hs
 
@@ -410,8 +420,8 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         store = self.hs.get_datastores().main
 
         # Set monthly active users to the limit
-        store.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.hs.config.server.max_mau_value)
+        store.get_monthly_active_count = AsyncMock(
+            return_value=self.hs.config.server.max_mau_value
         )
         # Check that the blocking of monthly active users is working as expected
         # The registration of a new user fails due to the limit
@@ -447,6 +457,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
     servlets = [
         synapse.rest.admin.register_servlets,
         login.register_servlets,
+        room.register_servlets,
     ]
     url = "/_synapse/admin/v2/users"
 
@@ -497,6 +508,62 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         # Check that all fields are available
         self._check_fields(channel.json_body["users"])
 
+    def test_last_seen(self) -> None:
+        """
+        Test that last_seen_ts field is properly working.
+        """
+        user1 = self.register_user("u1", "pass")
+        user1_token = self.login("u1", "pass")
+        user2 = self.register_user("u2", "pass")
+        user2_token = self.login("u2", "pass")
+        user3 = self.register_user("u3", "pass")
+        user3_token = self.login("u3", "pass")
+
+        self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+        self.reactor.advance(10)
+        self.helper.create_room_as(user2, tok=user2_token)
+        self.reactor.advance(10)
+        self.helper.create_room_as(user1, tok=user1_token)
+        self.reactor.advance(10)
+        self.helper.create_room_as(user3, tok=user3_token)
+        self.reactor.advance(10)
+
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual(4, len(channel.json_body["users"]))
+        self.assertEqual(4, channel.json_body["total"])
+
+        admin_last_seen = channel.json_body["users"][0]["last_seen_ts"]
+        user1_last_seen = channel.json_body["users"][1]["last_seen_ts"]
+        user2_last_seen = channel.json_body["users"][2]["last_seen_ts"]
+        user3_last_seen = channel.json_body["users"][3]["last_seen_ts"]
+        self.assertTrue(admin_last_seen > 0 and admin_last_seen < 10000)
+        self.assertTrue(user2_last_seen > 10000 and user2_last_seen < 20000)
+        self.assertTrue(user1_last_seen > 20000 and user1_last_seen < 30000)
+        self.assertTrue(user3_last_seen > 30000 and user3_last_seen < 40000)
+
+        self._order_test([self.admin_user, user2, user1, user3], "last_seen_ts")
+
+        self.reactor.advance(LAST_SEEN_GRANULARITY / 1000)
+        self.helper.create_room_as(user1, tok=user1_token)
+        self.reactor.advance(10)
+
+        channel = self.make_request(
+            "GET",
+            self.url + "/" + user1,
+            access_token=self.admin_user_tok,
+        )
+        self.assertTrue(
+            channel.json_body["last_seen_ts"] > 40000 + LAST_SEEN_GRANULARITY
+        )
+
+        self._order_test([self.admin_user, user2, user3, user1], "last_seen_ts")
+
     def test_search_term(self) -> None:
         """Test that searching for a users works correctly"""
 
@@ -870,6 +937,44 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         self._order_test([self.admin_user, user1, user2], "creation_ts", "f")
         self._order_test([user2, user1, self.admin_user], "creation_ts", "b")
 
+    def test_filter_admins(self) -> None:
+        """
+        Tests whether the various values of the query parameter `admins` lead to the
+        expected result set.
+        """
+
+        # Register an additional non admin user
+        self.register_user("user", "pass", admin=False)
+
+        # Query all users
+        channel = self.make_request(
+            "GET",
+            f"{self.url}",
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, channel.code, channel.result)
+        self.assertEqual(2, channel.json_body["total"])
+
+        # Query only admin users
+        channel = self.make_request(
+            "GET",
+            f"{self.url}?admins=true",
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, channel.code, channel.result)
+        self.assertEqual(1, channel.json_body["total"])
+        self.assertEqual(1, channel.json_body["users"][0]["admin"])
+
+        # Query only non admin users
+        channel = self.make_request(
+            "GET",
+            f"{self.url}?admins=false",
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, channel.code, channel.result)
+        self.assertEqual(1, channel.json_body["total"])
+        self.assertFalse(channel.json_body["users"][0]["admin"])
+
     @override_config(
         {
             "experimental_features": {
@@ -933,6 +1038,84 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         self.assertEqual(1, len(non_admin_user_ids), non_admin_user_ids)
         self.assertEqual(not_approved_user, non_admin_user_ids[0])
 
+    def test_filter_not_user_types(self) -> None:
+        """Tests that the endpoint handles the not_user_types param"""
+
+        regular_user_id = self.register_user("normalo", "secret")
+
+        bot_user_id = self.register_user("robo", "secret")
+        self.make_request(
+            "PUT",
+            "/_synapse/admin/v2/users/" + urllib.parse.quote(bot_user_id),
+            {"user_type": UserTypes.BOT},
+            access_token=self.admin_user_tok,
+        )
+
+        support_user_id = self.register_user("foo", "secret")
+        self.make_request(
+            "PUT",
+            "/_synapse/admin/v2/users/" + urllib.parse.quote(support_user_id),
+            {"user_type": UserTypes.SUPPORT},
+            access_token=self.admin_user_tok,
+        )
+
+        def test_user_type(
+            expected_user_ids: List[str], not_user_types: Optional[List[str]] = None
+        ) -> None:
+            """Runs a test for the not_user_types param
+            Args:
+                expected_user_ids: Ids of the users that are expected to be returned
+                not_user_types: List of values for the not_user_types param
+            """
+
+            user_type_query = ""
+
+            if not_user_types is not None:
+                user_type_query = "&".join(
+                    [f"not_user_type={u}" for u in not_user_types]
+                )
+
+            test_url = f"{self.url}?{user_type_query}"
+            channel = self.make_request(
+                "GET",
+                test_url,
+                access_token=self.admin_user_tok,
+            )
+
+            self.assertEqual(200, channel.code)
+            self.assertEqual(channel.json_body["total"], len(expected_user_ids))
+            self.assertEqual(
+                expected_user_ids,
+                [u["name"] for u in channel.json_body["users"]],
+            )
+
+        # Request without user_types →  all users expected
+        test_user_type([self.admin_user, support_user_id, regular_user_id, bot_user_id])
+
+        # Request and exclude bot users
+        test_user_type(
+            [self.admin_user, support_user_id, regular_user_id],
+            not_user_types=[UserTypes.BOT],
+        )
+
+        # Request and exclude bot and support users
+        test_user_type(
+            [self.admin_user, regular_user_id],
+            not_user_types=[UserTypes.BOT, UserTypes.SUPPORT],
+        )
+
+        # Request and exclude empty user types →  only expected the bot and support user
+        test_user_type([support_user_id, bot_user_id], not_user_types=[""])
+
+        # Request and exclude empty user types and bots →  only expected the support user
+        test_user_type([support_user_id], not_user_types=["", UserTypes.BOT])
+
+        # Request and exclude a custom type (neither service nor bot) →  expect all users
+        test_user_type(
+            [self.admin_user, support_user_id, regular_user_id, bot_user_id],
+            not_user_types=["custom"],
+        )
+
     def test_erasure_status(self) -> None:
         # Create a new user.
         user_id = self.register_user("eraseme", "eraseme")
@@ -963,6 +1146,32 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         users = {user["name"]: user for user in channel.json_body["users"]}
         self.assertIs(users[user_id]["erased"], True)
 
+    def test_filter_locked(self) -> None:
+        # Create a new user.
+        user_id = self.register_user("lockme", "lockme")
+
+        # Lock them
+        self.get_success(self.store.set_user_locked_status(user_id, True))
+
+        # Locked user should appear in list users API
+        channel = self.make_request(
+            "GET",
+            self.url + "?locked=true",
+            access_token=self.admin_user_tok,
+        )
+        users = {user["name"]: user for user in channel.json_body["users"]}
+        self.assertIn(user_id, users)
+        self.assertTrue(users[user_id]["locked"])
+
+        # Locked user should not appear in list users API
+        channel = self.make_request(
+            "GET",
+            self.url + "?locked=false",
+            access_token=self.admin_user_tok,
+        )
+        users = {user["name"]: user for user in channel.json_body["users"]}
+        self.assertNotIn(user_id, users)
+
     def _order_test(
         self,
         expected_user_list: List[str],
@@ -1010,6 +1219,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
             self.assertIn("displayname", u)
             self.assertIn("avatar_url", u)
             self.assertIn("creation_ts", u)
+            self.assertIn("last_seen_ts", u)
 
     def _create_users(self, number_users: int) -> None:
         """
@@ -1340,7 +1550,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
         # To test deactivation for users without a profile, we delete the profile information for our user.
         self.get_success(
             self.store.db_pool.simple_delete_one(
-                table="profiles", keyvalues={"user_id": "user"}
+                table="profiles", keyvalues={"full_user_id": "@user:test"}
             )
         )
 
@@ -1399,6 +1609,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         login.register_servlets,
         sync.register_servlets,
         register.register_servlets,
+        user_directory.register_servlets,
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -1708,8 +1919,8 @@ class UserRestTestCase(unittest.HomeserverTestCase):
             )
 
         # Set monthly active users to the limit
-        self.store.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.hs.config.server.max_mau_value)
+        self.store.get_monthly_active_count = AsyncMock(
+            return_value=self.hs.config.server.max_mau_value
         )
         # Check that the blocking of monthly active users is working as expected
         # The registration of a new user fails due to the limit
@@ -1745,8 +1956,8 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         handler = self.hs.get_registration_handler()
 
         # Set monthly active users to the limit
-        self.store.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.hs.config.server.max_mau_value)
+        self.store.get_monthly_active_count = AsyncMock(
+            return_value=self.hs.config.server.max_mau_value
         )
         # Check that the blocking of monthly active users is working as expected
         # The registration of a new user fails due to the limit
@@ -2386,6 +2597,105 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         # This key was removed intentionally. Ensure it is not accidentally re-included.
         self.assertNotIn("password_hash", channel.json_body)
 
+    def test_locked_user(self) -> None:
+        # User can sync
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/v3/sync",
+            access_token=self.other_user_token,
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+
+        # Lock user
+        channel = self.make_request(
+            "PUT",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+            content={"locked": True},
+        )
+
+        # User is not authorized to sync anymore
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/v3/sync",
+            access_token=self.other_user_token,
+        )
+        self.assertEqual(401, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.USER_LOCKED, channel.json_body["errcode"])
+        self.assertTrue(channel.json_body["soft_logout"])
+
+    @override_config({"user_directory": {"enabled": True, "search_all_users": True}})
+    def test_locked_user_not_in_user_dir(self) -> None:
+        # User is available in the user dir
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/v3/user_directory/search",
+            {"search_term": self.other_user},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertIn("results", channel.json_body)
+        self.assertEqual(1, len(channel.json_body["results"]))
+
+        # Lock user
+        channel = self.make_request(
+            "PUT",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+            content={"locked": True},
+        )
+
+        # User is not available anymore in the user dir
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/v3/user_directory/search",
+            {"search_term": self.other_user},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertIn("results", channel.json_body)
+        self.assertEqual(0, len(channel.json_body["results"]))
+
+    @override_config(
+        {
+            "user_directory": {
+                "enabled": True,
+                "search_all_users": True,
+                "show_locked_users": True,
+            }
+        }
+    )
+    def test_locked_user_in_user_dir_with_show_locked_users_option(self) -> None:
+        # User is available in the user dir
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/v3/user_directory/search",
+            {"search_term": self.other_user},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertIn("results", channel.json_body)
+        self.assertEqual(1, len(channel.json_body["results"]))
+
+        # Lock user
+        channel = self.make_request(
+            "PUT",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+            content={"locked": True},
+        )
+
+        # User is still available in the user dir
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/v3/user_directory/search",
+            {"search_term": self.other_user},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertIn("results", channel.json_body)
+        self.assertEqual(1, len(channel.json_body["results"]))
+
     @override_config({"user_directory": {"enabled": True, "search_all_users": True}})
     def test_change_name_deactivate_user_user_directory(self) -> None:
         """
@@ -2394,7 +2704,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         """
 
         # is in user directory
-        profile = self.get_success(self.store.get_user_in_directory(self.other_user))
+        profile = self.get_success(self.store._get_user_in_directory(self.other_user))
         assert profile is not None
         self.assertTrue(profile["display_name"] == "User")
 
@@ -2411,7 +2721,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertTrue(channel.json_body["deactivated"])
 
         # is not in user directory
-        profile = self.get_success(self.store.get_user_in_directory(self.other_user))
+        profile = self.get_success(self.store._get_user_in_directory(self.other_user))
         self.assertIsNone(profile)
 
         # Set new displayname user
@@ -2428,7 +2738,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("Foobar", channel.json_body["displayname"])
 
         # is not in user directory
-        profile = self.get_success(self.store.get_user_in_directory(self.other_user))
+        profile = self.get_success(self.store._get_user_in_directory(self.other_user))
         self.assertIsNone(profile)
 
     def test_reactivate_user(self) -> None:
@@ -2810,6 +3120,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertIn("consent_version", content)
         self.assertIn("consent_ts", content)
         self.assertIn("external_ids", content)
+        self.assertIn("last_seen_ts", content)
 
         # This key was removed intentionally. Ensure it is not accidentally re-included.
         self.assertNotIn("password_hash", content)
@@ -3110,7 +3421,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
-        self.media_repo = hs.get_media_repository_resource()
         self.filepaths = MediaFilePaths(hs.config.media.media_store_path)
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -3121,6 +3431,11 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
             self.other_user
         )
 
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        resources = super().create_resource_dict()
+        resources["/_matrix/media"] = self.hs.get_media_repository_resource()
+        return resources
+
     @parameterized.expand(["GET", "DELETE"])
     def test_no_auth(self, method: str) -> None:
         """Try to list media of an user without authentication."""
@@ -3596,12 +3911,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         Returns:
             The ID of the newly created media.
         """
-        upload_resource = self.media_repo.children[b"upload"]
-        download_resource = self.media_repo.children[b"download"]
-
         # Upload some media into the room
         response = self.helper.upload_media(
-            upload_resource, image_data, user_token, filename, expect_code=200
+            image_data, user_token, filename, expect_code=200
         )
 
         # Extract media ID from the response
@@ -3609,11 +3921,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         media_id = server_and_media_id.split("/")[1]
 
         # Try to access a media and to create `last_access_ts`
-        channel = make_request(
-            self.reactor,
-            FakeSite(download_resource, self.reactor),
+        channel = self.make_request(
             "GET",
-            server_and_media_id,
+            f"/_matrix/media/v3/download/{server_and_media_id}",
             shorthand=False,
             access_token=user_token,
         )
diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py
index 6c04e6c56c..4c69d224b8 100644
--- a/tests/rest/admin/test_username_available.py
+++ b/tests/rest/admin/test_username_available.py
@@ -50,7 +50,7 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
             )
 
         handler = self.hs.get_registration_handler()
-        handler.check_username = check_username  # type: ignore[assignment]
+        handler.check_username = check_username  # type: ignore[method-assign]
 
     def test_username_available(self) -> None:
         """
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index ac19f3c6da..cffbda9a7d 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -31,6 +31,7 @@ from synapse.rest import admin
 from synapse.rest.client import account, login, register, room
 from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
 from synapse.server import HomeServer
+from synapse.storage._base import db_to_json
 from synapse.types import JsonDict, UserID
 from synapse.util import Clock
 
@@ -134,6 +135,18 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         # Assert we can't log in with the old password
         self.attempt_wrong_password_login("kermit", old_password)
 
+        # Check that the UI Auth information doesn't store the password in the database.
+        #
+        # Note that we don't have the UI Auth session ID, so just pull out the single
+        # row.
+        ui_auth_data = self.get_success(
+            self.store.db_pool.simple_select_one(
+                "ui_auth_sessions", keyvalues={}, retcols=("clientdict",)
+            )
+        )
+        client_dict = db_to_json(ui_auth_data["clientdict"])
+        self.assertNotIn("new_password", client_dict)
+
     @override_config({"rc_3pid_validation": {"burst_count": 3}})
     def test_ratelimit_by_email(self) -> None:
         """Test that we ratelimit /requestToken for the same email."""
@@ -562,7 +575,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
 
         # create a bunch of users and add keys for them
         users = []
-        for i in range(0, 20):
+        for i in range(20):
             user_id = self.register_user("missPiggy" + str(i), "test")
             users.append((user_id,))
 
@@ -1346,7 +1359,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
                 return {}
 
         # Register a mock that will return the expected result depending on the remote.
-        self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json)  # type: ignore[assignment]
+        self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json)  # type: ignore[method-assign]
 
         # Check that we've got the correct response from the client-side endpoint.
         self._test_status(
diff --git a/tests/rest/client/test_account_data.py b/tests/rest/client/test_account_data.py
index d5b0640e7a..481db9a687 100644
--- a/tests/rest/client/test_account_data.py
+++ b/tests/rest/client/test_account_data.py
@@ -11,13 +11,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from unittest.mock import Mock
+from unittest.mock import AsyncMock
 
 from synapse.rest import admin
 from synapse.rest.client import account_data, login, room
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 
 
 class AccountDataTestCase(unittest.HomeserverTestCase):
@@ -32,7 +31,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
         """Tests that the on_account_data_updated module callback is called correctly when
         a user's account data changes.
         """
-        mocked_callback = Mock(return_value=make_awaitable(None))
+        mocked_callback = AsyncMock(return_value=None)
         self.hs.get_account_data_handler()._on_account_data_updated_callbacks.append(
             mocked_callback
         )
diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py
index c16e8d43f4..cf23430f6a 100644
--- a/tests/rest/client/test_capabilities.py
+++ b/tests/rest/client/test_capabilities.py
@@ -186,3 +186,31 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
             self.assertGreater(len(details["support"]), 0)
             for room_version in details["support"]:
                 self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, str(room_version))
+
+    def test_get_get_token_login_fields_when_disabled(self) -> None:
+        """By default login via an existing session is disabled."""
+        access_token = self.get_success(
+            self.auth_handler.create_access_token_for_user_id(
+                self.user, device_id=None, valid_until_ms=None
+            )
+        )
+
+        channel = self.make_request("GET", self.url, access_token=access_token)
+        capabilities = channel.json_body["capabilities"]
+
+        self.assertEqual(channel.code, HTTPStatus.OK)
+        self.assertFalse(capabilities["m.get_login_token"]["enabled"])
+
+    @override_config({"login_via_existing_session": {"enabled": True}})
+    def test_get_get_token_login_fields_when_enabled(self) -> None:
+        access_token = self.get_success(
+            self.auth_handler.create_access_token_for_user_id(
+                self.user, device_id=None, valid_until_ms=None
+            )
+        )
+
+        channel = self.make_request("GET", self.url, access_token=access_token)
+        capabilities = channel.json_body["capabilities"]
+
+        self.assertEqual(channel.code, HTTPStatus.OK)
+        self.assertTrue(capabilities["m.get_login_token"]["enabled"])
diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py
index d80eea17d3..60099f8c59 100644
--- a/tests/rest/client/test_devices.py
+++ b/tests/rest/client/test_devices.py
@@ -13,12 +13,14 @@
 # limitations under the License.
 from http import HTTPStatus
 
+from twisted.internet.defer import ensureDeferred
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.errors import NotFoundError
 from synapse.rest import admin, devices, room, sync
-from synapse.rest.client import account, login, register
+from synapse.rest.client import account, keys, login, register
 from synapse.server import HomeServer
+from synapse.types import JsonDict, UserID, create_requester
 from synapse.util import Clock
 
 from tests import unittest
@@ -208,8 +210,13 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
         login.register_servlets,
         register.register_servlets,
         devices.register_servlets,
+        keys.register_servlets,
     ]
 
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.registration = hs.get_registration_handler()
+        self.message_handler = hs.get_device_message_handler()
+
     def test_PUT(self) -> None:
         """Sanity-check that we can PUT a dehydrated device.
 
@@ -226,7 +233,21 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
                 "device_data": {
                     "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm",
                     "account": "dehydrated_device",
-                }
+                },
+                "device_keys": {
+                    "user_id": "@alice:test",
+                    "device_id": "device1",
+                    "valid_until_ts": "80",
+                    "algorithms": [
+                        "m.olm.curve25519-aes-sha2",
+                    ],
+                    "keys": {
+                        "<algorithm>:<device_id>": "<key_base64>",
+                    },
+                    "signatures": {
+                        "<user_id>": {"<algorithm>:<device_id>": "<signature_base64>"}
+                    },
+                },
             },
             access_token=token,
             shorthand=False,
@@ -234,3 +255,336 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
         device_id = channel.json_body.get("device_id")
         self.assertIsInstance(device_id, str)
+
+    @unittest.override_config(
+        {"experimental_features": {"msc2697_enabled": False, "msc3814_enabled": True}}
+    )
+    def test_dehydrate_msc3814(self) -> None:
+        user = self.register_user("mikey", "pass")
+        token = self.login(user, "pass", device_id="device1")
+        content: JsonDict = {
+            "device_data": {
+                "algorithm": "m.dehydration.v1.olm",
+            },
+            "device_id": "device1",
+            "initial_device_display_name": "foo bar",
+            "device_keys": {
+                "user_id": "@mikey:test",
+                "device_id": "device1",
+                "valid_until_ts": "80",
+                "algorithms": [
+                    "m.olm.curve25519-aes-sha2",
+                ],
+                "keys": {
+                    "<algorithm>:<device_id>": "<key_base64>",
+                },
+                "signatures": {
+                    "<user_id>": {"<algorithm>:<device_id>": "<signature_base64>"}
+                },
+            },
+            "fallback_keys": {
+                "alg1:device1": "f4llb4ckk3y",
+                "signed_<algorithm>:<device_id>": {
+                    "fallback": "true",
+                    "key": "f4llb4ckk3y",
+                    "signatures": {
+                        "<user_id>": {"<algorithm>:<device_id>": "<key_base64>"}
+                    },
+                },
+            },
+            "one_time_keys": {"alg1:k1": "0net1m3k3y"},
+        }
+        channel = self.make_request(
+            "PUT",
+            "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device",
+            content=content,
+            access_token=token,
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 200)
+        device_id = channel.json_body.get("device_id")
+        assert device_id is not None
+        self.assertIsInstance(device_id, str)
+        self.assertEqual("device1", device_id)
+
+        # test that we can now GET the dehydrated device info
+        channel = self.make_request(
+            "GET",
+            "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device",
+            access_token=token,
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 200)
+        returned_device_id = channel.json_body.get("device_id")
+        self.assertEqual(returned_device_id, device_id)
+        device_data = channel.json_body.get("device_data")
+        expected_device_data = {
+            "algorithm": "m.dehydration.v1.olm",
+        }
+        self.assertEqual(device_data, expected_device_data)
+
+        # test that the keys are correctly uploaded
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/r0/keys/query",
+            {
+                "device_keys": {
+                    user: ["device1"],
+                },
+            },
+            token,
+        )
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(
+            channel.json_body["device_keys"][user][device_id]["keys"],
+            content["device_keys"]["keys"],
+        )
+        # first claim should return the onetime key we uploaded
+        res = self.get_success(
+            self.hs.get_e2e_keys_handler().claim_one_time_keys(
+                {user: {device_id: {"alg1": 1}}},
+                UserID.from_string(user),
+                timeout=None,
+                always_include_fallback_keys=False,
+            )
+        )
+        self.assertEqual(
+            res,
+            {
+                "failures": {},
+                "one_time_keys": {user: {device_id: {"alg1:k1": "0net1m3k3y"}}},
+            },
+        )
+        # second claim should return fallback key
+        res2 = self.get_success(
+            self.hs.get_e2e_keys_handler().claim_one_time_keys(
+                {user: {device_id: {"alg1": 1}}},
+                UserID.from_string(user),
+                timeout=None,
+                always_include_fallback_keys=False,
+            )
+        )
+        self.assertEqual(
+            res2,
+            {
+                "failures": {},
+                "one_time_keys": {user: {device_id: {"alg1:device1": "f4llb4ckk3y"}}},
+            },
+        )
+
+        # create another device for the user
+        (
+            new_device_id,
+            _,
+            _,
+            _,
+        ) = self.get_success(
+            self.registration.register_device(
+                user_id=user,
+                device_id=None,
+                initial_display_name="new device",
+            )
+        )
+        requester = create_requester(user, device_id=new_device_id)
+
+        # Send a message to the dehydrated device
+        ensureDeferred(
+            self.message_handler.send_device_message(
+                requester=requester,
+                message_type="test.message",
+                messages={user: {device_id: {"body": "test_message"}}},
+            )
+        )
+        self.pump()
+
+        # make sure we can fetch the message with our dehydrated device id
+        channel = self.make_request(
+            "POST",
+            f"_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device/{device_id}/events",
+            content={},
+            access_token=token,
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 200)
+        expected_content = {"body": "test_message"}
+        self.assertEqual(channel.json_body["events"][0]["content"], expected_content)
+
+        # fetch messages again and make sure that the message was not deleted
+        channel = self.make_request(
+            "POST",
+            f"_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device/{device_id}/events",
+            content={},
+            access_token=token,
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body["events"][0]["content"], expected_content)
+        next_batch_token = channel.json_body.get("next_batch")
+
+        # make sure fetching messages with next batch token works - there are no unfetched
+        # messages so we should receive an empty array
+        content = {"next_batch": next_batch_token}
+        channel = self.make_request(
+            "POST",
+            f"_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device/{device_id}/events",
+            content=content,
+            access_token=token,
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body["events"], [])
+
+        # make sure we can delete the dehydrated device
+        channel = self.make_request(
+            "DELETE",
+            "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device",
+            access_token=token,
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 200)
+
+        # ...and after deleting it is no longer available
+        channel = self.make_request(
+            "GET",
+            "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device",
+            access_token=token,
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 401)
+
+    @unittest.override_config(
+        {"experimental_features": {"msc2697_enabled": False, "msc3814_enabled": True}}
+    )
+    def test_msc3814_dehydrated_device_delete_works(self) -> None:
+        user = self.register_user("mikey", "pass")
+        token = self.login(user, "pass", device_id="device1")
+        content: JsonDict = {
+            "device_data": {
+                "algorithm": "m.dehydration.v1.olm",
+            },
+            "device_id": "device2",
+            "initial_device_display_name": "foo bar",
+            "device_keys": {
+                "user_id": "@mikey:test",
+                "device_id": "device2",
+                "valid_until_ts": "80",
+                "algorithms": [
+                    "m.olm.curve25519-aes-sha2",
+                ],
+                "keys": {
+                    "<algorithm>:<device_id>": "<key_base64>",
+                },
+                "signatures": {
+                    "<user_id>": {"<algorithm>:<device_id>": "<signature_base64>"}
+                },
+            },
+        }
+        channel = self.make_request(
+            "PUT",
+            "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device",
+            content=content,
+            access_token=token,
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 200)
+        device_id = channel.json_body.get("device_id")
+        assert device_id is not None
+        self.assertIsInstance(device_id, str)
+        self.assertEqual("device2", device_id)
+
+        # ensure that keys were uploaded and available
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/r0/keys/query",
+            {
+                "device_keys": {
+                    user: ["device2"],
+                },
+            },
+            token,
+        )
+        self.assertEqual(
+            channel.json_body["device_keys"][user]["device2"]["keys"],
+            {
+                "<algorithm>:<device_id>": "<key_base64>",
+            },
+        )
+
+        # delete the dehydrated device
+        channel = self.make_request(
+            "DELETE",
+            "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device",
+            access_token=token,
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 200)
+
+        # ensure that keys are no longer available for deleted device
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/r0/keys/query",
+            {
+                "device_keys": {
+                    user: ["device2"],
+                },
+            },
+            token,
+        )
+        self.assertEqual(channel.json_body["device_keys"], {"@mikey:test": {}})
+
+        # check that an old device is deleted when user PUTs a new device
+        # First, create a device
+        content["device_id"] = "device3"
+        content["device_keys"]["device_id"] = "device3"
+        channel = self.make_request(
+            "PUT",
+            "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device",
+            content=content,
+            access_token=token,
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 200)
+        device_id = channel.json_body.get("device_id")
+        assert device_id is not None
+        self.assertIsInstance(device_id, str)
+        self.assertEqual("device3", device_id)
+
+        # create a second device without deleting first device
+        content["device_id"] = "device4"
+        content["device_keys"]["device_id"] = "device4"
+        channel = self.make_request(
+            "PUT",
+            "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device",
+            content=content,
+            access_token=token,
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 200)
+        device_id = channel.json_body.get("device_id")
+        assert device_id is not None
+        self.assertIsInstance(device_id, str)
+        self.assertEqual("device4", device_id)
+
+        # check that the second device that was created is what is returned when we GET
+        channel = self.make_request(
+            "GET",
+            "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device",
+            access_token=token,
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 200)
+        returned_device_id = channel.json_body["device_id"]
+        self.assertEqual(returned_device_id, "device4")
+
+        # and that if we query the keys for the first device they are not there
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/r0/keys/query",
+            {
+                "device_keys": {
+                    user: ["device3"],
+                },
+            },
+            token,
+        )
+        self.assertEqual(channel.json_body["device_keys"], {"@mikey:test": {}})
diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py
index 54df2a252c..141e0f57a3 100644
--- a/tests/rest/client/test_events.py
+++ b/tests/rest/client/test_events.py
@@ -45,7 +45,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
 
         hs = self.setup_test_homeserver(config=config)
 
-        hs.get_federation_handler = Mock()  # type: ignore[assignment]
+        hs.get_federation_handler = Mock()  # type: ignore[method-assign]
 
         return hs
 
diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py
index 9faa9de050..90a8df147c 100644
--- a/tests/rest/client/test_filter.py
+++ b/tests/rest/client/test_filter.py
@@ -46,7 +46,9 @@ class FilterTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200)
         self.assertEqual(channel.json_body, {"filter_id": "0"})
         filter = self.get_success(
-            self.store.get_user_filter(user_localpart="apple", filter_id=0)
+            self.store.get_user_filter(
+                user_id=UserID.from_string(FilterTestCase.user_id), filter_id=0
+            )
         )
         self.pump()
         self.assertEqual(filter, self.EXAMPLE_FILTER)
@@ -63,14 +65,14 @@ class FilterTestCase(unittest.HomeserverTestCase):
 
     def test_add_filter_non_local_user(self) -> None:
         _is_mine = self.hs.is_mine
-        self.hs.is_mine = lambda target_user: False  # type: ignore[assignment]
+        self.hs.is_mine = lambda target_user: False  # type: ignore[method-assign]
         channel = self.make_request(
             "POST",
             "/_matrix/client/r0/user/%s/filter" % (self.user_id),
             self.EXAMPLE_FILTER_JSON,
         )
 
-        self.hs.is_mine = _is_mine  # type: ignore[assignment]
+        self.hs.is_mine = _is_mine  # type: ignore[method-assign]
         self.assertEqual(channel.code, 403)
         self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
 
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index dc32982e22..768d7ad4c2 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -13,11 +13,12 @@
 # limitations under the License.
 import time
 import urllib.parse
-from typing import Any, Dict, List, Optional
+from typing import Any, Collection, Dict, List, Optional, Tuple, Union
 from unittest.mock import Mock
 from urllib.parse import urlencode
 
 import pymacaroons
+from typing_extensions import Literal
 
 from twisted.test.proto_helpers import MemoryReactor
 from twisted.web.resource import Resource
@@ -26,11 +27,12 @@ import synapse.rest.admin
 from synapse.api.constants import ApprovalNoticeMedium, LoginType
 from synapse.api.errors import Codes
 from synapse.appservice import ApplicationService
+from synapse.module_api import ModuleApi
 from synapse.rest.client import devices, login, logout, register
 from synapse.rest.client.account import WhoamiRestServlet
 from synapse.rest.synapse.client import build_synapse_client_resource_tree
 from synapse.server import HomeServer
-from synapse.types import create_requester
+from synapse.types import JsonDict, create_requester
 from synapse.util import Clock
 
 from tests import unittest
@@ -88,6 +90,56 @@ ADDITIONAL_LOGIN_FLOWS = [
 ]
 
 
+class TestSpamChecker:
+    def __init__(self, config: None, api: ModuleApi):
+        api.register_spam_checker_callbacks(
+            check_login_for_spam=self.check_login_for_spam,
+        )
+
+    @staticmethod
+    def parse_config(config: JsonDict) -> None:
+        return None
+
+    async def check_login_for_spam(
+        self,
+        user_id: str,
+        device_id: Optional[str],
+        initial_display_name: Optional[str],
+        request_info: Collection[Tuple[Optional[str], str]],
+        auth_provider_id: Optional[str] = None,
+    ) -> Union[
+        Literal["NOT_SPAM"],
+        Tuple["synapse.module_api.errors.Codes", JsonDict],
+    ]:
+        return "NOT_SPAM"
+
+
+class DenyAllSpamChecker:
+    def __init__(self, config: None, api: ModuleApi):
+        api.register_spam_checker_callbacks(
+            check_login_for_spam=self.check_login_for_spam,
+        )
+
+    @staticmethod
+    def parse_config(config: JsonDict) -> None:
+        return None
+
+    async def check_login_for_spam(
+        self,
+        user_id: str,
+        device_id: Optional[str],
+        initial_display_name: Optional[str],
+        request_info: Collection[Tuple[Optional[str], str]],
+        auth_provider_id: Optional[str] = None,
+    ) -> Union[
+        Literal["NOT_SPAM"],
+        Tuple["synapse.module_api.errors.Codes", JsonDict],
+    ]:
+        # Return an odd set of values to ensure that they get correctly passed
+        # to the client.
+        return Codes.LIMIT_EXCEEDED, {"extra": "value"}
+
+
 class LoginRestServletTestCase(unittest.HomeserverTestCase):
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -117,16 +169,17 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
                 # which sets these values to 10000, but as we're overriding the entire
                 # rc_login dict here, we need to set this manually as well
                 "account": {"per_second": 10000, "burst_count": 10000},
-            }
+            },
+            "experimental_features": {"msc4041_enabled": True},
         }
     )
     def test_POST_ratelimiting_per_address(self) -> None:
         # Create different users so we're sure not to be bothered by the per-user
         # ratelimiter.
-        for i in range(0, 6):
+        for i in range(6):
             self.register_user("kermit" + str(i), "monkey")
 
-        for i in range(0, 6):
+        for i in range(6):
             params = {
                 "type": "m.login.password",
                 "identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
@@ -137,12 +190,15 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             if i == 5:
                 self.assertEqual(channel.code, 429, msg=channel.result)
                 retry_after_ms = int(channel.json_body["retry_after_ms"])
+                retry_header = channel.headers.getRawHeaders("Retry-After")
             else:
                 self.assertEqual(channel.code, 200, msg=channel.result)
 
         # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
         # than 1min.
-        self.assertTrue(retry_after_ms < 6000)
+        self.assertLess(retry_after_ms, 6000)
+        assert retry_header
+        self.assertLessEqual(int(retry_header[0]), 6)
 
         self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
 
@@ -165,13 +221,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
                 # which sets these values to 10000, but as we're overriding the entire
                 # rc_login dict here, we need to set this manually as well
                 "address": {"per_second": 10000, "burst_count": 10000},
-            }
+            },
+            "experimental_features": {"msc4041_enabled": True},
         }
     )
     def test_POST_ratelimiting_per_account(self) -> None:
         self.register_user("kermit", "monkey")
 
-        for i in range(0, 6):
+        for i in range(6):
             params = {
                 "type": "m.login.password",
                 "identifier": {"type": "m.id.user", "user": "kermit"},
@@ -182,12 +239,15 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             if i == 5:
                 self.assertEqual(channel.code, 429, msg=channel.result)
                 retry_after_ms = int(channel.json_body["retry_after_ms"])
+                retry_header = channel.headers.getRawHeaders("Retry-After")
             else:
                 self.assertEqual(channel.code, 200, msg=channel.result)
 
         # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
         # than 1min.
-        self.assertTrue(retry_after_ms < 6000)
+        self.assertLess(retry_after_ms, 6000)
+        assert retry_header
+        self.assertLessEqual(int(retry_header[0]), 6)
 
         self.reactor.advance(retry_after_ms / 1000.0)
 
@@ -210,13 +270,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
                 # rc_login dict here, we need to set this manually as well
                 "address": {"per_second": 10000, "burst_count": 10000},
                 "failed_attempts": {"per_second": 0.17, "burst_count": 5},
-            }
+            },
+            "experimental_features": {"msc4041_enabled": True},
         }
     )
     def test_POST_ratelimiting_per_account_failed_attempts(self) -> None:
         self.register_user("kermit", "monkey")
 
-        for i in range(0, 6):
+        for i in range(6):
             params = {
                 "type": "m.login.password",
                 "identifier": {"type": "m.id.user", "user": "kermit"},
@@ -227,12 +288,15 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             if i == 5:
                 self.assertEqual(channel.code, 429, msg=channel.result)
                 retry_after_ms = int(channel.json_body["retry_after_ms"])
+                retry_header = channel.headers.getRawHeaders("Retry-After")
             else:
                 self.assertEqual(channel.code, 403, msg=channel.result)
 
         # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
         # than 1min.
-        self.assertTrue(retry_after_ms < 6000)
+        self.assertLess(retry_after_ms, 6000)
+        assert retry_header
+        self.assertLessEqual(int(retry_header[0]), 6)
 
         self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
 
@@ -446,6 +510,82 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
         )
 
+    def test_get_login_flows_with_login_via_existing_disabled(self) -> None:
+        """GET /login should return m.login.token without get_login_token"""
+        channel = self.make_request("GET", "/_matrix/client/r0/login")
+        self.assertEqual(channel.code, 200, channel.result)
+
+        flows = {flow["type"]: flow for flow in channel.json_body["flows"]}
+        self.assertNotIn("m.login.token", flows)
+
+    @override_config({"login_via_existing_session": {"enabled": True}})
+    def test_get_login_flows_with_login_via_existing_enabled(self) -> None:
+        """GET /login should return m.login.token with get_login_token true"""
+        channel = self.make_request("GET", "/_matrix/client/r0/login")
+        self.assertEqual(channel.code, 200, channel.result)
+
+        self.assertCountEqual(
+            channel.json_body["flows"],
+            [
+                {"type": "m.login.token", "get_login_token": True},
+                {"type": "m.login.password"},
+                {"type": "m.login.application_service"},
+            ],
+        )
+
+    @override_config(
+        {
+            "modules": [
+                {
+                    "module": TestSpamChecker.__module__
+                    + "."
+                    + TestSpamChecker.__qualname__
+                }
+            ]
+        }
+    )
+    def test_spam_checker_allow(self) -> None:
+        """Check that that adding a spam checker doesn't break login."""
+        self.register_user("kermit", "monkey")
+
+        body = {"type": "m.login.password", "user": "kermit", "password": "monkey"}
+
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/r0/login",
+            body,
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+
+    @override_config(
+        {
+            "modules": [
+                {
+                    "module": DenyAllSpamChecker.__module__
+                    + "."
+                    + DenyAllSpamChecker.__qualname__
+                }
+            ]
+        }
+    )
+    def test_spam_checker_deny(self) -> None:
+        """Check that login"""
+
+        self.register_user("kermit", "monkey")
+
+        body = {"type": "m.login.password", "user": "kermit", "password": "monkey"}
+
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/r0/login",
+            body,
+        )
+        self.assertEqual(channel.code, 403, channel.result)
+        self.assertLessEqual(
+            {"errcode": Codes.LIMIT_EXCEEDED, "extra": "value"}.items(),
+            channel.json_body.items(),
+        )
+
 
 @skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
 class MultiSSOTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py
index b8187db982..f05e619aa8 100644
--- a/tests/rest/client/test_login_token_request.py
+++ b/tests/rest/client/test_login_token_request.py
@@ -15,14 +15,14 @@
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.rest import admin
-from synapse.rest.client import login, login_token_request
+from synapse.rest.client import login, login_token_request, versions
 from synapse.server import HomeServer
 from synapse.util import Clock
 
 from tests import unittest
 from tests.unittest import override_config
 
-endpoint = "/_matrix/client/unstable/org.matrix.msc3882/login/token"
+GET_TOKEN_ENDPOINT = "/_matrix/client/v1/login/get_token"
 
 
 class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
@@ -30,6 +30,7 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
         login.register_servlets,
         admin.register_servlets,
         login_token_request.register_servlets,
+        versions.register_servlets,  # TODO: remove once unstable revision 0 support is removed
     ]
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
@@ -46,26 +47,26 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
         self.password = "password"
 
     def test_disabled(self) -> None:
-        channel = self.make_request("POST", endpoint, {}, access_token=None)
+        channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=None)
         self.assertEqual(channel.code, 404)
 
         self.register_user(self.user, self.password)
         token = self.login(self.user, self.password)
 
-        channel = self.make_request("POST", endpoint, {}, access_token=token)
+        channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=token)
         self.assertEqual(channel.code, 404)
 
-    @override_config({"experimental_features": {"msc3882_enabled": True}})
+    @override_config({"login_via_existing_session": {"enabled": True}})
     def test_require_auth(self) -> None:
-        channel = self.make_request("POST", endpoint, {}, access_token=None)
+        channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=None)
         self.assertEqual(channel.code, 401)
 
-    @override_config({"experimental_features": {"msc3882_enabled": True}})
+    @override_config({"login_via_existing_session": {"enabled": True}})
     def test_uia_on(self) -> None:
         user_id = self.register_user(self.user, self.password)
         token = self.login(self.user, self.password)
 
-        channel = self.make_request("POST", endpoint, {}, access_token=token)
+        channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=token)
         self.assertEqual(channel.code, 401)
         self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
 
@@ -80,9 +81,9 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
             },
         }
 
-        channel = self.make_request("POST", endpoint, uia, access_token=token)
+        channel = self.make_request("POST", GET_TOKEN_ENDPOINT, uia, access_token=token)
         self.assertEqual(channel.code, 200)
-        self.assertEqual(channel.json_body["expires_in"], 300)
+        self.assertEqual(channel.json_body["expires_in_ms"], 300000)
 
         login_token = channel.json_body["login_token"]
 
@@ -95,15 +96,15 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["user_id"], user_id)
 
     @override_config(
-        {"experimental_features": {"msc3882_enabled": True, "msc3882_ui_auth": False}}
+        {"login_via_existing_session": {"enabled": True, "require_ui_auth": False}}
     )
     def test_uia_off(self) -> None:
         user_id = self.register_user(self.user, self.password)
         token = self.login(self.user, self.password)
 
-        channel = self.make_request("POST", endpoint, {}, access_token=token)
+        channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=token)
         self.assertEqual(channel.code, 200)
-        self.assertEqual(channel.json_body["expires_in"], 300)
+        self.assertEqual(channel.json_body["expires_in_ms"], 300000)
 
         login_token = channel.json_body["login_token"]
 
@@ -117,10 +118,10 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
 
     @override_config(
         {
-            "experimental_features": {
-                "msc3882_enabled": True,
-                "msc3882_ui_auth": False,
-                "msc3882_token_timeout": "15s",
+            "login_via_existing_session": {
+                "enabled": True,
+                "require_ui_auth": False,
+                "token_timeout": "15s",
             }
         }
     )
@@ -128,6 +129,40 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
         self.register_user(self.user, self.password)
         token = self.login(self.user, self.password)
 
-        channel = self.make_request("POST", endpoint, {}, access_token=token)
+        channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=token)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body["expires_in_ms"], 15000)
+
+    @override_config(
+        {
+            "login_via_existing_session": {
+                "enabled": True,
+                "require_ui_auth": False,
+                "token_timeout": "15s",
+            }
+        }
+    )
+    def test_unstable_support(self) -> None:
+        # TODO: remove support for unstable MSC3882 is no longer needed
+
+        # check feature is advertised in versions response:
+        channel = self.make_request(
+            "GET", "/_matrix/client/versions", {}, access_token=None
+        )
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(
+            channel.json_body["unstable_features"]["org.matrix.msc3882"], True
+        )
+
+        self.register_user(self.user, self.password)
+        token = self.login(self.user, self.password)
+
+        # check feature is available via the unstable endpoint and returns an expires_in value in seconds
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/unstable/org.matrix.msc3882/login/token",
+            {},
+            access_token=token,
+        )
         self.assertEqual(channel.code, 200)
         self.assertEqual(channel.json_body["expires_in"], 15)
diff --git a/tests/rest/client/test_models.py b/tests/rest/client/test_models.py
index 0b8fcb0c47..524ea6047e 100644
--- a/tests/rest/client/test_models.py
+++ b/tests/rest/client/test_models.py
@@ -12,12 +12,18 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import unittest as stdlib_unittest
+from typing import TYPE_CHECKING
 
-from pydantic import BaseModel, ValidationError
 from typing_extensions import Literal
 
+from synapse._pydantic_compat import HAS_PYDANTIC_V2
 from synapse.rest.client.models import EmailRequestTokenBody
 
+if TYPE_CHECKING or HAS_PYDANTIC_V2:
+    from pydantic.v1 import BaseModel, ValidationError
+else:
+    from pydantic import BaseModel, ValidationError
+
 
 class ThreepidMediumEnumTestCase(stdlib_unittest.TestCase):
     class Model(BaseModel):
diff --git a/tests/rest/client/test_notifications.py b/tests/rest/client/test_notifications.py
index 700f6587a0..41ceb3db51 100644
--- a/tests/rest/client/test_notifications.py
+++ b/tests/rest/client/test_notifications.py
@@ -11,7 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -20,7 +20,6 @@ from synapse.rest.client import login, notifications, receipts, room
 from synapse.server import HomeServer
 from synapse.util import Clock
 
-from tests.test_utils import simple_async_mock
 from tests.unittest import HomeserverTestCase
 
 
@@ -45,7 +44,7 @@ class HTTPPusherTests(HomeserverTestCase):
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         # Mock out the calls over federation.
         fed_transport_client = Mock(spec=["send_transaction"])
-        fed_transport_client.send_transaction = simple_async_mock({})
+        fed_transport_client.send_transaction = AsyncMock(return_value={})
 
         return self.setup_test_homeserver(
             federation_transport_client=fed_transport_client,
diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py
index dcbb125a3b..66b387cea3 100644
--- a/tests/rest/client/test_presence.py
+++ b/tests/rest/client/test_presence.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from http import HTTPStatus
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -23,7 +23,6 @@ from synapse.types import UserID
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 
 
 class PresenceTestCase(unittest.HomeserverTestCase):
@@ -36,11 +35,10 @@ class PresenceTestCase(unittest.HomeserverTestCase):
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.presence_handler = Mock(spec=PresenceHandler)
-        self.presence_handler.set_state.return_value = make_awaitable(None)
+        self.presence_handler.set_state = AsyncMock(return_value=None)
 
         hs = self.setup_test_homeserver(
             "red",
-            federation_http_client=None,
             federation_client=Mock(),
             presence_handler=self.presence_handler,
         )
diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py
index 27c93ad761..ecae092b47 100644
--- a/tests/rest/client/test_profile.py
+++ b/tests/rest/client/test_profile.py
@@ -68,6 +68,18 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         res = self._get_displayname()
         self.assertEqual(res, "test")
 
+    def test_set_displayname_with_extra_spaces(self) -> None:
+        channel = self.make_request(
+            "PUT",
+            "/profile/%s/displayname" % (self.owner,),
+            content={"displayname": "  test  "},
+            access_token=self.owner_tok,
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+
+        res = self._get_displayname()
+        self.assertEqual(res, "test")
+
     def test_set_displayname_noauth(self) -> None:
         channel = self.make_request(
             "PUT",
diff --git a/tests/rest/client/test_public_rooms.py b/tests/rest/client/test_public_rooms.py
index ec5b19c33f..3de43760db 100644
--- a/tests/rest/client/test_public_rooms.py
+++ b/tests/rest/client/test_public_rooms.py
@@ -11,13 +11,11 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import List, Optional, Tuple
 
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.rest import admin, login, room
 from synapse.server import HomeServer
-from synapse.types import PublicRoom, ThirdPartyInstanceID
 from synapse.util import Clock
 
 from tests.unittest import HomeserverTestCase
@@ -87,6 +85,7 @@ class PublicRoomsTestCase(HomeserverTestCase):
     def test_pagination_limit_1(self) -> None:
         returned_rooms = set()
 
+        channel = None
         for i in range(10):
             next_batch = None if i == 0 else channel.json_body["next_batch"]
             since_query_str = f"&since={next_batch}" if next_batch else ""
@@ -115,6 +114,7 @@ class PublicRoomsTestCase(HomeserverTestCase):
     def test_pagination_limit_2(self) -> None:
         returned_rooms = set()
 
+        channel = None
         for i in range(5):
             next_batch = None if i == 0 else channel.json_body["next_batch"]
             since_query_str = f"&since={next_batch}" if next_batch else ""
diff --git a/tests/rest/client/test_push_rule_attrs.py b/tests/rest/client/test_push_rule_attrs.py
index 4f875b9289..5aca74475f 100644
--- a/tests/rest/client/test_push_rule_attrs.py
+++ b/tests/rest/client/test_push_rule_attrs.py
@@ -412,3 +412,70 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         )
         self.assertEqual(channel.code, 404)
         self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+    def test_contains_user_name(self) -> None:
+        """
+        Tests that `contains_user_name` rule is present and have proper value in `pattern`.
+        """
+        username = "bob"
+        self.register_user(username, "pass")
+        token = self.login(username, "pass")
+
+        channel = self.make_request(
+            "GET",
+            "/pushrules/global/content/.m.rule.contains_user_name",
+            access_token=token,
+        )
+
+        self.assertEqual(channel.code, 200)
+
+        self.assertEqual(
+            {
+                "rule_id": ".m.rule.contains_user_name",
+                "default": True,
+                "enabled": True,
+                "pattern": username,
+                "actions": [
+                    "notify",
+                    {"set_tweak": "highlight"},
+                    {"set_tweak": "sound", "value": "default"},
+                ],
+            },
+            channel.json_body,
+        )
+
+    def test_is_user_mention(self) -> None:
+        """
+        Tests that `is_user_mention` rule is present and have proper value in `value`.
+        """
+        user = self.register_user("bob", "pass")
+        token = self.login("bob", "pass")
+
+        channel = self.make_request(
+            "GET",
+            "/pushrules/global/override/.m.rule.is_user_mention",
+            access_token=token,
+        )
+
+        self.assertEqual(channel.code, 200)
+
+        self.assertEqual(
+            {
+                "rule_id": ".m.rule.is_user_mention",
+                "default": True,
+                "enabled": True,
+                "conditions": [
+                    {
+                        "kind": "event_property_contains",
+                        "key": "content.m\\.mentions.user_ids",
+                        "value": user,
+                    }
+                ],
+                "actions": [
+                    "notify",
+                    {"set_tweak": "highlight"},
+                    {"set_tweak": "sound", "value": "default"},
+                ],
+            },
+            channel.json_body,
+        )
diff --git a/tests/rest/client/test_read_marker.py b/tests/rest/client/test_read_marker.py
index 0eedcdb476..5cdd5694a0 100644
--- a/tests/rest/client/test_read_marker.py
+++ b/tests/rest/client/test_read_marker.py
@@ -131,9 +131,6 @@ class ReadMarkerTestCase(unittest.HomeserverTestCase):
         event = self.get_success(self.store.get_event(event_id_1, allow_none=True))
         assert event is None
 
-        # TODO See https://github.com/matrix-org/synapse/issues/13476
-        self.store.get_event_ordering.invalidate_all()
-
         # Test moving the read marker to a newer event
         event_id_2 = send_message()
         channel = self.make_request(
diff --git a/tests/rest/client/test_receipts.py b/tests/rest/client/test_receipts.py
index 2a7fcea386..ec638c89b7 100644
--- a/tests/rest/client/test_receipts.py
+++ b/tests/rest/client/test_receipts.py
@@ -11,11 +11,16 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from http import HTTPStatus
+from typing import Optional
+
 from twisted.test.proto_helpers import MemoryReactor
 
 import synapse.rest.admin
-from synapse.rest.client import login, receipts, register
+from synapse.api.constants import EduTypes, EventTypes, HistoryVisibility, ReceiptTypes
+from synapse.rest.client import login, receipts, room, sync
 from synapse.server import HomeServer
+from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests import unittest
@@ -24,30 +29,113 @@ from tests import unittest
 class ReceiptsTestCase(unittest.HomeserverTestCase):
     servlets = [
         login.register_servlets,
-        register.register_servlets,
         receipts.register_servlets,
         synapse.rest.admin.register_servlets,
+        room.register_servlets,
+        sync.register_servlets,
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.owner = self.register_user("owner", "pass")
-        self.owner_tok = self.login("owner", "pass")
+        self.url = "/sync?since=%s"
+        self.next_batch = "s0"
+
+        # Register the first user
+        self.user_id = self.register_user("kermit", "monkey")
+        self.tok = self.login("kermit", "monkey")
+
+        # Create the room
+        self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+        # Register the second user
+        self.user2 = self.register_user("kermit2", "monkey")
+        self.tok2 = self.login("kermit2", "monkey")
+
+        # Join the second user
+        self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
 
     def test_send_receipt(self) -> None:
+        # Send a message.
+        res = self.helper.send(self.room_id, body="hello", tok=self.tok)
+
+        # Send a read receipt
+        channel = self.make_request(
+            "POST",
+            f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}",
+            {},
+            access_token=self.tok2,
+        )
+        self.assertEqual(channel.code, 200)
+        self.assertNotEqual(self._get_read_receipt(), None)
+
+    def test_send_receipt_unknown_event(self) -> None:
+        """Receipts sent for unknown events are ignored to not break message retention."""
+        # Attempt to send a receipt to an unknown room.
         channel = self.make_request(
             "POST",
             "/rooms/!abc:beep/receipt/m.read/$def",
             content={},
-            access_token=self.owner_tok,
+            access_token=self.tok2,
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+        self.assertIsNone(self._get_read_receipt())
+
+        # Attempt to send a receipt to an unknown event.
+        channel = self.make_request(
+            "POST",
+            f"/rooms/{self.room_id}/receipt/m.read/$def",
+            content={},
+            access_token=self.tok2,
         )
         self.assertEqual(channel.code, 200, channel.result)
+        self.assertIsNone(self._get_read_receipt())
+
+    def test_send_receipt_unviewable_event(self) -> None:
+        """Receipts sent for unviewable events are errors."""
+        # Create a room where new users can't see events from before their join
+        # & send events into it.
+        room_id = self.helper.create_room_as(
+            self.user_id,
+            tok=self.tok,
+            extra_content={
+                "preset": "private_chat",
+                "initial_state": [
+                    {
+                        "content": {"history_visibility": HistoryVisibility.JOINED},
+                        "state_key": "",
+                        "type": EventTypes.RoomHistoryVisibility,
+                    }
+                ],
+            },
+        )
+        res = self.helper.send(room_id, body="hello", tok=self.tok)
+
+        # Attempt to send a receipt from the wrong user.
+        channel = self.make_request(
+            "POST",
+            f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}",
+            content={},
+            access_token=self.tok2,
+        )
+        self.assertEqual(channel.code, 403, channel.result)
+
+        # Join the user to the room, but they still can't see the event.
+        self.helper.invite(room_id, self.user_id, self.user2, tok=self.tok)
+        self.helper.join(room=room_id, user=self.user2, tok=self.tok2)
+
+        channel = self.make_request(
+            "POST",
+            f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}",
+            content={},
+            access_token=self.tok2,
+        )
+        self.assertEqual(channel.code, 403, channel.result)
 
     def test_send_receipt_invalid_room_id(self) -> None:
         channel = self.make_request(
             "POST",
             "/rooms/not-a-room-id/receipt/m.read/$def",
             content={},
-            access_token=self.owner_tok,
+            access_token=self.tok,
         )
         self.assertEqual(channel.code, 400, channel.result)
         self.assertEqual(
@@ -59,7 +147,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
             "POST",
             "/rooms/!abc:beep/receipt/m.read/not-an-event-id",
             content={},
-            access_token=self.owner_tok,
+            access_token=self.tok,
         )
         self.assertEqual(channel.code, 400, channel.result)
         self.assertEqual(
@@ -71,6 +159,123 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
             "POST",
             "/rooms/!abc:beep/receipt/invalid-receipt-type/$def",
             content={},
-            access_token=self.owner_tok,
+            access_token=self.tok,
         )
         self.assertEqual(channel.code, 400, channel.result)
+
+    def test_private_read_receipts(self) -> None:
+        # Send a message as the first user
+        res = self.helper.send(self.room_id, body="hello", tok=self.tok)
+
+        # Send a private read receipt to tell the server the first user's message was read
+        channel = self.make_request(
+            "POST",
+            f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}",
+            {},
+            access_token=self.tok2,
+        )
+        self.assertEqual(channel.code, 200)
+
+        # Test that the first user can't see the other user's private read receipt
+        self.assertIsNone(self._get_read_receipt())
+
+    def test_public_receipt_can_override_private(self) -> None:
+        """
+        Sending a public read receipt to the same event which has a private read
+        receipt should cause that receipt to become public.
+        """
+        # Send a message as the first user
+        res = self.helper.send(self.room_id, body="hello", tok=self.tok)
+
+        # Send a private read receipt
+        channel = self.make_request(
+            "POST",
+            f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}",
+            {},
+            access_token=self.tok2,
+        )
+        self.assertEqual(channel.code, 200)
+        self.assertIsNone(self._get_read_receipt())
+
+        # Send a public read receipt
+        channel = self.make_request(
+            "POST",
+            f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}",
+            {},
+            access_token=self.tok2,
+        )
+        self.assertEqual(channel.code, 200)
+
+        # Test that we did override the private read receipt
+        self.assertNotEqual(self._get_read_receipt(), None)
+
+    def test_private_receipt_cannot_override_public(self) -> None:
+        """
+        Sending a private read receipt to the same event which has a public read
+        receipt should cause no change.
+        """
+        # Send a message as the first user
+        res = self.helper.send(self.room_id, body="hello", tok=self.tok)
+
+        # Send a public read receipt
+        channel = self.make_request(
+            "POST",
+            f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}",
+            {},
+            access_token=self.tok2,
+        )
+        self.assertEqual(channel.code, 200)
+        self.assertNotEqual(self._get_read_receipt(), None)
+
+        # Send a private read receipt
+        channel = self.make_request(
+            "POST",
+            f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}",
+            {},
+            access_token=self.tok2,
+        )
+        self.assertEqual(channel.code, 200)
+
+        # Test that we didn't override the public read receipt
+        self.assertIsNone(self._get_read_receipt())
+
+    def test_read_receipt_with_empty_body_is_rejected(self) -> None:
+        # Send a message as the first user
+        res = self.helper.send(self.room_id, body="hello", tok=self.tok)
+
+        # Send a read receipt for this message with an empty body
+        channel = self.make_request(
+            "POST",
+            f"/rooms/{self.room_id}/receipt/m.read/{res['event_id']}",
+            access_token=self.tok2,
+        )
+        self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST)
+        self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON", channel.json_body)
+
+    def _get_read_receipt(self) -> Optional[JsonDict]:
+        """Syncs and returns the read receipt."""
+
+        # Checks if event is a read receipt
+        def is_read_receipt(event: JsonDict) -> bool:
+            return event["type"] == EduTypes.RECEIPT
+
+        # Sync
+        channel = self.make_request(
+            "GET",
+            self.url % self.next_batch,
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.code, 200)
+
+        # Store the next batch for the next request.
+        self.next_batch = channel.json_body["next_batch"]
+
+        if channel.json_body.get("rooms", None) is None:
+            return None
+
+        # Return the read receipt
+        ephemeral_events = channel.json_body["rooms"]["join"][self.room_id][
+            "ephemeral"
+        ]["events"]
+        receipt_event = filter(is_read_receipt, ephemeral_events)
+        return next(receipt_event, None)
diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py
index 84a60c0b07..4e0a387bd3 100644
--- a/tests/rest/client/test_redactions.py
+++ b/tests/rest/client/test_redactions.py
@@ -13,13 +13,17 @@
 # limitations under the License.
 from typing import List, Optional
 
+from parameterized import parameterized
+
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import EventTypes, RelationTypes
-from synapse.api.room_versions import RoomVersions
+from synapse.api.room_versions import RoomVersion, RoomVersions
 from synapse.rest import admin
 from synapse.rest.client import login, room, sync
 from synapse.server import HomeServer
+from synapse.storage._base import db_to_json
+from synapse.storage.database import LoggingTransaction
 from synapse.types import JsonDict
 from synapse.util import Clock
 
@@ -217,9 +221,9 @@ class RedactionsTestCase(HomeserverTestCase):
             self._redact_event(self.mod_access_token, self.room_id, msg_id)
 
     @override_config({"experimental_features": {"msc3912_enabled": True}})
-    def test_redact_relations(self) -> None:
-        """Tests that we can redact the relations of an event at the same time as the
-        event itself.
+    def test_redact_relations_with_types(self) -> None:
+        """Tests that we can redact the relations of an event of specific types
+        at the same time as the event itself.
         """
         # Send a root event.
         res = self.helper.send_event(
@@ -318,6 +322,104 @@ class RedactionsTestCase(HomeserverTestCase):
         self.assertNotIn("redacted_because", event_dict, event_dict)
 
     @override_config({"experimental_features": {"msc3912_enabled": True}})
+    def test_redact_all_relations(self) -> None:
+        """Tests that we can redact all the relations of an event at the same time as the
+        event itself.
+        """
+        # Send a root event.
+        res = self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={"msgtype": "m.text", "body": "hello"},
+            tok=self.mod_access_token,
+        )
+        root_event_id = res["event_id"]
+
+        # Send an edit to this root event.
+        res = self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={
+                "body": " * hello world",
+                "m.new_content": {
+                    "body": "hello world",
+                    "msgtype": "m.text",
+                },
+                "m.relates_to": {
+                    "event_id": root_event_id,
+                    "rel_type": RelationTypes.REPLACE,
+                },
+                "msgtype": "m.text",
+            },
+            tok=self.mod_access_token,
+        )
+        edit_event_id = res["event_id"]
+
+        # Also send a threaded message whose root is the same as the edit's.
+        res = self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "message 1",
+                "m.relates_to": {
+                    "event_id": root_event_id,
+                    "rel_type": RelationTypes.THREAD,
+                },
+            },
+            tok=self.mod_access_token,
+        )
+        threaded_event_id = res["event_id"]
+
+        # Also send a reaction, again with the same root.
+        res = self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Reaction,
+            content={
+                "m.relates_to": {
+                    "rel_type": RelationTypes.ANNOTATION,
+                    "event_id": root_event_id,
+                    "key": "👍",
+                }
+            },
+            tok=self.mod_access_token,
+        )
+        reaction_event_id = res["event_id"]
+
+        # Redact the root event, specifying that we also want to delete all events that
+        # relate to it.
+        self._redact_event(
+            self.mod_access_token,
+            self.room_id,
+            root_event_id,
+            with_relations=["*"],
+        )
+
+        # Check that the root event got redacted.
+        event_dict = self.helper.get_event(
+            self.room_id, root_event_id, self.mod_access_token
+        )
+        self.assertIn("redacted_because", event_dict, event_dict)
+
+        # Check that the edit got redacted.
+        event_dict = self.helper.get_event(
+            self.room_id, edit_event_id, self.mod_access_token
+        )
+        self.assertIn("redacted_because", event_dict, event_dict)
+
+        # Check that the threaded message got redacted.
+        event_dict = self.helper.get_event(
+            self.room_id, threaded_event_id, self.mod_access_token
+        )
+        self.assertIn("redacted_because", event_dict, event_dict)
+
+        # Check that the reaction got redacted.
+        event_dict = self.helper.get_event(
+            self.room_id, reaction_event_id, self.mod_access_token
+        )
+        self.assertIn("redacted_because", event_dict, event_dict)
+
+    @override_config({"experimental_features": {"msc3912_enabled": True}})
     def test_redact_relations_no_perms(self) -> None:
         """Tests that, when redacting a message along with its relations, if not all
         the related messages can be redacted because of insufficient permissions, the
@@ -469,35 +571,81 @@ class RedactionsTestCase(HomeserverTestCase):
         self.assertIn("body", event_dict["content"], event_dict)
         self.assertEqual("I'm in a thread!", event_dict["content"]["body"])
 
-    def test_content_redaction(self) -> None:
-        """MSC2174 moved the redacts property to the content."""
+    @parameterized.expand(
+        [
+            # Tuples of:
+            #   Room version
+            #   Boolean: True if the redaction event content should include the event ID.
+            #   Boolean: true if the resulting redaction event is expected to include the
+            #            event ID in the content.
+            (RoomVersions.V10, False, False),
+            (RoomVersions.V11, True, True),
+            (RoomVersions.V11, False, True),
+        ]
+    )
+    def test_redaction_content(
+        self, room_version: RoomVersion, include_content: bool, expect_content: bool
+    ) -> None:
+        """
+        Room version 11 moved the redacts property to the content.
+
+        Ensure that the event gets created properly and that the Client-Server
+        API servers the proper backwards-compatible version.
+        """
         # Create a room with the newer room version.
         room_id = self.helper.create_room_as(
             self.mod_user_id,
             tok=self.mod_access_token,
-            room_version=RoomVersions.MSC2176.identifier,
+            room_version=room_version.identifier,
         )
 
         # Create an event.
         b = self.helper.send(room_id=room_id, tok=self.mod_access_token)
         event_id = b["event_id"]
 
-        # Attempt to redact it with a bogus event ID.
-        self._redact_event(
+        # Ensure the event ID in the URL and the content must match.
+        if include_content:
+            self._redact_event(
+                self.mod_access_token,
+                room_id,
+                event_id,
+                expect_code=400,
+                content={"redacts": "foo"},
+            )
+
+        # Redact it for real.
+        result = self._redact_event(
             self.mod_access_token,
             room_id,
             event_id,
-            expect_code=400,
-            content={"redacts": "foo"},
+            content={"redacts": event_id} if include_content else {},
         )
-
-        # Redact it for real.
-        self._redact_event(self.mod_access_token, room_id, event_id)
+        redaction_event_id = result["event_id"]
 
         # Sync the room, to get the id of the create event
         timeline = self._sync_room_timeline(self.mod_access_token, room_id)
         redact_event = timeline[-1]
         self.assertEqual(redact_event["type"], EventTypes.Redaction)
-        # The redacts key should be in the content.
-        self.assertNotIn("redacts", redact_event)
-        self.assertEquals(redact_event["content"]["redacts"], event_id)
+        # The redacts key should be in the content and the redacts keys.
+        self.assertEqual(redact_event["content"]["redacts"], event_id)
+        self.assertEqual(redact_event["redacts"], event_id)
+
+        # But it isn't actually part of the event.
+        def get_event(txn: LoggingTransaction) -> JsonDict:
+            return db_to_json(
+                main_datastore._fetch_event_rows(txn, [redaction_event_id])[
+                    redaction_event_id
+                ].json
+            )
+
+        main_datastore = self.hs.get_datastores().main
+        event_json = self.get_success(
+            main_datastore.db_pool.runInteraction("get_event", get_event)
+        )
+        self.assertEqual(event_json["type"], EventTypes.Redaction)
+        if expect_content:
+            self.assertNotIn("redacts", event_json)
+            self.assertEqual(event_json["content"]["redacts"], event_id)
+        else:
+            self.assertEqual(event_json["redacts"], event_id)
+            self.assertNotIn("redacts", event_json["content"])
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index b228dba861..ba4e017a0e 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -75,7 +75,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(channel.code, 200, msg=channel.result)
         det_data = {"user_id": user_id, "home_server": self.hs.hostname}
-        self.assertDictContainsSubset(det_data, channel.json_body)
+        self.assertLessEqual(det_data.items(), channel.json_body.items())
 
     def test_POST_appservice_registration_no_type(self) -> None:
         as_token = "i_am_an_app_service"
@@ -136,7 +136,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             "device_id": device_id,
         }
         self.assertEqual(channel.code, 200, msg=channel.result)
-        self.assertDictContainsSubset(det_data, channel.json_body)
+        self.assertLessEqual(det_data.items(), channel.json_body.items())
 
     @override_config({"enable_registration": False})
     def test_POST_disabled_registration(self) -> None:
@@ -157,7 +157,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
         det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"}
         self.assertEqual(channel.code, 200, msg=channel.result)
-        self.assertDictContainsSubset(det_data, channel.json_body)
+        self.assertLessEqual(det_data.items(), channel.json_body.items())
 
     def test_POST_disabled_guest_registration(self) -> None:
         self.hs.config.registration.allow_guest_access = False
@@ -169,7 +169,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
     @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
     def test_POST_ratelimiting_guest(self) -> None:
-        for i in range(0, 6):
+        for i in range(6):
             url = self.url + b"?kind=guest"
             channel = self.make_request(b"POST", url, b"{}")
 
@@ -187,7 +187,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
     @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
     def test_POST_ratelimiting(self) -> None:
-        for i in range(0, 6):
+        for i in range(6):
             request_data = {
                 "username": "kermit" + str(i),
                 "password": "monkey",
@@ -267,7 +267,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             "device_id": device_id,
         }
         self.assertEqual(channel.code, 200, msg=channel.result)
-        self.assertDictContainsSubset(det_data, channel.json_body)
+        self.assertLessEqual(det_data.items(), channel.json_body.items())
 
         # Check the `completed` counter has been incremented and pending is 0
         res = self.get_success(
@@ -1223,7 +1223,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
     def test_GET_ratelimiting(self) -> None:
         token = "1234"
 
-        for i in range(0, 6):
+        for i in range(6):
             channel = self.make_request(
                 b"GET",
                 f"{self.url}?token={token}",
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 75439416c1..61773fb28c 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -15,7 +15,7 @@
 
 import urllib.parse
 from typing import Any, Callable, Dict, List, Optional, Tuple
-from unittest.mock import patch
+from unittest.mock import AsyncMock, patch
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -28,7 +28,6 @@ from synapse.util import Clock
 
 from tests import unittest
 from tests.server import FakeChannel
-from tests.test_utils import make_awaitable
 from tests.test_utils.event_injection import inject_event
 from tests.unittest import override_config
 
@@ -129,7 +128,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
             f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}",
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         return [ev["event_id"] for ev in channel.json_body["chunk"]]
 
     def _get_bundled_aggregations(self) -> JsonDict:
@@ -142,7 +141,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
             f"/_matrix/client/v3/rooms/{self.room}/event/{self.parent_id}",
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         return channel.json_body["unsigned"].get("m.relations", {})
 
     def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
@@ -264,7 +263,8 @@ class RelationsTestCase(BaseRelationsTestCase):
         # Disable the validation to pretend this came over federation.
         with patch(
             "synapse.handlers.message.EventCreationHandler._validate_event_relation",
-            new=lambda self, event: make_awaitable(None),
+            new_callable=AsyncMock,
+            return_value=None,
         ):
             # Generate a various relations from a different room.
             self.get_success(
@@ -570,7 +570,7 @@ class RelationsTestCase(BaseRelationsTestCase):
         )
         self.assertEqual(200, channel.code, channel.json_body)
         event_result = channel.json_body
-        self.assertDictContainsSubset(original_body, event_result["content"])
+        self.assertLessEqual(original_body.items(), event_result["content"].items())
 
         # also check /context, which returns the *edited* event
         channel = self.make_request(
@@ -587,14 +587,14 @@ class RelationsTestCase(BaseRelationsTestCase):
             (context_result, "/context"),
         ):
             # The reference metadata should still be intact.
-            self.assertDictContainsSubset(
+            self.assertLessEqual(
                 {
                     "m.relates_to": {
                         "event_id": self.parent_id,
                         "rel_type": "m.reference",
                     }
-                },
-                result_event_dict["content"],
+                }.items(),
+                result_event_dict["content"].items(),
                 desc,
             )
 
@@ -1300,7 +1300,8 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
         # not an event the Client-Server API will allow..
         with patch(
             "synapse.handlers.message.EventCreationHandler._validate_event_relation",
-            new=lambda self, event: make_awaitable(None),
+            new_callable=AsyncMock,
+            return_value=None,
         ):
             # Create a sub-thread off the thread, which is not allowed.
             self._send_relation(
@@ -1371,9 +1372,11 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
         latest_event_in_thread = thread_summary["latest_event"]
         # The latest event in the thread should have the edit appear under the
         # bundled aggregations.
-        self.assertDictContainsSubset(
-            {"event_id": edit_event_id, "sender": "@alice:test"},
-            latest_event_in_thread["unsigned"]["m.relations"][RelationTypes.REPLACE],
+        self.assertLessEqual(
+            {"event_id": edit_event_id, "sender": "@alice:test"}.items(),
+            latest_event_in_thread["unsigned"]["m.relations"][
+                RelationTypes.REPLACE
+            ].items(),
         )
 
     def test_aggregation_get_event_for_annotation(self) -> None:
@@ -1602,7 +1605,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
             f"/_matrix/client/v1/rooms/{self.room}/threads",
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         threads = channel.json_body["chunk"]
         return [
             (
@@ -1634,11 +1637,11 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
         ##################################################
         # Check the test data is configured as expected. #
         ##################################################
-        self.assertEquals(self._get_related_events(), list(reversed(thread_replies)))
+        self.assertEqual(self._get_related_events(), list(reversed(thread_replies)))
         relations = self._get_bundled_aggregations()
-        self.assertDictContainsSubset(
-            {"count": 3, "current_user_participated": True},
-            relations[RelationTypes.THREAD],
+        self.assertLessEqual(
+            {"count": 3, "current_user_participated": True}.items(),
+            relations[RelationTypes.THREAD].items(),
         )
         # The latest event is the last sent event.
         self.assertEqual(
@@ -1655,11 +1658,11 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
         self._redact(thread_replies.pop())
 
         # The thread should still exist, but the latest event should be updated.
-        self.assertEquals(self._get_related_events(), list(reversed(thread_replies)))
+        self.assertEqual(self._get_related_events(), list(reversed(thread_replies)))
         relations = self._get_bundled_aggregations()
-        self.assertDictContainsSubset(
-            {"count": 2, "current_user_participated": True},
-            relations[RelationTypes.THREAD],
+        self.assertLessEqual(
+            {"count": 2, "current_user_participated": True}.items(),
+            relations[RelationTypes.THREAD].items(),
         )
         # And the latest event is the last unredacted event.
         self.assertEqual(
@@ -1674,11 +1677,11 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
         self._redact(thread_replies.pop(0))
 
         # Nothing should have changed (except the thread count).
-        self.assertEquals(self._get_related_events(), thread_replies)
+        self.assertEqual(self._get_related_events(), thread_replies)
         relations = self._get_bundled_aggregations()
-        self.assertDictContainsSubset(
-            {"count": 1, "current_user_participated": True},
-            relations[RelationTypes.THREAD],
+        self.assertLessEqual(
+            {"count": 1, "current_user_participated": True}.items(),
+            relations[RelationTypes.THREAD].items(),
         )
         # And the latest event is the last unredacted event.
         self.assertEqual(
@@ -1691,11 +1694,11 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
         # Redact the last remaining event. #
         ####################################
         self._redact(thread_replies.pop(0))
-        self.assertEquals(thread_replies, [])
+        self.assertEqual(thread_replies, [])
 
         # The event should no longer be considered a thread.
-        self.assertEquals(self._get_related_events(), [])
-        self.assertEquals(self._get_bundled_aggregations(), {})
+        self.assertEqual(self._get_related_events(), [])
+        self.assertEqual(self._get_bundled_aggregations(), {})
         self.assertEqual(self._get_threads(), [])
 
     def test_redact_parent_edit(self) -> None:
@@ -1749,8 +1752,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
         # The relations are returned.
         event_ids = self._get_related_events()
         relations = self._get_bundled_aggregations()
-        self.assertEquals(event_ids, [related_event_id])
-        self.assertEquals(
+        self.assertEqual(event_ids, [related_event_id])
+        self.assertEqual(
             relations[RelationTypes.REFERENCE],
             {"chunk": [{"event_id": related_event_id}]},
         )
@@ -1772,13 +1775,13 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
         # The unredacted relation should still exist.
         event_ids = self._get_related_events()
         relations = self._get_bundled_aggregations()
-        self.assertEquals(len(event_ids), 1)
-        self.assertDictContainsSubset(
+        self.assertEqual(len(event_ids), 1)
+        self.assertLessEqual(
             {
                 "count": 1,
                 "current_user_participated": True,
-            },
-            relations[RelationTypes.THREAD],
+            }.items(),
+            relations[RelationTypes.THREAD].items(),
         )
         self.assertEqual(
             relations[RelationTypes.THREAD]["latest_event"]["event_id"],
@@ -1816,7 +1819,7 @@ class ThreadsTestCase(BaseRelationsTestCase):
             f"/_matrix/client/v1/rooms/{self.room}/threads",
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         threads = self._get_threads(channel.json_body)
         self.assertEqual(threads, [(thread_2, reply_2), (thread_1, reply_1)])
 
@@ -1829,7 +1832,7 @@ class ThreadsTestCase(BaseRelationsTestCase):
             f"/_matrix/client/v1/rooms/{self.room}/threads",
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         # Tuple of (thread ID, latest event ID) for each thread.
         threads = self._get_threads(channel.json_body)
         self.assertEqual(threads, [(thread_1, reply_3), (thread_2, reply_2)])
@@ -1850,7 +1853,7 @@ class ThreadsTestCase(BaseRelationsTestCase):
             f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1",
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
         self.assertEqual(thread_roots, [thread_2])
 
@@ -1864,7 +1867,7 @@ class ThreadsTestCase(BaseRelationsTestCase):
             f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1&from={next_batch}",
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
         self.assertEqual(thread_roots, [thread_1], channel.json_body)
 
@@ -1899,7 +1902,7 @@ class ThreadsTestCase(BaseRelationsTestCase):
             f"/_matrix/client/v1/rooms/{self.room}/threads",
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
         self.assertEqual(
             thread_roots, [thread_3, thread_2, thread_1], channel.json_body
@@ -1911,7 +1914,7 @@ class ThreadsTestCase(BaseRelationsTestCase):
             f"/_matrix/client/v1/rooms/{self.room}/threads?include=participated",
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
         self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body)
 
@@ -1943,6 +1946,6 @@ class ThreadsTestCase(BaseRelationsTestCase):
             f"/_matrix/client/v1/rooms/{self.room}/threads",
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
         self.assertEqual(thread_roots, [thread_1], channel.json_body)
diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py
deleted file mode 100644
index 9d5cb60d16..0000000000
--- a/tests/rest/client/test_room_batch.py
+++ /dev/null
@@ -1,302 +0,0 @@
-import logging
-from typing import List, Tuple
-from unittest.mock import Mock, patch
-
-from twisted.test.proto_helpers import MemoryReactor
-
-from synapse.api.constants import EventContentFields, EventTypes
-from synapse.appservice import ApplicationService
-from synapse.rest import admin
-from synapse.rest.client import login, register, room, room_batch, sync
-from synapse.server import HomeServer
-from synapse.types import JsonDict, RoomStreamToken
-from synapse.util import Clock
-
-from tests import unittest
-
-logger = logging.getLogger(__name__)
-
-
-def _create_join_state_events_for_batch_send_request(
-    virtual_user_ids: List[str],
-    insert_time: int,
-) -> List[JsonDict]:
-    return [
-        {
-            "type": EventTypes.Member,
-            "sender": virtual_user_id,
-            "origin_server_ts": insert_time,
-            "content": {
-                "membership": "join",
-                "displayname": "display-name-for-%s" % (virtual_user_id,),
-            },
-            "state_key": virtual_user_id,
-        }
-        for virtual_user_id in virtual_user_ids
-    ]
-
-
-def _create_message_events_for_batch_send_request(
-    virtual_user_id: str, insert_time: int, count: int
-) -> List[JsonDict]:
-    return [
-        {
-            "type": EventTypes.Message,
-            "sender": virtual_user_id,
-            "origin_server_ts": insert_time,
-            "content": {
-                "msgtype": "m.text",
-                "body": "Historical %d" % (i),
-                EventContentFields.MSC2716_HISTORICAL: True,
-            },
-        }
-        for i in range(count)
-    ]
-
-
-class RoomBatchTestCase(unittest.HomeserverTestCase):
-    """Test importing batches of historical messages."""
-
-    servlets = [
-        admin.register_servlets_for_client_rest_resource,
-        room_batch.register_servlets,
-        room.register_servlets,
-        register.register_servlets,
-        login.register_servlets,
-        sync.register_servlets,
-    ]
-
-    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-        config = self.default_config()
-
-        self.appservice = ApplicationService(
-            token="i_am_an_app_service",
-            id="1234",
-            namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
-            # Note: this user does not have to match the regex above
-            sender="@as_main:test",
-        )
-
-        mock_load_appservices = Mock(return_value=[self.appservice])
-        with patch(
-            "synapse.storage.databases.main.appservice.load_appservices",
-            mock_load_appservices,
-        ):
-            hs = self.setup_test_homeserver(config=config)
-        return hs
-
-    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.clock = clock
-        self._storage_controllers = hs.get_storage_controllers()
-
-        self.virtual_user_id, _ = self.register_appservice_user(
-            "as_user_potato", self.appservice.token
-        )
-
-    def _create_test_room(self) -> Tuple[str, str, str, str]:
-        room_id = self.helper.create_room_as(
-            self.appservice.sender, tok=self.appservice.token
-        )
-
-        res_a = self.helper.send_event(
-            room_id=room_id,
-            type=EventTypes.Message,
-            content={
-                "msgtype": "m.text",
-                "body": "A",
-            },
-            tok=self.appservice.token,
-        )
-        event_id_a = res_a["event_id"]
-
-        res_b = self.helper.send_event(
-            room_id=room_id,
-            type=EventTypes.Message,
-            content={
-                "msgtype": "m.text",
-                "body": "B",
-            },
-            tok=self.appservice.token,
-        )
-        event_id_b = res_b["event_id"]
-
-        res_c = self.helper.send_event(
-            room_id=room_id,
-            type=EventTypes.Message,
-            content={
-                "msgtype": "m.text",
-                "body": "C",
-            },
-            tok=self.appservice.token,
-        )
-        event_id_c = res_c["event_id"]
-
-        return room_id, event_id_a, event_id_b, event_id_c
-
-    @unittest.override_config({"experimental_features": {"msc2716_enabled": True}})
-    def test_same_state_groups_for_whole_historical_batch(self) -> None:
-        """Make sure that when using the `/batch_send` endpoint to import a
-        bunch of historical messages, it re-uses the same `state_group` across
-        the whole batch. This is an easy optimization to make sure we're getting
-        right because the state for the whole batch is contained in
-        `state_events_at_start` and can be shared across everything.
-        """
-
-        time_before_room = int(self.clock.time_msec())
-        room_id, event_id_a, _, _ = self._create_test_room()
-
-        channel = self.make_request(
-            "POST",
-            "/_matrix/client/unstable/org.matrix.msc2716/rooms/%s/batch_send?prev_event_id=%s"
-            % (room_id, event_id_a),
-            content={
-                "events": _create_message_events_for_batch_send_request(
-                    self.virtual_user_id, time_before_room, 3
-                ),
-                "state_events_at_start": _create_join_state_events_for_batch_send_request(
-                    [self.virtual_user_id], time_before_room
-                ),
-            },
-            access_token=self.appservice.token,
-        )
-        self.assertEqual(channel.code, 200, channel.result)
-
-        # Get the historical event IDs that we just imported
-        historical_event_ids = channel.json_body["event_ids"]
-        self.assertEqual(len(historical_event_ids), 3)
-
-        # Fetch the state_groups
-        state_group_map = self.get_success(
-            self._storage_controllers.state.get_state_groups_ids(
-                room_id, historical_event_ids
-            )
-        )
-
-        # We expect all of the historical events to be using the same state_group
-        # so there should only be a single state_group here!
-        self.assertEqual(
-            len(state_group_map.keys()),
-            1,
-            "Expected a single state_group to be returned by saw state_groups=%s"
-            % (state_group_map.keys(),),
-        )
-
-    @unittest.override_config({"experimental_features": {"msc2716_enabled": True}})
-    def test_sync_while_batch_importing(self) -> None:
-        """
-        Make sure that /sync correctly returns full room state when a user joins
-        during ongoing batch backfilling.
-        See: https://github.com/matrix-org/synapse/issues/12281
-        """
-        # Create user who will be invited & join room
-        user_id = self.register_user("beep", "test")
-        user_tok = self.login("beep", "test")
-
-        time_before_room = int(self.clock.time_msec())
-
-        # Create a room with some events
-        room_id, _, _, _ = self._create_test_room()
-        # Invite the user
-        self.helper.invite(
-            room_id, src=self.appservice.sender, tok=self.appservice.token, targ=user_id
-        )
-
-        # Create another room, send a bunch of events to advance the stream token
-        other_room_id = self.helper.create_room_as(
-            self.appservice.sender, tok=self.appservice.token
-        )
-        for _ in range(5):
-            self.helper.send_event(
-                room_id=other_room_id,
-                type=EventTypes.Message,
-                content={"msgtype": "m.text", "body": "C"},
-                tok=self.appservice.token,
-            )
-
-        # Join the room as the normal user
-        self.helper.join(room_id, user_id, tok=user_tok)
-
-        # Create an event to hang the historical batch from - In order to see
-        # the failure case originally reported in #12281, the historical batch
-        # must be hung from the most recent event in the room so the base
-        # insertion event ends up with the highest `topogological_ordering`
-        # (`depth`) in the room but will have a negative `stream_ordering`
-        # because it's a `historical` event. Previously, when assembling the
-        # `state` for the `/sync` response, the bugged logic would sort by
-        # `topological_ordering` descending and pick up the base insertion
-        # event because it has a negative `stream_ordering` below the given
-        # pagination token. Now we properly sort by `stream_ordering`
-        # descending which puts `historical` events with a negative
-        # `stream_ordering` way at the bottom and aren't selected as expected.
-        response = self.helper.send_event(
-            room_id=room_id,
-            type=EventTypes.Message,
-            content={
-                "msgtype": "m.text",
-                "body": "C",
-            },
-            tok=self.appservice.token,
-        )
-        event_to_hang_id = response["event_id"]
-
-        channel = self.make_request(
-            "POST",
-            "/_matrix/client/unstable/org.matrix.msc2716/rooms/%s/batch_send?prev_event_id=%s"
-            % (room_id, event_to_hang_id),
-            content={
-                "events": _create_message_events_for_batch_send_request(
-                    self.virtual_user_id, time_before_room, 3
-                ),
-                "state_events_at_start": _create_join_state_events_for_batch_send_request(
-                    [self.virtual_user_id], time_before_room
-                ),
-            },
-            access_token=self.appservice.token,
-        )
-        self.assertEqual(channel.code, 200, channel.result)
-
-        # Now we need to find the invite + join events stream tokens so we can sync between
-        main_store = self.hs.get_datastores().main
-        events, next_key = self.get_success(
-            main_store.get_recent_events_for_room(
-                room_id,
-                50,
-                end_token=main_store.get_room_max_token(),
-            ),
-        )
-        invite_event_position = None
-        for event in events:
-            if (
-                event.type == "m.room.member"
-                and event.content["membership"] == "invite"
-            ):
-                invite_event_position = self.get_success(
-                    main_store.get_topological_token_for_event(event.event_id)
-                )
-                break
-
-        assert invite_event_position is not None, "No invite event found"
-
-        # Remove the topological order from the token by re-creating w/stream only
-        invite_event_position = RoomStreamToken(None, invite_event_position.stream)
-
-        # Sync everything after this token
-        since_token = self.get_success(invite_event_position.to_string(main_store))
-        sync_response = self.make_request(
-            "GET",
-            f"/sync?since={since_token}",
-            access_token=user_tok,
-        )
-
-        # Assert that, for this room, the user was considered to have joined and thus
-        # receives the full state history
-        state_event_types = [
-            event["type"]
-            for event in sync_response.json_body["rooms"]["join"][room_id]["state"][
-                "events"
-            ]
-        ]
-
-        assert (
-            "m.room.create" in state_event_types
-        ), "Missing room full state in sync response"
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 4d39c89f6f..aaa4f3bba0 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -20,7 +20,7 @@
 import json
 from http import HTTPStatus
 from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
-from unittest.mock import Mock, call, patch
+from unittest.mock import AsyncMock, Mock, call, patch
 from urllib import parse as urlparse
 
 from parameterized import param, parameterized
@@ -41,7 +41,6 @@ from synapse.api.errors import Codes, HttpResponseException
 from synapse.appservice import ApplicationService
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
-from synapse.handlers.pagination import PurgeStatus
 from synapse.rest import admin
 from synapse.rest.client import account, directory, login, profile, register, room, sync
 from synapse.server import HomeServer
@@ -52,7 +51,6 @@ from synapse.util.stringutils import random_string
 from tests import unittest
 from tests.http.server._base import make_request_with_cancellation_test
 from tests.storage.test_stream import PaginationTestCase
-from tests.test_utils import make_awaitable
 from tests.test_utils.event_injection import create_event
 from tests.unittest import override_config
 
@@ -67,19 +65,17 @@ class RoomBase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.hs = self.setup_test_homeserver(
             "red",
-            federation_http_client=None,
-            federation_client=Mock(),
         )
 
-        self.hs.get_federation_handler = Mock()  # type: ignore[assignment]
-        self.hs.get_federation_handler.return_value.maybe_backfill = Mock(
-            return_value=make_awaitable(None)
+        self.hs.get_federation_handler = Mock()  # type: ignore[method-assign]
+        self.hs.get_federation_handler.return_value.maybe_backfill = AsyncMock(
+            return_value=None
         )
 
         async def _insert_client_ip(*args: Any, **kwargs: Any) -> None:
             return None
 
-        self.hs.get_datastores().main.insert_client_ip = _insert_client_ip  # type: ignore[assignment]
+        self.hs.get_datastores().main.insert_client_ip = _insert_client_ip  # type: ignore[method-assign]
 
         return self.hs
 
@@ -713,7 +709,7 @@ class RoomsCreateTestCase(RoomBase):
         self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
         self.assertTrue("room_id" in channel.json_body)
         assert channel.resource_usage is not None
-        self.assertEqual(30, channel.resource_usage.db_txn_count)
+        self.assertEqual(32, channel.resource_usage.db_txn_count)
 
     def test_post_room_initial_state(self) -> None:
         # POST with initial_state config key, expect new room id
@@ -726,7 +722,7 @@ class RoomsCreateTestCase(RoomBase):
         self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
         self.assertTrue("room_id" in channel.json_body)
         assert channel.resource_usage is not None
-        self.assertEqual(32, channel.resource_usage.db_txn_count)
+        self.assertEqual(34, channel.resource_usage.db_txn_count)
 
     def test_post_room_visibility_key(self) -> None:
         # POST with visibility config key, expect new room id
@@ -1364,7 +1360,7 @@ class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase):
 
         # Ensure the event was persisted with the correct timestamp.
         res = self.get_success(self.main_store.get_event(event_id))
-        self.assertEquals(ts, res.origin_server_ts)
+        self.assertEqual(ts, res.origin_server_ts)
 
     def test_send_state_event_ts(self) -> None:
         """Test sending a state event with a custom timestamp."""
@@ -1386,7 +1382,7 @@ class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase):
 
         # Ensure the event was persisted with the correct timestamp.
         res = self.get_success(self.main_store.get_event(event_id))
-        self.assertEquals(ts, res.origin_server_ts)
+        self.assertEqual(ts, res.origin_server_ts)
 
     def test_send_membership_event_ts(self) -> None:
         """Test sending a membership event with a custom timestamp."""
@@ -1408,7 +1404,7 @@ class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase):
 
         # Ensure the event was persisted with the correct timestamp.
         res = self.get_success(self.main_store.get_event(event_id))
-        self.assertEquals(ts, res.origin_server_ts)
+        self.assertEqual(ts, res.origin_server_ts)
 
 
 class RoomJoinRatelimitTestCase(RoomBase):
@@ -1451,6 +1447,30 @@ class RoomJoinRatelimitTestCase(RoomBase):
     @unittest.override_config(
         {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
     )
+    def test_join_attempts_local_ratelimit(self) -> None:
+        """Tests that unsuccessful joins that end up being denied are rate-limited."""
+        # Create 4 rooms
+        room_ids = [
+            self.helper.create_room_as(self.user_id, is_public=True) for _ in range(4)
+        ]
+        # Pre-emptively ban the user who will attempt to join.
+        joiner_user_id = self.register_user("joiner", "secret")
+        for room_id in room_ids:
+            self.helper.ban(room_id, self.user_id, joiner_user_id)
+
+        # Now make a new user try to join some of them.
+        # The user can make 3 requests, each of which should be denied.
+        for room_id in room_ids[0:3]:
+            self.helper.join(room_id, joiner_user_id, expect_code=HTTPStatus.FORBIDDEN)
+
+        # The fourth attempt should be rate limited.
+        self.helper.join(
+            room_ids[3], joiner_user_id, expect_code=HTTPStatus.TOO_MANY_REQUESTS
+        )
+
+    @unittest.override_config(
+        {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
+    )
     def test_join_local_ratelimit_profile_change(self) -> None:
         """Tests that sending a profile update into all of the user's joined rooms isn't
         rate-limited by the rate-limiter on joins."""
@@ -1941,6 +1961,43 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase):
             channel.json_body["error"],
         )
 
+    @unittest.override_config(
+        {
+            "default_power_level_content_override": {
+                "private_chat": {
+                    "events": {
+                        "m.room.avatar": 50,
+                        "m.room.canonical_alias": 50,
+                        "m.room.encryption": 999,
+                        "m.room.history_visibility": 100,
+                        "m.room.name": 50,
+                        "m.room.power_levels": 100,
+                        "m.room.server_acl": 100,
+                        "m.room.tombstone": 100,
+                    },
+                    "events_default": 0,
+                },
+            }
+        },
+    )
+    def test_config_override_blocks_encrypted_room(self) -> None:
+        # Given the server has config for private_chats,
+
+        # When I attempt to create an encrypted private_chat room
+        channel = self.make_request(
+            "POST",
+            "/createRoom",
+            '{"creation_content": {"m.federate": false},"name": "Secret Private Room","preset": "private_chat","initial_state": [{"type": "m.room.encryption","state_key": "","content": {"algorithm": "m.megolm.v1.aes-sha2"}}]}',
+        )
+
+        # Then I am not allowed because the required power level is unattainable
+        self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
+        self.assertEqual(
+            "You cannot create an encrypted room. "
+            + "user_level (100) < send_level (999)",
+            channel.json_body["error"],
+        )
+
 
 class RoomInitialSyncTestCase(RoomBase):
     """Tests /rooms/$room_id/initialSync."""
@@ -2052,11 +2109,8 @@ class RoomMessageListTestCase(RoomBase):
         self.assertEqual(len(chunk), 2, [event["content"] for event in chunk])
 
         # Purge every event before the second event.
-        purge_id = random_string(16)
-        pagination_handler._purges_by_id[purge_id] = PurgeStatus()
         self.get_success(
-            pagination_handler._purge_history(
-                purge_id=purge_id,
+            pagination_handler.purge_history(
                 room_id=self.room_id,
                 token=second_token_str,
                 delete_local_events=True,
@@ -2340,7 +2394,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
     ]
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-        return self.setup_test_homeserver(federation_client=Mock())
+        return self.setup_test_homeserver(federation_client=AsyncMock())
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.register_user("user", "pass")
@@ -2350,7 +2404,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
 
     def test_simple(self) -> None:
         "Simple test for searching rooms over federation"
-        self.federation_client.get_public_rooms.return_value = make_awaitable({})  # type: ignore[attr-defined]
+        self.federation_client.get_public_rooms.return_value = {}  # type: ignore[attr-defined]
 
         search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"}
 
@@ -2378,7 +2432,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
         # with a 404, when using search filters.
         self.federation_client.get_public_rooms.side_effect = (  # type: ignore[attr-defined]
             HttpResponseException(HTTPStatus.NOT_FOUND, "Not Found", b""),
-            make_awaitable({}),
+            {},
         )
 
         search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"}
@@ -3378,17 +3432,17 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
         # Mock a few functions to prevent the test from failing due to failing to talk to
         # a remote IS. We keep the mock for make_and_store_3pid_invite around so we
         # can check its call_count later on during the test.
-        make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
-        self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock  # type: ignore[assignment]
-        self.hs.get_identity_handler().lookup_3pid = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(None),
+        make_invite_mock = AsyncMock(return_value=(Mock(event_id="abc"), 0))
+        self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock  # type: ignore[method-assign]
+        self.hs.get_identity_handler().lookup_3pid = AsyncMock(  # type: ignore[method-assign]
+            return_value=None,
         )
 
         # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it
         # allow everything for now.
         # `spec` argument is needed for this function mock to have `__qualname__`, which
         # is needed for `Measure` metrics buried in SpamChecker.
-        mock = Mock(return_value=make_awaitable(True), spec=lambda *x: None)
+        mock = AsyncMock(return_value=True, spec=lambda *x: None)
         self.hs.get_module_api_callbacks().spam_checker._user_may_send_3pid_invite_callbacks.append(
             mock
         )
@@ -3416,7 +3470,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
 
         # Now change the return value of the callback to deny any invite and test that
         # we can't send the invite.
-        mock.return_value = make_awaitable(False)
+        mock.return_value = False
         channel = self.make_request(
             method="POST",
             path="/rooms/" + self.room_id + "/invite",
@@ -3442,18 +3496,18 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
         # Mock a few functions to prevent the test from failing due to failing to talk to
         # a remote IS. We keep the mock for make_and_store_3pid_invite around so we
         # can check its call_count later on during the test.
-        make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
-        self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock  # type: ignore[assignment]
-        self.hs.get_identity_handler().lookup_3pid = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(None),
+        make_invite_mock = AsyncMock(return_value=(Mock(event_id="abc"), 0))
+        self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock  # type: ignore[method-assign]
+        self.hs.get_identity_handler().lookup_3pid = AsyncMock(  # type: ignore[method-assign]
+            return_value=None,
         )
 
         # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it
         # allow everything for now.
         # `spec` argument is needed for this function mock to have `__qualname__`, which
         # is needed for `Measure` metrics buried in SpamChecker.
-        mock = Mock(
-            return_value=make_awaitable(synapse.module_api.NOT_SPAM),
+        mock = AsyncMock(
+            return_value=synapse.module_api.NOT_SPAM,
             spec=lambda *x: None,
         )
         self.hs.get_module_api_callbacks().spam_checker._user_may_send_3pid_invite_callbacks.append(
@@ -3484,7 +3538,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
         # Now change the return value of the callback to deny any invite and test that
         # we can't send the invite. We pick an arbitrary error code to be able to check
         # that the same code has been returned
-        mock.return_value = make_awaitable(Codes.CONSENT_NOT_GIVEN)
+        mock.return_value = Codes.CONSENT_NOT_GIVEN
         channel = self.make_request(
             method="POST",
             path="/rooms/" + self.room_id + "/invite",
@@ -3503,7 +3557,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
         make_invite_mock.assert_called_once()
 
         # Run variant with `Tuple[Codes, dict]`.
-        mock.return_value = make_awaitable((Codes.EXPIRED_ACCOUNT, {"field": "value"}))
+        mock.return_value = (Codes.EXPIRED_ACCOUNT, {"field": "value"})
         channel = self.make_request(
             method="POST",
             path="/rooms/" + self.room_id + "/invite",
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index 8d2cdf8751..9aecf88e41 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -84,7 +84,7 @@ class RoomTestCase(_ShadowBannedBase):
     def test_invite_3pid(self) -> None:
         """Ensure that a 3PID invite does not attempt to contact the identity server."""
         identity_handler = self.hs.get_identity_handler()
-        identity_handler.lookup_3pid = Mock(  # type: ignore[assignment]
+        identity_handler.lookup_3pid = Mock(  # type: ignore[method-assign]
             side_effect=AssertionError("This should not get called")
         )
 
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 9c876c7a32..d60665254e 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -13,8 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import json
-from http import HTTPStatus
-from typing import List, Optional
+from typing import List
 
 from parameterized import parameterized
 
@@ -22,7 +21,6 @@ from twisted.test.proto_helpers import MemoryReactor
 
 import synapse.rest.admin
 from synapse.api.constants import (
-    EduTypes,
     EventContentFields,
     EventTypes,
     ReceiptTypes,
@@ -376,156 +374,6 @@ class SyncKnockTestCase(KnockingStrippedStateEventHelperMixin):
         )
 
 
-class ReadReceiptsTestCase(unittest.HomeserverTestCase):
-    servlets = [
-        synapse.rest.admin.register_servlets,
-        login.register_servlets,
-        receipts.register_servlets,
-        room.register_servlets,
-        sync.register_servlets,
-    ]
-
-    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-        config = self.default_config()
-
-        return self.setup_test_homeserver(config=config)
-
-    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.url = "/sync?since=%s"
-        self.next_batch = "s0"
-
-        # Register the first user
-        self.user_id = self.register_user("kermit", "monkey")
-        self.tok = self.login("kermit", "monkey")
-
-        # Create the room
-        self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
-
-        # Register the second user
-        self.user2 = self.register_user("kermit2", "monkey")
-        self.tok2 = self.login("kermit2", "monkey")
-
-        # Join the second user
-        self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
-
-    def test_private_read_receipts(self) -> None:
-        # Send a message as the first user
-        res = self.helper.send(self.room_id, body="hello", tok=self.tok)
-
-        # Send a private read receipt to tell the server the first user's message was read
-        channel = self.make_request(
-            "POST",
-            f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}",
-            {},
-            access_token=self.tok2,
-        )
-        self.assertEqual(channel.code, 200)
-
-        # Test that the first user can't see the other user's private read receipt
-        self.assertIsNone(self._get_read_receipt())
-
-    def test_public_receipt_can_override_private(self) -> None:
-        """
-        Sending a public read receipt to the same event which has a private read
-        receipt should cause that receipt to become public.
-        """
-        # Send a message as the first user
-        res = self.helper.send(self.room_id, body="hello", tok=self.tok)
-
-        # Send a private read receipt
-        channel = self.make_request(
-            "POST",
-            f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}",
-            {},
-            access_token=self.tok2,
-        )
-        self.assertEqual(channel.code, 200)
-        self.assertIsNone(self._get_read_receipt())
-
-        # Send a public read receipt
-        channel = self.make_request(
-            "POST",
-            f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}",
-            {},
-            access_token=self.tok2,
-        )
-        self.assertEqual(channel.code, 200)
-
-        # Test that we did override the private read receipt
-        self.assertNotEqual(self._get_read_receipt(), None)
-
-    def test_private_receipt_cannot_override_public(self) -> None:
-        """
-        Sending a private read receipt to the same event which has a public read
-        receipt should cause no change.
-        """
-        # Send a message as the first user
-        res = self.helper.send(self.room_id, body="hello", tok=self.tok)
-
-        # Send a public read receipt
-        channel = self.make_request(
-            "POST",
-            f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}",
-            {},
-            access_token=self.tok2,
-        )
-        self.assertEqual(channel.code, 200)
-        self.assertNotEqual(self._get_read_receipt(), None)
-
-        # Send a private read receipt
-        channel = self.make_request(
-            "POST",
-            f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}",
-            {},
-            access_token=self.tok2,
-        )
-        self.assertEqual(channel.code, 200)
-
-        # Test that we didn't override the public read receipt
-        self.assertIsNone(self._get_read_receipt())
-
-    def test_read_receipt_with_empty_body_is_rejected(self) -> None:
-        # Send a message as the first user
-        res = self.helper.send(self.room_id, body="hello", tok=self.tok)
-
-        # Send a read receipt for this message with an empty body
-        channel = self.make_request(
-            "POST",
-            f"/rooms/{self.room_id}/receipt/m.read/{res['event_id']}",
-            access_token=self.tok2,
-        )
-        self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST)
-        self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON", channel.json_body)
-
-    def _get_read_receipt(self) -> Optional[JsonDict]:
-        """Syncs and returns the read receipt."""
-
-        # Checks if event is a read receipt
-        def is_read_receipt(event: JsonDict) -> bool:
-            return event["type"] == EduTypes.RECEIPT
-
-        # Sync
-        channel = self.make_request(
-            "GET",
-            self.url % self.next_batch,
-            access_token=self.tok,
-        )
-        self.assertEqual(channel.code, 200)
-
-        # Store the next batch for the next request.
-        self.next_batch = channel.json_body["next_batch"]
-
-        if channel.json_body.get("rooms", None) is None:
-            return None
-
-        # Return the read receipt
-        ephemeral_events = channel.json_body["rooms"]["join"][self.room_id][
-            "ephemeral"
-        ]["events"]
-        receipt_event = filter(is_read_receipt, ephemeral_events)
-        return next(receipt_event, None)
-
-
 class UnreadMessagesTestCase(unittest.HomeserverTestCase):
     servlets = [
         synapse.rest.admin.register_servlets,
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index e5ba5a9706..57eb713b15 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 import threading
 from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -33,7 +33,6 @@ from synapse.util import Clock
 from synapse.util.frozenutils import unfreeze
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 
 if TYPE_CHECKING:
     from synapse.module_api import ModuleApi
@@ -118,7 +117,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         async def _check_event_auth(origin: Any, event: Any, context: Any) -> None:
             pass
 
-        hs.get_federation_event_handler()._check_event_auth = _check_event_auth  # type: ignore[assignment]
+        hs.get_federation_event_handler()._check_event_auth = _check_event_auth  # type: ignore[method-assign]
 
         return hs
 
@@ -477,7 +476,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
 
     def test_on_new_event(self) -> None:
         """Test that the on_new_event callback is called on new events"""
-        on_new_event = Mock(make_awaitable(None))
+        on_new_event = AsyncMock(return_value=None)
         self.hs.get_module_api_callbacks().third_party_event_rules._on_new_event_callbacks.append(
             on_new_event
         )
@@ -580,7 +579,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo"
 
         # Register a mock callback.
-        m = Mock(return_value=make_awaitable(None))
+        m = AsyncMock(return_value=None)
         self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append(
             m
         )
@@ -641,7 +640,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo"
 
         # Register a mock callback.
-        m = Mock(return_value=make_awaitable(None))
+        m = AsyncMock(return_value=None)
         self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append(
             m
         )
@@ -682,7 +681,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         correctly when processing a user's deactivation.
         """
         # Register a mocked callback.
-        deactivation_mock = Mock(return_value=make_awaitable(None))
+        deactivation_mock = AsyncMock(return_value=None)
         third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
         third_party_rules._on_user_deactivation_status_changed_callbacks.append(
             deactivation_mock,
@@ -690,7 +689,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         # Also register a mocked callback for profile updates, to check that the
         # deactivation code calls it in a way that let modules know the user is being
         # deactivated.
-        profile_mock = Mock(return_value=make_awaitable(None))
+        profile_mock = AsyncMock(return_value=None)
         self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append(
             profile_mock,
         )
@@ -740,7 +739,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         well as a reactivation.
         """
         # Register a mock callback.
-        m = Mock(return_value=make_awaitable(None))
+        m = AsyncMock(return_value=None)
         third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
         third_party_rules._on_user_deactivation_status_changed_callbacks.append(m)
 
@@ -794,7 +793,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         correctly when processing a user's deactivation.
         """
         # Register a mocked callback.
-        deactivation_mock = Mock(return_value=make_awaitable(False))
+        deactivation_mock = AsyncMock(return_value=False)
         third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
         third_party_rules._check_can_deactivate_user_callbacks.append(
             deactivation_mock,
@@ -840,7 +839,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         correctly when processing a user's deactivation triggered by a server admin.
         """
         # Register a mocked callback.
-        deactivation_mock = Mock(return_value=make_awaitable(False))
+        deactivation_mock = AsyncMock(return_value=False)
         third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
         third_party_rules._check_can_deactivate_user_callbacks.append(
             deactivation_mock,
@@ -879,7 +878,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         correctly when processing an admin's shutdown room request.
         """
         # Register a mocked callback.
-        shutdown_mock = Mock(return_value=make_awaitable(False))
+        shutdown_mock = AsyncMock(return_value=False)
         third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
         third_party_rules._check_can_shutdown_room_callbacks.append(
             shutdown_mock,
@@ -915,7 +914,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         associating a 3PID to an account.
         """
         # Register a mocked callback.
-        threepid_bind_mock = Mock(return_value=make_awaitable(None))
+        threepid_bind_mock = AsyncMock(return_value=None)
         third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
         third_party_rules._on_threepid_bind_callbacks.append(threepid_bind_mock)
 
@@ -957,11 +956,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         just before associating and removing a 3PID to/from an account.
         """
         # Pretend to be a Synapse module and register both callbacks as mocks.
-        on_add_user_third_party_identifier_callback_mock = Mock(
-            return_value=make_awaitable(None)
-        )
-        on_remove_user_third_party_identifier_callback_mock = Mock(
-            return_value=make_awaitable(None)
+        on_add_user_third_party_identifier_callback_mock = AsyncMock(return_value=None)
+        on_remove_user_third_party_identifier_callback_mock = AsyncMock(
+            return_value=None
         )
         self.hs.get_module_api().register_third_party_rules_callbacks(
             on_add_user_third_party_identifier=on_add_user_third_party_identifier_callback_mock,
@@ -1021,8 +1018,8 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         when a user is deactivated and their third-party ID associations are deleted.
         """
         # Pretend to be a Synapse module and register both callbacks as mocks.
-        on_remove_user_third_party_identifier_callback_mock = Mock(
-            return_value=make_awaitable(None)
+        on_remove_user_third_party_identifier_callback_mock = AsyncMock(
+            return_value=None
         )
         self.hs.get_module_api().register_third_party_rules_callbacks(
             on_remove_user_third_party_identifier=on_remove_user_third_party_identifier_callback_mock,
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index d8dc56261a..951a3cbc43 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -14,7 +14,7 @@
 
 from http import HTTPStatus
 from typing import Any, Generator, Tuple, cast
-from unittest.mock import Mock, call
+from unittest.mock import AsyncMock, Mock, call
 
 from twisted.internet import defer, reactor as _reactor
 
@@ -24,7 +24,6 @@ from synapse.types import ISynapseReactor, JsonDict
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 from tests.utils import MockClock
 
 reactor = cast(ISynapseReactor, _reactor)
@@ -53,7 +52,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
     def test_executes_given_function(
         self,
     ) -> Generator["defer.Deferred[Any]", object, None]:
-        cb = Mock(return_value=make_awaitable(self.mock_http_response))
+        cb = AsyncMock(return_value=self.mock_http_response)
         res = yield self.cache.fetch_or_execute_request(
             self.mock_request, self.mock_requester, cb, "some_arg", keyword="arg"
         )
@@ -64,7 +63,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
     def test_deduplicates_based_on_key(
         self,
     ) -> Generator["defer.Deferred[Any]", object, None]:
-        cb = Mock(return_value=make_awaitable(self.mock_http_response))
+        cb = AsyncMock(return_value=self.mock_http_response)
         for i in range(3):  # invoke multiple times
             res = yield self.cache.fetch_or_execute_request(
                 self.mock_request,
@@ -168,7 +167,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]:
-        cb = Mock(return_value=make_awaitable(self.mock_http_response))
+        cb = AsyncMock(return_value=self.mock_http_response)
         yield self.cache.fetch_or_execute_request(
             self.mock_request, self.mock_requester, cb, "an arg"
         )
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 9532e5ddc1..465b696c0b 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -37,7 +37,6 @@ import attr
 from typing_extensions import Literal
 
 from twisted.test.proto_helpers import MemoryReactorClock
-from twisted.web.resource import Resource
 from twisted.web.server import Site
 
 from synapse.api.constants import Membership
@@ -45,7 +44,7 @@ from synapse.api.errors import Codes
 from synapse.server import HomeServer
 from synapse.types import JsonDict
 
-from tests.server import FakeChannel, FakeSite, make_request
+from tests.server import FakeChannel, make_request
 from tests.test_utils.html_parsers import TestHtmlParser
 from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer
 
@@ -558,7 +557,6 @@ class RestHelper:
 
     def upload_media(
         self,
-        resource: Resource,
         image_data: bytes,
         tok: str,
         filename: str = "test.png",
@@ -576,7 +574,7 @@ class RestHelper:
         path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
         channel = make_request(
             self.reactor,
-            FakeSite(resource, self.reactor),
+            self.site,
             "POST",
             path,
             content=image_data,
diff --git a/tests/rest/media/test_url_preview.py b/tests/rest/media/test_url_preview.py
index 170fb0534a..24459c6af4 100644
--- a/tests/rest/media/test_url_preview.py
+++ b/tests/rest/media/test_url_preview.py
@@ -24,10 +24,10 @@ from twisted.internet.address import IPv4Address, IPv6Address
 from twisted.internet.error import DNSLookupError
 from twisted.internet.interfaces import IAddress, IResolutionReceiver
 from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor
+from twisted.web.resource import Resource
 
 from synapse.config.oembed import OEmbedEndpointConfig
 from synapse.media.url_previewer import IMAGE_CACHE_EXPIRY_MS
-from synapse.rest.media.media_repository_resource import MediaRepositoryResource
 from synapse.server import HomeServer
 from synapse.types import JsonDict
 from synapse.util import Clock
@@ -40,7 +40,7 @@ from tests.test_utils import SMALL_PNG
 try:
     import lxml
 except ImportError:
-    lxml = None
+    lxml = None  # type: ignore[assignment]
 
 
 class URLPreviewTests(unittest.HomeserverTestCase):
@@ -117,8 +117,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.media_repo = hs.get_media_repository()
-        media_repo_resource = hs.get_media_repository_resource()
-        self.preview_url = media_repo_resource.children[b"preview_url"]
+        assert self.media_repo.url_previewer is not None
+        self.url_previewer = self.media_repo.url_previewer
 
         self.lookups: Dict[str, Any] = {}
 
@@ -143,8 +143,15 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         self.reactor.nameResolver = Resolver()  # type: ignore[assignment]
 
-    def create_test_resource(self) -> MediaRepositoryResource:
-        return self.hs.get_media_repository_resource()
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        """Create a resource tree for the test server
+
+        A resource tree is a mapping from path to twisted.web.resource.
+
+        The default implementation creates a JsonResource and calls each function in
+        `servlets` to register servlets against it.
+        """
+        return {"/_matrix/media": self.hs.get_media_repository_resource()}
 
     def _assert_small_png(self, json_body: JsonDict) -> None:
         """Assert properties from the SMALL_PNG test image."""
@@ -159,7 +166,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "preview_url?url=http://matrix.org",
+            "/_matrix/media/v3/preview_url?url=http://matrix.org",
             shorthand=False,
             await_result=False,
         )
@@ -183,7 +190,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         # Check the cache returns the correct response
         channel = self.make_request(
-            "GET", "preview_url?url=http://matrix.org", shorthand=False
+            "GET",
+            "/_matrix/media/v3/preview_url?url=http://matrix.org",
+            shorthand=False,
         )
 
         # Check the cache response has the same content
@@ -193,13 +202,15 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         )
 
         # Clear the in-memory cache
-        self.assertIn("http://matrix.org", self.preview_url._url_previewer._cache)
-        self.preview_url._url_previewer._cache.pop("http://matrix.org")
-        self.assertNotIn("http://matrix.org", self.preview_url._url_previewer._cache)
+        self.assertIn("http://matrix.org", self.url_previewer._cache)
+        self.url_previewer._cache.pop("http://matrix.org")
+        self.assertNotIn("http://matrix.org", self.url_previewer._cache)
 
         # Check the database cache returns the correct response
         channel = self.make_request(
-            "GET", "preview_url?url=http://matrix.org", shorthand=False
+            "GET",
+            "/_matrix/media/v3/preview_url?url=http://matrix.org",
+            shorthand=False,
         )
 
         # Check the cache response has the same content
@@ -221,7 +232,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "preview_url?url=http://matrix.org",
+            "/_matrix/media/v3/preview_url?url=http://matrix.org",
             shorthand=False,
             await_result=False,
         )
@@ -251,7 +262,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "preview_url?url=http://matrix.org",
+            "/_matrix/media/v3/preview_url?url=http://matrix.org",
             shorthand=False,
             await_result=False,
         )
@@ -287,7 +298,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "preview_url?url=http://matrix.org",
+            "/_matrix/media/v3/preview_url?url=http://matrix.org",
             shorthand=False,
             await_result=False,
         )
@@ -328,7 +339,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "preview_url?url=http://matrix.org",
+            "/_matrix/media/v3/preview_url?url=http://matrix.org",
             shorthand=False,
             await_result=False,
         )
@@ -363,7 +374,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "preview_url?url=http://matrix.org",
+            "/_matrix/media/v3/preview_url?url=http://matrix.org",
             shorthand=False,
             await_result=False,
         )
@@ -396,7 +407,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "preview_url?url=http://example.com",
+            "/_matrix/media/v3/preview_url?url=http://example.com",
             shorthand=False,
             await_result=False,
         )
@@ -425,7 +436,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")]
 
         channel = self.make_request(
-            "GET", "preview_url?url=http://example.com", shorthand=False
+            "GET",
+            "/_matrix/media/v3/preview_url?url=http://example.com",
+            shorthand=False,
         )
 
         # No requests made.
@@ -446,7 +459,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")]
 
         channel = self.make_request(
-            "GET", "preview_url?url=http://example.com", shorthand=False
+            "GET",
+            "/_matrix/media/v3/preview_url?url=http://example.com",
+            shorthand=False,
         )
 
         self.assertEqual(channel.code, 502)
@@ -463,7 +478,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         Blocked IP addresses, accessed directly, are not spidered.
         """
         channel = self.make_request(
-            "GET", "preview_url?url=http://192.168.1.1", shorthand=False
+            "GET",
+            "/_matrix/media/v3/preview_url?url=http://192.168.1.1",
+            shorthand=False,
         )
 
         # No requests made.
@@ -479,7 +496,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         Blocked IP ranges, accessed directly, are not spidered.
         """
         channel = self.make_request(
-            "GET", "preview_url?url=http://1.1.1.2", shorthand=False
+            "GET", "/_matrix/media/v3/preview_url?url=http://1.1.1.2", shorthand=False
         )
 
         self.assertEqual(channel.code, 403)
@@ -497,7 +514,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "preview_url?url=http://example.com",
+            "/_matrix/media/v3/preview_url?url=http://example.com",
             shorthand=False,
             await_result=False,
         )
@@ -533,7 +550,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         ]
 
         channel = self.make_request(
-            "GET", "preview_url?url=http://example.com", shorthand=False
+            "GET",
+            "/_matrix/media/v3/preview_url?url=http://example.com",
+            shorthand=False,
         )
         self.assertEqual(channel.code, 502)
         self.assertEqual(
@@ -553,7 +572,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         ]
 
         channel = self.make_request(
-            "GET", "preview_url?url=http://example.com", shorthand=False
+            "GET",
+            "/_matrix/media/v3/preview_url?url=http://example.com",
+            shorthand=False,
         )
 
         # No requests made.
@@ -574,7 +595,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.lookups["example.com"] = [(IPv6Address, "2001:800::1")]
 
         channel = self.make_request(
-            "GET", "preview_url?url=http://example.com", shorthand=False
+            "GET",
+            "/_matrix/media/v3/preview_url?url=http://example.com",
+            shorthand=False,
         )
 
         self.assertEqual(channel.code, 502)
@@ -591,10 +614,11 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         OPTIONS returns the OPTIONS.
         """
         channel = self.make_request(
-            "OPTIONS", "preview_url?url=http://example.com", shorthand=False
+            "OPTIONS",
+            "/_matrix/media/v3/preview_url?url=http://example.com",
+            shorthand=False,
         )
-        self.assertEqual(channel.code, 200)
-        self.assertEqual(channel.json_body, {})
+        self.assertEqual(channel.code, 204)
 
     def test_accept_language_config_option(self) -> None:
         """
@@ -605,7 +629,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         # Build and make a request to the server
         channel = self.make_request(
             "GET",
-            "preview_url?url=http://example.com",
+            "/_matrix/media/v3/preview_url?url=http://example.com",
             shorthand=False,
             await_result=False,
         )
@@ -658,7 +682,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "preview_url?url=http://matrix.org",
+            "/_matrix/media/v3/preview_url?url=http://matrix.org",
             shorthand=False,
             await_result=False,
         )
@@ -708,7 +732,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "preview_url?url=http://matrix.org",
+            "/_matrix/media/v3/preview_url?url=http://matrix.org",
             shorthand=False,
             await_result=False,
         )
@@ -750,7 +774,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "preview_url?url=http://matrix.org",
+            "/_matrix/media/v3/preview_url?url=http://matrix.org",
             shorthand=False,
             await_result=False,
         )
@@ -790,7 +814,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "preview_url?url=http://matrix.org",
+            "/_matrix/media/v3/preview_url?url=http://matrix.org",
             shorthand=False,
             await_result=False,
         )
@@ -831,7 +855,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            f"preview_url?{query_params}",
+            f"/_matrix/media/v3/preview_url?{query_params}",
             shorthand=False,
         )
         self.pump()
@@ -852,7 +876,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "preview_url?url=http://matrix.org",
+            "/_matrix/media/v3/preview_url?url=http://matrix.org",
             shorthand=False,
             await_result=False,
         )
@@ -889,7 +913,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
+            "/_matrix/media/v3/preview_url?url=http://twitter.com/matrixdotorg/status/12345",
             shorthand=False,
             await_result=False,
         )
@@ -949,7 +973,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
+            "/_matrix/media/v3/preview_url?url=http://twitter.com/matrixdotorg/status/12345",
             shorthand=False,
             await_result=False,
         )
@@ -998,7 +1022,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "preview_url?url=http://www.hulu.com/watch/12345",
+            "/_matrix/media/v3/preview_url?url=http://www.hulu.com/watch/12345",
             shorthand=False,
             await_result=False,
         )
@@ -1043,7 +1067,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
+            "/_matrix/media/v3/preview_url?url=http://twitter.com/matrixdotorg/status/12345",
             shorthand=False,
             await_result=False,
         )
@@ -1072,7 +1096,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "preview_url?url=http://www.twitter.com/matrixdotorg/status/12345",
+            "/_matrix/media/v3/preview_url?url=http://www.twitter.com/matrixdotorg/status/12345",
             shorthand=False,
             await_result=False,
         )
@@ -1164,7 +1188,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "preview_url?url=http://www.twitter.com/matrixdotorg/status/12345",
+            "/_matrix/media/v3/preview_url?url=http://www.twitter.com/matrixdotorg/status/12345",
             shorthand=False,
             await_result=False,
         )
@@ -1205,7 +1229,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "preview_url?url=http://cdn.twitter.com/matrixdotorg",
+            "/_matrix/media/v3/preview_url?url=http://cdn.twitter.com/matrixdotorg",
             shorthand=False,
             await_result=False,
         )
@@ -1247,7 +1271,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         # Check fetching
         channel = self.make_request(
             "GET",
-            f"download/{host}/{media_id}",
+            f"/_matrix/media/v3/download/{host}/{media_id}",
             shorthand=False,
             await_result=False,
         )
@@ -1260,7 +1284,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            f"download/{host}/{media_id}",
+            f"/_matrix/media/v3/download/{host}/{media_id}",
             shorthand=False,
             await_result=False,
         )
@@ -1295,7 +1319,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         # Check fetching
         channel = self.make_request(
             "GET",
-            f"thumbnail/{host}/{media_id}?width=32&height=32&method=scale",
+            f"/_matrix/media/v3/thumbnail/{host}/{media_id}?width=32&height=32&method=scale",
             shorthand=False,
             await_result=False,
         )
@@ -1313,7 +1337,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            f"thumbnail/{host}/{media_id}?width=32&height=32&method=scale",
+            f"/_matrix/media/v3/thumbnail/{host}/{media_id}?width=32&height=32&method=scale",
             shorthand=False,
             await_result=False,
         )
@@ -1343,7 +1367,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertTrue(os.path.isdir(thumbnail_dir))
 
         self.reactor.advance(IMAGE_CACHE_EXPIRY_MS * 1000 + 1)
-        self.get_success(self.preview_url._url_previewer._expire_url_cache_data())
+        self.get_success(self.url_previewer._expire_url_cache_data())
 
         for path in [file_path] + file_dirs + [thumbnail_dir] + thumbnail_dirs:
             self.assertFalse(
@@ -1363,7 +1387,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "preview_url?url=" + bad_url,
+            "/_matrix/media/v3/preview_url?url=" + bad_url,
             shorthand=False,
             await_result=False,
         )
@@ -1372,7 +1396,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "preview_url?url=" + good_url,
+            "/_matrix/media/v3/preview_url?url=" + good_url,
             shorthand=False,
             await_result=False,
         )
@@ -1404,7 +1428,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "preview_url?url=" + bad_url,
+            "/_matrix/media/v3/preview_url?url=" + bad_url,
             shorthand=False,
             await_result=False,
         )
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index 2091b08d89..377243a170 100644
--- a/tests/rest/test_well_known.py
+++ b/tests/rest/test_well_known.py
@@ -17,6 +17,13 @@ from synapse.rest.well_known import well_known_resource
 
 from tests import unittest
 
+try:
+    import authlib  # noqa: F401
+
+    HAS_AUTHLIB = True
+except ImportError:
+    HAS_AUTHLIB = False
+
 
 class WellKnownTests(unittest.HomeserverTestCase):
     def create_test_resource(self) -> Resource:
@@ -96,3 +103,37 @@ class WellKnownTests(unittest.HomeserverTestCase):
             "GET", "/.well-known/matrix/server", shorthand=False
         )
         self.assertEqual(channel.code, 404)
+
+    @unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
+    @unittest.override_config(
+        {
+            "public_baseurl": "https://homeserver",  # this is only required so that client well known is served
+            "experimental_features": {
+                "msc3861": {
+                    "enabled": True,
+                    "issuer": "https://issuer",
+                    "account_management_url": "https://my-account.issuer",
+                    "client_id": "id",
+                    "client_auth_method": "client_secret_post",
+                    "client_secret": "secret",
+                },
+            },
+            "disable_registration": True,
+        }
+    )
+    def test_client_well_known_msc3861_oauth_delegation(self) -> None:
+        channel = self.make_request(
+            "GET", "/.well-known/matrix/client", shorthand=False
+        )
+
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(
+            channel.json_body,
+            {
+                "m.homeserver": {"base_url": "https://homeserver/"},
+                "org.matrix.msc2965.authentication": {
+                    "issuer": "https://issuer",
+                    "account": "https://my-account.issuer",
+                },
+            },
+        )
diff --git a/tests/server.py b/tests/server.py
index 7296f0a552..08633fe640 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import hashlib
+import ipaddress
 import json
 import logging
 import os
@@ -26,6 +27,7 @@ from typing import (
     Any,
     Awaitable,
     Callable,
+    Deque,
     Dict,
     Iterable,
     List,
@@ -41,10 +43,10 @@ from typing import (
 from unittest.mock import Mock
 
 import attr
-from typing_extensions import Deque, ParamSpec
+from typing_extensions import ParamSpec
 from zope.interface import implementer
 
-from twisted.internet import address, threads, udp
+from twisted.internet import address, tcp, threads, udp
 from twisted.internet._resolver import SimpleResolverComplexifier
 from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
 from twisted.internet.error import DNSLookupError
@@ -53,6 +55,7 @@ from twisted.internet.interfaces import (
     IConnector,
     IConsumer,
     IHostnameResolver,
+    IListeningPort,
     IProducer,
     IProtocol,
     IPullProducer,
@@ -62,7 +65,7 @@ from twisted.internet.interfaces import (
     IResolverSimple,
     ITransport,
 )
-from twisted.internet.protocol import ClientFactory, DatagramProtocol
+from twisted.internet.protocol import ClientFactory, DatagramProtocol, Factory
 from twisted.python import threadpool
 from twisted.python.failure import Failure
 from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
@@ -523,6 +526,35 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
         """
         self._tcp_callbacks[(host, port)] = callback
 
+    def connectUNIX(
+        self,
+        address: str,
+        factory: ClientFactory,
+        timeout: float = 30,
+        checkPID: int = 0,
+    ) -> IConnector:
+        """
+        Unix sockets aren't supported for unit tests yet. Make it obvious to any
+        developer trying it out that they will need to do some work before being able
+        to use it in tests.
+        """
+        raise Exception("Unix sockets are not implemented for tests yet, sorry.")
+
+    def listenUNIX(
+        self,
+        address: str,
+        factory: Factory,
+        backlog: int = 50,
+        mode: int = 0o666,
+        wantPID: int = 0,
+    ) -> IListeningPort:
+        """
+        Unix sockets aren't supported for unit tests yet. Make it obvious to any
+        developer trying it out that they will need to do some work before being able
+        to use it in tests.
+        """
+        raise Exception("Unix sockets are not implemented for tests, sorry")
+
     def connectTCP(
         self,
         host: str,
@@ -536,6 +568,8 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
         conn = super().connectTCP(
             host, port, factory, timeout=timeout, bindAddress=None
         )
+        if self.lookups and host in self.lookups:
+            validate_connector(conn, self.lookups[host])
 
         callback = self._tcp_callbacks.get((host, port))
         if callback:
@@ -568,6 +602,55 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
             super().advance(0)
 
 
+def validate_connector(connector: tcp.Connector, expected_ip: str) -> None:
+    """Try to validate the obtained connector as it would happen when
+    synapse is running and the conection will be established.
+
+    This method will raise a useful exception when necessary, else it will
+    just do nothing.
+
+    This is in order to help catch quirks related to reactor.connectTCP,
+    since when called directly, the connector's destination will be of type
+    IPv4Address, with the hostname as the literal host that was given (which
+    could be an IPv6-only host or an IPv6 literal).
+
+    But when called from reactor.connectTCP *through* e.g. an Endpoint, the
+    connector's destination will contain the specific IP address with the
+    correct network stack class.
+
+    Note that testing code paths that use connectTCP directly should not be
+    affected by this check, unless they specifically add a test with a
+    matching reactor.lookups[HOSTNAME] = "IPv6Literal", where reactor is of
+    type ThreadedMemoryReactorClock.
+    For an example of implementing such tests, see test/handlers/send_email.py.
+    """
+    destination = connector.getDestination()
+
+    # We use address.IPv{4,6}Address to check what the reactor thinks it is
+    # is sending but check for validity with ipaddress.IPv{4,6}Address
+    # because they fail with IPs on the wrong network stack.
+    cls_mapping = {
+        address.IPv4Address: ipaddress.IPv4Address,
+        address.IPv6Address: ipaddress.IPv6Address,
+    }
+
+    cls = cls_mapping.get(destination.__class__)
+
+    if cls is not None:
+        try:
+            cls(expected_ip)
+        except Exception as exc:
+            raise ValueError(
+                "Invalid IP type and resolution for %s. Expected %s to be %s"
+                % (destination, expected_ip, cls.__name__)
+            ) from exc
+    else:
+        raise ValueError(
+            "Unknown address type %s for %s"
+            % (destination.__class__.__name__, destination)
+        )
+
+
 class ThreadPool:
     """
     Threadless thread pool.
@@ -639,10 +722,10 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None:
                 **kwargs,
             )
 
-        pool.runWithConnection = runWithConnection  # type: ignore[assignment]
+        pool.runWithConnection = runWithConnection  # type: ignore[method-assign]
         pool.runInteraction = runInteraction  # type: ignore[assignment]
         # Replace the thread pool with a threadless 'thread' pool
-        pool.threadpool = ThreadPool(clock._reactor)  # type: ignore[assignment]
+        pool.threadpool = ThreadPool(clock._reactor)
         pool.running = True
 
     # We've just changed the Databases to run DB transactions on the same
@@ -969,8 +1052,6 @@ def setup_test_homeserver(
     hs.tls_server_context_factory = Mock()
 
     hs.setup()
-    if homeserver_to_use == TestHomeServer:
-        hs.setup_background_tasks()
 
     if isinstance(db_engine, PostgresEngine):
         database_pool = hs.get_datastores().databases[0]
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index d2bfa53eda..17f428bfc5 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from typing import Tuple
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -29,7 +29,6 @@ from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 from tests.utils import default_config
 
@@ -69,24 +68,22 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         assert isinstance(rlsn, ResourceLimitsServerNotices)
         self._rlsn = rlsn
 
-        self._rlsn._store.user_last_seen_monthly_active = Mock(
-            return_value=make_awaitable(1000)
-        )
-        self._rlsn._server_notices_manager.send_notice = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(Mock())
+        self._rlsn._store.user_last_seen_monthly_active = AsyncMock(return_value=1000)
+        self._rlsn._server_notices_manager.send_notice = AsyncMock(  # type: ignore[method-assign]
+            return_value=Mock()
         )
         self._send_notice = self._rlsn._server_notices_manager.send_notice
 
         self.user_id = "@user_id:test"
 
-        self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock(
-            return_value=make_awaitable("!something:localhost")
+        self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = (
+            AsyncMock(return_value="!something:localhost")
         )
-        self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = Mock(
-            return_value=make_awaitable("!something:localhost")
+        self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = AsyncMock(
+            return_value="!something:localhost"
         )
-        self._rlsn._store.add_tag_to_room = Mock(return_value=make_awaitable(None))  # type: ignore[assignment]
-        self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({}))  # type: ignore[assignment]
+        self._rlsn._store.add_tag_to_room = AsyncMock(return_value=None)  # type: ignore[method-assign]
+        self._rlsn._store.get_tags_for_room = AsyncMock(return_value={})  # type: ignore[method-assign]
 
     @override_config({"hs_disabled": True})
     def test_maybe_send_server_notice_disabled_hs(self) -> None:
@@ -103,14 +100,14 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
     def test_maybe_send_server_notice_to_user_remove_blocked_notice(self) -> None:
         """Test when user has blocked notice, but should have it removed"""
 
-        self._rlsn._auth_blocking.check_auth_blocking = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(None)
+        self._rlsn._auth_blocking.check_auth_blocking = AsyncMock(  # type: ignore[method-assign]
+            return_value=None
         )
         mock_event = Mock(
             type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
         )
-        self._rlsn._store.get_events = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable({"123": mock_event})
+        self._rlsn._store.get_events = AsyncMock(  # type: ignore[method-assign]
+            return_value={"123": mock_event}
         )
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
         # Would be better to check the content, but once == remove blocking event
@@ -125,16 +122,16 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         """
         Test when user has blocked notice, but notice ought to be there (NOOP)
         """
-        self._rlsn._auth_blocking.check_auth_blocking = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(None),
+        self._rlsn._auth_blocking.check_auth_blocking = AsyncMock(  # type: ignore[method-assign]
+            return_value=None,
             side_effect=ResourceLimitError(403, "foo"),
         )
 
         mock_event = Mock(
             type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
         )
-        self._rlsn._store.get_events = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable({"123": mock_event})
+        self._rlsn._store.get_events = AsyncMock(  # type: ignore[method-assign]
+            return_value={"123": mock_event}
         )
 
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -145,8 +142,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         """
         Test when user does not have blocked notice, but should have one
         """
-        self._rlsn._auth_blocking.check_auth_blocking = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(None),
+        self._rlsn._auth_blocking.check_auth_blocking = AsyncMock(  # type: ignore[method-assign]
+            return_value=None,
             side_effect=ResourceLimitError(403, "foo"),
         )
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -158,8 +155,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         """
         Test when user does not have blocked notice, nor should they (NOOP)
         """
-        self._rlsn._auth_blocking.check_auth_blocking = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(None)
+        self._rlsn._auth_blocking.check_auth_blocking = AsyncMock(  # type: ignore[method-assign]
+            return_value=None
         )
 
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -171,12 +168,10 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         Test when user is not part of the MAU cohort - this should not ever
         happen - but ...
         """
-        self._rlsn._auth_blocking.check_auth_blocking = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(None)
-        )
-        self._rlsn._store.user_last_seen_monthly_active = Mock(
-            return_value=make_awaitable(None)
+        self._rlsn._auth_blocking.check_auth_blocking = AsyncMock(  # type: ignore[method-assign]
+            return_value=None
         )
+        self._rlsn._store.user_last_seen_monthly_active = AsyncMock(return_value=None)
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
 
         self._send_notice.assert_not_called()
@@ -189,8 +184,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         Test that when server is over MAU limit and alerting is suppressed, then
         an alert message is not sent into the room
         """
-        self._rlsn._auth_blocking.check_auth_blocking = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(None),
+        self._rlsn._auth_blocking.check_auth_blocking = AsyncMock(  # type: ignore[method-assign]
+            return_value=None,
             side_effect=ResourceLimitError(
                 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
             ),
@@ -204,8 +199,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         """
         Test that when a server is disabled, that MAU limit alerting is ignored.
         """
-        self._rlsn._auth_blocking.check_auth_blocking = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(None),
+        self._rlsn._auth_blocking.check_auth_blocking = AsyncMock(  # type: ignore[method-assign]
+            return_value=None,
             side_effect=ResourceLimitError(
                 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
             ),
@@ -223,22 +218,22 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         When the room is already in a blocked state, test that when alerting
         is suppressed that the room is returned to an unblocked state.
         """
-        self._rlsn._auth_blocking.check_auth_blocking = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(None),
+        self._rlsn._auth_blocking.check_auth_blocking = AsyncMock(  # type: ignore[method-assign]
+            return_value=None,
             side_effect=ResourceLimitError(
                 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
             ),
         )
 
-        self._rlsn._is_room_currently_blocked = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable((True, []))
+        self._rlsn._is_room_currently_blocked = AsyncMock(  # type: ignore[method-assign]
+            return_value=(True, [])
         )
 
         mock_event = Mock(
             type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
         )
-        self._rlsn._store.get_events = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable({"123": mock_event})
+        self._rlsn._store.get_events = AsyncMock(  # type: ignore[method-assign]
+            return_value={"123": mock_event}
         )
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
 
@@ -284,11 +279,9 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
         self.user_id = "@user_id:test"
 
     def test_server_notice_only_sent_once(self) -> None:
-        self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000))
+        self.store.get_monthly_active_count = AsyncMock(return_value=1000)
 
-        self.store.user_last_seen_monthly_active = Mock(
-            return_value=make_awaitable(1000)
-        )
+        self.store.user_last_seen_monthly_active = AsyncMock(return_value=1000)
 
         # Call the function multiple times to ensure we only send the notice once
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -327,7 +320,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
         hasn't been reached (since it's the only user and the limit is 5), so users
         shouldn't receive a server notice.
         """
-        m = Mock(return_value=make_awaitable(None))
+        m = AsyncMock(return_value=None)
         self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = m
 
         user_id = self.register_user("user", "password")
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index 2e3f2318d9..6a2f7584f6 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -719,7 +719,10 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
         persisted_events = {a.event_id: a, b.event_id: b}
         unpersited_events = {c.event_id: c}
 
-        state_sets = [{"a": a.event_id, "b": b.event_id}, {"c": c.event_id}]
+        state_sets = [
+            {("a", ""): a.event_id, ("b", ""): b.event_id},
+            {("c", ""): c.event_id},
+        ]
 
         store = TestStateResolutionStore(persisted_events)
 
@@ -774,8 +777,8 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
         unpersited_events = {c.event_id: c, d.event_id: d}
 
         state_sets = [
-            {"a": a.event_id, "b": b.event_id},
-            {"c": c.event_id, "d": d.event_id},
+            {("a", ""): a.event_id, ("b", ""): b.event_id},
+            {("c", ""): c.event_id, ("d", ""): d.event_id},
         ]
 
         store = TestStateResolutionStore(persisted_events)
@@ -841,8 +844,8 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
         unpersited_events = {c.event_id: c, d.event_id: d, e.event_id: e}
 
         state_sets = [
-            {"a": a.event_id, "b": b.event_id, "e": e.event_id},
-            {"c": c.event_id, "d": d.event_id},
+            {("a", ""): a.event_id, ("b", ""): b.event_id, ("e", ""): e.event_id},
+            {("c", ""): c.event_id, ("d", ""): d.event_id},
         ]
 
         store = TestStateResolutionStore(persisted_events)
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 9606ecc43b..b223dc750b 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -139,6 +139,55 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
             # That should result in a single db query to lookup
             self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
 
+    def test_persisting_event_prefills_get_event_cache(self) -> None:
+        """
+        Test to make sure that the `_get_event_cache` is prefilled after we persist an
+        event and returns the updated value.
+        """
+        event, event_context = self.get_success(
+            create_event(
+                self.hs,
+                room_id=self.room_id,
+                sender=self.user,
+                type="test_event_type",
+                content={"body": "conflabulation"},
+            )
+        )
+
+        # First, check `_get_event_cache` for the event we just made
+        # to verify it's not in the cache.
+        res = self.store._get_event_cache.get_local((event.event_id,))
+        self.assertEqual(res, None, "Event was cached when it should not have been.")
+
+        with LoggingContext(name="test") as ctx:
+            # Persist the event which should invalidate then prefill the
+            # `_get_event_cache` so we don't return stale values.
+            # Side Note: Apparently, persisting an event isn't a transaction in the
+            # sense that it is recorded in the LoggingContext
+            persistence = self.hs.get_storage_controllers().persistence
+            assert persistence is not None
+            self.get_success(
+                persistence.persist_event(
+                    event,
+                    event_context,
+                )
+            )
+
+            # Check `_get_event_cache` again and we should see the updated fact
+            # that we now have the event cached after persisting it.
+            res = self.store._get_event_cache.get_local((event.event_id,))
+            self.assertEqual(res.event, event, "Event not cached as expected.")  # type: ignore
+
+            # Try and fetch the event from the database.
+            self.get_success(self.store.get_event(event.event_id))
+
+            # Verify that the database hit was avoided.
+            self.assertEqual(
+                ctx.get_resource_usage().evt_db_fetch_count,
+                0,
+                "Database was hit, which would not happen if event was cached.",
+            )
+
     def test_invalidate_cache_by_room_id(self) -> None:
         """
         Test to make sure that all events associated with the given `(room_id,)`
@@ -188,7 +237,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
         self.event_id = res["event_id"]
 
         # Reset the event cache so the tests start with it empty
-        self.get_success(self.store._get_event_cache.clear())
+        self.store._get_event_cache.clear()
 
     def test_simple(self) -> None:
         """Test that we cache events that we pull from the DB."""
@@ -205,7 +254,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
         """
 
         # Reset the event cache
-        self.get_success(self.store._get_event_cache.clear())
+        self.store._get_event_cache.clear()
 
         with LoggingContext("test") as ctx:
             # We keep hold of the event event though we never use it.
@@ -215,7 +264,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
             self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
 
         # Reset the event cache
-        self.get_success(self.store._get_event_cache.clear())
+        self.store._get_event_cache.clear()
 
         with LoggingContext("test") as ctx:
             self.get_success(self.store.get_event(self.event_id))
@@ -390,7 +439,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase):
         self.event_id = res["event_id"]
 
         # Reset the event cache so the tests start with it empty
-        self.get_success(self.store._get_event_cache.clear())
+        self.store._get_event_cache.clear()
 
     @contextmanager
     def blocking_get_event_calls(
diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py
index 56cb49d9b5..35f77052a7 100644
--- a/tests/storage/databases/main/test_lock.py
+++ b/tests/storage/databases/main/test_lock.py
@@ -12,13 +12,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+
 from twisted.internet import defer, reactor
 from twisted.internet.base import ReactorBase
 from twisted.internet.defer import Deferred
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.server import HomeServer
-from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS
+from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS, _RENEWAL_INTERVAL_MS
 from synapse.util import Clock
 
 from tests import unittest
@@ -131,6 +132,7 @@ class LockTestCase(unittest.HomeserverTestCase):
 
         # We simulate the process getting stuck by cancelling the looping call
         # that keeps the lock active.
+        assert lock._looping_call
         lock._looping_call.stop()
 
         # Wait for the lock to timeout.
@@ -166,4 +168,338 @@ class LockTestCase(unittest.HomeserverTestCase):
         # Now call the shutdown code
         self.get_success(self.store._on_shutdown())
 
-        self.assertEqual(self.store._live_tokens, {})
+        self.assertEqual(self.store._live_lock_tokens, {})
+
+
+class ReadWriteLockTestCase(unittest.HomeserverTestCase):
+    """Test the read/write lock implementation."""
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastores().main
+
+    def test_acquire_write_contention(self) -> None:
+        """Test that we can only acquire one write lock at a time"""
+        # Track the number of tasks holding the lock.
+        # Should be at most 1.
+        in_lock = 0
+        max_in_lock = 0
+
+        release_lock: "Deferred[None]" = Deferred()
+
+        async def task() -> None:
+            nonlocal in_lock
+            nonlocal max_in_lock
+
+            lock = await self.store.try_acquire_read_write_lock(
+                "name", "key", write=True
+            )
+            if not lock:
+                return
+
+            async with lock:
+                in_lock += 1
+                max_in_lock = max(max_in_lock, in_lock)
+
+                # Block to allow other tasks to attempt to take the lock.
+                await release_lock
+
+                in_lock -= 1
+
+        # Start 3 tasks.
+        task1 = defer.ensureDeferred(task())
+        task2 = defer.ensureDeferred(task())
+        task3 = defer.ensureDeferred(task())
+
+        # Give the reactor a kick so that the database transaction returns.
+        self.pump()
+
+        release_lock.callback(None)
+
+        # Run the tasks to completion.
+        # To work around `Linearizer`s using a different reactor to sleep when
+        # contended (#12841), we call `runUntilCurrent` on
+        # `twisted.internet.reactor`, which is a different reactor to that used
+        # by the homeserver.
+        assert isinstance(reactor, ReactorBase)
+        self.get_success(task1)
+        reactor.runUntilCurrent()
+        self.get_success(task2)
+        reactor.runUntilCurrent()
+        self.get_success(task3)
+
+        # At most one task should have held the lock at a time.
+        self.assertEqual(max_in_lock, 1)
+
+    def test_acquire_multiple_reads(self) -> None:
+        """Test that we can acquire multiple read locks at a time"""
+        # Track the number of tasks holding the lock.
+        in_lock = 0
+        max_in_lock = 0
+
+        release_lock: "Deferred[None]" = Deferred()
+
+        async def task() -> None:
+            nonlocal in_lock
+            nonlocal max_in_lock
+
+            lock = await self.store.try_acquire_read_write_lock(
+                "name", "key", write=False
+            )
+            if not lock:
+                return
+
+            async with lock:
+                in_lock += 1
+                max_in_lock = max(max_in_lock, in_lock)
+
+                # Block to allow other tasks to attempt to take the lock.
+                await release_lock
+
+                in_lock -= 1
+
+        # Start 3 tasks.
+        task1 = defer.ensureDeferred(task())
+        task2 = defer.ensureDeferred(task())
+        task3 = defer.ensureDeferred(task())
+
+        # Give the reactor a kick so that the database transaction returns.
+        self.pump()
+
+        release_lock.callback(None)
+
+        # Run the tasks to completion.
+        # To work around `Linearizer`s using a different reactor to sleep when
+        # contended (#12841), we call `runUntilCurrent` on
+        # `twisted.internet.reactor`, which is a different reactor to that used
+        # by the homeserver.
+        assert isinstance(reactor, ReactorBase)
+        self.get_success(task1)
+        reactor.runUntilCurrent()
+        self.get_success(task2)
+        reactor.runUntilCurrent()
+        self.get_success(task3)
+
+        # At most one task should have held the lock at a time.
+        self.assertEqual(max_in_lock, 3)
+
+    def test_write_lock_acquired(self) -> None:
+        """Test that we can take out a write lock and that while we hold it
+        nobody else can take it out.
+        """
+        # First to acquire this lock, so it should complete
+        lock = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=True)
+        )
+        assert lock is not None
+
+        # Enter the context manager
+        self.get_success(lock.__aenter__())
+
+        # Attempting to acquire the lock again fails, as both read and write.
+        lock2 = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=True)
+        )
+        self.assertIsNone(lock2)
+
+        lock3 = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=False)
+        )
+        self.assertIsNone(lock3)
+
+        # Calling `is_still_valid` reports true.
+        self.assertTrue(self.get_success(lock.is_still_valid()))
+
+        # Drop the lock
+        self.get_success(lock.__aexit__(None, None, None))
+
+        # We can now acquire the lock again.
+        lock4 = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=True)
+        )
+        assert lock4 is not None
+        self.get_success(lock4.__aenter__())
+        self.get_success(lock4.__aexit__(None, None, None))
+
+    def test_read_lock_acquired(self) -> None:
+        """Test that we can take out a read lock and that while we hold it
+        only other reads can use it.
+        """
+        # First to acquire this lock, so it should complete
+        lock = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=False)
+        )
+        assert lock is not None
+
+        # Enter the context manager
+        self.get_success(lock.__aenter__())
+
+        # Attempting to acquire the write lock fails
+        lock2 = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=True)
+        )
+        self.assertIsNone(lock2)
+
+        # Attempting to acquire a read lock succeeds
+        lock3 = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=False)
+        )
+        assert lock3 is not None
+        self.get_success(lock3.__aenter__())
+
+        # Calling `is_still_valid` reports true.
+        self.assertTrue(self.get_success(lock.is_still_valid()))
+
+        # Drop the first lock
+        self.get_success(lock.__aexit__(None, None, None))
+
+        # Attempting to acquire the write lock still fails, as lock3 is still
+        # active.
+        lock4 = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=True)
+        )
+        self.assertIsNone(lock4)
+
+        # Drop the still open third lock
+        self.get_success(lock3.__aexit__(None, None, None))
+
+        # We can now acquire the lock again.
+        lock5 = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=True)
+        )
+        assert lock5 is not None
+        self.get_success(lock5.__aenter__())
+        self.get_success(lock5.__aexit__(None, None, None))
+
+    def test_maintain_lock(self) -> None:
+        """Test that we don't time out locks while they're still active (lock is
+        renewed in the background if the process is still alive)"""
+
+        lock = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=True)
+        )
+        assert lock is not None
+
+        self.get_success(lock.__aenter__())
+
+        # Wait for ages with the lock, we should not be able to get the lock.
+        for _ in range(10):
+            self.reactor.advance((_RENEWAL_INTERVAL_MS / 1000))
+
+        lock2 = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=True)
+        )
+        self.assertIsNone(lock2)
+
+        self.get_success(lock.__aexit__(None, None, None))
+
+    def test_timeout_lock(self) -> None:
+        """Test that we time out locks if they're not updated for ages"""
+
+        lock = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=True)
+        )
+        assert lock is not None
+
+        self.get_success(lock.__aenter__())
+
+        # We simulate the process getting stuck by cancelling the looping call
+        # that keeps the lock active.
+        assert lock._looping_call
+        lock._looping_call.stop()
+
+        # Wait for the lock to timeout.
+        self.reactor.advance(2 * _LOCK_TIMEOUT_MS / 1000)
+
+        lock2 = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=True)
+        )
+        self.assertIsNotNone(lock2)
+
+        self.assertFalse(self.get_success(lock.is_still_valid()))
+
+    def test_drop(self) -> None:
+        """Test that dropping the context manager means we stop renewing the lock"""
+
+        lock = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=True)
+        )
+        self.assertIsNotNone(lock)
+
+        del lock
+
+        # Wait for the lock to timeout.
+        self.reactor.advance(2 * _LOCK_TIMEOUT_MS / 1000)
+
+        lock2 = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=True)
+        )
+        self.assertIsNotNone(lock2)
+
+    def test_shutdown(self) -> None:
+        """Test that shutting down Synapse releases the locks"""
+        # Acquire two locks
+        lock = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key", write=True)
+        )
+        self.assertIsNotNone(lock)
+        lock2 = self.get_success(
+            self.store.try_acquire_read_write_lock("name", "key2", write=True)
+        )
+        self.assertIsNotNone(lock2)
+
+        # Now call the shutdown code
+        self.get_success(self.store._on_shutdown())
+
+        self.assertEqual(self.store._live_read_write_lock_tokens, {})
+
+    def test_acquire_multiple_locks(self) -> None:
+        """Tests that acquiring multiple locks at once works."""
+
+        # Take out multiple locks and ensure that we can't get those locks out
+        # again.
+        lock = self.get_success(
+            self.store.try_acquire_multi_read_write_lock(
+                [("name1", "key1"), ("name2", "key2")], write=True
+            )
+        )
+        self.assertIsNotNone(lock)
+
+        assert lock is not None
+        self.get_success(lock.__aenter__())
+
+        lock2 = self.get_success(
+            self.store.try_acquire_read_write_lock("name1", "key1", write=True)
+        )
+        self.assertIsNone(lock2)
+
+        lock3 = self.get_success(
+            self.store.try_acquire_read_write_lock("name2", "key2", write=False)
+        )
+        self.assertIsNone(lock3)
+
+        # Overlapping locks attempts will fail, and won't lock any locks.
+        lock4 = self.get_success(
+            self.store.try_acquire_multi_read_write_lock(
+                [("name1", "key1"), ("name3", "key3")], write=True
+            )
+        )
+        self.assertIsNone(lock4)
+
+        lock5 = self.get_success(
+            self.store.try_acquire_read_write_lock("name3", "key3", write=True)
+        )
+        self.assertIsNotNone(lock5)
+        assert lock5 is not None
+        self.get_success(lock5.__aenter__())
+        self.get_success(lock5.__aexit__(None, None, None))
+
+        # Once we release the lock we can take out the locks again.
+        self.get_success(lock.__aexit__(None, None, None))
+
+        lock6 = self.get_success(
+            self.store.try_acquire_read_write_lock("name1", "key1", write=True)
+        )
+        self.assertIsNotNone(lock6)
+        assert lock6 is not None
+        self.get_success(lock6.__aenter__())
+        self.get_success(lock6.__aexit__(None, None, None))
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 5e1324a169..cbce26a725 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -15,7 +15,7 @@ import json
 import os
 import tempfile
 from typing import List, cast
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 import yaml
 
@@ -35,12 +35,11 @@ from synapse.types import DeviceListUpdates
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 
 
 class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
     def setUp(self) -> None:
-        super(ApplicationServiceStoreTestCase, self).setUp()
+        super().setUp()
 
         self.as_yaml_files: List[str] = []
 
@@ -71,7 +70,7 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
             except Exception:
                 pass
 
-        super(ApplicationServiceStoreTestCase, self).tearDown()
+        super().tearDown()
 
     def _add_appservice(
         self, as_token: str, id: str, url: str, hs_token: str, sender: str
@@ -110,7 +109,7 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
 
 class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
     def setUp(self) -> None:
-        super(ApplicationServiceTransactionStoreTestCase, self).setUp()
+        super().setUp()
         self.as_yaml_files: List[str] = []
 
         self.hs.config.appservice.app_service_config_files = self.as_yaml_files
@@ -339,7 +338,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
 
         # we aren't testing store._base stuff here, so mock this out
         # (ignore needed because Mypy won't allow us to assign to a method otherwise)
-        self.store.get_events_as_list = Mock(return_value=make_awaitable(events))  # type: ignore[assignment]
+        self.store.get_events_as_list = AsyncMock(return_value=events)  # type: ignore[method-assign]
 
         self.get_success(self._insert_txn(self.as_list[1]["id"], 9, other_events))
         self.get_success(self._insert_txn(service.id, 10, events))
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index fd619b64d4..abf7d0564d 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -11,8 +11,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-from unittest.mock import Mock
+import logging
+from unittest.mock import AsyncMock, Mock
 
 import yaml
 
@@ -20,12 +20,18 @@ from twisted.internet.defer import Deferred, ensureDeferred
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.server import HomeServer
-from synapse.storage.background_updates import BackgroundUpdater
+from synapse.storage.background_updates import (
+    BackgroundUpdater,
+    ForeignKeyConstraint,
+    NotNullConstraint,
+    run_validate_constraint_and_delete_rows_schema_delta,
+)
+from synapse.storage.database import LoggingTransaction
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import make_awaitable, simple_async_mock
 from tests.unittest import override_config
 
 
@@ -324,6 +330,28 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
         self.update_handler.side_effect = update_short
         self.get_success(self.updates.do_next_background_update(False))
 
+    def test_failed_update_logs_exception_details(self) -> None:
+        needle = "RUH ROH RAGGY"
+
+        def failing_update(progress: JsonDict, count: int) -> int:
+            raise Exception(needle)
+
+        self.update_handler.side_effect = failing_update
+        self.update_handler.reset_mock()
+
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                values={"update_name": "test_update", "progress_json": "{}"},
+            )
+        )
+
+        with self.assertLogs(level=logging.ERROR) as logs:
+            # Expect a back-to-back RuntimeError to be raised
+            self.get_failure(self.updates.run_background_updates(False), RuntimeError)
+
+        self.assertTrue(any(needle in log for log in logs.output), logs.output)
+
 
 class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -341,8 +369,8 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
 
         # Mock out the AsyncContextManager
         class MockCM:
-            __aenter__ = simple_async_mock(return_value=None)
-            __aexit__ = simple_async_mock(return_value=None)
+            __aenter__ = AsyncMock(return_value=None)
+            __aexit__ = AsyncMock(return_value=None)
 
         self._update_ctx_manager = MockCM
 
@@ -356,9 +384,9 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
         # Register the callbacks with more mocks
         self.hs.get_module_api().register_background_update_controller_callbacks(
             on_update=self._on_update,
-            min_batch_size=Mock(return_value=make_awaitable(self._default_batch_size)),
-            default_batch_size=Mock(
-                return_value=make_awaitable(self._default_batch_size),
+            min_batch_size=AsyncMock(return_value=self._default_batch_size),
+            default_batch_size=AsyncMock(
+                return_value=self._default_batch_size,
             ),
         )
 
@@ -404,3 +432,225 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
         self.pump()
         self._update_ctx_manager.__aexit__.assert_called()
         self.get_success(do_update_d)
+
+
+class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
+    """Tests the validate contraint and delete background handlers."""
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates
+        # the base test class should have run the real bg updates for us
+        self.assertTrue(
+            self.get_success(self.updates.has_completed_background_updates())
+        )
+
+        self.store = self.hs.get_datastores().main
+
+    def test_not_null_constraint(self) -> None:
+        # Create the initial tables, where we have some invalid data.
+        """Tests adding a not null constraint."""
+        table_sql = """
+            CREATE TABLE test_constraint(
+                a INT PRIMARY KEY,
+                b INT
+            );
+        """
+        self.get_success(
+            self.store.db_pool.execute(
+                "test_not_null_constraint", lambda _: None, table_sql
+            )
+        )
+
+        # We add an index so that we can check that its correctly recreated when
+        # using SQLite.
+        index_sql = "CREATE INDEX test_index ON test_constraint(a)"
+        self.get_success(
+            self.store.db_pool.execute(
+                "test_not_null_constraint", lambda _: None, index_sql
+            )
+        )
+
+        self.get_success(
+            self.store.db_pool.simple_insert("test_constraint", {"a": 1, "b": 1})
+        )
+        self.get_success(
+            self.store.db_pool.simple_insert("test_constraint", {"a": 2, "b": None})
+        )
+        self.get_success(
+            self.store.db_pool.simple_insert("test_constraint", {"a": 3, "b": 3})
+        )
+
+        # Now lets do the migration
+
+        table2_sqlite = """
+            CREATE TABLE test_constraint2(
+                a INT PRIMARY KEY,
+                b INT,
+                CONSTRAINT test_constraint_name CHECK (b is NOT NULL)
+            );
+        """
+
+        def delta(txn: LoggingTransaction) -> None:
+            run_validate_constraint_and_delete_rows_schema_delta(
+                txn,
+                ordering=1000,
+                update_name="test_bg_update",
+                table="test_constraint",
+                constraint_name="test_constraint_name",
+                constraint=NotNullConstraint("b"),
+                sqlite_table_name="test_constraint2",
+                sqlite_table_schema=table2_sqlite,
+            )
+
+        self.get_success(
+            self.store.db_pool.runInteraction(
+                "test_not_null_constraint",
+                delta,
+            )
+        )
+
+        if isinstance(self.store.database_engine, PostgresEngine):
+            # Postgres uses a background update
+            self.updates.register_background_validate_constraint_and_delete_rows(
+                "test_bg_update",
+                table="test_constraint",
+                constraint_name="test_constraint_name",
+                constraint=NotNullConstraint("b"),
+                unique_columns=["a"],
+            )
+
+            # Tell the DataStore that it hasn't finished all updates yet
+            self.store.db_pool.updates._all_done = False
+
+            # Now let's actually drive the updates to completion
+            self.wait_for_background_updates()
+
+        # Check the correct values are in the new table.
+        rows = self.get_success(
+            self.store.db_pool.simple_select_list(
+                table="test_constraint",
+                keyvalues={},
+                retcols=("a", "b"),
+            )
+        )
+
+        self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}])
+
+        # And check that invalid rows get correctly rejected.
+        self.get_failure(
+            self.store.db_pool.simple_insert("test_constraint", {"a": 2, "b": None}),
+            exc=self.store.database_engine.module.IntegrityError,
+        )
+
+        # Check the index is still there for SQLite.
+        if isinstance(self.store.database_engine, Sqlite3Engine):
+            # Ensure the index exists in the schema.
+            self.get_success(
+                self.store.db_pool.simple_select_one_onecol(
+                    table="sqlite_master",
+                    keyvalues={"tbl_name": "test_constraint"},
+                    retcol="name",
+                )
+            )
+
+    def test_foreign_constraint(self) -> None:
+        """Tests adding a not foreign key constraint."""
+
+        # Create the initial tables, where we have some invalid data.
+        base_sql = """
+            CREATE TABLE base_table(
+                b INT PRIMARY KEY
+            );
+        """
+
+        table_sql = """
+            CREATE TABLE test_constraint(
+                a INT PRIMARY KEY,
+                b INT NOT NULL
+            );
+        """
+        self.get_success(
+            self.store.db_pool.execute(
+                "test_foreign_key_constraint", lambda _: None, base_sql
+            )
+        )
+        self.get_success(
+            self.store.db_pool.execute(
+                "test_foreign_key_constraint", lambda _: None, table_sql
+            )
+        )
+
+        self.get_success(self.store.db_pool.simple_insert("base_table", {"b": 1}))
+        self.get_success(
+            self.store.db_pool.simple_insert("test_constraint", {"a": 1, "b": 1})
+        )
+        self.get_success(
+            self.store.db_pool.simple_insert("test_constraint", {"a": 2, "b": 2})
+        )
+        self.get_success(self.store.db_pool.simple_insert("base_table", {"b": 3}))
+        self.get_success(
+            self.store.db_pool.simple_insert("test_constraint", {"a": 3, "b": 3})
+        )
+
+        table2_sqlite = """
+            CREATE TABLE test_constraint2(
+                a INT PRIMARY KEY,
+                b INT NOT NULL,
+                CONSTRAINT test_constraint_name FOREIGN KEY (b) REFERENCES base_table (b)
+            );
+        """
+
+        def delta(txn: LoggingTransaction) -> None:
+            run_validate_constraint_and_delete_rows_schema_delta(
+                txn,
+                ordering=1000,
+                update_name="test_bg_update",
+                table="test_constraint",
+                constraint_name="test_constraint_name",
+                constraint=ForeignKeyConstraint(
+                    "base_table", [("b", "b")], deferred=False
+                ),
+                sqlite_table_name="test_constraint2",
+                sqlite_table_schema=table2_sqlite,
+            )
+
+        self.get_success(
+            self.store.db_pool.runInteraction(
+                "test_foreign_key_constraint",
+                delta,
+            )
+        )
+
+        if isinstance(self.store.database_engine, PostgresEngine):
+            # Postgres uses a background update
+            self.updates.register_background_validate_constraint_and_delete_rows(
+                "test_bg_update",
+                table="test_constraint",
+                constraint_name="test_constraint_name",
+                constraint=ForeignKeyConstraint(
+                    "base_table", [("b", "b")], deferred=False
+                ),
+                unique_columns=["a"],
+            )
+
+            # Tell the DataStore that it hasn't finished all updates yet
+            self.store.db_pool.updates._all_done = False
+
+            # Now let's actually drive the updates to completion
+            self.wait_for_background_updates()
+
+        # Check the correct values are in the new table.
+        rows = self.get_success(
+            self.store.db_pool.simple_select_list(
+                table="test_constraint",
+                keyvalues={},
+                retcols=("a", "b"),
+            )
+        )
+        self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}])
+
+        # And check that invalid rows get correctly rejected.
+        self.get_failure(
+            self.store.db_pool.simple_insert("test_constraint", {"a": 2, "b": 2}),
+            exc=self.store.database_engine.module.IntegrityError,
+        )
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 7de109966d..ceb9597dd3 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -120,7 +120,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
             self.store.get_latest_event_ids_in_room(self.room_id)
         )
 
-        self.assertEqual(latest_event_ids, [event_id_4])
+        self.assertEqual(latest_event_ids, {event_id_4})
 
     def test_basic_cleanup(self) -> None:
         """Test that extremities are correctly calculated in the presence of
@@ -147,7 +147,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         latest_event_ids = self.get_success(
             self.store.get_latest_event_ids_in_room(self.room_id)
         )
-        self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b})
+        self.assertEqual(latest_event_ids, {event_id_a, event_id_b})
 
         # Run the background update and check it did the right thing
         self.run_background_update()
@@ -155,7 +155,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         latest_event_ids = self.get_success(
             self.store.get_latest_event_ids_in_room(self.room_id)
         )
-        self.assertEqual(latest_event_ids, [event_id_b])
+        self.assertEqual(latest_event_ids, {event_id_b})
 
     def test_chain_of_fail_cleanup(self) -> None:
         """Test that extremities are correctly calculated in the presence of
@@ -185,7 +185,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         latest_event_ids = self.get_success(
             self.store.get_latest_event_ids_in_room(self.room_id)
         )
-        self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b})
+        self.assertEqual(latest_event_ids, {event_id_a, event_id_b})
 
         # Run the background update and check it did the right thing
         self.run_background_update()
@@ -193,7 +193,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         latest_event_ids = self.get_success(
             self.store.get_latest_event_ids_in_room(self.room_id)
         )
-        self.assertEqual(latest_event_ids, [event_id_b])
+        self.assertEqual(latest_event_ids, {event_id_b})
 
     def test_forked_graph_cleanup(self) -> None:
         r"""Test that extremities are correctly calculated in the presence of
@@ -240,7 +240,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         latest_event_ids = self.get_success(
             self.store.get_latest_event_ids_in_room(self.room_id)
         )
-        self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b, event_id_c})
+        self.assertEqual(latest_event_ids, {event_id_a, event_id_b, event_id_c})
 
         # Run the background update and check it did the right thing
         self.run_background_update()
@@ -248,7 +248,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         latest_event_ids = self.get_success(
             self.store.get_latest_event_ids_in_room(self.room_id)
         )
-        self.assertEqual(set(latest_event_ids), {event_id_b, event_id_c})
+        self.assertEqual(latest_event_ids, {event_id_b, event_id_c})
 
 
 class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index cd0079871c..6b9692c486 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 from typing import Any, Dict
-from unittest.mock import Mock
+from unittest.mock import AsyncMock
 
 from parameterized import parameterized
 
@@ -30,7 +30,6 @@ from synapse.util import Clock
 
 from tests import unittest
 from tests.server import make_request
-from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 
 
@@ -66,15 +65,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         )
 
         r = result[(user_id, device_id)]
-        self.assertDictContainsSubset(
+        self.assertLessEqual(
             {
                 "user_id": user_id,
                 "device_id": device_id,
                 "ip": "ip",
                 "user_agent": "user_agent",
                 "last_seen": 12345678000,
-            },
-            r,
+            }.items(),
+            r.items(),
         )
 
     def test_insert_new_client_ip_none_device_id(self) -> None:
@@ -443,9 +442,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         lots_of_users = 100
         user_id = "@user:server"
 
-        self.store.get_monthly_active_count = Mock(
-            return_value=make_awaitable(lots_of_users)
-        )
+        self.store.get_monthly_active_count = AsyncMock(return_value=lots_of_users)
         self.get_success(
             self.store.insert_client_ip(
                 user_id, "access_token", "ip", "user_agent", "device_id"
@@ -529,15 +526,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         )
 
         r = result[(user_id, device_id)]
-        self.assertDictContainsSubset(
+        self.assertLessEqual(
             {
                 "user_id": user_id,
                 "device_id": device_id,
                 "ip": None,
                 "user_agent": None,
                 "last_seen": None,
-            },
-            r,
+            }.items(),
+            r.items(),
         )
 
         # Register the background update to run again.
@@ -564,15 +561,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         )
 
         r = result[(user_id, device_id)]
-        self.assertDictContainsSubset(
+        self.assertLessEqual(
             {
                 "user_id": user_id,
                 "device_id": device_id,
                 "ip": "ip",
                 "user_agent": "user_agent",
                 "last_seen": 0,
-            },
-            r,
+            }.items(),
+            r.items(),
         )
 
     def test_old_user_ips_pruned(self) -> None:
@@ -643,15 +640,80 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         )
 
         r = result2[(user_id, device_id)]
-        self.assertDictContainsSubset(
+        self.assertLessEqual(
             {
                 "user_id": user_id,
                 "device_id": device_id,
                 "ip": "ip",
                 "user_agent": "user_agent",
                 "last_seen": 0,
-            },
-            r,
+            }.items(),
+            r.items(),
+        )
+
+    def test_invalid_user_agents_are_ignored(self) -> None:
+        # First make sure we have completed all updates.
+        self.wait_for_background_updates()
+
+        user_id1 = "@user1:id"
+        user_id2 = "@user2:id"
+        device_id1 = "MY_DEVICE1"
+        device_id2 = "MY_DEVICE2"
+        access_token1 = "access_token1"
+        access_token2 = "access_token2"
+
+        # Insert a user IP 1
+        self.get_success(
+            self.store.store_device(
+                user_id1,
+                device_id1,
+                "display name1",
+            )
+        )
+        # Insert a user IP 2
+        self.get_success(
+            self.store.store_device(
+                user_id2,
+                device_id2,
+                "display name2",
+            )
+        )
+
+        self.get_success(
+            self.store.insert_client_ip(
+                user_id1, access_token1, "ip", "sync-v3-proxy-", device_id1
+            )
+        )
+        self.get_success(
+            self.store.insert_client_ip(
+                user_id2, access_token2, "ip", "user_agent", device_id2
+            )
+        )
+        # Force persisting to disk
+        self.reactor.advance(200)
+
+        # We should see that in the DB
+        result = self.get_success(
+            self.store.db_pool.simple_select_list(
+                table="user_ips",
+                keyvalues={},
+                retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
+                desc="get_user_ip_and_agents",
+            )
+        )
+
+        # ensure user1 is filtered out
+        self.assertEqual(
+            result,
+            [
+                {
+                    "access_token": access_token2,
+                    "ip": "ip",
+                    "user_agent": "user_agent",
+                    "device_id": device_id2,
+                    "last_seen": 0,
+                }
+            ],
         )
 
 
@@ -715,13 +777,13 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
             self.store.get_last_client_ip_by_device(self.user_id, device_id)
         )
         r = result[(self.user_id, device_id)]
-        self.assertDictContainsSubset(
+        self.assertLessEqual(
             {
                 "user_id": self.user_id,
                 "device_id": device_id,
                 "ip": expected_ip,
                 "user_agent": "Mozzila pizza",
                 "last_seen": 123456100,
-            },
-            r,
+            }.items(),
+            r.items(),
         )
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index f03807c8f9..58ab41cf26 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -58,13 +58,13 @@ class DeviceStoreTestCase(HomeserverTestCase):
 
         res = self.get_success(self.store.get_device("user_id", "device_id"))
         assert res is not None
-        self.assertDictContainsSubset(
+        self.assertLessEqual(
             {
                 "user_id": "user_id",
                 "device_id": "device_id",
                 "display_name": "display_name",
-            },
-            res,
+            }.items(),
+            res.items(),
         )
 
     def test_get_devices_by_user(self) -> None:
@@ -80,21 +80,21 @@ class DeviceStoreTestCase(HomeserverTestCase):
 
         res = self.get_success(self.store.get_devices_by_user("user_id"))
         self.assertEqual(2, len(res.keys()))
-        self.assertDictContainsSubset(
+        self.assertLessEqual(
             {
                 "user_id": "user_id",
                 "device_id": "device1",
                 "display_name": "display_name 1",
-            },
-            res["device1"],
+            }.items(),
+            res["device1"].items(),
         )
-        self.assertDictContainsSubset(
+        self.assertLessEqual(
             {
                 "user_id": "user_id",
                 "device_id": "device2",
                 "display_name": "display_name 2",
-            },
-            res["device2"],
+            }.items(),
+            res["device2"].items(),
         )
 
     def test_count_devices_by_users(self) -> None:
diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py
index 9cb326d90a..f6df31aba4 100644
--- a/tests/storage/test_e2e_room_keys.py
+++ b/tests/storage/test_e2e_room_keys.py
@@ -31,7 +31,7 @@ room_key: RoomKey = {
 
 class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-        hs = self.setup_test_homeserver("server", federation_http_client=None)
+        hs = self.setup_test_homeserver("server")
         self.store = hs.get_datastores().main
         return hs
 
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 5fde3b9c78..2033377b52 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -38,7 +38,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase):
         self.assertIn("user", res)
         self.assertIn("device", res["user"])
         dev = res["user"]["device"]
-        self.assertDictContainsSubset(json, dev)
+        self.assertLessEqual(json.items(), dev.items())
 
     def test_reupload_key(self) -> None:
         now = 1470174257070
@@ -71,8 +71,12 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase):
         self.assertIn("user", res)
         self.assertIn("device", res["user"])
         dev = res["user"]["device"]
-        self.assertDictContainsSubset(
-            {"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev
+        self.assertLessEqual(
+            {
+                "key": "value",
+                "unsigned": {"device_display_name": "display_name"},
+            }.items(),
+            dev.items(),
         )
 
     def test_multiple_devices(self) -> None:
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index e39b63edac..b55dd07f14 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -401,7 +401,10 @@ class EventChainStoreTestCase(HomeserverTestCase):
             assert persist_events_store is not None
             persist_events_store._store_event_txn(
                 txn,
-                [(e, EventContext(self.hs.get_storage_controllers())) for e in events],
+                [
+                    (e, EventContext(self.hs.get_storage_controllers(), {}))
+                    for e in events
+                ],
             )
 
             # Actually call the function that calculates the auth chain stuff.
@@ -661,7 +664,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
 
         # Add a bunch of state so that it takes multiple iterations of the
         # background update to process the room.
-        for i in range(0, 150):
+        for i in range(150):
             self.helper.send_state(
                 room_id, event_type="m.test", body={"index": i}, tok=self.token
             )
@@ -715,12 +718,12 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
 
         # Add a bunch of state so that it takes multiple iterations of the
         # background update to process the room.
-        for i in range(0, 150):
+        for i in range(150):
             self.helper.send_state(
                 room_id1, event_type="m.test", body={"index": i}, tok=self.token
             )
 
-        for i in range(0, 150):
+        for i in range(150):
             self.helper.send_state(
                 room_id2, event_type="m.test", body={"index": i}, tok=self.token
             )
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 81e50bdd55..d3e20f44b2 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -13,7 +13,19 @@
 # limitations under the License.
 
 import datetime
-from typing import Dict, List, Tuple, Union, cast
+from typing import (
+    Collection,
+    Dict,
+    FrozenSet,
+    Iterable,
+    List,
+    Mapping,
+    Set,
+    Tuple,
+    TypeVar,
+    Union,
+    cast,
+)
 
 import attr
 from parameterized import parameterized
@@ -38,6 +50,138 @@ from synapse.util import Clock, json_encoder
 import tests.unittest
 import tests.utils
 
+# The silly auth graph we use to test the auth difference algorithm,
+# where the top are the most recent events.
+#
+#   A   B
+#    \ /
+#  D  E
+#  \  |
+#   ` F   C
+#     |  /|
+#     G ´ |
+#     | \ |
+#     H   I
+#     |   |
+#     K   J
+
+AUTH_GRAPH: Dict[str, List[str]] = {
+    "a": ["e"],
+    "b": ["e"],
+    "c": ["g", "i"],
+    "d": ["f"],
+    "e": ["f"],
+    "f": ["g"],
+    "g": ["h", "i"],
+    "h": ["k"],
+    "i": ["j"],
+    "k": [],
+    "j": [],
+}
+
+DEPTH_GRAPH = {
+    "a": 7,
+    "b": 7,
+    "c": 4,
+    "d": 6,
+    "e": 6,
+    "f": 5,
+    "g": 3,
+    "h": 2,
+    "i": 2,
+    "k": 1,
+    "j": 1,
+}
+
+T = TypeVar("T")
+
+
+def get_all_topologically_sorted_orders(
+    nodes: Iterable[T],
+    graph: Mapping[T, Collection[T]],
+) -> List[List[T]]:
+    """Given a set of nodes and a graph, return all possible topological
+    orderings.
+    """
+
+    # This is implemented by Kahn's algorithm, and forking execution each time
+    # we have a choice over which node to consider next.
+
+    degree_map = {node: 0 for node in nodes}
+    reverse_graph: Dict[T, Set[T]] = {}
+
+    for node, edges in graph.items():
+        if node not in degree_map:
+            continue
+
+        for edge in set(edges):
+            if edge in degree_map:
+                degree_map[node] += 1
+
+            reverse_graph.setdefault(edge, set()).add(node)
+        reverse_graph.setdefault(node, set())
+
+    zero_degree = [node for node, degree in degree_map.items() if degree == 0]
+
+    return _get_all_topologically_sorted_orders_inner(
+        reverse_graph, zero_degree, degree_map
+    )
+
+
+def _get_all_topologically_sorted_orders_inner(
+    reverse_graph: Dict[T, Set[T]],
+    zero_degree: List[T],
+    degree_map: Dict[T, int],
+) -> List[List[T]]:
+    new_paths = []
+
+    # Rather than only choosing *one* item from the list of nodes with zero
+    # degree, we "fork" execution and run the algorithm for each node in the
+    # zero degree.
+    for node in zero_degree:
+        new_degree_map = degree_map.copy()
+        new_zero_degree = zero_degree.copy()
+        new_zero_degree.remove(node)
+
+        for edge in reverse_graph.get(node, []):
+            if edge in new_degree_map:
+                new_degree_map[edge] -= 1
+                if new_degree_map[edge] == 0:
+                    new_zero_degree.append(edge)
+
+        paths = _get_all_topologically_sorted_orders_inner(
+            reverse_graph, new_zero_degree, new_degree_map
+        )
+        for path in paths:
+            path.insert(0, node)
+
+        new_paths.extend(paths)
+
+    if not new_paths:
+        return [[]]
+
+    return new_paths
+
+
+def get_all_topologically_consistent_subsets(
+    nodes: Iterable[T],
+    graph: Mapping[T, Collection[T]],
+) -> Set[FrozenSet[T]]:
+    """Get all subsets of the graph where if node N is in the subgraph, then all
+    nodes that can reach that node (i.e. for all X there exists a path X -> N)
+    are in the subgraph.
+    """
+    all_topological_orderings = get_all_topologically_sorted_orders(nodes, graph)
+
+    graph_subsets = set()
+    for ordering in all_topological_orderings:
+        ordering.reverse()
+
+        for idx in range(len(ordering)):
+            graph_subsets.add(frozenset(ordering[:idx]))
+
+    return graph_subsets
+
 
 @attr.s(auto_attribs=True, frozen=True, slots=True)
 class _BackfillSetupInfo:
@@ -83,7 +227,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
                 (room_id, event_id),
             )
 
-        for i in range(0, 20):
+        for i in range(20):
             self.get_success(
                 self.store.db_pool.runInteraction("insert", insert_event, i)
             )
@@ -91,7 +235,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         # this should get the last ten
         r = self.get_success(self.store.get_prev_events_for_room(room_id))
         self.assertEqual(10, len(r))
-        for i in range(0, 10):
+        for i in range(10):
             self.assertEqual("$event_%i:local" % (19 - i), r[i])
 
     def test_get_rooms_with_many_extremities(self) -> None:
@@ -99,8 +243,32 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         room2 = "#room2"
         room3 = "#room3"
 
-        def insert_event(txn: Cursor, i: int, room_id: str) -> None:
+        def insert_event(txn: LoggingTransaction, i: int, room_id: str) -> None:
             event_id = "$event_%i:local" % i
+
+            # We need to insert into events table to get around the foreign key constraint.
+            self.store.db_pool.simple_insert_txn(
+                txn,
+                table="events",
+                values={
+                    "instance_name": "master",
+                    "stream_ordering": self.store._stream_id_gen.get_next_txn(txn),
+                    "topological_ordering": 1,
+                    "depth": 1,
+                    "event_id": event_id,
+                    "room_id": room_id,
+                    "type": EventTypes.Message,
+                    "processed": True,
+                    "outlier": False,
+                    "origin_server_ts": 0,
+                    "received_ts": 0,
+                    "sender": "@user:local",
+                    "contains_url": False,
+                    "state_key": None,
+                    "rejection_reason": None,
+                },
+            )
+
             txn.execute(
                 (
                     "INSERT INTO event_forward_extremities (room_id, event_id) "
@@ -109,15 +277,19 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
                 (room_id, event_id),
             )
 
-        for i in range(0, 20):
+        for i in range(20):
             self.get_success(
                 self.store.db_pool.runInteraction("insert", insert_event, i, room1)
             )
             self.get_success(
-                self.store.db_pool.runInteraction("insert", insert_event, i, room2)
+                self.store.db_pool.runInteraction(
+                    "insert", insert_event, i + 100, room2
+                )
             )
             self.get_success(
-                self.store.db_pool.runInteraction("insert", insert_event, i, room3)
+                self.store.db_pool.runInteraction(
+                    "insert", insert_event, i + 200, room3
+                )
             )
 
         # Test simple case
@@ -144,49 +316,6 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
     def _setup_auth_chain(self, use_chain_cover_index: bool) -> str:
         room_id = "@ROOM:local"
 
-        # The silly auth graph we use to test the auth difference algorithm,
-        # where the top are the most recent events.
-        #
-        #   A   B
-        #    \ /
-        #  D  E
-        #  \  |
-        #   ` F   C
-        #     |  /|
-        #     G ´ |
-        #     | \ |
-        #     H   I
-        #     |   |
-        #     K   J
-
-        auth_graph: Dict[str, List[str]] = {
-            "a": ["e"],
-            "b": ["e"],
-            "c": ["g", "i"],
-            "d": ["f"],
-            "e": ["f"],
-            "f": ["g"],
-            "g": ["h", "i"],
-            "h": ["k"],
-            "i": ["j"],
-            "k": [],
-            "j": [],
-        }
-
-        depth_map = {
-            "a": 7,
-            "b": 7,
-            "c": 4,
-            "d": 6,
-            "e": 6,
-            "f": 5,
-            "g": 3,
-            "h": 2,
-            "i": 2,
-            "k": 1,
-            "j": 1,
-        }
-
         # Mark the room as maybe having a cover index.
 
         def store_room(txn: LoggingTransaction) -> None:
@@ -210,9 +339,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         def insert_event(txn: LoggingTransaction) -> None:
             stream_ordering = 0
 
-            for event_id in auth_graph:
+            for event_id in AUTH_GRAPH:
                 stream_ordering += 1
-                depth = depth_map[event_id]
+                depth = DEPTH_GRAPH[event_id]
 
                 self.store.db_pool.simple_insert_txn(
                     txn,
@@ -232,8 +361,8 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
             self.persist_events._persist_event_auth_chain_txn(
                 txn,
                 [
-                    cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
-                    for event_id in auth_graph
+                    cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id]))
+                    for event_id in AUTH_GRAPH
                 ],
             )
 
@@ -316,7 +445,51 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         room_id = self._setup_auth_chain(use_chain_cover_index)
 
         # Now actually test that various combinations give the right result:
+        self.assert_auth_diff_is_expected(room_id)
+
+    @parameterized.expand(
+        [
+            [graph_subset]
+            for graph_subset in get_all_topologically_consistent_subsets(
+                AUTH_GRAPH, AUTH_GRAPH
+            )
+        ]
+    )
+    def test_auth_difference_partial(self, graph_subset: Collection[str]) -> None:
+        """Test that if we only have a chain cover index on a partial subset of
+        the room we still get the correct auth chain difference.
+
+        We do this by removing the chain cover index for every valid subset of the
+        graph.
+        """
+        room_id = self._setup_auth_chain(True)
+
+        for event_id in graph_subset:
+            # Remove chain cover from that event.
+            self.get_success(
+                self.store.db_pool.simple_delete(
+                    table="event_auth_chains",
+                    keyvalues={"event_id": event_id},
+                    desc="test_auth_difference_partial_remove",
+                )
+            )
+            self.get_success(
+                self.store.db_pool.simple_insert(
+                    table="event_auth_chain_to_calculate",
+                    values={
+                        "event_id": event_id,
+                        "room_id": room_id,
+                        "type": "",
+                        "state_key": "",
+                    },
+                    desc="test_auth_difference_partial_remove",
+                )
+            )
+
+        self.assert_auth_diff_is_expected(room_id)
 
+    def assert_auth_diff_is_expected(self, room_id: str) -> None:
+        """Assert the auth chain difference returns the correct answers."""
         difference = self.get_success(
             self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
         )
@@ -924,215 +1097,42 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
         self.assertEqual(backfill_event_ids, ["b3", "b2", "b1"])
 
-    def _setup_room_for_insertion_backfill_tests(self) -> _BackfillSetupInfo:
+    def test_get_event_ids_with_failed_pull_attempts(self) -> None:
         """
-        Sets up a room with various insertion event backward extremities to test
-        backfill functions against.
-
-        Returns:
-            _BackfillSetupInfo including the `room_id` to test against and
-            `depth_map` of events in the room
+        Test to make sure we properly get event_ids based on whether they have any
+        failed pull attempts.
         """
-        room_id = "!backfill-room-test:some-host"
-
-        depth_map: Dict[str, int] = {
-            "1": 1,
-            "2": 2,
-            "insertion_eventA": 3,
-            "3": 4,
-            "insertion_eventB": 5,
-            "4": 6,
-            "5": 7,
-        }
-
-        def populate_db(txn: LoggingTransaction) -> None:
-            # Insert the room to satisfy the foreign key constraint of
-            # `event_failed_pull_attempts`
-            self.store.db_pool.simple_insert_txn(
-                txn,
-                "rooms",
-                {
-                    "room_id": room_id,
-                    "creator": "room_creator_user_id",
-                    "is_public": True,
-                    "room_version": "6",
-                },
-            )
-
-            # Insert our server events
-            stream_ordering = 0
-            for event_id, depth in depth_map.items():
-                self.store.db_pool.simple_insert_txn(
-                    txn,
-                    table="events",
-                    values={
-                        "event_id": event_id,
-                        "type": EventTypes.MSC2716_INSERTION
-                        if event_id.startswith("insertion_event")
-                        else "test_regular_type",
-                        "room_id": room_id,
-                        "depth": depth,
-                        "topological_ordering": depth,
-                        "stream_ordering": stream_ordering,
-                        "processed": True,
-                        "outlier": False,
-                    },
-                )
-
-                if event_id.startswith("insertion_event"):
-                    self.store.db_pool.simple_insert_txn(
-                        txn,
-                        table="insertion_event_extremities",
-                        values={
-                            "event_id": event_id,
-                            "room_id": room_id,
-                        },
-                    )
-
-                stream_ordering += 1
-
-        self.get_success(
-            self.store.db_pool.runInteraction(
-                "_setup_room_for_insertion_backfill_tests_populate_db",
-                populate_db,
-            )
-        )
-
-        return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map)
-
-    def test_get_insertion_event_backward_extremities_in_room(self) -> None:
-        """
-        Test to make sure only insertion event backward extremities that are
-        older and come before the `current_depth` are returned.
-        """
-        setup_info = self._setup_room_for_insertion_backfill_tests()
-        room_id = setup_info.room_id
-        depth_map = setup_info.depth_map
-
-        # Try at "insertion_eventB"
-        backfill_points = self.get_success(
-            self.store.get_insertion_event_backward_extremities_in_room(
-                room_id, depth_map["insertion_eventB"], limit=100
-            )
-        )
-        backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
-        self.assertEqual(backfill_event_ids, ["insertion_eventB", "insertion_eventA"])
-
-        # Try at "insertion_eventA"
-        backfill_points = self.get_success(
-            self.store.get_insertion_event_backward_extremities_in_room(
-                room_id, depth_map["insertion_eventA"], limit=100
-            )
-        )
-        backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
-        # Event "2" has a depth of 2 but is not included here because we only
-        # know the approximate depth of 5 from our event "3".
-        self.assertListEqual(backfill_event_ids, ["insertion_eventA"])
-
-    def test_get_insertion_event_backward_extremities_in_room_excludes_events_we_have_attempted(
-        self,
-    ) -> None:
-        """
-        Test to make sure that insertion events we have attempted to backfill
-        (and within backoff timeout duration) do not show up as an event to
-        backfill again.
-        """
-        setup_info = self._setup_room_for_insertion_backfill_tests()
-        room_id = setup_info.room_id
-        depth_map = setup_info.depth_map
-
-        # Record some attempts to backfill these events which will make
-        # `get_insertion_event_backward_extremities_in_room` exclude them
-        # because we haven't passed the backoff interval.
-        self.get_success(
-            self.store.record_event_failed_pull_attempt(
-                room_id, "insertion_eventA", "fake cause"
-            )
-        )
-
-        # No time has passed since we attempted to backfill ^
-
-        # Try at "insertion_eventB"
-        backfill_points = self.get_success(
-            self.store.get_insertion_event_backward_extremities_in_room(
-                room_id, depth_map["insertion_eventB"], limit=100
-            )
-        )
-        backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
-        # Only the backfill points that we didn't record earlier exist here.
-        self.assertEqual(backfill_event_ids, ["insertion_eventB"])
-
-    def test_get_insertion_event_backward_extremities_in_room_attempted_event_retry_after_backoff_duration(
-        self,
-    ) -> None:
-        """
-        Test to make sure after we fake attempt to backfill event
-        "insertion_eventA" many times, we can see retry and see the
-        "insertion_eventA" again after the backoff timeout duration has
-        exceeded.
-        """
-        setup_info = self._setup_room_for_insertion_backfill_tests()
-        room_id = setup_info.room_id
-        depth_map = setup_info.depth_map
+        # Create the room
+        user_id = self.register_user("alice", "test")
+        tok = self.login("alice", "test")
+        room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
 
-        # Record some attempts to backfill these events which will make
-        # `get_backfill_points_in_room` exclude them because we
-        # haven't passed the backoff interval.
-        self.get_success(
-            self.store.record_event_failed_pull_attempt(
-                room_id, "insertion_eventB", "fake cause"
-            )
-        )
-        self.get_success(
-            self.store.record_event_failed_pull_attempt(
-                room_id, "insertion_eventA", "fake cause"
-            )
-        )
-        self.get_success(
-            self.store.record_event_failed_pull_attempt(
-                room_id, "insertion_eventA", "fake cause"
-            )
-        )
         self.get_success(
             self.store.record_event_failed_pull_attempt(
-                room_id, "insertion_eventA", "fake cause"
+                room_id, "$failed_event_id1", "fake cause"
             )
         )
         self.get_success(
             self.store.record_event_failed_pull_attempt(
-                room_id, "insertion_eventA", "fake cause"
+                room_id, "$failed_event_id2", "fake cause"
             )
         )
 
-        # Now advance time by 2 hours and we should only be able to see
-        # "insertion_eventB" because we have waited long enough for the single
-        # attempt (2^1 hours) but we still shouldn't see "insertion_eventA"
-        # because we haven't waited long enough for this many attempts.
-        self.reactor.advance(datetime.timedelta(hours=2).total_seconds())
-
-        # Try at "insertion_eventA" and make sure that "insertion_eventA" is not
-        # in the list because we've already attempted many times
-        backfill_points = self.get_success(
-            self.store.get_insertion_event_backward_extremities_in_room(
-                room_id, depth_map["insertion_eventA"], limit=100
+        event_ids_with_failed_pull_attempts = self.get_success(
+            self.store.get_event_ids_with_failed_pull_attempts(
+                event_ids=[
+                    "$failed_event_id1",
+                    "$fresh_event_id1",
+                    "$failed_event_id2",
+                    "$fresh_event_id2",
+                ]
             )
         )
-        backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
-        self.assertEqual(backfill_event_ids, [])
-
-        # Now advance time by 20 hours (above 2^4 because we made 4 attemps) and
-        # see if we can now backfill it
-        self.reactor.advance(datetime.timedelta(hours=20).total_seconds())
 
-        # Try at "insertion_eventA" again after we advanced enough time and we
-        # should see "insertion_eventA" again
-        backfill_points = self.get_success(
-            self.store.get_insertion_event_backward_extremities_in_room(
-                room_id, depth_map["insertion_eventA"], limit=100
-            )
+        self.assertEqual(
+            event_ids_with_failed_pull_attempts,
+            {"$failed_event_id1", "$failed_event_id2"},
         )
-        backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
-        self.assertEqual(backfill_event_ids, ["insertion_eventA"])
 
     def test_get_event_ids_to_not_pull_from_backoff(self) -> None:
         """
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
deleted file mode 100644
index 5d7c13e6d0..0000000000
--- a/tests/storage/test_keys.py
+++ /dev/null
@@ -1,137 +0,0 @@
-# Copyright 2017 Vector Creations Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import signedjson.key
-import signedjson.types
-import unpaddedbase64
-
-from synapse.storage.keys import FetchKeyResult
-
-import tests.unittest
-
-
-def decode_verify_key_base64(
-    key_id: str, key_base64: str
-) -> signedjson.types.VerifyKey:
-    key_bytes = unpaddedbase64.decode_base64(key_base64)
-    return signedjson.key.decode_verify_key_bytes(key_id, key_bytes)
-
-
-KEY_1 = decode_verify_key_base64(
-    "ed25519:key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw"
-)
-KEY_2 = decode_verify_key_base64(
-    "ed25519:key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
-)
-
-
-class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
-    def test_get_server_signature_keys(self) -> None:
-        store = self.hs.get_datastores().main
-
-        key_id_1 = "ed25519:key1"
-        key_id_2 = "ed25519:KEY_ID_2"
-        self.get_success(
-            store.store_server_signature_keys(
-                "from_server",
-                10,
-                {
-                    ("server1", key_id_1): FetchKeyResult(KEY_1, 100),
-                    ("server1", key_id_2): FetchKeyResult(KEY_2, 200),
-                },
-            )
-        )
-
-        res = self.get_success(
-            store.get_server_signature_keys(
-                [
-                    ("server1", key_id_1),
-                    ("server1", key_id_2),
-                    ("server1", "ed25519:key3"),
-                ]
-            )
-        )
-
-        self.assertEqual(len(res.keys()), 3)
-        res1 = res[("server1", key_id_1)]
-        self.assertEqual(res1.verify_key, KEY_1)
-        self.assertEqual(res1.verify_key.version, "key1")
-        self.assertEqual(res1.valid_until_ts, 100)
-
-        res2 = res[("server1", key_id_2)]
-        self.assertEqual(res2.verify_key, KEY_2)
-        # version comes from the ID it was stored with
-        self.assertEqual(res2.verify_key.version, "KEY_ID_2")
-        self.assertEqual(res2.valid_until_ts, 200)
-
-        # non-existent result gives None
-        self.assertIsNone(res[("server1", "ed25519:key3")])
-
-    def test_cache(self) -> None:
-        """Check that updates correctly invalidate the cache."""
-
-        store = self.hs.get_datastores().main
-
-        key_id_1 = "ed25519:key1"
-        key_id_2 = "ed25519:key2"
-
-        self.get_success(
-            store.store_server_signature_keys(
-                "from_server",
-                0,
-                {
-                    ("srv1", key_id_1): FetchKeyResult(KEY_1, 100),
-                    ("srv1", key_id_2): FetchKeyResult(KEY_2, 200),
-                },
-            )
-        )
-
-        res = self.get_success(
-            store.get_server_signature_keys([("srv1", key_id_1), ("srv1", key_id_2)])
-        )
-        self.assertEqual(len(res.keys()), 2)
-
-        res1 = res[("srv1", key_id_1)]
-        self.assertEqual(res1.verify_key, KEY_1)
-        self.assertEqual(res1.valid_until_ts, 100)
-
-        res2 = res[("srv1", key_id_2)]
-        self.assertEqual(res2.verify_key, KEY_2)
-        self.assertEqual(res2.valid_until_ts, 200)
-
-        # we should be able to look up the same thing again without a db hit
-        res = self.get_success(store.get_server_signature_keys([("srv1", key_id_1)]))
-        self.assertEqual(len(res.keys()), 1)
-        self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)
-
-        new_key_2 = signedjson.key.get_verify_key(
-            signedjson.key.generate_signing_key("key2")
-        )
-        d = store.store_server_signature_keys(
-            "from_server", 10, {("srv1", key_id_2): FetchKeyResult(new_key_2, 300)}
-        )
-        self.get_success(d)
-
-        res = self.get_success(
-            store.get_server_signature_keys([("srv1", key_id_1), ("srv1", key_id_2)])
-        )
-        self.assertEqual(len(res.keys()), 2)
-
-        res1 = res[("srv1", key_id_1)]
-        self.assertEqual(res1.verify_key, KEY_1)
-        self.assertEqual(res1.valid_until_ts, 100)
-
-        res2 = res[("srv1", key_id_2)]
-        self.assertEqual(res2.verify_key, new_key_2)
-        self.assertEqual(res2.valid_until_ts, 300)
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index 27f450e22d..b8823d6993 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -20,7 +20,7 @@ from tests import unittest
 
 class DataStoreTestCase(unittest.HomeserverTestCase):
     def setUp(self) -> None:
-        super(DataStoreTestCase, self).setUp()
+        super().setUp()
 
         self.store = self.hs.get_datastores().main
 
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 2827738379..49366440ce 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from typing import Any, Dict, List
-from unittest.mock import Mock
+from unittest.mock import AsyncMock
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -21,7 +21,6 @@ from synapse.server import HomeServer
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 from tests.unittest import default_config, override_config
 
 FORTY_DAYS = 40 * 24 * 60 * 60
@@ -253,7 +252,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         )
         self.get_success(d)
 
-        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))  # type: ignore[assignment]
+        self.store.upsert_monthly_active_user = AsyncMock(return_value=None)  # type: ignore[method-assign]
 
         d = self.store.populate_monthly_active_users(user_id)
         self.get_success(d)
@@ -261,24 +260,22 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         self.store.upsert_monthly_active_user.assert_not_called()
 
     def test_populate_monthly_users_should_update(self) -> None:
-        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))  # type: ignore[assignment]
+        self.store.upsert_monthly_active_user = AsyncMock(return_value=None)  # type: ignore[method-assign]
 
-        self.store.is_trial_user = Mock(return_value=make_awaitable(False))  # type: ignore[assignment]
+        self.store.is_trial_user = AsyncMock(return_value=False)  # type: ignore[method-assign]
 
-        self.store.user_last_seen_monthly_active = Mock(
-            return_value=make_awaitable(None)
-        )
+        self.store.user_last_seen_monthly_active = AsyncMock(return_value=None)
         d = self.store.populate_monthly_active_users("user_id")
         self.get_success(d)
 
         self.store.upsert_monthly_active_user.assert_called_once()
 
     def test_populate_monthly_users_should_not_update(self) -> None:
-        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))  # type: ignore[assignment]
+        self.store.upsert_monthly_active_user = AsyncMock(return_value=None)  # type: ignore[method-assign]
 
-        self.store.is_trial_user = Mock(return_value=make_awaitable(False))  # type: ignore[assignment]
-        self.store.user_last_seen_monthly_active = Mock(
-            return_value=make_awaitable(self.hs.get_clock().time_msec())
+        self.store.is_trial_user = AsyncMock(return_value=False)  # type: ignore[method-assign]
+        self.store.user_last_seen_monthly_active = AsyncMock(
+            return_value=self.hs.get_clock().time_msec()
         )
 
         d = self.store.populate_monthly_active_users("user_id")
@@ -359,7 +356,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
 
     @override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
     def test_no_users_when_not_tracking(self) -> None:
-        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))  # type: ignore[assignment]
+        self.store.upsert_monthly_active_user = AsyncMock(return_value=None)  # type: ignore[method-assign]
 
         self.get_success(self.store.populate_monthly_active_users("@user:sever"))
 
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index f9cf0fcb82..95f99f4130 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.server import HomeServer
@@ -35,18 +36,14 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(
             "Frank",
-            (
-                self.get_success(
-                    self.store.get_profile_displayname(self.u_frank.localpart)
-                )
-            ),
+            (self.get_success(self.store.get_profile_displayname(self.u_frank))),
         )
 
         # test set to None
         self.get_success(self.store.set_profile_displayname(self.u_frank, None))
 
         self.assertIsNone(
-            self.get_success(self.store.get_profile_displayname(self.u_frank.localpart))
+            self.get_success(self.store.get_profile_displayname(self.u_frank))
         )
 
     def test_avatar_url(self) -> None:
@@ -58,18 +55,14 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(
             "http://my.site/here",
-            (
-                self.get_success(
-                    self.store.get_profile_avatar_url(self.u_frank.localpart)
-                )
-            ),
+            (self.get_success(self.store.get_profile_avatar_url(self.u_frank))),
         )
 
         # test set to None
         self.get_success(self.store.set_profile_avatar_url(self.u_frank, None))
 
         self.assertIsNone(
-            self.get_success(self.store.get_profile_avatar_url(self.u_frank.localpart))
+            self.get_success(self.store.get_profile_avatar_url(self.u_frank))
         )
 
     def test_profiles_bg_migration(self) -> None:
@@ -89,7 +82,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
 
             self.get_success(self.store.db_pool.runInteraction("", f))
 
-        for i in range(0, 70):
+        for i in range(70):
             self.get_success(
                 self.store.db_pool.simple_insert(
                     "profiles",
@@ -122,7 +115,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
         )
 
         expected_values = []
-        for i in range(0, 70):
+        for i in range(70):
             expected_values.append((f"@hello{i:02}:{self.hs.hostname}",))
 
         res = self.get_success(
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 857e2caf2e..0282673167 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -27,7 +27,7 @@ class PurgeTests(HomeserverTestCase):
     servlets = [room.register_servlets]
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-        hs = self.setup_test_homeserver("server", federation_http_client=None)
+        hs = self.setup_test_homeserver("server")
         return hs
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 05ea802008..0cca34d355 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -16,7 +16,7 @@ from twisted.test.proto_helpers import MemoryReactor
 from synapse.api.constants import UserTypes
 from synapse.api.errors import ThreepidValidationError
 from synapse.server import HomeServer
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, UserID, UserInfo
 from synapse.util import Clock
 
 from tests.unittest import HomeserverTestCase, override_config
@@ -35,22 +35,22 @@ class RegistrationStoreTestCase(HomeserverTestCase):
         self.get_success(self.store.register_user(self.user_id, self.pwhash))
 
         self.assertEqual(
-            {
+            UserInfo(
                 # TODO(paul): Surely this field should be 'user_id', not 'name'
-                "name": self.user_id,
-                "password_hash": self.pwhash,
-                "admin": 0,
-                "is_guest": 0,
-                "consent_version": None,
-                "consent_ts": None,
-                "consent_server_notice_sent": None,
-                "appservice_id": None,
-                "creation_ts": 0,
-                "user_type": None,
-                "deactivated": 0,
-                "shadow_banned": 0,
-                "approved": 1,
-            },
+                user_id=UserID.from_string(self.user_id),
+                is_admin=False,
+                is_guest=False,
+                consent_server_notice_sent=None,
+                consent_ts=None,
+                consent_version=None,
+                appservice_id=None,
+                creation_ts=0,
+                user_type=None,
+                is_deactivated=False,
+                locked=False,
+                is_shadow_banned=False,
+                approved=True,
+            ),
             (self.get_success(self.store.get_user_by_id(self.user_id))),
         )
 
@@ -63,9 +63,11 @@ class RegistrationStoreTestCase(HomeserverTestCase):
 
         user = self.get_success(self.store.get_user_by_id(self.user_id))
         assert user
-        self.assertEqual(user["consent_version"], "1")
-        self.assertGreater(user["consent_ts"], before_consent)
-        self.assertLess(user["consent_ts"], self.clock.time_msec())
+        self.assertEqual(user.consent_version, "1")
+        self.assertIsNotNone(user.consent_ts)
+        assert user.consent_ts is not None
+        self.assertGreater(user.consent_ts, before_consent)
+        self.assertLess(user.consent_ts, self.clock.time_msec())
 
     def test_add_tokens(self) -> None:
         self.get_success(self.store.register_user(self.user_id, self.pwhash))
@@ -213,7 +215,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
 
         user = self.get_success(self.store.get_user_by_id(self.user_id))
         assert user is not None
-        self.assertTrue(user["approved"])
+        self.assertTrue(user.approved)
 
         approved = self.get_success(self.store.is_user_approved(self.user_id))
         self.assertTrue(approved)
@@ -226,7 +228,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
 
         user = self.get_success(self.store.get_user_by_id(self.user_id))
         assert user is not None
-        self.assertFalse(user["approved"])
+        self.assertFalse(user.approved)
 
         approved = self.get_success(self.store.is_user_approved(self.user_id))
         self.assertFalse(approved)
@@ -246,7 +248,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
         user = self.get_success(self.store.get_user_by_id(self.user_id))
         self.assertIsNotNone(user)
         assert user is not None
-        self.assertEqual(user["approved"], 1)
+        self.assertEqual(user.approved, 1)
 
         approved = self.get_success(self.store.is_user_approved(self.user_id))
         self.assertTrue(approved)
diff --git a/tests/storage/test_rollback_worker.py b/tests/storage/test_rollback_worker.py
index 966aafea6f..809c9f175d 100644
--- a/tests/storage/test_rollback_worker.py
+++ b/tests/storage/test_rollback_worker.py
@@ -45,9 +45,7 @@ def fake_listdir(filepath: str) -> List[str]:
 
 class WorkerSchemaTests(HomeserverTestCase):
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-        hs = self.setup_test_homeserver(
-            federation_http_client=None, homeserver_to_use=GenericWorkerServer
-        )
+        hs = self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer)
         return hs
 
     def default_config(self) -> JsonDict:
@@ -55,6 +53,7 @@ class WorkerSchemaTests(HomeserverTestCase):
 
         # Mark this as a worker app.
         conf["worker_app"] = "yes"
+        conf["instance_map"] = {"main": {"host": "127.0.0.1", "port": 0}}
 
         return conf
 
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 71ec74eadc..1e27f2c275 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -44,13 +44,13 @@ class RoomStoreTestCase(HomeserverTestCase):
     def test_get_room(self) -> None:
         res = self.get_success(self.store.get_room(self.room.to_string()))
         assert res is not None
-        self.assertDictContainsSubset(
+        self.assertLessEqual(
             {
                 "room_id": self.room.to_string(),
                 "creator": self.u_creator.to_string(),
                 "is_public": True,
-            },
-            res,
+            }.items(),
+            res.items(),
         )
 
     def test_get_room_unknown_room(self) -> None:
@@ -59,13 +59,13 @@ class RoomStoreTestCase(HomeserverTestCase):
     def test_get_room_with_stats(self) -> None:
         res = self.get_success(self.store.get_room_with_stats(self.room.to_string()))
         assert res is not None
-        self.assertDictContainsSubset(
+        self.assertLessEqual(
             {
                 "room_id": self.room.to_string(),
                 "creator": self.u_creator.to_string(),
                 "public": True,
-            },
-            res,
+            }.items(),
+            res.items(),
         )
 
     def test_get_room_with_stats_unknown_room(self) -> None:
diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py
index f183c38477..52ffa91c81 100644
--- a/tests/storage/test_room_search.py
+++ b/tests/storage/test_room_search.py
@@ -318,14 +318,14 @@ class MessageSearchTest(HomeserverTestCase):
             result = self.get_success(
                 store.search_msgs([self.room_id], query, ["content.body"])
             )
-            self.assertEquals(
+            self.assertEqual(
                 result["count"],
                 1 if expect_to_contain else 0,
                 f"expected '{query}' to match '{self.PHRASE}'"
                 if expect_to_contain
                 else f"'{query}' unexpectedly matched '{self.PHRASE}'",
             )
-            self.assertEquals(
+            self.assertEqual(
                 len(result["results"]),
                 1 if expect_to_contain else 0,
                 "results array length should match count",
@@ -336,14 +336,14 @@ class MessageSearchTest(HomeserverTestCase):
             result = self.get_success(
                 store.search_rooms([self.room_id], query, ["content.body"], 10)
             )
-            self.assertEquals(
+            self.assertEqual(
                 result["count"],
                 1 if expect_to_contain else 0,
                 f"expected '{query}' to match '{self.PHRASE}'"
                 if expect_to_contain
                 else f"'{query}' unexpectedly matched '{self.PHRASE}'",
             )
-            self.assertEquals(
+            self.assertEqual(
                 len(result["results"]),
                 1 if expect_to_contain else 0,
                 "results array length should match count",
diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py
index db9ee9955e..ef06b50dbb 100644
--- a/tests/storage/test_transactions.py
+++ b/tests/storage/test_transactions.py
@@ -17,7 +17,6 @@ from twisted.test.proto_helpers import MemoryReactor
 from synapse.server import HomeServer
 from synapse.storage.databases.main.transactions import DestinationRetryTimings
 from synapse.util import Clock
-from synapse.util.retryutils import MAX_RETRY_INTERVAL
 
 from tests.unittest import HomeserverTestCase
 
@@ -33,15 +32,14 @@ class TransactionStoreTestCase(HomeserverTestCase):
         destination retries, as well as testing tht we can set and get
         correctly.
         """
-        d = self.store.get_destination_retry_timings("example.com")
-        r = self.get_success(d)
+        r = self.get_success(self.store.get_destination_retry_timings("example.com"))
         self.assertIsNone(r)
 
-        d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100)
-        self.get_success(d)
+        self.get_success(
+            self.store.set_destination_retry_timings("example.com", 1000, 50, 100)
+        )
 
-        d = self.store.get_destination_retry_timings("example.com")
-        r = self.get_success(d)
+        r = self.get_success(self.store.get_destination_retry_timings("example.com"))
 
         self.assertEqual(
             DestinationRetryTimings(
@@ -58,8 +56,14 @@ class TransactionStoreTestCase(HomeserverTestCase):
         self.get_success(d)
 
     def test_large_destination_retry(self) -> None:
+        max_retry_interval_ms = (
+            self.hs.config.federation.destination_max_retry_interval_ms
+        )
         d = self.store.set_destination_retry_timings(
-            "example.com", MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL
+            "example.com",
+            max_retry_interval_ms,
+            max_retry_interval_ms,
+            max_retry_interval_ms,
         )
         self.get_success(d)
 
diff --git a/tests/storage/test_txn_limit.py b/tests/storage/test_txn_limit.py
index 15ea4770bd..22f074982f 100644
--- a/tests/storage/test_txn_limit.py
+++ b/tests/storage/test_txn_limit.py
@@ -38,5 +38,5 @@ class SQLTransactionLimitTestCase(unittest.HomeserverTestCase):
         db_pool = self.hs.get_datastores().databases[0]
 
         # force txn limit to roll over at least once
-        for _ in range(0, 1001):
+        for _ in range(1001):
             self.get_success_or_raise(db_pool.runInteraction("test_select", do_select))
diff --git a/tests/storage/test_user_filters.py b/tests/storage/test_user_filters.py
index bab802f56e..d4637d9d1e 100644
--- a/tests/storage/test_user_filters.py
+++ b/tests/storage/test_user_filters.py
@@ -45,7 +45,7 @@ class UserFiltersStoreTestCase(unittest.HomeserverTestCase):
 
             self.get_success(self.store.db_pool.runInteraction("", f))
 
-        for i in range(0, 70):
+        for i in range(70):
             self.get_success(
                 self.store.db_pool.simple_insert(
                     "user_filters",
@@ -82,7 +82,7 @@ class UserFiltersStoreTestCase(unittest.HomeserverTestCase):
         )
 
         expected_values = []
-        for i in range(0, 70):
+        for i in range(70):
             expected_values.append((f"@hello{i:02}:{self.hs.hostname}",))
 
         res = self.get_success(
diff --git a/tests/storage/util/test_partial_state_events_tracker.py b/tests/storage/util/test_partial_state_events_tracker.py
index 0e3fc2a77f..29be8cdbd0 100644
--- a/tests/storage/util/test_partial_state_events_tracker.py
+++ b/tests/storage/util/test_partial_state_events_tracker.py
@@ -22,7 +22,6 @@ from synapse.storage.util.partial_state_events_tracker import (
     PartialStateEventsTracker,
 )
 
-from tests.test_utils import make_awaitable
 from tests.unittest import TestCase
 
 
@@ -124,16 +123,17 @@ class PartialStateEventsTrackerTestCase(TestCase):
 class PartialCurrentStateTrackerTestCase(TestCase):
     def setUp(self) -> None:
         self.mock_store = mock.Mock(spec_set=["is_partial_state_room"])
+        self.mock_store.is_partial_state_room = mock.AsyncMock()
 
         self.tracker = PartialCurrentStateTracker(self.mock_store)
 
     def test_does_not_block_for_full_state_rooms(self) -> None:
-        self.mock_store.is_partial_state_room.return_value = make_awaitable(False)
+        self.mock_store.is_partial_state_room.return_value = False
 
         self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
 
     def test_blocks_for_partial_room_state(self) -> None:
-        self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
+        self.mock_store.is_partial_state_room.return_value = True
 
         d = ensureDeferred(self.tracker.await_full_state("room_id"))
 
@@ -156,7 +156,7 @@ class PartialCurrentStateTrackerTestCase(TestCase):
         self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
 
     def test_cancellation(self) -> None:
-        self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
+        self.mock_store.is_partial_state_room.return_value = True
 
         d1 = ensureDeferred(self.tracker.await_full_state("room_id"))
         self.assertNoResult(d1)
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 6d15ac7597..1b0504709e 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 from typing import Collection, List, Optional, Union
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -31,7 +31,6 @@ from synapse.util import Clock
 from synapse.util.retryutils import NotRetryingDestination
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 
 
 class MessageAcceptTests(unittest.HomeserverTestCase):
@@ -52,9 +51,15 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         self.store = self.hs.get_datastores().main
 
         # Figure out what the most recent event is
-        most_recent = self.get_success(
-            self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
-        )[0]
+        most_recent = next(
+            iter(
+                self.get_success(
+                    self.hs.get_datastores().main.get_latest_event_ids_in_room(
+                        self.room_id
+                    )
+                )
+            )
+        )
 
         join_event = make_event_from_dict(
             {
@@ -81,7 +86,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         ) -> None:
             pass
 
-        federation_event_handler._check_event_auth = _check_event_auth  # type: ignore[assignment]
+        federation_event_handler._check_event_auth = _check_event_auth  # type: ignore[method-assign]
         self.client = self.hs.get_federation_client()
 
         async def _check_sigs_and_hash_for_pulled_events_and_fetch(
@@ -101,8 +106,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
 
         # Make sure we actually joined the room
         self.assertEqual(
-            self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))[0],
-            "$join:test.serv",
+            self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)),
+            {"$join:test.serv"},
         )
 
     def test_cant_hide_direct_ancestors(self) -> None:
@@ -128,9 +133,11 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         self.http_client.post_json = post_json
 
         # Figure out what the most recent event is
-        most_recent = self.get_success(
-            self.store.get_latest_event_ids_in_room(self.room_id)
-        )[0]
+        most_recent = next(
+            iter(
+                self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
+            )
+        )
 
         # Now lie about an event
         lying_event = make_event_from_dict(
@@ -166,7 +173,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
 
         # Make sure the invalid event isn't there
         extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
-        self.assertEqual(extrem[0], "$join:test.serv")
+        self.assertEqual(extrem, {"$join:test.serv"})
 
     def test_retry_device_list_resync(self) -> None:
         """Tests that device lists are marked as stale if they couldn't be synced, and
@@ -191,12 +198,12 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
 
         # Register the mock on the federation client.
         federation_client = self.hs.get_federation_client()
-        federation_client.query_user_devices = Mock(side_effect=query_user_devices)  # type: ignore[assignment]
+        federation_client.query_user_devices = Mock(side_effect=query_user_devices)  # type: ignore[method-assign]
 
         # Register a mock on the store so that the incoming update doesn't fail because
         # we don't share a room with the user.
         store = self.hs.get_datastores().main
-        store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"]))
+        store.get_rooms_for_user = AsyncMock(return_value=["!someroom:test"])
 
         # Manually inject a fake device list update. We need this update to include at
         # least one prev_id so that the user's device list will need to be retried.
@@ -241,27 +248,24 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
 
         # Register mock device list retrieval on the federation client.
         federation_client = self.hs.get_federation_client()
-        federation_client.query_user_devices = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(
-                {
+        federation_client.query_user_devices = AsyncMock(  # type: ignore[method-assign]
+            return_value={
+                "user_id": remote_user_id,
+                "stream_id": 1,
+                "devices": [],
+                "master_key": {
                     "user_id": remote_user_id,
-                    "stream_id": 1,
-                    "devices": [],
-                    "master_key": {
-                        "user_id": remote_user_id,
-                        "usage": ["master"],
-                        "keys": {"ed25519:" + remote_master_key: remote_master_key},
-                    },
-                    "self_signing_key": {
-                        "user_id": remote_user_id,
-                        "usage": ["self_signing"],
-                        "keys": {
-                            "ed25519:"
-                            + remote_self_signing_key: remote_self_signing_key
-                        },
+                    "usage": ["master"],
+                    "keys": {"ed25519:" + remote_master_key: remote_master_key},
+                },
+                "self_signing_key": {
+                    "user_id": remote_user_id,
+                    "usage": ["self_signing"],
+                    "keys": {
+                        "ed25519:" + remote_self_signing_key: remote_self_signing_key
                     },
-                }
-            )
+                },
+            }
         )
 
         # Resync the device list.
diff --git a/tests/test_server.py b/tests/test_server.py
index e266c06a2c..36162cd1f5 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -38,7 +38,7 @@ from tests.http.server._base import test_disconnect
 from tests.server import (
     FakeChannel,
     FakeSite,
-    ThreadedMemoryReactorClock,
+    get_clock,
     make_request,
     setup_test_homeserver,
 )
@@ -46,12 +46,11 @@ from tests.server import (
 
 class JsonResourceTests(unittest.TestCase):
     def setUp(self) -> None:
-        self.reactor = ThreadedMemoryReactorClock()
-        self.hs_clock = Clock(self.reactor)
+        reactor, clock = get_clock()
+        self.reactor = reactor
         self.homeserver = setup_test_homeserver(
             self.addCleanup,
-            federation_http_client=None,
-            clock=self.hs_clock,
+            clock=clock,
             reactor=self.reactor,
         )
 
@@ -209,7 +208,13 @@ class JsonResourceTests(unittest.TestCase):
 
 class OptionsResourceTests(unittest.TestCase):
     def setUp(self) -> None:
-        self.reactor = ThreadedMemoryReactorClock()
+        reactor, clock = get_clock()
+        self.reactor = reactor
+        self.homeserver = setup_test_homeserver(
+            self.addCleanup,
+            clock=clock,
+            reactor=self.reactor,
+        )
 
         class DummyResource(Resource):
             isLeaf = True
@@ -242,6 +247,7 @@ class OptionsResourceTests(unittest.TestCase):
             "1.0",
             max_request_body_size=4096,
             reactor=self.reactor,
+            hs=self.homeserver,
         )
 
         # render the request and return the channel
@@ -268,7 +274,7 @@ class OptionsResourceTests(unittest.TestCase):
         )
         self.assertEqual(
             channel.headers.getRawHeaders(b"Access-Control-Expose-Headers"),
-            [b"Synapse-Trace-Id"],
+            [b"Synapse-Trace-Id, Server"],
         )
 
     def _check_cors_msc3886_headers(self, channel: FakeChannel) -> None:
@@ -344,7 +350,8 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
             await self.callback(request)
 
     def setUp(self) -> None:
-        self.reactor = ThreadedMemoryReactorClock()
+        reactor, _ = get_clock()
+        self.reactor = reactor
 
     def test_good_response(self) -> None:
         async def callback(request: SynapseRequest) -> None:
@@ -462,9 +469,9 @@ class DirectServeJsonResourceCancellationTests(unittest.TestCase):
     """Tests for `DirectServeJsonResource` cancellation."""
 
     def setUp(self) -> None:
-        self.reactor = ThreadedMemoryReactorClock()
-        self.clock = Clock(self.reactor)
-        self.resource = CancellableDirectServeJsonResource(self.clock)
+        reactor, clock = get_clock()
+        self.reactor = reactor
+        self.resource = CancellableDirectServeJsonResource(clock)
         self.site = FakeSite(self.resource, self.reactor)
 
     def test_cancellable_disconnect(self) -> None:
@@ -496,9 +503,9 @@ class DirectServeHtmlResourceCancellationTests(unittest.TestCase):
     """Tests for `DirectServeHtmlResource` cancellation."""
 
     def setUp(self) -> None:
-        self.reactor = ThreadedMemoryReactorClock()
-        self.clock = Clock(self.reactor)
-        self.resource = CancellableDirectServeHtmlResource(self.clock)
+        reactor, clock = get_clock()
+        self.reactor = reactor
+        self.resource = CancellableDirectServeHtmlResource(clock)
         self.site = FakeSite(self.resource, self.reactor)
 
     def test_cancellable_disconnect(self) -> None:
diff --git a/tests/test_state.py b/tests/test_state.py
index ddf59916b1..9c8679cc1d 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -28,7 +28,7 @@ from unittest.mock import Mock
 
 from twisted.internet import defer
 
-from synapse.api.auth import Auth
+from synapse.api.auth.internal import InternalAuth
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase, make_event_from_dict
@@ -240,7 +240,7 @@ class StateTestCase(unittest.TestCase):
         hs.get_macaroon_generator.return_value = MacaroonGenerator(
             clock, "tesths", b"verysecret"
         )
-        hs.get_auth.return_value = Auth(hs)
+        hs.get_auth.return_value = InternalAuth(hs)
         hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
         hs.get_storage_controllers.return_value = storage_controllers
 
@@ -555,10 +555,15 @@ class StateTestCase(unittest.TestCase):
             (e.event_id for e in old_state + [event]), current_state_ids.values()
         )
 
-        self.assertIsNotNone(context.state_group_before_event)
+        assert context.state_group_before_event is not None
+        assert context.state_group is not None
+        self.assertEqual(
+            context.state_group_deltas.get(
+                (context.state_group_before_event, context.state_group)
+            ),
+            {(event.type, event.state_key): event.event_id},
+        )
         self.assertNotEqual(context.state_group_before_event, context.state_group)
-        self.assertEqual(context.state_group_before_event, context.prev_group)
-        self.assertEqual({("state", ""): event.event_id}, context.delta_ids)
 
     @defer.inlineCallbacks
     def test_trivial_annotate_message(
@@ -709,7 +714,7 @@ class StateTestCase(unittest.TestCase):
         store = _DummyStore()
         store.register_events(old_state_1)
         store.register_events(old_state_2)
-        self.dummy_store.get_events = store.get_events  # type: ignore[assignment]
+        self.dummy_store.get_events = store.get_events  # type: ignore[method-assign]
 
         context: EventContext
         context = yield self._get_context(
@@ -768,7 +773,7 @@ class StateTestCase(unittest.TestCase):
         store = _DummyStore()
         store.register_events(old_state_1)
         store.register_events(old_state_2)
-        self.dummy_store.get_events = store.get_events  # type: ignore[assignment]
+        self.dummy_store.get_events = store.get_events  # type: ignore[method-assign]
 
         context: EventContext
         context = yield self._get_context(
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index 52424aa087..64a49488c6 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -85,7 +85,9 @@ class TermsTestCase(unittest.HomeserverTestCase):
             }
         }
         self.assertIsInstance(channel.json_body["params"], dict)
-        self.assertDictContainsSubset(channel.json_body["params"], expected_params)
+        self.assertLessEqual(
+            channel.json_body["params"].items(), expected_params.items()
+        )
 
         # We have to complete the dummy auth stage before completing the terms stage
         request_data = {
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index e5dae670a7..fa731426cd 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -18,10 +18,8 @@ Utilities for running the unit tests
 import json
 import sys
 import warnings
-from asyncio import Future
 from binascii import unhexlify
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar
-from unittest.mock import Mock
+from typing import TYPE_CHECKING, Awaitable, Callable, Tuple, TypeVar
 
 import attr
 import zope.interface
@@ -33,7 +31,7 @@ from twisted.web.http import RESPONSES
 from twisted.web.http_headers import Headers
 from twisted.web.iweb import IResponse
 
-from synapse.types import JsonDict
+from synapse.types import JsonSerializable
 
 if TYPE_CHECKING:
     from sys import UnraisableHookArgs
@@ -57,27 +55,12 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
     raise Exception("awaitable has not yet completed")
 
 
-def make_awaitable(result: TV) -> Awaitable[TV]:
-    """
-    Makes an awaitable, suitable for mocking an `async` function.
-    This uses Futures as they can be awaited multiple times so can be returned
-    to multiple callers.
-    """
-    future: Future[TV] = Future()
-    future.set_result(result)
-    return future
-
-
 def setup_awaitable_errors() -> Callable[[], None]:
     """
     Convert warnings from a non-awaited coroutines into errors.
     """
     warnings.simplefilter("error", RuntimeWarning)
 
-    # unraisablehook was added in Python 3.8.
-    if not hasattr(sys, "unraisablehook"):
-        return lambda: None
-
     # State shared between unraisablehook and check_for_unraisable_exceptions.
     unraisable_exceptions = []
     orig_unraisablehook = sys.unraisablehook
@@ -100,18 +83,6 @@ def setup_awaitable_errors() -> Callable[[], None]:
     return cleanup
 
 
-def simple_async_mock(
-    return_value: Optional[TV] = None, raises: Optional[Exception] = None
-) -> Mock:
-    # AsyncMock is not available in python3.5, this mimics part of its behaviour
-    async def cb(*args: Any, **kwargs: Any) -> Optional[TV]:
-        if raises:
-            raise raises
-        return return_value
-
-    return Mock(side_effect=cb)
-
-
 # Type ignore: it does not fully implement IResponse, but is good enough for tests
 @zope.interface.implementer(IResponse)
 @attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -145,7 +116,7 @@ class FakeResponse:  # type: ignore[misc]
         protocol.connectionLost(Failure(ResponseDone()))
 
     @classmethod
-    def json(cls, *, code: int = 200, payload: JsonDict) -> "FakeResponse":
+    def json(cls, *, code: int = 200, payload: JsonSerializable) -> "FakeResponse":
         headers = Headers({"Content-Type": ["application/json"]})
         body = json.dumps(payload).encode("utf-8")
         return cls(code=code, body=body, headers=headers)
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index b522163a34..199bb06a81 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -40,10 +40,9 @@ def setup_logging() -> None:
     """
     root_logger = logging.getLogger()
 
-    log_format = (
-        "%(asctime)s - %(name)s - %(lineno)d - "
-        "%(levelname)s - %(request)s - %(message)s"
-    )
+    # We exclude `%(asctime)s` from this format because the Twisted logger adds its own
+    # timestamp
+    log_format = "%(name)s - %(lineno)d - " "%(levelname)s - %(request)s - %(message)s"
 
     handler = ToTwistedHandler()
     formatter = logging.Formatter(log_format)
@@ -54,4 +53,16 @@ def setup_logging() -> None:
     log_level = os.environ.get("SYNAPSE_TEST_LOG_LEVEL", "ERROR")
     root_logger.setLevel(log_level)
 
+    # In order to not add noise by default (since we only log ERROR messages for trial
+    # tests as configured above), we only enable this for developers for looking for
+    # more INFO or DEBUG.
+    if root_logger.isEnabledFor(logging.INFO):
+        # Log when events are (maybe unexpectedly) filtered out of responses in tests. It's
+        # just nice to be able to look at the CI log and figure out why an event isn't being
+        # returned.
+        logging.getLogger("synapse.visibility.filtered_event_debug").setLevel(
+            logging.DEBUG
+        )
+
+    # Blow away the pyo3-log cache so that it reloads the configuration.
     reset_logging_config()
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 9ed330f554..434902c3f0 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -31,7 +31,7 @@ TEST_ROOM_ID = "!TEST:ROOM"
 
 class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
     def setUp(self) -> None:
-        super(FilterEventsForServerTestCase, self).setUp()
+        super().setUp()
         self.event_creation_handler = self.hs.get_event_creation_handler()
         self.event_builder_factory = self.hs.get_event_builder_factory()
         self._storage_controllers = self.hs.get_storage_controllers()
@@ -51,12 +51,12 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
 
         # before we do that, we persist some other events to act as state.
         self._inject_visibility("@admin:hs", "joined")
-        for i in range(0, 10):
+        for i in range(10):
             self._inject_room_member("@resident%i:hs" % i)
 
         events_to_filter = []
 
-        for i in range(0, 10):
+        for i in range(10):
             user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server")
             evt = self._inject_room_member(user, extra_content={"a": "b"})
             events_to_filter.append(evt)
@@ -74,7 +74,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         )
 
         # the result should be 5 redacted events, and 5 unredacted events.
-        for i in range(0, 5):
+        for i in range(5):
             self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
             self.assertNotIn("a", filtered[i].content)
 
@@ -177,7 +177,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        for i in range(0, len(events_to_filter)):
+        for i in range(len(events_to_filter)):
             self.assertEqual(
                 events_to_filter[i].event_id,
                 filtered[i].event_id,
diff --git a/tests/unittest.py b/tests/unittest.py
index 623c5a75a2..99ad02eb06 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -13,6 +13,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import functools
 import gc
 import hashlib
 import hmac
@@ -59,7 +60,7 @@ from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import DEFAULT_ROOM_VERSION
 from synapse.crypto.event_signing import add_hashes_and_signatures
 from synapse.federation.transport.server import TransportLayerServer
-from synapse.http.server import JsonResource
+from synapse.http.server import JsonResource, OptionsResource
 from synapse.http.site import SynapseRequest, SynapseSite
 from synapse.logging.context import (
     SENTINEL_CONTEXT,
@@ -69,6 +70,7 @@ from synapse.logging.context import (
 )
 from synapse.rest import RegisterServletsFunc
 from synapse.server import HomeServer
+from synapse.storage.keys import FetchKeyResult
 from synapse.types import JsonDict, Requester, UserID, create_requester
 from synapse.util import Clock
 from synapse.util.httpresourcetree import create_resource_tree
@@ -150,7 +152,11 @@ def deepcopy_config(config: _TConfig) -> _TConfig:
     return new_config
 
 
-_make_homeserver_config_obj_cache: Dict[str, Union[RootConfig, Config]] = {}
+@functools.lru_cache(maxsize=8)
+def _parse_config_dict(config: str) -> RootConfig:
+    config_obj = HomeServerConfig()
+    config_obj.parse_config_dict(json.loads(config), "", "")
+    return config_obj
 
 
 def make_homeserver_config_obj(config: Dict[str, Any]) -> RootConfig:
@@ -164,21 +170,7 @@ def make_homeserver_config_obj(config: Dict[str, Any]) -> RootConfig:
     but it keeps a cache of `HomeServerConfig` instances and deepcopies them as needed,
     to avoid validating the whole configuration every time.
     """
-    cache_key = json.dumps(config)
-
-    if cache_key in _make_homeserver_config_obj_cache:
-        # Cache hit: reuse the existing instance
-        config_obj = _make_homeserver_config_obj_cache[cache_key]
-    else:
-        # Cache miss; create the actual instance
-        config_obj = HomeServerConfig()
-        config_obj.parse_config_dict(config, "", "")
-
-        # Add to the cache
-        _make_homeserver_config_obj_cache[cache_key] = config_obj
-
-    assert isinstance(config_obj, RootConfig)
-
+    config_obj = _parse_config_dict(json.dumps(config, sort_keys=True))
     return deepcopy_config(config_obj)
 
 
@@ -322,7 +314,7 @@ class HomeserverTestCase(TestCase):
         servlets: List of servlet registration function.
         user_id (str): The user ID to assume if auth is hijacked.
         hijack_auth: Whether to hijack auth to return the user specified
-        in user_id.
+           in user_id.
     """
 
     hijack_auth: ClassVar[bool] = True
@@ -367,6 +359,7 @@ class HomeserverTestCase(TestCase):
             server_version_string="1",
             max_request_body_size=4096,
             reactor=self.reactor,
+            hs=self.hs,
         )
 
         from tests.rest.client.utils import RestHelper
@@ -403,9 +396,9 @@ class HomeserverTestCase(TestCase):
                     )
 
                 # Type ignore: mypy doesn't like us assigning to methods.
-                self.hs.get_auth().get_user_by_req = get_requester  # type: ignore[assignment]
-                self.hs.get_auth().get_user_by_access_token = get_requester  # type: ignore[assignment]
-                self.hs.get_auth().get_access_token_from_request = Mock(return_value=token)  # type: ignore[assignment]
+                self.hs.get_auth().get_user_by_req = get_requester  # type: ignore[method-assign]
+                self.hs.get_auth().get_user_by_access_token = get_requester  # type: ignore[method-assign]
+                self.hs.get_auth().get_access_token_from_request = Mock(return_value=token)  # type: ignore[method-assign]
 
         if self.needs_threadpool:
             self.reactor.threadpool = ThreadPool()  # type: ignore[assignment]
@@ -466,7 +459,7 @@ class HomeserverTestCase(TestCase):
         The default calls `self.create_resource_dict` and builds the resultant dict
         into a tree.
         """
-        root_resource = Resource()
+        root_resource = OptionsResource()
         create_resource_tree(self.create_resource_dict(), root_resource)
         return root_resource
 
@@ -866,23 +859,22 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
         verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version)
 
         self.get_success(
-            hs.get_datastores().main.store_server_keys_json(
+            hs.get_datastores().main.store_server_keys_response(
                 self.OTHER_SERVER_NAME,
-                verify_key_id,
                 from_server=self.OTHER_SERVER_NAME,
-                ts_now_ms=clock.time_msec(),
-                ts_expires_ms=clock.time_msec() + 10000,
-                key_json_bytes=canonicaljson.encode_canonical_json(
-                    {
-                        "verify_keys": {
-                            verify_key_id: {
-                                "key": signedjson.key.encode_verify_key_base64(
-                                    verify_key
-                                )
-                            }
+                ts_added_ms=clock.time_msec(),
+                verify_keys={
+                    verify_key_id: FetchKeyResult(
+                        verify_key=verify_key, valid_until_ts=clock.time_msec() + 10000
+                    ),
+                },
+                response_json={
+                    "verify_keys": {
+                        verify_key_id: {
+                            "key": signedjson.key.encode_verify_key_base64(verify_key)
                         }
                     }
-                ),
+                },
             )
         )
 
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 13f1edd533..7e8725e610 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -13,7 +13,18 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Iterable, Set, Tuple, cast
+from typing import (
+    Any,
+    Generator,
+    Iterable,
+    List,
+    Mapping,
+    NoReturn,
+    Optional,
+    Set,
+    Tuple,
+    cast,
+)
 from unittest import mock
 
 from twisted.internet import defer, reactor
@@ -29,7 +40,7 @@ from synapse.logging.context import (
     make_deferred_yieldable,
 )
 from synapse.util.caches import descriptors
-from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
 
 from tests import unittest
 from tests.test_utils import get_awaitable_result
@@ -37,21 +48,21 @@ from tests.test_utils import get_awaitable_result
 logger = logging.getLogger(__name__)
 
 
-def run_on_reactor():
-    d: "Deferred[int]" = defer.Deferred()
+def run_on_reactor() -> "Deferred[int]":
+    d: "Deferred[int]" = Deferred()
     cast(IReactorTime, reactor).callLater(0, d.callback, 0)
     return make_deferred_yieldable(d)
 
 
 class DescriptorTestCase(unittest.TestCase):
     @defer.inlineCallbacks
-    def test_cache(self):
+    def test_cache(self) -> Generator["Deferred[Any]", object, None]:
         class Cls:
-            def __init__(self):
+            def __init__(self) -> None:
                 self.mock = mock.Mock()
 
             @descriptors.cached()
-            def fn(self, arg1, arg2):
+            def fn(self, arg1: int, arg2: int) -> str:
                 return self.mock(arg1, arg2)
 
         obj = Cls()
@@ -77,15 +88,15 @@ class DescriptorTestCase(unittest.TestCase):
         obj.mock.assert_not_called()
 
     @defer.inlineCallbacks
-    def test_cache_num_args(self):
+    def test_cache_num_args(self) -> Generator["Deferred[Any]", object, None]:
         """Only the first num_args arguments should matter to the cache"""
 
         class Cls:
-            def __init__(self):
+            def __init__(self) -> None:
                 self.mock = mock.Mock()
 
             @descriptors.cached(num_args=1)
-            def fn(self, arg1, arg2):
+            def fn(self, arg1: int, arg2: int) -> str:
                 return self.mock(arg1, arg2)
 
         obj = Cls()
@@ -111,7 +122,7 @@ class DescriptorTestCase(unittest.TestCase):
         obj.mock.assert_not_called()
 
     @defer.inlineCallbacks
-    def test_cache_uncached_args(self):
+    def test_cache_uncached_args(self) -> Generator["Deferred[Any]", object, None]:
         """
         Only the arguments not named in uncached_args should matter to the cache
 
@@ -123,10 +134,10 @@ class DescriptorTestCase(unittest.TestCase):
             # Note that it is important that this is not the last argument to
             # test behaviour of skipping arguments properly.
             @descriptors.cached(uncached_args=("arg2",))
-            def fn(self, arg1, arg2, arg3):
+            def fn(self, arg1: int, arg2: int, arg3: int) -> str:
                 return self.mock(arg1, arg2, arg3)
 
-            def __init__(self):
+            def __init__(self) -> None:
                 self.mock = mock.Mock()
 
         obj = Cls()
@@ -152,15 +163,15 @@ class DescriptorTestCase(unittest.TestCase):
         obj.mock.assert_not_called()
 
     @defer.inlineCallbacks
-    def test_cache_kwargs(self):
+    def test_cache_kwargs(self) -> Generator["Deferred[Any]", object, None]:
         """Test that keyword arguments are treated properly"""
 
         class Cls:
-            def __init__(self):
+            def __init__(self) -> None:
                 self.mock = mock.Mock()
 
             @descriptors.cached()
-            def fn(self, arg1, kwarg1=2):
+            def fn(self, arg1: int, kwarg1: int = 2) -> str:
                 return self.mock(arg1, kwarg1=kwarg1)
 
         obj = Cls()
@@ -188,12 +199,12 @@ class DescriptorTestCase(unittest.TestCase):
         self.assertEqual(r, "fish")
         obj.mock.assert_not_called()
 
-    def test_cache_with_sync_exception(self):
+    def test_cache_with_sync_exception(self) -> None:
         """If the wrapped function throws synchronously, things should continue to work"""
 
         class Cls:
             @cached()
-            def fn(self, arg1):
+            def fn(self, arg1: int) -> NoReturn:
                 raise SynapseError(100, "mai spoon iz too big!!1")
 
         obj = Cls()
@@ -209,23 +220,24 @@ class DescriptorTestCase(unittest.TestCase):
         d = obj.fn(1)
         self.failureResultOf(d, SynapseError)
 
-    def test_cache_with_async_exception(self):
+    def test_cache_with_async_exception(self) -> None:
         """The wrapped function returns a failure"""
 
         class Cls:
-            result = None
+            result: Optional[Deferred] = None
             call_count = 0
 
             @cached()
-            def fn(self, arg1):
+            def fn(self, arg1: int) -> Deferred:
                 self.call_count += 1
+                assert self.result is not None
                 return self.result
 
         obj = Cls()
         callbacks: Set[str] = set()
 
         # set off an asynchronous request
-        origin_d: Deferred = defer.Deferred()
+        origin_d: Deferred = Deferred()
         obj.result = origin_d
 
         d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1"))
@@ -260,17 +272,17 @@ class DescriptorTestCase(unittest.TestCase):
         self.assertEqual(self.successResultOf(d3), 100)
         self.assertEqual(obj.call_count, 2)
 
-    def test_cache_logcontexts(self):
+    def test_cache_logcontexts(self) -> Deferred:
         """Check that logcontexts are set and restored correctly when
         using the cache."""
 
-        complete_lookup: Deferred = defer.Deferred()
+        complete_lookup: Deferred = Deferred()
 
         class Cls:
             @descriptors.cached()
-            def fn(self, arg1):
+            def fn(self, arg1: int) -> "Deferred[int]":
                 @defer.inlineCallbacks
-                def inner_fn():
+                def inner_fn() -> Generator["Deferred[object]", object, int]:
                     with PreserveLoggingContext():
                         yield complete_lookup
                     return 1
@@ -278,13 +290,13 @@ class DescriptorTestCase(unittest.TestCase):
                 return inner_fn()
 
         @defer.inlineCallbacks
-        def do_lookup():
+        def do_lookup() -> Generator["Deferred[Any]", object, int]:
             with LoggingContext("c1") as c1:
                 r = yield obj.fn(1)
                 self.assertEqual(current_context(), c1)
-            return r
+            return cast(int, r)
 
-        def check_result(r):
+        def check_result(r: int) -> None:
             self.assertEqual(r, 1)
 
         obj = Cls()
@@ -304,15 +316,15 @@ class DescriptorTestCase(unittest.TestCase):
 
         return defer.gatherResults([d1, d2])
 
-    def test_cache_logcontexts_with_exception(self):
+    def test_cache_logcontexts_with_exception(self) -> "Deferred[None]":
         """Check that the cache sets and restores logcontexts correctly when
         the lookup function throws an exception"""
 
         class Cls:
             @descriptors.cached()
-            def fn(self, arg1):
+            def fn(self, arg1: int) -> Deferred:
                 @defer.inlineCallbacks
-                def inner_fn():
+                def inner_fn() -> Generator["Deferred[Any]", object, NoReturn]:
                     # we want this to behave like an asynchronous function
                     yield run_on_reactor()
                     raise SynapseError(400, "blah")
@@ -320,7 +332,7 @@ class DescriptorTestCase(unittest.TestCase):
                 return inner_fn()
 
         @defer.inlineCallbacks
-        def do_lookup():
+        def do_lookup() -> Generator["Deferred[object]", object, None]:
             with LoggingContext("c1") as c1:
                 try:
                     d = obj.fn(1)
@@ -347,13 +359,13 @@ class DescriptorTestCase(unittest.TestCase):
         return d1
 
     @defer.inlineCallbacks
-    def test_cache_default_args(self):
+    def test_cache_default_args(self) -> Generator["Deferred[Any]", object, None]:
         class Cls:
-            def __init__(self):
+            def __init__(self) -> None:
                 self.mock = mock.Mock()
 
             @descriptors.cached()
-            def fn(self, arg1, arg2=2, arg3=3):
+            def fn(self, arg1: int, arg2: int = 2, arg3: int = 3) -> str:
                 return self.mock(arg1, arg2, arg3)
 
         obj = Cls()
@@ -384,27 +396,27 @@ class DescriptorTestCase(unittest.TestCase):
         self.assertEqual(r, "chips")
         obj.mock.assert_not_called()
 
-    def test_cache_iterable(self):
+    def test_cache_iterable(self) -> None:
         class Cls:
-            def __init__(self):
+            def __init__(self) -> None:
                 self.mock = mock.Mock()
 
             @descriptors.cached(iterable=True)
-            def fn(self, arg1, arg2):
+            def fn(self, arg1: int, arg2: int) -> Tuple[str, ...]:
                 return self.mock(arg1, arg2)
 
         obj = Cls()
 
-        obj.mock.return_value = ["spam", "eggs"]
+        obj.mock.return_value = ("spam", "eggs")
         r = obj.fn(1, 2)
-        self.assertEqual(r.result, ["spam", "eggs"])
+        self.assertEqual(r.result, ("spam", "eggs"))
         obj.mock.assert_called_once_with(1, 2)
         obj.mock.reset_mock()
 
         # a call with different params should call the mock again
-        obj.mock.return_value = ["chips"]
+        obj.mock.return_value = ("chips",)
         r = obj.fn(1, 3)
-        self.assertEqual(r.result, ["chips"])
+        self.assertEqual(r.result, ("chips",))
         obj.mock.assert_called_once_with(1, 3)
         obj.mock.reset_mock()
 
@@ -412,17 +424,17 @@ class DescriptorTestCase(unittest.TestCase):
         self.assertEqual(len(obj.fn.cache.cache), 3)
 
         r = obj.fn(1, 2)
-        self.assertEqual(r.result, ["spam", "eggs"])
+        self.assertEqual(r.result, ("spam", "eggs"))
         r = obj.fn(1, 3)
-        self.assertEqual(r.result, ["chips"])
+        self.assertEqual(r.result, ("chips",))
         obj.mock.assert_not_called()
 
-    def test_cache_iterable_with_sync_exception(self):
+    def test_cache_iterable_with_sync_exception(self) -> None:
         """If the wrapped function throws synchronously, things should continue to work"""
 
         class Cls:
             @descriptors.cached(iterable=True)
-            def fn(self, arg1):
+            def fn(self, arg1: int) -> NoReturn:
                 raise SynapseError(100, "mai spoon iz too big!!1")
 
         obj = Cls()
@@ -438,20 +450,20 @@ class DescriptorTestCase(unittest.TestCase):
         d = obj.fn(1)
         self.failureResultOf(d, SynapseError)
 
-    def test_invalidate_cascade(self):
+    def test_invalidate_cascade(self) -> None:
         """Invalidations should cascade up through cache contexts"""
 
         class Cls:
             @cached(cache_context=True)
-            async def func1(self, key, cache_context):
+            async def func1(self, key: str, cache_context: _CacheContext) -> int:
                 return await self.func2(key, on_invalidate=cache_context.invalidate)
 
             @cached(cache_context=True)
-            async def func2(self, key, cache_context):
+            async def func2(self, key: str, cache_context: _CacheContext) -> int:
                 return await self.func3(key, on_invalidate=cache_context.invalidate)
 
             @cached(cache_context=True)
-            async def func3(self, key, cache_context):
+            async def func3(self, key: str, cache_context: _CacheContext) -> int:
                 self.invalidate = cache_context.invalidate
                 return 42
 
@@ -463,13 +475,13 @@ class DescriptorTestCase(unittest.TestCase):
         obj.invalidate()
         top_invalidate.assert_called_once()
 
-    def test_cancel(self):
+    def test_cancel(self) -> None:
         """Test that cancelling a lookup does not cancel other lookups"""
         complete_lookup: "Deferred[None]" = Deferred()
 
         class Cls:
             @cached()
-            async def fn(self, arg1):
+            async def fn(self, arg1: int) -> str:
                 await complete_lookup
                 return str(arg1)
 
@@ -488,7 +500,7 @@ class DescriptorTestCase(unittest.TestCase):
         self.failureResultOf(d1, CancelledError)
         self.assertEqual(d2.result, "123")
 
-    def test_cancel_logcontexts(self):
+    def test_cancel_logcontexts(self) -> None:
         """Test that cancellation does not break logcontexts.
 
         * The `CancelledError` must be raised with the correct logcontext.
@@ -501,14 +513,14 @@ class DescriptorTestCase(unittest.TestCase):
             inner_context_was_finished = False
 
             @cached()
-            async def fn(self, arg1):
+            async def fn(self, arg1: int) -> str:
                 await make_deferred_yieldable(complete_lookup)
                 self.inner_context_was_finished = current_context().finished
                 return str(arg1)
 
         obj = Cls()
 
-        async def do_lookup():
+        async def do_lookup() -> None:
             with LoggingContext("c1") as c1:
                 try:
                     await obj.fn(123)
@@ -542,10 +554,10 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
     """
 
     @defer.inlineCallbacks
-    def test_passthrough(self):
+    def test_passthrough(self) -> Generator["Deferred[Any]", object, None]:
         class A:
             @cached()
-            def func(self, key):
+            def func(self, key: str) -> str:
                 return key
 
         a = A()
@@ -554,12 +566,12 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
         self.assertEqual((yield a.func("bar")), "bar")
 
     @defer.inlineCallbacks
-    def test_hit(self):
+    def test_hit(self) -> Generator["Deferred[Any]", object, None]:
         callcount = [0]
 
         class A:
             @cached()
-            def func(self, key):
+            def func(self, key: str) -> str:
                 callcount[0] += 1
                 return key
 
@@ -572,12 +584,12 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
         self.assertEqual(callcount[0], 1)
 
     @defer.inlineCallbacks
-    def test_invalidate(self):
+    def test_invalidate(self) -> Generator["Deferred[Any]", object, None]:
         callcount = [0]
 
         class A:
             @cached()
-            def func(self, key):
+            def func(self, key: str) -> str:
                 callcount[0] += 1
                 return key
 
@@ -592,48 +604,48 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(callcount[0], 2)
 
-    def test_invalidate_missing(self):
+    def test_invalidate_missing(self) -> None:
         class A:
             @cached()
-            def func(self, key):
+            def func(self, key: str) -> str:
                 return key
 
         A().func.invalidate(("what",))
 
     @defer.inlineCallbacks
-    def test_max_entries(self):
+    def test_max_entries(self) -> Generator["Deferred[Any]", object, None]:
         callcount = [0]
 
         class A:
             @cached(max_entries=10)
-            def func(self, key):
+            def func(self, key: int) -> int:
                 callcount[0] += 1
                 return key
 
         a = A()
 
-        for k in range(0, 12):
+        for k in range(12):
             yield a.func(k)
 
         self.assertEqual(callcount[0], 12)
 
         # There must have been at least 2 evictions, meaning if we calculate
         # all 12 values again, we must get called at least 2 more times
-        for k in range(0, 12):
+        for k in range(12):
             yield a.func(k)
 
         self.assertTrue(
             callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
         )
 
-    def test_prefill(self):
+    def test_prefill(self) -> None:
         callcount = [0]
 
         d = defer.succeed(123)
 
         class A:
             @cached()
-            def func(self, key):
+            def func(self, key: str) -> "Deferred[int]":
                 callcount[0] += 1
                 return d
 
@@ -645,18 +657,18 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
         self.assertEqual(callcount[0], 0)
 
     @defer.inlineCallbacks
-    def test_invalidate_context(self):
+    def test_invalidate_context(self) -> Generator["Deferred[Any]", object, None]:
         callcount = [0]
         callcount2 = [0]
 
         class A:
             @cached()
-            def func(self, key):
+            def func(self, key: str) -> str:
                 callcount[0] += 1
                 return key
 
             @cached(cache_context=True)
-            def func2(self, key, cache_context):
+            def func2(self, key: str, cache_context: _CacheContext) -> "Deferred[str]":
                 callcount2[0] += 1
                 return self.func(key, on_invalidate=cache_context.invalidate)
 
@@ -678,18 +690,18 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
         self.assertEqual(callcount2[0], 2)
 
     @defer.inlineCallbacks
-    def test_eviction_context(self):
+    def test_eviction_context(self) -> Generator["Deferred[Any]", object, None]:
         callcount = [0]
         callcount2 = [0]
 
         class A:
             @cached(max_entries=2)
-            def func(self, key):
+            def func(self, key: str) -> str:
                 callcount[0] += 1
                 return key
 
             @cached(cache_context=True)
-            def func2(self, key, cache_context):
+            def func2(self, key: str, cache_context: _CacheContext) -> "Deferred[str]":
                 callcount2[0] += 1
                 return self.func(key, on_invalidate=cache_context.invalidate)
 
@@ -715,18 +727,18 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
         self.assertEqual(callcount2[0], 3)
 
     @defer.inlineCallbacks
-    def test_double_get(self):
+    def test_double_get(self) -> Generator["Deferred[Any]", object, None]:
         callcount = [0]
         callcount2 = [0]
 
         class A:
             @cached()
-            def func(self, key):
+            def func(self, key: str) -> str:
                 callcount[0] += 1
                 return key
 
             @cached(cache_context=True)
-            def func2(self, key, cache_context):
+            def func2(self, key: str, cache_context: _CacheContext) -> "Deferred[str]":
                 callcount2[0] += 1
                 return self.func(key, on_invalidate=cache_context.invalidate)
 
@@ -763,17 +775,19 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
 
 class CachedListDescriptorTestCase(unittest.TestCase):
     @defer.inlineCallbacks
-    def test_cache(self):
+    def test_cache(self) -> Generator["Deferred[Any]", object, None]:
         class Cls:
-            def __init__(self):
+            def __init__(self) -> None:
                 self.mock = mock.Mock()
 
             @descriptors.cached()
-            def fn(self, arg1, arg2):
+            def fn(self, arg1: int, arg2: int) -> None:
                 pass
 
             @descriptors.cachedList(cached_method_name="fn", list_name="args1")
-            async def list_fn(self, args1, arg2):
+            async def list_fn(
+                self, args1: Iterable[int], arg2: int
+            ) -> Mapping[int, str]:
                 context = current_context()
                 assert isinstance(context, LoggingContext)
                 assert context.name == "c1"
@@ -824,23 +838,23 @@ class CachedListDescriptorTestCase(unittest.TestCase):
             obj.mock.assert_called_once_with({40}, 2)
             self.assertEqual(r, {10: "fish", 40: "gravy"})
 
-    def test_concurrent_lookups(self):
+    def test_concurrent_lookups(self) -> None:
         """All concurrent lookups should get the same result"""
 
         class Cls:
-            def __init__(self):
+            def __init__(self) -> None:
                 self.mock = mock.Mock()
 
             @descriptors.cached()
-            def fn(self, arg1):
+            def fn(self, arg1: int) -> None:
                 pass
 
             @descriptors.cachedList(cached_method_name="fn", list_name="args1")
-            def list_fn(self, args1) -> "Deferred[dict]":
+            def list_fn(self, args1: List[int]) -> "Deferred[Mapping[int, str]]":
                 return self.mock(args1)
 
         obj = Cls()
-        deferred_result: "Deferred[dict]" = Deferred()
+        deferred_result: "Deferred[Mapping[int, str]]" = Deferred()
         obj.mock.return_value = deferred_result
 
         # start off several concurrent lookups of the same key
@@ -867,19 +881,19 @@ class CachedListDescriptorTestCase(unittest.TestCase):
         self.assertEqual(self.successResultOf(d3), {10: "peas"})
 
     @defer.inlineCallbacks
-    def test_invalidate(self):
+    def test_invalidate(self) -> Generator["Deferred[Any]", object, None]:
         """Make sure that invalidation callbacks are called."""
 
         class Cls:
-            def __init__(self):
+            def __init__(self) -> None:
                 self.mock = mock.Mock()
 
             @descriptors.cached()
-            def fn(self, arg1, arg2):
+            def fn(self, arg1: int, arg2: int) -> None:
                 pass
 
             @descriptors.cachedList(cached_method_name="fn", list_name="args1")
-            async def list_fn(self, args1, arg2):
+            async def list_fn(self, args1: List[int], arg2: int) -> Mapping[int, str]:
                 # we want this to behave like an asynchronous function
                 await run_on_reactor()
                 return self.mock(args1, arg2)
@@ -908,17 +922,17 @@ class CachedListDescriptorTestCase(unittest.TestCase):
         invalidate0.assert_called_once()
         invalidate1.assert_called_once()
 
-    def test_cancel(self):
+    def test_cancel(self) -> None:
         """Test that cancelling a lookup does not cancel other lookups"""
         complete_lookup: "Deferred[None]" = Deferred()
 
         class Cls:
             @cached()
-            def fn(self, arg1):
+            def fn(self, arg1: int) -> None:
                 pass
 
             @cachedList(cached_method_name="fn", list_name="args")
-            async def list_fn(self, args):
+            async def list_fn(self, args: List[int]) -> Mapping[int, str]:
                 await complete_lookup
                 return {arg: str(arg) for arg in args}
 
@@ -936,7 +950,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
         self.failureResultOf(d1, CancelledError)
         self.assertEqual(d2.result, {123: "123", 456: "456", 789: "789"})
 
-    def test_cancel_logcontexts(self):
+    def test_cancel_logcontexts(self) -> None:
         """Test that cancellation does not break logcontexts.
 
         * The `CancelledError` must be raised with the correct logcontext.
@@ -949,18 +963,18 @@ class CachedListDescriptorTestCase(unittest.TestCase):
             inner_context_was_finished = False
 
             @cached()
-            def fn(self, arg1):
+            def fn(self, arg1: int) -> None:
                 pass
 
             @cachedList(cached_method_name="fn", list_name="args")
-            async def list_fn(self, args):
+            async def list_fn(self, args: List[int]) -> Mapping[int, str]:
                 await make_deferred_yieldable(complete_lookup)
                 self.inner_context_was_finished = current_context().finished
                 return {arg: str(arg) for arg in args}
 
         obj = Cls()
 
-        async def do_lookup():
+        async def do_lookup() -> None:
             with LoggingContext("c1") as c1:
                 try:
                     await obj.list_fn([123])
@@ -983,7 +997,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
         )
         self.assertEqual(current_context(), SENTINEL_CONTEXT)
 
-    def test_num_args_mismatch(self):
+    def test_num_args_mismatch(self) -> None:
         """
         Make sure someone does not accidentally use @cachedList on a method with
         a mismatch in the number args to the underlying single cache method.
@@ -991,14 +1005,14 @@ class CachedListDescriptorTestCase(unittest.TestCase):
 
         class Cls:
             @descriptors.cached(tree=True)
-            def fn(self, room_id, event_id):
+            def fn(self, room_id: str, event_id: str) -> None:
                 pass
 
             # This is wrong ❌. `@cachedList` expects to be given the same number
             # of arguments as the underlying cached function, just with one of
             # the arguments being an iterable
             @descriptors.cachedList(cached_method_name="fn", list_name="keys")
-            def list_fn(self, keys: Iterable[Tuple[str, str]]):
+            def list_fn(self, keys: Iterable[Tuple[str, str]]) -> None:
                 pass
 
             # Corrected syntax ✅
diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py
index 91cac9822a..05983ed434 100644
--- a/tests/util/test_async_helpers.py
+++ b/tests/util/test_async_helpers.py
@@ -60,11 +60,9 @@ class ObservableDeferredTest(TestCase):
         observer1.addBoth(check_called_first)
 
         # store the results
-        results: List[Optional[ObservableDeferred[int]]] = [None, None]
+        results: List[Optional[int]] = [None, None]
 
-        def check_val(
-            res: ObservableDeferred[int], idx: int
-        ) -> ObservableDeferred[int]:
+        def check_val(res: int, idx: int) -> int:
             results[idx] = res
             return res
 
@@ -93,14 +91,14 @@ class ObservableDeferredTest(TestCase):
         observer1.addBoth(check_called_first)
 
         # store the results
-        results: List[Optional[ObservableDeferred[str]]] = [None, None]
+        results: List[Optional[Failure]] = [None, None]
 
-        def check_val(res: ObservableDeferred[str], idx: int) -> None:
+        def check_failure(res: Failure, idx: int) -> None:
             results[idx] = res
             return None
 
-        observer1.addErrback(check_val, 0)
-        observer2.addErrback(check_val, 1)
+        observer1.addErrback(check_failure, 0)
+        observer2.addErrback(check_failure, 1)
 
         try:
             raise Exception("gah!")
diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
index 5f8f4e76b5..4bcd17a6fc 100644
--- a/tests/util/test_retryutils.py
+++ b/tests/util/test_retryutils.py
@@ -11,12 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from synapse.util.retryutils import (
-    MIN_RETRY_INTERVAL,
-    RETRY_MULTIPLIER,
-    NotRetryingDestination,
-    get_retry_limiter,
-)
+from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
 
 from tests.unittest import HomeserverTestCase
 
@@ -42,6 +37,11 @@ class RetryLimiterTestCase(HomeserverTestCase):
 
         limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
 
+        min_retry_interval_ms = (
+            self.hs.config.federation.destination_min_retry_interval_ms
+        )
+        retry_multiplier = self.hs.config.federation.destination_retry_multiplier
+
         self.pump(1)
         try:
             with limiter:
@@ -57,7 +57,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
         assert new_timings is not None
         self.assertEqual(new_timings.failure_ts, failure_ts)
         self.assertEqual(new_timings.retry_last_ts, failure_ts)
-        self.assertEqual(new_timings.retry_interval, MIN_RETRY_INTERVAL)
+        self.assertEqual(new_timings.retry_interval, min_retry_interval_ms)
 
         # now if we try again we should get a failure
         self.get_failure(
@@ -68,7 +68,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
         # advance the clock and try again
         #
 
-        self.pump(MIN_RETRY_INTERVAL)
+        self.pump(min_retry_interval_ms)
         limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
 
         self.pump(1)
@@ -87,16 +87,16 @@ class RetryLimiterTestCase(HomeserverTestCase):
         self.assertEqual(new_timings.failure_ts, failure_ts)
         self.assertEqual(new_timings.retry_last_ts, retry_ts)
         self.assertGreaterEqual(
-            new_timings.retry_interval, MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 0.5
+            new_timings.retry_interval, min_retry_interval_ms * retry_multiplier * 0.5
         )
         self.assertLessEqual(
-            new_timings.retry_interval, MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0
+            new_timings.retry_interval, min_retry_interval_ms * retry_multiplier * 2.0
         )
 
         #
         # one more go, with success
         #
-        self.reactor.advance(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0)
+        self.reactor.advance(min_retry_interval_ms * retry_multiplier * 2.0)
         limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
 
         self.pump(1)
@@ -108,3 +108,54 @@ class RetryLimiterTestCase(HomeserverTestCase):
 
         new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
         self.assertIsNone(new_timings)
+
+    def test_max_retry_interval(self) -> None:
+        """Test that `destination_max_retry_interval` setting works as expected"""
+        store = self.hs.get_datastores().main
+
+        destination_max_retry_interval_ms = (
+            self.hs.config.federation.destination_max_retry_interval_ms
+        )
+
+        self.get_success(get_retry_limiter("test_dest", self.clock, store))
+        self.pump(1)
+
+        failure_ts = self.clock.time_msec()
+
+        # Simulate reaching destination_max_retry_interval
+        self.get_success(
+            store.set_destination_retry_timings(
+                "test_dest",
+                failure_ts=failure_ts,
+                retry_last_ts=failure_ts,
+                retry_interval=destination_max_retry_interval_ms,
+            )
+        )
+
+        # Check it fails
+        self.get_failure(
+            get_retry_limiter("test_dest", self.clock, store), NotRetryingDestination
+        )
+
+        # Get past retry_interval and we can try again, and still throw an error to continue the backoff
+        self.reactor.advance(destination_max_retry_interval_ms / 1000 + 1)
+        limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
+        self.pump(1)
+        try:
+            with limiter:
+                self.pump(1)
+                raise AssertionError("argh")
+        except AssertionError:
+            pass
+
+        self.pump()
+
+        # retry_interval does not increase and stays at destination_max_retry_interval_ms
+        new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
+        assert new_timings is not None
+        self.assertEqual(new_timings.retry_interval, destination_max_retry_interval_ms)
+
+        # Check it fails
+        self.get_failure(
+            get_retry_limiter("test_dest", self.clock, store), NotRetryingDestination
+        )
diff --git a/tests/util/test_task_scheduler.py b/tests/util/test_task_scheduler.py
new file mode 100644
index 0000000000..8665aeb50c
--- /dev/null
+++ b/tests/util/test_task_scheduler.py
@@ -0,0 +1,208 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple
+
+from twisted.internet.task import deferLater
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
+from synapse.types import JsonMapping, ScheduledTask, TaskStatus
+from synapse.util import Clock
+from synapse.util.task_scheduler import TaskScheduler
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.unittest import HomeserverTestCase, override_config
+
+
+class TestTaskScheduler(HomeserverTestCase):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.task_scheduler = hs.get_task_scheduler()
+        self.task_scheduler.register_action(self._test_task, "_test_task")
+        self.task_scheduler.register_action(self._sleeping_task, "_sleeping_task")
+        self.task_scheduler.register_action(self._raising_task, "_raising_task")
+        self.task_scheduler.register_action(self._resumable_task, "_resumable_task")
+
+    async def _test_task(
+        self, task: ScheduledTask
+    ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+        # This test task will copy the parameters to the result
+        result = None
+        if task.params:
+            result = task.params
+        return (TaskStatus.COMPLETE, result, None)
+
+    def test_schedule_task(self) -> None:
+        """Schedule a task in the future with some parameters to be copied as a result and check it executed correctly.
+        Also check that it get removed after `KEEP_TASKS_FOR_MS`."""
+        timestamp = self.clock.time_msec() + 30 * 1000
+        task_id = self.get_success(
+            self.task_scheduler.schedule_task(
+                "_test_task",
+                timestamp=timestamp,
+                params={"val": 1},
+            )
+        )
+
+        task = self.get_success(self.task_scheduler.get_task(task_id))
+        assert task is not None
+        self.assertEqual(task.status, TaskStatus.SCHEDULED)
+        self.assertIsNone(task.result)
+
+        # The timestamp being 30s after now the task should been executed
+        # after the first scheduling loop is run
+        self.reactor.advance(TaskScheduler.SCHEDULE_INTERVAL_MS / 1000)
+
+        task = self.get_success(self.task_scheduler.get_task(task_id))
+        assert task is not None
+        self.assertEqual(task.status, TaskStatus.COMPLETE)
+        assert task.result is not None
+        # The passed parameter should have been copied to the result
+        self.assertTrue(task.result.get("val") == 1)
+
+        # Let's wait for the complete task to be deleted and hence unavailable
+        self.reactor.advance((TaskScheduler.KEEP_TASKS_FOR_MS / 1000) + 1)
+
+        task = self.get_success(self.task_scheduler.get_task(task_id))
+        self.assertIsNone(task)
+
+    async def _sleeping_task(
+        self, task: ScheduledTask
+    ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+        # Sleep for a second
+        await deferLater(self.reactor, 1, lambda: None)
+        return TaskStatus.COMPLETE, None, None
+
+    def test_schedule_lot_of_tasks(self) -> None:
+        """Schedule more than `TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS` tasks and check the behavior."""
+        task_ids = []
+        for i in range(TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS + 1):
+            task_ids.append(
+                self.get_success(
+                    self.task_scheduler.schedule_task(
+                        "_sleeping_task",
+                        params={"val": i},
+                    )
+                )
+            )
+
+        # This is to give the time to the active tasks to finish
+        self.reactor.advance(1)
+
+        # Check that only MAX_CONCURRENT_RUNNING_TASKS tasks has run and that one
+        # is still scheduled.
+        tasks = [
+            self.get_success(self.task_scheduler.get_task(task_id))
+            for task_id in task_ids
+        ]
+
+        self.assertEquals(
+            len(
+                [t for t in tasks if t is not None and t.status == TaskStatus.COMPLETE]
+            ),
+            TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS,
+        )
+
+        scheduled_tasks = [
+            t for t in tasks if t is not None and t.status == TaskStatus.ACTIVE
+        ]
+        self.assertEquals(len(scheduled_tasks), 1)
+
+        # We need to wait for the next run of the scheduler loop
+        self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
+        self.reactor.advance(1)
+
+        # Check that the last task has been properly executed after the next scheduler loop run
+        prev_scheduled_task = self.get_success(
+            self.task_scheduler.get_task(scheduled_tasks[0].id)
+        )
+        assert prev_scheduled_task is not None
+        self.assertEquals(
+            prev_scheduled_task.status,
+            TaskStatus.COMPLETE,
+        )
+
+    async def _raising_task(
+        self, task: ScheduledTask
+    ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+        raise Exception("raising")
+
+    def test_schedule_raising_task(self) -> None:
+        """Schedule a task raising an exception and check it runs to failure and report exception content."""
+        task_id = self.get_success(self.task_scheduler.schedule_task("_raising_task"))
+
+        task = self.get_success(self.task_scheduler.get_task(task_id))
+        assert task is not None
+        self.assertEqual(task.status, TaskStatus.FAILED)
+        self.assertEqual(task.error, "raising")
+
+    async def _resumable_task(
+        self, task: ScheduledTask
+    ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+        if task.result and "in_progress" in task.result:
+            return TaskStatus.COMPLETE, {"success": True}, None
+        else:
+            await self.task_scheduler.update_task(task.id, result={"in_progress": True})
+            # Await forever to simulate an aborted task because of a restart
+            await deferLater(self.reactor, 2**16, lambda: None)
+            # This should never been called
+            return TaskStatus.ACTIVE, None, None
+
+    def test_schedule_resumable_task(self) -> None:
+        """Schedule a resumable task and check that it gets properly resumed and complete after simulating a synapse restart."""
+        task_id = self.get_success(self.task_scheduler.schedule_task("_resumable_task"))
+
+        task = self.get_success(self.task_scheduler.get_task(task_id))
+        assert task is not None
+        self.assertEqual(task.status, TaskStatus.ACTIVE)
+
+        # Simulate a synapse restart by emptying the list of running tasks
+        self.task_scheduler._running_tasks = set()
+        self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
+
+        task = self.get_success(self.task_scheduler.get_task(task_id))
+        assert task is not None
+        self.assertEqual(task.status, TaskStatus.COMPLETE)
+        assert task.result is not None
+        self.assertTrue(task.result.get("success"))
+
+
+class TestTaskSchedulerWithBackgroundWorker(BaseMultiWorkerStreamTestCase):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.task_scheduler = hs.get_task_scheduler()
+        self.task_scheduler.register_action(self._test_task, "_test_task")
+
+    async def _test_task(
+        self, task: ScheduledTask
+    ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+        return (TaskStatus.COMPLETE, None, None)
+
+    @override_config({"run_background_tasks_on": "worker1"})
+    def test_schedule_task(self) -> None:
+        """Check that a task scheduled to run now is launch right away on the background worker."""
+        bg_worker_hs = self.make_worker_hs(
+            "synapse.app.generic_worker",
+            extra_config={"worker_name": "worker1"},
+        )
+        bg_worker_hs.get_task_scheduler().register_action(self._test_task, "_test_task")
+
+        task_id = self.get_success(
+            self.task_scheduler.schedule_task(
+                "_test_task",
+            )
+        )
+
+        task = self.get_success(self.task_scheduler.get_task(task_id))
+        assert task is not None
+        self.assertEqual(task.status, TaskStatus.COMPLETE)