summary refs log tree commit diff
path: root/rust/src/push/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'rust/src/push/mod.rs')
-rw-r--r--rust/src/push/mod.rs28
1 files changed, 20 insertions, 8 deletions
diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs
index de6764e7c5..30fffc31ad 100644
--- a/rust/src/push/mod.rs
+++ b/rust/src/push/mod.rs
@@ -42,7 +42,6 @@
 //!
 //! The set of "base rules" are the list of rules that every user has by default. A
 //! user can modify their copy of the push rules in one of three ways:
-//!
 //!     1. Adding a new push rule of a certain kind
 //!     2. Changing the actions of a base rule
 //!     3. Enabling/disabling a base rule.
@@ -58,12 +57,16 @@ use std::collections::{BTreeMap, HashMap, HashSet};
 use anyhow::{Context, Error};
 use log::warn;
 use pyo3::prelude::*;
-use pythonize::pythonize;
+use pythonize::{depythonize, pythonize};
 use serde::de::Error as _;
 use serde::{Deserialize, Serialize};
 use serde_json::Value;
 
+use self::evaluator::PushRuleEvaluator;
+
 mod base_rules;
+pub mod evaluator;
+pub mod utils;
 
 /// Called when registering modules with python.
 pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
@@ -71,6 +74,7 @@ pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
     child_module.add_class::<PushRule>()?;
     child_module.add_class::<PushRules>()?;
     child_module.add_class::<FilteredPushRules>()?;
+    child_module.add_class::<PushRuleEvaluator>()?;
     child_module.add_function(wrap_pyfunction!(get_base_rule_ids, m)?)?;
 
     m.add_submodule(child_module)?;
@@ -274,6 +278,8 @@ pub enum KnownCondition {
     #[serde(rename = "org.matrix.msc3772.relation_match")]
     RelationMatch {
         rel_type: Cow<'static, str>,
+        #[serde(skip_serializing_if = "Option::is_none", rename = "type")]
+        event_type_pattern: Option<Cow<'static, str>>,
         #[serde(skip_serializing_if = "Option::is_none")]
         sender: Option<Cow<'static, str>>,
         #[serde(skip_serializing_if = "Option::is_none")]
@@ -287,20 +293,26 @@ impl IntoPy<PyObject> for Condition {
     }
 }
 
+impl<'source> FromPyObject<'source> for Condition {
+    fn extract(ob: &'source PyAny) -> PyResult<Self> {
+        Ok(depythonize(ob)?)
+    }
+}
+
 /// The body of a [`Condition::EventMatch`]
 #[derive(Serialize, Deserialize, Debug, Clone)]
 pub struct EventMatchCondition {
-    key: Cow<'static, str>,
+    pub key: Cow<'static, str>,
     #[serde(skip_serializing_if = "Option::is_none")]
-    pattern: Option<Cow<'static, str>>,
+    pub pattern: Option<Cow<'static, str>>,
     #[serde(skip_serializing_if = "Option::is_none")]
-    pattern_type: Option<Cow<'static, str>>,
+    pub pattern_type: Option<Cow<'static, str>>,
 }
 
 /// The collection of push rules for a user.
 #[derive(Debug, Clone, Default)]
 #[pyclass(frozen)]
-struct PushRules {
+pub struct PushRules {
     /// Custom push rules that override a base rule.
     overridden_base_rules: HashMap<Cow<'static, str>, PushRule>,
 
@@ -319,7 +331,7 @@ struct PushRules {
 #[pymethods]
 impl PushRules {
     #[new]
-    fn new(rules: Vec<PushRule>) -> PushRules {
+    pub fn new(rules: Vec<PushRule>) -> PushRules {
         let mut push_rules: PushRules = Default::default();
 
         for rule in rules {
@@ -396,7 +408,7 @@ pub struct FilteredPushRules {
 #[pymethods]
 impl FilteredPushRules {
     #[new]
-    fn py_new(
+    pub fn py_new(
         push_rules: PushRules,
         enabled_map: BTreeMap<String, bool>,
         msc3786_enabled: bool,