diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index b20d949689..cc417e2fbf 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -84,3 +84,6 @@ class ExperimentalConfig(Config):
# MSC3786 (Add a default push rule to ignore m.room.server_acl events)
self.msc3786_enabled: bool = experimental.get("msc3786_enabled", False)
+
+ # MSC3772: A push rule for mutual relations.
+ self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False)
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index a17b35a605..4c7278b5a1 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -139,6 +139,7 @@ BASE_APPEND_CONTENT_RULES: List[Dict[str, Any]] = [
{
"kind": "event_match",
"key": "content.body",
+ # Match the localpart of the requester's MXID.
"pattern_type": "user_localpart",
}
],
@@ -191,6 +192,7 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
"pattern": "invite",
"_cache_key": "_invite_member",
},
+ # Match the requester's MXID.
{"kind": "event_match", "key": "state_key", "pattern_type": "user_id"},
],
"actions": [
@@ -351,6 +353,18 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
],
},
{
+ "rule_id": "global/underride/.org.matrix.msc3772.thread_reply",
+ "conditions": [
+ {
+ "kind": "org.matrix.msc3772.relation_match",
+ "rel_type": "m.thread",
+ # Match the requester's MXID.
+ "sender_type": "user_id",
+ }
+ ],
+ "actions": ["notify", {"set_tweak": "highlight", "value": False}],
+ },
+ {
"rule_id": "global/underride/.m.rule.message",
"conditions": [
{
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 4cc8a2ecca..1a8e7ef3dc 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -13,8 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import itertools
import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import attr
from prometheus_client import Counter
@@ -121,6 +122,9 @@ class BulkPushRuleEvaluator:
resizable=False,
)
+ # Whether to support MSC3772 is supported.
+ self._relations_match_enabled = self.hs.config.experimental.msc3772_enabled
+
async def _get_rules_for_event(
self, event: EventBase, context: EventContext
) -> Dict[str, List[Dict[str, Any]]]:
@@ -192,6 +196,60 @@ class BulkPushRuleEvaluator:
return pl_event.content if pl_event else {}, sender_level
+ async def _get_mutual_relations(
+ self, event: EventBase, rules: Iterable[Dict[str, Any]]
+ ) -> Dict[str, Set[Tuple[str, str]]]:
+ """
+ Fetch event metadata for events which related to the same event as the given event.
+
+ If the given event has no relation information, returns an empty dictionary.
+
+ Args:
+ event_id: The event ID which is targeted by relations.
+ rules: The push rules which will be processed for this event.
+
+ Returns:
+ A dictionary of relation type to:
+ A set of tuples of:
+ The sender
+ The event type
+ """
+
+ # If the experimental feature is not enabled, skip fetching relations.
+ if not self._relations_match_enabled:
+ return {}
+
+ # If the event does not have a relation, then cannot have any mutual
+ # relations.
+ relation = relation_from_event(event)
+ if not relation:
+ return {}
+
+ # Pre-filter to figure out which relation types are interesting.
+ rel_types = set()
+ for rule in rules:
+ # Skip disabled rules.
+ if "enabled" in rule and not rule["enabled"]:
+ continue
+
+ for condition in rule["conditions"]:
+ if condition["kind"] != "org.matrix.msc3772.relation_match":
+ continue
+
+ # rel_type is required.
+ rel_type = condition.get("rel_type")
+ if rel_type:
+ rel_types.add(rel_type)
+
+ # If no valid rules were found, no mutual relations.
+ if not rel_types:
+ return {}
+
+ # If any valid rules were found, fetch the mutual relations.
+ return await self.store.get_mutual_event_relations(
+ relation.parent_id, rel_types
+ )
+
@measure_func("action_for_event_by_user")
async def action_for_event_by_user(
self, event: EventBase, context: EventContext
@@ -216,8 +274,17 @@ class BulkPushRuleEvaluator:
sender_power_level,
) = await self._get_power_levels_and_sender_level(event, context)
+ relations = await self._get_mutual_relations(
+ event, itertools.chain(*rules_by_user.values())
+ )
+
evaluator = PushRuleEvaluatorForEvent(
- event, len(room_members), sender_power_level, power_levels
+ event,
+ len(room_members),
+ sender_power_level,
+ power_levels,
+ relations,
+ self._relations_match_enabled,
)
# If the event is not a state event check if any users ignore the sender.
diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py
index 63b22d50ae..5117ef6854 100644
--- a/synapse/push/clientformat.py
+++ b/synapse/push/clientformat.py
@@ -48,6 +48,10 @@ def format_push_rules_for_user(
elif pattern_type == "user_localpart":
c["pattern"] = user.localpart
+ sender_type = c.pop("sender_type", None)
+ if sender_type == "user_id":
+ c["sender"] = user.to_string()
+
rulearray = rules["global"][template_name]
template_rule = _rule_to_template(r)
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 54db6b5612..2e8a017add 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -15,7 +15,7 @@
import logging
import re
-from typing import Any, Dict, List, Mapping, Optional, Pattern, Tuple, Union
+from typing import Any, Dict, List, Mapping, Optional, Pattern, Set, Tuple, Union
from matrix_common.regex import glob_to_regex, to_word_pattern
@@ -120,11 +120,15 @@ class PushRuleEvaluatorForEvent:
room_member_count: int,
sender_power_level: int,
power_levels: Dict[str, Union[int, Dict[str, int]]],
+ relations: Dict[str, Set[Tuple[str, str]]],
+ relations_match_enabled: bool,
):
self._event = event
self._room_member_count = room_member_count
self._sender_power_level = sender_power_level
self._power_levels = power_levels
+ self._relations = relations
+ self._relations_match_enabled = relations_match_enabled
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
self._value_cache = _flatten_dict(event)
@@ -188,7 +192,16 @@ class PushRuleEvaluatorForEvent:
return _sender_notification_permission(
self._event, condition, self._sender_power_level, self._power_levels
)
+ elif (
+ condition["kind"] == "org.matrix.msc3772.relation_match"
+ and self._relations_match_enabled
+ ):
+ return self._relation_match(condition, user_id)
else:
+ # XXX This looks incorrect -- we have reached an unknown condition
+ # kind and are unconditionally returning that it matches. Note
+ # that it seems possible to provide a condition to the /pushrules
+ # endpoint with an unknown kind, see _rule_tuple_from_request_object.
return True
def _event_match(self, condition: dict, user_id: str) -> bool:
@@ -256,6 +269,41 @@ class PushRuleEvaluatorForEvent:
return bool(r.search(body))
+ def _relation_match(self, condition: dict, user_id: str) -> bool:
+ """
+ Check an "relation_match" push rule condition.
+
+ Args:
+ condition: The "event_match" push rule condition to match.
+ user_id: The user's MXID.
+
+ Returns:
+ True if the condition matches the event, False otherwise.
+ """
+ rel_type = condition.get("rel_type")
+ if not rel_type:
+ logger.warning("relation_match condition missing rel_type")
+ return False
+
+ sender_pattern = condition.get("sender")
+ if sender_pattern is None:
+ sender_type = condition.get("sender_type")
+ if sender_type == "user_id":
+ sender_pattern = user_id
+ type_pattern = condition.get("type")
+
+ # If any other relations matches, return True.
+ for sender, event_type in self._relations.get(rel_type, ()):
+ if sender_pattern and not _glob_matches(sender_pattern, sender):
+ continue
+ if type_pattern and not _glob_matches(type_pattern, event_type):
+ continue
+ # All values must have matched.
+ return True
+
+ # No relations matched.
+ return False
+
# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache(
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 0df8ff5395..17e35cf63e 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1828,6 +1828,10 @@ class PersistEventsStore:
self.store.get_aggregation_groups_for_event.invalidate,
(relation.parent_id,),
)
+ txn.call_after(
+ self.store.get_mutual_event_relations_for_rel_type.invalidate,
+ (relation.parent_id,),
+ )
if relation.rel_type == RelationTypes.REPLACE:
txn.call_after(
@@ -2004,6 +2008,11 @@ class PersistEventsStore:
self.store._invalidate_cache_and_stream(
txn, self.store.get_thread_participated, (redacted_relates_to,)
)
+ self.store._invalidate_cache_and_stream(
+ txn,
+ self.store.get_mutual_event_relations_for_rel_type,
+ (redacted_relates_to,),
+ )
self.db_pool.simple_delete_txn(
txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index ad67901cc1..4adabc88cc 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -61,6 +61,11 @@ def _is_experimental_rule_enabled(
and not experimental_config.msc3786_enabled
):
return False
+ if (
+ rule_id == "global/underride/.org.matrix.msc3772.thread_reply"
+ and not experimental_config.msc3772_enabled
+ ):
+ return False
return True
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index fe8fded88b..3b1b2ce6cb 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,6 +13,7 @@
# limitations under the License.
import logging
+from collections import defaultdict
from typing import (
Collection,
Dict,
@@ -767,6 +768,57 @@ class RelationsWorkerStore(SQLBaseStore):
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)
+ @cached(iterable=True)
+ async def get_mutual_event_relations_for_rel_type(
+ self, event_id: str, relation_type: str
+ ) -> Set[Tuple[str, str]]:
+ raise NotImplementedError()
+
+ @cachedList(
+ cached_method_name="get_mutual_event_relations_for_rel_type",
+ list_name="relation_types",
+ )
+ async def get_mutual_event_relations(
+ self, event_id: str, relation_types: Collection[str]
+ ) -> Dict[str, Set[Tuple[str, str]]]:
+ """
+ Fetch event metadata for events which related to the same event as the given event.
+
+ If the given event has no relation information, returns an empty dictionary.
+
+ Args:
+ event_id: The event ID which is targeted by relations.
+ relation_types: The relation types to check for mutual relations.
+
+ Returns:
+ A dictionary of relation type to:
+ A set of tuples of:
+ The sender
+ The event type
+ """
+ rel_type_sql, rel_type_args = make_in_list_sql_clause(
+ self.database_engine, "relation_type", relation_types
+ )
+
+ sql = f"""
+ SELECT DISTINCT relation_type, sender, type FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE relates_to_id = ? AND {rel_type_sql}
+ """
+
+ def _get_event_relations(
+ txn: LoggingTransaction,
+ ) -> Dict[str, Set[Tuple[str, str]]]:
+ txn.execute(sql, [event_id] + rel_type_args)
+ result = defaultdict(set)
+ for rel_type, sender, type in txn.fetchall():
+ result[rel_type].add((sender, type))
+ return result
+
+ return await self.db_pool.runInteraction(
+ "get_event_relations", _get_event_relations
+ )
+
class RelationsStore(RelationsWorkerStore):
pass
|