summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/appservice/test_scheduler.py54
-rw-r--r--tests/crypto/test_event_signing.py6
-rw-r--r--tests/federation/test_federation_sender.py23
-rw-r--r--tests/federation/transport/test_knocking.py3
-rw-r--r--tests/handlers/test_appservice.py121
-rw-r--r--tests/handlers/test_deactivate_account.py25
-rw-r--r--tests/handlers/test_e2e_keys.py6
-rw-r--r--tests/handlers/test_federation.py77
-rw-r--r--tests/handlers/test_federation_event.py225
-rw-r--r--tests/handlers/test_oidc.py7
-rw-r--r--tests/handlers/test_profile.py2
-rw-r--r--tests/handlers/test_sync.py4
-rw-r--r--tests/handlers/test_user_directory.py2
-rw-r--r--tests/module_api/test_api.py22
-rw-r--r--tests/replication/slave/storage/test_events.py4
-rw-r--r--tests/rest/admin/test_media.py8
-rw-r--r--tests/rest/admin/test_server_notice.py12
-rw-r--r--tests/rest/admin/test_user.py15
-rw-r--r--tests/rest/client/test_account.py11
-rw-r--r--tests/rest/client/test_account_data.py75
-rw-r--r--tests/rest/client/test_relations.py207
-rw-r--r--tests/rest/client/test_rooms.py2
-rw-r--r--tests/rest/client/test_sync.py68
-rw-r--r--tests/rest/client/test_third_party_rules.py41
-rw-r--r--tests/rest/client/utils.py11
-rw-r--r--tests/rest/key/v2/test_remote_key_resource.py16
-rw-r--r--tests/rest/media/v1/test_url_preview.py43
-rw-r--r--tests/server.py10
-rw-r--r--tests/storage/databases/main/test_lock.py8
-rw-r--r--tests/storage/test__base.py14
-rw-r--r--tests/storage/test_appservice.py75
-rw-r--r--tests/storage/test_cleanup_extrems.py20
-rw-r--r--tests/storage/test_devices.py14
-rw-r--r--tests/storage/test_id_generators.py12
-rw-r--r--tests/storage/test_monthly_active_users.py47
-rw-r--r--tests/storage/test_redaction.py62
-rw-r--r--tests/storage/test_roommember.py25
-rw-r--r--tests/storage/test_stream.py4
-rw-r--r--tests/test_phone_home.py12
-rw-r--r--tests/test_visibility.py71
-rw-r--r--tests/unittest.py127
-rw-r--r--tests/util/test_linearizer.py224
-rw-r--r--tests/utils.py8
43 files changed, 1174 insertions, 649 deletions
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index 1cbb059357..0b22afdc75 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -24,6 +24,7 @@ from synapse.appservice.scheduler import (
 )
 from synapse.logging.context import make_deferred_yieldable
 from synapse.server import HomeServer
+from synapse.types import DeviceListUpdates
 from synapse.util import Clock
 
 from tests import unittest
@@ -70,6 +71,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
             to_device_messages=[],  # txn made and saved
             one_time_key_counts={},
             unused_fallback_keys={},
+            device_list_summary=DeviceListUpdates(),
         )
         self.assertEqual(0, len(self.txnctrl.recoverers))  # no recoverer made
         txn.complete.assert_called_once_with(self.store)  # txn completed
@@ -96,6 +98,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
             to_device_messages=[],  # txn made and saved
             one_time_key_counts={},
             unused_fallback_keys={},
+            device_list_summary=DeviceListUpdates(),
         )
         self.assertEqual(0, txn.send.call_count)  # txn not sent though
         self.assertEqual(0, txn.complete.call_count)  # or completed
@@ -124,6 +127,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
             to_device_messages=[],
             one_time_key_counts={},
             unused_fallback_keys={},
+            device_list_summary=DeviceListUpdates(),
         )
         self.assertEqual(1, self.recoverer_fn.call_count)  # recoverer made
         self.assertEqual(1, self.recoverer.recover.call_count)  # and invoked
@@ -225,7 +229,9 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
         service = Mock(id=4)
         event = Mock()
         self.scheduler.enqueue_for_appservice(service, events=[event])
-        self.txn_ctrl.send.assert_called_once_with(service, [event], [], [], None, None)
+        self.txn_ctrl.send.assert_called_once_with(
+            service, [event], [], [], None, None, DeviceListUpdates()
+        )
 
     def test_send_single_event_with_queue(self):
         d = defer.Deferred()
@@ -240,12 +246,14 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
         # (call enqueue_for_appservice multiple times deliberately)
         self.scheduler.enqueue_for_appservice(service, events=[event2])
         self.scheduler.enqueue_for_appservice(service, events=[event3])
-        self.txn_ctrl.send.assert_called_with(service, [event], [], [], None, None)
+        self.txn_ctrl.send.assert_called_with(
+            service, [event], [], [], None, None, DeviceListUpdates()
+        )
         self.assertEqual(1, self.txn_ctrl.send.call_count)
         # Resolve the send event: expect the queued events to be sent
         d.callback(service)
         self.txn_ctrl.send.assert_called_with(
-            service, [event2, event3], [], [], None, None
+            service, [event2, event3], [], [], None, None, DeviceListUpdates()
         )
         self.assertEqual(2, self.txn_ctrl.send.call_count)
 
@@ -272,15 +280,21 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
         # send events for different ASes and make sure they are sent
         self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event])
         self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2])
-        self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], [], None, None)
+        self.txn_ctrl.send.assert_called_with(
+            srv1, [srv_1_event], [], [], None, None, DeviceListUpdates()
+        )
         self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event])
         self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2])
-        self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], [], None, None)
+        self.txn_ctrl.send.assert_called_with(
+            srv2, [srv_2_event], [], [], None, None, DeviceListUpdates()
+        )
 
         # make sure callbacks for a service only send queued events for THAT
         # service
         srv_2_defer.callback(srv2)
-        self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], [], None, None)
+        self.txn_ctrl.send.assert_called_with(
+            srv2, [srv_2_event2], [], [], None, None, DeviceListUpdates()
+        )
         self.assertEqual(3, self.txn_ctrl.send.call_count)
 
     def test_send_large_txns(self):
@@ -300,17 +314,17 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
 
         # Expect the first event to be sent immediately.
         self.txn_ctrl.send.assert_called_with(
-            service, [event_list[0]], [], [], None, None
+            service, [event_list[0]], [], [], None, None, DeviceListUpdates()
         )
         srv_1_defer.callback(service)
         # Then send the next 100 events
         self.txn_ctrl.send.assert_called_with(
-            service, event_list[1:101], [], [], None, None
+            service, event_list[1:101], [], [], None, None, DeviceListUpdates()
         )
         srv_2_defer.callback(service)
         # Then the final 99 events
         self.txn_ctrl.send.assert_called_with(
-            service, event_list[101:], [], [], None, None
+            service, event_list[101:], [], [], None, None, DeviceListUpdates()
         )
         self.assertEqual(3, self.txn_ctrl.send.call_count)
 
@@ -320,7 +334,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
         event_list = [Mock(name="event")]
         self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
         self.txn_ctrl.send.assert_called_once_with(
-            service, [], event_list, [], None, None
+            service, [], event_list, [], None, None, DeviceListUpdates()
         )
 
     def test_send_multiple_ephemeral_no_queue(self):
@@ -329,7 +343,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
         event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")]
         self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
         self.txn_ctrl.send.assert_called_once_with(
-            service, [], event_list, [], None, None
+            service, [], event_list, [], None, None, DeviceListUpdates()
         )
 
     def test_send_single_ephemeral_with_queue(self):
@@ -345,13 +359,21 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
         # Send more events: expect send() to NOT be called multiple times.
         self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2)
         self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3)
-        self.txn_ctrl.send.assert_called_with(service, [], event_list_1, [], None, None)
+        self.txn_ctrl.send.assert_called_with(
+            service, [], event_list_1, [], None, None, DeviceListUpdates()
+        )
         self.assertEqual(1, self.txn_ctrl.send.call_count)
         # Resolve txn_ctrl.send
         d.callback(service)
         # Expect the queued events to be sent
         self.txn_ctrl.send.assert_called_with(
-            service, [], event_list_2 + event_list_3, [], None, None
+            service,
+            [],
+            event_list_2 + event_list_3,
+            [],
+            None,
+            None,
+            DeviceListUpdates(),
         )
         self.assertEqual(2, self.txn_ctrl.send.call_count)
 
@@ -365,8 +387,10 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
         event_list = first_chunk + second_chunk
         self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
         self.txn_ctrl.send.assert_called_once_with(
-            service, [], first_chunk, [], None, None
+            service, [], first_chunk, [], None, None, DeviceListUpdates()
         )
         d.callback(service)
-        self.txn_ctrl.send.assert_called_with(service, [], second_chunk, [], None, None)
+        self.txn_ctrl.send.assert_called_with(
+            service, [], second_chunk, [], None, None, DeviceListUpdates()
+        )
         self.assertEqual(2, self.txn_ctrl.send.call_count)
diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py
index 694020fbef..06e0545a4f 100644
--- a/tests/crypto/test_event_signing.py
+++ b/tests/crypto/test_event_signing.py
@@ -28,8 +28,8 @@ from tests import unittest
 SIGNING_KEY_SEED = decode_base64("YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1")
 
 KEY_ALG = "ed25519"
-KEY_VER = 1
-KEY_NAME = "%s:%d" % (KEY_ALG, KEY_VER)
+KEY_VER = "1"
+KEY_NAME = "%s:%s" % (KEY_ALG, KEY_VER)
 
 HOSTNAME = "domain"
 
@@ -39,7 +39,7 @@ class EventSigningTestCase(unittest.TestCase):
         # NB: `signedjson` expects `nacl.signing.SigningKey` instances which have been
         # monkeypatched to include new `alg` and `version` attributes. This is captured
         # by the `signedjson.types.SigningKey` protocol.
-        self.signing_key: signedjson.types.SigningKey = nacl.signing.SigningKey(
+        self.signing_key: signedjson.types.SigningKey = nacl.signing.SigningKey(  # type: ignore[assignment]
             SIGNING_KEY_SEED
         )
         self.signing_key.alg = KEY_ALG
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index e90592855a..a6e91956af 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -14,6 +14,7 @@
 from typing import Optional
 from unittest.mock import Mock
 
+from parameterized import parameterized_class
 from signedjson import key, sign
 from signedjson.types import BaseKey, SigningKey
 
@@ -154,6 +155,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
         )
 
 
