summary refs log tree commit diff
path: root/tests/rest
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest')
-rw-r--r--tests/rest/client/test_third_party_rules.py56
1 files changed, 35 insertions, 21 deletions
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 753ecc8d16..e5ba5a9706 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -22,7 +22,9 @@ from synapse.api.errors import SynapseError
 from synapse.api.room_versions import RoomVersion
 from synapse.config.homeserver import HomeServerConfig
 from synapse.events import EventBase
-from synapse.events.third_party_rules import load_legacy_third_party_event_rules
+from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
+    load_legacy_third_party_event_rules,
+)
 from synapse.rest import admin
 from synapse.rest.client import account, login, profile, room
 from synapse.server import HomeServer
@@ -146,7 +148,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
             return ev.type != "foo.bar.forbidden", None
 
         callback = Mock(spec=[], side_effect=check)
-        self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [
+        self.hs.get_module_api_callbacks().third_party_event_rules._check_event_allowed_callbacks = [
             callback
         ]
 
@@ -202,7 +204,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         ) -> Tuple[bool, Optional[JsonDict]]:
             raise NastyHackException(429, "message")
 
-        self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
+        self.hs.get_module_api_callbacks().third_party_event_rules._check_event_allowed_callbacks = [
+            check
+        ]
 
         # Make a request
         channel = self.make_request(
@@ -229,7 +233,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
             ev.content = {"x": "y"}
             return True, None
 
-        self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
+        self.hs.get_module_api_callbacks().third_party_event_rules._check_event_allowed_callbacks = [
+            check
+        ]
 
         # now send the event
         channel = self.make_request(
@@ -253,7 +259,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
             d["content"] = {"x": "y"}
             return True, d
 
-        self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
+        self.hs.get_module_api_callbacks().third_party_event_rules._check_event_allowed_callbacks = [
+            check
+        ]
 
         # now send the event
         channel = self.make_request(
@@ -289,7 +297,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
             }
             return True, d
 
-        self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
+        self.hs.get_module_api_callbacks().third_party_event_rules._check_event_allowed_callbacks = [
+            check
+        ]
 
         # Send an event, then edit it.
         channel = self.make_request(
@@ -440,7 +450,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
                 )
             return True, None
 
-        self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [test_fn]
+        self.hs.get_module_api_callbacks().third_party_event_rules._check_event_allowed_callbacks = [
+            test_fn
+        ]
 
         # Sometimes the bug might not happen the first time the event type is added
         # to the state but might happen when an event updates the state of the room for
@@ -466,7 +478,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
     def test_on_new_event(self) -> None:
         """Test that the on_new_event callback is called on new events"""
         on_new_event = Mock(make_awaitable(None))
-        self.hs.get_third_party_event_rules()._on_new_event_callbacks.append(
+        self.hs.get_module_api_callbacks().third_party_event_rules._on_new_event_callbacks.append(
             on_new_event
         )
 
@@ -569,7 +581,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
 
         # Register a mock callback.
         m = Mock(return_value=make_awaitable(None))
-        self.hs.get_third_party_event_rules()._on_profile_update_callbacks.append(m)
+        self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append(
+            m
+        )
 
         # Change the display name.
         channel = self.make_request(
@@ -628,7 +642,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
 
         # Register a mock callback.
         m = Mock(return_value=make_awaitable(None))
-        self.hs.get_third_party_event_rules()._on_profile_update_callbacks.append(m)
+        self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append(
+            m
+        )
 
         # Register an admin user.
         self.register_user("admin", "password", admin=True)
@@ -667,7 +683,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         """
         # Register a mocked callback.
         deactivation_mock = Mock(return_value=make_awaitable(None))
-        third_party_rules = self.hs.get_third_party_event_rules()
+        third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
         third_party_rules._on_user_deactivation_status_changed_callbacks.append(
             deactivation_mock,
         )
@@ -675,7 +691,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         # deactivation code calls it in a way that let modules know the user is being
         # deactivated.
         profile_mock = Mock(return_value=make_awaitable(None))
-        self.hs.get_third_party_event_rules()._on_profile_update_callbacks.append(
+        self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append(
             profile_mock,
         )
 
@@ -725,7 +741,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         """
         # Register a mock callback.
         m = Mock(return_value=make_awaitable(None))
-        third_party_rules = self.hs.get_third_party_event_rules()
+        third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
         third_party_rules._on_user_deactivation_status_changed_callbacks.append(m)
 
         # Register an admin user.
@@ -779,7 +795,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         """
         # Register a mocked callback.
         deactivation_mock = Mock(return_value=make_awaitable(False))
-        third_party_rules = self.hs.get_third_party_event_rules()
+        third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
         third_party_rules._check_can_deactivate_user_callbacks.append(
             deactivation_mock,
         )
@@ -825,7 +841,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         """
         # Register a mocked callback.
         deactivation_mock = Mock(return_value=make_awaitable(False))
-        third_party_rules = self.hs.get_third_party_event_rules()
+        third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
         third_party_rules._check_can_deactivate_user_callbacks.append(
             deactivation_mock,
         )
@@ -864,7 +880,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         """
         # Register a mocked callback.
         shutdown_mock = Mock(return_value=make_awaitable(False))
-        third_party_rules = self.hs.get_third_party_event_rules()
+        third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
         third_party_rules._check_can_shutdown_room_callbacks.append(
             shutdown_mock,
         )
@@ -900,7 +916,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         """
         # Register a mocked callback.
         threepid_bind_mock = Mock(return_value=make_awaitable(None))
-        third_party_rules = self.hs.get_third_party_event_rules()
+        third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
         third_party_rules._on_threepid_bind_callbacks.append(threepid_bind_mock)
 
         # Register an admin user.
@@ -947,8 +963,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         on_remove_user_third_party_identifier_callback_mock = Mock(
             return_value=make_awaitable(None)
         )
-        third_party_rules = self.hs.get_third_party_event_rules()
-        third_party_rules.register_third_party_rules_callbacks(
+        self.hs.get_module_api().register_third_party_rules_callbacks(
             on_add_user_third_party_identifier=on_add_user_third_party_identifier_callback_mock,
             on_remove_user_third_party_identifier=on_remove_user_third_party_identifier_callback_mock,
         )
@@ -1009,8 +1024,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         on_remove_user_third_party_identifier_callback_mock = Mock(
             return_value=make_awaitable(None)
         )
-        third_party_rules = self.hs.get_third_party_event_rules()
-        third_party_rules.register_third_party_rules_callbacks(
+        self.hs.get_module_api().register_third_party_rules_callbacks(
             on_remove_user_third_party_identifier=on_remove_user_third_party_identifier_callback_mock,
         )