summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12740.feature1
-rw-r--r--synapse/config/experimental.py3
-rw-r--r--synapse/push/baserules.py14
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py71
-rw-r--r--synapse/push/clientformat.py4
-rw-r--r--synapse/push/push_rule_evaluator.py50
-rw-r--r--synapse/storage/databases/main/events.py9
-rw-r--r--synapse/storage/databases/main/push_rule.py5
-rw-r--r--synapse/storage/databases/main/relations.py52
-rw-r--r--tests/push/test_push_rule_evaluator.py84
10 files changed, 287 insertions, 6 deletions
diff --git a/changelog.d/12740.feature b/changelog.d/12740.feature
new file mode 100644
index 0000000000..e674c31ae8
--- /dev/null
+++ b/changelog.d/12740.feature
@@ -0,0 +1 @@
+Experimental support for [MSC3772](https://github.com/matrix-org/matrix-spec-proposals/pull/3772): Push rule for mutually related events.
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index b20d949689..cc417e2fbf 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -84,3 +84,6 @@ class ExperimentalConfig(Config):
 
         # MSC3786 (Add a default push rule to ignore m.room.server_acl events)
         self.msc3786_enabled: bool = experimental.get("msc3786_enabled", False)
+
+        # MSC3772: A push rule for mutual relations.
+        self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False)
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index a17b35a605..4c7278b5a1 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -139,6 +139,7 @@ BASE_APPEND_CONTENT_RULES: List[Dict[str, Any]] = [
             {
                 "kind": "event_match",
                 "key": "content.body",
+                # Match the localpart of the requester's MXID.
                 "pattern_type": "user_localpart",
             }
         ],
@@ -191,6 +192,7 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
                 "pattern": "invite",
                 "_cache_key": "_invite_member",
             },