+@parameterized_class(
+    [
+        {"enable_room_poke_code_path": False},
+        {"enable_room_poke_code_path": True},
+    ]
+)
 class FederationSenderDevicesTestCases(HomeserverTestCase):
     servlets = [
         admin.register_servlets,
@@ -168,17 +175,21 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
     def default_config(self):
         c = super().default_config()
         c["send_federation"] = True
+        c["use_new_device_lists_changes_in_room"] = self.enable_room_poke_code_path
         return c
 
     def prepare(self, reactor, clock, hs):
-        # stub out get_users_who_share_room_with_user so that it claims that
-        # `@user2:host2` is in the room
-        def get_users_who_share_room_with_user(user_id):
+        # stub out `get_rooms_for_user` and `get_users_in_room` so that the
+        # server thinks the user shares a room with `@user2:host2`
+        def get_rooms_for_user(user_id):
+            return defer.succeed({"!room:host1"})
+
+        hs.get_datastores().main.get_rooms_for_user = get_rooms_for_user
+
+        def get_users_in_room(room_id):
             return defer.succeed({"@user2:host2"})
 
-        hs.get_datastores().main.get_users_who_share_room_with_user = (
-            get_users_who_share_room_with_user
-        )
+        hs.get_datastores().main.get_users_in_room = get_users_in_room
 
         # whenever send_transaction is called, record the edu data
         self.edus = []
diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py
index 648a01618e..d21c11b716 100644
--- a/tests/federation/transport/test_knocking.py
+++ b/tests/federation/transport/test_knocking.py
@@ -23,7 +23,7 @@ from synapse.server import HomeServer
 from synapse.types import RoomAlias
 
 from tests.test_utils import event_injection
-from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config
+from tests.unittest import FederatingHomeserverTestCase, TestCase
 
 
 class KnockingStrippedStateEventHelperMixin(TestCase):
@@ -221,7 +221,6 @@ class FederationKnockingTestCase(
 
         return super().prepare(reactor, clock, homeserver)
 
-    @override_config({"experimental_features": {"msc2403_enabled": True}})
     def test_room_state_returned_when_knocking(self):
         """
         Tests that specific, stripped state events from a room are returned after
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index cead9f90df..8c72cf6b30 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -15,6 +15,8 @@
 from typing import Dict, Iterable, List, Optional
 from unittest.mock import Mock
 
+from parameterized import parameterized
+
 from twisted.internet import defer
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -471,6 +473,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
             to_device_messages,
             _otks,
             _fbks,
+            _device_list_summary,
         ) = self.send_mock.call_args[0]
 
         # Assert that this was the same to-device message that local_user sent
@@ -583,7 +586,15 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
         service_id_to_message_count: Dict[str, int] = {}
 
         for call in self.send_mock.call_args_list:
-            service, _events, _ephemeral, to_device_messages, _otks, _fbks = call[0]
+            (
+                service,
+                _events,
+                _ephemeral,
+                to_device_messages,
+                _otks,
+                _fbks,
+                _device_list_summary,
+            ) = call[0]
 
             # Check that this was made to an interested service
             self.assertIn(service, interested_appservices)
@@ -627,6 +638,114 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
         return appservice
 
 
+class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase):
+    """
+    Tests that the ApplicationServicesHandler sends device list updates to application
+    services correctly.
+    """
+
+    servlets = [
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        # Allow us to modify cached feature flags mid-test
+        self.as_handler = hs.get_application_service_handler()
+
+        # Mock ApplicationServiceApi's put_json, so we can verify the raw JSON that
+        # will be sent over the wire
+        self.put_json = simple_async_mock()
+        hs.get_application_service_api().put_json = self.put_json  # type: ignore[assignment]
+
+        # Mock out application services, and allow defining our own in tests
+        self._services: List[ApplicationService] = []
+        self.hs.get_datastores().main.get_app_services = Mock(
+            return_value=self._services
+        )
+
+    # Test across a variety of configuration values
+    @parameterized.expand(
+        [
+            (True, True, True),
+            (True, False, False),
+            (False, True, False),
+            (False, False, False),
+        ]
+    )
+    def test_application_service_receives_device_list_updates(
+        self,
+        experimental_feature_enabled: bool,
+        as_supports_txn_extensions: bool,
+        as_should_receive_device_list_updates: bool,
+    ):
+        """
+        Tests that an application service receives notice of changed device
+        lists for a user, when a user changes their device lists.
+
+        Arguments above are populated by parameterized.
+
+        Args:
+            as_should_receive_device_list_updates: Whether we expect the AS to receive the
+                device list changes.
+            experimental_feature_enabled: Whether the "msc3202_transaction_extensions" experimental
+                feature is enabled. This feature must be enabled for device lists to ASs to work.
+            as_supports_txn_extensions: Whether the application service has explicitly registered
+                to receive information defined by MSC3202 - which includes device list changes.
+        """
+        # Change whether the experimental feature is enabled or disabled before making
+        # device list changes
+        self.as_handler._msc3202_transaction_extensions_enabled = (
+            experimental_feature_enabled
+        )
+
+        # Create an appservice that is interested in "local_user"
+        appservice = ApplicationService(
+            token=random_string(10),
+            hostname="example.com",
+            id=random_string(10),
+            sender="@as:example.com",
+            rate_limited=False,
+            namespaces={
+                ApplicationService.NS_USERS: [
+                    {
+                        "regex": "@local_user:.+",
+                        "exclusive": False,
+                    }
+                ],
+            },
+            supports_ephemeral=True,
+            msc3202_transaction_extensions=as_supports_txn_extensions,
+            # Must be set for Synapse to try pushing data to the AS
+            hs_token="abcde",
+            url="some_url",
+        )
+
+        # Register the application service
+        self._services.append(appservice)
+
+        # Register a user on the homeserver
+        self.local_user = self.register_user("local_user", "password")
+        self.local_user_token = self.login("local_user", "password")
+
+        if as_should_receive_device_list_updates:
+            # Ensure that the resulting JSON uses the unstable prefix and contains the
+            # expected users
+            self.put_json.assert_called_once()
+            json_body = self.put_json.call_args[1]["json_body"]
+
+            # Our application service should have received a device list update with
+            # "local_user" in the "changed" list
+            device_list_dict = json_body.get("org.matrix.msc3202.device_lists", {})
+            self.assertEqual([], device_list_dict["left"])
+            self.assertEqual([self.local_user], device_list_dict["changed"])
+
+        else:
+            # No device list changes should have been sent out
+            self.put_json.assert_not_called()
+
+
 class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
     # Argument indices for pulling out arguments from a `send_mock`.
     ARG_OTK_COUNTS = 4
diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py
index 3a10791226..7586e472b5 100644
--- a/tests/handlers/test_deactivate_account.py
+++ b/tests/handlers/test_deactivate_account.py
@@ -44,21 +44,20 @@ class DeactivateAccountTestCase(HomeserverTestCase):
         Deactivates the account `self.user` using `self.token` and asserts
         that it returns a 200 success code.
         """
-        req = self.get_success(
-            self.make_request(
-                "POST",
-                "account/deactivate",
-                {
-                    "auth": {
-                        "type": "m.login.password",
-                        "user": self.user,
-                        "password": "pass",
-                    },
-                    "erase": True,
+        req = self.make_request(
+            "POST",
+            "account/deactivate",
+            {
+                "auth": {
+                    "type": "m.login.password",
+                    "user": self.user,
+                    "password": "pass",
                 },
-                access_token=self.token,
-            )
+                "erase": True,
+            },
+            access_token=self.token,
         )
+
         self.assertEqual(req.code, HTTPStatus.OK, req)
 
     def test_global_account_data_deleted_upon_deactivation(self) -> None:
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index ac21a28c43..8c74ed1fcf 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -463,8 +463,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         res = e.value.code
         self.assertEqual(res, 400)
 
-        res = self.get_success(self.handler.query_local_devices({local_user: None}))
-        self.assertDictEqual(res, {local_user: {}})
+        query_res = self.get_success(
+            self.handler.query_local_devices({local_user: None})
+        )
+        self.assertDictEqual(query_res, {local_user: {}})
 
     def test_upload_signatures(self) -> None:
         """should check signatures that are uploaded"""
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 4d65639a1e..060ba5f517 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -20,17 +20,17 @@ from twisted.test.proto_helpers import MemoryReactor
 from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
 from synapse.api.room_versions import RoomVersions
-from synapse.events import EventBase
+from synapse.events import EventBase, make_event_from_dict
 from synapse.federation.federation_base import event_from_pdu_json
 from synapse.logging.context import LoggingContext, run_in_background
 from synapse.rest import admin
 from synapse.rest.client import login, room
 from synapse.server import HomeServer
-from synapse.types import create_requester
 from synapse.util import Clock
 from synapse.util.stringutils import random_string
 
 from tests import unittest
+from tests.test_utils import event_injection
 
 logger = logging.getLogger(__name__)
 
@@ -39,7 +39,7 @@ def generate_fake_event_id() -> str:
     return "$fake_" + random_string(43)
 
 
-class FederationTestCase(unittest.HomeserverTestCase):
+class FederationTestCase(unittest.FederatingHomeserverTestCase):
     servlets = [
         admin.register_servlets,
         login.register_servlets,
@@ -219,41 +219,77 @@ class FederationTestCase(unittest.HomeserverTestCase):
         # create the room
         user_id = self.register_user("kermit", "test")
         tok = self.login("kermit", "test")
-        requester = create_requester(user_id)
 
         room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+        room_version = self.get_success(self.store.get_room_version(room_id))
+
+        # we need a user on the remote server to be a member, so that we can send
+        # extremity-causing events.
+        self.get_success(
+            event_injection.inject_member_event(
+                self.hs, room_id, f"@user:{self.OTHER_SERVER_NAME}", "join"
+            )
+        )
 
-        ev1 = self.helper.send(room_id, "first message", tok=tok)
+        send_result = self.helper.send(room_id, "first message", tok=tok)
+        ev1 = self.get_success(
+            self.store.get_event(send_result["event_id"], allow_none=False)
+        )
+        current_state = self.get_success(
+            self.store.get_events_as_list(
+                (self.get_success(self.store.get_current_state_ids(room_id))).values()
+            )
+        )
 
         # Create "many" backward extremities. The magic number we're trying to
         # create more than is 5 which corresponds to the number of backward
         # extremities we slice off in `_maybe_backfill_inner`
+        federation_event_handler = self.hs.get_federation_event_handler()
         for _ in range(0, 8):
-            event_handler = self.hs.get_event_creation_handler()
-            event, context = self.get_success(
-                event_handler.create_event(
-                    requester,
+            event = make_event_from_dict(
+                self.add_hashes_and_signatures(
                     {
+                        "origin_server_ts": 1,
                         "type": "m.room.message",
                         "content": {
                             "msgtype": "m.text",
                             "body": "message connected to fake event",
                         },
                         "room_id": room_id,
-                        "sender": user_id,
+                        "sender": f"@user:{self.OTHER_SERVER_NAME}",
+                        "prev_events": [
+                            ev1.event_id,
+                            # We're creating an backward extremity each time thanks
+                            # to this fake event
+                            generate_fake_event_id(),
+                        ],
+                        # lazy: *everything* is an auth event
+                        "auth_events": [ev.event_id for ev in current_state],
+                        "depth": ev1.depth + 1,
                     },
-                    prev_event_ids=[
-                        ev1["event_id"],
-                        # We're creating an backward extremity each time thanks
-                        # to this fake event
-                        generate_fake_event_id(),
-                    ],
-                )
+                    room_version,
+                ),
+                room_version,
             )
+
+            # we poke this directly into _process_received_pdu, to avoid the
+            # federation handler wanting to backfill the fake event.
             self.get_success(
-                event_handler.handle_new_client_event(requester, event, context)
+                federation_event_handler._process_received_pdu(
+                    self.OTHER_SERVER_NAME, event, state=current_state
+                )
             )
 
+        # we should now have 8 backwards extremities.
+        backwards_extremities = self.get_success(
+            self.store.db_pool.simple_select_list(
+                "event_backward_extremities",
+                keyvalues={"room_id": room_id},
+                retcols=["event_id"],
+            )
+        )
+        self.assertEqual(len(backwards_extremities), 8)
+
         current_depth = 1
         limit = 100
         with LoggingContext("receive_pdu"):
@@ -339,7 +375,8 @@ class FederationTestCase(unittest.HomeserverTestCase):
         member_event.signatures = member_event_dict["signatures"]
 
         # Add the new member_event to the StateMap
-        prev_state_map[
+        updated_state_map = dict(prev_state_map)
+        updated_state_map[
             (member_event.type, member_event.state_key)
         ] = member_event.event_id
         auth_events.append(member_event)
@@ -363,7 +400,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
                 prev_event_ids=message_event_dict["prev_events"],
                 auth_event_ids=self._event_auth_handler.compute_auth_events(
                     builder,
-                    prev_state_map,
+                    updated_state_map,
                     for_verification=False,
                 ),
                 depth=message_event_dict["depth"],
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
new file mode 100644
index 0000000000..489ba57736
--- /dev/null
+++ b/tests/handlers/test_federation_event.py
@@ -0,0 +1,225 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest import mock
+
+from synapse.events import make_event_from_dict
+from synapse.events.snapshot import EventContext
+from synapse.federation.transport.client import StateRequestResponse
+from synapse.logging.context import LoggingContext
+from synapse.rest import admin
+from synapse.rest.client import login, room
+
+from tests import unittest
+from tests.test_utils import event_injection, make_awaitable
+
+
+class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
+    servlets = [
+        admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor, clock):
+        # mock out the federation transport client
+        self.mock_federation_transport_client = mock.Mock(
+            spec=["get_room_state_ids", "get_room_state", "get_event"]
+        )
+        return super().setup_test_homeserver(
+            federation_transport_client=self.mock_federation_transport_client
+        )
+
+    def test_process_pulled_event_with_missing_state(self) -> None:
+        """Ensure that we correctly handle pulled events with lots of missing state
+
+        In this test, we pretend we are processing a "pulled" event (eg, via backfill
+        or get_missing_events). The pulled event has a prev_event we haven't previously
+        seen, so the server requests the state at that prev_event. There is a lot
+        of state we don't have, so we expect the server to make a /state request.
+
+        We check that the pulled event is correctly persisted, and that the state is
+        as we expect.
+        """
+        return self._test_process_pulled_event_with_missing_state(False)
+
+    def test_process_pulled_event_with_missing_state_where_prev_is_outlier(
+        self,
+    ) -> None:
+        """Ensure that we correctly handle pulled events with lots of missing state
+
+        A slight modification to test_process_pulled_event_with_missing_state. Again
+        we have a "pulled" event which refers to a prev_event with lots of state,
+        but in this case we already have the prev_event (as an outlier, obviously -
+        if it were a regular event, we wouldn't need to request the state).
+        """
+        return self._test_process_pulled_event_with_missing_state(True)
+
+    def _test_process_pulled_event_with_missing_state(
+        self, prev_exists_as_outlier: bool
+    ) -> None:
+        OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
+        main_store = self.hs.get_datastores().main
+        state_storage = self.hs.get_storage().state
+
+        # create the room
+        user_id = self.register_user("kermit", "test")
+        tok = self.login("kermit", "test")
+        room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+        room_version = self.get_success(main_store.get_room_version(room_id))
+
+        # allow the remote user to send state events
+        self.helper.send_state(
+            room_id,
+            "m.room.power_levels",
+            {"events_default": 0, "state_default": 0},
+            tok=tok,
+        )
+
+        # add the remote user to the room
+        member_event = self.get_success(
+            event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join")
+        )
+
+        initial_state_map = self.get_success(main_store.get_current_state_ids(room_id))
+
+        auth_event_ids = [
+            initial_state_map[("m.room.create", "")],
+            initial_state_map[("m.room.power_levels", "")],
+            initial_state_map[("m.room.join_rules", "")],
+            member_event.event_id,
+        ]
+
+        # mock up a load of state events which we are missing
+        state_events = [
+            make_event_from_dict(
+                self.add_hashes_and_signatures(
+                    {
+                        "type": "test_state_type",
+                        "state_key": f"state_{i}",
+                        "room_id": room_id,
+                        "sender": OTHER_USER,
+                        "prev_events": [member_event.event_id],
+                        "auth_events": auth_event_ids,
+                        "origin_server_ts": 1,
+                        "depth": 10,
+                        "content": {"body": f"state_{i}"},
+                    }
+                ),
+                room_version,
+            )
+            for i in range(1, 10)
+        ]
+
+        # this is the state that we are going to claim is active at the prev_event.
+        state_at_prev_event = state_events + self.get_success(
+            main_store.get_events_as_list(initial_state_map.values())
+        )
+
+        # mock up a prev event.
+        # Depending on the test, we either persist this upfront (as an outlier),
+        # or let the server request it.
+        prev_event = make_event_from_dict(
+            self.add_hashes_and_signatures(
+                {
+                    "type": "test_regular_type",
+                    "room_id": room_id,
+                    "sender": OTHER_USER,
+                    "prev_events": [],
+                    "auth_events": auth_event_ids,
+                    "origin_server_ts": 1,
+                    "depth": 11,
+                    "content": {"body": "missing_prev"},
+                }
+            ),
+            room_version,
+        )
+        if prev_exists_as_outlier:
+            prev_event.internal_metadata.outlier = True
+            persistence = self.hs.get_storage().persistence
+            self.get_success(
+                persistence.persist_event(prev_event, EventContext.for_outlier())
+            )
+        else:
+
+            async def get_event(destination: str, event_id: str, timeout=None):
+                self.assertEqual(destination, self.OTHER_SERVER_NAME)
+                self.assertEqual(event_id, prev_event.event_id)
+                return {"pdus": [prev_event.get_pdu_json()]}
+
+            self.mock_federation_transport_client.get_event.side_effect = get_event
+
+        # mock up a regular event to pass into _process_pulled_event
+        pulled_event = make_event_from_dict(
+            self.add_hashes_and_signatures(
+                {
+                    "type": "test_regular_type",
+                    "room_id": room_id,
+                    "sender": OTHER_USER,
+                    "prev_events": [prev_event.event_id],
+                    "auth_events": auth_event_ids,
+                    "origin_server_ts": 1,
+                    "depth": 12,
+                    "content": {"body": "pulled"},
+                }
+            ),
+            room_version,
+        )
+
+        # we expect an outbound request to /state_ids, so stub that out
+        self.mock_federation_transport_client.get_room_state_ids.return_value = (
+            make_awaitable(
+                {
+                    "pdu_ids": [e.event_id for e in state_at_prev_event],
+                    "auth_chain_ids": [],
+                }
+            )
+        )
+
+        # we also expect an outbound request to /state
+        self.mock_federation_transport_client.get_room_state.return_value = (
+            make_awaitable(
+                StateRequestResponse(auth_events=[], state=state_at_prev_event)
+            )
+        )
+
+        # we have to bump the clock a bit, to keep the retry logic in
+        # FederationClient.get_pdu happy
+        self.reactor.advance(60000)
+
+        # Finally, the call under test: send the pulled event into _process_pulled_event
+        with LoggingContext("test"):
+            self.get_success(
+                self.hs.get_federation_event_handler()._process_pulled_event(
+                    self.OTHER_SERVER_NAME, pulled_event, backfilled=False
+                )
+            )
+
+        # check that the event is correctly persisted
+        persisted = self.get_success(main_store.get_event(pulled_event.event_id))
+        self.assertIsNotNone(persisted, "pulled event was not persisted at all")
+        self.assertFalse(
+            persisted.internal_metadata.is_outlier(), "pulled event was an outlier"
+        )
+
+        # check that the state at that event is as expected
+        state = self.get_success(
+            state_storage.get_state_ids_for_event(pulled_event.event_id)
+        )
+        expected_state = {
+            (e.type, e.state_key): e.event_id for e in state_at_prev_event
+        }
+        self.assertEqual(state, expected_state)
+
+        if prev_exists_as_outlier:
+            self.mock_federation_transport_client.get_event.assert_not_called()
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 014815db6e..9684120c70 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -354,10 +354,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
         req = Mock(spec=["cookies"])
         req.cookies = []
 
-        url = self.get_success(
-            self.provider.handle_redirect_request(req, b"http://client/redirect")
+        url = urlparse(
+            self.get_success(
+                self.provider.handle_redirect_request(req, b"http://client/redirect")
+            )
         )
-        url = urlparse(url)
         auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
 
         self.assertEqual(url.scheme, auth_endpoint.scheme)
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 1ec105c373..f88c725a42 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -59,7 +59,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         self.bob = UserID.from_string("@4567:test")
         self.alice = UserID.from_string("@alice:remote")
 
-        self.get_success(self.register_user(self.frank.localpart, "frankpassword"))
+        self.register_user(self.frank.localpart, "frankpassword")
 
         self.handler = hs.get_profile_handler()
 
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 3aedc0767b..865b8b7e47 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -158,9 +158,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
             )
 
         # Blow away caches (supported room versions can only change due to a restart).
-        self.get_success(
-            self.store.get_rooms_for_user_with_stream_ordering.invalidate_all()
-        )
+        self.store.get_rooms_for_user_with_stream_ordering.invalidate_all()
         self.store._get_event_cache.clear()
 
         # The rooms should be excluded from the sync response.
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 92012cd6f7..c6e501c7be 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -351,6 +351,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             self.handler.handle_local_profile_change(regular_user_id, profile_info)
         )
         profile = self.get_success(self.store.get_user_in_directory(regular_user_id))
+        assert profile is not None
         self.assertTrue(profile["display_name"] == display_name)
 
     def test_handle_local_profile_change_with_deactivated_user(self) -> None:
@@ -369,6 +370,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
 
         # profile is in directory
         profile = self.get_success(self.store.get_user_in_directory(r_user_id))
+        assert profile is not None
         self.assertTrue(profile["display_name"] == display_name)
 
         # deactivate user
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 10dd94b549..9fd5d59c55 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -87,11 +87,23 @@ class ModuleApiTestCase(HomeserverTestCase):
         self.assertEqual(displayname, "Bobberino")
 
     def test_can_register_admin_user(self):
-        user_id = self.get_success(
-            self.register_user(
-                "bob_module_admin", "1234", displayname="Bobberino Admin", admin=True
-            )
+        user_id = self.register_user(
+            "bob_module_admin", "1234", displayname="Bobberino Admin", admin=True
         )
+
+        found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
+        self.assertEqual(found_user.user_id.to_string(), user_id)
+        self.assertIdentical(found_user.is_admin, True)
+
+    def test_can_set_admin(self):
+        user_id = self.register_user(
+            "alice_wants_admin",
+            "1234",
+            displayname="Alice Powerhungry",
+            admin=False,
+        )
+
+        self.get_success(self.module_api.set_user_admin(user_id, True))
         found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
         self.assertEqual(found_user.user_id.to_string(), user_id)
         self.assertIdentical(found_user.is_admin, True)
@@ -278,7 +290,7 @@ class ModuleApiTestCase(HomeserverTestCase):
         # Create a user and room to play with
         user_id = self.register_user("kermit", "monkey")
         tok = self.login("kermit", "monkey")
-        room_id = self.helper.create_room_as(user_id, tok=tok)
+        room_id = self.helper.create_room_as(user_id, tok=tok, is_public=False)
 
         # The room should not currently be in the public rooms directory
         is_in_public_rooms = self.get_success(
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 17dc42fd37..297a9e77f8 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -268,7 +268,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
 
         event_source = RoomEventSource(self.hs)
         event_source.store = self.slaved_store
-        current_token = self.get_success(event_source.get_current_key())
+        current_token = event_source.get_current_key()
 
         # gradually stream out the replication
         while repl_transport.buffer:
@@ -277,7 +277,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
             self.pump(0)
 
             prev_token = current_token
-            current_token = self.get_success(event_source.get_current_key())
+            current_token = event_source.get_current_key()
 
             # attempt to replicate the behaviour of the sync handler.
             #
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 0d47dd0aff..e909e444ac 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -702,6 +702,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
         """
 
         media_info = self.get_success(self.store.get_local_media(self.media_id))
+        assert media_info is not None
         self.assertFalse(media_info["quarantined_by"])
 
         # quarantining
@@ -715,6 +716,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
         self.assertFalse(channel.json_body)
 
         media_info = self.get_success(self.store.get_local_media(self.media_id))
+        assert media_info is not None
         self.assertTrue(media_info["quarantined_by"])
 
         # remove from quarantine
@@ -728,6 +730,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
         self.assertFalse(channel.json_body)
 
         media_info = self.get_success(self.store.get_local_media(self.media_id))
+        assert media_info is not None
         self.assertFalse(media_info["quarantined_by"])
 
     def test_quarantine_protected_media(self) -> None:
@@ -740,6 +743,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
 
         # verify protection
         media_info = self.get_success(self.store.get_local_media(self.media_id))
+        assert media_info is not None
         self.assertTrue(media_info["safe_from_quarantine"])
 
         # quarantining
@@ -754,6 +758,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
 
         # verify that is not in quarantine
         media_info = self.get_success(self.store.get_local_media(self.media_id))
+        assert media_info is not None
         self.assertFalse(media_info["quarantined_by"])
 
 
@@ -830,6 +835,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
         """
 
         media_info = self.get_success(self.store.get_local_media(self.media_id))
+        assert media_info is not None
         self.assertFalse(media_info["safe_from_quarantine"])
 
         # protect
@@ -843,6 +849,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
         self.assertFalse(channel.json_body)
 
         media_info = self.get_success(self.store.get_local_media(self.media_id))
+        assert media_info is not None
         self.assertTrue(media_info["safe_from_quarantine"])
 
         # unprotect
@@ -856,6 +863,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
         self.assertFalse(channel.json_body)
 
         media_info = self.get_success(self.store.get_local_media(self.media_id))
+        assert media_info is not None
         self.assertFalse(media_info["safe_from_quarantine"])
 
 
diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
index 2c855bff99..a53463c9ba 100644
--- a/tests/rest/admin/test_server_notice.py
+++ b/tests/rest/admin/test_server_notice.py
@@ -214,9 +214,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
         self.assertEqual(messages[0]["sender"], "@notices:test")
 
         # invalidate cache of server notices room_ids
-        self.get_success(
-            self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
-        )
+        self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
 
         # send second message
         channel = self.make_request(
@@ -291,9 +289,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
         # invalidate cache of server notices room_ids
         # if server tries to send to a cached room_id the user gets the message
         # in old room
-        self.get_success(
-            self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
-        )
+        self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
 
         # send second message
         channel = self.make_request(
@@ -380,9 +376,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
 
         # invalidate cache of server notices room_ids
         # if server tries to send to a cached room_id it gives an error
-        self.get_success(
-            self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
-        )
+        self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
 
         # send second message
         channel = self.make_request(
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index bef911d5df..0cdf1dec40 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1590,10 +1590,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
         self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
 
-        pushers = self.get_success(
-            self.store.get_pushers_by({"user_name": "@bob:test"})
+        pushers = list(
+            self.get_success(self.store.get_pushers_by({"user_name": "@bob:test"}))
         )
-        pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
         self.assertEqual("@bob:test", pushers[0].user_name)
 
@@ -1632,10 +1631,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
         self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
 
-        pushers = self.get_success(
-            self.store.get_pushers_by({"user_name": "@bob:test"})
+        pushers = list(
+            self.get_success(self.store.get_pushers_by({"user_name": "@bob:test"}))
         )
-        pushers = list(pushers)
         self.assertEqual(len(pushers), 0)
 
     def test_set_password(self) -> None:
@@ -2144,6 +2142,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         # is in user directory
         profile = self.get_success(self.store.get_user_in_directory(self.other_user))
+        assert profile is not None
         self.assertTrue(profile["display_name"] == "User")
 
         # Deactivate user
@@ -2711,6 +2710,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
         user_tuple = self.get_success(
             self.store.get_user_by_access_token(other_user_token)
         )
+        assert user_tuple is not None
         token_id = user_tuple.token_id
 
         self.get_success(
@@ -3676,6 +3676,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
         # The user starts off as not shadow-banned.
         other_user_token = self.login("user", "pass")
         result = self.get_success(self.store.get_user_by_access_token(other_user_token))
+        assert result is not None
         self.assertFalse(result.shadow_banned)
 
         channel = self.make_request("POST", self.url, access_token=self.admin_user_tok)
@@ -3684,6 +3685,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
 
         # Ensure the user is shadow-banned (and the cache was cleared).
         result = self.get_success(self.store.get_user_by_access_token(other_user_token))
+        assert result is not None
         self.assertTrue(result.shadow_banned)
 
         # Un-shadow-ban the user.
@@ -3695,6 +3697,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
 
         # Ensure the user is no longer shadow-banned (and the cache was cleared).
         result = self.get_success(self.store.get_user_by_access_token(other_user_token))
+        assert result is not None
         self.assertFalse(result.shadow_banned)
 
 
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index 27946febff..e00b5c171c 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -89,6 +89,17 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         self.store = hs.get_datastores().main
         self.submit_token_resource = PasswordResetSubmitTokenResource(hs)
 
+    def attempt_wrong_password_login(self, username: str, password: str) -> None:
+        """Attempts to login as the user with the given password, asserting
+        that the attempt *fails*.
+        """
+        body = {"type": "m.login.password", "user": username, "password": password}
+
+        channel = self.make_request(
+            "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
+        )
+        self.assertEqual(channel.code, 403, channel.result)
+
     def test_basic_password_reset(self) -> None:
         """Test basic password reset flow"""
         old_password = "monkey"
diff --git a/tests/rest/client/test_account_data.py b/tests/rest/client/test_account_data.py
new file mode 100644
index 0000000000..d5b0640e7a
--- /dev/null
+++ b/tests/rest/client/test_account_data.py
@@ -0,0 +1,75 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest.mock import Mock
+
+from synapse.rest import admin
+from synapse.rest.client import account_data, login, room
+
+from tests import unittest
+from tests.test_utils import make_awaitable
+
+
+class AccountDataTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+        account_data.register_servlets,
+    ]
+
+    def test_on_account_data_updated_callback(self) -> None:
+        """Tests that the on_account_data_updated module callback is called correctly when
+        a user's account data changes.
+        """
+        mocked_callback = Mock(return_value=make_awaitable(None))
+        self.hs.get_account_data_handler()._on_account_data_updated_callbacks.append(
+            mocked_callback
+        )
+
+        user_id = self.register_user("user", "password")
+        tok = self.login("user", "password")
+        account_data_type = "org.matrix.foo"
+        account_data_content = {"bar": "baz"}
+
+        # Change the user's global account data.
+        channel = self.make_request(
+            "PUT",
+            f"/user/{user_id}/account_data/{account_data_type}",
+            account_data_content,
+            access_token=tok,
+        )
+
+        # Test that the callback is called with the user ID, the new account data, and
+        # None as the room ID.
+        self.assertEqual(channel.code, 200, channel.result)
+        mocked_callback.assert_called_once_with(
+            user_id, None, account_data_type, account_data_content
+        )
+
+        # Change the user's room-specific account data.
+        room_id = self.helper.create_room_as(user_id, tok=tok)
+        channel = self.make_request(
+            "PUT",
+            f"/user/{user_id}/rooms/{room_id}/account_data/{account_data_type}",
+            account_data_content,
+            access_token=tok,
+        )
+
+        # Test that the callback is called with the user ID, the room ID and the new
+        # account data.
+        self.assertEqual(channel.code, 200, channel.result)
+        self.assertEqual(mocked_callback.call_count, 2)
+        mocked_callback.assert_called_with(
+            user_id, room_id, account_data_type, account_data_content
+        )
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index fe97a0b3dd..419eef166a 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -13,7 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import itertools
 import urllib.parse
 from typing import Any, Callable, Dict, List, Optional, Tuple
 from unittest.mock import patch
@@ -145,16 +144,6 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code, channel.json_body)
         return channel.json_body["unsigned"].get("m.relations", {})
 
-    def _get_aggregations(self) -> List[JsonDict]:
-        """Request /aggregations on the parent ID and includes the returned chunk."""
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        return channel.json_body["chunk"]
-
     def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
         """
         Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
@@ -264,43 +253,6 @@ class RelationsTestCase(BaseRelationsTestCase):
             expected_response_code=400,
         )
 
-    def test_aggregation(self) -> None:
-        """Test that annotations get correctly aggregated."""
-
-        self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        self._send_relation(
-            RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
-        )
-        self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
-
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        self.assertEqual(
-            channel.json_body,
-            {
-                "chunk": [
-                    {"type": "m.reaction", "key": "a", "count": 2},
-                    {"type": "m.reaction", "key": "b", "count": 1},
-                ]
-            },
-        )
-
-    def test_aggregation_must_be_annotation(self) -> None:
-        """Test that aggregations must be annotations."""
-
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/unstable/rooms/{self.room}/aggregations"
-            f"/{self.parent_id}/{RelationTypes.REPLACE}?limit=1",
-            access_token=self.user_token,
-        )
-        self.assertEqual(400, channel.code, channel.json_body)
-
     def test_ignore_invalid_room(self) -> None:
         """Test that we ignore invalid relations over federation."""
         # Create another room and send a message in it.
@@ -394,15 +346,6 @@ class RelationsTestCase(BaseRelationsTestCase):
         self.assertEqual(200, channel.code, channel.json_body)
         self.assertEqual(channel.json_body["chunk"], [])
 
-        # And when fetching aggregations.
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/unstable/rooms/{room2}/aggregations/{parent_id}",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        self.assertEqual(channel.json_body["chunk"], [])
-
         # And for bundled aggregations.
         channel = self.make_request(
             "GET",
@@ -717,15 +660,6 @@ class RelationsTestCase(BaseRelationsTestCase):
         self.assertEqual(200, channel.code, channel.json_body)
         self.assertNotIn("m.relations", channel.json_body["unsigned"])
 
-        # But unknown relations can be directly queried.
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        self.assertEqual(channel.json_body["chunk"], [])
-
     def test_background_update(self) -> None:
         """Test the event_arbitrary_relations background update."""
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
@@ -941,131 +875,6 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
                 annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
             )
 
-    def test_aggregation_pagination_groups(self) -> None:
-        """Test that we can paginate annotation groups correctly."""
-
-        # We need to create ten separate users to send each reaction.
-        access_tokens = [self.user_token, self.user2_token]
-        idx = 0
-        while len(access_tokens) < 10:
-            user_id, token = self._create_user("test" + str(idx))
-            idx += 1
-
-            self.helper.join(self.room, user=user_id, tok=token)
-            access_tokens.append(token)
-
-        idx = 0
-        sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1}
-        for key in itertools.chain.from_iterable(
-            itertools.repeat(key, num) for key, num in sent_groups.items()
-        ):
-            self._send_relation(
-                RelationTypes.ANNOTATION,
-                "m.reaction",
-                key=key,
-                access_token=access_tokens[idx],
-            )
-
-            idx += 1
-            idx %= len(access_tokens)
-
-        prev_token: Optional[str] = None
-        found_groups: Dict[str, int] = {}
-        for _ in range(20):
-            from_token = ""
-            if prev_token:
-                from_token = "&from=" + prev_token
-
-            channel = self.make_request(
-                "GET",
-                f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}",
-                access_token=self.user_token,
-            )
-            self.assertEqual(200, channel.code, channel.json_body)
-
-            self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-
-            for groups in channel.json_body["chunk"]:
-                # We only expect reactions
-                self.assertEqual(groups["type"], "m.reaction", channel.json_body)
-
-                # We should only see each key once
-                self.assertNotIn(groups["key"], found_groups, channel.json_body)
-
-                found_groups[groups["key"]] = groups["count"]
-
-            next_batch = channel.json_body.get("next_batch")
-
-            self.assertNotEqual(prev_token, next_batch)
-            prev_token = next_batch
-
-            if not prev_token:
-                break
-
-        self.assertEqual(sent_groups, found_groups)
-
-    def test_aggregation_pagination_within_group(self) -> None:
-        """Test that we can paginate within an annotation group."""
-
-        # We need to create ten separate users to send each reaction.
-        access_tokens = [self.user_token, self.user2_token]
-        idx = 0
-        while len(access_tokens) < 10:
-            user_id, token = self._create_user("test" + str(idx))
-            idx += 1
-
-            self.helper.join(self.room, user=user_id, tok=token)
-            access_tokens.append(token)
-
-        idx = 0
-        expected_event_ids = []
-        for _ in range(10):
-            channel = self._send_relation(
-                RelationTypes.ANNOTATION,
-                "m.reaction",
-                key="👍",
-                access_token=access_tokens[idx],
-            )
-            expected_event_ids.append(channel.json_body["event_id"])
-
-            idx += 1
-
-        # Also send a different type of reaction so that we test we don't see it
-        self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
-
-        prev_token = ""
-        found_event_ids: List[str] = []
-        encoded_key = urllib.parse.quote_plus("👍".encode())
-        for _ in range(20):
-            from_token = ""
-            if prev_token:
-                from_token = "&from=" + prev_token
-
-            channel = self.make_request(
-                "GET",
-                f"/_matrix/client/unstable/rooms/{self.room}"
-                f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
-                f"/m.reaction/{encoded_key}?limit=1{from_token}",
-                access_token=self.user_token,
-            )
-            self.assertEqual(200, channel.code, channel.json_body)
-
-            self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-
-            found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
-
-            next_batch = channel.json_body.get("next_batch")
-
-            self.assertNotEqual(prev_token, next_batch)
-            prev_token = next_batch
-
-            if not prev_token:
-                break
-
-        # We paginated backwards, so reverse
-        found_event_ids.reverse()
-        self.assertEqual(found_event_ids, expected_event_ids)
-
 
 class BundledAggregationsTestCase(BaseRelationsTestCase):
     """
@@ -1453,10 +1262,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
             {"chunk": [{"type": "m.reaction", "key": "a", "count": 2}]},
         )
 
-        # Both relations appear in the aggregation.
-        chunk = self._get_aggregations()
-        self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 2}])
-
         # Redact one of the reactions.
         self._redact(to_redact_event_id)
 
@@ -1469,10 +1274,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
             {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
         )
 
-        # The unredacted aggregation should still exist.
-        chunk = self._get_aggregations()
-        self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 1}])
-
     def test_redact_relation_thread(self) -> None:
         """
         Test that thread replies are properly handled after the thread reply redacted.
@@ -1578,10 +1379,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
         self.assertEqual(len(event_ids), 1)
         self.assertIn(RelationTypes.ANNOTATION, relations)
 
-        # The aggregation should exist.
-        chunk = self._get_aggregations()
-        self.assertEqual(chunk, [{"type": "m.reaction", "key": "👍", "count": 1}])
-
         # Redact the original event.
         self._redact(self.parent_id)
 
@@ -1594,10 +1391,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
             {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]},
         )
 
-        # There's nothing to aggregate.
-        chunk = self._get_aggregations()
-        self.assertEqual(chunk, [{"count": 1, "key": "👍", "type": "m.reaction"}])
-
     @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
     def test_redact_parent_thread(self) -> None:
         """
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 3a9617d6da..6ff79b9e2e 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -982,7 +982,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
         super().prepare(reactor, clock, hs)
         # profile changes expect that the user is actually registered
         user = UserID.from_string(self.user_id)
-        self.get_success(self.register_user(user.localpart, "supersecretpassword"))
+        self.register_user(user.localpart, "supersecretpassword")
 
     @unittest.override_config(
         {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 4351013952..773c16a54c 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -341,7 +341,6 @@ class SyncKnockTestCase(
             hs, self.room_id, self.user_id
         )
 
-    @override_config({"experimental_features": {"msc2403_enabled": True}})
     def test_knock_room_state(self) -> None:
         """Tests that /sync returns state from a room after knocking on it."""
         # Knock on a room
@@ -497,6 +496,11 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
         receipts.register_servlets,
     ]
 
+    def default_config(self) -> JsonDict:
+        config = super().default_config()
+        config["experimental_features"] = {"msc2654_enabled": True}
+        return config
+
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.url = "/sync?since=%s"
         self.next_batch = "s0"
@@ -772,3 +776,65 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
         self.assertIn(
             self.user_id, device_list_changes, incremental_sync_channel.json_body
         )
+
+
+class ExcludeRoomTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        sync.register_servlets,
+        room.register_servlets,
+    ]
+
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+    ) -> None:
+        self.user_id = self.register_user("user", "password")
+        self.tok = self.login("user", "password")
+
+        self.excluded_room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+        self.included_room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+        # We need to manually append the room ID, because we can't know the ID before
+        # creating the room, and we can't set the config after starting the homeserver.
+        self.hs.get_sync_handler().rooms_to_exclude.append(self.excluded_room_id)
+
+    def test_join_leave(self) -> None:
+        """Tests that rooms are correctly excluded from the 'join' and 'leave' sections of
+        sync responses.
+        """
+        channel = self.make_request("GET", "/sync", access_token=self.tok)
+        self.assertEqual(channel.code, 200, channel.result)
+
+        self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["join"])
+        self.assertIn(self.included_room_id, channel.json_body["rooms"]["join"])
+
+        self.helper.leave(self.excluded_room_id, self.user_id, tok=self.tok)
+        self.helper.leave(self.included_room_id, self.user_id, tok=self.tok)
+
+        channel = self.make_request(
+            "GET",
+            "/sync?since=" + channel.json_body["next_batch"],
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+
+        self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["leave"])
+        self.assertIn(self.included_room_id, channel.json_body["rooms"]["leave"])
+
+    def test_invite(self) -> None:
+        """Tests that rooms are correctly excluded from the 'invite' section of sync
+        responses.
+        """
+        invitee = self.register_user("invitee", "password")
+        invitee_tok = self.login("invitee", "password")
+
+        self.helper.invite(self.excluded_room_id, self.user_id, invitee, tok=self.tok)
+        self.helper.invite(self.included_room_id, self.user_id, invitee, tok=self.tok)
+
+        channel = self.make_request("GET", "/sync", access_token=invitee_tok)
+        self.assertEqual(channel.code, 200, channel.result)
+
+        self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["invite"])
+        self.assertIn(self.included_room_id, channel.json_body["rooms"]["invite"])
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index e7de67e3a3..5eb0f243f7 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -896,3 +896,44 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
 
         # Check that the mock was called with the right room ID
         self.assertEqual(args[1], self.room_id)
+
+    def test_on_threepid_bind(self) -> None:
+        """Tests that the on_threepid_bind module callback is called correctly after
+        associating a 3PID to an account.
+        """
+        # Register a mocked callback.
+        threepid_bind_mock = Mock(return_value=make_awaitable(None))
+        third_party_rules = self.hs.get_third_party_event_rules()
+        third_party_rules._on_threepid_bind_callbacks.append(threepid_bind_mock)
+
+        # Register an admin user.
+        self.register_user("admin", "password", admin=True)
+        admin_tok = self.login("admin", "password")
+
+        # Also register a normal user we can modify.
+        user_id = self.register_user("user", "password")
+
+        # Add a 3PID to the user.
+        channel = self.make_request(
+            "PUT",
+            "/_synapse/admin/v2/users/%s" % user_id,
+            {
+                "threepids": [
+                    {
+                        "medium": "email",
+                        "address": "foo@example.com",
+                    },
+                ],
+            },
+            access_token=admin_tok,
+        )
+
+        # Check that the shutdown was blocked
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check that the mock was called once.
+        threepid_bind_mock.assert_called_once()
+        args = threepid_bind_mock.call_args[0]
+
+        # Check that the mock was called with the right parameters
+        self.assertEqual(args, (user_id, "email", "foo@example.com"))
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 28663826fc..a0788b1bb0 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -88,7 +88,7 @@ class RestHelper:
     def create_room_as(
         self,
         room_creator: Optional[str] = None,
-        is_public: Optional[bool] = None,
+        is_public: Optional[bool] = True,
         room_version: Optional[str] = None,
         tok: Optional[str] = None,
         expect_code: int = HTTPStatus.OK,
@@ -101,9 +101,12 @@ class RestHelper:
         Args:
             room_creator: The user ID to create the room with.
             is_public: If True, the `visibility` parameter will be set to
-                "public". If False, it will be set to "private". If left
-                unspecified, the server will set it to an appropriate default
-                (which should be "private" as per the CS spec).
+                "public". If False, it will be set to "private".
+                If None, doesn't specify the `visibility` parameter in which
+                case the server is supposed to make the room private according to
+                the CS API.
+                Defaults to public, since that is commonly needed in tests
+                for convenience where room privacy is not a problem.
             room_version: The room version to create the room as. Defaults to Synapse's
                 default room version.
             tok: The access token to use in the request.
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 978c252f84..ac0ac06b7e 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -76,7 +76,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
                 "verify_keys": {
                     key_id: {
                         "key": signedjson.key.encode_verify_key_base64(
-                            signing_key.verify_key
+                            signedjson.key.get_verify_key(signing_key)
                         )
                     }
                 },
@@ -175,7 +175,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
                     % (
                         self.hs_signing_key.version,
                     ): signedjson.key.encode_verify_key_base64(
-                        self.hs_signing_key.verify_key
+                        signedjson.key.get_verify_key(self.hs_signing_key)
                     )
                 },
             }
@@ -229,7 +229,9 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
         assert isinstance(keyres, FetchKeyResult)
         self.assertEqual(
             signedjson.key.encode_verify_key_base64(keyres.verify_key),
-            signedjson.key.encode_verify_key_base64(testkey.verify_key),
+            signedjson.key.encode_verify_key_base64(
+                signedjson.key.get_verify_key(testkey)
+            ),
         )
 
     def test_get_notary_key(self) -> None:
@@ -251,7 +253,9 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
         assert isinstance(keyres, FetchKeyResult)
         self.assertEqual(
             signedjson.key.encode_verify_key_base64(keyres.verify_key),
-            signedjson.key.encode_verify_key_base64(testkey.verify_key),
+            signedjson.key.encode_verify_key_base64(
+                signedjson.key.get_verify_key(testkey)
+            ),
         )
 
     def test_get_notary_keyserver_key(self) -> None:
@@ -268,5 +272,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
         assert isinstance(keyres, FetchKeyResult)
         self.assertEqual(
             signedjson.key.encode_verify_key_base64(keyres.verify_key),
-            signedjson.key.encode_verify_key_base64(self.hs_signing_key.verify_key),
+            signedjson.key.encode_verify_key_base64(
+                signedjson.key.get_verify_key(self.hs_signing_key)
+            ),
         )
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 5148c39874..3b24d0ace6 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -17,7 +17,7 @@ import json
 import os
 import re
 from typing import Any, Dict, Optional, Sequence, Tuple, Type
-from urllib.parse import urlencode
+from urllib.parse import quote, urlencode
 
 from twisted.internet._resolver import HostResolution
 from twisted.internet.address import IPv4Address, IPv6Address
@@ -69,7 +69,6 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             "2001:800::/21",
         )
         config["url_preview_ip_range_whitelist"] = ("1.1.1.1",)
