summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2022-09-07 16:54:19 +0100
committerErik Johnston <erik@matrix.org>2022-09-09 15:10:52 +0100
commit310df251730c895ceef6f030648232d517ab5a4a (patch)
tree1b170e11c44121bf7a37e566091a07b327d8ef17
parentExperimental rules (diff)
downloadsynapse-310df251730c895ceef6f030648232d517ab5a4a.tar.xz
Fixup
-rw-r--r--rust/src/push/evaluator.rs244
-rw-r--r--rust/src/push/mod.rs239
-rw-r--r--stubs/synapse/synapse_rust/__init__.pyi (renamed from stubs/synapse/synapse_rust.pyi)0
-rw-r--r--stubs/synapse/synapse_rust/push.pyi47
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py2
-rw-r--r--synapse/push/clientformat.py3
-rw-r--r--synapse/storage/databases/main/push_rule.py4
-rw-r--r--tests/handlers/test_deactivate_account.py27
8 files changed, 310 insertions, 256 deletions
diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs
new file mode 100644
index 0000000000..95d847cac0
--- /dev/null
+++ b/rust/src/push/evaluator.rs
@@ -0,0 +1,244 @@
+use std::collections::{BTreeMap, BTreeSet};
+
+use anyhow::{Context, Error};
+use log::warn;
+use pyo3::prelude::*;
+
+use super::{
+    utils::{get_localpart_from_id, glob_to_regex, GlobMatchType},
+    Action, Condition, EventMatchCondition, FilteredPushRules, INEQUALITY_EXPR,
+};
+
+#[pyclass]
+pub struct PushRuleEvaluator {
+    flattened_keys: BTreeMap<String, String>,
+    body: String,
+    room_member_count: u64,
+    power_levels: BTreeMap<String, BTreeMap<String, u64>>,
+    relations: BTreeMap<String, BTreeSet<(String, String)>>,
+    relation_match_enabled: bool,
+    sender_power_level: u64,
+}
+
+#[pymethods]
+impl PushRuleEvaluator {
+    #[new]
+    fn py_new(
+        flattened_keys: BTreeMap<String, String>,
+        room_member_count: u64,
+        sender_power_level: u64,
+        power_levels: BTreeMap<String, BTreeMap<String, u64>>,
+        relations: BTreeMap<String, BTreeSet<(String, String)>>,
+        relation_match_enabled: bool,
+    ) -> Result<Self, Error> {
+        let body = flattened_keys
+            .get("content.body")
+            .cloned()
+            .unwrap_or_default();
+
+        Ok(PushRuleEvaluator {
+            flattened_keys,
+            body,
+            room_member_count,
+            power_levels,
+            relations,
+            relation_match_enabled,
+            sender_power_level,
+        })
+    }
+
+    fn run(
+        &self,
+        push_rules: &FilteredPushRules,
+        user_id: Option<&str>,
+        display_name: Option<&str>,
+    ) -> Vec<Action> {
+        let mut actions = Vec::new();
+        'outer: for (push_rule, enabled) in push_rules.iter() {
+            if !enabled {
+                continue;
+            }
+
+            for condition in push_rule.conditions.iter() {
+                match self.match_condition(condition, user_id, display_name) {
+                    Ok(true) => {}
+                    Ok(false) => continue 'outer,
+                    Err(err) => {
+                        warn!("Condition match failed {err}");
+                        continue 'outer;
+                    }
+                }
+            }
+
+            actions.extend(
+                push_rule
+                    .actions
+                    .iter()
+                    // .filter(|a| **a != Action::DontNotify)
+                    .cloned(),
+            );
+
+            return actions;
+        }
+
+        actions
+    }
+}
+
+impl PushRuleEvaluator {
+    pub fn match_condition(
+        &self,
+        condition: &Condition,
+        user_id: Option<&str>,
+        display_name: Option<&str>,
+    ) -> Result<bool, Error> {
+        let result = match condition {
+            Condition::EventMatch(event_match) => self.match_event_match(event_match, user_id)?,
+            Condition::ContainsDisplayName => {
+                if let Some(dn) = display_name {
+                    let matcher = glob_to_regex(dn, GlobMatchType::Word)?;
+                    matcher.is_match(&self.body)
+                } else {
+                    false
+                }
+            }
+            Condition::RoomMemberCount { is } => {
+                if let Some(is) = is {
+                    self.match_member_count(is)?
+                } else {
+                    false
+                }
+            }
+            Condition::SenderNotificationPermission { key } => {
+                let required_level = self
+                    .power_levels
+                    .get("notifications")
+                    .and_then(|m| m.get(key.as_ref()))
+                    .copied()
+                    .unwrap_or(50);
+
+                self.sender_power_level >= required_level
+            }
+            Condition::RelationMatch {
+                rel_type,
+                sender,
+                sender_type,
+            } => {
+                if !self.relation_match_enabled {
+                    return Ok(false);
+                }
+
+                let sender_pattern = if let Some(sender) = sender {
+                    sender
+                } else if let Some(sender_type) = sender_type {
+                    if sender_type == "user_id" {
+                        if let Some(user_id) = user_id {
+                            user_id
+                        } else {
+                            return Ok(false);
+                        }
+                    } else {
+                        warn!("Unrecognized sender_type:  {sender_type}");
+                        return Ok(false);
+                    }
+                } else {
+                    warn!("relation_match condition missing sender or sender_type");
+                    return Ok(false);
+                };
+
+                let relations = if let Some(relations) = self.relations.get(&**rel_type) {
+                    relations
+                } else {
+                    return Ok(false);
+                };
+
+                let sender_compiled_pattern = glob_to_regex(sender_pattern, GlobMatchType::Whole)?;
+                let rel_type_compiled_pattern = glob_to_regex(rel_type, GlobMatchType::Whole)?;
+
+                for (relation_sender, event_type) in relations {
+                    if sender_compiled_pattern.is_match(&relation_sender)
+                        && rel_type_compiled_pattern.is_match(event_type)
+                    {
+                        return Ok(true);
+                    }
+                }
+
+                false
+            }
+        };
+
+        Ok(result)
+    }
+
+    fn match_event_match(
+        &self,
+        event_match: &EventMatchCondition,
+        user_id: Option<&str>,
+    ) -> Result<bool, Error> {
+        let pattern = if let Some(pattern) = &event_match.pattern {
+            pattern
+        } else if let Some(pattern_type) = &event_match.pattern_type {
+            let user_id = if let Some(user_id) = user_id {
+                user_id
+            } else {
+                return Ok(false);
+            };
+            match &**pattern_type {
+                "user_id" => user_id,
+                "user_localpart" => get_localpart_from_id(user_id)?,
+                _ => return Ok(false),
+            }
+        } else {
+            return Ok(false);
+        };
+
+        if event_match.key == "content.body" {
+            let compiled_pattern = glob_to_regex(pattern, GlobMatchType::Word)?;
+            Ok(compiled_pattern.is_match(&self.body))
+        } else if let Some(value) = self.flattened_keys.get(&*event_match.key) {
+            let compiled_pattern = glob_to_regex(pattern, GlobMatchType::Whole)?;
+            Ok(compiled_pattern.is_match(value))
+        } else {
+            Ok(false)
+        }
+    }
+
+    fn match_member_count(&self, is: &str) -> Result<bool, Error> {
+        let captures = INEQUALITY_EXPR.captures(is).context("bad is clause")?;
+        let ineq = captures.get(1).map(|m| m.as_str()).unwrap_or("==");
+        let rhs: u64 = captures
+            .get(2)
+            .context("missing number")?
+            .as_str()
+            .parse()?;
+
+        let matches = match ineq {
+            "" | "==" => self.room_member_count == rhs,
+            "<" => self.room_member_count < rhs,
+            ">" => self.room_member_count > rhs,
+            ">=" => self.room_member_count >= rhs,
+            "<=" => self.room_member_count <= rhs,
+            _ => false,
+        };
+
+        Ok(matches)
+    }
+}
+
+#[test]
+fn push_rule_evaluator() {
+    let mut flattened_keys = BTreeMap::new();
+    flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string());
+    let evaluator = PushRuleEvaluator::py_new(
+        flattened_keys,
+        10,
+        0,
+        BTreeMap::new(),
+        BTreeMap::new(),
+        true,
+    )
+    .unwrap();
+
+    let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob"));
+    assert_eq!(result.len(), 3);
+}
diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs
index 6a099a15fa..33b2bc5d47 100644
--- a/rust/src/push/mod.rs
+++ b/rust/src/push/mod.rs
@@ -5,7 +5,7 @@
 //! allocation atm).
 
 use std::borrow::Cow;
