summary refs log tree commit diff
path: root/tests/rest/client
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client')
-rw-r--r--tests/rest/client/test_account.py4
-rw-r--r--tests/rest/client/test_auth.py19
-rw-r--r--tests/rest/client/test_capabilities.py1
-rw-r--r--tests/rest/client/test_consent.py1
-rw-r--r--tests/rest/client/test_directory.py1
-rw-r--r--tests/rest/client/test_ephemeral_message.py1
-rw-r--r--tests/rest/client/test_events.py3
-rw-r--r--tests/rest/client/test_filter.py1
-rw-r--r--tests/rest/client/test_keys.py141
-rw-r--r--tests/rest/client/test_login.py2
-rw-r--r--tests/rest/client/test_login_token_request.py1
-rw-r--r--tests/rest/client/test_presence.py1
-rw-r--r--tests/rest/client/test_profile.py3
-rw-r--r--tests/rest/client/test_register.py4
-rw-r--r--tests/rest/client/test_relations.py237
-rw-r--r--tests/rest/client/test_rendezvous.py1
-rw-r--r--tests/rest/client/test_rooms.py18
-rw-r--r--tests/rest/client/test_sync.py3
-rw-r--r--tests/rest/client/test_third_party_rules.py126
-rw-r--r--tests/rest/client/utils.py58
20 files changed, 352 insertions, 274 deletions
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index e2ee1a1766..2b05dffc7d 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -40,7 +40,6 @@ from tests.unittest import override_config
 
 
 class PasswordResetTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         account.register_servlets,
         synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -408,7 +407,6 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
 
 
 class DeactivateTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         login.register_servlets,
@@ -492,7 +490,6 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
 
 
 class WhoamiTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         login.register_servlets,
@@ -567,7 +564,6 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
 
 
 class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         account.register_servlets,
         login.register_servlets,
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 208ec44829..0d8fe77b88 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -34,7 +34,7 @@ from synapse.util import Clock
 from tests import unittest
 from tests.handlers.test_oidc import HAS_OIDC
 from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER
-from tests.server import FakeChannel, make_request
+from tests.server import FakeChannel
 from tests.unittest import override_config, skip_unless
 
 
@@ -43,13 +43,15 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker):
         super().__init__(hs)
         self.recaptcha_attempts: List[Tuple[dict, str]] = []
 
+    def is_enabled(self) -> bool:
+        return True
+
     def check_auth(self, authdict: dict, clientip: str) -> Any:
         self.recaptcha_attempts.append((authdict, clientip))
         return succeed(True)
 
 
 class FallbackAuthTests(unittest.HomeserverTestCase):
