diff --git a/changelog.d/8468.misc b/changelog.d/8468.misc
new file mode 100644
index 0000000000..32ba991e64
--- /dev/null
+++ b/changelog.d/8468.misc
@@ -0,0 +1 @@
+Additional testing for `ThirdPartyEventRules`.
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 1ca77519d5..e38b8e67fb 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -61,12 +61,14 @@ class ThirdPartyEventRules:
prev_state_ids = await context.get_prev_state_ids()
# Retrieve the state events from the database.
- state_events = {}
- for key, event_id in prev_state_ids.items():
- state_events[key] = await self.store.get_event(event_id, allow_none=True)
+ events = await self.store.get_events(prev_state_ids.values())
+ state_events = {(ev.type, ev.state_key): ev for ev in events.values()}
- ret = await self.third_party_rules.check_event_allowed(event, state_events)
- return ret
+ # The module can modify the event slightly if it wants, but caution should be
+ # exercised, and it's likely to go very wrong if applied to events received over
+ # federation.
+
+ return await self.third_party_rules.check_event_allowed(event, state_events)
async def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 7b322f526c..c12518c931 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -12,33 +12,43 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import threading
+
+from mock import Mock
+
+from synapse.events import EventBase
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
-from synapse.types import Requester
+from synapse.types import Requester, StateMap
from tests import unittest
+thread_local = threading.local()
+
class ThirdPartyRulesTestModule:
- def __init__(self, config, *args, **kwargs):
- pass
+ def __init__(self, config, module_api):
+ # keep a record of the "current" rules module, so that the test can patch
+ # it if desired.
+ thread_local.rules_module = self
async def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool
):
return True
- async def check_event_allowed(self, event, context):
- if event.type == "foo.bar.forbidden":
- return False
- else:
- return True
+ async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
+ return True
@staticmethod
def parse_config(config):
return config
+def current_rules_module() -> ThirdPartyRulesTestModule:
+ return thread_local.rules_module
+
+
class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,
@@ -46,15 +56,13 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def make_homeserver(self, reactor, clock):
- config = self.default_config()
+ def default_config(self):
+ config = super().default_config()
config["third_party_event_rules"] = {
"module": __name__ + ".ThirdPartyRulesTestModule",
"config": {},
}
-
- self.hs = self.setup_test_homeserver(config=config)
- return self.hs
+ return config
def prepare(self, reactor, clock, homeserver):
# Create a user and room to play with during the tests
@@ -67,6 +75,14 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
"""Tests that a forbidden event is forbidden from being sent, but an allowed one
can be sent.
"""
+ # patch the rules module with a Mock which will return False for some event
+ # types
+ async def check(ev, state):
+ return ev.type != "foo.bar.forbidden"
+
+ callback = Mock(spec=[], side_effect=check)
+ current_rules_module().check_event_allowed = callback
+
request, channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % self.room_id,
@@ -76,6 +92,16 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
+ callback.assert_called_once()
+
+ # there should be various state events in the state arg: do some basic checks
+ state_arg = callback.call_args[0][1]
+ for k in (("m.room.create", ""), ("m.room.member", self.user_id)):
+ self.assertIn(k, state_arg)
+ ev = state_arg[k]
+ self.assertEqual(ev.type, k[0])
+ self.assertEqual(ev.state_key, k[1])
+
request, channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % self.room_id,
@@ -84,3 +110,35 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
)
self.render(request)
self.assertEquals(channel.result["code"], b"403", channel.result)
+
+ def test_modify_event(self):
+ """Tests that the module can successfully tweak an event before it is persisted.
+ """
+ # first patch the event checker so that it will modify the event
+ async def check(ev: EventBase, state):
+ ev.content = {"x": "y"}
+ return True
+
+ current_rules_module().check_event_allowed = check
+
+ # now send the event
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
+ {"x": "x"},
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ event_id = channel.json_body["event_id"]
+
+ # ... and check that it got modified
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ ev = channel.json_body
+ self.assertEqual(ev["content"]["x"], "y")
|