summary refs log tree commit diff
path: root/synapse/push
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/push')
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py207
-rw-r--r--synapse/push/clientformat.py16
-rw-r--r--synapse/push/mailer.py2
-rw-r--r--synapse/push/push_tools.py28
4 files changed, 161 insertions, 92 deletions
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py

index eced182fd5..f27ba64d53 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -13,18 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import itertools import logging from typing import ( TYPE_CHECKING, Any, Collection, Dict, - Iterable, List, Mapping, Optional, - Set, Tuple, Union, ) @@ -32,13 +29,14 @@ from typing import ( from prometheus_client import Counter from synapse.api.constants import MAIN_TIMELINE, EventTypes, Membership, RelationTypes +from synapse.api.room_versions import PushRuleRoomFlag, RoomVersion from synapse.event_auth import auth_types_for_event, get_user_power_level from synapse.events import EventBase, relation_from_event from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.storage.databases.main.roommember import EventIdMembership -from synapse.storage.state import StateFilter -from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRuleEvaluator +from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator +from synapse.types.state import StateFilter from synapse.util.caches import register_cache from synapse.util.metrics import measure_func from synapse.visibility import filter_event_for_clients_with_state @@ -48,7 +46,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) - push_rules_invalidation_counter = Counter( "synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", "" ) @@ -109,6 +106,9 @@ class BulkPushRuleEvaluator: self.store = hs.get_datastores().main self.clock = hs.get_clock() self._event_auth_handler = hs.get_event_auth_handler() + self.should_calculate_push_rules = self.hs.config.push.enable_push + + self._related_event_match_enabled = self.hs.config.experimental.msc3664_enabled self.room_push_rule_cache_metrics = register_cache( "cache", @@ -117,9 +117,6 @@ class BulkPushRuleEvaluator: resizable=False, ) - # Whether to support MSC3772 is supported. - self._relations_match_enabled = self.hs.config.experimental.msc3772_enabled - async def _get_rules_for_event( self, event: EventBase, @@ -171,8 +168,21 @@ class BulkPushRuleEvaluator: return rules_by_user async def _get_power_levels_and_sender_level( - self, event: EventBase, context: EventContext + self, + event: EventBase, + context: EventContext, + event_id_to_event: Mapping[str, EventBase], ) -> Tuple[dict, Optional[int]]: + """ + Given an event and an event context, get the power level event relevant to the event + and the power level of the sender of the event. + Args: + event: event to check + context: context of event to check + event_id_to_event: a mapping of event_id to event for a set of events being + batch persisted. This is needed as the sought-after power level event may + be in this batch rather than the DB + """ # There are no power levels and sender levels possible to get from outlier if event.internal_metadata.is_outlier(): return {}, None @@ -183,15 +193,26 @@ class BulkPushRuleEvaluator: ) pl_event_id = prev_state_ids.get(POWER_KEY) + # fastpath: if there's a power level event, that's all we need, and + # not having a power level event is an extreme edge case if pl_event_id: - # fastpath: if there's a power level event, that's all we need, and - # not having a power level event is an extreme edge case - auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)} + # Get the power level event from the batch, or fall back to the database. + pl_event = event_id_to_event.get(pl_event_id) + if pl_event: + auth_events = {POWER_KEY: pl_event} + else: + auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)} else: auth_events_ids = self._event_auth_handler.compute_auth_events( event, prev_state_ids, for_verification=False ) auth_events_dict = await self.store.get_events(auth_events_ids) + # Some needed auth events might be in the batch, combine them with those + # fetched from the database. + for auth_event_id in auth_events_ids: + auth_event = event_id_to_event.get(auth_event_id) + if auth_event: + auth_events_dict[auth_event_id] = auth_event auth_events = {(e.type, e.state_key): e for e in auth_events_dict.values()} sender_level = get_user_power_level(event.sender, auth_events) @@ -200,61 +221,82 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level - async def _get_mutual_relations( - self, parent_id: str, rules: Iterable[Tuple[PushRule, bool]] - ) -> Dict[str, Set[Tuple[str, str]]]: - """ - Fetch event metadata for events which related to the same event as the given event. - - If the given event has no relation information, returns an empty dictionary. - - Args: - parent_id: The event ID which is targeted by relations. - rules: The push rules which will be processed for this event. + async def _related_events(self, event: EventBase) -> Dict[str, Dict[str, str]]: + """Fetches the related events for 'event'. Sets the im.vector.is_falling_back key if the event is from a fallback relation Returns: - A dictionary of relation type to: - A set of tuples of: - The sender - The event type + Mapping of relation type to flattened events. """ + related_events: Dict[str, Dict[str, str]] = {} + if self._related_event_match_enabled: + related_event_id = event.content.get("m.relates_to", {}).get("event_id") + relation_type = event.content.get("m.relates_to", {}).get("rel_type") + if related_event_id is not None and relation_type is not None: + related_event = await self.store.get_event( + related_event_id, allow_none=True + ) + if related_event is not None: + related_events[relation_type] = _flatten_dict(related_event) + + reply_event_id = ( + event.content.get("m.relates_to", {}) + .get("m.in_reply_to", {}) + .get("event_id") + ) - # If the experimental feature is not enabled, skip fetching relations. - if not self._relations_match_enabled: - return {} + # convert replies to pseudo relations + if reply_event_id is not None: + related_event = await self.store.get_event( + reply_event_id, allow_none=True + ) - # Pre-filter to figure out which relation types are interesting. - rel_types = set() - for rule, enabled in rules: - if not enabled: - continue + if related_event is not None: + related_events["m.in_reply_to"] = _flatten_dict(related_event) - for condition in rule.conditions: - if condition["kind"] != "org.matrix.msc3772.relation_match": - continue + # indicate that this is from a fallback relation. + if relation_type == "m.thread" and event.content.get( + "m.relates_to", {} + ).get("is_falling_back", False): + related_events["m.in_reply_to"][ + "im.vector.is_falling_back" + ] = "" - # rel_type is required. - rel_type = condition.get("rel_type") - if rel_type: - rel_types.add(rel_type) + return related_events - # If no valid rules were found, no mutual relations. - if not rel_types: - return {} - - # If any valid rules were found, fetch the mutual relations. - return await self.store.get_mutual_event_relations(parent_id, rel_types) + async def action_for_events_by_user( + self, events_and_context: List[Tuple[EventBase, EventContext]] + ) -> None: + """Given a list of events and their associated contexts, evaluate the push rules + for each event, check if the message should increment the unread count, and + insert the results into the event_push_actions_staging table. + """ + if not self.should_calculate_push_rules: + return + # For batched events the power level events may not have been persisted yet, + # so we pass in the batched events. Thus if the event cannot be found in the + # database we can check in the batch. + event_id_to_event = {e.event_id: e for e, _ in events_and_context} + for event, context in events_and_context: + await self._action_for_event_by_user(event, context, event_id_to_event) @measure_func("action_for_event_by_user") - async def action_for_event_by_user( - self, event: EventBase, context: EventContext + async def _action_for_event_by_user( + self, + event: EventBase, + context: EventContext, + event_id_to_event: Mapping[str, EventBase], ) -> None: - """Given an event and context, evaluate the push rules, check if the message - should increment the unread count, and insert the results into the - event_push_actions_staging table. - """ - if not event.internal_metadata.is_notifiable(): - # Push rules for events that aren't notifiable can't be processed by this + + if ( + not event.internal_metadata.is_notifiable() + or event.internal_metadata.is_historical() + ): + # Push rules for events that aren't notifiable can't be processed by this and + # we want to skip push notification actions for historical messages + # because we don't want to notify people about old history back in time. + # The historical messages also do not have the proper `context.current_state_ids` + # and `state_groups` because they have `prev_events` that aren't persisted yet + # (historical messages persisted in reverse-chronological order). return # Disable counting as unread unless the experimental configuration is @@ -274,25 +316,24 @@ class BulkPushRuleEvaluator: ( power_levels, sender_power_level, - ) = await self._get_power_levels_and_sender_level(event, context) + ) = await self._get_power_levels_and_sender_level( + event, context, event_id_to_event + ) + # Find the event's thread ID. relation = relation_from_event(event) - # If the event does not have a relation, then cannot have any mutual - # relations or thread ID. - relations = {} + # If the event does not have a relation, then it cannot have a thread ID. thread_id = MAIN_TIMELINE if relation: - relations = await self._get_mutual_relations( - relation.parent_id, - itertools.chain(*(r.rules() for r in rules_by_user.values())), - ) # Recursively attempt to find the thread this event relates to. if relation.rel_type == RelationTypes.THREAD: thread_id = relation.parent_id else: # Since the event has not yet been persisted we check whether # the parent is part of a thread. - thread_id = await self.store.get_thread_id(relation.parent_id) or "main" + thread_id = await self.store.get_thread_id(relation.parent_id) + + related_events = await self._related_events(event) # It's possible that old room versions have non-integer power levels (floats or # strings). Workaround this by explicitly converting to int. @@ -302,12 +343,14 @@ class BulkPushRuleEvaluator: notification_levels[user_id] = int(level) evaluator = PushRuleEvaluator( - _flatten_dict(event), + _flatten_dict(event, room_version=event.room_version), room_member_count, sender_power_level, notification_levels, - relations, - self._relations_match_enabled, + related_events, + self._related_event_match_enabled, + event.room_version.msc3931_push_features, + self.hs.config.experimental.msc1767_enabled, # MSC3931 flag ) users = rules_by_user.keys() @@ -383,6 +426,7 @@ StateGroup = Union[object, int] def _flatten_dict( d: Union[EventBase, Mapping[str, Any]], + room_version: Optional[RoomVersion] = None, prefix: Optional[List[str]] = None, result: Optional[Dict[str, str]] = None, ) -> Dict[str, str]: @@ -394,6 +438,31 @@ def _flatten_dict( if isinstance(value, str): result[".".join(prefix + [key])] = value.lower() elif isinstance(value, Mapping): + # do not set `room_version` due to recursion considerations below _flatten_dict(value, prefix=(prefix + [key]), result=result) + # `room_version` should only ever be set when looking at the top level of an event + if ( + room_version is not None + and PushRuleRoomFlag.EXTENSIBLE_EVENTS in room_version.msc3931_push_features + and isinstance(d, EventBase) + ): + # Room supports extensible events: replace `content.body` with the plain text + # representation from `m.markup`, as per MSC1767. + markup = d.get("content").get("m.markup") + if room_version.identifier.startswith("org.matrix.msc1767."): + markup = d.get("content").get("org.matrix.msc1767.markup") + if markup is not None and isinstance(markup, list): + text = "" + for rep in markup: + if not isinstance(rep, dict): + # invalid markup - skip all processing + break + if rep.get("mimetype", "text/plain") == "text/plain": + rep_text = rep.get("body") + if rep_text is not None and isinstance(rep_text, str): + text = rep_text.lower() + break + result["content.body"] = text + return result diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py
index 7095ae83f9..622a1e35c5 100644 --- a/synapse/push/clientformat.py +++ b/synapse/push/clientformat.py
@@ -44,6 +44,12 @@ def format_push_rules_for_user( rulearray.append(template_rule) + pattern_type = template_rule.pop("pattern_type", None) + if pattern_type == "user_id": + template_rule["pattern"] = user.to_string() + elif pattern_type == "user_localpart": + template_rule["pattern"] = user.localpart + template_rule["enabled"] = enabled if "conditions" not in template_rule: @@ -93,10 +99,14 @@ def _rule_to_template(rule: PushRule) -> Optional[Dict[str, Any]]: if len(rule.conditions) != 1: return None thecond = rule.conditions[0] - if "pattern" not in thecond: - return None + templaterule = {"actions": rule.actions} - templaterule["pattern"] = thecond["pattern"] + if "pattern" in thecond: + templaterule["pattern"] = thecond["pattern"] + elif "pattern_type" in thecond: + templaterule["pattern_type"] = thecond["pattern_type"] + else: + return None else: # This should not be reached unless this function is not kept in sync # with PRIORITY_CLASS_INVERSE_MAP. diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index c2575ba3d9..93b255ced5 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py
@@ -37,8 +37,8 @@ from synapse.push.push_types import ( TemplateVars, ) from synapse.storage.databases.main.event_push_actions import EmailPushAction -from synapse.storage.state import StateFilter from synapse.types import StateMap, UserID +from synapse.types.state import StateFilter from synapse.util.async_helpers import concurrently_execute from synapse.visibility import filter_events_for_client diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index edeba27a45..7ee07e4bee 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py
@@ -17,7 +17,6 @@ from synapse.events import EventBase from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.storage.controllers import StorageControllers from synapse.storage.databases.main import DataStore -from synapse.util.async_helpers import concurrently_execute async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int: @@ -26,23 +25,12 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - badge = len(invites) - room_notifs = [] - - async def get_room_unread_count(room_id: str) -> None: - room_notifs.append( - await store.get_unread_event_push_actions_by_room_for_user( - room_id, - user_id, - ) - ) - - await concurrently_execute(get_room_unread_count, joins, 10) - - for notifs in room_notifs: - # Combine the counts from all the threads. - notify_count = notifs.main_timeline.notify_count + sum( - n.notify_count for n in notifs.threads.values() - ) + room_to_count = await store.get_unread_counts_by_room_for_user(user_id) + for room_id, notify_count in room_to_count.items(): + # room_to_count may include rooms which the user has left, + # ignore those. + if room_id not in joins: + continue if notify_count == 0: continue @@ -51,8 +39,10 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - # return one badge count per conversation badge += 1 else: - # increment the badge count by the number of unread messages in the room + # Increase badge by number of notifications in room + # NOTE: this includes threaded and unthreaded notifications. badge += notify_count + return badge