-
     servlets = [
         auth.register_servlets,
         register.register_servlets,
@@ -57,7 +59,6 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
     hijack_auth = False
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
         config = self.default_config()
 
         config["enable_registration_captcha"] = True
@@ -1319,16 +1320,8 @@ class OidcBackchannelLogoutTests(unittest.HomeserverTestCase):
         channel = self.submit_logout_token(logout_token)
         self.assertEqual(channel.code, 200)
 
-        # Now try to exchange the login token
-        channel = make_request(
-            self.hs.get_reactor(),
-            self.site,
-            "POST",
-            "/login",
-            content={"type": "m.login.token", "token": login_token},
-        )
-        # It should have failed
-        self.assertEqual(channel.code, 403)
+        # Now try to exchange the login token, it should fail.
+        self.helper.login_via_token(login_token, 403)
 
     @override_config(
         {
diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py
index d1751e1557..c16e8d43f4 100644
--- a/tests/rest/client/test_capabilities.py
+++ b/tests/rest/client/test_capabilities.py
@@ -26,7 +26,6 @@ from tests.unittest import override_config
 
 
 class CapabilitiesTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         capabilities.register_servlets,
diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py
index b1ca81a911..bb845179d3 100644
--- a/tests/rest/client/test_consent.py
+++ b/tests/rest/client/test_consent.py
@@ -38,7 +38,6 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
     hijack_auth = False
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
         config = self.default_config()
         config["form_secret"] = "123abc"
 
diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py
index 7a88aa2cda..6490e883bf 100644
--- a/tests/rest/client/test_directory.py
+++ b/tests/rest/client/test_directory.py
@@ -28,7 +28,6 @@ from tests.unittest import override_config
 
 
 class DirectoryTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         admin.register_servlets_for_client_rest_resource,
         directory.register_servlets,
diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py
index 9fa1f82dfe..f31ebc8021 100644
--- a/tests/rest/client/test_ephemeral_message.py
+++ b/tests/rest/client/test_ephemeral_message.py
@@ -26,7 +26,6 @@ from tests import unittest
 
 
 class EphemeralMessageTestCase(unittest.HomeserverTestCase):
-
     user_id = "@user:test"
 
     servlets = [
diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py
index a9b7db9db2..54df2a252c 100644
--- a/tests/rest/client/test_events.py
+++ b/tests/rest/client/test_events.py
@@ -38,7 +38,6 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
     ]
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
         config = self.default_config()
         config["enable_registration_captcha"] = False
         config["enable_registration"] = True
@@ -51,7 +50,6 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
         return hs
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-
         # register an account
         self.user_id = self.register_user("sid1", "pass")
         self.token = self.login(self.user_id, "pass")
@@ -142,7 +140,6 @@ class GetEventsTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-
         # register an account
         self.user_id = self.register_user("sid1", "pass")
         self.token = self.login(self.user_id, "pass")
diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py
index 830762fd53..91678abf13 100644
--- a/tests/rest/client/test_filter.py
+++ b/tests/rest/client/test_filter.py
@@ -25,7 +25,6 @@ PATH_PREFIX = "/_matrix/client/v2_alpha"
 
 
 class FilterTestCase(unittest.HomeserverTestCase):
-
     user_id = "@apple:test"
     hijack_auth = True
     EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}}
diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py
index 741fecea77..8ee5489057 100644
--- a/tests/rest/client/test_keys.py
+++ b/tests/rest/client/test_keys.py
@@ -14,12 +14,21 @@
 
 from http import HTTPStatus
 
+from signedjson.key import (
+    encode_verify_key_base64,
+    generate_signing_key,
+    get_verify_key,
+)
+from signedjson.sign import sign_json
+
 from synapse.api.errors import Codes
 from synapse.rest import admin
 from synapse.rest.client import keys, login
+from synapse.types import JsonDict
 
 from tests import unittest
 from tests.http.server._base import make_request_with_cancellation_test
+from tests.unittest import override_config
 
 
 class KeyQueryTestCase(unittest.HomeserverTestCase):
@@ -118,3 +127,135 @@ class KeyQueryTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(200, channel.code, msg=channel.result["body"])
         self.assertIn(bob, channel.json_body["device_keys"])
+
+    def make_device_keys(self, user_id: str, device_id: str) -> JsonDict:
+        # We only generate a master key to simplify the test.
+        master_signing_key = generate_signing_key(device_id)
+        master_verify_key = encode_verify_key_base64(get_verify_key(master_signing_key))
+
+        return {
+            "master_key": sign_json(
+                {
+                    "user_id": user_id,
+                    "usage": ["master"],
+                    "keys": {"ed25519:" + master_verify_key: master_verify_key},
+                },
+                user_id,
+                master_signing_key,
+            ),
+        }
+
+    def test_device_signing_with_uia(self) -> None:
+        """Device signing key upload requires UIA."""
+        password = "wonderland"
+        device_id = "ABCDEFGHI"
+        alice_id = self.register_user("alice", password)
+        alice_token = self.login("alice", password, device_id=device_id)
+
+        content = self.make_device_keys(alice_id, device_id)
+
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/v3/keys/device_signing/upload",
+            content,
+            alice_token,
+        )
+
+        self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+        # Grab the session
+        session = channel.json_body["session"]
+        # Ensure that flows are what is expected.
+        self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+
+        # add UI auth
+        content["auth"] = {
+            "type": "m.login.password",
+            "identifier": {"type": "m.id.user", "user": alice_id},
+            "password": password,
+            "session": session,
+        }
+
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/v3/keys/device_signing/upload",
+            content,
+            alice_token,
+        )
+
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+    @override_config({"ui_auth": {"session_timeout": "15m"}})
+    def test_device_signing_with_uia_session_timeout(self) -> None:
+        """Device signing key upload requires UIA buy passes with grace period."""
+        password = "wonderland"
+        device_id = "ABCDEFGHI"
+        alice_id = self.register_user("alice", password)
+        alice_token = self.login("alice", password, device_id=device_id)
+
+        content = self.make_device_keys(alice_id, device_id)
+
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/v3/keys/device_signing/upload",
+            content,
+            alice_token,
+        )
+
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+    @override_config(
+        {
+            "experimental_features": {"msc3967_enabled": True},
+            "ui_auth": {"session_timeout": "15s"},
+        }
+    )
+    def test_device_signing_with_msc3967(self) -> None:
+        """Device signing key follows MSC3967 behaviour when enabled."""
+        password = "wonderland"
+        device_id = "ABCDEFGHI"
+        alice_id = self.register_user("alice", password)
+        alice_token = self.login("alice", password, device_id=device_id)
+
+        keys1 = self.make_device_keys(alice_id, device_id)
+
+        # Initial request should succeed as no existing keys are present.
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/v3/keys/device_signing/upload",
+            keys1,
+            alice_token,
+        )
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+        keys2 = self.make_device_keys(alice_id, device_id)
+
+        # Subsequent request should require UIA as keys already exist even though session_timeout is set.
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/v3/keys/device_signing/upload",
+            keys2,
+            alice_token,
+        )
+        self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+
+        # Grab the session
+        session = channel.json_body["session"]
+        # Ensure that flows are what is expected.
+        self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+
+        # add UI auth
+        keys2["auth"] = {
+            "type": "m.login.password",
+            "identifier": {"type": "m.id.user", "user": alice_id},
+            "password": password,
+            "session": session,
+        }
+
+        # Request should complete
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/v3/keys/device_signing/upload",
+            keys2,
+            alice_token,
+        )
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index ff5baa9f0a..62acf4f44e 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -89,7 +89,6 @@ ADDITIONAL_LOGIN_FLOWS = [
 
 
 class LoginRestServletTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         login.register_servlets,
@@ -737,7 +736,6 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
 
 
 class CASTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         login.register_servlets,
     ]
diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py
index 6aedc1a11c..b8187db982 100644
--- a/tests/rest/client/test_login_token_request.py
+++ b/tests/rest/client/test_login_token_request.py
@@ -26,7 +26,6 @@ endpoint = "/_matrix/client/unstable/org.matrix.msc3882/login/token"
 
 
 class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         login.register_servlets,
         admin.register_servlets,
diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py
index 67e16880e6..dcbb125a3b 100644
--- a/tests/rest/client/test_presence.py
+++ b/tests/rest/client/test_presence.py
@@ -35,7 +35,6 @@ class PresenceTestCase(unittest.HomeserverTestCase):
     servlets = [presence.register_servlets]
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
         self.presence_handler = Mock(spec=PresenceHandler)
         self.presence_handler.set_state.return_value = make_awaitable(None)
 
diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py
index 8de5a342ae..27c93ad761 100644
--- a/tests/rest/client/test_profile.py
+++ b/tests/rest/client/test_profile.py
@@ -30,7 +30,6 @@ from tests import unittest
 
 
 class ProfileTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         admin.register_servlets_for_client_rest_resource,
         login.register_servlets,
@@ -324,7 +323,6 @@ class ProfileTestCase(unittest.HomeserverTestCase):
 
 
 class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         admin.register_servlets_for_client_rest_resource,
         login.register_servlets,
@@ -404,7 +402,6 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
 
 
 class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         admin.register_servlets_for_client_rest_resource,
         login.register_servlets,
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 4c561f9525..b228dba861 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -40,7 +40,6 @@ from tests.unittest import override_config
 
 
 class RegisterRestServletTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         login.register_servlets,
         register.register_servlets,
@@ -797,7 +796,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
 
 class AccountValidityTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         register.register_servlets,
         synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -913,7 +911,6 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
 
 
 class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         register.register_servlets,
         synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -1132,7 +1129,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
 
 
 class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
-
     servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index c8a6911d5e..fbbbcb23f1 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -30,7 +30,6 @@ from tests import unittest
 from tests.server import FakeChannel
 from tests.test_utils import make_awaitable
 from tests.test_utils.event_injection import inject_event
-from tests.unittest import override_config
 
 
 class BaseRelationsTestCase(unittest.HomeserverTestCase):
@@ -403,7 +402,7 @@ class RelationsTestCase(BaseRelationsTestCase):
 
     def test_edit(self) -> None:
         """Test that a simple edit works."""
-
+        orig_body = {"body": "Hi!", "msgtype": "m.text"}
         new_body = {"msgtype": "m.text", "body": "I've been edited!"}
         edit_event_content = {
             "msgtype": "m.text",
@@ -424,9 +423,7 @@ class RelationsTestCase(BaseRelationsTestCase):
             access_token=self.user_token,
         )
         self.assertEqual(200, channel.code, channel.json_body)
-        self.assertEqual(
-            channel.json_body["content"], {"body": "Hi!", "msgtype": "m.text"}
-        )
+        self.assertEqual(channel.json_body["content"], orig_body)
         self._assert_edit_bundle(channel.json_body, edit_event_id, edit_event_content)
 
         # Request the room messages.
@@ -443,7 +440,7 @@ class RelationsTestCase(BaseRelationsTestCase):
         )
 
         # Request the room context.
-        # /context should return the edited event.
+        # /context should return the event.
         channel = self.make_request(
             "GET",
             f"/rooms/{self.room}/context/{self.parent_id}",
@@ -453,7 +450,7 @@ class RelationsTestCase(BaseRelationsTestCase):
         self._assert_edit_bundle(
             channel.json_body["event"], edit_event_id, edit_event_content
         )
-        self.assertEqual(channel.json_body["event"]["content"], new_body)
+        self.assertEqual(channel.json_body["event"]["content"], orig_body)
 
         # Request sync, but limit the timeline so it becomes limited (and includes
         # bundled aggregations).
@@ -491,45 +488,11 @@ class RelationsTestCase(BaseRelationsTestCase):
             edit_event_content,
         )
 
