summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/federation.py4
-rw-r--r--synapse/handlers/federation_event.py4
-rw-r--r--synapse/handlers/message.py14
-rw-r--r--synapse/handlers/room.py39
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py74
5 files changed, 79 insertions, 56 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 275a37a575..4fbc79a6cb 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1017,7 +1017,9 @@ class FederationHandler:
 
         context = EventContext.for_outlier(self._storage_controllers)
 
-        await self._bulk_push_rule_evaluator.action_for_event_by_user(event, context)
+        await self._bulk_push_rule_evaluator.action_for_events_by_user(
+            [(event, context)]
+        )
         try:
             await self._federation_event_handler.persist_events_and_notify(
                 event.room_id, [(event, context)]
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 06e41b5cc0..7da6316a82 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -2171,8 +2171,8 @@ class FederationEventHandler:
                     min_depth,
                 )
             else:
-                await self._bulk_push_rule_evaluator.action_for_event_by_user(
-                    event, context
+                await self._bulk_push_rule_evaluator.action_for_events_by_user(
+                    [(event, context)]
                 )
 
         try:
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 15b828dd74..468900a07f 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1433,17 +1433,9 @@ class EventCreationHandler:
             a room that has been un-partial stated.
         """
 
-        for event, context in events_and_context:
-            # Skip push notification actions for historical messages
-            # because we don't want to notify people about old history back in time.
-            # The historical messages also do not have the proper `context.current_state_ids`
-            # and `state_groups` because they have `prev_events` that aren't persisted yet
-            # (historical messages persisted in reverse-chronological order).
-            if not event.internal_metadata.is_historical():
-                with opentracing.start_active_span("calculate_push_actions"):
-                    await self._bulk_push_rule_evaluator.action_for_event_by_user(
-                        event, context
-                    )
+        await self._bulk_push_rule_evaluator.action_for_events_by_user(
+            events_and_context
+        )
 
         try:
             # If we're a worker we need to hit out to the master.
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 638f54051a..cc1e5c8f97 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1055,9 +1055,6 @@ class RoomCreationHandler:
         event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
         depth = 1
 
-        # the last event sent/persisted to the db
-        last_sent_event_id: Optional[str] = None
-
         # the most recently created event
         prev_event: List[str] = []
         # a map of event types, state keys -> event_ids. We collect these mappings this as events are
@@ -1102,26 +1099,6 @@ class RoomCreationHandler:
 
             return new_event, new_context
 
-        async def send(
-            event: EventBase,
-            context: synapse.events.snapshot.EventContext,
-            creator: Requester,
-        ) -> int:
-            nonlocal last_sent_event_id
-
-            ev = await self.event_creation_handler.handle_new_client_event(
-                requester=creator,
-                events_and_context=[(event, context)],
-                ratelimit=False,
-                ignore_shadow_ban=True,
-            )
-
-            last_sent_event_id = ev.event_id
-
-            # we know it was persisted, so must have a stream ordering
-            assert ev.internal_metadata.stream_ordering
-            return ev.internal_metadata.stream_ordering
-
         try:
             config = self._presets_dict[preset_config]
         except KeyError:
@@ -1135,10 +1112,14 @@ class RoomCreationHandler:
         )
 
         logger.debug("Sending %s in new room", EventTypes.Member)
-        await send(creation_event, creation_context, creator)
+        ev = await self.event_creation_handler.handle_new_client_event(
+            requester=creator,
+            events_and_context=[(creation_event, creation_context)],
+            ratelimit=False,
+            ignore_shadow_ban=True,
+        )
+        last_sent_event_id = ev.event_id
 
-        # Room create event must exist at this point
-        assert last_sent_event_id is not None
         member_event_id, _ = await self.room_member_handler.update_membership(
             creator,
             creator.user,
@@ -1157,6 +1138,7 @@ class RoomCreationHandler:
         depth += 1
         state_map[(EventTypes.Member, creator.user.to_string())] = member_event_id
 
+        events_to_send = []
         # We treat the power levels override specially as this needs to be one
         # of the first events that get sent into a room.
         pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None)
@@ -1165,7 +1147,7 @@ class RoomCreationHandler:
                 EventTypes.PowerLevels, pl_content, False
             )
             current_state_group = power_context._state_group
-            await send(power_event, power_context, creator)
+            events_to_send.append((power_event, power_context))
         else:
             power_level_content: JsonDict = {
                 "users": {creator_id: 100},
@@ -1214,9 +1196,8 @@ class RoomCreationHandler:
                 False,
             )
             current_state_group = pl_context._state_group
-            await send(pl_event, pl_context, creator)
+            events_to_send.append((pl_event, pl_context))
 
-        events_to_send = []
         if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
             room_alias_event, room_alias_context = await create_event(
                 EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index a75386f6a0..d7795a9080 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -165,8 +165,21 @@ class BulkPushRuleEvaluator:
         return rules_by_user
 
     async def _get_power_levels_and_sender_level(
-        self, event: EventBase, context: EventContext
+        self,
+        event: EventBase,
+        context: EventContext,
+        event_id_to_event: Mapping[str, EventBase],
     ) -> Tuple[dict, Optional[int]]:
+        """
+        Given an event and an event context, get the power level event relevant to the event
+        and the power level of the sender of the event.
+        Args:
+            event: event to check
+            context: context of event to check
+            event_id_to_event: a mapping of event_id to event for a set of events being
+            batch persisted. This is needed as the sought-after power level event may
+            be in this batch rather than the DB
+        """
         # There are no power levels and sender levels possible to get from outlier
         if event.internal_metadata.is_outlier():
             return {}, None
@@ -177,15 +190,26 @@ class BulkPushRuleEvaluator:
         )
         pl_event_id = prev_state_ids.get(POWER_KEY)
 
+        # fastpath: if there's a power level event, that's all we need, and
+        # not having a power level event is an extreme edge case
         if pl_event_id:
-            # fastpath: if there's a power level event, that's all we need, and
-            # not having a power level event is an extreme edge case
-            auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)}
+            # Get the power level event from the batch, or fall back to the database.
+            pl_event = event_id_to_event.get(pl_event_id)
+            if pl_event:
+                auth_events = {POWER_KEY: pl_event}
+            else:
+                auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)}
         else:
             auth_events_ids = self._event_auth_handler.compute_auth_events(
                 event, prev_state_ids, for_verification=False
             )
             auth_events_dict = await self.store.get_events(auth_events_ids)
+            # Some needed auth events might be in the batch, combine them with those
+            # fetched from the database.
+            for auth_event_id in auth_events_ids:
+                auth_event = event_id_to_event.get(auth_event_id)
+                if auth_event:
+                    auth_events_dict[auth_event_id] = auth_event
             auth_events = {(e.type, e.state_key): e for e in auth_events_dict.values()}
 
         sender_level = get_user_power_level(event.sender, auth_events)
@@ -194,16 +218,38 @@ class BulkPushRuleEvaluator:
 
         return pl_event.content if pl_event else {}, sender_level
 
-    @measure_func("action_for_event_by_user")
-    async def action_for_event_by_user(
-        self, event: EventBase, context: EventContext
+    async def action_for_events_by_user(
+        self, events_and_context: List[Tuple[EventBase, EventContext]]
     ) -> None:
-        """Given an event and context, evaluate the push rules, check if the message
-        should increment the unread count, and insert the results into the
-        event_push_actions_staging table.
+        """Given a list of events and their associated contexts, evaluate the push rules
+        for each event, check if the message should increment the unread count, and
+        insert the results into the event_push_actions_staging table.
         """
-        if not event.internal_metadata.is_notifiable():
-            # Push rules for events that aren't notifiable can't be processed by this
+        # For batched events the power level events may not have been persisted yet,
+        # so we pass in the batched events. Thus if the event cannot be found in the
+        # database we can check in the batch.
+        event_id_to_event = {e.event_id: e for e, _ in events_and_context}
+        for event, context in events_and_context:
+            await self._action_for_event_by_user(event, context, event_id_to_event)
+
+    @measure_func("action_for_event_by_user")
+    async def _action_for_event_by_user(
+        self,
+        event: EventBase,
+        context: EventContext,
+        event_id_to_event: Mapping[str, EventBase],
+    ) -> None:
+
+        if (
+            not event.internal_metadata.is_notifiable()
+            or event.internal_metadata.is_historical()
+        ):
+            # Push rules for events that aren't notifiable can't be processed by this and
+            # we want to skip push notification actions for historical messages
+            # because we don't want to notify people about old history back in time.
+            # The historical messages also do not have the proper `context.current_state_ids`
+            # and `state_groups` because they have `prev_events` that aren't persisted yet
+            # (historical messages persisted in reverse-chronological order).
             return
 
         # Disable counting as unread unless the experimental configuration is
@@ -223,7 +269,9 @@ class BulkPushRuleEvaluator:
         (
             power_levels,
             sender_power_level,
-        ) = await self._get_power_levels_and_sender_level(event, context)
+        ) = await self._get_power_levels_and_sender_level(
+            event, context, event_id_to_event
+        )
 
         # Find the event's thread ID.
         relation = relation_from_event(event)