-        config["url_preview_url_blacklist"] = []
         config["url_preview_accept_language"] = [
             "en-UK",
             "en-US;q=0.9",
@@ -1123,3 +1122,43 @@ class URLPreviewTests(unittest.HomeserverTestCase):
                 os.path.exists(path),
                 f"{os.path.relpath(path, self.media_store_path)} was not deleted",
             )
+
+    @unittest.override_config({"url_preview_url_blacklist": [{"port": "*"}]})
+    def test_blacklist_port(self) -> None:
+        """Tests that blacklisting URLs with a port makes previewing such URLs
+        fail with a 403 error and doesn't impact other previews.
+        """
+        self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
+
+        bad_url = quote("http://matrix.org:8888/foo")
+        good_url = quote("http://matrix.org/foo")
+
+        channel = self.make_request(
+            "GET",
+            "preview_url?url=" + bad_url,
+            shorthand=False,
+            await_result=False,
+        )
+        self.pump()
+        self.assertEqual(channel.code, 403, channel.result)
+
+        channel = self.make_request(
+            "GET",
+            "preview_url?url=" + good_url,
+            shorthand=False,
+            await_result=False,
+        )
+        self.pump()
+
+        client = self.reactor.tcpClients[0][2].buildProtocol(None)
+        server = AccumulatingProtocol()
+        server.makeConnection(FakeTransport(client, self.reactor))
+        client.makeConnection(FakeTransport(server, self.reactor))
+        client.dataReceived(
+            b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
+            % (len(self.end_content),)
+            + self.end_content
+        )
+
+        self.pump()
+        self.assertEqual(channel.code, 200)
diff --git a/tests/server.py b/tests/server.py
index 6ce2a17bf4..16559d2588 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -22,7 +22,6 @@ import warnings
 from collections import deque
 from io import SEEK_END, BytesIO
 from typing import (
-    AnyStr,
     Callable,
     Dict,
     Iterable,
@@ -77,6 +76,7 @@ from tests.utils import (
     POSTGRES_BASE_DB,
     POSTGRES_HOST,
     POSTGRES_PASSWORD,
+    POSTGRES_PORT,
     POSTGRES_USER,
     SQLITE_PERSIST_DB,
     USE_POSTGRES_FOR_TESTS,
@@ -86,6 +86,9 @@ from tests.utils import (
 
 logger = logging.getLogger(__name__)
 
+# the type of thing that can be passed into `make_request` in the headers list
+CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
+
 
 class TimedOutException(Exception):
     """
@@ -260,7 +263,7 @@ def make_request(
     federation_auth_origin: Optional[bytes] = None,
     content_is_form: bool = False,
     await_result: bool = True,
-    custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
+    custom_headers: Optional[Iterable[CustomHeaderType]] = None,
     client_ip: str = "127.0.0.1",
 ) -> FakeChannel:
     """
@@ -745,6 +748,7 @@ def setup_test_homeserver(
                 "host": POSTGRES_HOST,
                 "password": POSTGRES_PASSWORD,
                 "user": POSTGRES_USER,
+                "port": POSTGRES_PORT,
                 "cp_min": 1,
                 "cp_max": 5,
             },
@@ -784,6 +788,7 @@ def setup_test_homeserver(
             database=POSTGRES_BASE_DB,
             user=POSTGRES_USER,
             host=POSTGRES_HOST,
+            port=POSTGRES_PORT,
             password=POSTGRES_PASSWORD,
         )
         db_conn.autocommit = True
@@ -831,6 +836,7 @@ def setup_test_homeserver(
                 database=POSTGRES_BASE_DB,
                 user=POSTGRES_USER,
                 host=POSTGRES_HOST,
+                port=POSTGRES_PORT,
                 password=POSTGRES_PASSWORD,
             )
             db_conn.autocommit = True
diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py
index 3ac4646969..74c6224eb6 100644
--- a/tests/storage/databases/main/test_lock.py
+++ b/tests/storage/databases/main/test_lock.py
@@ -28,7 +28,7 @@ class LockTestCase(unittest.HomeserverTestCase):
         """
         # First to acquire this lock, so it should complete
         lock = self.get_success(self.store.try_acquire_lock("name", "key"))
-        self.assertIsNotNone(lock)
+        assert lock is not None
 
         # Enter the context manager
         self.get_success(lock.__aenter__())
@@ -45,7 +45,7 @@ class LockTestCase(unittest.HomeserverTestCase):
 
         # We can now acquire the lock again.
         lock3 = self.get_success(self.store.try_acquire_lock("name", "key"))
-        self.assertIsNotNone(lock3)
+        assert lock3 is not None
         self.get_success(lock3.__aenter__())
         self.get_success(lock3.__aexit__(None, None, None))
 
@@ -53,7 +53,7 @@ class LockTestCase(unittest.HomeserverTestCase):
         """Test that we don't time out locks while they're still active"""
 
         lock = self.get_success(self.store.try_acquire_lock("name", "key"))
-        self.assertIsNotNone(lock)
+        assert lock is not None
 
         self.get_success(lock.__aenter__())
 
@@ -69,7 +69,7 @@ class LockTestCase(unittest.HomeserverTestCase):
         """Test that we time out locks if they're not updated for ages"""
 
         lock = self.get_success(self.store.try_acquire_lock("name", "key"))
-        self.assertIsNotNone(lock)
+        assert lock is not None
 
         self.get_success(lock.__aenter__())
 
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 4899cd5c36..366398e39d 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -14,12 +14,18 @@
 # limitations under the License.
 
 import secrets
+from typing import Any, Dict, Generator, List, Tuple
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests import unittest
 
 
 class UpsertManyTests(unittest.HomeserverTestCase):
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.storage = hs.get_datastores().main
 
         self.table_name = "table_" + secrets.token_hex(6)
@@ -40,11 +46,13 @@ class UpsertManyTests(unittest.HomeserverTestCase):
             )
         )
 
-    def _dump_to_tuple(self, res):
+    def _dump_to_tuple(
+        self, res: List[Dict[str, Any]]
+    ) -> Generator[Tuple[int, str, str], None, None]:
         for i in res:
             yield (i["id"], i["username"], i["value"])
 
-    def test_upsert_many(self):
+    def test_upsert_many(self) -> None:
         """
         Upsert_many will perform the upsert operation across a batch of data.
         """
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index ee599f4336..1bf93e79a7 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import (
     ApplicationServiceStore,
     ApplicationServiceTransactionStore,
 )
+from synapse.types import DeviceListUpdates
 from synapse.util import Clock
 
 from tests import unittest
@@ -168,15 +169,6 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
             (as_id, txn_id, json.dumps([e.event_id for e in events])),
         )
 
