diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index c0f93f898a..3b99513707 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -934,3 +934,124 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
# Check that the mock was called with the right parameters
self.assertEqual(args, (user_id, "email", "foo@example.com"))
+
+ def test_on_add_and_remove_user_third_party_identifier(self) -> None:
+ """Tests that the on_add_user_third_party_identifier and
+ on_remove_user_third_party_identifier module callbacks are called
+ just before associating and removing a 3PID to/from an account.
+ """
+ # Pretend to be a Synapse module and register both callbacks as mocks.
+ third_party_rules = self.hs.get_third_party_event_rules()
+ on_add_user_third_party_identifier_callback_mock = Mock(
+ return_value=make_awaitable(None)
+ )
+ on_remove_user_third_party_identifier_callback_mock = Mock(
+ return_value=make_awaitable(None)
+ )
+ third_party_rules._on_threepid_bind_callbacks.append(
+ on_add_user_third_party_identifier_callback_mock
+ )
+ third_party_rules._on_threepid_bind_callbacks.append(
+ on_remove_user_third_party_identifier_callback_mock
+ )
+
+ # Register an admin user.
+ self.register_user("admin", "password", admin=True)
+ admin_tok = self.login("admin", "password")
+
+ # Also register a normal user we can modify.
+ user_id = self.register_user("user", "password")
+
+ # Add a 3PID to the user.
+ channel = self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/%s" % user_id,
+ {
+ "threepids": [
+ {
+ "medium": "email",
+ "address": "foo@example.com",
+ },
+ ],
+ },
+ access_token=admin_tok,
+ )
+
+ # Check that the mocked add callback was called with the appropriate
+ # 3PID details.
+ self.assertEqual(channel.code, 200, channel.json_body)
+ on_add_user_third_party_identifier_callback_mock.assert_called_once()
+ args = on_add_user_third_party_identifier_callback_mock.call_args[0]
+ self.assertEqual(args, (user_id, "email", "foo@example.com"))
+
+ # Now remove the 3PID from the user
+ channel = self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/%s" % user_id,
+ {
+ "threepids": [],
+ },
+ access_token=admin_tok,
+ )
+
+ # Check that the mocked remove callback was called with the appropriate
+ # 3PID details.
+ self.assertEqual(channel.code, 200, channel.json_body)
+ on_remove_user_third_party_identifier_callback_mock.assert_called_once()
+ args = on_remove_user_third_party_identifier_callback_mock.call_args[0]
+ self.assertEqual(args, (user_id, "email", "foo@example.com"))
+
+ def test_on_remove_user_third_party_identifier_is_called_on_deactivate(
+ self,
+ ) -> None:
+ """Tests that the on_remove_user_third_party_identifier module callback is called
+ when a user is deactivated and their third-party ID associations are deleted.
+ """
+ # Pretend to be a Synapse module and register both callbacks as mocks.
+ third_party_rules = self.hs.get_third_party_event_rules()
+ on_remove_user_third_party_identifier_callback_mock = Mock(
+ return_value=make_awaitable(None)
+ )
+ third_party_rules._on_threepid_bind_callbacks.append(
+ on_remove_user_third_party_identifier_callback_mock
+ )
+
+ # Register an admin user.
+ self.register_user("admin", "password", admin=True)
+ admin_tok = self.login("admin", "password")
+
+ # Also register a normal user we can modify.
+ user_id = self.register_user("user", "password")
+
+ # Add a 3PID to the user.
+ channel = self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/%s" % user_id,
+ {
+ "threepids": [
+ {
+ "medium": "email",
+ "address": "foo@example.com",
+ },
+ ],
+ },
+ access_token=admin_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Now deactivate the user.
+ channel = self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/%s" % user_id,
+ {
+ "deactivated": True,
+ },
+ access_token=admin_tok,
+ )
+
+ # Check that the mocked remove callback was called with the appropriate
+ # 3PID details.
+ self.assertEqual(channel.code, 200, channel.json_body)
+ on_remove_user_third_party_identifier_callback_mock.assert_called_once()
+ args = on_remove_user_third_party_identifier_callback_mock.call_args[0]
+ self.assertEqual(args, (user_id, "email", "foo@example.com"))
|