diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index eced182fd5..f27ba64d53 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -13,18 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import itertools
import logging
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
- Iterable,
List,
Mapping,
Optional,
- Set,
Tuple,
Union,
)
@@ -32,13 +29,14 @@ from typing import (
from prometheus_client import Counter
from synapse.api.constants import MAIN_TIMELINE, EventTypes, Membership, RelationTypes
+from synapse.api.room_versions import PushRuleRoomFlag, RoomVersion
from synapse.event_auth import auth_types_for_event, get_user_power_level
from synapse.events import EventBase, relation_from_event
from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
from synapse.storage.databases.main.roommember import EventIdMembership
-from synapse.storage.state import StateFilter
-from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRuleEvaluator
+from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator
+from synapse.types.state import StateFilter
from synapse.util.caches import register_cache
from synapse.util.metrics import measure_func
from synapse.visibility import filter_event_for_clients_with_state
@@ -48,7 +46,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-
push_rules_invalidation_counter = Counter(
"synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", ""
)
@@ -109,6 +106,9 @@ class BulkPushRuleEvaluator:
self.store = hs.get_datastores().main
self.clock = hs.get_clock()
self._event_auth_handler = hs.get_event_auth_handler()
+ self.should_calculate_push_rules = self.hs.config.push.enable_push
+
+ self._related_event_match_enabled = self.hs.config.experimental.msc3664_enabled
self.room_push_rule_cache_metrics = register_cache(
"cache",
@@ -117,9 +117,6 @@ class BulkPushRuleEvaluator:
resizable=False,
)
- # Whether to support MSC3772 is supported.
- self._relations_match_enabled = self.hs.config.experimental.msc3772_enabled
-
async def _get_rules_for_event(
self,
event: EventBase,
@@ -171,8 +168,21 @@ class BulkPushRuleEvaluator:
return rules_by_user
async def _get_power_levels_and_sender_level(
- self, event: EventBase, context: EventContext
+ self,
+ event: EventBase,
+ context: EventContext,
+ event_id_to_event: Mapping[str, EventBase],
) -> Tuple[dict, Optional[int]]:
+ """
+ Given an event and an event context, get the power level event relevant to the event
+ and the power level of the sender of the event.
+ Args:
+ event: event to check
+ context: context of event to check
+ event_id_to_event: a mapping of event_id to event for a set of events being
+ batch persisted. This is needed as the sought-after power level event may
+ be in this batch rather than the DB
+ """
# There are no power levels and sender levels possible to get from outlier
if event.internal_metadata.is_outlier():
return {}, None
@@ -183,15 +193,26 @@ class BulkPushRuleEvaluator:
)
pl_event_id = prev_state_ids.get(POWER_KEY)
+ # fastpath: if there's a power level event, that's all we need, and
+ # not having a power level event is an extreme edge case
if pl_event_id:
- # fastpath: if there's a power level event, that's all we need, and
- # not having a power level event is an extreme edge case
- auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)}
+ # Get the power level event from the batch, or fall back to the database.
+ pl_event = event_id_to_event.get(pl_event_id)
+ if pl_event:
+ auth_events = {POWER_KEY: pl_event}
+ else:
+ auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)}
else:
auth_events_ids = self._event_auth_handler.compute_auth_events(
event, prev_state_ids, for_verification=False
)
auth_events_dict = await self.store.get_events(auth_events_ids)
+ # Some needed auth events might be in the batch, combine them with those
+ # fetched from the database.
+ for auth_event_id in auth_events_ids:
+ auth_event = event_id_to_event.get(auth_event_id)
+ if auth_event:
+ auth_events_dict[auth_event_id] = auth_event
auth_events = {(e.type, e.state_key): e for e in auth_events_dict.values()}
sender_level = get_user_power_level(event.sender, auth_events)
@@ -200,61 +221,82 @@ class BulkPushRuleEvaluator:
return pl_event.content if pl_event else {}, sender_level
- async def _get_mutual_relations(
- self, parent_id: str, rules: Iterable[Tuple[PushRule, bool]]
- ) -> Dict[str, Set[Tuple[str, str]]]:
- """
- Fetch event metadata for events which related to the same event as the given event.
-
- If the given event has no relation information, returns an empty dictionary.
-
- Args:
- parent_id: The event ID which is targeted by relations.
- rules: The push rules which will be processed for this event.
+ async def _related_events(self, event: EventBase) -> Dict[str, Dict[str, str]]:
+ """Fetches the related events for 'event'. Sets the im.vector.is_falling_back key if the event is from a fallback relation
Returns:
- A dictionary of relation type to:
- A set of tuples of:
- The sender
- The event type
+ Mapping of relation type to flattened events.
"""
+ related_events: Dict[str, Dict[str, str]] = {}
+ if self._related_event_match_enabled:
+ related_event_id = event.content.get("m.relates_to", {}).get("event_id")
+ relation_type = event.content.get("m.relates_to", {}).get("rel_type")
+ if related_event_id is not None and relation_type is not None:
+ related_event = await self.store.get_event(
+ related_event_id, allow_none=True
+ )
+ if related_event is not None:
+ related_events[relation_type] = _flatten_dict(related_event)
+
+ reply_event_id = (
+ event.content.get("m.relates_to", {})
+ .get("m.in_reply_to", {})
+ .get("event_id")
+ )
- # If the experimental feature is not enabled, skip fetching relations.
- if not self._relations_match_enabled:
- return {}
+ # convert replies to pseudo relations
+ if reply_event_id is not None:
+ related_event = await self.store.get_event(
+ reply_event_id, allow_none=True
+ )
- # Pre-filter to figure out which relation types are interesting.
- rel_types = set()
- for rule, enabled in rules:
- if not enabled:
- continue
+ if related_event is not None:
+ related_events["m.in_reply_to"] = _flatten_dict(related_event)
- for condition in rule.conditions:
- if condition["kind"] != "org.matrix.msc3772.relation_match":
- continue
+ # indicate that this is from a fallback relation.
+ if relation_type == "m.thread" and event.content.get(
+ "m.relates_to", {}
+ ).get("is_falling_back", False):
+ related_events["m.in_reply_to"][
+ "im.vector.is_falling_back"
+ ] = ""
- # rel_type is required.
- rel_type = condition.get("rel_type")
- if rel_type:
- rel_types.add(rel_type)
+ return related_events
- # If no valid rules were found, no mutual relations.
- if not rel_types:
- return {}
-
- # If any valid rules were found, fetch the mutual relations.
- return await self.store.get_mutual_event_relations(parent_id, rel_types)
+ async def action_for_events_by_user(
+ self, events_and_context: List[Tuple[EventBase, EventContext]]
+ ) -> None:
+ """Given a list of events and their associated contexts, evaluate the push rules
+ for each event, check if the message should increment the unread count, and
+ insert the results into the event_push_actions_staging table.
+ """
+ if not self.should_calculate_push_rules:
+ return
+ # For batched events the power level events may not have been persisted yet,
+ # so we pass in the batched events. Thus if the event cannot be found in the
+ # database we can check in the batch.
+ event_id_to_event = {e.event_id: e for e, _ in events_and_context}
+ for event, context in events_and_context:
+ await self._action_for_event_by_user(event, context, event_id_to_event)
@measure_func("action_for_event_by_user")
- async def action_for_event_by_user(
- self, event: EventBase, context: EventContext
+ async def _action_for_event_by_user(
+ self,
+ event: EventBase,
+ context: EventContext,
+ event_id_to_event: Mapping[str, EventBase],
) -> None:
- """Given an event and context, evaluate the push rules, check if the message
- should increment the unread count, and insert the results into the
- event_push_actions_staging table.
- """
- if not event.internal_metadata.is_notifiable():
- # Push rules for events that aren't notifiable can't be processed by this
+
+ if (
+ not event.internal_metadata.is_notifiable()
+ or event.internal_metadata.is_historical()
+ ):
+ # Push rules for events that aren't notifiable can't be processed by this and
+ # we want to skip push notification actions for historical messages
+ # because we don't want to notify people about old history back in time.
+ # The historical messages also do not have the proper `context.current_state_ids`
+ # and `state_groups` because they have `prev_events` that aren't persisted yet
+ # (historical messages persisted in reverse-chronological order).
return
# Disable counting as unread unless the experimental configuration is
@@ -274,25 +316,24 @@ class BulkPushRuleEvaluator:
(
power_levels,
sender_power_level,
- ) = await self._get_power_levels_and_sender_level(event, context)
+ ) = await self._get_power_levels_and_sender_level(
+ event, context, event_id_to_event
+ )
+ # Find the event's thread ID.
relation = relation_from_event(event)
- # If the event does not have a relation, then cannot have any mutual
- # relations or thread ID.
- relations = {}
+ # If the event does not have a relation, then it cannot have a thread ID.
thread_id = MAIN_TIMELINE
if relation:
- relations = await self._get_mutual_relations(
- relation.parent_id,
- itertools.chain(*(r.rules() for r in rules_by_user.values())),
- )
# Recursively attempt to find the thread this event relates to.
if relation.rel_type == RelationTypes.THREAD:
thread_id = relation.parent_id
else:
# Since the event has not yet been persisted we check whether
# the parent is part of a thread.
- thread_id = await self.store.get_thread_id(relation.parent_id) or "main"
+ thread_id = await self.store.get_thread_id(relation.parent_id)
+
+ related_events = await self._related_events(event)
# It's possible that old room versions have non-integer power levels (floats or
# strings). Workaround this by explicitly converting to int.
@@ -302,12 +343,14 @@ class BulkPushRuleEvaluator:
notification_levels[user_id] = int(level)
evaluator = PushRuleEvaluator(
- _flatten_dict(event),
+ _flatten_dict(event, room_version=event.room_version),
room_member_count,
sender_power_level,
notification_levels,
- relations,
- self._relations_match_enabled,
+ related_events,
+ self._related_event_match_enabled,
+ event.room_version.msc3931_push_features,
+ self.hs.config.experimental.msc1767_enabled, # MSC3931 flag
)
users = rules_by_user.keys()
@@ -383,6 +426,7 @@ StateGroup = Union[object, int]
def _flatten_dict(
d: Union[EventBase, Mapping[str, Any]],
+ room_version: Optional[RoomVersion] = None,
prefix: Optional[List[str]] = None,
result: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:
@@ -394,6 +438,31 @@ def _flatten_dict(
if isinstance(value, str):
result[".".join(prefix + [key])] = value.lower()
elif isinstance(value, Mapping):
+ # do not set `room_version` due to recursion considerations below
_flatten_dict(value, prefix=(prefix + [key]), result=result)
+ # `room_version` should only ever be set when looking at the top level of an event
+ if (
+ room_version is not None
+ and PushRuleRoomFlag.EXTENSIBLE_EVENTS in room_version.msc3931_push_features
+ and isinstance(d, EventBase)
+ ):
+ # Room supports extensible events: replace `content.body` with the plain text
+ # representation from `m.markup`, as per MSC1767.
+ markup = d.get("content").get("m.markup")
+ if room_version.identifier.startswith("org.matrix.msc1767."):
+ markup = d.get("content").get("org.matrix.msc1767.markup")
+ if markup is not None and isinstance(markup, list):
+ text = ""
+ for rep in markup:
+ if not isinstance(rep, dict):
+ # invalid markup - skip all processing
+ break
+ if rep.get("mimetype", "text/plain") == "text/plain":
+ rep_text = rep.get("body")
+ if rep_text is not None and isinstance(rep_text, str):
+ text = rep_text.lower()
+ break
+ result["content.body"] = text
+
return result
diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py
index 7095ae83f9..622a1e35c5 100644
--- a/synapse/push/clientformat.py
+++ b/synapse/push/clientformat.py
@@ -44,6 +44,12 @@ def format_push_rules_for_user(
rulearray.append(template_rule)
+ pattern_type = template_rule.pop("pattern_type", None)
+ if pattern_type == "user_id":
+ template_rule["pattern"] = user.to_string()
+ elif pattern_type == "user_localpart":
+ template_rule["pattern"] = user.localpart
+
template_rule["enabled"] = enabled
if "conditions" not in template_rule:
@@ -93,10 +99,14 @@ def _rule_to_template(rule: PushRule) -> Optional[Dict[str, Any]]:
if len(rule.conditions) != 1:
return None
thecond = rule.conditions[0]
- if "pattern" not in thecond:
- return None
+
templaterule = {"actions": rule.actions}
- templaterule["pattern"] = thecond["pattern"]
+ if "pattern" in thecond:
+ templaterule["pattern"] = thecond["pattern"]
+ elif "pattern_type" in thecond:
+ templaterule["pattern_type"] = thecond["pattern_type"]
+ else:
+ return None
else:
# This should not be reached unless this function is not kept in sync
# with PRIORITY_CLASS_INVERSE_MAP.
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index c2575ba3d9..93b255ced5 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -37,8 +37,8 @@ from synapse.push.push_types import (
TemplateVars,
)
from synapse.storage.databases.main.event_push_actions import EmailPushAction
-from synapse.storage.state import StateFilter
from synapse.types import StateMap, UserID
+from synapse.types.state import StateFilter
from synapse.util.async_helpers import concurrently_execute
from synapse.visibility import filter_events_for_client
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index edeba27a45..7ee07e4bee 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -17,7 +17,6 @@ from synapse.events import EventBase
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage.controllers import StorageControllers
from synapse.storage.databases.main import DataStore
-from synapse.util.async_helpers import concurrently_execute
async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int:
@@ -26,23 +25,12 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
badge = len(invites)
- room_notifs = []
-
- async def get_room_unread_count(room_id: str) -> None:
- room_notifs.append(
- await store.get_unread_event_push_actions_by_room_for_user(
- room_id,
- user_id,
- )
- )
-
- await concurrently_execute(get_room_unread_count, joins, 10)
-
- for notifs in room_notifs:
- # Combine the counts from all the threads.
- notify_count = notifs.main_timeline.notify_count + sum(
- n.notify_count for n in notifs.threads.values()
- )
+ room_to_count = await store.get_unread_counts_by_room_for_user(user_id)
+ for room_id, notify_count in room_to_count.items():
+ # room_to_count may include rooms which the user has left,
+ # ignore those.
+ if room_id not in joins:
+ continue
if notify_count == 0:
continue
@@ -51,8 +39,10 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
# return one badge count per conversation
badge += 1
else:
- # increment the badge count by the number of unread messages in the room
+ # Increase badge by number of notifications in room
+ # NOTE: this includes threaded and unthreaded notifications.
badge += notify_count
+
return badge
|