summary refs log tree commit diff
path: root/rust/src
diff options
context:
space:
mode:
Diffstat (limited to 'rust/src')
-rw-r--r--rust/src/push/evaluator.rs244
-rw-r--r--rust/src/push/mod.rs239
2 files changed, 247 insertions, 236 deletions
diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs
new file mode 100644

index 0000000000..95d847cac0 --- /dev/null +++ b/rust/src/push/evaluator.rs
@@ -0,0 +1,244 @@ +use std::collections::{BTreeMap, BTreeSet}; + +use anyhow::{Context, Error}; +use log::warn; +use pyo3::prelude::*; + +use super::{ + utils::{get_localpart_from_id, glob_to_regex, GlobMatchType}, + Action, Condition, EventMatchCondition, FilteredPushRules, INEQUALITY_EXPR, +}; + +#[pyclass] +pub struct PushRuleEvaluator { + flattened_keys: BTreeMap<String, String>, + body: String, + room_member_count: u64, + power_levels: BTreeMap<String, BTreeMap<String, u64>>, + relations: BTreeMap<String, BTreeSet<(String, String)>>, + relation_match_enabled: bool, + sender_power_level: u64, +} + +#[pymethods] +impl PushRuleEvaluator { + #[new] + fn py_new( + flattened_keys: BTreeMap<String, String>, + room_member_count: u64, + sender_power_level: u64, + power_levels: BTreeMap<String, BTreeMap<String, u64>>, + relations: BTreeMap<String, BTreeSet<(String, String)>>, + relation_match_enabled: bool, + ) -> Result<Self, Error> { + let body = flattened_keys + .get("content.body") + .cloned() + .unwrap_or_default(); + + Ok(PushRuleEvaluator { + flattened_keys, + body, + room_member_count, + power_levels, + relations, + relation_match_enabled, + sender_power_level, + }) + } + + fn run( + &self, + push_rules: &FilteredPushRules, + user_id: Option<&str>, + display_name: Option<&str>, + ) -> Vec<Action> { + let mut actions = Vec::new(); + 'outer: for (push_rule, enabled) in push_rules.iter() { + if !enabled { + continue; + } + + for condition in push_rule.conditions.iter() { + match self.match_condition(condition, user_id, display_name) { + Ok(true) => {} + Ok(false) => continue 'outer, + Err(err) => { + warn!("Condition match failed {err}"); + continue 'outer; + } + } + } + + actions.extend( + push_rule + .actions + .iter() + // .filter(|a| **a != Action::DontNotify) + .cloned(), + ); + + return actions; + } + + actions + } +} + +impl PushRuleEvaluator { + pub fn match_condition( + &self, + condition: &Condition, + user_id: Option<&str>, + display_name: Option<&str>, + ) -> Result<bool, Error> { + let result = match condition { + Condition::EventMatch(event_match) => self.match_event_match(event_match, user_id)?, + Condition::ContainsDisplayName => { + if let Some(dn) = display_name { + let matcher = glob_to_regex(dn, GlobMatchType::Word)?; + matcher.is_match(&self.body) + } else { + false + } + } + Condition::RoomMemberCount { is } => { + if let Some(is) = is { + self.match_member_count(is)? + } else { + false + } + } + Condition::SenderNotificationPermission { key } => { + let required_level = self + .power_levels + .get("notifications") + .and_then(|m| m.get(key.as_ref())) + .copied() + .unwrap_or(50); + + self.sender_power_level >= required_level + } + Condition::RelationMatch { + rel_type, + sender, + sender_type, + } => { + if !self.relation_match_enabled { + return Ok(false); + } + + let sender_pattern = if let Some(sender) = sender { + sender + } else if let Some(sender_type) = sender_type { + if sender_type == "user_id" { + if let Some(user_id) = user_id { + user_id + } else { + return Ok(false); + } + } else { + warn!("Unrecognized sender_type: {sender_type}"); + return Ok(false); + } + } else { + warn!("relation_match condition missing sender or sender_type"); + return Ok(false); + }; + + let relations = if let Some(relations) = self.relations.get(&**rel_type) { + relations + } else { + return Ok(false); + }; + + let sender_compiled_pattern = glob_to_regex(sender_pattern, GlobMatchType::Whole)?; + let rel_type_compiled_pattern = glob_to_regex(rel_type, GlobMatchType::Whole)?; + + for (relation_sender, event_type) in relations { + if sender_compiled_pattern.is_match(&relation_sender) + && rel_type_compiled_pattern.is_match(event_type) + { + return Ok(true); + } + } + + false + } + }; + + Ok(result) + } + + fn match_event_match( + &self, + event_match: &EventMatchCondition, + user_id: Option<&str>, + ) -> Result<bool, Error> { + let pattern = if let Some(pattern) = &event_match.pattern { + pattern + } else if let Some(pattern_type) = &event_match.pattern_type { + let user_id = if let Some(user_id) = user_id { + user_id + } else { + return Ok(false); + }; + match &**pattern_type { + "user_id" => user_id, + "user_localpart" => get_localpart_from_id(user_id)?, + _ => return Ok(false), + } + } else { + return Ok(false); + }; + + if event_match.key == "content.body" { + let compiled_pattern = glob_to_regex(pattern, GlobMatchType::Word)?; + Ok(compiled_pattern.is_match(&self.body)) + } else if let Some(value) = self.flattened_keys.get(&*event_match.key) { + let compiled_pattern = glob_to_regex(pattern, GlobMatchType::Whole)?; + Ok(compiled_pattern.is_match(value)) + } else { + Ok(false) + } + } + + fn match_member_count(&self, is: &str) -> Result<bool, Error> { + let captures = INEQUALITY_EXPR.captures(is).context("bad is clause")?; + let ineq = captures.get(1).map(|m| m.as_str()).unwrap_or("=="); + let rhs: u64 = captures + .get(2) + .context("missing number")? + .as_str() + .parse()?; + + let matches = match ineq { + "" | "==" => self.room_member_count == rhs, + "<" => self.room_member_count < rhs, + ">" => self.room_member_count > rhs, + ">=" => self.room_member_count >= rhs, + "<=" => self.room_member_count <= rhs, + _ => false, + }; + + Ok(matches) + } +} + +#[test] +fn push_rule_evaluator() { + let mut flattened_keys = BTreeMap::new(); + flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string()); + let evaluator = PushRuleEvaluator::py_new( + flattened_keys, + 10, + 0, + BTreeMap::new(), + BTreeMap::new(), + true, + ) + .unwrap(); + + let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob")); + assert_eq!(result.len(), 3); +} diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs
index 6a099a15fa..33b2bc5d47 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs
@@ -5,7 +5,7 @@ //! allocation atm). use std::borrow::Cow; -use std::collections::{BTreeMap, BTreeSet, HashMap}; +use std::collections::{BTreeMap, HashMap}; use anyhow::{Context, Error}; use lazy_static::lazy_static; @@ -17,9 +17,10 @@ use serde::de::Error as _; use serde::{Deserialize, Serialize}; use serde_json::Value; -use self::utils::{glob_to_regex, GlobMatchType}; +use self::evaluator::PushRuleEvaluator; mod base_rules; +mod evaluator; mod utils; lazy_static! { @@ -349,222 +350,6 @@ impl FilteredPushRules { } } -#[pyclass] -pub struct PushRuleEvaluator { - flattened_keys: BTreeMap<String, String>, - body: String, - room_member_count: u64, - power_levels: BTreeMap<String, BTreeMap<String, u64>>, - relations: BTreeMap<String, BTreeSet<(String, String)>>, - relation_match_enabled: bool, - sender_power_level: u64, -} - -#[pymethods] -impl PushRuleEvaluator { - #[new] - fn py_new( - flattened_keys: BTreeMap<String, String>, - room_member_count: u64, - sender_power_level: u64, - power_levels: BTreeMap<String, BTreeMap<String, u64>>, - relations: BTreeMap<String, BTreeSet<(String, String)>>, - relation_match_enabled: bool, - ) -> Result<Self, Error> { - let body = flattened_keys - .get("content.body") - .cloned() - .unwrap_or_default(); - - Ok(PushRuleEvaluator { - flattened_keys, - body, - room_member_count, - power_levels, - relations, - relation_match_enabled, - sender_power_level, - }) - } - - fn run( - &self, - push_rules: &FilteredPushRules, - user_id: Option<&str>, - display_name: Option<&str>, - ) -> Vec<Action> { - let mut actions = Vec::new(); - 'outer: for (push_rule, enabled) in push_rules.iter() { - if !enabled { - continue; - } - - for condition in push_rule.conditions.iter() { - match self.match_condition(condition, user_id, display_name) { - Ok(true) => {} - Ok(false) => continue 'outer, - Err(err) => { - warn!("Condition match failed {err}"); - continue 'outer; - } - } - } - - actions.extend( - push_rule - .actions - .iter() - // .filter(|a| **a != Action::DontNotify) - .cloned(), - ); - - return actions; - } - - actions - } -} - -impl PushRuleEvaluator { - pub fn match_condition( - &self, - condition: &Condition, - user_id: Option<&str>, - display_name: Option<&str>, - ) -> Result<bool, Error> { - let result = match condition { - Condition::EventMatch(event_match) => self.match_event_match(event_match, user_id)?, - Condition::ContainsDisplayName => { - if let Some(dn) = display_name { - let matcher = glob_to_regex(dn, GlobMatchType::Word)?; - matcher.is_match(&self.body) - } else { - false - } - } - Condition::RoomMemberCount { is } => { - if let Some(is) = is { - self.match_member_count(is)? - } else { - false - } - } - Condition::SenderNotificationPermission { key } => { - let required_level = self - .power_levels - .get("notifications") - .and_then(|m| m.get(key.as_ref())) - .copied() - .unwrap_or(50); - - self.sender_power_level >= required_level - } - Condition::RelationMatch { - rel_type, - sender, - sender_type, - } => { - if !self.relation_match_enabled { - return Ok(false); - } - - let sender_pattern = if let Some(sender) = sender { - sender - } else if let Some(sender_type) = sender_type { - if sender_type == "user_id" { - if let Some(user_id) = user_id { - user_id - } else { - return Ok(false); - } - } else { - warn!("Unrecognized sender_type: {sender_type}"); - return Ok(false); - } - } else { - warn!("relation_match condition missing sender or sender_type"); - return Ok(false); - }; - - let relations = if let Some(relations) = self.relations.get(&**rel_type) { - relations - } else { - return Ok(false); - }; - - let sender_compiled_pattern = glob_to_regex(sender_pattern, GlobMatchType::Whole)?; - let rel_type_compiled_pattern = glob_to_regex(rel_type, GlobMatchType::Whole)?; - - for (relation_sender, event_type) in relations { - if sender_compiled_pattern.is_match(&relation_sender) - && rel_type_compiled_pattern.is_match(event_type) - { - return Ok(true); - } - } - - false - } - }; - - Ok(result) - } - - fn match_event_match( - &self, - event_match: &EventMatchCondition, - user_id: Option<&str>, - ) -> Result<bool, Error> { - let pattern = if let Some(pattern) = &event_match.pattern { - pattern - } else if let Some(pattern_type) = &event_match.pattern_type { - let user_id = if let Some(user_id) = user_id { - user_id - } else { - return Ok(false); - }; - match &**pattern_type { - "user_id" => user_id, - "user_localpart" => utils::get_localpart_from_id(user_id)?, - _ => return Ok(false), - } - } else { - return Ok(false); - }; - - if event_match.key == "content.body" { - let compiled_pattern = glob_to_regex(pattern, GlobMatchType::Word)?; - Ok(compiled_pattern.is_match(&self.body)) - } else if let Some(value) = self.flattened_keys.get(&*event_match.key) { - let compiled_pattern = glob_to_regex(pattern, GlobMatchType::Whole)?; - Ok(compiled_pattern.is_match(value)) - } else { - Ok(false) - } - } - - fn match_member_count(&self, is: &str) -> Result<bool, Error> { - let captures = INEQUALITY_EXPR.captures(is).context("bad is clause")?; - let ineq = captures.get(1).map(|m| m.as_str()).unwrap_or("=="); - let rhs: u64 = captures - .get(2) - .context("missing number")? - .as_str() - .parse()?; - - let matches = match ineq { - "" | "==" => self.room_member_count == rhs, - "<" => self.room_member_count < rhs, - ">" => self.room_member_count > rhs, - ">=" => self.room_member_count >= rhs, - "<=" => self.room_member_count <= rhs, - _ => false, - }; - - Ok(matches) - } -} - #[test] fn split_string() { let split_body: Vec<_> = WORD_BOUNDARY_EXPR @@ -604,21 +389,3 @@ fn test_deserialize_action() { let _: Action = serde_json::from_str(r#""coalesce""#).unwrap(); let _: Action = serde_json::from_str(r#"{"set_tweak": "highlight"}"#).unwrap(); } - -#[test] -fn push_rule_evaluator() { - let mut flattened_keys = BTreeMap::new(); - flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string()); - let evaluator = PushRuleEvaluator::py_new( - flattened_keys, - 10, - 0, - BTreeMap::new(), - BTreeMap::new(), - true, - ) - .unwrap(); - - let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob")); - assert_eq!(result.len(), 3); -}