-    def _set_last_txn(self, as_id, txn_id):
-        return self.db_pool.runOperation(
-            self.engine.convert_param_style(
-                "INSERT INTO application_services_state(as_id, last_txn, state) "
-                "VALUES(?,?,?)"
-            ),
-            (as_id, txn_id, ApplicationServiceState.UP.value),
-        )
-
     def test_get_appservice_state_none(
         self,
     ) -> None:
@@ -267,65 +259,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
         events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
         txn = self.get_success(
             defer.ensureDeferred(
-                self.store.create_appservice_txn(service, events, [], [], {}, {})
+                self.store.create_appservice_txn(
+                    service, events, [], [], {}, {}, DeviceListUpdates()
+                )
             )
         )
         self.assertEqual(txn.id, 1)
         self.assertEqual(txn.events, events)
         self.assertEqual(txn.service, service)
 
-    def test_create_appservice_txn_older_last_txn(
-        self,
-    ) -> None:
-        service = Mock(id=self.as_list[0]["id"])
-        events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
-        self.get_success(self._set_last_txn(service.id, 9643))  # AS is falling behind
-        self.get_success(self._insert_txn(service.id, 9644, events))
-        self.get_success(self._insert_txn(service.id, 9645, events))
-        txn = self.get_success(
-            self.store.create_appservice_txn(service, events, [], [], {}, {})
-        )
-        self.assertEqual(txn.id, 9646)
-        self.assertEqual(txn.events, events)
-        self.assertEqual(txn.service, service)
-
-    def test_create_appservice_txn_up_to_date_last_txn(
-        self,
-    ) -> None:
-        service = Mock(id=self.as_list[0]["id"])
-        events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
-        self.get_success(self._set_last_txn(service.id, 9643))
-        txn = self.get_success(
-            self.store.create_appservice_txn(service, events, [], [], {}, {})
-        )
-        self.assertEqual(txn.id, 9644)
-        self.assertEqual(txn.events, events)
-        self.assertEqual(txn.service, service)
-
-    def test_create_appservice_txn_up_fuzzing(
-        self,
-    ) -> None:
-        service = Mock(id=self.as_list[0]["id"])
-        events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
-        self.get_success(self._set_last_txn(service.id, 9643))
-
-        # dump in rows with higher IDs to make sure the queries aren't wrong.
-        self.get_success(self._set_last_txn(self.as_list[1]["id"], 119643))
-        self.get_success(self._set_last_txn(self.as_list[2]["id"], 9))
-        self.get_success(self._set_last_txn(self.as_list[3]["id"], 9643))
-        self.get_success(self._insert_txn(self.as_list[1]["id"], 119644, events))
-        self.get_success(self._insert_txn(self.as_list[1]["id"], 119645, events))
-        self.get_success(self._insert_txn(self.as_list[1]["id"], 119646, events))
-        self.get_success(self._insert_txn(self.as_list[2]["id"], 10, events))
-        self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events))
-
-        txn = self.get_success(
-            self.store.create_appservice_txn(service, events, [], [], {}, {})
-        )
-        self.assertEqual(txn.id, 9644)
-        self.assertEqual(txn.events, events)
-        self.assertEqual(txn.service, service)
-
     def test_complete_appservice_txn_first_txn(
         self,
     ) -> None:
