diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs
new file mode 100644
index 0000000000..95d847cac0
--- /dev/null
+++ b/rust/src/push/evaluator.rs
@@ -0,0 +1,244 @@
+use std::collections::{BTreeMap, BTreeSet};
+
+use anyhow::{Context, Error};
+use log::warn;
+use pyo3::prelude::*;
+
+use super::{
+ utils::{get_localpart_from_id, glob_to_regex, GlobMatchType},
+ Action, Condition, EventMatchCondition, FilteredPushRules, INEQUALITY_EXPR,
+};
+
+#[pyclass]
+pub struct PushRuleEvaluator {
+ flattened_keys: BTreeMap<String, String>,
+ body: String,
+ room_member_count: u64,
+ power_levels: BTreeMap<String, BTreeMap<String, u64>>,
+ relations: BTreeMap<String, BTreeSet<(String, String)>>,
+ relation_match_enabled: bool,
+ sender_power_level: u64,
+}
+
+#[pymethods]
+impl PushRuleEvaluator {
+ #[new]
+ fn py_new(
+ flattened_keys: BTreeMap<String, String>,
+ room_member_count: u64,
+ sender_power_level: u64,
+ power_levels: BTreeMap<String, BTreeMap<String, u64>>,
+ relations: BTreeMap<String, BTreeSet<(String, String)>>,
+ relation_match_enabled: bool,
+ ) -> Result<Self, Error> {
+ let body = flattened_keys
+ .get("content.body")
+ .cloned()
+ .unwrap_or_default();
+
+ Ok(PushRuleEvaluator {
+ flattened_keys,
+ body,
+ room_member_count,
+ power_levels,
+ relations,
+ relation_match_enabled,
+ sender_power_level,
+ })
+ }
+
+ fn run(
+ &self,
+ push_rules: &FilteredPushRules,
+ user_id: Option<&str>,
+ display_name: Option<&str>,
+ ) -> Vec<Action> {
+ let mut actions = Vec::new();
+ 'outer: for (push_rule, enabled) in push_rules.iter() {
+ if !enabled {
+ continue;
+ }
+
+ for condition in push_rule.conditions.iter() {
+ match self.match_condition(condition, user_id, display_name) {
+ Ok(true) => {}
+ Ok(false) => continue 'outer,
+ Err(err) => {
+ warn!("Condition match failed {err}");
+ continue 'outer;
+ }
+ }
+ }
+
+ actions.extend(
+ push_rule
+ .actions
+ .iter()
+ // .filter(|a| **a != Action::DontNotify)
+ .cloned(),
+ );
+
+ return actions;
+ }
+
+ actions
+ }
+}
+
+impl PushRuleEvaluator {
+ pub fn match_condition(
+ &self,
+ condition: &Condition,
+ user_id: Option<&str>,
+ display_name: Option<&str>,
+ ) -> Result<bool, Error> {
+ let result = match condition {
+ Condition::EventMatch(event_match) => self.match_event_match(event_match, user_id)?,
+ Condition::ContainsDisplayName => {
+ if let Some(dn) = display_name {
+ let matcher = glob_to_regex(dn, GlobMatchType::Word)?;
+ matcher.is_match(&self.body)
+ } else {
+ false
+ }
+ }
+ Condition::RoomMemberCount { is } => {
+ if let Some(is) = is {
+ self.match_member_count(is)?
+ } else {
+ false
+ }
+ }
+ Condition::SenderNotificationPermission { key } => {
+ let required_level = self
+ .power_levels
+ .get("notifications")
+ .and_then(|m| m.get(key.as_ref()))
+ .copied()
+ .unwrap_or(50);
+
+ self.sender_power_level >= required_level
+ }
+ Condition::RelationMatch {
+ rel_type,
+ sender,
+ sender_type,
+ } => {
+ if !self.relation_match_enabled {
+ return Ok(false);
+ }
+
+ let sender_pattern = if let Some(sender) = sender {
+ sender
+ } else if let Some(sender_type) = sender_type {
+ if sender_type == "user_id" {
+ if let Some(user_id) = user_id {
+ user_id
+ } else {
+ return Ok(false);
+ }
+ } else {
+ warn!("Unrecognized sender_type: {sender_type}");
+ return Ok(false);
+ }
+ } else {
+ warn!("relation_match condition missing sender or sender_type");
+ return Ok(false);
+ };
+
+ let relations = if let Some(relations) = self.relations.get(&**rel_type) {
+ relations
+ } else {
+ return Ok(false);
+ };
+
+ let sender_compiled_pattern = glob_to_regex(sender_pattern, GlobMatchType::Whole)?;
+ let rel_type_compiled_pattern = glob_to_regex(rel_type, GlobMatchType::Whole)?;
+
+ for (relation_sender, event_type) in relations {
+ if sender_compiled_pattern.is_match(&relation_sender)
+ && rel_type_compiled_pattern.is_match(event_type)
+ {
+ return Ok(true);
+ }
+ }
+
+ false
+ }
+ };
+
+ Ok(result)
+ }
+
+ fn match_event_match(
+ &self,
+ event_match: &EventMatchCondition,
+ user_id: Option<&str>,
+ ) -> Result<bool, Error> {
+ let pattern = if let Some(pattern) = &event_match.pattern {
+ pattern
+ } else if let Some(pattern_type) = &event_match.pattern_type {
+ let user_id = if let Some(user_id) = user_id {
+ user_id
+ } else {
+ return Ok(false);
+ };
+ match &**pattern_type {
+ "user_id" => user_id,
+ "user_localpart" => get_localpart_from_id(user_id)?,
+ _ => return Ok(false),
+ }
+ } else {
+ return Ok(false);
+ };
+
+ if event_match.key == "content.body" {
+ let compiled_pattern = glob_to_regex(pattern, GlobMatchType::Word)?;
+ Ok(compiled_pattern.is_match(&self.body))
+ } else if let Some(value) = self.flattened_keys.get(&*event_match.key) {
+ let compiled_pattern = glob_to_regex(pattern, GlobMatchType::Whole)?;
+ Ok(compiled_pattern.is_match(value))
+ } else {
+ Ok(false)
+ }
+ }
+
+ fn match_member_count(&self, is: &str) -> Result<bool, Error> {
+ let captures = INEQUALITY_EXPR.captures(is).context("bad is clause")?;
+ let ineq = captures.get(1).map(|m| m.as_str()).unwrap_or("==");
+ let rhs: u64 = captures
+ .get(2)
+ .context("missing number")?
+ .as_str()
+ .parse()?;
+
+ let matches = match ineq {
+ "" | "==" => self.room_member_count == rhs,
+ "<" => self.room_member_count < rhs,
+ ">" => self.room_member_count > rhs,
+ ">=" => self.room_member_count >= rhs,
+ "<=" => self.room_member_count <= rhs,
+ _ => false,
+ };
+
+ Ok(matches)
+ }
+}
+
+#[test]
+fn push_rule_evaluator() {
+ let mut flattened_keys = BTreeMap::new();
+ flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string());
+ let evaluator = PushRuleEvaluator::py_new(
+ flattened_keys,
+ 10,
+ 0,
+ BTreeMap::new(),
+ BTreeMap::new(),
+ true,
+ )
+ .unwrap();
+
+ let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob"));
+ assert_eq!(result.len(), 3);
+}
diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs
index 6a099a15fa..33b2bc5d47 100644
--- a/rust/src/push/mod.rs
+++ b/rust/src/push/mod.rs
@@ -5,7 +5,7 @@
//! allocation atm).
use std::borrow::Cow;
-use std::collections::{BTreeMap, BTreeSet, HashMap};
+use std::collections::{BTreeMap, HashMap};
use anyhow::{Context, Error};
use lazy_static::lazy_static;
@@ -17,9 +17,10 @@ use serde::de::Error as _;
use serde::{Deserialize, Serialize};
use serde_json::Value;
-use self::utils::{glob_to_regex, GlobMatchType};
+use self::evaluator::PushRuleEvaluator;
mod base_rules;
+mod evaluator;
mod utils;
lazy_static! {
@@ -349,222 +350,6 @@ impl FilteredPushRules {
}
}
-#[pyclass]
-pub struct PushRuleEvaluator {
- flattened_keys: BTreeMap<String, String>,
- body: String,
- room_member_count: u64,
- power_levels: BTreeMap<String, BTreeMap<String, u64>>,
- relations: BTreeMap<String, BTreeSet<(String, String)>>,
- relation_match_enabled: bool,
- sender_power_level: u64,
-}
-
-#[pymethods]
-impl PushRuleEvaluator {
- #[new]
- fn py_new(
- flattened_keys: BTreeMap<String, String>,
- room_member_count: u64,
- sender_power_level: u64,
- power_levels: BTreeMap<String, BTreeMap<String, u64>>,
- relations: BTreeMap<String, BTreeSet<(String, String)>>,
- relation_match_enabled: bool,
- ) -> Result<Self, Error> {
- let body = flattened_keys
- .get("content.body")
- .cloned()
- .unwrap_or_default();
-
- Ok(PushRuleEvaluator {
- flattened_keys,
- body,
- room_member_count,
- power_levels,
- relations,
- relation_match_enabled,
- sender_power_level,
- })
- }
-
- fn run(
- &self,
- push_rules: &FilteredPushRules,
- user_id: Option<&str>,
- display_name: Option<&str>,
- ) -> Vec<Action> {
- let mut actions = Vec::new();
- 'outer: for (push_rule, enabled) in push_rules.iter() {
- if !enabled {
- continue;
- }
-
- for condition in push_rule.conditions.iter() {
- match self.match_condition(condition, user_id, display_name) {
- Ok(true) => {}
- Ok(false) => continue 'outer,
- Err(err) => {
- warn!("Condition match failed {err}");
- continue 'outer;
- }
- }
- }
-
- actions.extend(
- push_rule
- .actions
- .iter()
- // .filter(|a| **a != Action::DontNotify)
- .cloned(),
- );
-
- return actions;
- }
-
- actions
- }
-}
-
-impl PushRuleEvaluator {
- pub fn match_condition(
- &self,
- condition: &Condition,
- user_id: Option<&str>,
- display_name: Option<&str>,
- ) -> Result<bool, Error> {
- let result = match condition {
- Condition::EventMatch(event_match) => self.match_event_match(event_match, user_id)?,
- Condition::ContainsDisplayName => {
- if let Some(dn) = display_name {
- let matcher = glob_to_regex(dn, GlobMatchType::Word)?;
- matcher.is_match(&self.body)
- } else {
- false
- }
- }
- Condition::RoomMemberCount { is } => {
- if let Some(is) = is {
- self.match_member_count(is)?
- } else {
- false
- }
- }
- Condition::SenderNotificationPermission { key } => {
- let required_level = self
- .power_levels
- .get("notifications")
- .and_then(|m| m.get(key.as_ref()))
- .copied()
- .unwrap_or(50);
-
- self.sender_power_level >= required_level
- }
- Condition::RelationMatch {
- rel_type,
- sender,
- sender_type,
- } => {
- if !self.relation_match_enabled {
- return Ok(false);
- }
-
- let sender_pattern = if let Some(sender) = sender {
- sender
- } else if let Some(sender_type) = sender_type {
- if sender_type == "user_id" {
- if let Some(user_id) = user_id {
- user_id
- } else {
- return Ok(false);
- }
- } else {
- warn!("Unrecognized sender_type: {sender_type}");
- return Ok(false);
- }
- } else {
- warn!("relation_match condition missing sender or sender_type");
- return Ok(false);
- };
-
- let relations = if let Some(relations) = self.relations.get(&**rel_type) {
- relations
- } else {
- return Ok(false);
- };
-
- let sender_compiled_pattern = glob_to_regex(sender_pattern, GlobMatchType::Whole)?;
- let rel_type_compiled_pattern = glob_to_regex(rel_type, GlobMatchType::Whole)?;
-
- for (relation_sender, event_type) in relations {
- if sender_compiled_pattern.is_match(&relation_sender)
- && rel_type_compiled_pattern.is_match(event_type)
- {
- return Ok(true);
- }
- }
-
- false
- }
- };
-
- Ok(result)
- }
-
- fn match_event_match(
- &self,
- event_match: &EventMatchCondition,
- user_id: Option<&str>,
- ) -> Result<bool, Error> {
- let pattern = if let Some(pattern) = &event_match.pattern {
- pattern
- } else if let Some(pattern_type) = &event_match.pattern_type {
- let user_id = if let Some(user_id) = user_id {
- user_id
- } else {
- return Ok(false);
- };
- match &**pattern_type {
- "user_id" => user_id,
- "user_localpart" => utils::get_localpart_from_id(user_id)?,
- _ => return Ok(false),
- }
- } else {
- return Ok(false);
- };
-
- if event_match.key == "content.body" {
- let compiled_pattern = glob_to_regex(pattern, GlobMatchType::Word)?;
- Ok(compiled_pattern.is_match(&self.body))
- } else if let Some(value) = self.flattened_keys.get(&*event_match.key) {
- let compiled_pattern = glob_to_regex(pattern, GlobMatchType::Whole)?;
- Ok(compiled_pattern.is_match(value))
- } else {
- Ok(false)
- }
- }
-
- fn match_member_count(&self, is: &str) -> Result<bool, Error> {
- let captures = INEQUALITY_EXPR.captures(is).context("bad is clause")?;
- let ineq = captures.get(1).map(|m| m.as_str()).unwrap_or("==");
- let rhs: u64 = captures
- .get(2)
- .context("missing number")?
- .as_str()
- .parse()?;
-
- let matches = match ineq {
- "" | "==" => self.room_member_count == rhs,
- "<" => self.room_member_count < rhs,
- ">" => self.room_member_count > rhs,
- ">=" => self.room_member_count >= rhs,
- "<=" => self.room_member_count <= rhs,
- _ => false,
- };
-
- Ok(matches)
- }
-}
-
#[test]
fn split_string() {
let split_body: Vec<_> = WORD_BOUNDARY_EXPR
@@ -604,21 +389,3 @@ fn test_deserialize_action() {
let _: Action = serde_json::from_str(r#""coalesce""#).unwrap();
let _: Action = serde_json::from_str(r#"{"set_tweak": "highlight"}"#).unwrap();
}
-
-#[test]
-fn push_rule_evaluator() {
- let mut flattened_keys = BTreeMap::new();
- flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string());
- let evaluator = PushRuleEvaluator::py_new(
- flattened_keys,
- 10,
- 0,
- BTreeMap::new(),
- BTreeMap::new(),
- true,
- )
- .unwrap();
-
- let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob"));
- assert_eq!(result.len(), 3);
-}
|