summary refs log tree commit diff
path: root/tests/rest
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest')
-rw-r--r--tests/rest/client/test_third_party_rules.py219
1 files changed, 217 insertions, 2 deletions
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 9cca9edd30..bfc04785b7 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -15,12 +15,12 @@ import threading
 from typing import TYPE_CHECKING, Dict, Optional, Tuple
 from unittest.mock import Mock
 
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, LoginType, Membership
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase
 from synapse.events.third_party_rules import load_legacy_third_party_event_rules
 from synapse.rest import admin
-from synapse.rest.client import login, room
+from synapse.rest.client import account, login, profile, room
 from synapse.types import JsonDict, Requester, StateMap
 from synapse.util.frozenutils import unfreeze
 
@@ -80,6 +80,8 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         admin.register_servlets,
         login.register_servlets,
         room.register_servlets,
+        profile.register_servlets,
+        account.register_servlets,
     ]
 
     def make_homeserver(self, reactor, clock):
@@ -530,3 +532,216 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
             },
             tok=self.tok,
         )
+
+    def test_on_profile_update(self):
+        """Tests that the on_profile_update module callback is correctly called on
+        profile updates.
+        """
+        displayname = "Foo"
+        avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo"
+
+        # Register a mock callback.
+        m = Mock(return_value=make_awaitable(None))
+        self.hs.get_third_party_event_rules()._on_profile_update_callbacks.append(m)
+
+        # Change the display name.
+        channel = self.make_request(
+            "PUT",
+            "/_matrix/client/v3/profile/%s/displayname" % self.user_id,
+            {"displayname": displayname},
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check that the callback has been called once for our user.
+        m.assert_called_once()
+        args = m.call_args[0]
+        self.assertEqual(args[0], self.user_id)
+
+        # Test that by_admin is False.
+        self.assertFalse(args[2])
+        # Test that deactivation is False.
+        self.assertFalse(args[3])
+
+        # Check that we've got the right profile data.
+        profile_info = args[1]
+        self.assertEqual(profile_info.display_name, displayname)
+        self.assertIsNone(profile_info.avatar_url)
+
+        # Change the avatar.
+        channel = self.make_request(
+            "PUT",
+            "/_matrix/client/v3/profile/%s/avatar_url" % self.user_id,
+            {"avatar_url": avatar_url},
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check that the callback has been called once for our user.
+        self.assertEqual(m.call_count, 2)
+        args = m.call_args[0]
+        self.assertEqual(args[0], self.user_id)
+
+        # Test that by_admin is False.
+        self.assertFalse(args[2])
+        # Test that deactivation is False.
+        self.assertFalse(args[3])
+
+        # Check that we've got the right profile data.
+        profile_info = args[1]
+        self.assertEqual(profile_info.display_name, displayname)
+        self.assertEqual(profile_info.avatar_url, avatar_url)
+
+    def test_on_profile_update_admin(self):
+        """Tests that the on_profile_update module callback is correctly called on
+        profile updates triggered by a server admin.
+        """
+        displayname = "Foo"
+        avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo"
+
+        # Register a mock callback.
+        m = Mock(return_value=make_awaitable(None))
+        self.hs.get_third_party_event_rules()._on_profile_update_callbacks.append(m)
+
+        # Register an admin user.
+        self.register_user("admin", "password", admin=True)
+        admin_tok = self.login("admin", "password")
+
+        # Change a user's profile.
+        channel = self.make_request(
+            "PUT",
+            "/_synapse/admin/v2/users/%s" % self.user_id,
+            {"displayname": displayname, "avatar_url": avatar_url},
+            access_token=admin_tok,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check that the callback has been called twice (since we update the display name
+        # and avatar separately).
+        self.assertEqual(m.call_count, 2)
+
+        # Get the arguments for the last call and check it's about the right user.
+        args = m.call_args[0]
+        self.assertEqual(args[0], self.user_id)
+
+        # Check that by_admin is True.
+        self.assertTrue(args[2])
+        # Test that deactivation is False.
+        self.assertFalse(args[3])
+
+        # Check that we've got the right profile data.
+        profile_info = args[1]
+        self.assertEqual(profile_info.display_name, displayname)
+        self.assertEqual(profile_info.avatar_url, avatar_url)
+
+    def test_on_user_deactivation_status_changed(self):
+        """Tests that the on_user_deactivation_status_changed module callback is called
+        correctly when processing a user's deactivation.
+        """
+        # Register a mocked callback.
+        deactivation_mock = Mock(return_value=make_awaitable(None))
+        third_party_rules = self.hs.get_third_party_event_rules()
+        third_party_rules._on_user_deactivation_status_changed_callbacks.append(
+            deactivation_mock,
+        )
+        # Also register a mocked callback for profile updates, to check that the
+        # deactivation code calls it in a way that let modules know the user is being
+        # deactivated.
+        profile_mock = Mock(return_value=make_awaitable(None))
+        self.hs.get_third_party_event_rules()._on_profile_update_callbacks.append(
+            profile_mock,
+        )
+
+        # Register a user that we'll deactivate.
+        user_id = self.register_user("altan", "password")
+        tok = self.login("altan", "password")
+
+        # Deactivate that user.
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/v3/account/deactivate",
+            {
+                "auth": {
+                    "type": LoginType.PASSWORD,
+                    "password": "password",
+                    "identifier": {
+                        "type": "m.id.user",
+                        "user": user_id,
+                    },
+                },
+                "erase": True,
+            },
+            access_token=tok,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check that the mock was called once.
+        deactivation_mock.assert_called_once()
+        args = deactivation_mock.call_args[0]
+
+        # Check that the mock was called with the right user ID, and with a True
+        # deactivated flag and a False by_admin flag.
+        self.assertEqual(args[0], user_id)
+        self.assertTrue(args[1])
+        self.assertFalse(args[2])
+
+        # Check that the profile update callback was called twice (once for the display
+        # name and once for the avatar URL), and that the "deactivation" boolean is true.
+        self.assertEqual(profile_mock.call_count, 2)
+        args = profile_mock.call_args[0]
+        self.assertTrue(args[3])
+
+    def test_on_user_deactivation_status_changed_admin(self):
+        """Tests that the on_user_deactivation_status_changed module callback is called
+        correctly when processing a user's deactivation triggered by a server admin as
+        well as a reactivation.
+        """
+        # Register a mock callback.
+        m = Mock(return_value=make_awaitable(None))
+        third_party_rules = self.hs.get_third_party_event_rules()
+        third_party_rules._on_user_deactivation_status_changed_callbacks.append(m)
+
+        # Register an admin user.
+        self.register_user("admin", "password", admin=True)
+        admin_tok = self.login("admin", "password")
+
+        # Register a user that we'll deactivate.
+        user_id = self.register_user("altan", "password")
+
+        # Deactivate the user.
+        channel = self.make_request(
+            "PUT",
+            "/_synapse/admin/v2/users/%s" % user_id,
+            {"deactivated": True},
+            access_token=admin_tok,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check that the mock was called once.
+        m.assert_called_once()
+        args = m.call_args[0]
+
+        # Check that the mock was called with the right user ID, and with True deactivated
+        # and by_admin flags.
+        self.assertEqual(args[0], user_id)
+        self.assertTrue(args[1])
+        self.assertTrue(args[2])
+
+        # Reactivate the user.
+        channel = self.make_request(
+            "PUT",
+            "/_synapse/admin/v2/users/%s" % user_id,
+            {"deactivated": False, "password": "hackme"},
+            access_token=admin_tok,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check that the mock was called once.
+        self.assertEqual(m.call_count, 2)
+        args = m.call_args[0]
+
+        # Check that the mock was called with the right user ID, and with a False
+        # deactivated flag and a True by_admin flag.
+        self.assertEqual(args[0], user_id)
+        self.assertFalse(args[1])
+        self.assertTrue(args[2])