@@ -359,13 +301,13 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(0, len(res))
 
-    def test_complete_appservice_txn_existing_in_state_table(
+    def test_complete_appservice_txn_updates_last_txn_state(
         self,
     ) -> None:
         service = Mock(id=self.as_list[0]["id"])
         events = [Mock(event_id="e1"), Mock(event_id="e2")]
         txn_id = 5
-        self.get_success(self._set_last_txn(service.id, 4))
+        self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP)
         self.get_success(self._insert_txn(service.id, txn_id, events))
         self.get_success(
             self.store.complete_appservice_txn(txn_id=txn_id, service=service)
@@ -416,6 +358,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
         self.get_success(self._insert_txn(service.id, 12, other_events))
 
         txn = self.get_success(self.store.get_oldest_unsent_txn(service))
+        assert txn is not None
         self.assertEqual(service, txn.service)
         self.assertEqual(10, txn.id)
         self.assertEqual(events, txn.events)
@@ -476,12 +419,12 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
         value = self.get_success(
             self.store.get_type_stream_id_for_appservice(self.service, "read_receipt")
         )
-        self.assertEqual(value, 0)
+        self.assertEqual(value, 1)
 
         value = self.get_success(
             self.store.get_type_stream_id_for_appservice(self.service, "presence")
         )
-        self.assertEqual(value, 0)
+        self.assertEqual(value, 1)
 
     def test_get_type_stream_id_for_appservice_invalid_type(self) -> None:
         self.get_failure(
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index ce89c96912..b998ad42d9 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -68,6 +68,22 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
 
         self.wait_for_background_updates()
 
+    def add_extremity(self, room_id: str, event_id: str) -> None:
+        """
+        Add the given event as an extremity to the room.
+        """
+        self.get_success(
+            self.hs.get_datastores().main.db_pool.simple_insert(
+                table="event_forward_extremities",
+                values={"room_id": room_id, "event_id": event_id},
+                desc="test_add_extremity",
+            )
+        )
+
+        self.hs.get_datastores().main.get_latest_event_ids_in_room.invalidate(
+            (room_id,)
+        )
+
     def test_soft_failed_extremities_handled_correctly(self):
         """Test that extremities are correctly calculated in the presence of
         soft failed events.
@@ -250,7 +266,9 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
         self.user = UserID.from_string(self.register_user("user1", "password"))
         self.token1 = self.login("user1", "password")
         self.requester = create_requester(self.user)
-        info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
+        info, _ = self.get_success(
+            self.room_creator.create_room(self.requester, {"visibility": "public"})
+        )
         self.room_id = info["room_id"]
         self.event_creator = homeserver.get_event_creation_handler()
         homeserver.config.consent.user_consent_version = self.CONSENT_VERSION
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 21ffc5a909..d1227dd4ac 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -96,7 +96,9 @@ class DeviceStoreTestCase(HomeserverTestCase):
 
         # Add two device updates with sequential `stream_id`s
         self.get_success(
-            self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
+            self.store.add_device_change_to_streams(
+                "user_id", device_ids, ["somehost"], ["!some:room"]
+            )
         )
 
         # Get all device updates ever meant for this remote
@@ -122,7 +124,9 @@ class DeviceStoreTestCase(HomeserverTestCase):
             "device_id5",
         ]
         self.get_success(
-            self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
+            self.store.add_device_change_to_streams(
+                "user_id", device_ids, ["somehost"], ["!some:room"]
+            )
         )
 
         # Get device updates meant for this remote
@@ -144,7 +148,9 @@ class DeviceStoreTestCase(HomeserverTestCase):
         # Add some more device updates to ensure it still resumes properly
         device_ids = ["device_id6", "device_id7"]
         self.get_success(
-            self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
+            self.store.add_device_change_to_streams(
+                "user_id", device_ids, ["somehost"], ["!some:room"]
+            )
         )
 
         # Get the next batch of device updates
@@ -220,7 +226,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
 
         self.get_success(
             self.store.add_device_change_to_streams(
-                "@user_id:test", device_ids, ["somehost"]
+                "@user_id:test", device_ids, ["somehost"], ["!some:room"]
             )
         )
 
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 395396340b..2d8d1f860f 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -157,10 +157,10 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen.get_positions(), {"master": 7})
         self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
 
-        ctx1 = self.get_success(id_gen.get_next())
-        ctx2 = self.get_success(id_gen.get_next())
-        ctx3 = self.get_success(id_gen.get_next())
-        ctx4 = self.get_success(id_gen.get_next())
+        ctx1 = id_gen.get_next()
+        ctx2 = id_gen.get_next()
+        ctx3 = id_gen.get_next()
+        ctx4 = id_gen.get_next()
 
         s1 = self.get_success(ctx1.__aenter__())
         s2 = self.get_success(ctx2.__aenter__())
@@ -362,8 +362,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
 
         # Persist two rows at once
-        ctx1 = self.get_success(id_gen.get_next())
-        ctx2 = self.get_success(id_gen.get_next())
+        ctx1 = id_gen.get_next()
+        ctx2 = id_gen.get_next()
 
         s1 = self.get_success(ctx1.__aenter__())
         s2 = self.get_success(ctx2.__aenter__())
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 79648d45db..60c8d37594 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -11,11 +11,15 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Any, Dict, List
 from unittest.mock import Mock
 
 from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import UserTypes
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests import unittest
 from tests.test_utils import make_awaitable
@@ -24,7 +28,7 @@ from tests.unittest import default_config, override_config
 FORTY_DAYS = 40 * 24 * 60 * 60
 
 
-def gen_3pids(count):
+def gen_3pids(count: int) -> List[Dict[str, Any]]:
     """Generate `count` threepids as a list."""
     return [
         {"medium": "email", "address": "user%i@matrix.org" % i} for i in range(count)
@@ -32,7 +36,7 @@ def gen_3pids(count):
 
 
 class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = default_config("test")
 
         config.update({"limit_usage_by_mau": True, "max_mau_value": 50})
@@ -44,10 +48,10 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
 
         return config
 
-    def prepare(self, reactor, clock, homeserver):
-        self.store = homeserver.get_datastores().main
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastores().main
         # Advance the clock a bit
-        reactor.advance(FORTY_DAYS)
+        self.reactor.advance(FORTY_DAYS)
 
     @override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)})
     def test_initialise_reserved_users(self):
@@ -245,7 +249,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         )
         self.get_success(d)
 
-        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
+        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))  # type: ignore[assignment]
 
         d = self.store.populate_monthly_active_users(user_id)
         self.get_success(d)
@@ -253,9 +257,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         self.store.upsert_monthly_active_user.assert_not_called()
 
     def test_populate_monthly_users_should_update(self):
-        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
+        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))  # type: ignore[assignment]
 
-        self.store.is_trial_user = Mock(return_value=defer.succeed(False))
+        self.store.is_trial_user = Mock(return_value=defer.succeed(False))  # type: ignore[assignment]
 
         self.store.user_last_seen_monthly_active = Mock(
             return_value=defer.succeed(None)
@@ -266,9 +270,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         self.store.upsert_monthly_active_user.assert_called_once()
 
     def test_populate_monthly_users_should_not_update(self):
-        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
+        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))  # type: ignore[assignment]
 
-        self.store.is_trial_user = Mock(return_value=defer.succeed(False))
+        self.store.is_trial_user = Mock(return_value=defer.succeed(False))  # type: ignore[assignment]
         self.store.user_last_seen_monthly_active = Mock(
             return_value=defer.succeed(self.hs.get_clock().time_msec())
         )
@@ -324,16 +328,15 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         count = self.get_success(self.store.get_monthly_active_count())
         self.assertEqual(count, 0)
 
-        d = self.store.register_user(
-            user_id=support_user_id, password_hash=None, user_type=UserTypes.SUPPORT
+        self.get_success(
+            self.store.register_user(
+                user_id=support_user_id, password_hash=None, user_type=UserTypes.SUPPORT
+            )
         )
-        self.get_success(d)
 
-        d = self.store.upsert_monthly_active_user(support_user_id)
-        self.get_success(d)
+        self.get_success(self.store.upsert_monthly_active_user(support_user_id))
 
-        d = self.store.get_monthly_active_count()
-        count = self.get_success(d)
+        count = self.get_success(self.store.get_monthly_active_count())
         self.assertEqual(count, 0)
 
     # Note that the max_mau_value setting should not matter.
@@ -352,7 +355,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
 
     @override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
     def test_no_users_when_not_tracking(self):
-        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
+        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))  # type: ignore[assignment]
 
         self.get_success(self.store.populate_monthly_active_users("@user:sever"))
 
@@ -388,16 +391,16 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
             self.store.register_user(user_id=native_user1, password_hash=None)
         )
 
-        count = self.get_success(self.store.get_monthly_active_count_by_service())
-        self.assertEqual(count, {})
+        count1 = self.get_success(self.store.get_monthly_active_count_by_service())
+        self.assertEqual(count1, {})
 
         self.get_success(self.store.upsert_monthly_active_user(native_user1))
         self.get_success(self.store.upsert_monthly_active_user(appservice1_user1))
         self.get_success(self.store.upsert_monthly_active_user(appservice1_user2))
         self.get_success(self.store.upsert_monthly_active_user(appservice2_user1))
 
-        count = self.get_success(self.store.get_monthly_active_count())
-        self.assertEqual(count, 1)
+        count2 = self.get_success(self.store.get_monthly_active_count())
+        self.assertEqual(count2, 1)
 
         d = self.store.get_monthly_active_count_by_service()
         result = self.get_success(d)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 03e9cc7d4a..d8d17ef379 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -119,11 +119,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
         return event
 
     def test_redact(self):
-        self.get_success(
-            self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
-        )
+        self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
 
-        msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t"))
+        msg_event = self.inject_message(self.room1, self.u_alice, "t")
 
         # Check event has not been redacted:
         event = self.get_success(self.store.get_event(msg_event.event_id))
@@ -141,9 +139,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
 
         # Redact event
         reason = "Because I said so"
-        self.get_success(
-            self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
-        )
+        self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
 
         event = self.get_success(self.store.get_event(msg_event.event_id))
 
@@ -170,14 +166,10 @@ class RedactionTestCase(unittest.HomeserverTestCase):
         )
 
     def test_redact_join(self):
-        self.get_success(
-            self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
-        )
+        self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
 
-        msg_event = self.get_success(
-            self.inject_room_member(
-                self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"}
-            )
+        msg_event = self.inject_room_member(
+            self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"}
         )
 
         event = self.get_success(self.store.get_event(msg_event.event_id))
@@ -195,9 +187,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
 
         # Redact event
         reason = "Because I said so"
-        self.get_success(
-            self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
-        )
+        self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
 
         # Check redaction
 
@@ -311,11 +301,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
     def test_redact_censor(self):
         """Test that a redacted event gets censored in the DB after a month"""
 
-        self.get_success(
-            self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
-        )
+        self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
 
-        msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t"))
+        msg_event = self.inject_message(self.room1, self.u_alice, "t")
 
         # Check event has not been redacted:
         event = self.get_success(self.store.get_event(msg_event.event_id))
@@ -333,9 +321,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
 
         # Redact event
         reason = "Because I said so"
-        self.get_success(
-            self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
-        )
+        self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
 
         event = self.get_success(self.store.get_event(msg_event.event_id))
 
@@ -381,25 +367,19 @@ class RedactionTestCase(unittest.HomeserverTestCase):
     def test_redact_redaction(self):
         """Tests that we can redact a redaction and can fetch it again."""
 
-        self.get_success(
-            self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
-        )
+        self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
 
-        msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t"))
+        msg_event = self.inject_message(self.room1, self.u_alice, "t")
 
-        first_redact_event = self.get_success(
-            self.inject_redaction(
-                self.room1, msg_event.event_id, self.u_alice, "Redacting message"
-            )
+        first_redact_event = self.inject_redaction(
+            self.room1, msg_event.event_id, self.u_alice, "Redacting message"
         )
 
-        self.get_success(
-            self.inject_redaction(
-                self.room1,
-                first_redact_event.event_id,
-                self.u_alice,
-                "Redacting redaction",
-            )
+        self.inject_redaction(
+            self.room1,
+            first_redact_event.event_id,
+            self.u_alice,
+            "Redacting redaction",
         )
 
         # Now lets jump to the future where we have censored the redaction event
@@ -414,9 +394,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
     def test_store_redacted_redaction(self):
         """Tests that we can store a redacted redaction."""
 
-        self.get_success(
-            self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
-        )
+        self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
 
         builder = self.event_builder_factory.for_room_version(
             RoomVersions.V1,
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index b8f09a8ee0..a2a9c05f24 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -12,11 +12,14 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import Membership
 from synapse.rest.admin import register_servlets_for_client_rest_resource
 from synapse.rest.client import login, room
+from synapse.server import HomeServer
 from synapse.types import UserID, create_requester
+from synapse.util import Clock
 
 from tests import unittest
 from tests.server import TestHomeServer
@@ -31,7 +34,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs: TestHomeServer):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: TestHomeServer) -> None:
 
         # We can't test the RoomMemberStore on its own without the other event
         # storage logic
@@ -44,7 +47,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         # User elsewhere on another host
         self.u_charlie = UserID.from_string("@charlie:elsewhere")
 
-    def test_one_member(self):
+    def test_one_member(self) -> None:
 
         # Alice creates the room, and is automatically joined
         self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
@@ -57,7 +60,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual([self.room], [m.room_id for m in rooms_for_user])
 
-    def test_count_known_servers(self):
+    def test_count_known_servers(self) -> None:
         """
         _count_known_servers will calculate how many servers are in a room.
         """
@@ -68,7 +71,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         servers = self.get_success(self.store._count_known_servers())
         self.assertEqual(servers, 2)
 
-    def test_count_known_servers_stat_counter_disabled(self):
+    def test_count_known_servers_stat_counter_disabled(self) -> None:
         """
         If enabled, the metrics for how many servers are known will be counted.
         """
@@ -85,7 +88,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
     @unittest.override_config(
         {"enable_metrics": True, "metrics_flags": {"known_servers": True}}
     )
-    def test_count_known_servers_stat_counter_enabled(self):
+    def test_count_known_servers_stat_counter_enabled(self) -> None:
         """
         If enabled, the metrics for how many servers are known will be counted.
         """
@@ -107,7 +110,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         # It now knows about Charlie's server.
         self.assertEqual(self.store._known_servers_count, 2)
 
-    def test_get_joined_users_from_context(self):
+    def test_get_joined_users_from_context(self) -> None:
         room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
         bob_event = self.get_success(
             event_injection.inject_member_event(
@@ -161,7 +164,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(users.keys(), {self.u_alice, self.u_bob})
 
-    def test__null_byte_in_display_name_properly_handled(self):
+    def test__null_byte_in_display_name_properly_handled(self) -> None:
         room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
 
         res = self.get_success(
@@ -211,11 +214,11 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
 
 
 class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
-    def prepare(self, reactor, clock, homeserver):
-        self.store = homeserver.get_datastores().main
-        self.room_creator = homeserver.get_room_creation_handler()
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastores().main
+        self.room_creator = hs.get_room_creation_handler()
 
-    def test_can_rerun_update(self):
+    def test_can_rerun_update(self) -> None:
         # First make sure we have completed all updates.
         self.wait_for_background_updates()
 
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index eaa0d7d749..52e41cdab4 100644
--- a/tests/storage/test_stream.py
+++ b/tests/storage/test_stream.py
@@ -110,9 +110,7 @@ class PaginationTestCase(HomeserverTestCase):
     def _filter_messages(self, filter: JsonDict) -> List[EventBase]:
         """Make a request to /messages with a filter, returns the chunk of events."""
 
-        from_token = self.get_success(
-            self.hs.get_event_sources().get_current_token_for_pagination()
-        )
+        from_token = self.hs.get_event_sources().get_current_token_for_pagination()
 
         events, next_key = self.get_success(
             self.hs.get_datastores().main.paginate_room_events(
diff --git a/tests/test_phone_home.py b/tests/test_phone_home.py
index 09707a74d7..b01cae6e5d 100644
--- a/tests/test_phone_home.py
+++ b/tests/test_phone_home.py
@@ -16,23 +16,24 @@ import resource
 from unittest import mock
 
 from synapse.app.phone_stats_home import phone_stats_home
+from synapse.types import JsonDict
 
 from tests.unittest import HomeserverTestCase
 
 
 class PhoneHomeStatsTestCase(HomeserverTestCase):
-    def test_performance_frozen_clock(self):
+    def test_performance_frozen_clock(self) -> None:
         """
         If time doesn't move, don't error out.
         """
         past_stats = [
             (self.hs.get_clock().time(), resource.getrusage(resource.RUSAGE_SELF))
         ]
-        stats = {}
+        stats: JsonDict = {}
         self.get_success(phone_stats_home(self.hs, stats, past_stats))
         self.assertEqual(stats["cpu_average"], 0)
 
-    def test_performance_100(self):
+    def test_performance_100(self) -> None:
         """
         1 second of usage over 1 second is 100% CPU usage.
         """
@@ -43,7 +44,8 @@ class PhoneHomeStatsTestCase(HomeserverTestCase):
         old_resource.ru_maxrss = real_res.ru_maxrss
 
         past_stats = [(self.hs.get_clock().time(), old_resource)]
-        stats = {}
+        stats: JsonDict = {}
         self.reactor.advance(1)
-        self.get_success(phone_stats_home(self.hs, stats, past_stats))
+        # `old_resource` has type `Mock` instead of `struct_rusage`
+        self.get_success(phone_stats_home(self.hs, stats, past_stats))  # type: ignore[arg-type]
         self.assertApproximates(stats["cpu_average"], 100, tolerance=2.5)
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 532e3fe9cd..d0230f9ebb 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -17,6 +17,7 @@ from unittest.mock import patch
 
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase, make_event_from_dict
+from synapse.events.snapshot import EventContext
 from synapse.types import JsonDict, create_requester
 from synapse.visibility import filter_events_for_client, filter_events_for_server
 
@@ -47,17 +48,15 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         #
 
         # before we do that, we persist some other events to act as state.
-        self.get_success(self._inject_visibility("@admin:hs", "joined"))
+        self._inject_visibility("@admin:hs", "joined")
         for i in range(0, 10):
-            self.get_success(self._inject_room_member("@resident%i:hs" % i))
+            self._inject_room_member("@resident%i:hs" % i)
 
         events_to_filter = []
 
         for i in range(0, 10):
             user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server")
-            evt = self.get_success(
-                self._inject_room_member(user, extra_content={"a": "b"})
-            )
+            evt = self._inject_room_member(user, extra_content={"a": "b"})
             events_to_filter.append(evt)
 
         filtered = self.get_success(
@@ -73,24 +72,57 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
             self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
             self.assertEqual(filtered[i].content["a"], "b")
 
+    def test_filter_outlier(self) -> None:
+        # outlier events must be returned, for the good of the collective federation
+        self._inject_room_member("@resident:remote_hs")
+        self._inject_visibility("@resident:remote_hs", "joined")
+
+        outlier = self._inject_outlier()
+        self.assertEqual(
+            self.get_success(
+                filter_events_for_server(self.storage, "remote_hs", [outlier])
+            ),
+            [outlier],
+        )
+
+        # it should also work when there are other events in the list
+        evt = self._inject_message("@unerased:local_hs")
+
+        filtered = self.get_success(
+            filter_events_for_server(self.storage, "remote_hs", [outlier, evt])
+        )
+        self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}")
+        self.assertEqual(filtered[0], outlier)
+        self.assertEqual(filtered[1].event_id, evt.event_id)
+        self.assertEqual(filtered[1].content, evt.content)
+
+        # ... but other servers should only be able to see the outlier (the other should
+        # be redacted)
+        filtered = self.get_success(
+            filter_events_for_server(self.storage, "other_server", [outlier, evt])
+        )
+        self.assertEqual(filtered[0], outlier)
+        self.assertEqual(filtered[1].event_id, evt.event_id)
+        self.assertNotIn("body", filtered[1].content)
+
     def test_erased_user(self) -> None:
         # 4 message events, from erased and unerased users, with a membership
         # change in the middle of them.
         events_to_filter = []
 
-        evt = self.get_success(self._inject_message("@unerased:local_hs"))
+        evt = self._inject_message("@unerased:local_hs")
         events_to_filter.append(evt)
 
-        evt = self.get_success(self._inject_message("@erased:local_hs"))
+        evt = self._inject_message("@erased:local_hs")
         events_to_filter.append(evt)
 
-        evt = self.get_success(self._inject_room_member("@joiner:remote_hs"))
+        evt = self._inject_room_member("@joiner:remote_hs")
         events_to_filter.append(evt)
 
-        evt = self.get_success(self._inject_message("@unerased:local_hs"))
+        evt = self._inject_message("@unerased:local_hs")
         events_to_filter.append(evt)
 
-        evt = self.get_success(self._inject_message("@erased:local_hs"))
+        evt = self._inject_message("@erased:local_hs")
         events_to_filter.append(evt)
 
         # the erasey user gets erased
@@ -187,6 +219,25 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         self.get_success(self.storage.persistence.persist_event(event, context))
         return event
 
+    def _inject_outlier(self) -> EventBase:
+        builder = self.event_builder_factory.for_room_version(
+            RoomVersions.V1,
+            {
+                "type": "m.room.member",
+                "sender": "@test:user",
+                "state_key": "@test:user",
+                "room_id": TEST_ROOM_ID,
+                "content": {"membership": "join"},
+            },
+        )
+
+        event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[]))
+        event.internal_metadata.outlier = True
+        self.get_success(
+            self.storage.persistence.persist_event(event, EventContext.for_outlier())
+        )
+        return event
+
 
 class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
     def test_out_of_band_invite_rejection(self):
diff --git a/tests/unittest.py b/tests/unittest.py
index 326895f4c9..9afa68c164 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -16,17 +16,17 @@
 import gc
 import hashlib
 import hmac
-import inspect
 import json
 import logging
 import secrets
 import time
 from typing import (
     Any,
-    AnyStr,
+    Awaitable,
     Callable,
     ClassVar,
     Dict,
+    Generic,
     Iterable,
     List,
     Optional,
@@ -40,6 +40,7 @@ from unittest.mock import Mock, patch
 import canonicaljson
 import signedjson.key
 import unpaddedbase64
+from typing_extensions import Protocol
 
 from twisted.internet.defer import Deferred, ensureDeferred
 from twisted.python.failure import Failure
@@ -50,7 +51,7 @@ from twisted.web.resource import Resource
 from twisted.web.server import Request
 
 from synapse import events
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import DEFAULT_ROOM_VERSION
@@ -71,7 +72,13 @@ from synapse.types import JsonDict, UserID, create_requester
 from synapse.util import Clock
 from synapse.util.httpresourcetree import create_resource_tree
 
-from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
+from tests.server import (
+    CustomHeaderType,
+    FakeChannel,
+    get_clock,
+    make_request,
+    setup_test_homeserver,
+)
 from tests.test_utils import event_injection, setup_awaitable_errors
 from tests.test_utils.logging_setup import setup_logging
 from tests.utils import default_config, setupdb
@@ -79,6 +86,17 @@ from tests.utils import default_config, setupdb
 setupdb()
 setup_logging()
 
+TV = TypeVar("TV")
+_ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True)
+
+
+class _TypedFailure(Generic[_ExcType], Protocol):
+    """Extension to twisted.Failure, where the 'value' has a certain type."""
+
+    @property
+    def value(self) -> _ExcType:
+        ...
+
 
 def around(target):
     """A CLOS-style 'around' modifier, which wraps the original method of the
@@ -277,6 +295,7 @@ class HomeserverTestCase(TestCase):
 
         if hasattr(self, "user_id"):
             if self.hijack_auth:
+                assert self.helper.auth_user_id is not None
 
                 # We need a valid token ID to satisfy foreign key constraints.
                 token_id = self.get_success(
@@ -289,6 +308,7 @@ class HomeserverTestCase(TestCase):
                 )
 
                 async def get_user_by_access_token(token=None, allow_guest=False):
+                    assert self.helper.auth_user_id is not None
                     return {
                         "user": UserID.from_string(self.helper.auth_user_id),
                         "token_id": token_id,
@@ -296,6 +316,7 @@ class HomeserverTestCase(TestCase):
                     }
 
                 async def get_user_by_req(request, allow_guest=False, rights="access"):
+                    assert self.helper.auth_user_id is not None
                     return create_requester(
                         UserID.from_string(self.helper.auth_user_id),
                         token_id,
@@ -312,7 +333,7 @@ class HomeserverTestCase(TestCase):
                 )
 
         if self.needs_threadpool:
-            self.reactor.threadpool = ThreadPool()
+            self.reactor.threadpool = ThreadPool()  # type: ignore[assignment]
             self.addCleanup(self.reactor.threadpool.stop)
             self.reactor.threadpool.start()
 
@@ -427,7 +448,7 @@ class HomeserverTestCase(TestCase):
         federation_auth_origin: Optional[bytes] = None,
         content_is_form: bool = False,
         await_result: bool = True,
-        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
+        custom_headers: Optional[Iterable[CustomHeaderType]] = None,
         client_ip: str = "127.0.0.1",
     ) -> FakeChannel:
         """
@@ -512,40 +533,36 @@ class HomeserverTestCase(TestCase):
 
         return hs
 
-    def pump(self, by=0.0):
+    def pump(self, by: float = 0.0) -> None:
         """
         Pump the reactor enough that Deferreds will fire.
         """
         self.reactor.pump([by] * 100)
 
-    def get_success(self, d, by=0.0):
-        if inspect.isawaitable(d):
-            d = ensureDeferred(d)
-        if not isinstance(d, Deferred):
-            return d
+    def get_success(
+        self,
+        d: Awaitable[TV],
+        by: float = 0.0,
+    ) -> TV:
+        deferred: Deferred[TV] = ensureDeferred(d)  # type: ignore[arg-type]
         self.pump(by=by)
-        return self.successResultOf(d)
+        return self.successResultOf(deferred)
 
-    def get_failure(self, d, exc):
+    def get_failure(
+        self, d: Awaitable[Any], exc: Type[_ExcType]
+    ) -> _TypedFailure[_ExcType]:
         """
         Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
         """
-        if inspect.isawaitable(d):
-            d = ensureDeferred(d)
-        if not isinstance(d, Deferred):
-            return d
+        deferred: Deferred[Any] = ensureDeferred(d)  # type: ignore[arg-type]
         self.pump()
-        return self.failureResultOf(d, exc)
+        return self.failureResultOf(deferred, exc)
 
-    def get_success_or_raise(self, d, by=0.0):
+    def get_success_or_raise(self, d: Awaitable[TV], by: float = 0.0) -> TV:
         """Drive deferred to completion and return result or raise exception
         on failure.
         """
-
-        if inspect.isawaitable(d):
-            deferred = ensureDeferred(d)
-        if not isinstance(deferred, Deferred):
-            return d
+        deferred: Deferred[TV] = ensureDeferred(d)  # type: ignore[arg-type]
 
         results: list = []
         deferred.addBoth(results.append)
@@ -653,11 +670,11 @@ class HomeserverTestCase(TestCase):
 
     def login(
         self,
-        username,
-        password,
-        device_id=None,
-        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
-    ):
+        username: str,
+        password: str,
+        device_id: Optional[str] = None,
+        custom_headers: Optional[Iterable[CustomHeaderType]] = None,
+    ) -> str:
         """
         Log in a user, and get an access token. Requires the Login API be
         registered.
@@ -679,18 +696,22 @@ class HomeserverTestCase(TestCase):
         return access_token
 
     def create_and_send_event(
-        self, room_id, user, soft_failed=False, prev_event_ids=None
-    ):
+        self,
+        room_id: str,
+        user: UserID,
+        soft_failed: bool = False,
+        prev_event_ids: Optional[List[str]] = None,
+    ) -> str:
         """
         Create and send an event.
 
         Args:
-            soft_failed (bool): Whether to create a soft failed event or not
-            prev_event_ids (list[str]|None): Explicitly set the prev events,
+            soft_failed: Whether to create a soft failed event or not
+            prev_event_ids: Explicitly set the prev events,
                 or if None just use the default
 
         Returns:
-            str: The new event's ID.
+            The new event's ID.
         """
         event_creator = self.hs.get_event_creation_handler()
         requester = create_requester(user)
