diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 3277a116e8..7245830b01 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -137,6 +137,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"""Tests that a forbidden event is forbidden from being sent, but an allowed one
can be sent.
"""
+
# patch the rules module with a Mock which will return False for some event
# types
async def check(
@@ -243,6 +244,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
def test_modify_event(self) -> None:
"""The module can return a modified version of the event"""
+
# first patch the event checker so that it will modify the event
async def check(
ev: EventBase, state: StateMap[EventBase]
@@ -315,6 +317,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
def test_message_edit(self) -> None:
"""Ensure that the module doesn't cause issues with edited messages."""
+
# first patch the event checker so that it will modify the event
async def check(
ev: EventBase, state: StateMap[EventBase]
@@ -465,7 +468,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
async def test_fn(
event: EventBase, state_events: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict]]:
- if event.is_state and event.type == EventTypes.PowerLevels:
+ if event.is_state() and event.type == EventTypes.PowerLevels:
await api.create_and_send_event_into_room(
{
"room_id": event.room_id,
@@ -971,3 +974,124 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
# Check that the mock was called with the right parameters
self.assertEqual(args, (user_id, "email", "foo@example.com"))
+
+ def test_on_add_and_remove_user_third_party_identifier(self) -> None:
+ """Tests that the on_add_user_third_party_identifier and
+ on_remove_user_third_party_identifier module callbacks are called
+ just before associating and removing a 3PID to/from an account.
+ """
+ # Pretend to be a Synapse module and register both callbacks as mocks.
+ third_party_rules = self.hs.get_third_party_event_rules()
+ on_add_user_third_party_identifier_callback_mock = Mock(
+ return_value=make_awaitable(None)
+ )
+ on_remove_user_third_party_identifier_callback_mock = Mock(
+ return_value=make_awaitable(None)
+ )
+ third_party_rules._on_threepid_bind_callbacks.append(
+ on_add_user_third_party_identifier_callback_mock
+ )
+ third_party_rules._on_threepid_bind_callbacks.append(
+ on_remove_user_third_party_identifier_callback_mock
+ )
+
+ # Register an admin user.
+ self.register_user("admin", "password", admin=True)
+ admin_tok = self.login("admin", "password")
+
+ # Also register a normal user we can modify.
+ user_id = self.register_user("user", "password")
+
+ # Add a 3PID to the user.
+ channel = self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/%s" % user_id,
+ {
+ "threepids": [
+ {
+ "medium": "email",
+ "address": "foo@example.com",
+ },
+ ],
+ },
+ access_token=admin_tok,
+ )
+
+ # Check that the mocked add callback was called with the appropriate
+ # 3PID details.
+ self.assertEqual(channel.code, 200, channel.json_body)
+ on_add_user_third_party_identifier_callback_mock.assert_called_once()
+ args = on_add_user_third_party_identifier_callback_mock.call_args[0]
+ self.assertEqual(args, (user_id, "email", "foo@example.com"))
+
+ # Now remove the 3PID from the user
+ channel = self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/%s" % user_id,
+ {
+ "threepids": [],
+ },
+ access_token=admin_tok,
+ )
+
+ # Check that the mocked remove callback was called with the appropriate
+ # 3PID details.
+ self.assertEqual(channel.code, 200, channel.json_body)
+ on_remove_user_third_party_identifier_callback_mock.assert_called_once()
+ args = on_remove_user_third_party_identifier_callback_mock.call_args[0]
+ self.assertEqual(args, (user_id, "email", "foo@example.com"))
+
+ def test_on_remove_user_third_party_identifier_is_called_on_deactivate(
+ self,
+ ) -> None:
+ """Tests that the on_remove_user_third_party_identifier module callback is called
+ when a user is deactivated and their third-party ID associations are deleted.
+ """
+ # Pretend to be a Synapse module and register both callbacks as mocks.
+ third_party_rules = self.hs.get_third_party_event_rules()
+ on_remove_user_third_party_identifier_callback_mock = Mock(
+ return_value=make_awaitable(None)
+ )
+ third_party_rules._on_threepid_bind_callbacks.append(
+ on_remove_user_third_party_identifier_callback_mock
+ )
+
+ # Register an admin user.
+ self.register_user("admin", "password", admin=True)
+ admin_tok = self.login("admin", "password")
+
+ # Also register a normal user we can modify.
+ user_id = self.register_user("user", "password")
+
+ # Add a 3PID to the user.
+ channel = self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/%s" % user_id,
+ {
+ "threepids": [
+ {
+ "medium": "email",
+ "address": "foo@example.com",
+ },
+ ],
+ },
+ access_token=admin_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Now deactivate the user.
+ channel = self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/%s" % user_id,
+ {
+ "deactivated": True,
+ },
+ access_token=admin_tok,
+ )
+
+ # Check that the mocked remove callback was called with the appropriate
+ # 3PID details.
+ self.assertEqual(channel.code, 200, channel.json_body)
+ on_remove_user_third_party_identifier_callback_mock.assert_called_once()
+ args = on_remove_user_third_party_identifier_callback_mock.call_args[0]
+ self.assertEqual(args, (user_id, "email", "foo@example.com"))
|