-use std::collections::{BTreeMap, BTreeSet, HashMap};
+use std::collections::{BTreeMap, HashMap};
 
 use anyhow::{Context, Error};
 use lazy_static::lazy_static;
@@ -17,9 +17,10 @@ use serde::de::Error as _;
 use serde::{Deserialize, Serialize};
 use serde_json::Value;
 
-use self::utils::{glob_to_regex, GlobMatchType};
+use self::evaluator::PushRuleEvaluator;
 
 mod base_rules;
+mod evaluator;
 mod utils;
 
 lazy_static! {
@@ -349,222 +350,6 @@ impl FilteredPushRules {
     }
 }
 
-#[pyclass]
-pub struct PushRuleEvaluator {
-    flattened_keys: BTreeMap<String, String>,
-    body: String,
-    room_member_count: u64,
-    power_levels: BTreeMap<String, BTreeMap<String, u64>>,
-    relations: BTreeMap<String, BTreeSet<(String, String)>>,
-    relation_match_enabled: bool,
-    sender_power_level: u64,
-}
-
-#[pymethods]
-impl PushRuleEvaluator {
-    #[new]
-    fn py_new(
-        flattened_keys: BTreeMap<String, String>,
-        room_member_count: u64,
-        sender_power_level: u64,
-        power_levels: BTreeMap<String, BTreeMap<String, u64>>,
-        relations: BTreeMap<String, BTreeSet<(String, String)>>,
-        relation_match_enabled: bool,
-    ) -> Result<Self, Error> {
-        let body = flattened_keys
-            .get("content.body")
-            .cloned()
-            .unwrap_or_default();
-
-        Ok(PushRuleEvaluator {
-            flattened_keys,
-            body,
-            room_member_count,
-            power_levels,
-            relations,
-            relation_match_enabled,
-            sender_power_level,
-        })
-    }
-
-    fn run(
-        &self,
-        push_rules: &FilteredPushRules,
-        user_id: Option<&str>,
-        display_name: Option<&str>,
-    ) -> Vec<Action> {
-        let mut actions = Vec::new();
-        'outer: for (push_rule, enabled) in push_rules.iter() {
-            if !enabled {
-                continue;
-            }
-
-            for condition in push_rule.conditions.iter() {
-                match self.match_condition(condition, user_id, display_name) {
-                    Ok(true) => {}
-                    Ok(false) => continue 'outer,
-                    Err(err) => {
-                        warn!("Condition match failed {err}");
-                        continue 'outer;
-                    }
-                }
-            }
-
-            actions.extend(
-                push_rule
-                    .actions
-                    .iter()
-                    // .filter(|a| **a != Action::DontNotify)
-                    .cloned(),
-            );
-
-            return actions;
-        }
-
-        actions
-    }
-}
-
-impl PushRuleEvaluator {
-    pub fn match_condition(
-        &self,
-        condition: &Condition,
-        user_id: Option<&str>,
-        display_name: Option<&str>,
-    ) -> Result<bool, Error> {
-        let result = match condition {
-            Condition::EventMatch(event_match) => self.match_event_match(event_match, user_id)?,
-            Condition::ContainsDisplayName => {
-                if let Some(dn) = display_name {
-                    let matcher = glob_to_regex(dn, GlobMatchType::Word)?;
-                    matcher.is_match(&self.body)
-                } else {
-                    false
-                }
-            }
-            Condition::RoomMemberCount { is } => {
-                if let Some(is) = is {
-                    self.match_member_count(is)?
-                } else {
-                    false
-                }
-            }
-            Condition::SenderNotificationPermission { key } => {
-                let required_level = self
-                    .power_levels
-                    .get("notifications")
-                    .and_then(|m| m.get(key.as_ref()))
-                    .copied()
-                    .unwrap_or(50);
-
-                self.sender_power_level >= required_level
-            }
-            Condition::RelationMatch {
-                rel_type,
-                sender,
-                sender_type,
-            } => {
-                if !self.relation_match_enabled {
-                    return Ok(false);
-                }
-
-                let sender_pattern = if let Some(sender) = sender {
-                    sender
-                } else if let Some(sender_type) = sender_type {
-                    if sender_type == "user_id" {
-                        if let Some(user_id) = user_id {
-                            user_id
-                        } else {
-                            return Ok(false);
-                        }
-                    } else {
-                        warn!("Unrecognized sender_type:  {sender_type}");
-                        return Ok(false);
-                    }
-                } else {
-                    warn!("relation_match condition missing sender or sender_type");
-                    return Ok(false);
-                };
-
-                let relations = if let Some(relations) = self.relations.get(&**rel_type) {
-                    relations
-                } else {
-                    return Ok(false);
-                };
-
-                let sender_compiled_pattern = glob_to_regex(sender_pattern, GlobMatchType::Whole)?;
-                let rel_type_compiled_pattern = glob_to_regex(rel_type, GlobMatchType::Whole)?;
-
-                for (relation_sender, event_type) in relations {
-                    if sender_compiled_pattern.is_match(&relation_sender)
-                        && rel_type_compiled_pattern.is_match(event_type)
-                    {
-                        return Ok(true);
-                    }
-                }
-
-                false
-            }
-        };
-
-        Ok(result)
-    }
-
-    fn match_event_match(
-        &self,
-        event_match: &EventMatchCondition,
-        user_id: Option<&str>,
-    ) -> Result<bool, Error> {
-        let pattern = if let Some(pattern) = &event_match.pattern {
-            pattern
-        } else if let Some(pattern_type) = &event_match.pattern_type {
-            let user_id = if let Some(user_id) = user_id {
-                user_id
-            } else {
-                return Ok(false);
-            };
-            match &**pattern_type {
-                "user_id" => user_id,
-                "user_localpart" => utils::get_localpart_from_id(user_id)?,
-                _ => return Ok(false),
-            }
-        } else {
-            return Ok(false);
-        };
-
-        if event_match.key == "content.body" {
-            let compiled_pattern = glob_to_regex(pattern, GlobMatchType::Word)?;
-            Ok(compiled_pattern.is_match(&self.body))
-        } else if let Some(value) = self.flattened_keys.get(&*event_match.key) {
-            let compiled_pattern = glob_to_regex(pattern, GlobMatchType::Whole)?;
-            Ok(compiled_pattern.is_match(value))
-        } else {
-            Ok(false)
-        }
-    }
-
-    fn match_member_count(&self, is: &str) -> Result<bool, Error> {
-        let captures = INEQUALITY_EXPR.captures(is).context("bad is clause")?;
-        let ineq = captures.get(1).map(|m| m.as_str()).unwrap_or("==");
-        let rhs: u64 = captures
-            .get(2)
-            .context("missing number")?
-            .as_str()
-            .parse()?;
-
-        let matches = match ineq {
-            "" | "==" => self.room_member_count == rhs,
-            "<" => self.room_member_count < rhs,
-            ">" => self.room_member_count > rhs,
-            ">=" => self.room_member_count >= rhs,
-            "<=" => self.room_member_count <= rhs,
-            _ => false,
-        };
-
-        Ok(matches)
-    }
-}
-
 #[test]
 fn split_string() {
     let split_body: Vec<_> = WORD_BOUNDARY_EXPR
@@ -604,21 +389,3 @@ fn test_deserialize_action() {
     let _: Action = serde_json::from_str(r#""coalesce""#).unwrap();
     let _: Action = serde_json::from_str(r#"{"set_tweak": "highlight"}"#).unwrap();
 }
-
-#[test]
-fn push_rule_evaluator() {
-    let mut flattened_keys = BTreeMap::new();
-    flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string());
-    let evaluator = PushRuleEvaluator::py_new(
-        flattened_keys,
-        10,
-        0,
-        BTreeMap::new(),
-        BTreeMap::new(),
-        true,
-    )
-    .unwrap();
-
-    let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob"));
-    assert_eq!(result.len(), 3);
-}
diff --git a/stubs/synapse/synapse_rust.pyi b/stubs/synapse/synapse_rust/__init__.pyi
index 5b51ba05d7..5b51ba05d7 100644
--- a/stubs/synapse/synapse_rust.pyi
+++ b/stubs/synapse/synapse_rust/__init__.pyi
diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi
new file mode 100644
index 0000000000..3121bc4dde
--- /dev/null
+++ b/stubs/synapse/synapse_rust/push.pyi
@@ -0,0 +1,47 @@
+from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Set, Tuple, Union
+
+from synapse.types import JsonDict
+
+class PushRule:
+    rule_id: str
+    priority_class: int
+    conditions: Sequence[Mapping[str, str]]
+    actions: Sequence[Union[Mapping[str, Any], str]]
+    default: bool
+    default_enabled: bool
+
+    @staticmethod
+    def from_db(
+        rule_id: str, priority_class: int, conditions: str, actions: str
+    ) -> "PushRule": ...
+
+class PushRules:
+    def __init__(self, rules: Collection[PushRule]): ...
+    def rules(self) -> Collection[PushRule]: ...
+
+class FilteredPushRules:
+    def __init__(
+        self,
+        push_rules: PushRules,
+        enabled_map: Dict[str, bool],
+        msc3786_enabled: bool,
+        msc3772_enabled: bool,
+    ): ...
+    def rules(self) -> Collection[Tuple[PushRule, bool]]: ...
+
+class PushRuleEvaluator:
+    def __init__(
+        self,
+        flattened_keys: Mapping[str, str],
+        room_member_count: int,
+        sender_power_level: int,
+        power_levels: JsonDict,
+        relations: Mapping[str, Set[Tuple[str, str]]],
+        relation_match_enabled: bool,
+    ): ...
+    def run(
+        self,
+        push_rules: FilteredPushRules,
+        user_id: Optional[str],
+        display_name: Optional[str],
+    ) -> Collection[dict]: ...
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index a8f5d6fc59..21e41c07cd 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -281,7 +281,7 @@ class BulkPushRuleEvaluator:
         ) = await self._get_power_levels_and_sender_level(event, context)
 
         relations = await self._get_mutual_relations(
-            event, itertools.chain(*rules_by_user.values())
+            event, itertools.chain(*(r.rules() for r in rules_by_user.values()))
         )
 
         logger.info("Flatten map: %s", _flatten_dict(event))
diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py
index a293d4dbb6..ebc13beda1 100644
--- a/synapse/push/clientformat.py
+++ b/synapse/push/clientformat.py
@@ -16,10 +16,9 @@ import copy
 from typing import Any, Dict, List, Optional
 
 from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
+from synapse.synapse_rust.push import FilteredPushRules, PushRule
 from synapse.types import UserID
 
-from .baserules import FilteredPushRules, PushRule
-
 
 def format_push_rules_for_user(
     user: UserID, ruleslist: FilteredPushRules
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 841b72e515..9d76e9b879 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -31,7 +31,7 @@ from typing import (
 from synapse.api.errors import StoreError
 from synapse.config.homeserver import ExperimentalConfig
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
     DatabasePool,
     LoggingDatabaseConnection,
@@ -853,7 +853,7 @@ class PushRuleStore(PushRulesWorkerStore):
         user_push_rules = await self.get_push_rules_for_user(user_id)
 
         # Get rules relating to the old room and copy them to the new room
-        for rule, enabled in user_push_rules:
+        for rule, enabled in user_push_rules.rules():
             if not enabled:
                 continue
 
diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py
index 7b9b711521..bce65fab7d 100644
--- a/tests/handlers/test_deactivate_account.py
+++ b/tests/handlers/test_deactivate_account.py
@@ -15,11 +15,11 @@
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import AccountDataTypes
-from synapse.push.baserules import PushRule
 from synapse.push.rulekinds import PRIORITY_CLASS_MAP
 from synapse.rest import admin
 from synapse.rest.client import account, login
 from synapse.server import HomeServer
+from synapse.synapse_rust.push import PushRule
 from synapse.util import Clock
 
 from tests.unittest import HomeserverTestCase
@@ -161,20 +161,15 @@ class DeactivateAccountTestCase(HomeserverTestCase):
             self._store.get_push_rules_for_user(self.user)
         )
         # Filter out default rules; we don't care
-        push_rules = [r for r, _ in filtered_push_rules if self._is_custom_rule(r)]
+        push_rules = [
+            r for r, _ in filtered_push_rules.rules() if self._is_custom_rule(r)
+        ]
         # Check our rule made it
-        self.assertEqual(
-            push_rules,
-            [
-                PushRule(
-                    rule_id="personal.override.rule1",
-                    priority_class=5,
-                    conditions=[],
-                    actions=[],
-                )
-            ],
-            push_rules,
-        )
+        self.assertEqual(len(push_rules), 1)
+        self.assertEqual(push_rules[0].rule_id, "personal.override.rule1")
+        self.assertEqual(push_rules[0].priority_class, 5)
+        self.assertEqual(push_rules[0].conditions, [])
+        self.assertEqual(push_rules[0].actions, [])
 
         # Request the deactivation of our account
         self._deactivate_my_account()
@@ -183,7 +178,9 @@ class DeactivateAccountTestCase(HomeserverTestCase):
             self._store.get_push_rules_for_user(self.user)
         )
         # Filter out default rules; we don't care
-        push_rules = [r for r, _ in filtered_push_rules if self._is_custom_rule(r)]
+        push_rules = [
+            r for r, _ in filtered_push_rules.rules() if self._is_custom_rule(r)
+        ]
         # Check our rule no longer exists
         self.assertEqual(push_rules, [], push_rules)