summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/push/test_push_rule_evaluator.py215
1 files changed, 214 insertions, 1 deletions
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index decf619466..fe7c145840 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -38,7 +38,9 @@ from tests.test_utils.event_injection import create_event, inject_member_event
 
 
 class PushRuleEvaluatorTestCase(unittest.TestCase):
-    def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluator:
+    def _get_evaluator(
+        self, content: JsonDict, related_events=None
+    ) -> PushRuleEvaluator:
         event = FrozenEvent(
             {
                 "event_id": "$event_id",
@@ -58,6 +60,8 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
             room_member_count,
             sender_power_level,
             power_levels.get("notifications", {}),
+            {} if related_events is None else related_events,
+            True,
         )
 
     def test_display_name(self) -> None:
@@ -292,6 +296,215 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
             {"sound": "default", "highlight": True},
         )
 
+    def test_related_event_match(self):
+        evaluator = self._get_evaluator(
+            {
+                "m.relates_to": {
+                    "event_id": "$parent_event_id",
+                    "key": "😀",
+                    "rel_type": "m.annotation",
+                    "m.in_reply_to": {
+                        "event_id": "$parent_event_id",
+                    },
+                }
+            },
+            {
+                "m.in_reply_to": {
+                    "event_id": "$parent_event_id",
+                    "type": "m.room.message",
+                    "sender": "@other_user:test",
+                    "room_id": "!room:test",
+                    "content.msgtype": "m.text",
+                    "content.body": "Original message",
+                },
+                "m.annotation": {
+                    "event_id": "$parent_event_id",
+                    "type": "m.room.message",
+                    "sender": "@other_user:test",
+                    "room_id": "!room:test",
+                    "content.msgtype": "m.text",
+                    "content.body": "Original message",
+                },
+            },
+        )
+        self.assertTrue(
+            evaluator.matches(
+                {
+                    "kind": "im.nheko.msc3664.related_event_match",
+                    "key": "sender",
+                    "rel_type": "m.in_reply_to",
+                    "pattern": "@other_user:test",
+                },
+                "@user:test",
+                "display_name",
+            )
+        )
+        self.assertFalse(
+            evaluator.matches(
+                {
+                    "kind": "im.nheko.msc3664.related_event_match",
+                    "key": "sender",
+                    "rel_type": "m.in_reply_to",
+                    "pattern": "@user:test",
+                },
+                "@other_user:test",
+                "display_name",
+            )
+        )
+        self.assertTrue(
+            evaluator.matches(
+                {
+                    "kind": "im.nheko.msc3664.related_event_match",
+                    "key": "sender",
+                    "rel_type": "m.annotation",
+                    "pattern": "@other_user:test",
+                },
+                "@other_user:test",
+                "display_name",
+            )
+        )
+        self.assertFalse(
+            evaluator.matches(
+                {
+                    "kind": "im.nheko.msc3664.related_event_match",
+                    "key": "sender",
+                    "rel_type": "m.in_reply_to",
+                },
+                "@user:test",
+                "display_name",
+            )
+        )
+        self.assertTrue(
+            evaluator.matches(
+                {
+                    "kind": "im.nheko.msc3664.related_event_match",
+                    "rel_type": "m.in_reply_to",
+                },
+                "@user:test",
+                "display_name",
+            )
+        )
+        self.assertFalse(
+            evaluator.matches(
+                {
+                    "kind": "im.nheko.msc3664.related_event_match",
+                    "rel_type": "m.replace",
+                },
+                "@other_user:test",
+                "display_name",
+            )
+        )
+
+    def test_related_event_match_with_fallback(self):
+        evaluator = self._get_evaluator(
+            {
+                "m.relates_to": {
+                    "event_id": "$parent_event_id",
+                    "key": "😀",
+                    "rel_type": "m.thread",
+                    "is_falling_back": True,
+                    "m.in_reply_to": {
+                        "event_id": "$parent_event_id",
+                    },
+                }
+            },
+            {
+                "m.in_reply_to": {
+                    "event_id": "$parent_event_id",
+                    "type": "m.room.message",
+                    "sender": "@other_user:test",
+                    "room_id": "!room:test",
+                    "content.msgtype": "m.text",
+                    "content.body": "Original message",
+                    "im.vector.is_falling_back": "",
+                },
+                "m.thread": {
+                    "event_id": "$parent_event_id",
+                    "type": "m.room.message",
+                    "sender": "@other_user:test",
+                    "room_id": "!room:test",
+                    "content.msgtype": "m.text",
+                    "content.body": "Original message",
+                },
+            },
+        )
+        self.assertTrue(
+            evaluator.matches(
+                {
+                    "kind": "im.nheko.msc3664.related_event_match",
+                    "key": "sender",
+                    "rel_type": "m.in_reply_to",
+                    "pattern": "@other_user:test",
+                    "include_fallbacks": True,
+                },
+                "@user:test",
+                "display_name",
+            )
+        )
+        self.assertFalse(
+            evaluator.matches(
+                {
+                    "kind": "im.nheko.msc3664.related_event_match",
+                    "key": "sender",
+                    "rel_type": "m.in_reply_to",
+                    "pattern": "@other_user:test",
+                    "include_fallbacks": False,
+                },
+                "@user:test",
+                "display_name",
+            )
+        )
+        self.assertFalse(
+            evaluator.matches(
+                {
+                    "kind": "im.nheko.msc3664.related_event_match",
+                    "key": "sender",
+                    "rel_type": "m.in_reply_to",
+                    "pattern": "@other_user:test",
+                },
+                "@user:test",
+                "display_name",
+            )
+        )
+
+    def test_related_event_match_no_related_event(self):
+        evaluator = self._get_evaluator(
+            {"msgtype": "m.text", "body": "Message without related event"}
+        )
+        self.assertFalse(
+            evaluator.matches(
+                {
+                    "kind": "im.nheko.msc3664.related_event_match",
+                    "key": "sender",
+                    "rel_type": "m.in_reply_to",
+                    "pattern": "@other_user:test",
+                },
+                "@user:test",
+                "display_name",
+            )
+        )
+        self.assertFalse(
+            evaluator.matches(
+                {
+                    "kind": "im.nheko.msc3664.related_event_match",
+                    "key": "sender",
+                    "rel_type": "m.in_reply_to",
+                },
+                "@user:test",
+                "display_name",
+            )
+        )
+        self.assertFalse(
+            evaluator.matches(
+                {
+                    "kind": "im.nheko.msc3664.related_event_match",
+                    "rel_type": "m.in_reply_to",
+                },
+                "@user:test",
+                "display_name",
+            )
+        )
+
 
 class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
     """Tests for the bulk push rule evaluator"""