-    @override_config({"experimental_features": {"msc3925_inhibit_edit": True}})
-    def test_edit_inhibit_replace(self) -> None:
-        """
-        If msc3925_inhibit_edit is enabled, then the original event should not be
-        replaced.
-        """
-
-        new_body = {"msgtype": "m.text", "body": "I've been edited!"}
-        edit_event_content = {
-            "msgtype": "m.text",
-            "body": "foo",
-            "m.new_content": new_body,
-        }
-        channel = self._send_relation(
-            RelationTypes.REPLACE,
-            "m.room.message",
-            content=edit_event_content,
-        )
-        edit_event_id = channel.json_body["event_id"]
-
-        # /context should return the *original* event.
-        channel = self.make_request(
-            "GET",
-            f"/rooms/{self.room}/context/{self.parent_id}",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        self.assertEqual(
-            channel.json_body["event"]["content"], {"body": "Hi!", "msgtype": "m.text"}
-        )
-        self._assert_edit_bundle(
-            channel.json_body["event"], edit_event_id, edit_event_content
-        )
-
     def test_multi_edit(self) -> None:
         """Test that multiple edits, including attempts by people who
         shouldn't be allowed, are correctly handled.
         """
-
+        orig_body = orig_body = {"body": "Hi!", "msgtype": "m.text"}
         self._send_relation(
             RelationTypes.REPLACE,
             "m.room.message",
@@ -570,7 +533,7 @@ class RelationsTestCase(BaseRelationsTestCase):
         )
         self.assertEqual(200, channel.code, channel.json_body)
 
-        self.assertEqual(channel.json_body["event"]["content"], new_body)
+        self.assertEqual(channel.json_body["event"]["content"], orig_body)
         self._assert_edit_bundle(
             channel.json_body["event"], edit_event_id, edit_event_content
         )
@@ -642,6 +605,7 @@ class RelationsTestCase(BaseRelationsTestCase):
 
     def test_edit_edit(self) -> None:
         """Test that an edit cannot be edited."""
+        orig_body = {"body": "Hi!", "msgtype": "m.text"}
         new_body = {"msgtype": "m.text", "body": "Initial edit"}
         edit_event_content = {
             "msgtype": "m.text",
@@ -675,14 +639,12 @@ class RelationsTestCase(BaseRelationsTestCase):
             access_token=self.user_token,
         )
         self.assertEqual(200, channel.code, channel.json_body)
-        self.assertEqual(
-            channel.json_body["content"], {"body": "Hi!", "msgtype": "m.text"}
-        )
+        self.assertEqual(channel.json_body["content"], orig_body)
 
         # The relations information should not include the edit to the edit.
         self._assert_edit_bundle(channel.json_body, edit_event_id, edit_event_content)
 
-        # /context should return the event updated for the *first* edit
+        # /context should return the bundled edit for the *first* edit
         # (The edit to the edit should be ignored.)
         channel = self.make_request(
             "GET",
@@ -690,7 +652,7 @@ class RelationsTestCase(BaseRelationsTestCase):
             access_token=self.user_token,
         )
         self.assertEqual(200, channel.code, channel.json_body)
-        self.assertEqual(channel.json_body["event"]["content"], new_body)
+        self.assertEqual(channel.json_body["event"]["content"], orig_body)
         self._assert_edit_bundle(
             channel.json_body["event"], edit_event_id, edit_event_content
         )
@@ -1080,48 +1042,6 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
         ]
         assert_bundle(self._find_event_in_chunk(chunk))
 
-    def test_annotation(self) -> None:
-        """
-        Test that annotations get correctly bundled.
-        """
-        # Setup by sending a variety of relations.
-        self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        self._send_relation(
-            RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
-        )
-        self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
-
-        def assert_annotations(bundled_aggregations: JsonDict) -> None:
-            self.assertEqual(
-                {
-                    "chunk": [
-                        {"type": "m.reaction", "key": "a", "count": 2},
-                        {"type": "m.reaction", "key": "b", "count": 1},
-                    ]
-                },
-                bundled_aggregations,
-            )
-
-        self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7)
-
-    def test_annotation_to_annotation(self) -> None:
-        """Any relation to an annotation should be ignored."""
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        event_id = channel.json_body["event_id"]
-        self._send_relation(
-            RelationTypes.ANNOTATION, "m.reaction", "b", parent_id=event_id
-        )
-
-        # Fetch the initial annotation event to see if it has bundled aggregations.
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/v3/rooms/{self.room}/event/{event_id}",
-            access_token=self.user_token,
-        )
-        self.assertEquals(200, channel.code, channel.json_body)
-        # The first annotationt should not have any bundled aggregations.
-        self.assertNotIn("m.relations", channel.json_body["unsigned"])
-
     def test_reference(self) -> None:
         """
         Test that references get correctly bundled.
@@ -1138,7 +1058,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
                 bundled_aggregations,
             )
 
-        self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7)
+        self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 6)
 
     def test_thread(self) -> None:
         """
@@ -1183,7 +1103,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
 
         # The "user" sent the root event and is making queries for the bundled
         # aggregations: they have participated.
-        self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 7)
+        self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 6)
         # The "user2" sent replies in the thread and is making queries for the
         # bundled aggregations: they have participated.
         #
@@ -1208,9 +1128,10 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
         channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
         thread_2 = channel.json_body["event_id"]
 
