summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_auth.py57
-rw-r--r--tests/federation/transport/server/test__base.py10
-rw-r--r--tests/handlers/test_auth.py2
-rw-r--r--tests/handlers/test_device.py4
-rw-r--r--tests/handlers/test_oidc.py7
-rw-r--r--tests/handlers/test_register.py2
-rw-r--r--tests/handlers/test_room_summary.py2
-rw-r--r--tests/handlers/test_stats.py28
-rw-r--r--tests/handlers/test_sync.py2
-rw-r--r--tests/http/server/_base.py580
-rw-r--r--tests/http/test_matrixfederationclient.py (renamed from tests/http/test_fedclient.py)14
-rw-r--r--tests/http/test_servlet.py10
-rw-r--r--tests/replication/http/test__base.py10
-rw-r--r--tests/rest/client/test_profile.py8
-rw-r--r--tests/rest/client/test_relations.py1
-rw-r--r--tests/rest/client/test_rooms.py275
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py22
-rw-r--r--tests/state/test_v2.py125
-rw-r--r--tests/test_event_auth.py357
-rw-r--r--tests/test_server.py14
-rw-r--r--tests/test_state.py11
-rw-r--r--tests/test_types.py13
-rw-r--r--tests/unittest.py2
-rw-r--r--tests/util/test_macaroons.py146
24 files changed, 1337 insertions, 365 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index bc75ddd3e9..dfcfaf79b6 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -19,6 +19,7 @@ import pymacaroons
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.auth import Auth
+from synapse.api.auth_blocking import AuthBlocking
 from synapse.api.constants import UserTypes
 from synapse.api.errors import (
     AuthError,
@@ -49,7 +50,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
         # AuthBlocking reads from the hs' config on initialization. We need to
         # modify its config instead of the hs'
-        self.auth_blocking = self.auth._auth_blocking
+        self.auth_blocking = AuthBlocking(hs)
 
         self.test_user = "@foo:bar"
         self.test_token = b"_test_token_"
@@ -312,9 +313,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.assertEqual(self.store.insert_client_ip.call_count, 2)
 
     def test_get_user_from_macaroon(self):
-        self.store.get_user_by_access_token = simple_async_mock(
-            TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
-        )
+        self.store.get_user_by_access_token = simple_async_mock(None)
 
         user_id = "@baldrick:matrix.org"
         macaroon = pymacaroons.Macaroon(
@@ -322,17 +321,14 @@ class AuthTestCase(unittest.HomeserverTestCase):
             identifier="key",
             key=self.hs.config.key.macaroon_secret_key,
         )
+        # "Legacy" macaroons should not work for regular users not in the database
         macaroon.add_first_party_caveat("gen = 1")
         macaroon.add_first_party_caveat("type = access")
         macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
-        user_info = self.get_success(
-            self.auth.get_user_by_access_token(macaroon.serialize())
+        serialized = macaroon.serialize()
+        self.get_failure(
+            self.auth.get_user_by_access_token(serialized), InvalidClientTokenError
         )
-        self.assertEqual(user_id, user_info.user_id)
-
-        # TODO: device_id should come from the macaroon, but currently comes
-        # from the db.
-        self.assertEqual(user_info.device_id, "device")
 
     def test_get_guest_user_from_macaroon(self):
         self.store.get_user_by_id = simple_async_mock({"is_guest": True})
@@ -362,20 +358,22 @@ class AuthTestCase(unittest.HomeserverTestCase):
         small_number_of_users = 1
 
         # Ensure no error thrown
-        self.get_success(self.auth.check_auth_blocking())
+        self.get_success(self.auth_blocking.check_auth_blocking())
 
         self.auth_blocking._limit_usage_by_mau = True
 
         self.store.get_monthly_active_count = simple_async_mock(lots_of_users)
 
-        e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+        e = self.get_failure(
+            self.auth_blocking.check_auth_blocking(), ResourceLimitError
+        )
         self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact)
         self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
         self.assertEqual(e.value.code, 403)
 
         # Ensure does not throw an error
         self.store.get_monthly_active_count = simple_async_mock(small_number_of_users)
-        self.get_success(self.auth.check_auth_blocking())
+        self.get_success(self.auth_blocking.check_auth_blocking())
 
     def test_blocking_mau__depending_on_user_type(self):
         self.auth_blocking._max_mau_value = 50
@@ -383,15 +381,18 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
         self.store.get_monthly_active_count = simple_async_mock(100)
         # Support users allowed
-        self.get_success(self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT))
+        self.get_success(
+            self.auth_blocking.check_auth_blocking(user_type=UserTypes.SUPPORT)
+        )
         self.store.get_monthly_active_count = simple_async_mock(100)
         # Bots not allowed
         self.get_failure(
-            self.auth.check_auth_blocking(user_type=UserTypes.BOT), ResourceLimitError
+            self.auth_blocking.check_auth_blocking(user_type=UserTypes.BOT),
+            ResourceLimitError,
         )
         self.store.get_monthly_active_count = simple_async_mock(100)
         # Real users not allowed
-        self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+        self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError)
 
     def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(self):
         self.auth_blocking._max_mau_value = 50
@@ -419,7 +420,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             app_service=appservice,
             authenticated_entity="@appservice:server",
         )
-        self.get_success(self.auth.check_auth_blocking(requester=requester))
+        self.get_success(self.auth_blocking.check_auth_blocking(requester=requester))
 
     def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(self):
         self.auth_blocking._max_mau_value = 50
@@ -448,7 +449,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
             authenticated_entity="@appservice:server",
         )
         self.get_failure(
-            self.auth.check_auth_blocking(requester=requester), ResourceLimitError
+            self.auth_blocking.check_auth_blocking(requester=requester),
+            ResourceLimitError,
         )
 
     def test_reserved_threepid(self):
@@ -459,18 +461,21 @@ class AuthTestCase(unittest.HomeserverTestCase):
         unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
         self.auth_blocking._mau_limits_reserved_threepids = [threepid]
 
-        self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+        self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError)
 
         self.get_failure(
-            self.auth.check_auth_blocking(threepid=unknown_threepid), ResourceLimitError
+            self.auth_blocking.check_auth_blocking(threepid=unknown_threepid),
+            ResourceLimitError,
         )
 
-        self.get_success(self.auth.check_auth_blocking(threepid=threepid))
+        self.get_success(self.auth_blocking.check_auth_blocking(threepid=threepid))
 
     def test_hs_disabled(self):
         self.auth_blocking._hs_disabled = True
         self.auth_blocking._hs_disabled_message = "Reason for being disabled"
-        e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+        e = self.get_failure(
+            self.auth_blocking.check_auth_blocking(), ResourceLimitError
+        )
         self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact)
         self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
         self.assertEqual(e.value.code, 403)
@@ -485,7 +490,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
         self.auth_blocking._hs_disabled = True
         self.auth_blocking._hs_disabled_message = "Reason for being disabled"
-        e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+        e = self.get_failure(
+            self.auth_blocking.check_auth_blocking(), ResourceLimitError
+        )
         self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact)
         self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
         self.assertEqual(e.value.code, 403)
@@ -495,4 +502,4 @@ class AuthTestCase(unittest.HomeserverTestCase):
         user = "@user:server"
         self.auth_blocking._server_notices_mxid = user
         self.auth_blocking._hs_disabled_message = "Reason for being disabled"
-        self.get_success(self.auth.check_auth_blocking(user))
+        self.get_success(self.auth_blocking.check_auth_blocking(user))
diff --git a/tests/federation/transport/server/test__base.py b/tests/federation/transport/server/test__base.py
index e63885c1c9..d33e86db4c 100644
--- a/tests/federation/transport/server/test__base.py
+++ b/tests/federation/transport/server/test__base.py
@@ -24,7 +24,7 @@ from synapse.types import JsonDict
 from synapse.util.ratelimitutils import FederationRateLimiter
 
 from tests import unittest
-from tests.http.server._base import EndpointCancellationTestHelperMixin
+from tests.http.server._base import test_disconnect
 
 
 class CancellableFederationServlet(BaseFederationServlet):
@@ -54,9 +54,7 @@ class CancellableFederationServlet(BaseFederationServlet):
         return HTTPStatus.OK, {"result": True}
 
 
-class BaseFederationServletCancellationTests(
-    unittest.FederatingHomeserverTestCase, EndpointCancellationTestHelperMixin
-):
+class BaseFederationServletCancellationTests(unittest.FederatingHomeserverTestCase):
     """Tests for `BaseFederationServlet` cancellation."""
 
     skip = "`BaseFederationServlet` does not support cancellation yet."
