diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 713dcf6950..75b7e126ca 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -13,31 +13,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import itertools
import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Union
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Collection,
+ Dict,
+ List,
+ Mapping,
+ Optional,
+ Tuple,
+ Union,
+)
from prometheus_client import Counter
-from synapse.api.constants import EventTypes, Membership, RelationTypes
+from synapse.api.constants import MAIN_TIMELINE, EventTypes, Membership, RelationTypes
from synapse.event_auth import auth_types_for_event, get_user_power_level
from synapse.events import EventBase, relation_from_event
from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
from synapse.storage.databases.main.roommember import EventIdMembership
from synapse.storage.state import StateFilter
+from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator
from synapse.util.caches import register_cache
from synapse.util.metrics import measure_func
from synapse.visibility import filter_event_for_clients_with_state
-from .push_rule_evaluator import PushRuleEvaluatorForEvent
-
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
-
push_rules_invalidation_counter = Counter(
"synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", ""
)
@@ -99,6 +106,8 @@ class BulkPushRuleEvaluator:
self.clock = hs.get_clock()
self._event_auth_handler = hs.get_event_auth_handler()
+ self._related_event_match_enabled = self.hs.config.experimental.msc3664_enabled
+
self.room_push_rule_cache_metrics = register_cache(
"cache",
"room_push_rule_cache",
@@ -106,13 +115,10 @@ class BulkPushRuleEvaluator:
resizable=False,
)
- # Whether to support MSC3772 is supported.
- self._relations_match_enabled = self.hs.config.experimental.msc3772_enabled
-
async def _get_rules_for_event(
self,
event: EventBase,
- ) -> Dict[str, List[Dict[str, Any]]]:
+ ) -> Dict[str, FilteredPushRules]:
"""Get the push rules for all users who may need to be notified about
the event.
@@ -160,23 +166,51 @@ class BulkPushRuleEvaluator:
return rules_by_user
async def _get_power_levels_and_sender_level(
- self, event: EventBase, context: EventContext
- ) -> Tuple[dict, int]:
+ self,
+ event: EventBase,
+ context: EventContext,
+ event_id_to_event: Mapping[str, EventBase],
+ ) -> Tuple[dict, Optional[int]]:
+ """
+ Given an event and an event context, get the power level event relevant to the event
+ and the power level of the sender of the event.
+ Args:
+ event: event to check
+ context: context of event to check
+ event_id_to_event: a mapping of event_id to event for a set of events being
+ batch persisted. This is needed as the sought-after power level event may
+ be in this batch rather than the DB
+ """
+ # There are no power levels and sender levels possible to get from outlier
+ if event.internal_metadata.is_outlier():
+ return {}, None
+
event_types = auth_types_for_event(event.room_version, event)
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types(event_types)
)
pl_event_id = prev_state_ids.get(POWER_KEY)
+ # fastpath: if there's a power level event, that's all we need, and
+ # not having a power level event is an extreme edge case
if pl_event_id:
- # fastpath: if there's a power level event, that's all we need, and
- # not having a power level event is an extreme edge case
- auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)}
+ # Get the power level event from the batch, or fall back to the database.
+ pl_event = event_id_to_event.get(pl_event_id)
+ if pl_event:
+ auth_events = {POWER_KEY: pl_event}
+ else:
+ auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)}
else:
auth_events_ids = self._event_auth_handler.compute_auth_events(
event, prev_state_ids, for_verification=False
)
auth_events_dict = await self.store.get_events(auth_events_ids)
+ # Some needed auth events might be in the batch, combine them with those
+ # fetched from the database.
+ for auth_event_id in auth_events_ids:
+ auth_event = event_id_to_event.get(auth_event_id)
+ if auth_event:
+ auth_events_dict[auth_event_id] = auth_event
auth_events = {(e.type, e.state_key): e for e in auth_events_dict.values()}
sender_level = get_user_power_level(event.sender, auth_events)
@@ -185,76 +219,91 @@ class BulkPushRuleEvaluator:
return pl_event.content if pl_event else {}, sender_level
- async def _get_mutual_relations(
- self, event: EventBase, rules: Iterable[Dict[str, Any]]
- ) -> Dict[str, Set[Tuple[str, str]]]:
- """
- Fetch event metadata for events which related to the same event as the given event.
-
- If the given event has no relation information, returns an empty dictionary.
-
- Args:
- event_id: The event ID which is targeted by relations.
- rules: The push rules which will be processed for this event.
+ async def _related_events(self, event: EventBase) -> Dict[str, Dict[str, str]]:
+ """Fetches the related events for 'event'. Sets the im.vector.is_falling_back key if the event is from a fallback relation
Returns:
- A dictionary of relation type to:
- A set of tuples of:
- The sender
- The event type
+ Mapping of relation type to flattened events.
"""
+ related_events: Dict[str, Dict[str, str]] = {}
+ if self._related_event_match_enabled:
+ related_event_id = event.content.get("m.relates_to", {}).get("event_id")
+ relation_type = event.content.get("m.relates_to", {}).get("rel_type")
+ if related_event_id is not None and relation_type is not None:
+ related_event = await self.store.get_event(
+ related_event_id, allow_none=True
+ )
+ if related_event is not None:
+ related_events[relation_type] = _flatten_dict(related_event)
- # If the experimental feature is not enabled, skip fetching relations.
- if not self._relations_match_enabled:
- return {}
+ reply_event_id = (
+ event.content.get("m.relates_to", {})
+ .get("m.in_reply_to", {})
+ .get("event_id")
+ )
- # If the event does not have a relation, then cannot have any mutual
- # relations.
- relation = relation_from_event(event)
- if not relation:
- return {}
-
- # Pre-filter to figure out which relation types are interesting.
- rel_types = set()
- for rule in rules:
- # Skip disabled rules.
- if "enabled" in rule and not rule["enabled"]:
- continue
+ # convert replies to pseudo relations
+ if reply_event_id is not None:
+ related_event = await self.store.get_event(
+ reply_event_id, allow_none=True
+ )
- for condition in rule["conditions"]:
- if condition["kind"] != "org.matrix.msc3772.relation_match":
- continue
+ if related_event is not None:
+ related_events["m.in_reply_to"] = _flatten_dict(related_event)
- # rel_type is required.
- rel_type = condition.get("rel_type")
- if rel_type:
- rel_types.add(rel_type)
+ # indicate that this is from a fallback relation.
+ if relation_type == "m.thread" and event.content.get(
+ "m.relates_to", {}
+ ).get("is_falling_back", False):
+ related_events["m.in_reply_to"][
+ "im.vector.is_falling_back"
+ ] = ""
- # If no valid rules were found, no mutual relations.
- if not rel_types:
- return {}
+ return related_events
- # If any valid rules were found, fetch the mutual relations.
- return await self.store.get_mutual_event_relations(
- relation.parent_id, rel_types
- )
+ async def action_for_events_by_user(
+ self, events_and_context: List[Tuple[EventBase, EventContext]]
+ ) -> None:
+ """Given a list of events and their associated contexts, evaluate the push rules
+ for each event, check if the message should increment the unread count, and
+ insert the results into the event_push_actions_staging table.
+ """
+ # For batched events the power level events may not have been persisted yet,
+ # so we pass in the batched events. Thus if the event cannot be found in the
+ # database we can check in the batch.
+ event_id_to_event = {e.event_id: e for e, _ in events_and_context}
+ for event, context in events_and_context:
+ await self._action_for_event_by_user(event, context, event_id_to_event)
@measure_func("action_for_event_by_user")
- async def action_for_event_by_user(
- self, event: EventBase, context: EventContext
+ async def _action_for_event_by_user(
+ self,
+ event: EventBase,
+ context: EventContext,
+ event_id_to_event: Mapping[str, EventBase],
) -> None:
- """Given an event and context, evaluate the push rules, check if the message
- should increment the unread count, and insert the results into the
- event_push_actions_staging table.
- """
- if event.internal_metadata.is_outlier():
- # This can happen due to out of band memberships
+
+ if (
+ not event.internal_metadata.is_notifiable()
+ or event.internal_metadata.is_historical()
+ ):
+ # Push rules for events that aren't notifiable can't be processed by this and
+ # we want to skip push notification actions for historical messages
+ # because we don't want to notify people about old history back in time.
+ # The historical messages also do not have the proper `context.current_state_ids`
+ # and `state_groups` because they have `prev_events` that aren't persisted yet
+ # (historical messages persisted in reverse-chronological order).
return
- count_as_unread = _should_count_as_unread(event, context)
+ # Disable counting as unread unless the experimental configuration is
+ # enabled, as it can cause additional (unwanted) rows to be added to the
+ # event_push_actions table.
+ count_as_unread = False
+ if self.hs.config.experimental.msc2654_enabled:
+ count_as_unread = _should_count_as_unread(event, context)
rules_by_user = await self._get_rules_for_event(event)
- actions_by_user: Dict[str, List[Union[dict, str]]] = {}
+ actions_by_user: Dict[str, Collection[Union[Mapping, str]]] = {}
room_member_count = await self.store.get_number_joined_users_in_room(
event.room_id
@@ -263,19 +312,39 @@ class BulkPushRuleEvaluator:
(
power_levels,
sender_power_level,
- ) = await self._get_power_levels_and_sender_level(event, context)
-
- relations = await self._get_mutual_relations(
- event, itertools.chain(*rules_by_user.values())
+ ) = await self._get_power_levels_and_sender_level(
+ event, context, event_id_to_event
)
- evaluator = PushRuleEvaluatorForEvent(
- event,
+ # Find the event's thread ID.
+ relation = relation_from_event(event)
+ # If the event does not have a relation, then it cannot have a thread ID.
+ thread_id = MAIN_TIMELINE
+ if relation:
+ # Recursively attempt to find the thread this event relates to.
+ if relation.rel_type == RelationTypes.THREAD:
+ thread_id = relation.parent_id
+ else:
+ # Since the event has not yet been persisted we check whether
+ # the parent is part of a thread.
+ thread_id = await self.store.get_thread_id(relation.parent_id)
+
+ related_events = await self._related_events(event)
+
+ # It's possible that old room versions have non-integer power levels (floats or
+ # strings). Workaround this by explicitly converting to int.
+ notification_levels = power_levels.get("notifications", {})
+ if not event.room_version.msc3667_int_only_power_levels:
+ for user_id, level in notification_levels.items():
+ notification_levels[user_id] = int(level)
+
+ evaluator = PushRuleEvaluator(
+ _flatten_dict(event),
room_member_count,
sender_power_level,
- power_levels,
- relations,
- self._relations_match_enabled,
+ notification_levels,
+ related_events,
+ self._related_event_match_enabled,
)
users = rules_by_user.keys()
@@ -283,20 +352,10 @@ class BulkPushRuleEvaluator:
event.room_id, users
)
- # This is a check for the case where user joins a room without being
- # allowed to see history, and then the server receives a delayed event
- # from before the user joined, which they should not be pushed for
- uids_with_visibility = await filter_event_for_clients_with_state(
- self.store, users, event, context
- )
-
for uid, rules in rules_by_user.items():
if event.sender == uid:
continue
- if uid not in uids_with_visibility:
- continue
-
display_name = None
profile = profiles.get(uid)
if profile:
@@ -317,19 +376,30 @@ class BulkPushRuleEvaluator:
# current user, it'll be added to the dict later.
actions_by_user[uid] = []
- for rule in rules:
- if "enabled" in rule and not rule["enabled"]:
- continue
+ actions = evaluator.run(rules, uid, display_name)
+ if "notify" in actions:
+ # Push rules say we should notify the user of this event
+ actions_by_user[uid] = actions
- matches = evaluator.check_conditions(
- rule["conditions"], uid, display_name
- )
- if matches:
- actions = [x for x in rule["actions"] if x != "dont_notify"]
- if actions and "notify" in actions:
- # Push rules say we should notify the user of this event
- actions_by_user[uid] = actions
- break
+ # If there aren't any actions then we can skip the rest of the
+ # processing.
+ if not actions_by_user:
+ return
+
+ # This is a check for the case where user joins a room without being
+ # allowed to see history, and then the server receives a delayed event
+ # from before the user joined, which they should not be pushed for
+ #
+ # We do this *after* calculating the push actions as a) its unlikely
+ # that we'll filter anyone out and b) for large rooms its likely that
+ # most users will have push disabled and so the set of users to check is
+ # much smaller.
+ uids_with_visibility = await filter_event_for_clients_with_state(
+ self.store, actions_by_user.keys(), event, context
+ )
+
+ for user_id in set(actions_by_user).difference(uids_with_visibility):
+ actions_by_user.pop(user_id, None)
# Mark in the DB staging area the push actions for users who should be
# notified for this event. (This will then get handled when we persist
@@ -338,6 +408,7 @@ class BulkPushRuleEvaluator:
event.event_id,
actions_by_user,
count_as_unread,
+ thread_id,
)
@@ -345,3 +416,21 @@ MemberMap = Dict[str, Optional[EventIdMembership]]
Rule = Dict[str, dict]
RulesByUser = Dict[str, List[Rule]]
StateGroup = Union[object, int]
+
+
+def _flatten_dict(
+ d: Union[EventBase, Mapping[str, Any]],
+ prefix: Optional[List[str]] = None,
+ result: Optional[Dict[str, str]] = None,
+) -> Dict[str, str]:
+ if prefix is None:
+ prefix = []
+ if result is None:
+ result = {}
+ for key, value in d.items():
+ if isinstance(value, str):
+ result[".".join(prefix + [key])] = value.lower()
+ elif isinstance(value, Mapping):
+ _flatten_dict(value, prefix=(prefix + [key]), result=result)
+
+ return result
|