diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index a75386f6a0..d7795a9080 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -165,8 +165,21 @@ class BulkPushRuleEvaluator:
return rules_by_user
async def _get_power_levels_and_sender_level(
- self, event: EventBase, context: EventContext
+ self,
+ event: EventBase,
+ context: EventContext,
+ event_id_to_event: Mapping[str, EventBase],
) -> Tuple[dict, Optional[int]]:
+ """
+ Given an event and an event context, get the power level event relevant to the event
+ and the power level of the sender of the event.
+ Args:
+ event: event to check
+ context: context of event to check
+ event_id_to_event: a mapping of event_id to event for a set of events being
+ batch persisted. This is needed as the sought-after power level event may
+ be in this batch rather than the DB
+ """
# There are no power levels and sender levels possible to get from outlier
if event.internal_metadata.is_outlier():
return {}, None
@@ -177,15 +190,26 @@ class BulkPushRuleEvaluator:
)
pl_event_id = prev_state_ids.get(POWER_KEY)
+ # fastpath: if there's a power level event, that's all we need, and
+ # not having a power level event is an extreme edge case
if pl_event_id:
- # fastpath: if there's a power level event, that's all we need, and
- # not having a power level event is an extreme edge case
- auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)}
+ # Get the power level event from the batch, or fall back to the database.
+ pl_event = event_id_to_event.get(pl_event_id)
+ if pl_event:
+ auth_events = {POWER_KEY: pl_event}
+ else:
+ auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)}
else:
auth_events_ids = self._event_auth_handler.compute_auth_events(
event, prev_state_ids, for_verification=False
)
auth_events_dict = await self.store.get_events(auth_events_ids)
+ # Some needed auth events might be in the batch, combine them with those
+ # fetched from the database.
+ for auth_event_id in auth_events_ids:
+ auth_event = event_id_to_event.get(auth_event_id)
+ if auth_event:
+ auth_events_dict[auth_event_id] = auth_event
auth_events = {(e.type, e.state_key): e for e in auth_events_dict.values()}
sender_level = get_user_power_level(event.sender, auth_events)
@@ -194,16 +218,38 @@ class BulkPushRuleEvaluator:
return pl_event.content if pl_event else {}, sender_level
- @measure_func("action_for_event_by_user")
- async def action_for_event_by_user(
- self, event: EventBase, context: EventContext
+ async def action_for_events_by_user(
+ self, events_and_context: List[Tuple[EventBase, EventContext]]
) -> None:
- """Given an event and context, evaluate the push rules, check if the message
- should increment the unread count, and insert the results into the
- event_push_actions_staging table.
+ """Given a list of events and their associated contexts, evaluate the push rules
+ for each event, check if the message should increment the unread count, and
+ insert the results into the event_push_actions_staging table.
"""
- if not event.internal_metadata.is_notifiable():
- # Push rules for events that aren't notifiable can't be processed by this
+ # For batched events the power level events may not have been persisted yet,
+ # so we pass in the batched events. Thus if the event cannot be found in the
+ # database we can check in the batch.
+ event_id_to_event = {e.event_id: e for e, _ in events_and_context}
+ for event, context in events_and_context:
+ await self._action_for_event_by_user(event, context, event_id_to_event)
+
+ @measure_func("action_for_event_by_user")
+ async def _action_for_event_by_user(
+ self,
+ event: EventBase,
+ context: EventContext,
+ event_id_to_event: Mapping[str, EventBase],
+ ) -> None:
+
+ if (
+ not event.internal_metadata.is_notifiable()
+ or event.internal_metadata.is_historical()
+ ):
+ # Push rules for events that aren't notifiable can't be processed by this and
+ # we want to skip push notification actions for historical messages
+ # because we don't want to notify people about old history back in time.
+ # The historical messages also do not have the proper `context.current_state_ids`
+ # and `state_groups` because they have `prev_events` that aren't persisted yet
+ # (historical messages persisted in reverse-chronological order).
return
# Disable counting as unread unless the experimental configuration is
@@ -223,7 +269,9 @@ class BulkPushRuleEvaluator:
(
power_levels,
sender_power_level,
- ) = await self._get_power_levels_and_sender_level(event, context)
+ ) = await self._get_power_levels_and_sender_level(
+ event, context, event_id_to_event
+ )
# Find the event's thread ID.
relation = relation_from_event(event)
|