diff options
author | Erik Johnston <erik@matrix.org> | 2022-09-07 16:54:19 +0100 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2022-09-09 15:10:52 +0100 |
commit | 310df251730c895ceef6f030648232d517ab5a4a (patch) | |
tree | 1b170e11c44121bf7a37e566091a07b327d8ef17 | |
parent | Experimental rules (diff) | |
download | synapse-310df251730c895ceef6f030648232d517ab5a4a.tar.xz |
Fixup
-rw-r--r-- | rust/src/push/evaluator.rs | 244 | ||||
-rw-r--r-- | rust/src/push/mod.rs | 239 | ||||
-rw-r--r-- | stubs/synapse/synapse_rust/__init__.pyi (renamed from stubs/synapse/synapse_rust.pyi) | 0 | ||||
-rw-r--r-- | stubs/synapse/synapse_rust/push.pyi | 47 | ||||
-rw-r--r-- | synapse/push/bulk_push_rule_evaluator.py | 2 | ||||
-rw-r--r-- | synapse/push/clientformat.py | 3 | ||||
-rw-r--r-- | synapse/storage/databases/main/push_rule.py | 4 | ||||
-rw-r--r-- | tests/handlers/test_deactivate_account.py | 27 |
8 files changed, 310 insertions, 256 deletions
diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs new file mode 100644 index 0000000000..95d847cac0 --- /dev/null +++ b/rust/src/push/evaluator.rs @@ -0,0 +1,244 @@ +use std::collections::{BTreeMap, BTreeSet}; + +use anyhow::{Context, Error}; +use log::warn; +use pyo3::prelude::*; + +use super::{ + utils::{get_localpart_from_id, glob_to_regex, GlobMatchType}, + Action, Condition, EventMatchCondition, FilteredPushRules, INEQUALITY_EXPR, +}; + +#[pyclass] +pub struct PushRuleEvaluator { + flattened_keys: BTreeMap<String, String>, + body: String, + room_member_count: u64, + power_levels: BTreeMap<String, BTreeMap<String, u64>>, + relations: BTreeMap<String, BTreeSet<(String, String)>>, + relation_match_enabled: bool, + sender_power_level: u64, +} + +#[pymethods] +impl PushRuleEvaluator { + #[new] + fn py_new( + flattened_keys: BTreeMap<String, String>, + room_member_count: u64, + sender_power_level: u64, + power_levels: BTreeMap<String, BTreeMap<String, u64>>, + relations: BTreeMap<String, BTreeSet<(String, String)>>, + relation_match_enabled: bool, + ) -> Result<Self, Error> { + let body = flattened_keys + .get("content.body") + .cloned() + .unwrap_or_default(); + + Ok(PushRuleEvaluator { + flattened_keys, + body, + room_member_count, + power_levels, + relations, + relation_match_enabled, + sender_power_level, + }) + } + + fn run( + &self, + push_rules: &FilteredPushRules, + user_id: Option<&str>, + display_name: Option<&str>, + ) -> Vec<Action> { + let mut actions = Vec::new(); + 'outer: for (push_rule, enabled) in push_rules.iter() { + if !enabled { + continue; + } + + for condition in push_rule.conditions.iter() { + match self.match_condition(condition, user_id, display_name) { + Ok(true) => {} + Ok(false) => continue 'outer, + Err(err) => { + warn!("Condition match failed {err}"); + continue 'outer; + } + } + } + + actions.extend( + push_rule + .actions + .iter() + // .filter(|a| **a != Action::DontNotify) + .cloned(), + ); + + return actions; + } + + actions + } +} + +impl PushRuleEvaluator { + pub fn match_condition( + &self, + condition: &Condition, + user_id: Option<&str>, + display_name: Option<&str>, + ) -> Result<bool, Error> { + let result = match condition { + Condition::EventMatch(event_match) => self.match_event_match(event_match, user_id)?, + Condition::ContainsDisplayName => { + if let Some(dn) = display_name { + let matcher = glob_to_regex(dn, GlobMatchType::Word)?; + matcher.is_match(&self.body) + } else { + false + } + } + Condition::RoomMemberCount { is } => { + if let Some(is) = is { + self.match_member_count(is)? + } else { + false + } + } + Condition::SenderNotificationPermission { key } => { + let required_level = self + .power_levels + .get("notifications") + .and_then(|m| m.get(key.as_ref())) + .copied() + .unwrap_or(50); + + self.sender_power_level >= required_level + } + Condition::RelationMatch { + rel_type, + sender, + sender_type, + } => { + if !self.relation_match_enabled { + return Ok(false); + } + + let sender_pattern = if let Some(sender) = sender { + sender + } else if let Some(sender_type) = sender_type { + if sender_type == "user_id" { + if let Some(user_id) = user_id { + user_id + } else { + return Ok(false); + } + } else { + warn!("Unrecognized sender_type: {sender_type}"); + return Ok(false); + } + } else { + warn!("relation_match condition missing sender or sender_type"); + return Ok(false); + }; + + let relations = if let Some(relations) = self.relations.get(&**rel_type) { + relations + } else { + return Ok(false); + }; + + let sender_compiled_pattern = glob_to_regex(sender_pattern, GlobMatchType::Whole)?; + let rel_type_compiled_pattern = glob_to_regex(rel_type, GlobMatchType::Whole)?; + + for (relation_sender, event_type) in relations { + if sender_compiled_pattern.is_match(&relation_sender) + && rel_type_compiled_pattern.is_match(event_type) + { + return Ok(true); + } + } + + false + } + }; + + Ok(result) + } + + fn match_event_match( + &self, + event_match: &EventMatchCondition, + user_id: Option<&str>, + ) -> Result<bool, Error> { + let pattern = if let Some(pattern) = &event_match.pattern { + pattern + } else if let Some(pattern_type) = &event_match.pattern_type { + let user_id = if let Some(user_id) = user_id { + user_id + } else { + return Ok(false); + }; + match &**pattern_type { + "user_id" => user_id, + "user_localpart" => get_localpart_from_id(user_id)?, + _ => return Ok(false), + } + } else { + return Ok(false); + }; + + if event_match.key == "content.body" { + let compiled_pattern = glob_to_regex(pattern, GlobMatchType::Word)?; + Ok(compiled_pattern.is_match(&self.body)) + } else if let Some(value) = self.flattened_keys.get(&*event_match.key) { + let compiled_pattern = glob_to_regex(pattern, GlobMatchType::Whole)?; + Ok(compiled_pattern.is_match(value)) + } else { + Ok(false) + } + } + + fn match_member_count(&self, is: &str) -> Result<bool, Error> { + let captures = INEQUALITY_EXPR.captures(is).context("bad is clause")?; + let ineq = captures.get(1).map(|m| m.as_str()).unwrap_or("=="); + let rhs: u64 = captures + .get(2) + .context("missing number")? + .as_str() + .parse()?; + + let matches = match ineq { + "" | "==" => self.room_member_count == rhs, + "<" => self.room_member_count < rhs, + ">" => self.room_member_count > rhs, + ">=" => self.room_member_count >= rhs, + "<=" => self.room_member_count <= rhs, + _ => false, + }; + + Ok(matches) + } +} + +#[test] +fn push_rule_evaluator() { + let mut flattened_keys = BTreeMap::new(); + flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string()); + let evaluator = PushRuleEvaluator::py_new( + flattened_keys, + 10, + 0, + BTreeMap::new(), + BTreeMap::new(), + true, + ) + .unwrap(); + + let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob")); + assert_eq!(result.len(), 3); +} diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 6a099a15fa..33b2bc5d47 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -5,7 +5,7 @@ //! allocation atm). use std::borrow::Cow; -use std::collections::{BTreeMap, BTreeSet, HashMap}; +use std::collections::{BTreeMap, HashMap}; use anyhow::{Context, Error}; use lazy_static::lazy_static; @@ -17,9 +17,10 @@ use serde::de::Error as _; use serde::{Deserialize, Serialize}; use serde_json::Value; -use self::utils::{glob_to_regex, GlobMatchType}; +use self::evaluator::PushRuleEvaluator; mod base_rules; +mod evaluator; mod utils; lazy_static! { @@ -349,222 +350,6 @@ impl FilteredPushRules { } } -#[pyclass] -pub struct PushRuleEvaluator { - flattened_keys: BTreeMap<String, String>, - body: String, - room_member_count: u64, - power_levels: BTreeMap<String, BTreeMap<String, u64>>, - relations: BTreeMap<String, BTreeSet<(String, String)>>, - relation_match_enabled: bool, - sender_power_level: u64, -} - -#[pymethods] -impl PushRuleEvaluator { - #[new] - fn py_new( - flattened_keys: BTreeMap<String, String>, - room_member_count: u64, - sender_power_level: u64, - power_levels: BTreeMap<String, BTreeMap<String, u64>>, - relations: BTreeMap<String, BTreeSet<(String, String)>>, - relation_match_enabled: bool, - ) -> Result<Self, Error> { - let body = flattened_keys - .get("content.body") - .cloned() - .unwrap_or_default(); - - Ok(PushRuleEvaluator { - flattened_keys, - body, - room_member_count, - power_levels, - relations, - relation_match_enabled, - sender_power_level, - }) - } - - fn run( - &self, - push_rules: &FilteredPushRules, - user_id: Option<&str>, - display_name: Option<&str>, - ) -> Vec<Action> { - let mut actions = Vec::new(); - 'outer: for (push_rule, enabled) in push_rules.iter() { - if !enabled { - continue; - } - - for condition in push_rule.conditions.iter() { - match self.match_condition(condition, user_id, display_name) { - Ok(true) => {} - Ok(false) => continue 'outer, - Err(err) => { - warn!("Condition match failed {err}"); - continue 'outer; - } - } - } - - actions.extend( - push_rule - .actions - .iter() - // .filter(|a| **a != Action::DontNotify) - .cloned(), - ); - - return actions; - } - - actions - } -} - -impl PushRuleEvaluator { - pub fn match_condition( - &self, - condition: &Condition, - user_id: Option<&str>, - display_name: Option<&str>, - ) -> Result<bool, Error> { - let result = match condition { - Condition::EventMatch(event_match) => self.match_event_match(event_match, user_id)?, - Condition::ContainsDisplayName => { - if let Some(dn) = display_name { - let matcher = glob_to_regex(dn, GlobMatchType::Word)?; - matcher.is_match(&self.body) - } else { - false - } - } - Condition::RoomMemberCount { is } => { - if let Some(is) = is { - self.match_member_count(is)? - } else { - false - } - } - Condition::SenderNotificationPermission { key } => { - let required_level = self - .power_levels - .get("notifications") - .and_then(|m| m.get(key.as_ref())) - .copied() - .unwrap_or(50); - - self.sender_power_level >= required_level - } - Condition::RelationMatch { - rel_type, - sender, - sender_type, - } => { - if !self.relation_match_enabled { - return Ok(false); - } - - let sender_pattern = if let Some(sender) = sender { - sender - } else if let Some(sender_type) = sender_type { - if sender_type == "user_id" { - if let Some(user_id) = user_id { - user_id - } else { - return Ok(false); - } - } else { - warn!("Unrecognized sender_type: {sender_type}"); - return Ok(false); - } - } else { - warn!("relation_match condition missing sender or sender_type"); - return Ok(false); - }; - - let relations = if let Some(relations) = self.relations.get(&**rel_type) { - relations - } else { - return Ok(false); - }; - - let sender_compiled_pattern = glob_to_regex(sender_pattern, GlobMatchType::Whole)?; - let rel_type_compiled_pattern = glob_to_regex(rel_type, GlobMatchType::Whole)?; - - for (relation_sender, event_type) in relations { - if sender_compiled_pattern.is_match(&relation_sender) - && rel_type_compiled_pattern.is_match(event_type) - { - return Ok(true); - } - } - - false - } - }; - - Ok(result) - } - - fn match_event_match( - &self, - event_match: &EventMatchCondition, - user_id: Option<&str>, - ) -> Result<bool, Error> { - let pattern = if let Some(pattern) = &event_match.pattern { - pattern - } else if let Some(pattern_type) = &event_match.pattern_type { - let user_id = if let Some(user_id) = user_id { - user_id - } else { - return Ok(false); - }; - match &**pattern_type { - "user_id" => user_id, - "user_localpart" => utils::get_localpart_from_id(user_id)?, - _ => return Ok(false), - } - } else { - return Ok(false); - }; - - if event_match.key == "content.body" { - let compiled_pattern = glob_to_regex(pattern, GlobMatchType::Word)?; - Ok(compiled_pattern.is_match(&self.body)) - } else if let Some(value) = self.flattened_keys.get(&*event_match.key) { - let compiled_pattern = glob_to_regex(pattern, GlobMatchType::Whole)?; - Ok(compiled_pattern.is_match(value)) - } else { - Ok(false) - } - } - - fn match_member_count(&self, is: &str) -> Result<bool, Error> { - let captures = INEQUALITY_EXPR.captures(is).context("bad is clause")?; - let ineq = captures.get(1).map(|m| m.as_str()).unwrap_or("=="); - let rhs: u64 = captures - .get(2) - .context("missing number")? - .as_str() - .parse()?; - - let matches = match ineq { - "" | "==" => self.room_member_count == rhs, - "<" => self.room_member_count < rhs, - ">" => self.room_member_count > rhs, - ">=" => self.room_member_count >= rhs, - "<=" => self.room_member_count <= rhs, - _ => false, - }; - - Ok(matches) - } -} - #[test] fn split_string() { let split_body: Vec<_> = WORD_BOUNDARY_EXPR @@ -604,21 +389,3 @@ fn test_deserialize_action() { let _: Action = serde_json::from_str(r#""coalesce""#).unwrap(); let _: Action = serde_json::from_str(r#"{"set_tweak": "highlight"}"#).unwrap(); } - -#[test] -fn push_rule_evaluator() { - let mut flattened_keys = BTreeMap::new(); - flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string()); - let evaluator = PushRuleEvaluator::py_new( - flattened_keys, - 10, - 0, - BTreeMap::new(), - BTreeMap::new(), - true, - ) - .unwrap(); - - let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob")); - assert_eq!(result.len(), 3); -} diff --git a/stubs/synapse/synapse_rust.pyi b/stubs/synapse/synapse_rust/__init__.pyi index 5b51ba05d7..5b51ba05d7 100644 --- a/stubs/synapse/synapse_rust.pyi +++ b/stubs/synapse/synapse_rust/__init__.pyi diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi new file mode 100644 index 0000000000..3121bc4dde --- /dev/null +++ b/stubs/synapse/synapse_rust/push.pyi @@ -0,0 +1,47 @@ +from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Set, Tuple, Union + +from synapse.types import JsonDict + +class PushRule: + rule_id: str + priority_class: int + conditions: Sequence[Mapping[str, str]] + actions: Sequence[Union[Mapping[str, Any], str]] + default: bool + default_enabled: bool + + @staticmethod + def from_db( + rule_id: str, priority_class: int, conditions: str, actions: str + ) -> "PushRule": ... + +class PushRules: + def __init__(self, rules: Collection[PushRule]): ... + def rules(self) -> Collection[PushRule]: ... + +class FilteredPushRules: + def __init__( + self, + push_rules: PushRules, + enabled_map: Dict[str, bool], + msc3786_enabled: bool, + msc3772_enabled: bool, + ): ... + def rules(self) -> Collection[Tuple[PushRule, bool]]: ... + +class PushRuleEvaluator: + def __init__( + self, + flattened_keys: Mapping[str, str], + room_member_count: int, + sender_power_level: int, + power_levels: JsonDict, + relations: Mapping[str, Set[Tuple[str, str]]], + relation_match_enabled: bool, + ): ... + def run( + self, + push_rules: FilteredPushRules, + user_id: Optional[str], + display_name: Optional[str], + ) -> Collection[dict]: ... diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index a8f5d6fc59..21e41c07cd 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -281,7 +281,7 @@ class BulkPushRuleEvaluator: ) = await self._get_power_levels_and_sender_level(event, context) relations = await self._get_mutual_relations( - event, itertools.chain(*rules_by_user.values()) + event, itertools.chain(*(r.rules() for r in rules_by_user.values())) ) logger.info("Flatten map: %s", _flatten_dict(event)) diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py index a293d4dbb6..ebc13beda1 100644 --- a/synapse/push/clientformat.py +++ b/synapse/push/clientformat.py @@ -16,10 +16,9 @@ import copy from typing import Any, Dict, List, Optional from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP +from synapse.synapse_rust.push import FilteredPushRules, PushRule from synapse.types import UserID -from .baserules import FilteredPushRules, PushRule - def format_push_rules_for_user( user: UserID, ruleslist: FilteredPushRules diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 841b72e515..9d76e9b879 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -31,7 +31,7 @@ from typing import ( from synapse.api.errors import StoreError from synapse.config.homeserver import ExperimentalConfig from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, @@ -853,7 +853,7 @@ class PushRuleStore(PushRulesWorkerStore): user_push_rules = await self.get_push_rules_for_user(user_id) # Get rules relating to the old room and copy them to the new room - for rule, enabled in user_push_rules: + for rule, enabled in user_push_rules.rules(): if not enabled: continue diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py index 7b9b711521..bce65fab7d 100644 --- a/tests/handlers/test_deactivate_account.py +++ b/tests/handlers/test_deactivate_account.py @@ -15,11 +15,11 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import AccountDataTypes -from synapse.push.baserules import PushRule from synapse.push.rulekinds import PRIORITY_CLASS_MAP from synapse.rest import admin from synapse.rest.client import account, login from synapse.server import HomeServer +from synapse.synapse_rust.push import PushRule from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -161,20 +161,15 @@ class DeactivateAccountTestCase(HomeserverTestCase): self._store.get_push_rules_for_user(self.user) ) # Filter out default rules; we don't care - push_rules = [r for r, _ in filtered_push_rules if self._is_custom_rule(r)] + push_rules = [ + r for r, _ in filtered_push_rules.rules() if self._is_custom_rule(r) + ] # Check our rule made it - self.assertEqual( - push_rules, - [ - PushRule( - rule_id="personal.override.rule1", - priority_class=5, - conditions=[], - actions=[], - ) - ], - push_rules, - ) + self.assertEqual(len(push_rules), 1) + self.assertEqual(push_rules[0].rule_id, "personal.override.rule1") + self.assertEqual(push_rules[0].priority_class, 5) + self.assertEqual(push_rules[0].conditions, []) + self.assertEqual(push_rules[0].actions, []) # Request the deactivation of our account self._deactivate_my_account() @@ -183,7 +178,9 @@ class DeactivateAccountTestCase(HomeserverTestCase): self._store.get_push_rules_for_user(self.user) ) # Filter out default rules; we don't care - push_rules = [r for r, _ in filtered_push_rules if self._is_custom_rule(r)] + push_rules = [ + r for r, _ in filtered_push_rules.rules() if self._is_custom_rule(r) + ] # Check our rule no longer exists self.assertEqual(push_rules, [], push_rules) |