-        self._send_relation(
-            RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_2
+        channel = self._send_relation(
+            RelationTypes.REFERENCE, "org.matrix.test", parent_id=thread_2
         )
+        reference_event_id = channel.json_body["event_id"]
 
         def assert_thread(bundled_aggregations: JsonDict) -> None:
             self.assertEqual(2, bundled_aggregations.get("count"))
@@ -1235,17 +1156,15 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
             self.assert_dict(
                 {
                     "m.relations": {
-                        RelationTypes.ANNOTATION: {
-                            "chunk": [
-                                {"type": "m.reaction", "key": "a", "count": 1},
-                            ]
+                        RelationTypes.REFERENCE: {
+                            "chunk": [{"event_id": reference_event_id}]
                         },
                     }
                 },
                 bundled_aggregations["latest_event"].get("unsigned"),
             )
 
-        self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 7)
+        self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 6)
 
     def test_nested_thread(self) -> None:
         """
@@ -1330,7 +1249,6 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
         thread_summary = relations_dict[RelationTypes.THREAD]
         self.assertIn("latest_event", thread_summary)
         latest_event_in_thread = thread_summary["latest_event"]
-        self.assertEqual(latest_event_in_thread["content"]["body"], "I've been edited!")
         # The latest event in the thread should have the edit appear under the
         # bundled aggregations.
         self.assertDictContainsSubset(
@@ -1363,10 +1281,11 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
         channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
         thread_id = channel.json_body["event_id"]
 
-        # Annotate the thread.
-        self._send_relation(
-            RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id
+        # Make a reference to the thread.
+        channel = self._send_relation(
+            RelationTypes.REFERENCE, "org.matrix.test", parent_id=thread_id
         )
+        reference_event_id = channel.json_body["event_id"]
 
         channel = self.make_request(
             "GET",
@@ -1377,9 +1296,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
         self.assertEqual(
             channel.json_body["unsigned"].get("m.relations"),
             {
-                RelationTypes.ANNOTATION: {
-                    "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
-                },
+                RelationTypes.REFERENCE: {"chunk": [{"event_id": reference_event_id}]},
             },
         )
 
@@ -1396,9 +1313,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
         self.assertEqual(
             thread_message["unsigned"].get("m.relations"),
             {
-                RelationTypes.ANNOTATION: {
-                    "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
-                },
+                RelationTypes.REFERENCE: {"chunk": [{"event_id": reference_event_id}]},
             },
         )
 
@@ -1410,7 +1325,8 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
         Note that the spec allows for a server to return additional fields beyond
         what is specified.
         """
-        self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+        channel = self._send_relation(RelationTypes.REFERENCE, "org.matrix.test")
+        reference_event_id = channel.json_body["event_id"]
 
         # Note that the sync filter does not include "unsigned" as a field.
         filter = urllib.parse.quote_plus(
@@ -1428,7 +1344,12 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
 
         # Ensure there's bundled aggregations on it.
         self.assertIn("unsigned", parent_event)
-        self.assertIn("m.relations", parent_event["unsigned"])
+        self.assertEqual(
+            parent_event["unsigned"].get("m.relations"),
+            {
+                RelationTypes.REFERENCE: {"chunk": [{"event_id": reference_event_id}]},
+            },
+        )
 
 
 class RelationIgnoredUserTestCase(BaseRelationsTestCase):
@@ -1475,53 +1396,8 @@ class RelationIgnoredUserTestCase(BaseRelationsTestCase):
 
         return before_aggregations[relation_type], after_aggregations[relation_type]
 
-    def test_annotation(self) -> None:
-        """Annotations should ignore"""
-        # Send 2 from us, 2 from the to be ignored user.
-        allowed_event_ids = []
-        ignored_event_ids = []
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
-        allowed_event_ids.append(channel.json_body["event_id"])
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="b")
-        allowed_event_ids.append(channel.json_body["event_id"])
-        channel = self._send_relation(
-            RelationTypes.ANNOTATION,
-            "m.reaction",
-            key="a",
-            access_token=self.user2_token,
-        )
-        ignored_event_ids.append(channel.json_body["event_id"])
-        channel = self._send_relation(
-            RelationTypes.ANNOTATION,
-            "m.reaction",
-            key="c",
-            access_token=self.user2_token,
-        )
-        ignored_event_ids.append(channel.json_body["event_id"])
-
-        before_aggregations, after_aggregations = self._test_ignored_user(
-            RelationTypes.ANNOTATION, allowed_event_ids, ignored_event_ids
-        )
-
-        self.assertCountEqual(
-            before_aggregations["chunk"],
-            [
-                {"type": "m.reaction", "key": "a", "count": 2},
-                {"type": "m.reaction", "key": "b", "count": 1},
-                {"type": "m.reaction", "key": "c", "count": 1},
-            ],
-        )
-
-        self.assertCountEqual(
-            after_aggregations["chunk"],
-            [
-                {"type": "m.reaction", "key": "a", "count": 1},
-                {"type": "m.reaction", "key": "b", "count": 1},
-            ],
-        )
-
     def test_reference(self) -> None:
-        """Annotations should ignore"""
+        """Aggregations should exclude reference relations from ignored users"""
         channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
         allowed_event_ids = [channel.json_body["event_id"]]
 
@@ -1544,7 +1420,7 @@ class RelationIgnoredUserTestCase(BaseRelationsTestCase):
         )
 
     def test_thread(self) -> None:
-        """Annotations should ignore"""
+        """Aggregations should exclude thread releations from ignored users"""
         channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
         allowed_event_ids = [channel.json_body["event_id"]]
 
@@ -1618,43 +1494,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
             for t in threads
         ]
 