@@ -86,7 +84,7 @@ class BaseFederationServletCancellationTests(
         # request won't be processed.
         self.pump()
 
-        self._test_disconnect(
+        test_disconnect(
             self.reactor,
             channel,
             expect_cancellation=True,
@@ -106,7 +104,7 @@ class BaseFederationServletCancellationTests(
         # request won't be processed.
         self.pump()
 
-        self._test_disconnect(
+        test_disconnect(
             self.reactor,
             channel,
             expect_cancellation=False,
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 67a7829769..7106799d44 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -38,7 +38,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         # MAU tests
         # AuthBlocking reads from the hs' config on initialization. We need to
         # modify its config instead of the hs'
-        self.auth_blocking = hs.get_auth()._auth_blocking
+        self.auth_blocking = hs.get_auth_blocking()
         self.auth_blocking._max_mau_value = 50
 
         self.small_number_of_users = 1
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 01ea7d2a42..b8b465d35b 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -154,7 +154,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
         self._record_users()
 
         # delete the device
-        self.get_success(self.handler.delete_device(user1, "abc"))
+        self.get_success(self.handler.delete_devices(user1, ["abc"]))
 
         # check the device was deleted
         self.get_failure(self.handler.get_device(user1, "abc"), NotFoundError)
@@ -179,7 +179,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
         )
 
         # delete the device
-        self.get_success(self.handler.delete_device(user1, "abc"))
+        self.get_success(self.handler.delete_devices(user1, ["abc"]))
 
         # check that the device_inbox was deleted
         res = self.get_success(
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 1231aed944..e6cd3af7b7 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -25,7 +25,7 @@ from synapse.handlers.sso import MappingException
 from synapse.server import HomeServer
 from synapse.types import JsonDict, UserID
 from synapse.util import Clock
-from synapse.util.macaroons import get_value_from_macaroon
+from synapse.util.macaroons import OidcSessionData, get_value_from_macaroon
 
 from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
@@ -1227,7 +1227,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
     ) -> str:
         from synapse.handlers.oidc import OidcSessionData
 
-        return self.handler._token_generator.generate_oidc_session_token(
+        return self.handler._macaroon_generator.generate_oidc_session_token(
             state=state,
             session_data=OidcSessionData(
                 idp_id="oidc",
@@ -1251,7 +1251,6 @@ async def _make_callback_with_userinfo(
         userinfo: the OIDC userinfo dict
         client_redirect_url: the URL to redirect to on success.
     """
-    from synapse.handlers.oidc import OidcSessionData
 
     handler = hs.get_oidc_handler()
     provider = handler._providers["oidc"]
@@ -1260,7 +1259,7 @@ async def _make_callback_with_userinfo(
     provider._fetch_userinfo = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
 
     state = "state"
-    session = handler._token_generator.generate_oidc_session_token(
+    session = handler._macaroon_generator.generate_oidc_session_token(
         state=state,
         session_data=OidcSessionData(
             idp_id="oidc",
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index b6ba19c739..23f35d5bf5 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -699,7 +699,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         """
         if localpart is None:
             raise SynapseError(400, "Request must include user id")
-        await self.hs.get_auth().check_auth_blocking()
+        await self.hs.get_auth_blocking().check_auth_blocking()
         need_register = True
 
         try:
diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py
index 0546655690..aa650756e4 100644
--- a/tests/handlers/test_room_summary.py
+++ b/tests/handlers/test_room_summary.py
@@ -178,7 +178,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
             result_room_ids.append(result_room["room_id"])
             result_children_ids.append(
                 [
-                    (cs["room_id"], cs["state_key"])
+                    (result_room["room_id"], cs["state_key"])
                     for cs in result_room["children_state"]
                 ]
             )
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index ecd78fa369..05f9ec3c51 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -46,16 +46,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         self.get_success(
             self.store.db_pool.simple_insert(
                 "background_updates",
-                {"update_name": "populate_stats_prepare", "progress_json": "{}"},
-            )
-        )
-        self.get_success(
-            self.store.db_pool.simple_insert(
-                "background_updates",
                 {
                     "update_name": "populate_stats_process_rooms",
                     "progress_json": "{}",
-                    "depends_on": "populate_stats_prepare",
                 },
             )
         )
@@ -69,16 +62,6 @@ class StatsRoomTests(unittest.HomeserverTestCase):
                 },
             )
         )
-        self.get_success(
-            self.store.db_pool.simple_insert(
-                "background_updates",
-                {
-                    "update_name": "populate_stats_cleanup",
-                    "progress_json": "{}",
-                    "depends_on": "populate_stats_process_users",
-                },
-            )
-        )
 
     async def get_all_room_state(self):
         return await self.store.db_pool.simple_select_list(
@@ -533,7 +516,6 @@ class StatsRoomTests(unittest.HomeserverTestCase):
                 {
                     "update_name": "populate_stats_process_rooms",
                     "progress_json": "{}",
-                    "depends_on": "populate_stats_prepare",
                 },
             )
         )
@@ -547,16 +529,6 @@ class StatsRoomTests(unittest.HomeserverTestCase):
                 },
             )
         )
-        self.get_success(
-            self.store.db_pool.simple_insert(
-                "background_updates",
-                {
-                    "update_name": "populate_stats_cleanup",
-                    "progress_json": "{}",
-                    "depends_on": "populate_stats_process_users",
-                },
-            )
-        )
 
         self.wait_for_background_updates()
 
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index db3302a4c7..ecc7cc6461 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -45,7 +45,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
 
         # AuthBlocking reads from the hs' config on initialization. We need to
         # modify its config instead of the hs'
-        self.auth_blocking = self.hs.get_auth()._auth_blocking
+        self.auth_blocking = self.hs.get_auth_blocking()
 
     def test_wait_for_sync_for_user_auth_blocking(self):
         user_id1 = "@user1:test"
diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py
index b9f1a381aa..994d8880b0 100644
--- a/tests/http/server/_base.py
+++ b/tests/http/server/_base.py
@@ -12,89 +12,543 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import inspect
+import itertools
+import logging
 from http import HTTPStatus
-from typing import Any, Callable, Optional, Union
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    List,
+    Optional,
+    Set,
+    Tuple,
+    TypeVar,
+    Union,
+)
 from unittest import mock
+from unittest.mock import Mock
 
+from twisted.internet.defer import Deferred
 from twisted.internet.error import ConnectionDone
+from twisted.python.failure import Failure
+from twisted.test.proto_helpers import MemoryReactorClock
+from twisted.web.server import Site
 
 from synapse.http.server import (
     HTTP_STATUS_REQUEST_CANCELLED,
     respond_with_html_bytes,
     respond_with_json,
 )
+from synapse.http.site import SynapseRequest
+from synapse.logging.context import LoggingContext, make_deferred_yieldable
 from synapse.types import JsonDict
 
-from tests import unittest
-from tests.server import FakeChannel, ThreadedMemoryReactorClock
+from tests.server import FakeChannel, make_request
+from tests.unittest import logcontext_clean
 
+logger = logging.getLogger(__name__)
 
-class EndpointCancellationTestHelperMixin(unittest.TestCase):
-    """Provides helper methods for testing cancellation of endpoints."""
 
-    def _test_disconnect(
-        self,
-        reactor: ThreadedMemoryReactorClock,
-        channel: FakeChannel,
-        expect_cancellation: bool,
-        expected_body: Union[bytes, JsonDict],
-        expected_code: Optional[int] = None,
-    ) -> None:
-        """Disconnects an in-flight request and checks the response.
+T = TypeVar("T")
 
-        Args:
-            reactor: The twisted reactor running the request handler.
-            channel: The `FakeChannel` for the request.
-            expect_cancellation: `True` if request processing is expected to be
-                cancelled, `False` if the request should run to completion.
-            expected_body: The expected response for the request.
-            expected_code: The expected status code for the request. Defaults to `200`
-                or `499` depending on `expect_cancellation`.
-        """
-        # Determine the expected status code.
-        if expected_code is None:
-            if expect_cancellation:
-                expected_code = HTTP_STATUS_REQUEST_CANCELLED
-            else:
-                expected_code = HTTPStatus.OK
-
-        request = channel.request
-        self.assertFalse(
-            channel.is_finished(),
+
+def test_disconnect(
+    reactor: MemoryReactorClock,
+    channel: FakeChannel,
+    expect_cancellation: bool,
+    expected_body: Union[bytes, JsonDict],
+    expected_code: Optional[int] = None,
+) -> None:
+    """Disconnects an in-flight request and checks the response.
+
+    Args:
+        reactor: The twisted reactor running the request handler.
+        channel: The `FakeChannel` for the request.
+        expect_cancellation: `True` if request processing is expected to be cancelled,
+            `False` if the request should run to completion.
+        expected_body: The expected response for the request.
+        expected_code: The expected status code for the request. Defaults to `200` or
+            `499` depending on `expect_cancellation`.
+    """
+    # Determine the expected status code.
+    if expected_code is None:
+        if expect_cancellation:
+            expected_code = HTTP_STATUS_REQUEST_CANCELLED
+        else:
+            expected_code = HTTPStatus.OK
+
+    request = channel.request
+    if channel.is_finished():
+        raise AssertionError(
             "Request finished before we could disconnect - "
-            "was `await_result=False` passed to `make_request`?",
+            "ensure `await_result=False` is passed to `make_request`.",
         )
 
-        # We're about to disconnect the request. This also disconnects the channel, so
-        # we have to rely on mocks to extract the response.
-        respond_method: Callable[..., Any]
-        if isinstance(expected_body, bytes):
-            respond_method = respond_with_html_bytes
+    # We're about to disconnect the request. This also disconnects the channel, so we
+    # have to rely on mocks to extract the response.
+    respond_method: Callable[..., Any]
+    if isinstance(expected_body, bytes):
+        respond_method = respond_with_html_bytes
+    else:
+        respond_method = respond_with_json
+
+    with mock.patch(
+        f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
+    ) as respond_mock:
+        # Disconnect the request.
+        request.connectionLost(reason=ConnectionDone())
+
+        if expect_cancellation:
+            # An immediate cancellation is expected.
+            respond_mock.assert_called_once()
         else:
-            respond_method = respond_with_json
+            respond_mock.assert_not_called()
 
-        with mock.patch(
-            f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
-        ) as respond_mock:
-            # Disconnect the request.
-            request.connectionLost(reason=ConnectionDone())
+            # The handler is expected to run to completion.
+            reactor.advance(1.0)
+            respond_mock.assert_called_once()
 
-            if expect_cancellation:
-                # An immediate cancellation is expected.
-                respond_mock.assert_called_once()
-                args, _kwargs = respond_mock.call_args
-                code, body = args[1], args[2]
-                self.assertEqual(code, expected_code)
-                self.assertEqual(request.code, expected_code)
-                self.assertEqual(body, expected_body)
-            else:
-                respond_mock.assert_not_called()
-
-                # The handler is expected to run to completion.
-                reactor.pump([1.0])
+        args, _kwargs = respond_mock.call_args
+        code, body = args[1], args[2]
+
+        if code != expected_code:
+            raise AssertionError(
+                f"{code} != {expected_code} : "
+                "Request did not finish with the expected status code."
+            )
+
+        if request.code != expected_code:
+            raise AssertionError(
+                f"{request.code} != {expected_code} : "
+                "Request did not finish with the expected status code."
+            )
+
+        if body != expected_body:
+            raise AssertionError(
+                f"{body!r} != {expected_body!r} : "
+                "Request did not finish with the expected status code."
+            )
+
+
+@logcontext_clean
+def make_request_with_cancellation_test(
+    test_name: str,
+    reactor: MemoryReactorClock,
+    site: Site,
+    method: str,
+    path: str,
+    content: Union[bytes, str, JsonDict] = b"",
+) -> FakeChannel:
+    """Performs a request repeatedly, disconnecting at successive `await`s, until
+    one completes.
+
+    Fails if:
+        * A logging context is lost during cancellation.
+        * A logging context get restarted after it is marked as finished, eg. if
+            a request's logging context is used by some processing started by the
+            request, but the request neglects to cancel that processing or wait for it
+            to complete.
+
+            Note that "Re-starting finished log context" errors get raised within the
+            request handling code and may or may not get caught. These errors will
+            likely manifest as a different logging context error at a later point. When
+            debugging logging context failures, setting a breakpoint in
+            `logcontext_error` can prove useful.
+        * A request gets stuck, possibly due to a previous cancellation.
+        * The request does not return a 499 when the client disconnects.
+            This implies that a `CancelledError` was swallowed somewhere.
+
+    It is up to the caller to verify that the request returns the correct data when
+    it finally runs to completion.
+
+    Note that this function can only cover a single code path and does not guarantee
+    that an endpoint is compatible with cancellation on every code path.
+    To allow inspection of the code path that is being tested, this function will
+    log the stack trace at every `await` that gets cancelled. To view these log
+    lines, `trial` can be run with the `SYNAPSE_TEST_LOG_LEVEL=INFO` environment
+    variable, which will include the log lines in `_trial_temp/test.log`.
+    Alternatively, `_log_for_request` can be modified to write to `sys.stdout`.
+
+    Args:
+        test_name: The name of the test, which will be logged.
+        reactor: The twisted reactor running the request handler.
+        site: The twisted `Site` to use to render the request.
+        method: The HTTP request method ("verb").
+        path: The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces and
+            such).
+        content: The body of the request.
+
+    Returns:
+        The `FakeChannel` object which stores the result of the final request that
+        runs to completion.
+    """
+    # To process a request, a coroutine run is created for the async method handling
+    # the request. That method may then start other coroutine runs, wrapped in
+    # `Deferred`s.
+    #
+    # We would like to trigger a cancellation at the first `await`, re-run the
+    # request and cancel at the second `await`, and so on. By patching
+    # `Deferred.__next__`, we can intercept `await`s, track which ones we have or
+    # have not seen, and force them to block when they wouldn't have.
+
+    # The set of previously seen `await`s.
+    # Each element is a stringified stack trace.
+    seen_awaits: Set[Tuple[str, ...]] = set()
+
+    _log_for_request(
+        0, f"Running make_request_with_cancellation_test for {test_name}..."
+    )
+
+    for request_number in itertools.count(1):
+        deferred_patch = Deferred__next__Patch(seen_awaits, request_number)
+
+        try:
+            with mock.patch(
+                "synapse.http.server.respond_with_json", wraps=respond_with_json
+            ) as respond_mock:
+                with deferred_patch.patch():
+                    # Start the request.
+                    channel = make_request(
+                        reactor, site, method, path, content, await_result=False
+                    )
+                    request = channel.request
+
+                    # Run the request until we see a new `await` which we have not
+                    # yet cancelled at, or it completes.
+                    while not respond_mock.called and not deferred_patch.new_await_seen:
+                        previous_awaits_seen = deferred_patch.awaits_seen
+
+                        reactor.advance(0.0)
+
+                        if deferred_patch.awaits_seen == previous_awaits_seen:
+                            # We didn't see any progress. Try advancing the clock.
+                            reactor.advance(1.0)
+
+                        if deferred_patch.awaits_seen == previous_awaits_seen:
+                            # We still didn't see any progress. The request might be
+                            # stuck.
+                            raise AssertionError(
+                                "Request appears to be stuck, possibly due to a "
+                                "previous cancelled request"
+                            )
+
+                if respond_mock.called:
+                    # The request ran to completion and we are done with testing it.
+
+                    # `respond_with_json` writes the response asynchronously, so we
+                    # might have to give the reactor a kick before the channel gets
+                    # the response.
+                    deferred_patch.unblock_awaits()
+                    channel.await_result()
+
+                    return channel
+
+                # Disconnect the client and wait for the response.
+                request.connectionLost(reason=ConnectionDone())
+
+                _log_for_request(request_number, "--- disconnected ---")
+
+                # Advance the reactor just enough to get a response.
+                # We don't want to advance the reactor too far, because we can only
+                # detect re-starts of finished logging contexts after we set the
+                # finished flag below.
+                for _ in range(2):
+                    # We may need to pump the reactor to allow `delay_cancellation`s to
+                    # finish.
+                    if not respond_mock.called:
+                        reactor.advance(0.0)
+
+                    # Try advancing the clock if that didn't work.
+                    if not respond_mock.called:
+                        reactor.advance(1.0)
+
+                    # `delay_cancellation`s may be waiting for processing that we've
+                    # forced to block. Try unblocking them, followed by another round of
+                    # pumping the reactor.
+                    if not respond_mock.called:
+                        deferred_patch.unblock_awaits()
+
+                # Mark the request's logging context as finished. If it gets
+                # activated again, an `AssertionError` will be raised and bubble up
+                # through request handling code. This `AssertionError` may or may not be
+                # caught. Eventually some other code will deactivate the logging
+                # context which will raise a different `AssertionError` because
+                # resource usage won't have been correctly tracked.
+                if isinstance(request, SynapseRequest) and request.logcontext:
+                    request.logcontext.finished = True
+
+                # Check that the request finished with a 499,
+                # ie. the `CancelledError` wasn't swallowed.
                 respond_mock.assert_called_once()
-                args, _kwargs = respond_mock.call_args
-                code, body = args[1], args[2]
-                self.assertEqual(code, expected_code)
-                self.assertEqual(request.code, expected_code)
-                self.assertEqual(body, expected_body)
+
+                if request.code != HTTP_STATUS_REQUEST_CANCELLED:
+                    raise AssertionError(
+                        f"{request.code} != {HTTP_STATUS_REQUEST_CANCELLED} : "
+                        "Cancelled request did not finish with the correct status code."
+                    )
+        finally:
+            # Unblock any processing that might be shared between requests, if we
+            # haven't already done so.
+            deferred_patch.unblock_awaits()
+
+    assert False, "unreachable"  # noqa: B011
+
+
+class Deferred__next__Patch:
+    """A `Deferred.__next__` patch that will intercept `await`s and force them
+    to block once it sees a new `await`.
+
+    When done with the patch, `unblock_awaits()` must be called to clean up after any
+    `await`s that were forced to block, otherwise processing shared between multiple
+    requests, such as database queries started by `@cached`, will become permanently
+    stuck.
+
+    Usage:
+        seen_awaits = set()
+        deferred_patch = Deferred__next__Patch(seen_awaits, 1)
+        try:
+            with deferred_patch.patch():
+                # do things
+                ...
+        finally:
+            deferred_patch.unblock_awaits()
+    """
+
+    def __init__(self, seen_awaits: Set[Tuple[str, ...]], request_number: int):
+        """
+        Args:
+            seen_awaits: The set of stack traces of `await`s that have been previously
+                seen. When the `Deferred.__next__` patch sees a new `await`, it will add
+                it to the set.
+            request_number: The request number to log against.
+        """
+        self._request_number = request_number
+        self._seen_awaits = seen_awaits
+
+        self._original_Deferred___next__ = Deferred.__next__
+
+        # The number of `await`s on `Deferred`s we have seen so far.
+        self.awaits_seen = 0
+
+        # Whether we have seen a new `await` not in `seen_awaits`.
+        self.new_await_seen = False
+
+        # To force `await`s on resolved `Deferred`s to block, we make up a new
+        # unresolved `Deferred` and return it out of `Deferred.__next__` /
+        # `coroutine.send()`. We have to resolve it later, in case the `await`ing
+        # coroutine is part of some shared processing, such as `@cached`.
+        self._to_unblock: Dict[Deferred, Union[object, Failure]] = {}
+
+        # The last stack we logged.
+        self._previous_stack: List[inspect.FrameInfo] = []
+
+    def patch(self) -> ContextManager[Mock]:
+        """Returns a context manager which patches `Deferred.__next__`."""
+
+        def Deferred___next__(
+            deferred: "Deferred[T]", value: object = None
+        ) -> "Deferred[T]":
+            """Intercepts `await`s on `Deferred`s and rigs them to block once we have
+            seen enough of them.
+
+            `Deferred.__next__` will normally:
+                * return `self` if the `Deferred` is unresolved, in which case
+                   `coroutine.send()` will return the `Deferred`, and
+                   `_defer.inlineCallbacks` will stop running the coroutine until the
+                   `Deferred` is resolved.
+                * raise a `StopIteration(result)`, containing the result of the `await`.
+                * raise another exception, which will come out of the `await`.
+            """
+            self.awaits_seen += 1
+
+            stack = _get_stack(skip_frames=1)
+            stack_hash = _hash_stack(stack)
+
+            if stack_hash not in self._seen_awaits:
+                # Block at the current `await` onwards.
+                self._seen_awaits.add(stack_hash)
+                self.new_await_seen = True
+
+            if not self.new_await_seen:
+                # This `await` isn't interesting. Let it proceed normally.
+
+                # Don't log the stack. It's been seen before in a previous run.
+                self._previous_stack = stack
+
+                return self._original_Deferred___next__(deferred, value)
+
+            # We want to block at the current `await`.
+            if deferred.called and not deferred.paused:
+                # This `Deferred` already has a result.
+                # We return a new, unresolved, `Deferred` for `_inlineCallbacks` to wait
+                # on. This blocks the coroutine that did this `await`.
+                # We queue it up for unblocking later.
+                new_deferred: "Deferred[T]" = Deferred()
+                self._to_unblock[new_deferred] = deferred.result
+
+                _log_await_stack(
+                    stack,
+                    self._previous_stack,
+                    self._request_number,
+                    "force-blocked await",
+                )
+                self._previous_stack = stack
+
+                return make_deferred_yieldable(new_deferred)
+
+            # This `Deferred` does not have a result yet.
+            # The `await` will block normally, so we don't have to do anything.
+            _log_await_stack(
+                stack,
+                self._previous_stack,
+                self._request_number,
+                "blocking await",
+            )
+            self._previous_stack = stack
+
+            return self._original_Deferred___next__(deferred, value)
+
+        return mock.patch.object(Deferred, "__next__", new=Deferred___next__)
+
+    def unblock_awaits(self) -> None:
+        """Unblocks any shared processing that we forced to block.
+
+        Must be called when done, otherwise processing shared between multiple requests,
+        such as database queries started by `@cached`, will become permanently stuck.
+        """
+        to_unblock = self._to_unblock
+        self._to_unblock = {}
+        for deferred, result in to_unblock.items():
+            deferred.callback(result)
+
+
+def _log_for_request(request_number: int, message: str) -> None:
+    """Logs a message for an iteration of `make_request_with_cancellation_test`."""
+    # We want consistent alignment when logging stack traces, so ensure the logging
+    # context has a fixed width name.
+    with LoggingContext(name=f"request-{request_number:<2}"):
+        logger.info(message)
+
+
+def _log_await_stack(
+    stack: List[inspect.FrameInfo],
+    previous_stack: List[inspect.FrameInfo],
+    request_number: int,
+    note: str,
+) -> None:
+    """Logs the stack for an `await` in `make_request_with_cancellation_test`.
+
+    Only logs the part of the stack that has changed since the previous call.
+
+    Example output looks like:
+    ```
+    delay_cancellation:750 (synapse/util/async_helpers.py:750)
+        DatabasePool._runInteraction:768 (synapse/storage/database.py:768)
+            > *blocked on await* at DatabasePool.runWithConnection:891 (synapse/storage/database.py:891)
+    ```
+
+    Args:
+        stack: The stack to log, as returned by `_get_stack()`.
+        previous_stack: The previous stack logged, with callers appearing before
+            callees.
+        request_number: The request number to log against.
+        note: A note to attach to the last stack frame, eg. "blocked on await".
+    """
+    for i, frame_info in enumerate(stack[:-1]):
+        # Skip any frames in common with the previous logging.
+        if i < len(previous_stack) and frame_info == previous_stack[i]:
+            continue
+
+        frame = _format_stack_frame(frame_info)
+        message = f"{'  ' * i}{frame}"
+        _log_for_request(request_number, message)
+
+    # Always print the final frame with the `await`.
+    # If the frame with the `await` started another coroutine run, we may have already
+    # printed a deeper stack which includes our final frame. We want to log where all
+    # `await`s happen, so we reprint the frame in this case.
+    i = len(stack) - 1
+    frame_info = stack[i]
+    frame = _format_stack_frame(frame_info)
+    message = f"{'  ' * i}> *{note}* at {frame}"
+    _log_for_request(request_number, message)
+
+
+def _format_stack_frame(frame_info: inspect.FrameInfo) -> str:
+    """Returns a string representation of a stack frame.
+
+    Used for debug logging.
+
+    Returns:
+        A string, formatted like
+        "JsonResource._async_render:559 (synapse/http/server.py:559)".
+    """
+    method_name = _get_stack_frame_method_name(frame_info)
+
+    return (
+        f"{method_name}:{frame_info.lineno} ({frame_info.filename}:{frame_info.lineno})"
+    )
+
+
+def _get_stack(skip_frames: int) -> List[inspect.FrameInfo]:
+    """Captures the stack for a request.
+
+    Skips any twisted frames and stops at `JsonResource.wrapped_async_request_handler`.
+
+    Used for debug logging.
+
+    Returns:
+        A list of `inspect.FrameInfo`s, with callers appearing before callees.
+    """
+    stack = []
+
+    skip_frames += 1  # Also skip `get_stack` itself.
+
+    for frame_info in inspect.stack()[skip_frames:]:
+        # Skip any twisted `inlineCallbacks` gunk.
+        if "/twisted/" in frame_info.filename:
+            continue
+
+        # Exclude the reactor frame, upwards.
+        method_name = _get_stack_frame_method_name(frame_info)
+        if method_name == "ThreadedMemoryReactorClock.advance":
+            break
+
+        stack.append(frame_info)
+
+        # Stop at `JsonResource`'s `wrapped_async_request_handler`, which is the entry
+        # point for request handling.
+        if frame_info.function == "wrapped_async_request_handler":
+            break
+
+    return stack[::-1]
+
+
+def _get_stack_frame_method_name(frame_info: inspect.FrameInfo) -> str:
+    """Returns the name of a stack frame's method.
+
+    eg. "JsonResource._async_render".
+    """
+    method_name = frame_info.function
+
+    # Prefix the class name for instance methods.
+    frame_self = frame_info.frame.f_locals.get("self")
+    if frame_self:
+        method = getattr(frame_self, method_name, None)
+        if method:
+            method_name = method.__qualname__
+        else:
+            # We couldn't find the method on `self`.
+            # Make something up. It's useful to know which class "contains" a
+            # function anyway.
+            method_name = f"{type(frame_self).__name__} {method_name}"
+
+    return method_name
+
+
+def _hash_stack(stack: List[inspect.FrameInfo]):
+    """Turns a stack into a hashable value that can be put into a set."""
+    return tuple(_format_stack_frame(frame) for frame in stack)
diff --git a/tests/http/test_fedclient.py b/tests/http/test_matrixfederationclient.py
index 006dbab093..be9eaf34e8 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_matrixfederationclient.py
@@ -617,3 +617,17 @@ class FederationClientTests(HomeserverTestCase):
         self.assertIsInstance(f.value, RequestSendFailed)
 
         self.assertTrue(transport.disconnecting)
+
+    def test_build_auth_headers_rejects_falsey_destinations(self) -> None:
+        with self.assertRaises(ValueError):
+            self.cl.build_auth_headers(None, b"GET", b"https://example.com")
+        with self.assertRaises(ValueError):
+            self.cl.build_auth_headers(b"", b"GET", b"https://example.com")
+        with self.assertRaises(ValueError):
+            self.cl.build_auth_headers(
+                None, b"GET", b"https://example.com", destination_is=b""
+            )
+        with self.assertRaises(ValueError):
+            self.cl.build_auth_headers(
+                b"", b"GET", b"https://example.com", destination_is=b""
+            )
diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py
index b3655d7b44..bb966c80c6 100644
--- a/tests/http/test_servlet.py
+++ b/tests/http/test_servlet.py
@@ -30,7 +30,7 @@ from synapse.server import HomeServer
 from synapse.types import JsonDict
 
 from tests import unittest
-from tests.http.server._base import EndpointCancellationTestHelperMixin
+from tests.http.server._base import test_disconnect
 
 
 def make_request(content):
@@ -108,9 +108,7 @@ class CancellableRestServlet(RestServlet):
         return HTTPStatus.OK, {"result": True}
 
 
-class TestRestServletCancellation(
-    unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
-):
+class TestRestServletCancellation(unittest.HomeserverTestCase):
     """Tests for `RestServlet` cancellation."""
 
     servlets = [
@@ -120,7 +118,7 @@ class TestRestServletCancellation(
     def test_cancellable_disconnect(self) -> None:
         """Test that handlers with the `@cancellable` flag can be cancelled."""
         channel = self.make_request("GET", "/sleep", await_result=False)
-        self._test_disconnect(
+        test_disconnect(
             self.reactor,
             channel,
             expect_cancellation=True,
@@ -130,7 +128,7 @@ class TestRestServletCancellation(
     def test_uncancellable_disconnect(self) -> None:
         """Test that handlers without the `@cancellable` flag cannot be cancelled."""
         channel = self.make_request("POST", "/sleep", await_result=False)
-        self._test_disconnect(
+        test_disconnect(
             self.reactor,
             channel,
             expect_cancellation=False,
diff --git a/tests/replication/http/test__base.py b/tests/replication/http/test__base.py
index a5ab093a27..822a957c3a 100644
--- a/tests/replication/http/test__base.py
+++ b/tests/replication/http/test__base.py
@@ -25,7 +25,7 @@ from synapse.server import HomeServer
 from synapse.types import JsonDict
 
 from tests import unittest
-from tests.http.server._base import EndpointCancellationTestHelperMixin
+from tests.http.server._base import test_disconnect
 
 
 class CancellableReplicationEndpoint(ReplicationEndpoint):
@@ -69,9 +69,7 @@ class UncancellableReplicationEndpoint(ReplicationEndpoint):
         return HTTPStatus.OK, {"result": True}
 
 
-class ReplicationEndpointCancellationTestCase(
-    unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
-):
+class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase):
     """Tests for `ReplicationEndpoint` cancellation."""
 
     def create_test_resource(self):
@@ -87,7 +85,7 @@ class ReplicationEndpointCancellationTestCase(
         """Test that handlers with the `@cancellable` flag can be cancelled."""
         path = f"{REPLICATION_PREFIX}/{CancellableReplicationEndpoint.NAME}/"
         channel = self.make_request("POST", path, await_result=False)
-        self._test_disconnect(
+        test_disconnect(
             self.reactor,
             channel,
             expect_cancellation=True,
@@ -98,7 +96,7 @@ class ReplicationEndpointCancellationTestCase(
         """Test that handlers without the `@cancellable` flag cannot be cancelled."""
         path = f"{REPLICATION_PREFIX}/{UncancellableReplicationEndpoint.NAME}/"
         channel = self.make_request("POST", path, await_result=False)
-        self._test_disconnect(
+        test_disconnect(
             self.reactor,
             channel,
             expect_cancellation=False,
diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py
index 77c3ced42e..29bed0e872 100644
--- a/tests/rest/client/test_profile.py
+++ b/tests/rest/client/test_profile.py
@@ -13,6 +13,8 @@
 # limitations under the License.
 
 """Tests REST events for /profile paths."""
+import urllib.parse
+from http import HTTPStatus
 from typing import Any, Dict, Optional
 
 from twisted.test.proto_helpers import MemoryReactor
@@ -49,6 +51,12 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         res = self._get_displayname()
         self.assertEqual(res, "owner")
 
+    def test_get_displayname_rejects_bad_username(self) -> None:
+        channel = self.make_request(
+            "GET", f"/profile/{urllib.parse.quote('@alice:')}/displayname"
+        )
+        self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+
     def test_set_displayname(self) -> None:
         channel = self.make_request(
             "PUT",
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 62e4db23ef..aa84906548 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -728,6 +728,7 @@ class RelationsTestCase(BaseRelationsTestCase):
 
 
 class RelationPaginationTestCase(BaseRelationsTestCase):
+    @unittest.override_config({"experimental_features": {"msc3715_enabled": True}})
     def test_basic_paginate_relations(self) -> None:
         """Tests that calling pagination API correctly the latest relations."""
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index f523d89b8f..35c59ee9e0 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -18,10 +18,13 @@
 """Tests REST events for /rooms paths."""
 
 import json
-from typing import Any, Dict, Iterable, List, Optional
+from typing import Any, Dict, Iterable, List, Optional, Union
 from unittest.mock import Mock, call
 from urllib import parse as urlparse
 
+# `Literal` appears with Python 3.8.
+from typing_extensions import Literal
+
 from twisted.test.proto_helpers import MemoryReactor
 
 import synapse.rest.admin
@@ -42,6 +45,7 @@ from synapse.util import Clock
 from synapse.util.stringutils import random_string
 
 from tests import unittest
+from tests.http.server._base import make_request_with_cancellation_test
 from tests.test_utils import make_awaitable
 
 PATH_PREFIX = b"/_matrix/client/api/v1"
@@ -471,6 +475,49 @@ class RoomPermissionsTestCase(RoomBase):
         )
 
 
+class RoomStateTestCase(RoomBase):
+    """Tests /rooms/$room_id/state."""
+
+    user_id = "@sid1:red"
+
+    def test_get_state_cancellation(self) -> None:
+        """Test cancellation of a `/rooms/$room_id/state` request."""
+        room_id = self.helper.create_room_as(self.user_id)
+        channel = make_request_with_cancellation_test(
+            "test_state_cancellation",
+            self.reactor,
+            self.site,
+            "GET",
+            "/rooms/%s/state" % room_id,
+        )
+
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
+        self.assertCountEqual(
+            [state_event["type"] for state_event in channel.json_body],
+            {
+                "m.room.create",
+                "m.room.power_levels",
+                "m.room.join_rules",
+                "m.room.member",
+                "m.room.history_visibility",
+            },
+        )
+
+    def test_get_state_event_cancellation(self) -> None:
+        """Test cancellation of a `/rooms/$room_id/state/$event_type` request."""
+        room_id = self.helper.create_room_as(self.user_id)
+        channel = make_request_with_cancellation_test(
+            "test_state_cancellation",
+            self.reactor,
+            self.site,
+            "GET",
+            "/rooms/%s/state/m.room.member/%s" % (room_id, self.user_id),
+        )
+
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
+        self.assertEqual(channel.json_body, {"membership": "join"})
+
+
 class RoomsMemberListTestCase(RoomBase):
     """Tests /rooms/$room_id/members/list REST events."""
 
@@ -591,6 +638,62 @@ class RoomsMemberListTestCase(RoomBase):
         channel = self.make_request("GET", room_path)
         self.assertEqual(200, channel.code, msg=channel.result["body"])
 
+    def test_get_member_list_cancellation(self) -> None:
+        """Test cancellation of a `/rooms/$room_id/members` request."""
+        room_id = self.helper.create_room_as(self.user_id)
+        channel = make_request_with_cancellation_test(
+            "test_get_member_list_cancellation",
+            self.reactor,
+            self.site,
+            "GET",
+            "/rooms/%s/members" % room_id,
+        )
+
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
+        self.assertEqual(len(channel.json_body["chunk"]), 1)
+        self.assertLessEqual(
+            {
+                "content": {"membership": "join"},
+                "room_id": room_id,
+                "sender": self.user_id,
+                "state_key": self.user_id,
+                "type": "m.room.member",
+                "user_id": self.user_id,
+            }.items(),
+            channel.json_body["chunk"][0].items(),
+        )
+
+    def test_get_member_list_with_at_token_cancellation(self) -> None:
+        """Test cancellation of a `/rooms/$room_id/members?at=<sync token>` request."""
+        room_id = self.helper.create_room_as(self.user_id)
+
+        # first sync to get an at token
+        channel = self.make_request("GET", "/sync")
+        self.assertEqual(200, channel.code)
+        sync_token = channel.json_body["next_batch"]
+
+        channel = make_request_with_cancellation_test(
+            "test_get_member_list_with_at_token_cancellation",
+            self.reactor,
+            self.site,
+            "GET",
+            "/rooms/%s/members?at=%s" % (room_id, sync_token),
+        )
+
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
+        self.assertEqual(len(channel.json_body["chunk"]), 1)
+        self.assertLessEqual(
+            {
+                "content": {"membership": "join"},
+                "room_id": room_id,
+                "sender": self.user_id,
+                "state_key": self.user_id,
+                "type": "m.room.member",
+                "user_id": self.user_id,
+            }.items(),
+            channel.json_body["chunk"][0].items(),
+        )
+
 
 class RoomsCreateTestCase(RoomBase):
     """Tests /rooms and /rooms/$room_id REST events."""
@@ -677,9 +780,11 @@ class RoomsCreateTestCase(RoomBase):
         channel = self.make_request("POST", "/createRoom", content)
         self.assertEqual(200, channel.code)
 
-    def test_spam_checker_may_join_room(self) -> None:
+    def test_spam_checker_may_join_room_deprecated(self) -> None:
         """Tests that the user_may_join_room spam checker callback is correctly bypassed
         when creating a new room.
+
+        In this test, we use the deprecated API in which callbacks return a bool.
         """
 
         async def user_may_join_room(
@@ -701,6 +806,32 @@ class RoomsCreateTestCase(RoomBase):
 
         self.assertEqual(join_mock.call_count, 0)
 
+    def test_spam_checker_may_join_room(self) -> None:
+        """Tests that the user_may_join_room spam checker callback is correctly bypassed
+        when creating a new room.
+
+        In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`.
+        """
+
+        async def user_may_join_room(
+            mxid: str,
+            room_id: str,
+            is_invite: bool,
+        ) -> Codes:
+            return Codes.CONSENT_NOT_GIVEN
+
+        join_mock = Mock(side_effect=user_may_join_room)
+        self.hs.get_spam_checker()._user_may_join_room_callbacks.append(join_mock)
+
+        channel = self.make_request(
+            "POST",
+            "/createRoom",
+            {},
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        self.assertEqual(join_mock.call_count, 0)
+
 
 class RoomTopicTestCase(RoomBase):
     """Tests /rooms/$room_id/topic REST events."""
@@ -911,9 +1042,11 @@ class RoomJoinTestCase(RoomBase):
         self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
         self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
 
-    def test_spam_checker_may_join_room(self) -> None:
+    def test_spam_checker_may_join_room_deprecated(self) -> None:
         """Tests that the user_may_join_room spam checker callback is correctly called
         and blocks room joins when needed.
+
+        This test uses the deprecated API, in which callbacks return booleans.
         """
 
         # Register a dummy callback. Make it allow all room joins for now.
@@ -926,6 +1059,8 @@ class RoomJoinTestCase(RoomBase):
         ) -> bool:
             return return_value
 
+        # `spec` argument is needed for this function mock to have `__qualname__`, which
+        # is needed for `Measure` metrics buried in SpamChecker.
         callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None)
         self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock)
 
@@ -968,6 +1103,67 @@ class RoomJoinTestCase(RoomBase):
         return_value = False
         self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2)
 
+    def test_spam_checker_may_join_room(self) -> None:
+        """Tests that the user_may_join_room spam checker callback is correctly called
+        and blocks room joins when needed.
+
+        This test uses the latest API to this day, in which callbacks return `NOT_SPAM` or `Codes`.
+        """
+
+        # Register a dummy callback. Make it allow all room joins for now.
+        return_value: Union[Literal["NOT_SPAM"], Codes] = synapse.module_api.NOT_SPAM
+
+        async def user_may_join_room(
+            userid: str,
+            room_id: str,
+            is_invited: bool,
+        ) -> Union[Literal["NOT_SPAM"], Codes]:
+            return return_value
+
+        # `spec` argument is needed for this function mock to have `__qualname__`, which
+        # is needed for `Measure` metrics buried in SpamChecker.
+        callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None)
+        self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock)
+
+        # Join a first room, without being invited to it.
+        self.helper.join(self.room1, self.user2, tok=self.tok2)
+
+        # Check that the callback was called with the right arguments.
+        expected_call_args = (
+            (
+                self.user2,
+                self.room1,
+                False,
+            ),
+        )
+        self.assertEqual(
+            callback_mock.call_args,
+            expected_call_args,
+            callback_mock.call_args,
+        )
+
+        # Join a second room, this time with an invite for it.
+        self.helper.invite(self.room2, self.user1, self.user2, tok=self.tok1)
+        self.helper.join(self.room2, self.user2, tok=self.tok2)
+
+        # Check that the callback was called with the right arguments.
+        expected_call_args = (
+            (
+                self.user2,
+                self.room2,
+                True,
+            ),
+        )
+        self.assertEqual(
+            callback_mock.call_args,
+            expected_call_args,
+            callback_mock.call_args,
+        )
+
+        # Now make the callback deny all room joins, and check that a join actually fails.
+        return_value = Codes.CONSENT_NOT_GIVEN
+        self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2)
+
 
 class RoomJoinRatelimitTestCase(RoomBase):
     user_id = "@sid1:red"
@@ -2845,9 +3041,14 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
 
         self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
 
-    def test_threepid_invite_spamcheck(self) -> None:
+    def test_threepid_invite_spamcheck_deprecated(self) -> None:
+        """
+        Test allowing/blocking threepid invites with a spam-check module.
+
+        In this test, we use the deprecated API in which callbacks return a bool.
+        """
         # Mock a few functions to prevent the test from failing due to failing to talk to
-        # a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we
+        # a remote IS. We keep the mock for make_and_store_3pid_invite around so we
         # can check its call_count later on during the test.
         make_invite_mock = Mock(return_value=make_awaitable(0))
         self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
@@ -2901,3 +3102,67 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
 
         # Also check that it stopped before calling _make_and_store_3pid_invite.
         make_invite_mock.assert_called_once()
+
+    def test_threepid_invite_spamcheck(self) -> None:
+        """
+        Test allowing/blocking threepid invites with a spam-check module.
+
+        In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`."""
+        # Mock a few functions to prevent the test from failing due to failing to talk to
+        # a remote IS. We keep the mock for make_and_store_3pid_invite around so we
+        # can check its call_count later on during the test.
+        make_invite_mock = Mock(return_value=make_awaitable(0))
+        self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
+        self.hs.get_identity_handler().lookup_3pid = Mock(
+            return_value=make_awaitable(None),
+        )
+
+        # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it
+        # allow everything for now.
+        # `spec` argument is needed for this function mock to have `__qualname__`, which
+        # is needed for `Measure` metrics buried in SpamChecker.
+        mock = Mock(
+            return_value=make_awaitable(synapse.module_api.NOT_SPAM),
+            spec=lambda *x: None,
+        )
+        self.hs.get_spam_checker()._user_may_send_3pid_invite_callbacks.append(mock)
+
+        # Send a 3PID invite into the room and check that it succeeded.
+        email_to_invite = "teresa@example.com"
+        channel = self.make_request(
+            method="POST",
+            path="/rooms/" + self.room_id + "/invite",
+            content={
+                "id_server": "example.com",
+                "id_access_token": "sometoken",
+                "medium": "email",
+                "address": email_to_invite,
+            },
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.code, 200)
+
+        # Check that the callback was called with the right params.
+        mock.assert_called_with(self.user_id, "email", email_to_invite, self.room_id)
+
+        # Check that the call to send the invite was made.
+        make_invite_mock.assert_called_once()
+
+        # Now change the return value of the callback to deny any invite and test that
+        # we can't send the invite.
+        mock.return_value = make_awaitable(Codes.CONSENT_NOT_GIVEN)
+        channel = self.make_request(
+            method="POST",
+            path="/rooms/" + self.room_id + "/invite",
+            content={
+                "id_server": "example.com",
+                "id_access_token": "sometoken",
+                "medium": "email",
+                "address": email_to_invite,
+            },
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.code, 403)
+
+        # Also check that it stopped before calling _make_and_store_3pid_invite.
+        make_invite_mock.assert_called_once()
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 07e29788e5..e07ae78fc4 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -96,7 +96,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
     def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
         """Test when user has blocked notice, but should have it removed"""
 
-        self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
+        self._rlsn._auth_blocking.check_auth_blocking = Mock(
+            return_value=make_awaitable(None)
+        )
         mock_event = Mock(
             type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
         )
@@ -112,7 +114,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         """
         Test when user has blocked notice, but notice ought to be there (NOOP)
         """
-        self._rlsn._auth.check_auth_blocking = Mock(
+        self._rlsn._auth_blocking.check_auth_blocking = Mock(
             return_value=make_awaitable(None),
             side_effect=ResourceLimitError(403, "foo"),
         )
@@ -132,7 +134,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         """
         Test when user does not have blocked notice, but should have one
         """
-        self._rlsn._auth.check_auth_blocking = Mock(
+        self._rlsn._auth_blocking.check_auth_blocking = Mock(
             return_value=make_awaitable(None),
             side_effect=ResourceLimitError(403, "foo"),
         )
@@ -145,7 +147,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         """
         Test when user does not have blocked notice, nor should they (NOOP)
         """
-        self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
+        self._rlsn._auth_blocking.check_auth_blocking = Mock(
+            return_value=make_awaitable(None)
+        )
 
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
 
@@ -156,7 +160,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         Test when user is not part of the MAU cohort - this should not ever
         happen - but ...
         """
-        self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
+        self._rlsn._auth_blocking.check_auth_blocking = Mock(
+            return_value=make_awaitable(None)
+        )
         self._rlsn._store.user_last_seen_monthly_active = Mock(
             return_value=make_awaitable(None)
         )
@@ -170,7 +176,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         Test that when server is over MAU limit and alerting is suppressed, then
         an alert message is not sent into the room
         """
-        self._rlsn._auth.check_auth_blocking = Mock(
+        self._rlsn._auth_blocking.check_auth_blocking = Mock(
             return_value=make_awaitable(None),
             side_effect=ResourceLimitError(
                 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
@@ -185,7 +191,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         """
         Test that when a server is disabled, that MAU limit alerting is ignored.
         """
-        self._rlsn._auth.check_auth_blocking = Mock(
+        self._rlsn._auth_blocking.check_auth_blocking = Mock(
             return_value=make_awaitable(None),
             side_effect=ResourceLimitError(
                 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
@@ -202,7 +208,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         When the room is already in a blocked state, test that when alerting
         is suppressed that the room is returned to an unblocked state.
         """
-        self._rlsn._auth.check_auth_blocking = Mock(
+        self._rlsn._auth_blocking.check_auth_blocking = Mock(
             return_value=make_awaitable(None),
             side_effect=ResourceLimitError(
                 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index 8370a27195..78b83d97b6 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -13,7 +13,17 @@
 # limitations under the License.
 
 import itertools
-from typing import List
+from typing import (
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Mapping,
+    Optional,
+    Set,
+    Tuple,
+    TypeVar,
+)
 
 import attr
 
@@ -22,13 +32,13 @@ from twisted.internet import defer
 from synapse.api.constants import EventTypes, JoinRules, Membership
 from synapse.api.room_versions import RoomVersions
 from synapse.event_auth import auth_types_for_event
-from synapse.events import make_event_from_dict
+from synapse.events import EventBase, make_event_from_dict
 from synapse.state.v2 import (
     _get_auth_chain_difference,
     lexicographical_topological_sort,
     resolve_events_with_store,
 )
-from synapse.types import EventID
+from synapse.types import EventID, StateMap
 
 from tests import unittest
 
@@ -48,7 +58,7 @@ ORIGIN_SERVER_TS = 0
 
 
 class FakeClock:
-    def sleep(self, msec):
+    def sleep(self, msec: float) -> "defer.Deferred[None]":
         return defer.succeed(None)
 
 
@@ -60,7 +70,14 @@ class FakeEvent:
     as domain.
     """
 
-    def __init__(self, id, sender, type, state_key, content):
+    def __init__(
+        self,
+        id: str,
+        sender: str,
+        type: str,
+        state_key: Optional[str],
+        content: Mapping[str, object],
+    ):
         self.node_id = id
         self.event_id = EventID(id, "example.com").to_string()
         self.sender = sender
@@ -69,12 +86,12 @@ class FakeEvent:
         self.content = content
         self.room_id = ROOM_ID
 
-    def to_event(self, auth_events, prev_events):
+    def to_event(self, auth_events: List[str], prev_events: List[str]) -> EventBase:
         """Given the auth_events and prev_events, convert to a Frozen Event
 
         Args:
-            auth_events (list[str]): list of event_ids
-            prev_events (list[str]): list of event_ids
+            auth_events: list of event_ids
+            prev_events: list of event_ids
 
         Returns:
             FrozenEvent
@@ -164,7 +181,7 @@ INITIAL_EDGES = ["START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE"]
 
 
 class StateTestCase(unittest.TestCase):
-    def test_ban_vs_pl(self):
+    def test_ban_vs_pl(self) -> None:
         events = [
             FakeEvent(
                 id="PA",
@@ -202,7 +219,7 @@ class StateTestCase(unittest.TestCase):
 
         self.do_check(events, edges, expected_state_ids)
 
-    def test_join_rule_evasion(self):
+    def test_join_rule_evasion(self) -> None:
         events = [
             FakeEvent(
                 id="JR",
@@ -226,7 +243,7 @@ class StateTestCase(unittest.TestCase):
 
         self.do_check(events, edges, expected_state_ids)
 
-    def test_offtopic_pl(self):
+    def test_offtopic_pl(self) -> None:
         events = [
             FakeEvent(
                 id="PA",
@@ -257,7 +274,7 @@ class StateTestCase(unittest.TestCase):
 
         self.do_check(events, edges, expected_state_ids)
 
-    def test_topic_basic(self):
+    def test_topic_basic(self) -> None:
         events = [
             FakeEvent(
                 id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
@@ -297,7 +314,7 @@ class StateTestCase(unittest.TestCase):
 
         self.do_check(events, edges, expected_state_ids)
 
-    def test_topic_reset(self):
+    def test_topic_reset(self) -> None:
         events = [
             FakeEvent(
                 id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
@@ -327,7 +344,7 @@ class StateTestCase(unittest.TestCase):
 
         self.do_check(events, edges, expected_state_ids)
 
-    def test_topic(self):
+    def test_topic(self) -> None:
         events = [
             FakeEvent(
                 id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
@@ -380,7 +397,7 @@ class StateTestCase(unittest.TestCase):
 
         self.do_check(events, edges, expected_state_ids)
 
-    def test_mainline_sort(self):
+    def test_mainline_sort(self) -> None:
         """Tests that the mainline ordering works correctly."""
 
         events = [
@@ -434,22 +451,26 @@ class StateTestCase(unittest.TestCase):
 
         self.do_check(events, edges, expected_state_ids)
 
-    def do_check(self, events, edges, expected_state_ids):
+    def do_check(
+        self,
+        events: List[FakeEvent],
+        edges: List[List[str]],
+        expected_state_ids: List[str],
+    ) -> None:
         """Take a list of events and edges and calculate the state of the
         graph at END, and asserts it matches `expected_state_ids`
 
         Args:
-            events (list[FakeEvent])
-            edges (list[list[str]]): A list of chains of event edges, e.g.
+            events
+            edges: A list of chains of event edges, e.g.
                 `[[A, B, C]]` are edges A->B and B->C.
-            expected_state_ids (list[str]): The expected state at END, (excluding
+            expected_state_ids: The expected state at END, (excluding
                 the keys that haven't changed since START).
         """
         # We want to sort the events into topological order for processing.
-        graph = {}
+        graph: Dict[str, Set[str]] = {}
 
-        # node_id -> FakeEvent
-        fake_event_map = {}
+        fake_event_map: Dict[str, FakeEvent] = {}
 
         for ev in itertools.chain(INITIAL_EVENTS, events):
             graph[ev.node_id] = set()
@@ -462,10 +483,8 @@ class StateTestCase(unittest.TestCase):
             for a, b in pairwise(edge_list):
                 graph[a].add(b)
 
-        # event_id -> FrozenEvent
-        event_map = {}
-        # node_id -> state
-        state_at_event = {}
+        event_map: Dict[str, EventBase] = {}
+        state_at_event: Dict[str, StateMap[str]] = {}
 
         # We copy the map as the sort consumes the graph
         graph_copy = {k: set(v) for k, v in graph.items()}
@@ -496,7 +515,16 @@ class StateTestCase(unittest.TestCase):
             if fake_event.state_key is not None:
                 state_after[(fake_event.type, fake_event.state_key)] = event_id
 
-            auth_types = set(auth_types_for_event(RoomVersions.V6, fake_event))
+            # This type ignore is a bit sad. Things we have tried:
+            # 1. Define a `GenericEvent` Protocol satisfied by FakeEvent, EventBase and
+            #    EventBuilder. But this is Hard because the relevant attributes are
+            #    DictProperty[T] descriptors on EventBase but normal Ts on FakeEvent.
+            # 2. Define a `GenericEvent` Protocol describing `FakeEvent` only, and
+            #    change this function to accept Union[Event, EventBase, EventBuilder].
+            #    This seems reasonable to me, but mypy isn't happy. I think that's
+            #    a mypy bug, see https://github.com/python/mypy/issues/5570
+            # Instead, resort to a type-ignore.
+            auth_types = set(auth_types_for_event(RoomVersions.V6, fake_event))  # type: ignore[arg-type]
 
             auth_events = []
             for key in auth_types:
@@ -530,8 +558,14 @@ class StateTestCase(unittest.TestCase):
 
 
 class LexicographicalTestCase(unittest.TestCase):
-    def test_simple(self):
-        graph = {"l": {"o"}, "m": {"n", "o"}, "n": {"o"}, "o": set(), "p": {"o"}}
+    def test_simple(self) -> None:
+        graph: Dict[str, Set[str]] = {
+            "l": {"o"},
+            "m": {"n", "o"},
+            "n": {"o"},
+            "o": set(),
+            "p": {"o"},
+        }
 
         res = list(lexicographical_topological_sort(graph, key=lambda x: x))
 
@@ -539,7 +573,7 @@ class LexicographicalTestCase(unittest.TestCase):
 
 
 class SimpleParamStateTestCase(unittest.TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         # We build up a simple DAG.
 
         event_map = {}
@@ -627,7 +661,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
             ]
         }
 
-    def test_event_map_none(self):
+    def test_event_map_none(self) -> None:
         # Test that we correctly handle passing `None` as the event_map
 
         state_d = resolve_events_with_store(
@@ -649,7 +683,7 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
     events.
     """
 
-    def test_simple(self):
+    def test_simple(self) -> None:
         # Test getting the auth difference for a simple chain with a single
         # unpersisted event:
         #
@@ -695,7 +729,7 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
 
         self.assertEqual(difference, {c.event_id})
 
-    def test_multiple_unpersisted_chain(self):
+    def test_multiple_unpersisted_chain(self) -> None:
         # Test getting the auth difference for a simple chain with multiple
         # unpersisted events:
         #
@@ -752,7 +786,7 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
 
         self.assertEqual(difference, {d.event_id, c.event_id})
 
-    def test_unpersisted_events_different_sets(self):
+    def test_unpersisted_events_different_sets(self) -> None:
         # Test getting the auth difference for with multiple unpersisted events
         # in different branches:
         #
@@ -820,7 +854,10 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
         self.assertEqual(difference, {d.event_id, e.event_id})
 
 
-def pairwise(iterable):
+T = TypeVar("T")
+
+
+def pairwise(iterable: Iterable[T]) -> Iterable[Tuple[T, T]]:
     "s -> (s0,s1), (s1,s2), (s2, s3), ..."
     a, b = itertools.tee(iterable)
     next(b, None)
@@ -829,24 +866,26 @@ def pairwise(iterable):
 
 @attr.s
 class TestStateResolutionStore:
-    event_map = attr.ib()
+    event_map: Dict[str, EventBase] = attr.ib()
 
-    def get_events(self, event_ids, allow_rejected=False):
+    def get_events(
+        self, event_ids: Collection[str], allow_rejected: bool = False
+    ) -> "defer.Deferred[Dict[str, EventBase]]":
         """Get events from the database
 
         Args:
-            event_ids (list): The event_ids of the events to fetch
-            allow_rejected (bool): If True return rejected events.
+            event_ids: The event_ids of the events to fetch
+            allow_rejected: If True return rejected events.
 
         Returns:
-            Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
+            Dict from event_id to event.
         """
 
         return defer.succeed(
             {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
         )
 
-    def _get_auth_chain(self, event_ids: List[str]) -> List[str]:
+    def _get_auth_chain(self, event_ids: Iterable[str]) -> List[str]:
         """Gets the full auth chain for a set of events (including rejected
         events).
 
@@ -880,7 +919,9 @@ class TestStateResolutionStore:
 
         return list(result)
 
-    def get_auth_chain_difference(self, room_id, auth_sets):
+    def get_auth_chain_difference(
+        self, room_id: str, auth_sets: List[Set[str]]
+    ) -> "defer.Deferred[Set[str]]":
         chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
 
         common = set(chains[0]).intersection(*chains[1:])
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index e2c506e5a4..229ecd84a6 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -15,10 +15,12 @@
 import unittest
 from typing import Optional
 
+from parameterized import parameterized
+
 from synapse import event_auth
 from synapse.api.constants import EventContentFields
 from synapse.api.errors import AuthError
-from synapse.api.room_versions import RoomVersions
+from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
 from synapse.events import EventBase, make_event_from_dict
 from synapse.types import JsonDict, get_domain_from_id
 
@@ -30,38 +32,39 @@ class EventAuthTestCase(unittest.TestCase):
         """
         creator = "@creator:example.com"
         auth_events = [
-            _create_event(creator),
-            _join_event(creator),
+            _create_event(RoomVersions.V9, creator),
+            _join_event(RoomVersions.V9, creator),
         ]
 
         # creator should be able to send state
         event_auth.check_auth_rules_for_event(
-            RoomVersions.V9,
-            _random_state_event(creator),
+            _random_state_event(RoomVersions.V9, creator),
             auth_events,
         )
 
         # ... but a rejected join_rules event should cause it to be rejected
-        rejected_join_rules = _join_rules_event(creator, "public")
+        rejected_join_rules = _join_rules_event(
+            RoomVersions.V9,
+            creator,
+            "public",
+        )
         rejected_join_rules.rejected_reason = "stinky"
         auth_events.append(rejected_join_rules)
 
         self.assertRaises(
             AuthError,
             event_auth.check_auth_rules_for_event,
-            RoomVersions.V9,
-            _random_state_event(creator),
+            _random_state_event(RoomVersions.V9, creator),
             auth_events,
         )
 
         # ... even if there is *also* a good join rules
-        auth_events.append(_join_rules_event(creator, "public"))
+        auth_events.append(_join_rules_event(RoomVersions.V9, creator, "public"))
 
         self.assertRaises(
             AuthError,
             event_auth.check_auth_rules_for_event,
-            RoomVersions.V9,
-            _random_state_event(creator),
+            _random_state_event(RoomVersions.V9, creator),
             auth_events,
         )
 
@@ -73,15 +76,14 @@ class EventAuthTestCase(unittest.TestCase):
         creator = "@creator:example.com"
         joiner = "@joiner:example.com"
         auth_events = [
-            _create_event(creator),
-            _join_event(creator),
-            _join_event(joiner),
+            _create_event(RoomVersions.V1, creator),
+            _join_event(RoomVersions.V1, creator),
+            _join_event(RoomVersions.V1, joiner),
         ]
 
         # creator should be able to send state
         event_auth.check_auth_rules_for_event(
-            RoomVersions.V1,
-            _random_state_event(creator),
+            _random_state_event(RoomVersions.V1, creator),
             auth_events,
         )
 
@@ -89,8 +91,7 @@ class EventAuthTestCase(unittest.TestCase):
         self.assertRaises(
             AuthError,
             event_auth.check_auth_rules_for_event,
-            RoomVersions.V1,
-            _random_state_event(joiner),
+            _random_state_event(RoomVersions.V1, joiner),
             auth_events,
         )
 
@@ -104,28 +105,28 @@ class EventAuthTestCase(unittest.TestCase):
         king = "@joiner2:example.com"
 
         auth_events = [
-            _create_event(creator),
-            _join_event(creator),
+            _create_event(RoomVersions.V1, creator),
+            _join_event(RoomVersions.V1, creator),
             _power_levels_event(
-                creator, {"state_default": "30", "users": {pleb: "29", king: "30"}}
+                RoomVersions.V1,
+                creator,
+                {"state_default": "30", "users": {pleb: "29", king: "30"}},
             ),
-            _join_event(pleb),
-            _join_event(king),
+            _join_event(RoomVersions.V1, pleb),
+            _join_event(RoomVersions.V1, king),
         ]
 
         # pleb should not be able to send state
         self.assertRaises(
             AuthError,
             event_auth.check_auth_rules_for_event,
-            RoomVersions.V1,
-            _random_state_event(pleb),
+            _random_state_event(RoomVersions.V1, pleb),
             auth_events,
         ),
 
         # king should be able to send state
         event_auth.check_auth_rules_for_event(
-            RoomVersions.V1,
-            _random_state_event(king),
+            _random_state_event(RoomVersions.V1, king),
             auth_events,
         )
 
@@ -134,37 +135,33 @@ class EventAuthTestCase(unittest.TestCase):
         creator = "@creator:example.com"
         other = "@other:example.com"
         auth_events = [
-            _create_event(creator),
-            _join_event(creator),
+            _create_event(RoomVersions.V1, creator),
+            _join_event(RoomVersions.V1, creator),
         ]
 
         # creator should be able to send aliases
         event_auth.check_auth_rules_for_event(
-            RoomVersions.V1,
-            _alias_event(creator),
+            _alias_event(RoomVersions.V1, creator),
             auth_events,
         )
 
         # Reject an event with no state key.
         with self.assertRaises(AuthError):
             event_auth.check_auth_rules_for_event(
-                RoomVersions.V1,
-                _alias_event(creator, state_key=""),
+                _alias_event(RoomVersions.V1, creator, state_key=""),
                 auth_events,
             )
 
         # If the domain of the sender does not match the state key, reject.
         with self.assertRaises(AuthError):
             event_auth.check_auth_rules_for_event(
-                RoomVersions.V1,
-                _alias_event(creator, state_key="test.com"),
+                _alias_event(RoomVersions.V1, creator, state_key="test.com"),
                 auth_events,
             )
 
         # Note that the member does *not* need to be in the room.
         event_auth.check_auth_rules_for_event(
-            RoomVersions.V1,
-            _alias_event(other),
+            _alias_event(RoomVersions.V1, other),
             auth_events,
         )
 
@@ -173,38 +170,35 @@ class EventAuthTestCase(unittest.TestCase):
         creator = "@creator:example.com"
         other = "@other:example.com"
         auth_events = [
-            _create_event(creator),
-            _join_event(creator),
+            _create_event(RoomVersions.V6, creator),
+            _join_event(RoomVersions.V6, creator),
         ]
 
         # creator should be able to send aliases
         event_auth.check_auth_rules_for_event(
-            RoomVersions.V6,
-            _alias_event(creator),
+            _alias_event(RoomVersions.V6, creator),
             auth_events,
         )
 
         # No particular checks are done on the state key.
         event_auth.check_auth_rules_for_event(
-            RoomVersions.V6,
-            _alias_event(creator, state_key=""),
+            _alias_event(RoomVersions.V6, creator, state_key=""),
             auth_events,
         )
         event_auth.check_auth_rules_for_event(
-            RoomVersions.V6,
-            _alias_event(creator, state_key="test.com"),
+            _alias_event(RoomVersions.V6, creator, state_key="test.com"),
             auth_events,
         )
 
         # Per standard auth rules, the member must be in the room.
         with self.assertRaises(AuthError):
             event_auth.check_auth_rules_for_event(
-                RoomVersions.V6,
-                _alias_event(other),
+                _alias_event(RoomVersions.V6, other),
                 auth_events,
             )
 
-    def test_msc2209(self):
+    @parameterized.expand([(RoomVersions.V1, True), (RoomVersions.V6, False)])
+    def test_notifications(self, room_version: RoomVersion, allow_modification: bool):
         """
         Notifications power levels get checked due to MSC2209.
         """
@@ -212,28 +206,26 @@ class EventAuthTestCase(unittest.TestCase):
         pleb = "@joiner:example.com"
 
         auth_events = [
-            _create_event(creator),
-            _join_event(creator),
+            _create_event(room_version, creator),
+            _join_event(room_version, creator),
             _power_levels_event(
-                creator, {"state_default": "30", "users": {pleb: "30"}}
+                room_version, creator, {"state_default": "30", "users": {pleb: "30"}}
             ),
-            _join_event(pleb),
+            _join_event(room_version, pleb),
         ]
 
-        # pleb should be able to modify the notifications power level.
-        event_auth.check_auth_rules_for_event(
-            RoomVersions.V1,
-            _power_levels_event(pleb, {"notifications": {"room": 100}}),
-            auth_events,
+        pl_event = _power_levels_event(
+            room_version, pleb, {"notifications": {"room": 100}}
         )
 
-        # But an MSC2209 room rejects this change.
-        with self.assertRaises(AuthError):
-            event_auth.check_auth_rules_for_event(
-                RoomVersions.V6,
-                _power_levels_event(pleb, {"notifications": {"room": 100}}),
-                auth_events,
-            )
+        # on room V1, pleb should be able to modify the notifications power level.
+        if allow_modification:
+            event_auth.check_auth_rules_for_event(pl_event, auth_events)
+
+        else:
+            # But an MSC2209 room rejects this change.
+            with self.assertRaises(AuthError):
+                event_auth.check_auth_rules_for_event(pl_event, auth_events)
 
     def test_join_rules_public(self):
         """
@@ -243,58 +235,60 @@ class EventAuthTestCase(unittest.TestCase):
         pleb = "@joiner:example.com"
 
         auth_events = {
-            ("m.room.create", ""): _create_event(creator),
-            ("m.room.member", creator): _join_event(creator),
-            ("m.room.join_rules", ""): _join_rules_event(creator, "public"),
+            ("m.room.create", ""): _create_event(RoomVersions.V6, creator),
+            ("m.room.member", creator): _join_event(RoomVersions.V6, creator),
+            ("m.room.join_rules", ""): _join_rules_event(
+                RoomVersions.V6, creator, "public"
+            ),
         }
 
         # Check join.
         event_auth.check_auth_rules_for_event(
-            RoomVersions.V6,
-            _join_event(pleb),
+            _join_event(RoomVersions.V6, pleb),
             auth_events.values(),
         )
 
         # A user cannot be force-joined to a room.
         with self.assertRaises(AuthError):
             event_auth.check_auth_rules_for_event(
-                RoomVersions.V6,
-                _member_event(pleb, "join", sender=creator),
+                _member_event(RoomVersions.V6, pleb, "join", sender=creator),
                 auth_events.values(),
             )
 
         # Banned should be rejected.
-        auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban")
+        auth_events[("m.room.member", pleb)] = _member_event(
+            RoomVersions.V6, pleb, "ban"
+        )
         with self.assertRaises(AuthError):
             event_auth.check_auth_rules_for_event(
-                RoomVersions.V6,
-                _join_event(pleb),
+                _join_event(RoomVersions.V6, pleb),
                 auth_events.values(),
             )
 
         # A user who left can re-join.
-        auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave")
+        auth_events[("m.room.member", pleb)] = _member_event(
+            RoomVersions.V6, pleb, "leave"
+        )
         event_auth.check_auth_rules_for_event(
-            RoomVersions.V6,
-            _join_event(pleb),
+            _join_event(RoomVersions.V6, pleb),
             auth_events.values(),
         )
 
         # A user can send a join if they're in the room.
-        auth_events[("m.room.member", pleb)] = _member_event(pleb, "join")
+        auth_events[("m.room.member", pleb)] = _member_event(
+            RoomVersions.V6, pleb, "join"
+        )
         event_auth.check_auth_rules_for_event(
-            RoomVersions.V6,
-            _join_event(pleb),
+            _join_event(RoomVersions.V6, pleb),
             auth_events.values(),
         )
 
         # A user can accept an invite.
         auth_events[("m.room.member", pleb)] = _member_event(
-            pleb, "invite", sender=creator
+            RoomVersions.V6, pleb, "invite", sender=creator
         )
         event_auth.check_auth_rules_for_event(
-            RoomVersions.V6,
-            _join_event(pleb),
+            _join_event(RoomVersions.V6, pleb),
             auth_events.values(),
         )
 
@@ -306,64 +300,88 @@ class EventAuthTestCase(unittest.TestCase):
         pleb = "@joiner:example.com"
 
         auth_events = {
-            ("m.room.create", ""): _create_event(creator),
-            ("m.room.member", creator): _join_event(creator),
-            ("m.room.join_rules", ""): _join_rules_event(creator, "invite"),
+            ("m.room.create", ""): _create_event(RoomVersions.V6, creator),
+            ("m.room.member", creator): _join_event(RoomVersions.V6, creator),
+            ("m.room.join_rules", ""): _join_rules_event(
+                RoomVersions.V6, creator, "invite"
+            ),
         }
 
         # A join without an invite is rejected.
         with self.assertRaises(AuthError):
             event_auth.check_auth_rules_for_event(
-                RoomVersions.V6,
-                _join_event(pleb),
+                _join_event(RoomVersions.V6, pleb),
                 auth_events.values(),
             )
 
         # A user cannot be force-joined to a room.
         with self.assertRaises(AuthError):
             event_auth.check_auth_rules_for_event(
-                RoomVersions.V6,
-                _member_event(pleb, "join", sender=creator),
+                _member_event(RoomVersions.V6, pleb, "join", sender=creator),
                 auth_events.values(),
             )
 
         # Banned should be rejected.
-        auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban")
+        auth_events[("m.room.member", pleb)] = _member_event(
+            RoomVersions.V6, pleb, "ban"
+        )
         with self.assertRaises(AuthError):
             event_auth.check_auth_rules_for_event(
-                RoomVersions.V6,
-                _join_event(pleb),
+                _join_event(RoomVersions.V6, pleb),
                 auth_events.values(),
             )
 
         # A user who left cannot re-join.
-        auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave")
+        auth_events[("m.room.member", pleb)] = _member_event(
+            RoomVersions.V6, pleb, "leave"
+        )
         with self.assertRaises(AuthError):
             event_auth.check_auth_rules_for_event(
-                RoomVersions.V6,
-                _join_event(pleb),
+                _join_event(RoomVersions.V6, pleb),
                 auth_events.values(),
             )
 
         # A user can send a join if they're in the room.
-        auth_events[("m.room.member", pleb)] = _member_event(pleb, "join")
+        auth_events[("m.room.member", pleb)] = _member_event(
+            RoomVersions.V6, pleb, "join"
+        )
         event_auth.check_auth_rules_for_event(
-            RoomVersions.V6,
-            _join_event(pleb),
+            _join_event(RoomVersions.V6, pleb),
             auth_events.values(),
         )
 
         # A user can accept an invite.
         auth_events[("m.room.member", pleb)] = _member_event(
-            pleb, "invite", sender=creator
+            RoomVersions.V6, pleb, "invite", sender=creator
         )
         event_auth.check_auth_rules_for_event(
-            RoomVersions.V6,
-            _join_event(pleb),
+            _join_event(RoomVersions.V6, pleb),
             auth_events.values(),
         )
 
-    def test_join_rules_msc3083_restricted(self):
+    def test_join_rules_restricted_old_room(self) -> None:
+        """Old room versions should reject joins to restricted rooms"""
+        creator = "@creator:example.com"
+        pleb = "@joiner:example.com"
+
+        auth_events = {
+            ("m.room.create", ""): _create_event(RoomVersions.V6, creator),
+            ("m.room.member", creator): _join_event(RoomVersions.V6, creator),
+            ("m.room.power_levels", ""): _power_levels_event(
+                RoomVersions.V6, creator, {"invite": 0}
+            ),
+            ("m.room.join_rules", ""): _join_rules_event(
+                RoomVersions.V6, creator, "restricted"
+            ),
+        }
+
+        with self.assertRaises(AuthError):
+            event_auth.check_auth_rules_for_event(
+                _join_event(RoomVersions.V6, pleb),
+                auth_events.values(),
+            )
+
+    def test_join_rules_msc3083_restricted(self) -> None:
         """
         Test joining a restricted room from MSC3083.
 
@@ -377,29 +395,25 @@ class EventAuthTestCase(unittest.TestCase):
         pleb = "@joiner:example.com"
 
         auth_events = {
-            ("m.room.create", ""): _create_event(creator),
-            ("m.room.member", creator): _join_event(creator),
-            ("m.room.power_levels", ""): _power_levels_event(creator, {"invite": 0}),
-            ("m.room.join_rules", ""): _join_rules_event(creator, "restricted"),
+            ("m.room.create", ""): _create_event(RoomVersions.V8, creator),
+            ("m.room.member", creator): _join_event(RoomVersions.V8, creator),
+            ("m.room.power_levels", ""): _power_levels_event(
+                RoomVersions.V8, creator, {"invite": 0}
+            ),
+            ("m.room.join_rules", ""): _join_rules_event(
+                RoomVersions.V8, creator, "restricted"
+            ),
         }
 
-        # Older room versions don't understand this join rule
-        with self.assertRaises(AuthError):
-            event_auth.check_auth_rules_for_event(
-                RoomVersions.V6,
-                _join_event(pleb),
-                auth_events.values(),
-            )
-
         # A properly formatted join event should work.
         authorised_join_event = _join_event(
+            RoomVersions.V8,
             pleb,
             additional_content={
                 EventContentFields.AUTHORISING_USER: "@creator:example.com"
             },
         )
         event_auth.check_auth_rules_for_event(
-            RoomVersions.V8,
             authorised_join_event,
             auth_events.values(),
         )
@@ -408,14 +422,16 @@ class EventAuthTestCase(unittest.TestCase):
         # are done properly).
         pl_auth_events = auth_events.copy()
         pl_auth_events[("m.room.power_levels", "")] = _power_levels_event(
-            creator, {"invite": 100, "users": {"@inviter:foo.test": 150}}
+            RoomVersions.V8,
+            creator,
+            {"invite": 100, "users": {"@inviter:foo.test": 150}},
         )
         pl_auth_events[("m.room.member", "@inviter:foo.test")] = _join_event(
-            "@inviter:foo.test"
+            RoomVersions.V8, "@inviter:foo.test"
         )
         event_auth.check_auth_rules_for_event(
-            RoomVersions.V8,
             _join_event(
+                RoomVersions.V8,
                 pleb,
                 additional_content={
                     EventContentFields.AUTHORISING_USER: "@inviter:foo.test"
@@ -427,20 +443,21 @@ class EventAuthTestCase(unittest.TestCase):
         # A join which is missing an authorised server is rejected.
         with self.assertRaises(AuthError):
             event_auth.check_auth_rules_for_event(
-                RoomVersions.V8,
-                _join_event(pleb),
+                _join_event(RoomVersions.V8, pleb),
                 auth_events.values(),
             )
 
         # An join authorised by a user who is not in the room is rejected.
         pl_auth_events = auth_events.copy()
         pl_auth_events[("m.room.power_levels", "")] = _power_levels_event(
-            creator, {"invite": 100, "users": {"@other:example.com": 150}}
+            RoomVersions.V8,
+            creator,
+            {"invite": 100, "users": {"@other:example.com": 150}},
         )
         with self.assertRaises(AuthError):
             event_auth.check_auth_rules_for_event(
-                RoomVersions.V8,
                 _join_event(
+                    RoomVersions.V8,
                     pleb,
                     additional_content={
                         EventContentFields.AUTHORISING_USER: "@other:example.com"
@@ -453,8 +470,8 @@ class EventAuthTestCase(unittest.TestCase):
         # *would* be valid, but is sent be a different user.)
         with self.assertRaises(AuthError):
             event_auth.check_auth_rules_for_event(
-                RoomVersions.V8,
                 _member_event(
+                    RoomVersions.V8,
                     pleb,
                     "join",
                     sender=creator,
@@ -466,39 +483,41 @@ class EventAuthTestCase(unittest.TestCase):
             )
 
         # Banned should be rejected.
-        auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban")
+        auth_events[("m.room.member", pleb)] = _member_event(
+            RoomVersions.V8, pleb, "ban"
+        )
         with self.assertRaises(AuthError):
             event_auth.check_auth_rules_for_event(
-                RoomVersions.V8,
                 authorised_join_event,
                 auth_events.values(),
             )
 
         # A user who left can re-join.
-        auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave")
+        auth_events[("m.room.member", pleb)] = _member_event(
+            RoomVersions.V8, pleb, "leave"
+        )
         event_auth.check_auth_rules_for_event(
-            RoomVersions.V8,
             authorised_join_event,
             auth_events.values(),
         )
 
         # A user can send a join if they're in the room. (This doesn't need to
         # be authorised since the user is already joined.)
-        auth_events[("m.room.member", pleb)] = _member_event(pleb, "join")
+        auth_events[("m.room.member", pleb)] = _member_event(
+            RoomVersions.V8, pleb, "join"
+        )
         event_auth.check_auth_rules_for_event(
-            RoomVersions.V8,
-            _join_event(pleb),
+            _join_event(RoomVersions.V8, pleb),
             auth_events.values(),
         )
 
         # A user can accept an invite. (This doesn't need to be authorised since
         # the user was invited.)
         auth_events[("m.room.member", pleb)] = _member_event(
-            pleb, "invite", sender=creator
+            RoomVersions.V8, pleb, "invite", sender=creator
         )
         event_auth.check_auth_rules_for_event(
-            RoomVersions.V8,
-            _join_event(pleb),
+            _join_event(RoomVersions.V8, pleb),
             auth_events.values(),
         )
 
@@ -508,20 +527,25 @@ class EventAuthTestCase(unittest.TestCase):
 TEST_ROOM_ID = "!test:room"
 
 
-def _create_event(user_id: str) -> EventBase:
+def _create_event(
+    room_version: RoomVersion,
+    user_id: str,
+) -> EventBase:
     return make_event_from_dict(
         {
             "room_id": TEST_ROOM_ID,
-            "event_id": _get_event_id(),
+            **_maybe_get_event_id_dict_for_room_version(room_version),
             "type": "m.room.create",
             "state_key": "",
             "sender": user_id,
             "content": {"creator": user_id},
-        }
+        },
+        room_version=room_version,
     )
 
 
 def _member_event(
+    room_version: RoomVersion,
     user_id: str,
     membership: str,
     sender: Optional[str] = None,
@@ -530,79 +554,102 @@ def _member_event(
     return make_event_from_dict(
         {
             "room_id": TEST_ROOM_ID,
-            "event_id": _get_event_id(),
+            **_maybe_get_event_id_dict_for_room_version(room_version),
             "type": "m.room.member",
             "sender": sender or user_id,
             "state_key": user_id,
             "content": {"membership": membership, **(additional_content or {})},
             "prev_events": [],
-        }
+        },
+        room_version=room_version,
     )
 
 
-def _join_event(user_id: str, additional_content: Optional[dict] = None) -> EventBase:
-    return _member_event(user_id, "join", additional_content=additional_content)
+def _join_event(
+    room_version: RoomVersion,
+    user_id: str,
+    additional_content: Optional[dict] = None,
+) -> EventBase:
+    return _member_event(
+        room_version,
+        user_id,
+        "join",
+        additional_content=additional_content,
+    )
 
 
-def _power_levels_event(sender: str, content: JsonDict) -> EventBase:
+def _power_levels_event(
+    room_version: RoomVersion,
+    sender: str,
+    content: JsonDict,
+) -> EventBase:
     return make_event_from_dict(
         {
             "room_id": TEST_ROOM_ID,
-            "event_id": _get_event_id(),
+            **_maybe_get_event_id_dict_for_room_version(room_version),
             "type": "m.room.power_levels",
             "sender": sender,
             "state_key": "",
             "content": content,
-        }
+        },
+        room_version=room_version,
     )
 
 
-def _alias_event(sender: str, **kwargs) -> EventBase:
+def _alias_event(room_version: RoomVersion, sender: str, **kwargs) -> EventBase:
     data = {
         "room_id": TEST_ROOM_ID,
-        "event_id": _get_event_id(),
+        **_maybe_get_event_id_dict_for_room_version(room_version),
         "type": "m.room.aliases",
         "sender": sender,
         "state_key": get_domain_from_id(sender),
         "content": {"aliases": []},
     }
     data.update(**kwargs)
-    return make_event_from_dict(data)
+    return make_event_from_dict(data, room_version=room_version)
 
 
-def _random_state_event(sender: str) -> EventBase:
+def _random_state_event(room_version: RoomVersion, sender: str) -> EventBase:
     return make_event_from_dict(
         {
             "room_id": TEST_ROOM_ID,
-            "event_id": _get_event_id(),
+            **_maybe_get_event_id_dict_for_room_version(room_version),
             "type": "test.state",
             "sender": sender,
             "state_key": "",
             "content": {"membership": "join"},
-        }
+        },
+        room_version=room_version,
     )
 
 
-def _join_rules_event(sender: str, join_rule: str) -> EventBase:
+def _join_rules_event(
+    room_version: RoomVersion, sender: str, join_rule: str
+) -> EventBase:
     return make_event_from_dict(
         {
             "room_id": TEST_ROOM_ID,
-            "event_id": _get_event_id(),
+            **_maybe_get_event_id_dict_for_room_version(room_version),
             "type": "m.room.join_rules",
             "sender": sender,
             "state_key": "",
             "content": {
                 "join_rule": join_rule,
             },
-        }
+        },
+        room_version=room_version,
     )
 
 
 event_count = 0
 
 
-def _get_event_id() -> str:
+def _maybe_get_event_id_dict_for_room_version(room_version: RoomVersion) -> dict:
+    """If this room version needs it, generate an event id"""
+    if room_version.event_format != EventFormatVersions.V1:
+        return {}
+
     global event_count
     c = event_count
     event_count += 1
-    return "!%i:example.com" % (c,)
+    return {"event_id": "!%i:example.com" % (c,)}
diff --git a/tests/test_server.py b/tests/test_server.py
index 0f1eb43cbc..847432f791 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -34,7 +34,7 @@ from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests import unittest
-from tests.http.server._base import EndpointCancellationTestHelperMixin
+from tests.http.server._base import test_disconnect
 from tests.server import (
     FakeSite,
     ThreadedMemoryReactorClock,
@@ -407,7 +407,7 @@ class CancellableDirectServeHtmlResource(DirectServeHtmlResource):
         return HTTPStatus.OK, b"ok"
 
 
-class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMixin):
+class DirectServeJsonResourceCancellationTests(unittest.TestCase):
     """Tests for `DirectServeJsonResource` cancellation."""
 
     def setUp(self):
@@ -421,7 +421,7 @@ class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMix
         channel = make_request(
             self.reactor, self.site, "GET", "/sleep", await_result=False
         )
-        self._test_disconnect(
+        test_disconnect(
             self.reactor,
             channel,
             expect_cancellation=True,
@@ -433,7 +433,7 @@ class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMix
         channel = make_request(
             self.reactor, self.site, "POST", "/sleep", await_result=False
         )
-        self._test_disconnect(
+        test_disconnect(
             self.reactor,
             channel,
             expect_cancellation=False,
@@ -441,7 +441,7 @@ class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMix
         )
 
 
-class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMixin):
+class DirectServeHtmlResourceCancellationTests(unittest.TestCase):
     """Tests for `DirectServeHtmlResource` cancellation."""
 
     def setUp(self):
@@ -455,7 +455,7 @@ class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMix
         channel = make_request(
             self.reactor, self.site, "GET", "/sleep", await_result=False
         )
-        self._test_disconnect(
+        test_disconnect(
             self.reactor,
             channel,
             expect_cancellation=True,
@@ -467,6 +467,6 @@ class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMix
         channel = make_request(
             self.reactor, self.site, "POST", "/sleep", await_result=False
         )
-        self._test_disconnect(
+        test_disconnect(
             self.reactor, channel, expect_cancellation=False, expected_body=b"ok"
         )
diff --git a/tests/test_state.py b/tests/test_state.py
index 95f81bebae..b005dd8d0f 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -11,7 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Collection, Dict, List, Optional
+from typing import Collection, Dict, List, Optional, cast
 from unittest.mock import Mock
 
 from twisted.internet import defer
@@ -22,6 +22,8 @@ from synapse.api.room_versions import RoomVersions
 from synapse.events import make_event_from_dict
 from synapse.events.snapshot import EventContext
 from synapse.state import StateHandler, StateResolutionHandler
+from synapse.util import Clock
+from synapse.util.macaroons import MacaroonGenerator
 
 from tests import unittest
 
@@ -190,13 +192,18 @@ class StateTestCase(unittest.TestCase):
                 "get_clock",
                 "get_state_resolution_handler",
                 "get_account_validity_handler",
+                "get_macaroon_generator",
                 "hostname",
             ]
         )
+        clock = cast(Clock, MockClock())
         hs.config = default_config("tesths", True)
         hs.get_datastores.return_value = Mock(main=self.dummy_store)
         hs.get_state_handler.return_value = None
-        hs.get_clock.return_value = MockClock()
+        hs.get_clock.return_value = clock
+        hs.get_macaroon_generator.return_value = MacaroonGenerator(
+            clock, "tesths", b"verysecret"
+        )
         hs.get_auth.return_value = Auth(hs)
         hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
         hs.get_storage_controllers.return_value = storage_controllers
diff --git a/tests/test_types.py b/tests/test_types.py
index 0b10dae848..d8d82a517e 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -26,10 +26,21 @@ class UserIDTestCase(unittest.HomeserverTestCase):
         self.assertEqual("test", user.domain)
         self.assertEqual(True, self.hs.is_mine(user))
 
-    def test_pase_empty(self):
+    def test_parse_rejects_empty_id(self):
         with self.assertRaises(SynapseError):
             UserID.from_string("")
 
+    def test_parse_rejects_missing_sigil(self):
+        with self.assertRaises(SynapseError):
+            UserID.from_string("alice:example.com")
+
+    def test_parse_rejects_missing_separator(self):
+        with self.assertRaises(SynapseError):
+            UserID.from_string("@alice.example.com")
+
+    def test_validation_rejects_missing_domain(self):
+        self.assertFalse(UserID.is_valid("@alice:"))
+
     def test_build(self):
         user = UserID("5678efgh", "my.domain")
 
diff --git a/tests/unittest.py b/tests/unittest.py
index e7f255b4fa..c645dd3563 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -315,7 +315,7 @@ class HomeserverTestCase(TestCase):
                         "is_guest": False,
                     }
 
-                async def get_user_by_req(request, allow_guest=False, rights="access"):
+                async def get_user_by_req(request, allow_guest=False):
                     assert self.helper.auth_user_id is not None
                     return create_requester(
                         UserID.from_string(self.helper.auth_user_id),
diff --git a/tests/util/test_macaroons.py b/tests/util/test_macaroons.py
new file mode 100644
index 0000000000..32125f7bb7
--- /dev/null
+++ b/tests/util/test_macaroons.py
@@ -0,0 +1,146 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from pymacaroons.exceptions import MacaroonVerificationFailedException
+
+from synapse.util.macaroons import MacaroonGenerator, OidcSessionData
+
+from tests.server import get_clock
+from tests.unittest import TestCase
+
+
+class MacaroonGeneratorTestCase(TestCase):
+    def setUp(self):
+        self.reactor, hs_clock = get_clock()
+        self.macaroon_generator = MacaroonGenerator(hs_clock, "tesths", b"verysecret")
+        self.other_macaroon_generator = MacaroonGenerator(
+            hs_clock, "tesths", b"anothersecretkey"
+        )
+
+    def test_guest_access_token(self):
+        """Test the generation and verification of guest access tokens"""
+        token = self.macaroon_generator.generate_guest_access_token("@user:tesths")
+        user_id = self.macaroon_generator.verify_guest_token(token)
+        self.assertEqual(user_id, "@user:tesths")
+
+        # Raises with another secret key
+        with self.assertRaises(MacaroonVerificationFailedException):
+            self.other_macaroon_generator.verify_guest_token(token)
+
+        # Check that an old access token without the guest caveat does not work
+        macaroon = self.macaroon_generator._generate_base_macaroon("access")
+        macaroon.add_first_party_caveat(f"user_id = {user_id}")
+        macaroon.add_first_party_caveat("nonce = 0123456789abcdef")
+        token = macaroon.serialize()
+
+        with self.assertRaises(MacaroonVerificationFailedException):
+            self.macaroon_generator.verify_guest_token(token)
+
+    def test_delete_pusher_token(self):
+        """Test the generation and verification of delete_pusher tokens"""
+        token = self.macaroon_generator.generate_delete_pusher_token(
+            "@user:tesths", "m.mail", "john@example.com"
+        )
+        user_id = self.macaroon_generator.verify_delete_pusher_token(
+            token, "m.mail", "john@example.com"
+        )
+        self.assertEqual(user_id, "@user:tesths")
+
+        # Raises with another secret key
+        with self.assertRaises(MacaroonVerificationFailedException):
+            self.other_macaroon_generator.verify_delete_pusher_token(
+                token, "m.mail", "john@example.com"
+            )
+
+        # Raises when verifying for another pushkey
+        with self.assertRaises(MacaroonVerificationFailedException):
+            self.macaroon_generator.verify_delete_pusher_token(
+                token, "m.mail", "other@example.com"
+            )
+
+        # Raises when verifying for another app_id
+        with self.assertRaises(MacaroonVerificationFailedException):
+            self.macaroon_generator.verify_delete_pusher_token(
+                token, "somethingelse", "john@example.com"
+            )
+
+        # Check that an old token without the app_id and pushkey still works
+        macaroon = self.macaroon_generator._generate_base_macaroon("delete_pusher")
+        macaroon.add_first_party_caveat("user_id = @user:tesths")
+        token = macaroon.serialize()
+        user_id = self.macaroon_generator.verify_delete_pusher_token(
+            token, "m.mail", "john@example.com"
+        )
+        self.assertEqual(user_id, "@user:tesths")
+
+    def test_short_term_login_token(self):
+        """Test the generation and verification of short-term login tokens"""
+        token = self.macaroon_generator.generate_short_term_login_token(
+            user_id="@user:tesths",
+            auth_provider_id="oidc",
+            auth_provider_session_id="sid",
+            duration_in_ms=2 * 60 * 1000,
+        )
+
+        info = self.macaroon_generator.verify_short_term_login_token(token)
+        self.assertEqual(info.user_id, "@user:tesths")
+        self.assertEqual(info.auth_provider_id, "oidc")
+        self.assertEqual(info.auth_provider_session_id, "sid")
+
+        # Raises with another secret key
+        with self.assertRaises(MacaroonVerificationFailedException):
+            self.other_macaroon_generator.verify_short_term_login_token(token)
+
+        # Wait a minute
+        self.reactor.pump([60])
+        # Shouldn't raise
+        self.macaroon_generator.verify_short_term_login_token(token)
+        # Wait another minute
+        self.reactor.pump([60])
+        # Should raise since it expired
+        with self.assertRaises(MacaroonVerificationFailedException):
+            self.macaroon_generator.verify_short_term_login_token(token)
+
+    def test_oidc_session_token(self):
+        """Test the generation and verification of OIDC session cookies"""
+        state = "arandomstate"
+        session_data = OidcSessionData(
+            idp_id="oidc",
+            nonce="nonce",
+            client_redirect_url="https://example.com/",
+            ui_auth_session_id="",
+        )
+        token = self.macaroon_generator.generate_oidc_session_token(
+            state, session_data, duration_in_ms=2 * 60 * 1000
+        ).encode("utf-8")
+        info = self.macaroon_generator.verify_oidc_session_token(token, state)
+        self.assertEqual(session_data, info)
+
+        # Raises with another secret key
+        with self.assertRaises(MacaroonVerificationFailedException):
+            self.other_macaroon_generator.verify_oidc_session_token(token, state)
+
+        # Should raise with another state
+        with self.assertRaises(MacaroonVerificationFailedException):
+            self.macaroon_generator.verify_oidc_session_token(token, "anotherstate")
+
+        # Wait a minute
+        self.reactor.pump([60])
+        # Shouldn't raise
+        self.macaroon_generator.verify_oidc_session_token(token, state)
+        # Wait another minute
+        self.reactor.pump([60])
+        # Should raise since it expired
+        with self.assertRaises(MacaroonVerificationFailedException):
+            self.macaroon_generator.verify_oidc_session_token(token, state)