summary refs log tree commit diff
path: root/synapse/push
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/push')
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py29
1 files changed, 15 insertions, 14 deletions
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index d1caf8a0f7..3846fbc5f0 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -198,7 +198,7 @@ class BulkPushRuleEvaluator:
         return pl_event.content if pl_event else {}, sender_level
 
     async def _get_mutual_relations(
-        self, event: EventBase, rules: Iterable[Tuple[PushRule, bool]]
+        self, parent_id: str, rules: Iterable[Tuple[PushRule, bool]]
     ) -> Dict[str, Set[Tuple[str, str]]]:
         """
         Fetch event metadata for events which related to the same event as the given event.
@@ -206,7 +206,7 @@ class BulkPushRuleEvaluator:
         If the given event has no relation information, returns an empty dictionary.
 
         Args:
-            event_id: The event ID which is targeted by relations.
+            parent_id: The event ID which is targeted by relations.
             rules: The push rules which will be processed for this event.
 
         Returns:
@@ -220,12 +220,6 @@ class BulkPushRuleEvaluator:
         if not self._relations_match_enabled:
             return {}
 
-        # If the event does not have a relation, then cannot have any mutual
-        # relations.
-        relation = relation_from_event(event)
-        if not relation:
-            return {}
-
         # Pre-filter to figure out which relation types are interesting.
         rel_types = set()
         for rule, enabled in rules:
@@ -246,9 +240,7 @@ class BulkPushRuleEvaluator:
             return {}
 
         # If any valid rules were found, fetch the mutual relations.
-        return await self.store.get_mutual_event_relations(
-            relation.parent_id, rel_types
-        )
+        return await self.store.get_mutual_event_relations(parent_id, rel_types)
 
     @measure_func("action_for_event_by_user")
     async def action_for_event_by_user(
@@ -281,9 +273,17 @@ class BulkPushRuleEvaluator:
             sender_power_level,
         ) = await self._get_power_levels_and_sender_level(event, context)
 
-        relations = await self._get_mutual_relations(
-            event, itertools.chain(*rules_by_user.values())
-        )
+        relation = relation_from_event(event)
+        # If the event does not have a relation, then cannot have any mutual
+        # relations or thread ID.
+        relations = {}
+        thread_id = "main"
+        if relation:
+            relations = await self._get_mutual_relations(
+                relation.parent_id, itertools.chain(*rules_by_user.values())
+            )
+            if relation.rel_type == RelationTypes.THREAD:
+                thread_id = relation.parent_id
 
         evaluator = PushRuleEvaluatorForEvent(
             event,
@@ -352,6 +352,7 @@ class BulkPushRuleEvaluator:
             event.event_id,
             actions_by_user,
             count_as_unread,
+            thread_id,
         )