summary refs log tree commit diff
path: root/tests/rest/client
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client')
-rw-r--r--tests/rest/client/test_third_party_rules.py84
1 files changed, 71 insertions, 13 deletions
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 7b322f526c..c12518c931 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -12,33 +12,43 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import threading
+
+from mock import Mock
+
+from synapse.events import EventBase
 from synapse.rest import admin
 from synapse.rest.client.v1 import login, room
-from synapse.types import Requester
+from synapse.types import Requester, StateMap
 
 from tests import unittest
 
+thread_local = threading.local()
+
 
 class ThirdPartyRulesTestModule:
-    def __init__(self, config, *args, **kwargs):
-        pass
+    def __init__(self, config, module_api):
+        # keep a record of the "current" rules module, so that the test can patch
+        # it if desired.
+        thread_local.rules_module = self
 
     async def on_create_room(
         self, requester: Requester, config: dict, is_requester_admin: bool
     ):
         return True
 
-    async def check_event_allowed(self, event, context):
-        if event.type == "foo.bar.forbidden":
-            return False
-        else:
-            return True
+    async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
+        return True
 
     @staticmethod
     def parse_config(config):
         return config
 
 
+def current_rules_module() -> ThirdPartyRulesTestModule:
+    return thread_local.rules_module
+
+
 class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
     servlets = [
         admin.register_servlets,
@@ -46,15 +56,13 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
-        config = self.default_config()
+    def default_config(self):
+        config = super().default_config()
         config["third_party_event_rules"] = {
             "module": __name__ + ".ThirdPartyRulesTestModule",
             "config": {},
         }
-
-        self.hs = self.setup_test_homeserver(config=config)
-        return self.hs
+        return config
 
     def prepare(self, reactor, clock, homeserver):
         # Create a user and room to play with during the tests
@@ -67,6 +75,14 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         """Tests that a forbidden event is forbidden from being sent, but an allowed one
         can be sent.
         """
+        # patch the rules module with a Mock which will return False for some event
+        # types
+        async def check(ev, state):
+            return ev.type != "foo.bar.forbidden"
+
+        callback = Mock(spec=[], side_effect=check)
+        current_rules_module().check_event_allowed = callback
+
         request, channel = self.make_request(
             "PUT",
             "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % self.room_id,
@@ -76,6 +92,16 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         self.render(request)
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
+        callback.assert_called_once()
+
+        # there should be various state events in the state arg: do some basic checks
+        state_arg = callback.call_args[0][1]
+        for k in (("m.room.create", ""), ("m.room.member", self.user_id)):
+            self.assertIn(k, state_arg)
+            ev = state_arg[k]
+            self.assertEqual(ev.type, k[0])
+            self.assertEqual(ev.state_key, k[1])
+
         request, channel = self.make_request(
             "PUT",
             "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % self.room_id,
@@ -84,3 +110,35 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         )
         self.render(request)
         self.assertEquals(channel.result["code"], b"403", channel.result)
+
+    def test_modify_event(self):
+        """Tests that the module can successfully tweak an event before it is persisted.
+        """
+        # first patch the event checker so that it will modify the event
+        async def check(ev: EventBase, state):
+            ev.content = {"x": "y"}
+            return True
+
+        current_rules_module().check_event_allowed = check
+
+        # now send the event
+        request, channel = self.make_request(
+            "PUT",
+            "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
+            {"x": "x"},
+            access_token=self.tok,
+        )
+        self.render(request)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+        event_id = channel.json_body["event_id"]
+
+        # ... and check that it got modified
+        request, channel = self.make_request(
+            "GET",
+            "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
+            access_token=self.tok,
+        )
+        self.render(request)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+        ev = channel.json_body
+        self.assertEqual(ev["content"]["x"], "y")