summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/push/test_push_rule_evaluator.py147
1 files changed, 143 insertions, 4 deletions
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 516b65cc3c..6603447341 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -57,7 +57,7 @@ class FlattenDictTestCase(unittest.TestCase):
         )
 
     def test_non_string(self) -> None:
-        """Non-string items are dropped."""
+        """Booleans, ints, and nulls should be kept while other items are dropped."""
         input: Dict[str, Any] = {
             "woo": "woo",
             "foo": True,
@@ -66,7 +66,9 @@ class FlattenDictTestCase(unittest.TestCase):
             "fuzz": [],
             "boo": {},
         }
-        self.assertEqual({"woo": "woo"}, _flatten_dict(input))
+        self.assertEqual(
+            {"woo": "woo", "foo": True, "bar": 1, "baz": None}, _flatten_dict(input)
+        )
 
     def test_event(self) -> None:
         """Events can also be flattened."""
@@ -86,9 +88,9 @@ class FlattenDictTestCase(unittest.TestCase):
         )
         expected = {
             "content.msgtype": "m.text",
-            "content.body": "hello world!",
+            "content.body": "Hello world!",
             "content.format": "org.matrix.custom.html",
-            "content.formatted_body": "<h1>hello world!</h1>",
+            "content.formatted_body": "<h1>Hello world!</h1>",
             "room_id": "!test:test",
             "sender": "@alice:test",
             "type": "m.room.message",
@@ -166,6 +168,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
             related_event_match_enabled=True,
             room_version_feature_flags=event.room_version.msc3931_push_features,
             msc3931_enabled=True,
+            msc3758_exact_event_match=True,
         )
 
     def test_display_name(self) -> None:
@@ -410,6 +413,142 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
             "pattern should not match before a newline",
         )
 
+    def test_exact_event_match_string(self) -> None:
+        """Check that exact_event_match conditions work as expected for strings."""
+
+        # Test against a string value.
+        condition = {
+            "kind": "com.beeper.msc3758.exact_event_match",
+            "key": "content.value",
+            "value": "foobaz",
+        }
+        self._assert_matches(
+            condition,
+            {"value": "foobaz"},
+            "exact value should match",
+        )
+        self._assert_not_matches(
+            condition,
+            {"value": "FoobaZ"},
+            "values should match and be case-sensitive",
+        )
+        self._assert_not_matches(
+            condition,
+            {"value": "test foobaz test"},
+            "values must exactly match",
+        )
+        value: Any
+        for value in (True, False, 1, 1.1, None, [], {}):
+            self._assert_not_matches(
+                condition,
+                {"value": value},
+                "incorrect types should not match",
+            )
+
+        # it should work on frozendicts too
+        self._assert_matches(
+            condition,
+            frozendict.frozendict({"value": "foobaz"}),
+            "values should match on frozendicts",
+        )
+
+    def test_exact_event_match_boolean(self) -> None:
+        """Check that exact_event_match conditions work as expected for booleans."""
+
+        # Test against a True boolean value.
+        condition = {
+            "kind": "com.beeper.msc3758.exact_event_match",
+            "key": "content.value",
+            "value": True,
+        }
+        self._assert_matches(
+            condition,
+            {"value": True},
+            "exact value should match",
+        )
+        self._assert_not_matches(
+            condition,
+            {"value": False},
+            "incorrect values should not match",
+        )
+        for value in ("foobaz", 1, 1.1, None, [], {}):
+            self._assert_not_matches(
+                condition,
+                {"value": value},
+                "incorrect types should not match",
+            )
+
+        # Test against a False boolean value.
+        condition = {
+            "kind": "com.beeper.msc3758.exact_event_match",
+            "key": "content.value",
+            "value": False,
+        }
+        self._assert_matches(
+            condition,
+            {"value": False},
+            "exact value should match",
+        )
+        self._assert_not_matches(
+            condition,
+            {"value": True},
+            "incorrect values should not match",
+        )
+        # Choose false-y values to ensure there's no type coercion.
+        for value in ("", 0, 1.1, None, [], {}):
+            self._assert_not_matches(
+                condition,
+                {"value": value},
+                "incorrect types should not match",
+            )
+
+    def test_exact_event_match_null(self) -> None:
+        """Check that exact_event_match conditions work as expected for null."""
+
+        condition = {
+            "kind": "com.beeper.msc3758.exact_event_match",
+            "key": "content.value",
+            "value": None,
+        }
+        self._assert_matches(
+            condition,
+            {"value": None},
+            "exact value should match",
+        )
+        for value in ("foobaz", True, False, 1, 1.1, [], {}):
+            self._assert_not_matches(
+                condition,
+                {"value": value},
+                "incorrect types should not match",
+            )
+
+    def test_exact_event_match_integer(self) -> None:
+        """Check that exact_event_match conditions work as expected for integers."""
+
+        condition = {
+            "kind": "com.beeper.msc3758.exact_event_match",
+            "key": "content.value",
+            "value": 1,
+        }
+        self._assert_matches(
+            condition,
+            {"value": 1},
+            "exact value should match",
+        )
+        value: Any
+        for value in (1.1, -1, 0):
+            self._assert_not_matches(
+                condition,
+                {"value": value},
+                "incorrect values should not match",
+            )
+        for value in ("1", True, False, None, [], {}):
+            self._assert_not_matches(
+                condition,
+                {"value": value},
+                "incorrect types should not match",
+            )
+
     def test_no_body(self) -> None:
         """Not having a body shouldn't break the evaluator."""
         evaluator = self._get_evaluator({})