summary refs log tree commit diff
path: root/tests/push/test_push_rule_evaluator.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/push/test_push_rule_evaluator.py')
-rw-r--r--tests/push/test_push_rule_evaluator.py84
1 files changed, 81 insertions, 3 deletions
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 5dba187076..9b623d0033 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Dict, Optional, Union
+from typing import Dict, Optional, Set, Tuple, Union
 
 import frozendict
 
@@ -26,7 +26,12 @@ from tests import unittest
 
 
 class PushRuleEvaluatorTestCase(unittest.TestCase):
-    def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluatorForEvent:
+    def _get_evaluator(
+        self,
+        content: JsonDict,
+        relations: Optional[Dict[str, Set[Tuple[str, str]]]] = None,
+        relations_match_enabled: bool = False,
+    ) -> PushRuleEvaluatorForEvent:
         event = FrozenEvent(
             {
                 "event_id": "$event_id",
@@ -42,7 +47,12 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
         sender_power_level = 0
         power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
         return PushRuleEvaluatorForEvent(
-            event, room_member_count, sender_power_level, power_levels
+            event,
+            room_member_count,
+            sender_power_level,
+            power_levels,
+            relations or set(),
+            relations_match_enabled,
         )
 
     def test_display_name(self) -> None:
@@ -276,3 +286,71 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
             push_rule_evaluator.tweaks_for_actions(actions),
             {"sound": "default", "highlight": True},
         )
+
+    def test_relation_match(self) -> None:
+        """Test the relation_match push rule kind."""
+
+        # Check if the experimental feature is disabled.
+        evaluator = self._get_evaluator(
+            {}, {"m.annotation": {("@user:test", "m.reaction")}}
+        )
+        condition = {"kind": "relation_match"}
+        # Oddly, an unknown condition always matches.
+        self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
+
+        # A push rule evaluator with the experimental rule enabled.
+        evaluator = self._get_evaluator(
+            {}, {"m.annotation": {("@user:test", "m.reaction")}}, True
+        )
+
+        # Check just relation type.
+        condition = {
+            "kind": "org.matrix.msc3772.relation_match",
+            "rel_type": "m.annotation",
+        }
+        self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
+
+        # Check relation type and sender.
+        condition = {
+            "kind": "org.matrix.msc3772.relation_match",
+            "rel_type": "m.annotation",
+            "sender": "@user:test",
+        }
+        self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
+        condition = {
+            "kind": "org.matrix.msc3772.relation_match",
+            "rel_type": "m.annotation",
+            "sender": "@other:test",
+        }
+        self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
+
+        # Check relation type and event type.
+        condition = {
+            "kind": "org.matrix.msc3772.relation_match",
+            "rel_type": "m.annotation",
+            "type": "m.reaction",
+        }
+        self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
+
+        # Check just sender, this fails since rel_type is required.
+        condition = {
+            "kind": "org.matrix.msc3772.relation_match",
+            "sender": "@user:test",
+        }
+        self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
+
+        # Check sender glob.
+        condition = {
+            "kind": "org.matrix.msc3772.relation_match",
+            "rel_type": "m.annotation",
+            "sender": "@*:test",
+        }
+        self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
+
+        # Check event type glob.
+        condition = {
+            "kind": "org.matrix.msc3772.relation_match",
+            "rel_type": "m.annotation",
+            "event_type": "*.reaction",
+        }
+        self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))