-    def test_redact_relation_annotation(self) -> None:
-        """
-        Test that annotations of an event are properly handled after the
-        annotation is redacted.
-
-        The redacted relation should not be included in bundled aggregations or
-        the response to relations.
-        """
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        to_redact_event_id = channel.json_body["event_id"]
-
-        channel = self._send_relation(
-            RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
-        )
-        unredacted_event_id = channel.json_body["event_id"]
-
-        # Both relations should exist.
-        event_ids = self._get_related_events()
-        relations = self._get_bundled_aggregations()
-        self.assertCountEqual(event_ids, [to_redact_event_id, unredacted_event_id])
-        self.assertEquals(
-            relations["m.annotation"],
-            {"chunk": [{"type": "m.reaction", "key": "a", "count": 2}]},
-        )
-
-        # Redact one of the reactions.
-        self._redact(to_redact_event_id)
-
-        # The unredacted relation should still exist.
-        event_ids = self._get_related_events()
-        relations = self._get_bundled_aggregations()
-        self.assertEquals(event_ids, [unredacted_event_id])
-        self.assertEquals(
-            relations["m.annotation"],
-            {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
-        )
-
     def test_redact_relation_thread(self) -> None:
         """
         Test that thread replies are properly handled after the thread reply redacted.
@@ -1775,14 +1614,14 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
         is redacted.
         """
         # Add a relation
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
+        channel = self._send_relation(RelationTypes.REFERENCE, "org.matrix.test")
         related_event_id = channel.json_body["event_id"]
 
         # The relations should exist.
         event_ids = self._get_related_events()
         relations = self._get_bundled_aggregations()
         self.assertEqual(len(event_ids), 1)
-        self.assertIn(RelationTypes.ANNOTATION, relations)
+        self.assertIn(RelationTypes.REFERENCE, relations)
 
         # Redact the original event.
         self._redact(self.parent_id)
@@ -1792,8 +1631,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
         relations = self._get_bundled_aggregations()
         self.assertEquals(event_ids, [related_event_id])
         self.assertEquals(
-            relations["m.annotation"],
-            {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]},
+            relations[RelationTypes.REFERENCE],
+            {"chunk": [{"event_id": related_event_id}]},
         )
 
     def test_redact_parent_thread(self) -> None:
diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py
index c0eb5d01a6..8dbd64be55 100644
--- a/tests/rest/client/test_rendezvous.py
+++ b/tests/rest/client/test_rendezvous.py
@@ -25,7 +25,6 @@ endpoint = "/_matrix/client/unstable/org.matrix.msc3886/rendezvous"
 
 
 class RendezvousServletTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         rendezvous.register_servlets,
     ]
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index cfad182b2f..a4900703c4 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -65,7 +65,6 @@ class RoomBase(unittest.HomeserverTestCase):
     servlets = [room.register_servlets, room.register_deprecated_servlets]
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
         self.hs = self.setup_test_homeserver(
             "red",
             federation_http_client=None,
@@ -92,7 +91,6 @@ class RoomPermissionsTestCase(RoomBase):
     rmcreator_id = "@notme:red"
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-
         self.helper.auth_user_id = self.rmcreator_id
         # create some rooms under the name rmcreator_id
         self.uncreated_rmid = "!aa:test"
@@ -715,7 +713,7 @@ class RoomsCreateTestCase(RoomBase):
         self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
         self.assertTrue("room_id" in channel.json_body)
         assert channel.resource_usage is not None
-        self.assertEqual(33, channel.resource_usage.db_txn_count)
+        self.assertEqual(30, channel.resource_usage.db_txn_count)
 
     def test_post_room_initial_state(self) -> None:
         # POST with initial_state config key, expect new room id
@@ -728,7 +726,7 @@ class RoomsCreateTestCase(RoomBase):
         self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
         self.assertTrue("room_id" in channel.json_body)
         assert channel.resource_usage is not None
-        self.assertEqual(36, channel.resource_usage.db_txn_count)
+        self.assertEqual(32, channel.resource_usage.db_txn_count)
 
     def test_post_room_visibility_key(self) -> None:
         # POST with visibility config key, expect new room id
@@ -1127,7 +1125,6 @@ class RoomInviteRatelimitTestCase(RoomBase):
 
 
 class RoomJoinTestCase(RoomBase):
-
     servlets = [
         admin.register_servlets,
         login.register_servlets,
@@ -2102,7 +2099,6 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
     hijack_auth = False
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-
         # Register the user who does the searching
         self.user_id2 = self.register_user("user", "pass")
         self.access_token = self.login("user", "pass")
@@ -2195,7 +2191,6 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
 
 
 class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         room.register_servlets,
@@ -2203,7 +2198,6 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
     ]
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
         self.url = b"/_matrix/client/r0/publicRooms"
 
         config = self.default_config()
@@ -2225,7 +2219,6 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
 
 
 class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         room.register_servlets,
@@ -2233,7 +2226,6 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
     ]
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
         config = self.default_config()
         config["allow_public_rooms_without_auth"] = True
         self.hs = self.setup_test_homeserver(config=config)
@@ -2414,7 +2406,6 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
 
 
 class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         room.register_servlets,
@@ -2983,7 +2974,6 @@ class RelationsTestCase(PaginationTestCase):
 
 
 class ContextTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         room.register_servlets,
@@ -3359,7 +3349,6 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
 
 
 class ThreepidInviteTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         admin.register_servlets,
         login.register_servlets,
