summary refs log tree commit diff
path: root/rust/src/push/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'rust/src/push/mod.rs')
-rw-r--r--rust/src/push/mod.rs83
1 files changed, 83 insertions, 0 deletions
diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs
index 3c4f876cab..79e519fe11 100644
--- a/rust/src/push/mod.rs
+++ b/rust/src/push/mod.rs
@@ -56,7 +56,9 @@ use std::collections::{BTreeMap, HashMap, HashSet};
 
 use anyhow::{Context, Error};
 use log::warn;
+use pyo3::exceptions::PyTypeError;
 use pyo3::prelude::*;
+use pyo3::types::{PyBool, PyLong, PyString};
 use pythonize::{depythonize, pythonize};
 use serde::de::Error as _;
 use serde::{Deserialize, Serialize};
@@ -248,6 +250,36 @@ impl<'de> Deserialize<'de> for Action {
     }
 }
 
+/// A simple JSON values (string, int, boolean, or null).
+#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
+#[serde(untagged)]
+pub enum SimpleJsonValue {
+    Str(String),
+    Int(i64),
+    Bool(bool),
+    Null,
+}
+
+impl<'source> FromPyObject<'source> for SimpleJsonValue {
+    fn extract(ob: &'source PyAny) -> PyResult<Self> {
+        if let Ok(s) = <PyString as pyo3::PyTryFrom>::try_from(ob) {
+            Ok(SimpleJsonValue::Str(s.to_string()))
+        // A bool *is* an int, ensure we try bool first.
+        } else if let Ok(b) = <PyBool as pyo3::PyTryFrom>::try_from(ob) {
+            Ok(SimpleJsonValue::Bool(b.extract()?))
+        } else if let Ok(i) = <PyLong as pyo3::PyTryFrom>::try_from(ob) {
+            Ok(SimpleJsonValue::Int(i.extract()?))
+        } else if ob.is_none() {
+            Ok(SimpleJsonValue::Null)
+        } else {
+            Err(PyTypeError::new_err(format!(
+                "Can't convert from {} to SimpleJsonValue",
+                ob.get_type().name()?
+            )))
+        }
+    }
+}
+
 /// A condition used in push rules to match against an event.
 ///
 /// We need this split as `serde` doesn't give us the ability to have a
@@ -267,6 +299,8 @@ pub enum Condition {
 #[serde(tag = "kind")]
 pub enum KnownCondition {
     EventMatch(EventMatchCondition),
+    #[serde(rename = "com.beeper.msc3758.exact_event_match")]
+    ExactEventMatch(ExactEventMatchCondition),
     #[serde(rename = "im.nheko.msc3664.related_event_match")]
     RelatedEventMatch(RelatedEventMatchCondition),
     #[serde(rename = "org.matrix.msc3952.is_user_mention")]
@@ -309,6 +343,13 @@ pub struct EventMatchCondition {
     pub pattern_type: Option<Cow<'static, str>>,
 }
 
+/// The body of a [`Condition::ExactEventMatch`]
+#[derive(Serialize, Deserialize, Debug, Clone)]
+pub struct ExactEventMatchCondition {
+    pub key: Cow<'static, str>,
+    pub value: Cow<'static, SimpleJsonValue>,
+}
+
 /// The body of a [`Condition::RelatedEventMatch`]
 #[derive(Serialize, Deserialize, Debug, Clone)]
 pub struct RelatedEventMatchCondition {
@@ -543,6 +584,48 @@ fn test_deserialize_unstable_msc3931_condition() {
 }
 
 #[test]
+fn test_deserialize_unstable_msc3758_condition() {
+    // A string condition should work.
+    let json =
+        r#"{"kind":"com.beeper.msc3758.exact_event_match","key":"content.value","value":"foo"}"#;
+
+    let condition: Condition = serde_json::from_str(json).unwrap();
+    assert!(matches!(
+        condition,
+        Condition::Known(KnownCondition::ExactEventMatch(_))
+    ));
+
+    // A boolean condition should work.
+    let json =
+        r#"{"kind":"com.beeper.msc3758.exact_event_match","key":"content.value","value":true}"#;
+
+    let condition: Condition = serde_json::from_str(json).unwrap();
+    assert!(matches!(
+        condition,
+        Condition::Known(KnownCondition::ExactEventMatch(_))
+    ));
+
+    // An integer condition should work.
+    let json = r#"{"kind":"com.beeper.msc3758.exact_event_match","key":"content.value","value":1}"#;
+
+    let condition: Condition = serde_json::from_str(json).unwrap();
+    assert!(matches!(
+        condition,
+        Condition::Known(KnownCondition::ExactEventMatch(_))
+    ));
+
+    // A null condition should work
+    let json =
+        r#"{"kind":"com.beeper.msc3758.exact_event_match","key":"content.value","value":null}"#;
+
+    let condition: Condition = serde_json::from_str(json).unwrap();
+    assert!(matches!(
+        condition,
+        Condition::Known(KnownCondition::ExactEventMatch(_))
+    ));
+}
+
+#[test]
 fn test_deserialize_unstable_msc3952_user_condition() {
     let json = r#"{"kind":"org.matrix.msc3952.is_user_mention"}"#;