diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py
index 9c17a42b65..aba62b5dc8 100644
--- a/tests/push/test_bulk_push_rule_evaluator.py
+++ b/tests/push/test_bulk_push_rule_evaluator.py
@@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any
from unittest.mock import patch
from twisted.test.proto_helpers import MemoryReactor
+from synapse.api.constants import EventContentFields
from synapse.api.room_versions import RoomVersions
from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
from synapse.rest import admin
@@ -126,3 +128,89 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
# Ensure no actions are generated!
self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)]))
bulk_evaluator._action_for_event_by_user.assert_not_called()
+
+ @override_config({"experimental_features": {"msc3952_intentional_mentions": True}})
+ def test_mentions(self) -> None:
+ """Test the behavior of an event which includes invalid mentions."""
+ bulk_evaluator = BulkPushRuleEvaluator(self.hs)
+
+ sentinel = object()
+
+ def create_and_process(mentions: Any = sentinel) -> bool:
+ """Returns true iff the `mentions` trigger an event push action."""
+ content = {}
+ if mentions is not sentinel:
+ content[EventContentFields.MSC3952_MENTIONS] = mentions
+
+ # Create a new message event which should cause a notification.
+ event, context = self.get_success(
+ self.event_creation_handler.create_event(
+ self.requester,
+ {
+ "type": "test",
+ "room_id": self.room_id,
+ "content": content,
+ "sender": f"@bob:{self.hs.hostname}",
+ },
+ )
+ )
+
+ # Ensure no actions are generated!
+ self.get_success(
+ bulk_evaluator.action_for_events_by_user([(event, context)])
+ )
+
+ # If any actions are generated for this event, return true.
+ result = self.get_success(
+ self.hs.get_datastores().main.db_pool.simple_select_list(
+ table="event_push_actions_staging",
+ keyvalues={"event_id": event.event_id},
+ retcols=("*",),
+ desc="get_event_push_actions_staging",
+ )
+ )
+ return len(result) > 0
+
+ # Not including the mentions field should not notify.
+ self.assertFalse(create_and_process())
+ # An empty mentions field should not notify.
+ self.assertFalse(create_and_process({}))
+
+ # Non-dict mentions should be ignored.
+ mentions: Any
+ for mentions in (None, True, False, 1, "foo", []):
+ self.assertFalse(create_and_process(mentions))
+
+ # A non-list should be ignored.
+ for mentions in (None, True, False, 1, "foo", {}):
+ self.assertFalse(create_and_process({"user_ids": mentions}))
+
+ # The Matrix ID appearing anywhere in the list should notify.
+ self.assertTrue(create_and_process({"user_ids": [self.alice]}))
+ self.assertTrue(create_and_process({"user_ids": ["@another:test", self.alice]}))
+
+ # Duplicate user IDs should notify.
+ self.assertTrue(create_and_process({"user_ids": [self.alice, self.alice]}))
+
+ # Invalid entries in the list are ignored.
+ self.assertFalse(create_and_process({"user_ids": [None, True, False, {}, []]}))
+ self.assertTrue(
+ create_and_process({"user_ids": [None, True, False, {}, [], self.alice]})
+ )
+
+ # Room mentions from those without power should not notify.
+ self.assertFalse(create_and_process({"room": True}))
+
+ # Room mentions from those with power should notify.
+ self.helper.send_state(
+ self.room_id,
+ "m.room.power_levels",
+ {"notifications": {"room": 0}},
+ self.token,
+ state_key="",
+ )
+ self.assertTrue(create_and_process({"room": True}))
+
+ # Invalid data should not notify.
+ for mentions in (None, False, 1, "foo", [], {}):
+ self.assertFalse(create_and_process({"room": mentions}))
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 1b87756b75..9d01c989d4 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, List, Optional, Union, cast
+from typing import Dict, List, Optional, Set, Union, cast
import frozendict
@@ -39,7 +39,12 @@ from tests.test_utils.event_injection import create_event, inject_member_event
class PushRuleEvaluatorTestCase(unittest.TestCase):
def _get_evaluator(
- self, content: JsonMapping, related_events: Optional[JsonDict] = None
+ self,
+ content: JsonMapping,
+ *,
+ user_mentions: Optional[Set[str]] = None,
+ room_mention: bool = False,
+ related_events: Optional[JsonDict] = None,
) -> PushRuleEvaluator:
event = FrozenEvent(
{
@@ -57,13 +62,15 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
return PushRuleEvaluator(
_flatten_dict(event),
+ user_mentions or set(),
+ room_mention,
room_member_count,
sender_power_level,
cast(Dict[str, int], power_levels.get("notifications", {})),
{} if related_events is None else related_events,
- True,
- event.room_version.msc3931_push_features,
- True,
+ related_event_match_enabled=True,
+ room_version_feature_flags=event.room_version.msc3931_push_features,
+ msc3931_enabled=True,
)
def test_display_name(self) -> None:
@@ -90,6 +97,51 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
# A display name with spaces should work fine.
self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar"))
+ def test_user_mentions(self) -> None:
+ """Check for user mentions."""
+ condition = {"kind": "org.matrix.msc3952.is_user_mention"}
+
+ # No mentions shouldn't match.
+ evaluator = self._get_evaluator({})
+ self.assertFalse(evaluator.matches(condition, "@user:test", None))
+
+ # An empty set shouldn't match
+ evaluator = self._get_evaluator({}, user_mentions=set())
+ self.assertFalse(evaluator.matches(condition, "@user:test", None))
+
+ # The Matrix ID appearing anywhere in the mentions list should match
+ evaluator = self._get_evaluator({}, user_mentions={"@user:test"})
+ self.assertTrue(evaluator.matches(condition, "@user:test", None))
+
+ evaluator = self._get_evaluator(
+ {}, user_mentions={"@another:test", "@user:test"}
+ )
+ self.assertTrue(evaluator.matches(condition, "@user:test", None))
+
+ # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions
+ # since the BulkPushRuleEvaluator is what handles data sanitisation.
+
+ def test_room_mentions(self) -> None:
+ """Check for room mentions."""
+ condition = {"kind": "org.matrix.msc3952.is_room_mention"}
+
+ # No room mention shouldn't match.
+ evaluator = self._get_evaluator({})
+ self.assertFalse(evaluator.matches(condition, None, None))
+
+ # Room mention should match.
+ evaluator = self._get_evaluator({}, room_mention=True)
+ self.assertTrue(evaluator.matches(condition, None, None))
+
+ # A room mention and user mention is valid.
+ evaluator = self._get_evaluator(
+ {}, user_mentions={"@another:test"}, room_mention=True
+ )
+ self.assertTrue(evaluator.matches(condition, None, None))
+
+ # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions
+ # since the BulkPushRuleEvaluator is what handles data sanitisation.
+
def _assert_matches(
self, condition: JsonDict, content: JsonMapping, msg: Optional[str] = None
) -> None:
@@ -308,7 +360,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
},
}
},
- {
+ related_events={
"m.in_reply_to": {
"event_id": "$parent_event_id",
"type": "m.room.message",
@@ -408,7 +460,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
},
}
},
- {
+ related_events={
"m.in_reply_to": {
"event_id": "$parent_event_id",
"type": "m.room.message",
|