summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/push/test_bulk_push_rule_evaluator.py88
-rw-r--r--tests/push/test_push_rule_evaluator.py66
2 files changed, 147 insertions, 7 deletions
diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py
index 9c17a42b65..aba62b5dc8 100644
--- a/tests/push/test_bulk_push_rule_evaluator.py
+++ b/tests/push/test_bulk_push_rule_evaluator.py
@@ -12,10 +12,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import Any
 from unittest.mock import patch
 
 from twisted.test.proto_helpers import MemoryReactor
 
+from synapse.api.constants import EventContentFields
 from synapse.api.room_versions import RoomVersions
 from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
 from synapse.rest import admin
@@ -126,3 +128,89 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
         # Ensure no actions are generated!
         self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)]))
         bulk_evaluator._action_for_event_by_user.assert_not_called()
+
+    @override_config({"experimental_features": {"msc3952_intentional_mentions": True}})
+    def test_mentions(self) -> None:
+        """Test the behavior of an event which includes invalid mentions."""
+        bulk_evaluator = BulkPushRuleEvaluator(self.hs)
+
+        sentinel = object()
+
+        def create_and_process(mentions: Any = sentinel) -> bool:
+            """Returns true iff the `mentions` trigger an event push action."""
+            content = {}
+            if mentions is not sentinel:
+                content[EventContentFields.MSC3952_MENTIONS] = mentions
+
+            # Create a new message event which should cause a notification.
+            event, context = self.get_success(
+                self.event_creation_handler.create_event(
+                    self.requester,
+                    {
+                        "type": "test",
+                        "room_id": self.room_id,
+                        "content": content,
+                        "sender": f"@bob:{self.hs.hostname}",
+                    },
+                )
+            )
+
+            # Ensure no actions are generated!
+            self.get_success(
+                bulk_evaluator.action_for_events_by_user([(event, context)])
+            )
+
+            # If any actions are generated for this event, return true.
+            result = self.get_success(
+                self.hs.get_datastores().main.db_pool.simple_select_list(
+                    table="event_push_actions_staging",
+                    keyvalues={"event_id": event.event_id},
+                    retcols=("*",),
+                    desc="get_event_push_actions_staging",
+                )
+            )
+            return len(result) > 0
+
+        # Not including the mentions field should not notify.
+        self.assertFalse(create_and_process())
+        # An empty mentions field should not notify.
+        self.assertFalse(create_and_process({}))
+
+        # Non-dict mentions should be ignored.
+        mentions: Any
+        for mentions in (None, True, False, 1, "foo", []):
+            self.assertFalse(create_and_process(mentions))
+
+        # A non-list should be ignored.
+        for mentions in (None, True, False, 1, "foo", {}):
+            self.assertFalse(create_and_process({"user_ids": mentions}))
+
+        # The Matrix ID appearing anywhere in the list should notify.
+        self.assertTrue(create_and_process({"user_ids": [self.alice]}))
+        self.assertTrue(create_and_process({"user_ids": ["@another:test", self.alice]}))
+
+        # Duplicate user IDs should notify.
+        self.assertTrue(create_and_process({"user_ids": [self.alice, self.alice]}))
+
+        # Invalid entries in the list are ignored.
+        self.assertFalse(create_and_process({"user_ids": [None, True, False, {}, []]}))
+        self.assertTrue(
+            create_and_process({"user_ids": [None, True, False, {}, [], self.alice]})
+        )
+
+        # Room mentions from those without power should not notify.
+        self.assertFalse(create_and_process({"room": True}))
+
+        # Room mentions from those with power should notify.
+        self.helper.send_state(
+            self.room_id,
+            "m.room.power_levels",
+            {"notifications": {"room": 0}},
+            self.token,
+            state_key="",
+        )
+        self.assertTrue(create_and_process({"room": True}))
+
+        # Invalid data should not notify.
+        for mentions in (None, False, 1, "foo", [], {}):
+            self.assertFalse(create_and_process({"room": mentions}))
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 1b87756b75..9d01c989d4 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Dict, List, Optional, Union, cast
+from typing import Dict, List, Optional, Set, Union, cast
 
 import frozendict
 
@@ -39,7 +39,12 @@ from tests.test_utils.event_injection import create_event, inject_member_event
 
 class PushRuleEvaluatorTestCase(unittest.TestCase):
     def _get_evaluator(
-        self, content: JsonMapping, related_events: Optional[JsonDict] = None
+        self,
+        content: JsonMapping,
+        *,
+        user_mentions: Optional[Set[str]] = None,
+        room_mention: bool = False,
+        related_events: Optional[JsonDict] = None,
     ) -> PushRuleEvaluator:
         event = FrozenEvent(
             {
@@ -57,13 +62,15 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
         power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
         return PushRuleEvaluator(
             _flatten_dict(event),
+            user_mentions or set(),
+            room_mention,
             room_member_count,
             sender_power_level,
             cast(Dict[str, int], power_levels.get("notifications", {})),
             {} if related_events is None else related_events,
-            True,
-            event.room_version.msc3931_push_features,
-            True,
+            related_event_match_enabled=True,
+            room_version_feature_flags=event.room_version.msc3931_push_features,
+            msc3931_enabled=True,
         )
 
     def test_display_name(self) -> None:
@@ -90,6 +97,51 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
         # A display name with spaces should work fine.
         self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar"))
 
+    def test_user_mentions(self) -> None:
+        """Check for user mentions."""
+        condition = {"kind": "org.matrix.msc3952.is_user_mention"}
+
+        # No mentions shouldn't match.
+        evaluator = self._get_evaluator({})
+        self.assertFalse(evaluator.matches(condition, "@user:test", None))
+
+        # An empty set shouldn't match
+        evaluator = self._get_evaluator({}, user_mentions=set())
+        self.assertFalse(evaluator.matches(condition, "@user:test", None))
+
+        # The Matrix ID appearing anywhere in the mentions list should match
+        evaluator = self._get_evaluator({}, user_mentions={"@user:test"})
+        self.assertTrue(evaluator.matches(condition, "@user:test", None))
+
+        evaluator = self._get_evaluator(
+            {}, user_mentions={"@another:test", "@user:test"}
+        )
+        self.assertTrue(evaluator.matches(condition, "@user:test", None))
+
+        # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions
+        # since the BulkPushRuleEvaluator is what handles data sanitisation.
+
+    def test_room_mentions(self) -> None:
+        """Check for room mentions."""
+        condition = {"kind": "org.matrix.msc3952.is_room_mention"}
+
+        # No room mention shouldn't match.
+        evaluator = self._get_evaluator({})
+        self.assertFalse(evaluator.matches(condition, None, None))
+
+        # Room mention should match.
+        evaluator = self._get_evaluator({}, room_mention=True)
+        self.assertTrue(evaluator.matches(condition, None, None))
+
+        # A room mention and user mention is valid.
+        evaluator = self._get_evaluator(
+            {}, user_mentions={"@another:test"}, room_mention=True
+        )
+        self.assertTrue(evaluator.matches(condition, None, None))
+
+        # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions
+        # since the BulkPushRuleEvaluator is what handles data sanitisation.
+
     def _assert_matches(
         self, condition: JsonDict, content: JsonMapping, msg: Optional[str] = None
     ) -> None:
@@ -308,7 +360,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
                     },
                 }
             },
-            {
+            related_events={
                 "m.in_reply_to": {
                     "event_id": "$parent_event_id",
                     "type": "m.room.message",
@@ -408,7 +460,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
                     },
                 }
             },
-            {
+            related_events={
                 "m.in_reply_to": {
                     "event_id": "$parent_event_id",
                     "type": "m.room.message",