diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 57c4d70466..a0c760239d 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -116,6 +116,8 @@ class PusherConfig:
last_stream_ordering: int
last_success: Optional[int]
failing_since: Optional[int]
+ enabled: bool
+ device_id: Optional[str]
def as_dict(self) -> Dict[str, Any]:
"""Information that can be retrieved about a pusher after creation."""
@@ -128,6 +130,8 @@ class PusherConfig:
"lang": self.lang,
"profile_tag": self.profile_tag,
"pushkey": self.pushkey,
+ "enabled": self.enabled,
+ "device_id": self.device_id,
}
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
deleted file mode 100644
index 6c0cc5a6ce..0000000000
--- a/synapse/push/baserules.py
+++ /dev/null
@@ -1,446 +0,0 @@
-# Copyright 2015, 2016 OpenMarket Ltd
-# Copyright 2017 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import copy
-from typing import Any, Dict, List
-
-from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
-
-
-def list_with_base_rules(rawrules: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
- """Combine the list of rules set by the user with the default push rules
-
- Args:
- rawrules: The rules the user has modified or set.
-
- Returns:
- A new list with the rules set by the user combined with the defaults.
- """
- ruleslist = []
-
- # Grab the base rules that the user has modified.
- # The modified base rules have a priority_class of -1.
- modified_base_rules = {r["rule_id"]: r for r in rawrules if r["priority_class"] < 0}
-
- # Remove the modified base rules from the list, They'll be added back
- # in the default positions in the list.
- rawrules = [r for r in rawrules if r["priority_class"] >= 0]
-
- # shove the server default rules for each kind onto the end of each
- current_prio_class = list(PRIORITY_CLASS_INVERSE_MAP)[-1]
-
- ruleslist.extend(
- make_base_prepend_rules(
- PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
- )
- )
-
- for r in rawrules:
- if r["priority_class"] < current_prio_class:
- while r["priority_class"] < current_prio_class:
- ruleslist.extend(
- make_base_append_rules(
- PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
- modified_base_rules,
- )
- )
- current_prio_class -= 1
- if current_prio_class > 0:
- ruleslist.extend(
- make_base_prepend_rules(
- PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
- modified_base_rules,
- )
- )
-
- ruleslist.append(r)
-
- while current_prio_class > 0:
- ruleslist.extend(
- make_base_append_rules(
- PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
- )
- )
- current_prio_class -= 1
- if current_prio_class > 0:
- ruleslist.extend(
- make_base_prepend_rules(
- PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
- )
- )
-
- return ruleslist
-
-
-def make_base_append_rules(
- kind: str, modified_base_rules: Dict[str, Dict[str, Any]]
-) -> List[Dict[str, Any]]:
- rules = []
-
- if kind == "override":
- rules = BASE_APPEND_OVERRIDE_RULES
- elif kind == "underride":
- rules = BASE_APPEND_UNDERRIDE_RULES
- elif kind == "content":
- rules = BASE_APPEND_CONTENT_RULES
-
- # Copy the rules before modifying them
- rules = copy.deepcopy(rules)
- for r in rules:
- # Only modify the actions, keep the conditions the same.
- assert isinstance(r["rule_id"], str)
- modified = modified_base_rules.get(r["rule_id"])
- if modified:
- r["actions"] = modified["actions"]
-
- return rules
-
-
-def make_base_prepend_rules(
- kind: str,
- modified_base_rules: Dict[str, Dict[str, Any]],
-) -> List[Dict[str, Any]]:
- rules = []
-
- if kind == "override":
- rules = BASE_PREPEND_OVERRIDE_RULES
-
- # Copy the rules before modifying them
- rules = copy.deepcopy(rules)
- for r in rules:
- # Only modify the actions, keep the conditions the same.
- assert isinstance(r["rule_id"], str)
- modified = modified_base_rules.get(r["rule_id"])
- if modified:
- r["actions"] = modified["actions"]
-
- return rules
-
-
-# We have to annotate these types, otherwise mypy infers them as
-# `List[Dict[str, Sequence[Collection[str]]]]`.
-BASE_APPEND_CONTENT_RULES: List[Dict[str, Any]] = [
- {
- "rule_id": "global/content/.m.rule.contains_user_name",
- "conditions": [
- {
- "kind": "event_match",
- "key": "content.body",
- # Match the localpart of the requester's MXID.
- "pattern_type": "user_localpart",
- }
- ],
- "actions": [
- "notify",
- {"set_tweak": "sound", "value": "default"},
- {"set_tweak": "highlight"},
- ],
- }
-]
-
-
-BASE_PREPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
- {
- "rule_id": "global/override/.m.rule.master",
- "enabled": False,
- "conditions": [],
- "actions": ["dont_notify"],
- }
-]
-
-
-BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
- {
- "rule_id": "global/override/.m.rule.suppress_notices",
- "conditions": [
- {
- "kind": "event_match",
- "key": "content.msgtype",
- "pattern": "m.notice",
- "_cache_key": "_suppress_notices",
- }
- ],
- "actions": ["dont_notify"],
- },
- # NB. .m.rule.invite_for_me must be higher prio than .m.rule.member_event
- # otherwise invites will be matched by .m.rule.member_event
- {
- "rule_id": "global/override/.m.rule.invite_for_me",
- "conditions": [
- {
- "kind": "event_match",
- "key": "type",
- "pattern": "m.room.member",
- "_cache_key": "_member",
- },
- {
- "kind": "event_match",
- "key": "content.membership",
- "pattern": "invite",
- "_cache_key": "_invite_member",
- },
- # Match the requester's MXID.
- {"kind": "event_match", "key": "state_key", "pattern_type": "user_id"},
- ],
- "actions": [
- "notify",
- {"set_tweak": "sound", "value": "default"},
- {"set_tweak": "highlight", "value": False},
- ],
- },
- # Will we sometimes want to know about people joining and leaving?
- # Perhaps: if so, this could be expanded upon. Seems the most usual case
- # is that we don't though. We add this override rule so that even if
- # the room rule is set to notify, we don't get notifications about
- # join/leave/avatar/displayname events.
- # See also: https://matrix.org/jira/browse/SYN-607
- {
- "rule_id": "global/override/.m.rule.member_event",
- "conditions": [
- {
- "kind": "event_match",
- "key": "type",
- "pattern": "m.room.member",
- "_cache_key": "_member",
- }
- ],
- "actions": ["dont_notify"],
- },
- # This was changed from underride to override so it's closer in priority
- # to the content rules where the user name highlight rule lives. This
- # way a room rule is lower priority than both but a custom override rule
- # is higher priority than both.
- {
- "rule_id": "global/override/.m.rule.contains_display_name",
- "conditions": [{"kind": "contains_display_name"}],
- "actions": [
- "notify",
- {"set_tweak": "sound", "value": "default"},
- {"set_tweak": "highlight"},
- ],
- },
- {
- "rule_id": "global/override/.m.rule.roomnotif",
- "conditions": [
- {
- "kind": "event_match",
- "key": "content.body",
- "pattern": "@room",
- "_cache_key": "_roomnotif_content",
- },
- {
- "kind": "sender_notification_permission",
- "key": "room",
- "_cache_key": "_roomnotif_pl",
- },
- ],
- "actions": ["notify", {"set_tweak": "highlight", "value": True}],
- },
- {
- "rule_id": "global/override/.m.rule.tombstone",
- "conditions": [
- {
- "kind": "event_match",
- "key": "type",
- "pattern": "m.room.tombstone",
- "_cache_key": "_tombstone",
- },
- {
- "kind": "event_match",
- "key": "state_key",
- "pattern": "",
- "_cache_key": "_tombstone_statekey",
- },
- ],
- "actions": ["notify", {"set_tweak": "highlight", "value": True}],
- },
- {
- "rule_id": "global/override/.m.rule.reaction",
- "conditions": [
- {
- "kind": "event_match",
- "key": "type",
- "pattern": "m.reaction",
- "_cache_key": "_reaction",
- }
- ],
- "actions": ["dont_notify"],
- },
- # XXX: This is an experimental rule that is only enabled if msc3786_enabled
- # is enabled, if it is not the rule gets filtered out in _load_rules() in
- # PushRulesWorkerStore
- {
- "rule_id": "global/override/.org.matrix.msc3786.rule.room.server_acl",
- "conditions": [
- {
- "kind": "event_match",
- "key": "type",
- "pattern": "m.room.server_acl",
- "_cache_key": "_room_server_acl",
- },
- {
- "kind": "event_match",
- "key": "state_key",
- "pattern": "",
- "_cache_key": "_room_server_acl_state_key",
- },
- ],
- "actions": [],
- },
-]
-
-
-BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
- {
- "rule_id": "global/underride/.m.rule.call",
- "conditions": [
- {
- "kind": "event_match",
- "key": "type",
- "pattern": "m.call.invite",
- "_cache_key": "_call",
- }
- ],
- "actions": [
- "notify",
- {"set_tweak": "sound", "value": "ring"},
- {"set_tweak": "highlight", "value": False},
- ],
- },
- # XXX: once m.direct is standardised everywhere, we should use it to detect
- # a DM from the user's perspective rather than this heuristic.
- {
- "rule_id": "global/underride/.m.rule.room_one_to_one",
- "conditions": [
- {"kind": "room_member_count", "is": "2", "_cache_key": "member_count"},
- {
- "kind": "event_match",
- "key": "type",
- "pattern": "m.room.message",
- "_cache_key": "_message",
- },
- ],
- "actions": [
- "notify",
- {"set_tweak": "sound", "value": "default"},
- {"set_tweak": "highlight", "value": False},
- ],
- },
- # XXX: this is going to fire for events which aren't m.room.messages
- # but are encrypted (e.g. m.call.*)...
- {
- "rule_id": "global/underride/.m.rule.encrypted_room_one_to_one",
- "conditions": [
- {"kind": "room_member_count", "is": "2", "_cache_key": "member_count"},
- {
- "kind": "event_match",
- "key": "type",
- "pattern": "m.room.encrypted",
- "_cache_key": "_encrypted",
- },
- ],
- "actions": [
- "notify",
- {"set_tweak": "sound", "value": "default"},
- {"set_tweak": "highlight", "value": False},
- ],
- },
- {
- "rule_id": "global/underride/.org.matrix.msc3772.thread_reply",
- "conditions": [
- {
- "kind": "org.matrix.msc3772.relation_match",
- "rel_type": "m.thread",
- # Match the requester's MXID.
- "sender_type": "user_id",
- }
- ],
- "actions": ["notify", {"set_tweak": "highlight", "value": False}],
- },
- {
- "rule_id": "global/underride/.m.rule.message",
- "conditions": [
- {
- "kind": "event_match",
- "key": "type",
- "pattern": "m.room.message",
- "_cache_key": "_message",
- }
- ],
- "actions": ["notify", {"set_tweak": "highlight", "value": False}],
- },
- # XXX: this is going to fire for events which aren't m.room.messages
- # but are encrypted (e.g. m.call.*)...
- {
- "rule_id": "global/underride/.m.rule.encrypted",
- "conditions": [
- {
- "kind": "event_match",
- "key": "type",
- "pattern": "m.room.encrypted",
- "_cache_key": "_encrypted",
- }
- ],
- "actions": ["notify", {"set_tweak": "highlight", "value": False}],
- },
- {
- "rule_id": "global/underride/.im.vector.jitsi",
- "conditions": [
- {
- "kind": "event_match",
- "key": "type",
- "pattern": "im.vector.modular.widgets",
- "_cache_key": "_type_modular_widgets",
- },
- {
- "kind": "event_match",
- "key": "content.type",
- "pattern": "jitsi",
- "_cache_key": "_content_type_jitsi",
- },
- {
- "kind": "event_match",
- "key": "state_key",
- "pattern": "*",
- "_cache_key": "_is_state_event",
- },
- ],
- "actions": ["notify", {"set_tweak": "highlight", "value": False}],
- },
-]
-
-
-BASE_RULE_IDS = set()
-
-for r in BASE_APPEND_CONTENT_RULES:
- r["priority_class"] = PRIORITY_CLASS_MAP["content"]
- r["default"] = True
- BASE_RULE_IDS.add(r["rule_id"])
-
-for r in BASE_PREPEND_OVERRIDE_RULES:
- r["priority_class"] = PRIORITY_CLASS_MAP["override"]
- r["default"] = True
- BASE_RULE_IDS.add(r["rule_id"])
-
-for r in BASE_APPEND_OVERRIDE_RULES:
- r["priority_class"] = PRIORITY_CLASS_MAP["override"]
- r["default"] = True
- BASE_RULE_IDS.add(r["rule_id"])
-
-for r in BASE_APPEND_UNDERRIDE_RULES:
- r["priority_class"] = PRIORITY_CLASS_MAP["underride"]
- r["default"] = True
- BASE_RULE_IDS.add(r["rule_id"])
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 713dcf6950..75b7e126ca 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -13,31 +13,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import itertools
import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Union
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Collection,
+ Dict,
+ List,
+ Mapping,
+ Optional,
+ Tuple,
+ Union,
+)
from prometheus_client import Counter
-from synapse.api.constants import EventTypes, Membership, RelationTypes
+from synapse.api.constants import MAIN_TIMELINE, EventTypes, Membership, RelationTypes
from synapse.event_auth import auth_types_for_event, get_user_power_level
from synapse.events import EventBase, relation_from_event
from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
from synapse.storage.databases.main.roommember import EventIdMembership
from synapse.storage.state import StateFilter
+from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator
from synapse.util.caches import register_cache
from synapse.util.metrics import measure_func
from synapse.visibility import filter_event_for_clients_with_state
-from .push_rule_evaluator import PushRuleEvaluatorForEvent
-
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
-
push_rules_invalidation_counter = Counter(
"synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", ""
)
@@ -99,6 +106,8 @@ class BulkPushRuleEvaluator:
self.clock = hs.get_clock()
self._event_auth_handler = hs.get_event_auth_handler()
+ self._related_event_match_enabled = self.hs.config.experimental.msc3664_enabled
+
self.room_push_rule_cache_metrics = register_cache(
"cache",
"room_push_rule_cache",
@@ -106,13 +115,10 @@ class BulkPushRuleEvaluator:
resizable=False,
)
- # Whether to support MSC3772 is supported.
- self._relations_match_enabled = self.hs.config.experimental.msc3772_enabled
-
async def _get_rules_for_event(
self,
event: EventBase,
- ) -> Dict[str, List[Dict[str, Any]]]:
+ ) -> Dict[str, FilteredPushRules]:
"""Get the push rules for all users who may need to be notified about
the event.
@@ -160,23 +166,51 @@ class BulkPushRuleEvaluator:
return rules_by_user
async def _get_power_levels_and_sender_level(
- self, event: EventBase, context: EventContext
- ) -> Tuple[dict, int]:
+ self,
+ event: EventBase,
+ context: EventContext,
+ event_id_to_event: Mapping[str, EventBase],
+ ) -> Tuple[dict, Optional[int]]:
+ """
+ Given an event and an event context, get the power level event relevant to the event
+ and the power level of the sender of the event.
+ Args:
+ event: event to check
+ context: context of event to check
+ event_id_to_event: a mapping of event_id to event for a set of events being
+ batch persisted. This is needed as the sought-after power level event may
+ be in this batch rather than the DB
+ """
+ # There are no power levels and sender levels possible to get from outlier
+ if event.internal_metadata.is_outlier():
+ return {}, None
+
event_types = auth_types_for_event(event.room_version, event)
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types(event_types)
)
pl_event_id = prev_state_ids.get(POWER_KEY)
+ # fastpath: if there's a power level event, that's all we need, and
+ # not having a power level event is an extreme edge case
if pl_event_id:
- # fastpath: if there's a power level event, that's all we need, and
- # not having a power level event is an extreme edge case
- auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)}
+ # Get the power level event from the batch, or fall back to the database.
+ pl_event = event_id_to_event.get(pl_event_id)
+ if pl_event:
+ auth_events = {POWER_KEY: pl_event}
+ else:
+ auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)}
else:
auth_events_ids = self._event_auth_handler.compute_auth_events(
event, prev_state_ids, for_verification=False
)
auth_events_dict = await self.store.get_events(auth_events_ids)
+ # Some needed auth events might be in the batch, combine them with those
+ # fetched from the database.
+ for auth_event_id in auth_events_ids:
+ auth_event = event_id_to_event.get(auth_event_id)
+ if auth_event:
+ auth_events_dict[auth_event_id] = auth_event
auth_events = {(e.type, e.state_key): e for e in auth_events_dict.values()}
sender_level = get_user_power_level(event.sender, auth_events)
@@ -185,76 +219,91 @@ class BulkPushRuleEvaluator:
return pl_event.content if pl_event else {}, sender_level
- async def _get_mutual_relations(
- self, event: EventBase, rules: Iterable[Dict[str, Any]]
- ) -> Dict[str, Set[Tuple[str, str]]]:
- """
- Fetch event metadata for events which related to the same event as the given event.
-
- If the given event has no relation information, returns an empty dictionary.
-
- Args:
- event_id: The event ID which is targeted by relations.
- rules: The push rules which will be processed for this event.
+ async def _related_events(self, event: EventBase) -> Dict[str, Dict[str, str]]:
+ """Fetches the related events for 'event'. Sets the im.vector.is_falling_back key if the event is from a fallback relation
Returns:
- A dictionary of relation type to:
- A set of tuples of:
- The sender
- The event type
+ Mapping of relation type to flattened events.
"""
+ related_events: Dict[str, Dict[str, str]] = {}
+ if self._related_event_match_enabled:
+ related_event_id = event.content.get("m.relates_to", {}).get("event_id")
+ relation_type = event.content.get("m.relates_to", {}).get("rel_type")
+ if related_event_id is not None and relation_type is not None:
+ related_event = await self.store.get_event(
+ related_event_id, allow_none=True
+ )
+ if related_event is not None:
+ related_events[relation_type] = _flatten_dict(related_event)
- # If the experimental feature is not enabled, skip fetching relations.
- if not self._relations_match_enabled:
- return {}
+ reply_event_id = (
+ event.content.get("m.relates_to", {})
+ .get("m.in_reply_to", {})
+ .get("event_id")
+ )
- # If the event does not have a relation, then cannot have any mutual
- # relations.
- relation = relation_from_event(event)
- if not relation:
- return {}
-
- # Pre-filter to figure out which relation types are interesting.
- rel_types = set()
- for rule in rules:
- # Skip disabled rules.
- if "enabled" in rule and not rule["enabled"]:
- continue
+ # convert replies to pseudo relations
+ if reply_event_id is not None:
+ related_event = await self.store.get_event(
+ reply_event_id, allow_none=True
+ )
- for condition in rule["conditions"]:
- if condition["kind"] != "org.matrix.msc3772.relation_match":
- continue
+ if related_event is not None:
+ related_events["m.in_reply_to"] = _flatten_dict(related_event)
- # rel_type is required.
- rel_type = condition.get("rel_type")
- if rel_type:
- rel_types.add(rel_type)
+ # indicate that this is from a fallback relation.
+ if relation_type == "m.thread" and event.content.get(
+ "m.relates_to", {}
+ ).get("is_falling_back", False):
+ related_events["m.in_reply_to"][
+ "im.vector.is_falling_back"
+ ] = ""
- # If no valid rules were found, no mutual relations.
- if not rel_types:
- return {}
+ return related_events
- # If any valid rules were found, fetch the mutual relations.
- return await self.store.get_mutual_event_relations(
- relation.parent_id, rel_types
- )
+ async def action_for_events_by_user(
+ self, events_and_context: List[Tuple[EventBase, EventContext]]
+ ) -> None:
+ """Given a list of events and their associated contexts, evaluate the push rules
+ for each event, check if the message should increment the unread count, and
+ insert the results into the event_push_actions_staging table.
+ """
+ # For batched events the power level events may not have been persisted yet,
+ # so we pass in the batched events. Thus if the event cannot be found in the
+ # database we can check in the batch.
+ event_id_to_event = {e.event_id: e for e, _ in events_and_context}
+ for event, context in events_and_context:
+ await self._action_for_event_by_user(event, context, event_id_to_event)
@measure_func("action_for_event_by_user")
- async def action_for_event_by_user(
- self, event: EventBase, context: EventContext
+ async def _action_for_event_by_user(
+ self,
+ event: EventBase,
+ context: EventContext,
+ event_id_to_event: Mapping[str, EventBase],
) -> None:
- """Given an event and context, evaluate the push rules, check if the message
- should increment the unread count, and insert the results into the
- event_push_actions_staging table.
- """
- if event.internal_metadata.is_outlier():
- # This can happen due to out of band memberships
+
+ if (
+ not event.internal_metadata.is_notifiable()
+ or event.internal_metadata.is_historical()
+ ):
+ # Push rules for events that aren't notifiable can't be processed by this and
+ # we want to skip push notification actions for historical messages
+ # because we don't want to notify people about old history back in time.
+ # The historical messages also do not have the proper `context.current_state_ids`
+ # and `state_groups` because they have `prev_events` that aren't persisted yet
+ # (historical messages persisted in reverse-chronological order).
return
- count_as_unread = _should_count_as_unread(event, context)
+ # Disable counting as unread unless the experimental configuration is
+ # enabled, as it can cause additional (unwanted) rows to be added to the
+ # event_push_actions table.
+ count_as_unread = False
+ if self.hs.config.experimental.msc2654_enabled:
+ count_as_unread = _should_count_as_unread(event, context)
rules_by_user = await self._get_rules_for_event(event)
- actions_by_user: Dict[str, List[Union[dict, str]]] = {}
+ actions_by_user: Dict[str, Collection[Union[Mapping, str]]] = {}
room_member_count = await self.store.get_number_joined_users_in_room(
event.room_id
@@ -263,19 +312,39 @@ class BulkPushRuleEvaluator:
(
power_levels,
sender_power_level,
- ) = await self._get_power_levels_and_sender_level(event, context)
-
- relations = await self._get_mutual_relations(
- event, itertools.chain(*rules_by_user.values())
+ ) = await self._get_power_levels_and_sender_level(
+ event, context, event_id_to_event
)
- evaluator = PushRuleEvaluatorForEvent(
- event,
+ # Find the event's thread ID.
+ relation = relation_from_event(event)
+ # If the event does not have a relation, then it cannot have a thread ID.
+ thread_id = MAIN_TIMELINE
+ if relation:
+ # Recursively attempt to find the thread this event relates to.
+ if relation.rel_type == RelationTypes.THREAD:
+ thread_id = relation.parent_id
+ else:
+ # Since the event has not yet been persisted we check whether
+ # the parent is part of a thread.
+ thread_id = await self.store.get_thread_id(relation.parent_id)
+
+ related_events = await self._related_events(event)
+
+ # It's possible that old room versions have non-integer power levels (floats or
+ # strings). Workaround this by explicitly converting to int.
+ notification_levels = power_levels.get("notifications", {})
+ if not event.room_version.msc3667_int_only_power_levels:
+ for user_id, level in notification_levels.items():
+ notification_levels[user_id] = int(level)
+
+ evaluator = PushRuleEvaluator(
+ _flatten_dict(event),
room_member_count,
sender_power_level,
- power_levels,
- relations,
- self._relations_match_enabled,
+ notification_levels,
+ related_events,
+ self._related_event_match_enabled,
)
users = rules_by_user.keys()
@@ -283,20 +352,10 @@ class BulkPushRuleEvaluator:
event.room_id, users
)
- # This is a check for the case where user joins a room without being
- # allowed to see history, and then the server receives a delayed event
- # from before the user joined, which they should not be pushed for
- uids_with_visibility = await filter_event_for_clients_with_state(
- self.store, users, event, context
- )
-
for uid, rules in rules_by_user.items():
if event.sender == uid:
continue
- if uid not in uids_with_visibility:
- continue
-
display_name = None
profile = profiles.get(uid)
if profile:
@@ -317,19 +376,30 @@ class BulkPushRuleEvaluator:
# current user, it'll be added to the dict later.
actions_by_user[uid] = []
- for rule in rules:
- if "enabled" in rule and not rule["enabled"]:
- continue
+ actions = evaluator.run(rules, uid, display_name)
+ if "notify" in actions:
+ # Push rules say we should notify the user of this event
+ actions_by_user[uid] = actions
- matches = evaluator.check_conditions(
- rule["conditions"], uid, display_name
- )
- if matches:
- actions = [x for x in rule["actions"] if x != "dont_notify"]
- if actions and "notify" in actions:
- # Push rules say we should notify the user of this event
- actions_by_user[uid] = actions
- break
+ # If there aren't any actions then we can skip the rest of the
+ # processing.
+ if not actions_by_user:
+ return
+
+ # This is a check for the case where user joins a room without being
+ # allowed to see history, and then the server receives a delayed event
+ # from before the user joined, which they should not be pushed for
+ #
+ # We do this *after* calculating the push actions as a) its unlikely
+ # that we'll filter anyone out and b) for large rooms its likely that
+ # most users will have push disabled and so the set of users to check is
+ # much smaller.
+ uids_with_visibility = await filter_event_for_clients_with_state(
+ self.store, actions_by_user.keys(), event, context
+ )
+
+ for user_id in set(actions_by_user).difference(uids_with_visibility):
+ actions_by_user.pop(user_id, None)
# Mark in the DB staging area the push actions for users who should be
# notified for this event. (This will then get handled when we persist
@@ -338,6 +408,7 @@ class BulkPushRuleEvaluator:
event.event_id,
actions_by_user,
count_as_unread,
+ thread_id,
)
@@ -345,3 +416,21 @@ MemberMap = Dict[str, Optional[EventIdMembership]]
Rule = Dict[str, dict]
RulesByUser = Dict[str, List[Rule]]
StateGroup = Union[object, int]
+
+
+def _flatten_dict(
+ d: Union[EventBase, Mapping[str, Any]],
+ prefix: Optional[List[str]] = None,
+ result: Optional[Dict[str, str]] = None,
+) -> Dict[str, str]:
+ if prefix is None:
+ prefix = []
+ if result is None:
+ result = {}
+ for key, value in d.items():
+ if isinstance(value, str):
+ result[".".join(prefix + [key])] = value.lower()
+ elif isinstance(value, Mapping):
+ _flatten_dict(value, prefix=(prefix + [key]), result=result)
+
+ return result
diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py
index 5117ef6854..622a1e35c5 100644
--- a/synapse/push/clientformat.py
+++ b/synapse/push/clientformat.py
@@ -16,18 +16,16 @@ import copy
from typing import Any, Dict, List, Optional
from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
+from synapse.synapse_rust.push import FilteredPushRules, PushRule
from synapse.types import UserID
def format_push_rules_for_user(
- user: UserID, ruleslist: List
+ user: UserID, ruleslist: FilteredPushRules
) -> Dict[str, Dict[str, list]]:
"""Converts a list of rawrules and a enabled map into nested dictionaries
to match the Matrix client-server format for push rules"""
- # We're going to be mutating this a lot, so do a deep copy
- ruleslist = copy.deepcopy(ruleslist)
-
rules: Dict[str, Dict[str, List[Dict[str, Any]]]] = {
"global": {},
"device": {},
@@ -35,11 +33,36 @@ def format_push_rules_for_user(
rules["global"] = _add_empty_priority_class_arrays(rules["global"])
- for r in ruleslist:
- template_name = _priority_class_to_template_name(r["priority_class"])
+ for r, enabled in ruleslist.rules():
+ template_name = _priority_class_to_template_name(r.priority_class)
+
+ rulearray = rules["global"][template_name]
+
+ template_rule = _rule_to_template(r)
+ if not template_rule:
+ continue
+
+ rulearray.append(template_rule)
+
+ pattern_type = template_rule.pop("pattern_type", None)
+ if pattern_type == "user_id":
+ template_rule["pattern"] = user.to_string()
+ elif pattern_type == "user_localpart":
+ template_rule["pattern"] = user.localpart
+
+ template_rule["enabled"] = enabled
+
+ if "conditions" not in template_rule:
+ # Not all formatted rules have explicit conditions, e.g. "room"
+ # rules omit them as they can be derived from the kind and rule ID.
+ #
+ # If the formatted rule has no conditions then we can skip the
+ # formatting of conditions.
+ continue
# Remove internal stuff.
- for c in r["conditions"]:
+ template_rule["conditions"] = copy.deepcopy(template_rule["conditions"])
+ for c in template_rule["conditions"]:
c.pop("_cache_key", None)
pattern_type = c.pop("pattern_type", None)
@@ -52,16 +75,6 @@ def format_push_rules_for_user(
if sender_type == "user_id":
c["sender"] = user.to_string()
- rulearray = rules["global"][template_name]
-
- template_rule = _rule_to_template(r)
- if template_rule:
- if "enabled" in r:
- template_rule["enabled"] = r["enabled"]
- else:
- template_rule["enabled"] = True
- rulearray.append(template_rule)
-
return rules
@@ -71,34 +84,36 @@ def _add_empty_priority_class_arrays(d: Dict[str, list]) -> Dict[str, list]:
return d
-def _rule_to_template(rule: Dict[str, Any]) -> Optional[Dict[str, Any]]:
- unscoped_rule_id = None
- if "rule_id" in rule:
- unscoped_rule_id = _rule_id_from_namespaced(rule["rule_id"])
+def _rule_to_template(rule: PushRule) -> Optional[Dict[str, Any]]:
+ templaterule: Dict[str, Any]
+
+ unscoped_rule_id = _rule_id_from_namespaced(rule.rule_id)
- template_name = _priority_class_to_template_name(rule["priority_class"])
+ template_name = _priority_class_to_template_name(rule.priority_class)
if template_name in ["override", "underride"]:
- templaterule = {k: rule[k] for k in ["conditions", "actions"]}
+ templaterule = {"conditions": rule.conditions, "actions": rule.actions}
elif template_name in ["sender", "room"]:
- templaterule = {"actions": rule["actions"]}
- unscoped_rule_id = rule["conditions"][0]["pattern"]
+ templaterule = {"actions": rule.actions}
+ unscoped_rule_id = rule.conditions[0]["pattern"]
elif template_name == "content":
- if len(rule["conditions"]) != 1:
+ if len(rule.conditions) != 1:
return None
- thecond = rule["conditions"][0]
- if "pattern" not in thecond:
+ thecond = rule.conditions[0]
+
+ templaterule = {"actions": rule.actions}
+ if "pattern" in thecond:
+ templaterule["pattern"] = thecond["pattern"]
+ elif "pattern_type" in thecond:
+ templaterule["pattern_type"] = thecond["pattern_type"]
+ else:
return None
- templaterule = {"actions": rule["actions"]}
- templaterule["pattern"] = thecond["pattern"]
else:
# This should not be reached unless this function is not kept in sync
# with PRIORITY_CLASS_INVERSE_MAP.
raise ValueError("Unexpected template_name: %s" % (template_name,))
- if unscoped_rule_id:
- templaterule["rule_id"] = unscoped_rule_id
- if "default" in rule:
- templaterule["default"] = rule["default"]
+ templaterule["rule_id"] = unscoped_rule_id
+ templaterule["default"] = rule.default
return templaterule
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index e96fb45e9f..b048b03a74 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
import urllib.parse
-from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
from prometheus_client import Counter
@@ -28,7 +28,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import Pusher, PusherConfig, PusherConfigException
from synapse.storage.databases.main.event_push_actions import HttpPushAction
-from . import push_rule_evaluator, push_tools
+from . import push_tools
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -56,6 +56,39 @@ http_badges_failed_counter = Counter(
)
+def tweaks_for_actions(actions: List[Union[str, Dict]]) -> Dict[str, Any]:
+ """
+ Converts a list of actions into a `tweaks` dict (which can then be passed to
+ the push gateway).
+
+ This function ignores all actions other than `set_tweak` actions, and treats
+ absent `value`s as `True`, which agrees with the only spec-defined treatment
+ of absent `value`s (namely, for `highlight` tweaks).
+
+ Args:
+ actions: list of actions
+ e.g. [
+ {"set_tweak": "a", "value": "AAA"},
+ {"set_tweak": "b", "value": "BBB"},
+ {"set_tweak": "highlight"},
+ "notify"
+ ]
+
+ Returns:
+ dictionary of tweaks for those actions
+ e.g. {"a": "AAA", "b": "BBB", "highlight": True}
+ """
+ tweaks = {}
+ for a in actions:
+ if not isinstance(a, dict):
+ continue
+ if "set_tweak" in a:
+ # value is allowed to be absent in which case the value assumed
+ # should be True.
+ tweaks[a["set_tweak"]] = a.get("value", True)
+ return tweaks
+
+
class HttpPusher(Pusher):
INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes
MAX_BACKOFF_SEC = 60 * 60
@@ -281,7 +314,7 @@ class HttpPusher(Pusher):
if "notify" not in push_action.actions:
return True
- tweaks = push_rule_evaluator.tweaks_for_actions(push_action.actions)
+ tweaks = tweaks_for_actions(push_action.actions)
badge = await push_tools.get_badge_count(
self.hs.get_datastores().main,
self.user_id,
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
deleted file mode 100644
index 2e8a017add..0000000000
--- a/synapse/push/push_rule_evaluator.py
+++ /dev/null
@@ -1,350 +0,0 @@
-# Copyright 2015, 2016 OpenMarket Ltd
-# Copyright 2017 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-import re
-from typing import Any, Dict, List, Mapping, Optional, Pattern, Set, Tuple, Union
-
-from matrix_common.regex import glob_to_regex, to_word_pattern
-
-from synapse.events import EventBase
-from synapse.types import UserID
-from synapse.util.caches.lrucache import LruCache
-
-logger = logging.getLogger(__name__)
-
-
-GLOB_REGEX = re.compile(r"\\\[(\\\!|)(.*)\\\]")
-IS_GLOB = re.compile(r"[\?\*\[\]]")
-INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
-
-
-def _room_member_count(
- ev: EventBase, condition: Dict[str, Any], room_member_count: int
-) -> bool:
- return _test_ineq_condition(condition, room_member_count)
-
-
-def _sender_notification_permission(
- ev: EventBase,
- condition: Dict[str, Any],
- sender_power_level: int,
- power_levels: Dict[str, Union[int, Dict[str, int]]],
-) -> bool:
- notif_level_key = condition.get("key")
- if notif_level_key is None:
- return False
-
- notif_levels = power_levels.get("notifications", {})
- assert isinstance(notif_levels, dict)
- room_notif_level = notif_levels.get(notif_level_key, 50)
-
- return sender_power_level >= room_notif_level
-
-
-def _test_ineq_condition(condition: Dict[str, Any], number: int) -> bool:
- if "is" not in condition:
- return False
- m = INEQUALITY_EXPR.match(condition["is"])
- if not m:
- return False
- ineq = m.group(1)
- rhs = m.group(2)
- if not rhs.isdigit():
- return False
- rhs_int = int(rhs)
-
- if ineq == "" or ineq == "==":
- return number == rhs_int
- elif ineq == "<":
- return number < rhs_int
- elif ineq == ">":
- return number > rhs_int
- elif ineq == ">=":
- return number >= rhs_int
- elif ineq == "<=":
- return number <= rhs_int
- else:
- return False
-
-
-def tweaks_for_actions(actions: List[Union[str, Dict]]) -> Dict[str, Any]:
- """
- Converts a list of actions into a `tweaks` dict (which can then be passed to
- the push gateway).
-
- This function ignores all actions other than `set_tweak` actions, and treats
- absent `value`s as `True`, which agrees with the only spec-defined treatment
- of absent `value`s (namely, for `highlight` tweaks).
-
- Args:
- actions: list of actions
- e.g. [
- {"set_tweak": "a", "value": "AAA"},
- {"set_tweak": "b", "value": "BBB"},
- {"set_tweak": "highlight"},
- "notify"
- ]
-
- Returns:
- dictionary of tweaks for those actions
- e.g. {"a": "AAA", "b": "BBB", "highlight": True}
- """
- tweaks = {}
- for a in actions:
- if not isinstance(a, dict):
- continue
- if "set_tweak" in a:
- # value is allowed to be absent in which case the value assumed
- # should be True.
- tweaks[a["set_tweak"]] = a.get("value", True)
- return tweaks
-
-
-class PushRuleEvaluatorForEvent:
- def __init__(
- self,
- event: EventBase,
- room_member_count: int,
- sender_power_level: int,
- power_levels: Dict[str, Union[int, Dict[str, int]]],
- relations: Dict[str, Set[Tuple[str, str]]],
- relations_match_enabled: bool,
- ):
- self._event = event
- self._room_member_count = room_member_count
- self._sender_power_level = sender_power_level
- self._power_levels = power_levels
- self._relations = relations
- self._relations_match_enabled = relations_match_enabled
-
- # Maps strings of e.g. 'content.body' -> event["content"]["body"]
- self._value_cache = _flatten_dict(event)
-
- # Maps cache keys to final values.
- self._condition_cache: Dict[str, bool] = {}
-
- def check_conditions(
- self, conditions: List[dict], uid: str, display_name: Optional[str]
- ) -> bool:
- """
- Returns true if a user's conditions/user ID/display name match the event.
-
- Args:
- conditions: The user's conditions to match.
- uid: The user's MXID.
- display_name: The display name.
-
- Returns:
- True if all conditions match the event, False otherwise.
- """
- for cond in conditions:
- _cache_key = cond.get("_cache_key", None)
- if _cache_key:
- res = self._condition_cache.get(_cache_key, None)
- if res is False:
- return False
- elif res is True:
- continue
-
- res = self.matches(cond, uid, display_name)
- if _cache_key:
- self._condition_cache[_cache_key] = bool(res)
-
- if not res:
- return False
-
- return True
-
- def matches(
- self, condition: Dict[str, Any], user_id: str, display_name: Optional[str]
- ) -> bool:
- """
- Returns true if a user's condition/user ID/display name match the event.
-
- Args:
- condition: The user's condition to match.
- uid: The user's MXID.
- display_name: The display name, or None if there is not one.
-
- Returns:
- True if the condition matches the event, False otherwise.
- """
- if condition["kind"] == "event_match":
- return self._event_match(condition, user_id)
- elif condition["kind"] == "contains_display_name":
- return self._contains_display_name(display_name)
- elif condition["kind"] == "room_member_count":
- return _room_member_count(self._event, condition, self._room_member_count)
- elif condition["kind"] == "sender_notification_permission":
- return _sender_notification_permission(
- self._event, condition, self._sender_power_level, self._power_levels
- )
- elif (
- condition["kind"] == "org.matrix.msc3772.relation_match"
- and self._relations_match_enabled
- ):
- return self._relation_match(condition, user_id)
- else:
- # XXX This looks incorrect -- we have reached an unknown condition
- # kind and are unconditionally returning that it matches. Note
- # that it seems possible to provide a condition to the /pushrules
- # endpoint with an unknown kind, see _rule_tuple_from_request_object.
- return True
-
- def _event_match(self, condition: dict, user_id: str) -> bool:
- """
- Check an "event_match" push rule condition.
-
- Args:
- condition: The "event_match" push rule condition to match.
- user_id: The user's MXID.
-
- Returns:
- True if the condition matches the event, False otherwise.
- """
- pattern = condition.get("pattern", None)
-
- if not pattern:
- pattern_type = condition.get("pattern_type", None)
- if pattern_type == "user_id":
- pattern = user_id
- elif pattern_type == "user_localpart":
- pattern = UserID.from_string(user_id).localpart
-
- if not pattern:
- logger.warning("event_match condition with no pattern")
- return False
-
- # XXX: optimisation: cache our pattern regexps
- if condition["key"] == "content.body":
- body = self._event.content.get("body", None)
- if not body or not isinstance(body, str):
- return False
-
- return _glob_matches(pattern, body, word_boundary=True)
- else:
- haystack = self._value_cache.get(condition["key"], None)
- if haystack is None:
- return False
-
- return _glob_matches(pattern, haystack)
-
- def _contains_display_name(self, display_name: Optional[str]) -> bool:
- """
- Check an "event_match" push rule condition.
-
- Args:
- display_name: The display name, or None if there is not one.
-
- Returns:
- True if the display name is found in the event body, False otherwise.
- """
- if not display_name:
- return False
-
- body = self._event.content.get("body", None)
- if not body or not isinstance(body, str):
- return False
-
- # Similar to _glob_matches, but do not treat display_name as a glob.
- r = regex_cache.get((display_name, False, True), None)
- if not r:
- r1 = re.escape(display_name)
- r1 = to_word_pattern(r1)
- r = re.compile(r1, flags=re.IGNORECASE)
- regex_cache[(display_name, False, True)] = r
-
- return bool(r.search(body))
-
- def _relation_match(self, condition: dict, user_id: str) -> bool:
- """
- Check an "relation_match" push rule condition.
-
- Args:
- condition: The "event_match" push rule condition to match.
- user_id: The user's MXID.
-
- Returns:
- True if the condition matches the event, False otherwise.
- """
- rel_type = condition.get("rel_type")
- if not rel_type:
- logger.warning("relation_match condition missing rel_type")
- return False
-
- sender_pattern = condition.get("sender")
- if sender_pattern is None:
- sender_type = condition.get("sender_type")
- if sender_type == "user_id":
- sender_pattern = user_id
- type_pattern = condition.get("type")
-
- # If any other relations matches, return True.
- for sender, event_type in self._relations.get(rel_type, ()):
- if sender_pattern and not _glob_matches(sender_pattern, sender):
- continue
- if type_pattern and not _glob_matches(type_pattern, event_type):
- continue
- # All values must have matched.
- return True
-
- # No relations matched.
- return False
-
-
-# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
-regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache(
- 50000, "regex_push_cache"
-)
-
-
-def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
- """Tests if value matches glob.
-
- Args:
- glob
- value: String to test against glob.
- word_boundary: Whether to match against word boundaries or entire
- string. Defaults to False.
- """
-
- try:
- r = regex_cache.get((glob, True, word_boundary), None)
- if not r:
- r = glob_to_regex(glob, word_boundary=word_boundary)
- regex_cache[(glob, True, word_boundary)] = r
- return bool(r.search(value))
- except re.error:
- logger.warning("Failed to parse glob to regex: %r", glob)
- return False
-
-
-def _flatten_dict(
- d: Union[EventBase, Mapping[str, Any]],
- prefix: Optional[List[str]] = None,
- result: Optional[Dict[str, str]] = None,
-) -> Dict[str, str]:
- if prefix is None:
- prefix = []
- if result is None:
- result = {}
- for key, value in d.items():
- if isinstance(value, str):
- result[".".join(prefix + [key])] = value.lower()
- elif isinstance(value, Mapping):
- _flatten_dict(value, prefix=(prefix + [key]), result=result)
-
- return result
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index 6661887d9f..edeba27a45 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -17,6 +17,7 @@ from synapse.events import EventBase
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage.controllers import StorageControllers
from synapse.storage.databases.main import DataStore
+from synapse.util.async_helpers import concurrently_execute
async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int:
@@ -25,14 +26,25 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
badge = len(invites)
- for room_id in joins:
- notifs = await (
- store.get_unread_event_push_actions_by_room_for_user(
+ room_notifs = []
+
+ async def get_room_unread_count(room_id: str) -> None:
+ room_notifs.append(
+ await store.get_unread_event_push_actions_by_room_for_user(
room_id,
user_id,
)
)
- if notifs.notify_count == 0:
+
+ await concurrently_execute(get_room_unread_count, joins, 10)
+
+ for notifs in room_notifs:
+ # Combine the counts from all the threads.
+ notify_count = notifs.main_timeline.notify_count + sum(
+ n.notify_count for n in notifs.threads.values()
+ )
+
+ if notify_count == 0:
continue
if group_by_room:
@@ -40,7 +52,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
badge += 1
else:
# increment the badge count by the number of unread messages in the room
- badge += notifs.notify_count
+ badge += notify_count
return badge
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 1e0ef44fc7..e2648cbc93 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -94,7 +94,7 @@ class PusherPool:
return
run_as_background_process("start_pushers", self._start_pushers)
- async def add_pusher(
+ async def add_or_update_pusher(
self,
user_id: str,
access_token: Optional[int],
@@ -106,6 +106,8 @@ class PusherPool:
lang: Optional[str],
data: JsonDict,
profile_tag: str = "",
+ enabled: bool = True,
+ device_id: Optional[str] = None,
) -> Optional[Pusher]:
"""Creates a new pusher and adds it to the pool
@@ -147,9 +149,22 @@ class PusherPool:
last_stream_ordering=last_stream_ordering,
last_success=None,
failing_since=None,
+ enabled=enabled,
+ device_id=device_id,
)
)
+ # Before we actually persist the pusher, we check if the user already has one
+ # this app ID and pushkey. If so, we want to keep the access token and device ID
+ # in place, since this could be one device modifying (e.g. enabling/disabling)
+ # another device's pusher.
+ existing_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey(
+ user_id, app_id, pushkey
+ )
+ if existing_config:
+ access_token = existing_config.access_token
+ device_id = existing_config.device_id
+
await self.store.add_pusher(
user_id=user_id,
access_token=access_token,
@@ -163,8 +178,10 @@ class PusherPool:
data=data,
last_stream_ordering=last_stream_ordering,
profile_tag=profile_tag,
+ enabled=enabled,
+ device_id=device_id,
)
- pusher = await self.start_pusher_by_id(app_id, pushkey, user_id)
+ pusher = await self.process_pusher_change_by_id(app_id, pushkey, user_id)
return pusher
@@ -276,10 +293,25 @@ class PusherPool:
except Exception:
logger.exception("Exception in pusher on_new_receipts")
- async def start_pusher_by_id(
+ async def _get_pusher_config_for_user_by_app_id_and_pushkey(
+ self, user_id: str, app_id: str, pushkey: str
+ ) -> Optional[PusherConfig]:
+ resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
+
+ pusher_config = None
+ for r in resultlist:
+ if r.user_name == user_id:
+ pusher_config = r
+
+ return pusher_config
+
+ async def process_pusher_change_by_id(
self, app_id: str, pushkey: str, user_id: str
) -> Optional[Pusher]:
- """Look up the details for the given pusher, and start it
+ """Look up the details for the given pusher, and either start it if its
+ "enabled" flag is True, or try to stop it otherwise.
+
+ If the pusher is new and its "enabled" flag is False, the stop is a noop.
Returns:
The pusher started, if any
@@ -290,12 +322,13 @@ class PusherPool:
if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
return None
- resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
+ pusher_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey(
+ user_id, app_id, pushkey
+ )
- pusher_config = None
- for r in resultlist:
- if r.user_name == user_id:
- pusher_config = r
+ if pusher_config and not pusher_config.enabled:
+ self.maybe_stop_pusher(app_id, pushkey, user_id)
+ return None
pusher = None
if pusher_config:
@@ -305,7 +338,7 @@ class PusherPool:
async def _start_pushers(self) -> None:
"""Start all the pushers"""
- pushers = await self.store.get_all_pushers()
+ pushers = await self.store.get_enabled_pushers()
# Stagger starting up the pushers so we don't completely drown the
# process on start up.
@@ -363,6 +396,8 @@ class PusherPool:
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).inc()
+ logger.info("Starting pusher %s / %s", pusher.user_id, appid_pushkey)
+
# Check if there *may* be push to process. We do this as this check is a
# lot cheaper to do than actually fetching the exact rows we need to
# push.
@@ -382,16 +417,7 @@ class PusherPool:
return pusher
async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
- appid_pushkey = "%s:%s" % (app_id, pushkey)
-
- byuser = self.pushers.get(user_id, {})
-
- if appid_pushkey in byuser:
- logger.info("Stopping pusher %s / %s", user_id, appid_pushkey)
- pusher = byuser.pop(appid_pushkey)
- pusher.on_stop()
-
- synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
+ self.maybe_stop_pusher(app_id, pushkey, user_id)
# We can only delete pushers on master.
if self._remove_pusher_client:
@@ -402,3 +428,22 @@ class PusherPool:
await self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id, pushkey, user_id
)
+
+ def maybe_stop_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
+ """Stops a pusher with the given app ID and push key if one is running.
+
+ Args:
+ app_id: the pusher's app ID.
+ pushkey: the pusher's push key.
+ user_id: the user the pusher belongs to. Only used for logging.
+ """
+ appid_pushkey = "%s:%s" % (app_id, pushkey)
+
+ byuser = self.pushers.get(user_id, {})
+
+ if appid_pushkey in byuser:
+ logger.info("Stopping pusher %s / %s", user_id, appid_pushkey)
+ pusher = byuser.pop(appid_pushkey)
+ pusher.on_stop()
+
+ synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
|