summary refs log tree commit diff
path: root/rust
diff options
context:
space:
mode:
Diffstat (limited to 'rust')
-rw-r--r--rust/benches/evaluator.rs32
-rw-r--r--rust/src/push/evaluator.rs65
-rw-r--r--rust/src/push/mod.rs33
3 files changed, 101 insertions, 29 deletions
diff --git a/rust/benches/evaluator.rs b/rust/benches/evaluator.rs
index 229553ebf8..8213dfd9ea 100644
--- a/rust/benches/evaluator.rs
+++ b/rust/benches/evaluator.rs
@@ -15,8 +15,8 @@
 #![feature(test)]
 use std::collections::BTreeSet;
 use synapse::push::{
-    evaluator::PushRuleEvaluator, Condition, EventMatchCondition, FilteredPushRules, PushRules,
-    SimpleJsonValue,
+    evaluator::PushRuleEvaluator, Condition, EventMatchCondition, FilteredPushRules, JsonValue,
+    PushRules, SimpleJsonValue,
 };
 use test::Bencher;
 
@@ -27,15 +27,15 @@ fn bench_match_exact(b: &mut Bencher) {
     let flattened_keys = [
         (
             "type".to_string(),
-            SimpleJsonValue::Str("m.text".to_string()),
+            JsonValue::Value(SimpleJsonValue::Str("m.text".to_string())),
         ),
         (
             "room_id".to_string(),
-            SimpleJsonValue::Str("!room:server".to_string()),
+            JsonValue::Value(SimpleJsonValue::Str("!room:server".to_string())),
         ),
         (
             "content.body".to_string(),
-            SimpleJsonValue::Str("test message".to_string()),
+            JsonValue::Value(SimpleJsonValue::Str("test message".to_string())),
         ),
     ]
     .into_iter()
@@ -54,6 +54,7 @@ fn bench_match_exact(b: &mut Bencher) {
         vec![],
         false,
         false,
+        false,
     )
     .unwrap();
 
@@ -76,15 +77,15 @@ fn bench_match_word(b: &mut Bencher) {
     let flattened_keys = [
         (
             "type".to_string(),
-            SimpleJsonValue::Str("m.text".to_string()),
+            JsonValue::Value(SimpleJsonValue::Str("m.text".to_string())),
         ),
         (
             "room_id".to_string(),
-            SimpleJsonValue::Str("!room:server".to_string()),
+            JsonValue::Value(SimpleJsonValue::Str("!room:server".to_string())),
         ),
         (
             "content.body".to_string(),
-            SimpleJsonValue::Str("test message".to_string()),
+            JsonValue::Value(SimpleJsonValue::Str("test message".to_string())),
         ),
     ]
     .into_iter()
@@ -103,6 +104,7 @@ fn bench_match_word(b: &mut Bencher) {
         vec![],
         false,
         false,
+        false,
     )
     .unwrap();
 
@@ -125,15 +127,15 @@ fn bench_match_word_miss(b: &mut Bencher) {
     let flattened_keys = [
         (
             "type".to_string(),
-            SimpleJsonValue::Str("m.text".to_string()),
+            JsonValue::Value(SimpleJsonValue::Str("m.text".to_string())),
         ),
         (
             "room_id".to_string(),
-            SimpleJsonValue::Str("!room:server".to_string()),
+            JsonValue::Value(SimpleJsonValue::Str("!room:server".to_string())),
         ),
         (
             "content.body".to_string(),
-            SimpleJsonValue::Str("test message".to_string()),
+            JsonValue::Value(SimpleJsonValue::Str("test message".to_string())),
         ),
     ]
     .into_iter()
@@ -152,6 +154,7 @@ fn bench_match_word_miss(b: &mut Bencher) {
         vec![],
         false,
         false,
+        false,
     )
     .unwrap();
 
@@ -174,15 +177,15 @@ fn bench_eval_message(b: &mut Bencher) {
     let flattened_keys = [
         (
             "type".to_string(),
-            SimpleJsonValue::Str("m.text".to_string()),
+            JsonValue::Value(SimpleJsonValue::Str("m.text".to_string())),
         ),
         (
             "room_id".to_string(),
-            SimpleJsonValue::Str("!room:server".to_string()),
+            JsonValue::Value(SimpleJsonValue::Str("!room:server".to_string())),
         ),
         (
             "content.body".to_string(),
-            SimpleJsonValue::Str("test message".to_string()),
+            JsonValue::Value(SimpleJsonValue::Str("test message".to_string())),
         ),
     ]
     .into_iter()
@@ -201,6 +204,7 @@ fn bench_eval_message(b: &mut Bencher) {
         vec![],
         false,
         false,
+        false,
     )
     .unwrap();
 
diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs
index dd6b4343ec..2eaa06ad76 100644
--- a/rust/src/push/evaluator.rs
+++ b/rust/src/push/evaluator.rs
@@ -14,6 +14,7 @@
 
 use std::collections::{BTreeMap, BTreeSet};
 
+use crate::push::JsonValue;
 use anyhow::{Context, Error};
 use lazy_static::lazy_static;
 use log::warn;
@@ -63,7 +64,7 @@ impl RoomVersionFeatures {
 pub struct PushRuleEvaluator {
     /// A mapping of "flattened" keys to simple JSON values in the event, e.g.
     /// includes things like "type" and "content.msgtype".
-    flattened_keys: BTreeMap<String, SimpleJsonValue>,
+    flattened_keys: BTreeMap<String, JsonValue>,
 
     /// The "content.body", if any.
     body: String,
@@ -87,7 +88,7 @@ pub struct PushRuleEvaluator {
 
     /// The related events, indexed by relation type. Flattened in the same manner as
     /// `flattened_keys`.
-    related_events_flattened: BTreeMap<String, BTreeMap<String, SimpleJsonValue>>,
+    related_events_flattened: BTreeMap<String, BTreeMap<String, JsonValue>>,
 
     /// If msc3664, push rules for related events, is enabled.
     related_event_match_enabled: bool,
@@ -101,6 +102,9 @@ pub struct PushRuleEvaluator {
 
     /// If MSC3758 (exact_event_match push rule condition) is enabled.
     msc3758_exact_event_match: bool,
+
+    /// If MSC3966 (exact_event_property_contains push rule condition) is enabled.
+    msc3966_exact_event_property_contains: bool,
 }
 
 #[pymethods]
@@ -109,21 +113,22 @@ impl PushRuleEvaluator {
     #[allow(clippy::too_many_arguments)]
     #[new]
     pub fn py_new(
-        flattened_keys: BTreeMap<String, SimpleJsonValue>,
+        flattened_keys: BTreeMap<String, JsonValue>,
         has_mentions: bool,
         user_mentions: BTreeSet<String>,
         room_mention: bool,
         room_member_count: u64,
         sender_power_level: Option<i64>,
         notification_power_levels: BTreeMap<String, i64>,
-        related_events_flattened: BTreeMap<String, BTreeMap<String, SimpleJsonValue>>,
+        related_events_flattened: BTreeMap<String, BTreeMap<String, JsonValue>>,
         related_event_match_enabled: bool,
         room_version_feature_flags: Vec<String>,
         msc3931_enabled: bool,
         msc3758_exact_event_match: bool,
+        msc3966_exact_event_property_contains: bool,
     ) -> Result<Self, Error> {
         let body = match flattened_keys.get("content.body") {
-            Some(SimpleJsonValue::Str(s)) => s.clone(),
+            Some(JsonValue::Value(SimpleJsonValue::Str(s))) => s.clone(),
             _ => String::new(),
         };
 
@@ -141,6 +146,7 @@ impl PushRuleEvaluator {
             room_version_feature_flags,
             msc3931_enabled,
             msc3758_exact_event_match,
+            msc3966_exact_event_property_contains,
         })
     }
 
@@ -263,6 +269,9 @@ impl PushRuleEvaluator {
             KnownCondition::RelatedEventMatch(event_match) => {
                 self.match_related_event_match(event_match, user_id)?
             }
+            KnownCondition::ExactEventPropertyContains(exact_event_match) => {
+                self.match_exact_event_property_contains(exact_event_match)?
+            }
             KnownCondition::IsUserMention => {
                 if let Some(uid) = user_id {
                     self.user_mentions.contains(uid)
@@ -345,7 +354,7 @@ impl PushRuleEvaluator {
             return Ok(false);
         };
 
-        let haystack = if let Some(SimpleJsonValue::Str(haystack)) =
+        let haystack = if let Some(JsonValue::Value(SimpleJsonValue::Str(haystack))) =
             self.flattened_keys.get(&*event_match.key)
         {
             haystack
@@ -377,7 +386,9 @@ impl PushRuleEvaluator {
 
         let value = &exact_event_match.value;
 
-        let haystack = if let Some(haystack) = self.flattened_keys.get(&*exact_event_match.key) {
+        let haystack = if let Some(JsonValue::Value(haystack)) =
+            self.flattened_keys.get(&*exact_event_match.key)
+        {
             haystack
         } else {
             return Ok(false);
@@ -441,11 +452,12 @@ impl PushRuleEvaluator {
             return Ok(false);
         };
 
-        let haystack = if let Some(SimpleJsonValue::Str(haystack)) = event.get(&**key) {
-            haystack
-        } else {
-            return Ok(false);
-        };
+        let haystack =
+            if let Some(JsonValue::Value(SimpleJsonValue::Str(haystack))) = event.get(&**key) {
+                haystack
+            } else {
+                return Ok(false);
+            };
 
         // For the content.body we match against "words", but for everything
         // else we match against the entire value.
@@ -459,6 +471,29 @@ impl PushRuleEvaluator {
         compiled_pattern.is_match(haystack)
     }
 
+    /// Evaluates a `exact_event_property_contains` condition. (MSC3758)
+    fn match_exact_event_property_contains(
+        &self,
+        exact_event_match: &ExactEventMatchCondition,
+    ) -> Result<bool, Error> {
+        // First check if the feature is enabled.
+        if !self.msc3966_exact_event_property_contains {
+            return Ok(false);
+        }
+
+        let value = &exact_event_match.value;
+
+        let haystack = if let Some(JsonValue::Array(haystack)) =
+            self.flattened_keys.get(&*exact_event_match.key)
+        {
+            haystack
+        } else {
+            return Ok(false);
+        };
+
+        Ok(haystack.contains(&**value))
+    }
+
     /// Match the member count against an 'is' condition
     /// The `is` condition can be things like '>2', '==3' or even just '4'.
     fn match_member_count(&self, is: &str) -> Result<bool, Error> {
@@ -488,7 +523,7 @@ fn push_rule_evaluator() {
     let mut flattened_keys = BTreeMap::new();
     flattened_keys.insert(
         "content.body".to_string(),
-        SimpleJsonValue::Str("foo bar bob hello".to_string()),
+        JsonValue::Value(SimpleJsonValue::Str("foo bar bob hello".to_string())),
     );
     let evaluator = PushRuleEvaluator::py_new(
         flattened_keys,
@@ -503,6 +538,7 @@ fn push_rule_evaluator() {
         vec![],
         true,
         true,
+        true,
     )
     .unwrap();
 
@@ -519,7 +555,7 @@ fn test_requires_room_version_supports_condition() {
     let mut flattened_keys = BTreeMap::new();
     flattened_keys.insert(
         "content.body".to_string(),
-        SimpleJsonValue::Str("foo bar bob hello".to_string()),
+        JsonValue::Value(SimpleJsonValue::Str("foo bar bob hello".to_string())),
     );
     let flags = vec![RoomVersionFeatures::ExtensibleEvents.as_str().to_string()];
     let evaluator = PushRuleEvaluator::py_new(
@@ -535,6 +571,7 @@ fn test_requires_room_version_supports_condition() {
         flags,
         true,
         true,
+        true,
     )
     .unwrap();
 
diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs
index 79e519fe11..253b5f367c 100644
--- a/rust/src/push/mod.rs
+++ b/rust/src/push/mod.rs
@@ -58,7 +58,7 @@ use anyhow::{Context, Error};
 use log::warn;
 use pyo3::exceptions::PyTypeError;
 use pyo3::prelude::*;
-use pyo3::types::{PyBool, PyLong, PyString};
+use pyo3::types::{PyBool, PyList, PyLong, PyString};
 use pythonize::{depythonize, pythonize};
 use serde::de::Error as _;
 use serde::{Deserialize, Serialize};
@@ -280,6 +280,35 @@ impl<'source> FromPyObject<'source> for SimpleJsonValue {
     }
 }
 
+/// A JSON values (list, string, int, boolean, or null).
+#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
+#[serde(untagged)]
+pub enum JsonValue {
+    Array(Vec<SimpleJsonValue>),
+    Value(SimpleJsonValue),
+}
+
+impl<'source> FromPyObject<'source> for JsonValue {
+    fn extract(ob: &'source PyAny) -> PyResult<Self> {
+        if let Ok(l) = <PyList as pyo3::PyTryFrom>::try_from(ob) {
+            match l.iter().map(SimpleJsonValue::extract).collect() {
+                Ok(a) => Ok(JsonValue::Array(a)),
+                Err(e) => Err(PyTypeError::new_err(format!(
+                    "Can't convert to JsonValue::Array: {}",
+                    e
+                ))),
+            }
+        } else if let Ok(v) = SimpleJsonValue::extract(ob) {
+            Ok(JsonValue::Value(v))
+        } else {
+            Err(PyTypeError::new_err(format!(
+                "Can't convert from {} to JsonValue",
+                ob.get_type().name()?
+            )))
+        }
+    }
+}
+
 /// A condition used in push rules to match against an event.
 ///
 /// We need this split as `serde` doesn't give us the ability to have a
@@ -303,6 +332,8 @@ pub enum KnownCondition {
     ExactEventMatch(ExactEventMatchCondition),
     #[serde(rename = "im.nheko.msc3664.related_event_match")]
     RelatedEventMatch(RelatedEventMatchCondition),
+    #[serde(rename = "org.matrix.msc3966.exact_event_property_contains")]
+    ExactEventPropertyContains(ExactEventMatchCondition),
     #[serde(rename = "org.matrix.msc3952.is_user_mention")]
     IsUserMention,
     #[serde(rename = "org.matrix.msc3952.is_room_mention")]