+            # Match the requester's MXID.
             {"kind": "event_match", "key": "state_key", "pattern_type": "user_id"},
         ],
         "actions": [
@@ -351,6 +353,18 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
         ],
     },
     {
+        "rule_id": "global/underride/.org.matrix.msc3772.thread_reply",
+        "conditions": [
+            {
+                "kind": "org.matrix.msc3772.relation_match",
+                "rel_type": "m.thread",
+                # Match the requester's MXID.
+                "sender_type": "user_id",
+            }
+        ],
+        "actions": ["notify", {"set_tweak": "highlight", "value": False}],
+    },
+    {
         "rule_id": "global/underride/.m.rule.message",
         "conditions": [
             {
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 4cc8a2ecca..1a8e7ef3dc 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -13,8 +13,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import itertools
 import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import attr
 from prometheus_client import Counter
@@ -121,6 +122,9 @@ class BulkPushRuleEvaluator:
             resizable=False,
         )
 
+        # Whether to support MSC3772 is supported.
+        self._relations_match_enabled = self.hs.config.experimental.msc3772_enabled
+
     async def _get_rules_for_event(
         self, event: EventBase, context: EventContext
     ) -> Dict[str, List[Dict[str, Any]]]:
@@ -192,6 +196,60 @@ class BulkPushRuleEvaluator:
 
         return pl_event.content if pl_event else {}, sender_level
 
+    async def _get_mutual_relations(
+        self, event: EventBase, rules: Iterable[Dict[str, Any]]
+    ) -> Dict[str, Set[Tuple[str, str]]]:
+        """
+        Fetch event metadata for events which related to the same event as the given event.
+
+        If the given event has no relation information, returns an empty dictionary.
+
+        Args:
+            event_id: The event ID which is targeted by relations.
+            rules: The push rules which will be processed for this event.
+
+        Returns:
+            A dictionary of relation type to:
+                A set of tuples of:
+                    The sender
+                    The event type
+        """
+
+        # If the experimental feature is not enabled, skip fetching relations.
+        if not self._relations_match_enabled:
+            return {}
+
+        # If the event does not have a relation, then cannot have any mutual
+        # relations.
+        relation = relation_from_event(event)
+        if not relation:
+            return {}
+
+        # Pre-filter to figure out which relation types are interesting.
+        rel_types = set()
+        for rule in rules:
+            # Skip disabled rules.
+            if "enabled" in rule and not rule["enabled"]:
+                continue
+
+            for condition in rule["conditions"]:
+                if condition["kind"] != "org.matrix.msc3772.relation_match":
+                    continue
+
+                # rel_type is required.
+                rel_type = condition.get("rel_type")
+                if rel_type:
+                    rel_types.add(rel_type)
+
+        # If no valid rules were found, no mutual relations.
+        if not rel_types:
+            return {}
+
+        # If any valid rules were found, fetch the mutual relations.
+        return await self.store.get_mutual_event_relations(
+            relation.parent_id, rel_types
+        )
+
     @measure_func("action_for_event_by_user")
     async def action_for_event_by_user(
         self, event: EventBase, context: EventContext
@@ -216,8 +274,17 @@ class BulkPushRuleEvaluator:
             sender_power_level,
         ) = await self._get_power_levels_and_sender_level(event, context)
 
+        relations = await self._get_mutual_relations(
+            event, itertools.chain(*rules_by_user.values())
+        )
+
         evaluator = PushRuleEvaluatorForEvent(
-            event, len(room_members), sender_power_level, power_levels
+            event,
+            len(room_members),
+            sender_power_level,
+            power_levels,
+            relations,
+            self._relations_match_enabled,
         )
 
         # If the event is not a state event check if any users ignore the sender.
diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py
index 63b22d50ae..5117ef6854 100644
--- a/synapse/push/clientformat.py
+++ b/synapse/push/clientformat.py
@@ -48,6 +48,10 @@ def format_push_rules_for_user(
             elif pattern_type == "user_localpart":
                 c["pattern"] = user.localpart
 
+            sender_type = c.pop("sender_type", None)
+            if sender_type == "user_id":
+                c["sender"] = user.to_string()
+
         rulearray = rules["global"][template_name]
 
         template_rule = _rule_to_template(r)
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 54db6b5612..2e8a017add 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -15,7 +15,7 @@
 
 import logging
 import re
-from typing import Any, Dict, List, Mapping, Optional, Pattern, Tuple, Union
+from typing import Any, Dict, List, Mapping, Optional, Pattern, Set, Tuple, Union
 
 from matrix_common.regex import glob_to_regex, to_word_pattern
 
@@ -120,11 +120,15 @@ class PushRuleEvaluatorForEvent:
         room_member_count: int,
         sender_power_level: int,
         power_levels: Dict[str, Union[int, Dict[str, int]]],
+        relations: Dict[str, Set[Tuple[str, str]]],
+        relations_match_enabled: bool,
     ):
         self._event = event
         self._room_member_count = room_member_count
         self._sender_power_level = sender_power_level
         self._power_levels = power_levels
+        self._relations = relations
+        self._relations_match_enabled = relations_match_enabled
 
         # Maps strings of e.g. 'content.body' -> event["content"]["body"]
         self._value_cache = _flatten_dict(event)
@@ -188,7 +192,16 @@ class PushRuleEvaluatorForEvent:
             return _sender_notification_permission(
                 self._event, condition, self._sender_power_level, self._power_levels
             )
+        elif (
+            condition["kind"] == "org.matrix.msc3772.relation_match"
+            and self._relations_match_enabled
+        ):
+            return self._relation_match(condition, user_id)
         else:
+            # XXX This looks incorrect -- we have reached an unknown condition
+            #     kind and are unconditionally returning that it matches. Note
+            #     that it seems possible to provide a condition to the /pushrules
+            #     endpoint with an unknown kind, see _rule_tuple_from_request_object.
             return True
 
     def _event_match(self, condition: dict, user_id: str) -> bool:
@@ -256,6 +269,41 @@ class PushRuleEvaluatorForEvent:
 
         return bool(r.search(body))
 
+    def _relation_match(self, condition: dict, user_id: str) -> bool:
+        """
+        Check an "relation_match" push rule condition.
+
+        Args:
+            condition: The "event_match" push rule condition to match.
+            user_id: The user's MXID.
+
+        Returns:
+             True if the condition matches the event, False otherwise.
+        """
+        rel_type = condition.get("rel_type")
+        if not rel_type:
+            logger.warning("relation_match condition missing rel_type")
+            return False
+
+        sender_pattern = condition.get("sender")
+        if sender_pattern is None:
+            sender_type = condition.get("sender_type")
+            if sender_type == "user_id":
+                sender_pattern = user_id
+        type_pattern = condition.get("type")
+
+        # If any other relations matches, return True.
+        for sender, event_type in self._relations.get(rel_type, ()):
+            if sender_pattern and not _glob_matches(sender_pattern, sender):
+                continue
+            if type_pattern and not _glob_matches(type_pattern, event_type):
+                continue
+            # All values must have matched.
+            return True
+
+        # No relations matched.
+        return False
+
 
 # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
 regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache(
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 0df8ff5395..17e35cf63e 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1828,6 +1828,10 @@ class PersistEventsStore:
             self.store.get_aggregation_groups_for_event.invalidate,
             (relation.parent_id,),
         )
+        txn.call_after(
+            self.store.get_mutual_event_relations_for_rel_type.invalidate,
+            (relation.parent_id,),
+        )
 
         if relation.rel_type == RelationTypes.REPLACE:
             txn.call_after(
@@ -2004,6 +2008,11 @@ class PersistEventsStore:
             self.store._invalidate_cache_and_stream(
                 txn, self.store.get_thread_participated, (redacted_relates_to,)
             )
+            self.store._invalidate_cache_and_stream(
+                txn,
+                self.store.get_mutual_event_relations_for_rel_type,
+                (redacted_relates_to,),
+            )
 
         self.db_pool.simple_delete_txn(
             txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index ad67901cc1..4adabc88cc 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -61,6 +61,11 @@ def _is_experimental_rule_enabled(
         and not experimental_config.msc3786_enabled
     ):
         return False
+    if (
+        rule_id == "global/underride/.org.matrix.msc3772.thread_reply"
+        and not experimental_config.msc3772_enabled
+    ):
+        return False
     return True
 
 
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index fe8fded88b..3b1b2ce6cb 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 import logging
+from collections import defaultdict
 from typing import (
     Collection,
     Dict,
@@ -767,6 +768,57 @@ class RelationsWorkerStore(SQLBaseStore):
             "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
         )
 
+    @cached(iterable=True)
+    async def get_mutual_event_relations_for_rel_type(
+        self, event_id: str, relation_type: str
+    ) -> Set[Tuple[str, str]]:
+        raise NotImplementedError()
+
+    @cachedList(
+        cached_method_name="get_mutual_event_relations_for_rel_type",
+        list_name="relation_types",
+    )
+    async def get_mutual_event_relations(
+        self, event_id: str, relation_types: Collection[str]
+    ) -> Dict[str, Set[Tuple[str, str]]]:
+        """
+        Fetch event metadata for events which related to the same event as the given event.
+
+        If the given event has no relation information, returns an empty dictionary.
+
+        Args:
+            event_id: The event ID which is targeted by relations.
+            relation_types: The relation types to check for mutual relations.
+
+        Returns:
+            A dictionary of relation type to:
+                A set of tuples of:
+                    The sender
+                    The event type
+        """
+        rel_type_sql, rel_type_args = make_in_list_sql_clause(
+            self.database_engine, "relation_type", relation_types
+        )
+
+        sql = f"""
+            SELECT DISTINCT relation_type, sender, type FROM event_relations
+            INNER JOIN events USING (event_id)
+            WHERE relates_to_id = ? AND {rel_type_sql}
+        """
+
+        def _get_event_relations(
+            txn: LoggingTransaction,
+        ) -> Dict[str, Set[Tuple[str, str]]]:
+            txn.execute(sql, [event_id] + rel_type_args)
+            result = defaultdict(set)
+            for rel_type, sender, type in txn.fetchall():
+                result[rel_type].add((sender, type))
+            return result
+
+        return await self.db_pool.runInteraction(
+            "get_event_relations", _get_event_relations
+        )
+
 
 class RelationsStore(RelationsWorkerStore):
     pass
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 5dba187076..9b623d0033 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Dict, Optional, Union
+from typing import Dict, Optional, Set, Tuple, Union
 
 import frozendict
 
@@ -26,7 +26,12 @@ from tests import unittest
 
 
 class PushRuleEvaluatorTestCase(unittest.TestCase):
-    def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluatorForEvent:
+    def _get_evaluator(
+        self,
+        content: JsonDict,
+        relations: Optional[Dict[str, Set[Tuple[str, str]]]] = None,
+        relations_match_enabled: bool = False,
+    ) -> PushRuleEvaluatorForEvent:
         event = FrozenEvent(
             {
                 "event_id": "$event_id",
@@ -42,7 +47,12 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
         sender_power_level = 0
         power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
         return PushRuleEvaluatorForEvent(
-            event, room_member_count, sender_power_level, power_levels
+            event,
+            room_member_count,
+            sender_power_level,
+            power_levels,
+            relations or set(),
+            relations_match_enabled,
         )
 
     def test_display_name(self) -> None:
@@ -276,3 +286,71 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
             push_rule_evaluator.tweaks_for_actions(actions),
             {"sound": "default", "highlight": True},
         )
+
+    def test_relation_match(self) -> None:
+        """Test the relation_match push rule kind."""
+
+        # Check if the experimental feature is disabled.
+        evaluator = self._get_evaluator(
+            {}, {"m.annotation": {("@user:test", "m.reaction")}}
+        )
+        condition = {"kind": "relation_match"}
+        # Oddly, an unknown condition always matches.
+        self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
+
+        # A push rule evaluator with the experimental rule enabled.
+        evaluator = self._get_evaluator(
+            {}, {"m.annotation": {("@user:test", "m.reaction")}}, True
+        )
+
+        # Check just relation type.
+        condition = {
+            "kind": "org.matrix.msc3772.relation_match",
+            "rel_type": "m.annotation",
+        }
+        self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
+
+        # Check relation type and sender.
+        condition = {
+            "kind": "org.matrix.msc3772.relation_match",
+            "rel_type": "m.annotation",
+            "sender": "@user:test",
+        }
+        self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
+        condition = {
+            "kind": "org.matrix.msc3772.relation_match",
+            "rel_type": "m.annotation",
+            "sender": "@other:test",
+        }
+        self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
+
+        # Check relation type and event type.
+        condition = {
+            "kind": "org.matrix.msc3772.relation_match",
+            "rel_type": "m.annotation",
+            "type": "m.reaction",
+        }
+        self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
+
+        # Check just sender, this fails since rel_type is required.
+        condition = {
+            "kind": "org.matrix.msc3772.relation_match",
+            "sender": "@user:test",
+        }
+        self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
+
+        # Check sender glob.
+        condition = {
+            "kind": "org.matrix.msc3772.relation_match",
+            "rel_type": "m.annotation",
+            "sender": "@*:test",
+        }
+        self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
+
+        # Check event type glob.
+        condition = {
+            "kind": "org.matrix.msc3772.relation_match",
+            "rel_type": "m.annotation",
+            "event_type": "*.reaction",
+        }
+        self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))