summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/rest/client/test_third_party_rules.py132
1 files changed, 108 insertions, 24 deletions
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index c5e1c5458b..28dd47a28b 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -16,17 +16,19 @@ from typing import Dict
 from unittest.mock import Mock
 
 from synapse.events import EventBase
+from synapse.events.third_party_rules import load_legacy_third_party_event_rules
 from synapse.module_api import ModuleApi
 from synapse.rest import admin
 from synapse.rest.client.v1 import login, room
 from synapse.types import Requester, StateMap
+from synapse.util.frozenutils import unfreeze
 
 from tests import unittest
 
 thread_local = threading.local()
 
 
-class ThirdPartyRulesTestModule:
+class LegacyThirdPartyRulesTestModule:
     def __init__(self, config: Dict, module_api: ModuleApi):
         # keep a record of the "current" rules module, so that the test can patch
         # it if desired.
@@ -46,8 +48,26 @@ class ThirdPartyRulesTestModule:
         return config
 
 
-def current_rules_module() -> ThirdPartyRulesTestModule:
-    return thread_local.rules_module
+class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule):
+    def __init__(self, config: Dict, module_api: ModuleApi):
+        super().__init__(config, module_api)
+
+    def on_create_room(
+        self, requester: Requester, config: dict, is_requester_admin: bool
+    ):
+        return False
+
+
+class LegacyChangeEvents(LegacyThirdPartyRulesTestModule):
+    def __init__(self, config: Dict, module_api: ModuleApi):
+        super().__init__(config, module_api)
+
+    async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
+        d = event.get_dict()
+        content = unfreeze(event.content)
+        content["foo"] = "bar"
+        d["content"] = content
+        return d
 
 
 class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
@@ -57,20 +77,23 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def default_config(self):
-        config = super().default_config()
-        config["third_party_event_rules"] = {
-            "module": __name__ + ".ThirdPartyRulesTestModule",
-            "config": {},
-        }
-        return config
+    def make_homeserver(self, reactor, clock):
+        hs = self.setup_test_homeserver()
+
+        load_legacy_third_party_event_rules(hs)
+
+        return hs
 
     def prepare(self, reactor, clock, homeserver):
         # Create a user and room to play with during the tests
         self.user_id = self.register_user("kermit", "monkey")
         self.tok = self.login("kermit", "monkey")
 
-        self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+        # Some tests might prevent room creation on purpose.
+        try:
+            self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+        except Exception:
+            pass
 
     def test_third_party_rules(self):
         """Tests that a forbidden event is forbidden from being sent, but an allowed one
@@ -79,10 +102,12 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         # patch the rules module with a Mock which will return False for some event
         # types
         async def check(ev, state):
-            return ev.type != "foo.bar.forbidden"
+            return ev.type != "foo.bar.forbidden", None
 
         callback = Mock(spec=[], side_effect=check)
-        current_rules_module().check_event_allowed = callback
+        self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [
+            callback
+        ]
 
         channel = self.make_request(
             "PUT",
@@ -116,9 +141,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         # first patch the event checker so that it will try to modify the event
         async def check(ev: EventBase, state):
             ev.content = {"x": "y"}
-            return True
+            return True, None
 
-        current_rules_module().check_event_allowed = check
+        self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
 
         # now send the event
         channel = self.make_request(
@@ -127,7 +152,19 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
             {"x": "x"},
             access_token=self.tok,
         )
-        self.assertEqual(channel.result["code"], b"500", channel.result)
+        # check_event_allowed has some error handling, so it shouldn't 500 just because a
+        # module did something bad.
+        self.assertEqual(channel.code, 200, channel.result)
+        event_id = channel.json_body["event_id"]
+
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+        ev = channel.json_body
+        self.assertEqual(ev["content"]["x"], "x")
 
     def test_modify_event(self):
         """The module can return a modified version of the event"""
@@ -135,9 +172,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         async def check(ev: EventBase, state):
             d = ev.get_dict()
             d["content"] = {"x": "y"}
-            return d
+            return True, d
 
-        current_rules_module().check_event_allowed = check
+        self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
 
         # now send the event
         channel = self.make_request(
@@ -168,9 +205,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
                 "msgtype": "m.text",
                 "body": d["content"]["body"].upper(),
             }
-            return d
+            return True, d
 
-        current_rules_module().check_event_allowed = check
+        self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
 
         # Send an event, then edit it.
         channel = self.make_request(
@@ -222,7 +259,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         self.assertEqual(ev["content"]["body"], "EDITED BODY")
 
     def test_send_event(self):
-        """Tests that the module can send an event into a room via the module api"""
+        """Tests that a module can send an event into a room via the module api"""
         content = {
             "msgtype": "m.text",
             "body": "Hello!",
@@ -234,12 +271,59 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
             "sender": self.user_id,
         }
         event: EventBase = self.get_success(
-            current_rules_module().module_api.create_and_send_event_into_room(
-                event_dict
-            )
+            self.hs.get_module_api().create_and_send_event_into_room(event_dict)
         )
 
         self.assertEquals(event.sender, self.user_id)
         self.assertEquals(event.room_id, self.room_id)
         self.assertEquals(event.type, "m.room.message")
         self.assertEquals(event.content, content)
+
+    @unittest.override_config(
+        {
+            "third_party_event_rules": {
+                "module": __name__ + ".LegacyChangeEvents",
+                "config": {},
+            }
+        }
+    )
+    def test_legacy_check_event_allowed(self):
+        """Tests that the wrapper for legacy check_event_allowed callbacks works
+        correctly.
+        """
+        channel = self.make_request(
+            "PUT",
+            "/_matrix/client/r0/rooms/%s/send/m.room.message/1" % self.room_id,
+            {
+                "msgtype": "m.text",
+                "body": "Original body",
+            },
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+
+        event_id = channel.json_body["event_id"]
+
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+
+        self.assertIn("foo", channel.json_body["content"].keys())
+        self.assertEqual(channel.json_body["content"]["foo"], "bar")
+
+    @unittest.override_config(
+        {
+            "third_party_event_rules": {
+                "module": __name__ + ".LegacyDenyNewRooms",
+                "config": {},
+            }
+        }
+    )
+    def test_legacy_on_create_room(self):
+        """Tests that the wrapper for legacy on_create_room callbacks works
+        correctly.
+        """
+        self.helper.create_room_as(self.user_id, tok=self.tok, expect_code=403)