@@ -717,34 +738,7 @@ class HomeserverTestCase(TestCase):
 
         return event.event_id
 
-    def add_extremity(self, room_id, event_id):
-        """
-        Add the given event as an extremity to the room.
-        """
-        self.get_success(
-            self.hs.get_datastores().main.db_pool.simple_insert(
-                table="event_forward_extremities",
-                values={"room_id": room_id, "event_id": event_id},
-                desc="test_add_extremity",
-            )
-        )
-
-        self.hs.get_datastores().main.get_latest_event_ids_in_room.invalidate(
-            (room_id,)
-        )
-
-    def attempt_wrong_password_login(self, username, password):
-        """Attempts to login as the user with the given password, asserting
-        that the attempt *fails*.
-        """
-        body = {"type": "m.login.password", "user": username, "password": password}
-
-        channel = self.make_request(
-            "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
-        )
-        self.assertEqual(channel.code, 403, channel.result)
-
-    def inject_room_member(self, room: str, user: str, membership: Membership) -> None:
+    def inject_room_member(self, room: str, user: str, membership: str) -> None:
         """
         Inject a membership event into a room.
 
@@ -804,7 +798,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
         path: str,
         content: Optional[JsonDict] = None,
         await_result: bool = True,
-        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
+        custom_headers: Optional[Iterable[CustomHeaderType]] = None,
         client_ip: str = "127.0.0.1",
     ) -> FakeChannel:
         """Make an inbound signed federation request to this server
