diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index decf619466..fe7c145840 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -38,7 +38,9 @@ from tests.test_utils.event_injection import create_event, inject_member_event
class PushRuleEvaluatorTestCase(unittest.TestCase):
- def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluator:
+ def _get_evaluator(
+ self, content: JsonDict, related_events=None
+ ) -> PushRuleEvaluator:
event = FrozenEvent(
{
"event_id": "$event_id",
@@ -58,6 +60,8 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
room_member_count,
sender_power_level,
power_levels.get("notifications", {}),
+ {} if related_events is None else related_events,
+ True,
)
def test_display_name(self) -> None:
@@ -292,6 +296,215 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
{"sound": "default", "highlight": True},
)
+ def test_related_event_match(self):
+ evaluator = self._get_evaluator(
+ {
+ "m.relates_to": {
+ "event_id": "$parent_event_id",
+ "key": "😀",
+ "rel_type": "m.annotation",
+ "m.in_reply_to": {
+ "event_id": "$parent_event_id",
+ },
+ }
+ },
+ {
+ "m.in_reply_to": {
+ "event_id": "$parent_event_id",
+ "type": "m.room.message",
+ "sender": "@other_user:test",
+ "room_id": "!room:test",
+ "content.msgtype": "m.text",
+ "content.body": "Original message",
+ },
+ "m.annotation": {
+ "event_id": "$parent_event_id",
+ "type": "m.room.message",
+ "sender": "@other_user:test",
+ "room_id": "!room:test",
+ "content.msgtype": "m.text",
+ "content.body": "Original message",
+ },
+ },
+ )
+ self.assertTrue(
+ evaluator.matches(
+ {
+ "kind": "im.nheko.msc3664.related_event_match",
+ "key": "sender",
+ "rel_type": "m.in_reply_to",
+ "pattern": "@other_user:test",
+ },
+ "@user:test",
+ "display_name",
+ )
+ )
+ self.assertFalse(
+ evaluator.matches(
+ {
+ "kind": "im.nheko.msc3664.related_event_match",
+ "key": "sender",
+ "rel_type": "m.in_reply_to",
+ "pattern": "@user:test",
+ },
+ "@other_user:test",
+ "display_name",
+ )
+ )
+ self.assertTrue(
+ evaluator.matches(
+ {
+ "kind": "im.nheko.msc3664.related_event_match",
+ "key": "sender",
+ "rel_type": "m.annotation",
+ "pattern": "@other_user:test",
+ },
+ "@other_user:test",
+ "display_name",
+ )
+ )
+ self.assertFalse(
+ evaluator.matches(
+ {
+ "kind": "im.nheko.msc3664.related_event_match",
+ "key": "sender",
+ "rel_type": "m.in_reply_to",
+ },
+ "@user:test",
+ "display_name",
+ )
+ )
+ self.assertTrue(
+ evaluator.matches(
+ {
+ "kind": "im.nheko.msc3664.related_event_match",
+ "rel_type": "m.in_reply_to",
+ },
+ "@user:test",
+ "display_name",
+ )
+ )
+ self.assertFalse(
+ evaluator.matches(
+ {
+ "kind": "im.nheko.msc3664.related_event_match",
+ "rel_type": "m.replace",
+ },
+ "@other_user:test",
+ "display_name",
+ )
+ )
+
+ def test_related_event_match_with_fallback(self):
+ evaluator = self._get_evaluator(
+ {
+ "m.relates_to": {
+ "event_id": "$parent_event_id",
+ "key": "😀",
+ "rel_type": "m.thread",
+ "is_falling_back": True,
+ "m.in_reply_to": {
+ "event_id": "$parent_event_id",
+ },
+ }
+ },
+ {
+ "m.in_reply_to": {
+ "event_id": "$parent_event_id",
+ "type": "m.room.message",
+ "sender": "@other_user:test",
+ "room_id": "!room:test",
+ "content.msgtype": "m.text",
+ "content.body": "Original message",
+ "im.vector.is_falling_back": "",
+ },
+ "m.thread": {
+ "event_id": "$parent_event_id",
+ "type": "m.room.message",
+ "sender": "@other_user:test",
+ "room_id": "!room:test",
+ "content.msgtype": "m.text",
+ "content.body": "Original message",
+ },
+ },
+ )
+ self.assertTrue(
+ evaluator.matches(
+ {
+ "kind": "im.nheko.msc3664.related_event_match",
+ "key": "sender",
+ "rel_type": "m.in_reply_to",
+ "pattern": "@other_user:test",
+ "include_fallbacks": True,
+ },
+ "@user:test",
+ "display_name",
+ )
+ )
+ self.assertFalse(
+ evaluator.matches(
+ {
+ "kind": "im.nheko.msc3664.related_event_match",
+ "key": "sender",
+ "rel_type": "m.in_reply_to",
+ "pattern": "@other_user:test",
+ "include_fallbacks": False,
+ },
+ "@user:test",
+ "display_name",
+ )
+ )
+ self.assertFalse(
+ evaluator.matches(
+ {
+ "kind": "im.nheko.msc3664.related_event_match",
+ "key": "sender",
+ "rel_type": "m.in_reply_to",
+ "pattern": "@other_user:test",
+ },
+ "@user:test",
+ "display_name",
+ )
+ )
+
+ def test_related_event_match_no_related_event(self):
+ evaluator = self._get_evaluator(
+ {"msgtype": "m.text", "body": "Message without related event"}
+ )
+ self.assertFalse(
+ evaluator.matches(
+ {
+ "kind": "im.nheko.msc3664.related_event_match",
+ "key": "sender",
+ "rel_type": "m.in_reply_to",
+ "pattern": "@other_user:test",
+ },
+ "@user:test",
+ "display_name",
+ )
+ )
+ self.assertFalse(
+ evaluator.matches(
+ {
+ "kind": "im.nheko.msc3664.related_event_match",
+ "key": "sender",
+ "rel_type": "m.in_reply_to",
+ },
+ "@user:test",
+ "display_name",
+ )
+ )
+ self.assertFalse(
+ evaluator.matches(
+ {
+ "kind": "im.nheko.msc3664.related_event_match",
+ "rel_type": "m.in_reply_to",
+ },
+ "@user:test",
+ "display_name",
+ )
+ )
+
class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
"""Tests for the bulk push rule evaluator"""
|