summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/rest/client/test_third_party_rules.py28
1 files changed, 24 insertions, 4 deletions
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index b737625e33..0048bea54a 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -114,10 +114,10 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         self.render(request)
         self.assertEquals(channel.result["code"], b"403", channel.result)
 
-    def test_modify_event(self):
-        """Tests that the module can successfully tweak an event before it is persisted.
-        """
-        # first patch the event checker so that it will modify the event
+    def test_cannot_modify_event(self):
+        """cannot accidentally modify an event before it is persisted"""
+
+        # first patch the event checker so that it will try to modify the event
         async def check(ev: EventBase, state):
             ev.content = {"x": "y"}
             return True
@@ -132,6 +132,26 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
             access_token=self.tok,
         )
         self.render(request)
+        self.assertEqual(channel.result["code"], b"500", channel.result)
+
+    def test_modify_event(self):
+        """The module can return a modified version of the event"""
+        # first patch the event checker so that it will modify the event
+        async def check(ev: EventBase, state):
+            d = ev.get_dict()
+            d["content"] = {"x": "y"}
+            return d
+
+        current_rules_module().check_event_allowed = check
+
+        # now send the event
+        request, channel = self.make_request(
+            "PUT",
+            "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
+            {"x": "x"},
+            access_token=self.tok,
+        )
+        self.render(request)
         self.assertEqual(channel.result["code"], b"200", channel.result)
         event_id = channel.json_body["event_id"]