@@ -837,7 +831,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
             self.site,
             method=method,
             path=path,
-            content=content,
+            content=content or "",
             shorthand=False,
             await_result=await_result,
             custom_headers=custom_headers,
@@ -916,9 +910,6 @@ def override_config(extra_config):
     return decorator
 
 
-TV = TypeVar("TV")
-
-
 def skip_unless(condition: bool, reason: str) -> Callable[[TV], TV]:
     """A test decorator which will skip the decorated test unless a condition is set
 
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index c4a3917b23..fa132391a1 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -13,160 +13,202 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import Callable, Hashable, Tuple
+
 from twisted.internet import defer, reactor
-from twisted.internet.defer import CancelledError
+from twisted.internet.base import ReactorBase
+from twisted.internet.defer import CancelledError, Deferred
 
 from synapse.logging.context import LoggingContext, current_context
-from synapse.util import Clock
 from synapse.util.async_helpers import Linearizer
 
 from tests import unittest
 
 
 class LinearizerTestCase(unittest.TestCase):
-    @defer.inlineCallbacks
-    def test_linearizer(self):
+    def _start_task(
+        self, linearizer: Linearizer, key: Hashable
+    ) -> Tuple["Deferred[None]", "Deferred[None]", Callable[[], None]]:
+        """Starts a task which acquires the linearizer lock, blocks, then completes.
+
+        Args:
+            linearizer: The `Linearizer`.
+            key: The `Linearizer` key.
+
+        Returns:
+            A tuple containing:
+             * A cancellable `Deferred` for the entire task.
+             * A `Deferred` that resolves once the task acquires the lock.
+             * A function that unblocks the task. Must be called by the caller
+               to allow the task to release the lock and complete.
+        """
+        acquired_d: "Deferred[None]" = Deferred()
+        unblock_d: "Deferred[None]" = Deferred()
+
+        async def task() -> None:
+            with await linearizer.queue(key):
+                acquired_d.callback(None)
+                await unblock_d
+
+        d = defer.ensureDeferred(task())
+
+        def unblock() -> None:
+            unblock_d.callback(None)
+            # The next task, if it exists, will acquire the lock and require a kick of
+            # the reactor to advance.
+            self._pump()
+
+        return d, acquired_d, unblock
+
+    def _pump(self) -> None:
+        """Pump the reactor to advance `Linearizer`s."""
+        assert isinstance(reactor, ReactorBase)
+        while reactor.getDelayedCalls():
+            reactor.runUntilCurrent()
+
+    def test_linearizer(self) -> None:
+        """Tests that a task is queued up behind an earlier task."""
         linearizer = Linearizer()
 
         key = object()
 
-        d1 = linearizer.queue(key)
-        cm1 = yield d1
+        _, acquired_d1, unblock1 = self._start_task(linearizer, key)
+        self.assertTrue(acquired_d1.called)
+
+        _, acquired_d2, unblock2 = self._start_task(linearizer, key)
+        self.assertFalse(acquired_d2.called)
 
-        d2 = linearizer.queue(key)
-        self.assertFalse(d2.called)
+        # Once the first task is done, the second task can continue.
+        unblock1()
+        self.assertTrue(acquired_d2.called)
 
-        with cm1:
-            self.assertFalse(d2.called)
+        unblock2()
 
-        with (yield d2):
-            pass
+    def test_linearizer_is_queued(self) -> None:
+        """Tests `Linearizer.is_queued`.
 