@@ -3438,7 +3427,8 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
         """
         Test allowing/blocking threepid invites with a spam-check module.
 
-        In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`."""
+        In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`.
+        """
         # Mock a few functions to prevent the test from failing due to failing to talk to
         # a remote IS. We keep the mock for make_and_store_3pid_invite around so we
         # can check its call_count later on during the test.
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index b9047194dd..9c876c7a32 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -41,7 +41,6 @@ from tests.server import TimedOutException
 
 
 class FilterTestCase(unittest.HomeserverTestCase):
-
     user_id = "@apple:test"
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -191,7 +190,6 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
 
 
 class SyncTypingTests(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         room.register_servlets,
@@ -892,7 +890,6 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
 
 
 class ExcludeRoomTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets,
         login.register_servlets,
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 3277a116e8..7245830b01 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -137,6 +137,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         """Tests that a forbidden event is forbidden from being sent, but an allowed one
         can be sent.
         """
+
         # patch the rules module with a Mock which will return False for some event
         # types
         async def check(
@@ -243,6 +244,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
 
     def test_modify_event(self) -> None:
         """The module can return a modified version of the event"""
+
         # first patch the event checker so that it will modify the event
         async def check(
             ev: EventBase, state: StateMap[EventBase]
@@ -315,6 +317,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
 
     def test_message_edit(self) -> None:
         """Ensure that the module doesn't cause issues with edited messages."""
+
         # first patch the event checker so that it will modify the event
         async def check(
             ev: EventBase, state: StateMap[EventBase]
@@ -465,7 +468,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         async def test_fn(
             event: EventBase, state_events: StateMap[EventBase]
         ) -> Tuple[bool, Optional[JsonDict]]:
-            if event.is_state and event.type == EventTypes.PowerLevels:
+            if event.is_state() and event.type == EventTypes.PowerLevels:
                 await api.create_and_send_event_into_room(
                     {
                         "room_id": event.room_id,
@@ -971,3 +974,124 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
 
         # Check that the mock was called with the right parameters
         self.assertEqual(args, (user_id, "email", "foo@example.com"))
+
+    def test_on_add_and_remove_user_third_party_identifier(self) -> None:
+        """Tests that the on_add_user_third_party_identifier and
+        on_remove_user_third_party_identifier module callbacks are called
+        just before associating and removing a 3PID to/from an account.
+        """
+        # Pretend to be a Synapse module and register both callbacks as mocks.
+        third_party_rules = self.hs.get_third_party_event_rules()
+        on_add_user_third_party_identifier_callback_mock = Mock(
+            return_value=make_awaitable(None)
+        )
+        on_remove_user_third_party_identifier_callback_mock = Mock(
+            return_value=make_awaitable(None)
+        )
+        third_party_rules._on_threepid_bind_callbacks.append(
+            on_add_user_third_party_identifier_callback_mock
+        )
+        third_party_rules._on_threepid_bind_callbacks.append(
+            on_remove_user_third_party_identifier_callback_mock
+        )
+
+        # Register an admin user.
+        self.register_user("admin", "password", admin=True)
+        admin_tok = self.login("admin", "password")
+
+        # Also register a normal user we can modify.
+        user_id = self.register_user("user", "password")
+
+        # Add a 3PID to the user.
+        channel = self.make_request(
+            "PUT",
+            "/_synapse/admin/v2/users/%s" % user_id,
+            {
+                "threepids": [
+                    {
+                        "medium": "email",
+                        "address": "foo@example.com",
+                    },
+                ],
+            },
+            access_token=admin_tok,
+        )
+
+        # Check that the mocked add callback was called with the appropriate
+        # 3PID details.
+        self.assertEqual(channel.code, 200, channel.json_body)
+        on_add_user_third_party_identifier_callback_mock.assert_called_once()
+        args = on_add_user_third_party_identifier_callback_mock.call_args[0]
+        self.assertEqual(args, (user_id, "email", "foo@example.com"))
+
+        # Now remove the 3PID from the user
+        channel = self.make_request(
+            "PUT",
+            "/_synapse/admin/v2/users/%s" % user_id,
+            {
+                "threepids": [],
+            },
+            access_token=admin_tok,
+        )
+
+        # Check that the mocked remove callback was called with the appropriate
+        # 3PID details.
+        self.assertEqual(channel.code, 200, channel.json_body)
+        on_remove_user_third_party_identifier_callback_mock.assert_called_once()
+        args = on_remove_user_third_party_identifier_callback_mock.call_args[0]
+        self.assertEqual(args, (user_id, "email", "foo@example.com"))
+
+    def test_on_remove_user_third_party_identifier_is_called_on_deactivate(
+        self,
+    ) -> None:
+        """Tests that the on_remove_user_third_party_identifier module callback is called
+        when a user is deactivated and their third-party ID associations are deleted.
+        """
+        # Pretend to be a Synapse module and register both callbacks as mocks.
+        third_party_rules = self.hs.get_third_party_event_rules()
+        on_remove_user_third_party_identifier_callback_mock = Mock(
+            return_value=make_awaitable(None)
+        )
+        third_party_rules._on_threepid_bind_callbacks.append(
+            on_remove_user_third_party_identifier_callback_mock
+        )
+
+        # Register an admin user.
+        self.register_user("admin", "password", admin=True)
+        admin_tok = self.login("admin", "password")
+
+        # Also register a normal user we can modify.
+        user_id = self.register_user("user", "password")
+
+        # Add a 3PID to the user.
+        channel = self.make_request(
+            "PUT",
+            "/_synapse/admin/v2/users/%s" % user_id,
+            {
+                "threepids": [
+                    {
+                        "medium": "email",
+                        "address": "foo@example.com",
+                    },
+                ],
+            },
+            access_token=admin_tok,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Now deactivate the user.
+        channel = self.make_request(
+            "PUT",
+            "/_synapse/admin/v2/users/%s" % user_id,
+            {
+                "deactivated": True,
+            },
+            access_token=admin_tok,
+        )
+
+        # Check that the mocked remove callback was called with the appropriate
+        # 3PID details.
+        self.assertEqual(channel.code, 200, channel.json_body)
+        on_remove_user_third_party_identifier_callback_mock.assert_called_once()
+        args = on_remove_user_third_party_identifier_callback_mock.call_args[0]
+        self.assertEqual(args, (user_id, "email", "foo@example.com"))
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 8d6f2b6ff9..9532e5ddc1 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -36,6 +36,7 @@ from urllib.parse import urlencode
 import attr
 from typing_extensions import Literal
 
+from twisted.test.proto_helpers import MemoryReactorClock
 from twisted.web.resource import Resource
 from twisted.web.server import Site
 
@@ -67,6 +68,7 @@ class RestHelper:
     """
 
     hs: HomeServer
+    reactor: MemoryReactorClock
     site: Site
     auth_user_id: Optional[str]
 
@@ -142,7 +144,7 @@ class RestHelper:
             path = path + "?access_token=%s" % tok
 
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "POST",
             path,
@@ -216,7 +218,7 @@ class RestHelper:
             data["reason"] = reason
 
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "POST",
             path,
@@ -313,7 +315,7 @@ class RestHelper:
         data.update(extra_data or {})
 
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "PUT",
             path,
@@ -394,7 +396,7 @@ class RestHelper:
             path = path + "?access_token=%s" % tok
 
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "PUT",
             path,
@@ -433,7 +435,7 @@ class RestHelper:
             path = path + f"?access_token={tok}"
 
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "GET",
             path,
@@ -488,7 +490,7 @@ class RestHelper:
         if body is not None:
             content = json.dumps(body).encode("utf8")
 
-        channel = make_request(self.hs.get_reactor(), self.site, method, path, content)
+        channel = make_request(self.reactor, self.site, method, path, content)
 
         assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
             expect_code,
@@ -573,8 +575,8 @@ class RestHelper:
         image_length = len(image_data)
         path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
         channel = make_request(
-            self.hs.get_reactor(),
-            FakeSite(resource, self.hs.get_reactor()),
+            self.reactor,
+            FakeSite(resource, self.reactor),
             "POST",
             path,
             content=image_data,
@@ -603,7 +605,7 @@ class RestHelper:
             expect_code: The return code to expect from attempting the whoami request
         """
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "GET",
             "account/whoami",
@@ -642,7 +644,7 @@ class RestHelper:
     ) -> Tuple[JsonDict, FakeAuthorizationGrant]:
         """Log in (as a new user) via OIDC
 
-        Returns the result of the final token login.
+        Returns the result of the final token login and the fake authorization grant.
 
         Requires that "oidc_config" in the homeserver config be set appropriately
         (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
@@ -672,10 +674,28 @@ class RestHelper:
         assert m, channel.text_body
         login_token = m.group(1)
 
-        # finally, submit the matrix login token to the login API, which gives us our
-        # matrix access token and device id.
+        return self.login_via_token(login_token, expected_status), grant
+
+    def login_via_token(
+        self,
+        login_token: str,
+        expected_status: int = 200,
+    ) -> JsonDict:
+        """Submit the matrix login token to the login API, which gives us our
+        matrix access token and device id.Log in (as a new user) via OIDC
+
+        Returns the result of the token login.
+
+        Requires that "oidc_config" in the homeserver config be set appropriately
+        (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
+        "public_base_url".
+
+        Also requires the login servlet and the OIDC callback resource to be mounted at
+        the normal places.
+        """
+
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "POST",
             "/login",
@@ -684,7 +704,7 @@ class RestHelper:
         assert (
             channel.code == expected_status
         ), f"unexpected status in response: {channel.code}"
-        return channel.json_body, grant
+        return channel.json_body
 
     def auth_via_oidc(
         self,
@@ -805,7 +825,7 @@ class RestHelper:
         with fake_serer.patch_homeserver(hs=self.hs):
             # now hit the callback URI with the right params and a made-up code
             channel = make_request(
-                self.hs.get_reactor(),
+                self.reactor,
                 self.site,
                 "GET",
                 callback_uri,
@@ -849,7 +869,7 @@ class RestHelper:
         # is the easiest way of figuring out what the Host header ought to be set to
         # to keep Synapse happy.
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "GET",
             uri,
@@ -867,7 +887,7 @@ class RestHelper:
         location = get_location(channel)
         parts = urllib.parse.urlsplit(location)
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "GET",
             urllib.parse.urlunsplit(("", "") + parts[2:]),
@@ -900,9 +920,7 @@ class RestHelper:
             + urllib.parse.urlencode({"session": ui_auth_session_id})
         )
         # hit the redirect url (which will issue a cookie and state)
-        channel = make_request(
-            self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint
-        )
+        channel = make_request(self.reactor, self.site, "GET", sso_redirect_endpoint)
         # that should serve a confirmation page
         assert channel.code == HTTPStatus.OK, channel.text_body
         channel.extract_cookies(cookies)