-    @defer.inlineCallbacks
-    def test_linearizer_is_queued(self):
+        Runs through the same scenario as `test_linearizer`.
+        """
         linearizer = Linearizer()
 
         key = object()
 
-        d1 = linearizer.queue(key)
-        cm1 = yield d1
+        _, acquired_d1, unblock1 = self._start_task(linearizer, key)
+        self.assertTrue(acquired_d1.called)
 
-        # Since d1 gets called immediately, "is_queued" should return false.
+        # Since the first task acquires the lock immediately, "is_queued" should return
+        # false.
         self.assertFalse(linearizer.is_queued(key))
 
-        d2 = linearizer.queue(key)
-        self.assertFalse(d2.called)
+        _, acquired_d2, unblock2 = self._start_task(linearizer, key)
+        self.assertFalse(acquired_d2.called)
 
-        # Now d2 is queued up behind successful completion of cm1
+        # Now the second task is queued up behind the first.
         self.assertTrue(linearizer.is_queued(key))
 
-        with cm1:
-            self.assertFalse(d2.called)
-
-            # cm1 still not done, so d2 still queued.
-            self.assertTrue(linearizer.is_queued(key))
+        unblock1()
 
-        # And now d2 is called and nothing is in the queue again
+        # And now the second task acquires the lock and nothing is in the queue again.
+        self.assertTrue(acquired_d2.called)
         self.assertFalse(linearizer.is_queued(key))
 
-        with (yield d2):
-            self.assertFalse(linearizer.is_queued(key))
-
+        unblock2()
         self.assertFalse(linearizer.is_queued(key))
 
-    def test_lots_of_queued_things(self):
-        # we have one slow thing, and lots of fast things queued up behind it.
-        # it should *not* explode the stack.
+    def test_lots_of_queued_things(self) -> None:
+        """Tests lots of fast things queued up behind a slow thing.
+
+        The stack should *not* explode when the slow thing completes.
+        """
         linearizer = Linearizer()
+        key = ""
 
-        @defer.inlineCallbacks
-        def func(i, sleep=False):
+        async def func(i: int) -> None:
             with LoggingContext("func(%s)" % i) as lc:
-                with (yield linearizer.queue("")):
+                with (await linearizer.queue(key)):
                     self.assertEqual(current_context(), lc)
-                    if sleep:
-                        yield Clock(reactor).sleep(0)
 
                 self.assertEqual(current_context(), lc)
 
-        func(0, sleep=True)
+        _, _, unblock = self._start_task(linearizer, key)
         for i in range(1, 100):
-            func(i)
+            defer.ensureDeferred(func(i))
 
-        return func(1000)
+        d = defer.ensureDeferred(func(1000))
+        unblock()
+        self.successResultOf(d)
 
-    @defer.inlineCallbacks
-    def test_multiple_entries(self):
+    def test_multiple_entries(self) -> None:
+        """Tests a `Linearizer` with a concurrency above 1."""
         limiter = Linearizer(max_count=3)
 
         key = object()
 
-        d1 = limiter.queue(key)
-        cm1 = yield d1
-
-        d2 = limiter.queue(key)
-        cm2 = yield d2
-
-        d3 = limiter.queue(key)
-        cm3 = yield d3
-
-        d4 = limiter.queue(key)
-        self.assertFalse(d4.called)
-
-        d5 = limiter.queue(key)
-        self.assertFalse(d5.called)
+        _, acquired_d1, unblock1 = self._start_task(limiter, key)
+        self.assertTrue(acquired_d1.called)
 
-        with cm1:
-            self.assertFalse(d4.called)
-            self.assertFalse(d5.called)
+        _, acquired_d2, unblock2 = self._start_task(limiter, key)
+        self.assertTrue(acquired_d2.called)
 
-        cm4 = yield d4
-        self.assertFalse(d5.called)
+        _, acquired_d3, unblock3 = self._start_task(limiter, key)
+        self.assertTrue(acquired_d3.called)
 
-        with cm3:
-            self.assertFalse(d5.called)
+        # These next two tasks have to wait.
+        _, acquired_d4, unblock4 = self._start_task(limiter, key)
+        self.assertFalse(acquired_d4.called)
 
-        cm5 = yield d5
+        _, acquired_d5, unblock5 = self._start_task(limiter, key)
+        self.assertFalse(acquired_d5.called)
 
-        with cm2:
-            pass
+        # Once the first task completes, the fourth task can continue.
+        unblock1()
+        self.assertTrue(acquired_d4.called)
+        self.assertFalse(acquired_d5.called)
 
-        with cm4:
-            pass
+        # Once the third task completes, the fifth task can continue.
+        unblock3()
+        self.assertTrue(acquired_d5.called)
 
-        with cm5:
-            pass
+        # Make all tasks finish.
+        unblock2()
+        unblock4()
+        unblock5()
 
-        d6 = limiter.queue(key)
-        with (yield d6):
-            pass
+        # The next task shouldn't have to wait.
+        _, acquired_d6, unblock6 = self._start_task(limiter, key)
+        self.assertTrue(acquired_d6)
+        unblock6()
 
-    @defer.inlineCallbacks
-    def test_cancellation(self):
+    def test_cancellation(self) -> None:
+        """Tests cancellation while waiting for a `Linearizer`."""
         linearizer = Linearizer()
 
         key = object()
 
-        d1 = linearizer.queue(key)
-        cm1 = yield d1
+        d1, acquired_d1, unblock1 = self._start_task(linearizer, key)
+        self.assertTrue(acquired_d1.called)
 
-        d2 = linearizer.queue(key)
-        self.assertFalse(d2.called)
+        # Create a second task, waiting for the first task.
+        d2, acquired_d2, _ = self._start_task(linearizer, key)
+        self.assertFalse(acquired_d2.called)
 
-        d3 = linearizer.queue(key)
-        self.assertFalse(d3.called)
+        # Create a third task, waiting for the second task.
+        d3, acquired_d3, unblock3 = self._start_task(linearizer, key)
+        self.assertFalse(acquired_d3.called)
 
+        # Cancel the waiting second task.
         d2.cancel()
 
-        with cm1:
-            pass
+        unblock1()
+        self.successResultOf(d1)
 
         self.assertTrue(d2.called)
-        try:
-            yield d2
-            self.fail("Expected d2 to raise CancelledError")
-        except CancelledError:
-            pass
-
-        with (yield d3):
-            pass
+        self.failureResultOf(d2, CancelledError)
+
+        # The third task should continue running.
+        self.assertTrue(
+            acquired_d3.called,
+            "Third task did not get the lock after the second task was cancelled",
+        )
+        unblock3()
+        self.successResultOf(d3)
diff --git a/tests/utils.py b/tests/utils.py
index f6b1d60371..d4ba3a9b99 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -35,6 +35,11 @@ LEAVE_DB = os.environ.get("SYNAPSE_LEAVE_DB", False)
 POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", None)
 POSTGRES_HOST = os.environ.get("SYNAPSE_POSTGRES_HOST", None)
 POSTGRES_PASSWORD = os.environ.get("SYNAPSE_POSTGRES_PASSWORD", None)
+POSTGRES_PORT = (
+    int(os.environ["SYNAPSE_POSTGRES_PORT"])
+    if "SYNAPSE_POSTGRES_PORT" in os.environ
+    else None
+)
 POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),)
 
 # When debugging a specific test, it's occasionally useful to write the
@@ -55,6 +60,7 @@ def setupdb():
         db_conn = db_engine.module.connect(
             user=POSTGRES_USER,
             host=POSTGRES_HOST,
+            port=POSTGRES_PORT,
             password=POSTGRES_PASSWORD,
             dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
         )
@@ -73,6 +79,7 @@ def setupdb():
             database=POSTGRES_BASE_DB,
             user=POSTGRES_USER,
             host=POSTGRES_HOST,
+            port=POSTGRES_PORT,
             password=POSTGRES_PASSWORD,
         )
         db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
@@ -83,6 +90,7 @@ def setupdb():
             db_conn = db_engine.module.connect(
                 user=POSTGRES_USER,
                 host=POSTGRES_HOST,
+                port=POSTGRES_PORT,
                 password=POSTGRES_PASSWORD,
                 dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
             )