diff options
Diffstat (limited to 'rust/src/push/mod.rs')
-rw-r--r-- | rust/src/push/mod.rs | 222 |
1 files changed, 201 insertions, 21 deletions
diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 3c4f876cab..575a1c1e68 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -56,7 +56,9 @@ use std::collections::{BTreeMap, HashMap, HashSet}; use anyhow::{Context, Error}; use log::warn; +use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; +use pyo3::types::{PyBool, PyList, PyLong, PyString}; use pythonize::{depythonize, pythonize}; use serde::de::Error as _; use serde::{Deserialize, Serialize}; @@ -248,6 +250,65 @@ impl<'de> Deserialize<'de> for Action { } } +/// A simple JSON values (string, int, boolean, or null). +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(untagged)] +pub enum SimpleJsonValue { + Str(String), + Int(i64), + Bool(bool), + Null, +} + +impl<'source> FromPyObject<'source> for SimpleJsonValue { + fn extract(ob: &'source PyAny) -> PyResult<Self> { + if let Ok(s) = <PyString as pyo3::PyTryFrom>::try_from(ob) { + Ok(SimpleJsonValue::Str(s.to_string())) + // A bool *is* an int, ensure we try bool first. + } else if let Ok(b) = <PyBool as pyo3::PyTryFrom>::try_from(ob) { + Ok(SimpleJsonValue::Bool(b.extract()?)) + } else if let Ok(i) = <PyLong as pyo3::PyTryFrom>::try_from(ob) { + Ok(SimpleJsonValue::Int(i.extract()?)) + } else if ob.is_none() { + Ok(SimpleJsonValue::Null) + } else { + Err(PyTypeError::new_err(format!( + "Can't convert from {} to SimpleJsonValue", + ob.get_type().name()? + ))) + } + } +} + +/// A JSON values (list, string, int, boolean, or null). +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(untagged)] +pub enum JsonValue { + Array(Vec<SimpleJsonValue>), + Value(SimpleJsonValue), +} + +impl<'source> FromPyObject<'source> for JsonValue { + fn extract(ob: &'source PyAny) -> PyResult<Self> { + if let Ok(l) = <PyList as pyo3::PyTryFrom>::try_from(ob) { + match l.iter().map(SimpleJsonValue::extract).collect() { + Ok(a) => Ok(JsonValue::Array(a)), + Err(e) => Err(PyTypeError::new_err(format!( + "Can't convert to JsonValue::Array: {}", + e + ))), + } + } else if let Ok(v) = SimpleJsonValue::extract(ob) { + Ok(JsonValue::Value(v)) + } else { + Err(PyTypeError::new_err(format!( + "Can't convert from {} to JsonValue", + ob.get_type().name()? + ))) + } + } +} + /// A condition used in push rules to match against an event. /// /// We need this split as `serde` doesn't give us the ability to have a @@ -267,12 +328,19 @@ pub enum Condition { #[serde(tag = "kind")] pub enum KnownCondition { EventMatch(EventMatchCondition), + // Identical to event_match but gives predefined patterns. Cannot be added by users. + #[serde(skip_deserializing, rename = "event_match")] + EventMatchType(EventMatchTypeCondition), + EventPropertyIs(EventPropertyIsCondition), #[serde(rename = "im.nheko.msc3664.related_event_match")] RelatedEventMatch(RelatedEventMatchCondition), - #[serde(rename = "org.matrix.msc3952.is_user_mention")] - IsUserMention, - #[serde(rename = "org.matrix.msc3952.is_room_mention")] - IsRoomMention, + // Identical to related_event_match but gives predefined patterns. Cannot be added by users. + #[serde(skip_deserializing, rename = "im.nheko.msc3664.related_event_match")] + RelatedEventMatchType(RelatedEventMatchTypeCondition), + EventPropertyContains(EventPropertyIsCondition), + // Identical to exact_event_property_contains but gives predefined patterns. Cannot be added by users. + #[serde(skip_deserializing, rename = "event_property_contains")] + ExactEventPropertyContainsType(EventPropertyIsTypeCondition), ContainsDisplayName, RoomMemberCount { #[serde(skip_serializing_if = "Option::is_none")] @@ -299,14 +367,43 @@ impl<'source> FromPyObject<'source> for Condition { } } -/// The body of a [`Condition::EventMatch`] +/// The body of a [`Condition::EventMatch`] with a pattern. #[derive(Serialize, Deserialize, Debug, Clone)] pub struct EventMatchCondition { pub key: Cow<'static, str>, - #[serde(skip_serializing_if = "Option::is_none")] - pub pattern: Option<Cow<'static, str>>, - #[serde(skip_serializing_if = "Option::is_none")] - pub pattern_type: Option<Cow<'static, str>>, + pub pattern: Cow<'static, str>, +} + +#[derive(Serialize, Debug, Clone)] +#[serde(rename_all = "snake_case")] +pub enum EventMatchPatternType { + UserId, + UserLocalpart, +} + +/// The body of a [`Condition::EventMatch`] that uses user_id or user_localpart as a pattern. +#[derive(Serialize, Debug, Clone)] +pub struct EventMatchTypeCondition { + pub key: Cow<'static, str>, + // During serialization, the pattern_type property gets replaced with a + // pattern property of the correct value in synapse.push.clientformat.format_push_rules_for_user. + pub pattern_type: Cow<'static, EventMatchPatternType>, +} + +/// The body of a [`Condition::EventPropertyIs`] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct EventPropertyIsCondition { + pub key: Cow<'static, str>, + pub value: Cow<'static, SimpleJsonValue>, +} + +/// The body of a [`Condition::EventPropertyIs`] that uses user_id or user_localpart as a pattern. +#[derive(Serialize, Debug, Clone)] +pub struct EventPropertyIsTypeCondition { + pub key: Cow<'static, str>, + // During serialization, the pattern_type property gets replaced with a + // pattern property of the correct value in synapse.push.clientformat.format_push_rules_for_user. + pub value_type: Cow<'static, EventMatchPatternType>, } /// The body of a [`Condition::RelatedEventMatch`] @@ -316,8 +413,18 @@ pub struct RelatedEventMatchCondition { pub key: Option<Cow<'static, str>>, #[serde(skip_serializing_if = "Option::is_none")] pub pattern: Option<Cow<'static, str>>, + pub rel_type: Cow<'static, str>, #[serde(skip_serializing_if = "Option::is_none")] - pub pattern_type: Option<Cow<'static, str>>, + pub include_fallbacks: Option<bool>, +} + +/// The body of a [`Condition::RelatedEventMatch`] that uses user_id or user_localpart as a pattern. +#[derive(Serialize, Debug, Clone)] +pub struct RelatedEventMatchTypeCondition { + // This is only used if pattern_type exists (and thus key must exist), so is + // a bit simpler than RelatedEventMatchCondition. + pub key: Cow<'static, str>, + pub pattern_type: Cow<'static, EventMatchPatternType>, pub rel_type: Cow<'static, str>, #[serde(skip_serializing_if = "Option::is_none")] pub include_fallbacks: Option<bool>, @@ -501,8 +608,7 @@ impl FilteredPushRules { fn test_serialize_condition() { let condition = Condition::Known(KnownCondition::EventMatch(EventMatchCondition { key: "content.body".into(), - pattern: Some("coffee".into()), - pattern_type: None, + pattern: "coffee".into(), })); let json = serde_json::to_string(&condition).unwrap(); @@ -516,7 +622,33 @@ fn test_serialize_condition() { fn test_deserialize_condition() { let json = r#"{"kind":"event_match","key":"content.body","pattern":"coffee"}"#; - let _: Condition = serde_json::from_str(json).unwrap(); + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!( + condition, + Condition::Known(KnownCondition::EventMatch(_)) + )); +} + +#[test] +fn test_serialize_event_match_condition_with_pattern_type() { + let condition = Condition::Known(KnownCondition::EventMatchType(EventMatchTypeCondition { + key: "content.body".into(), + pattern_type: Cow::Owned(EventMatchPatternType::UserId), + })); + + let json = serde_json::to_string(&condition).unwrap(); + assert_eq!( + json, + r#"{"kind":"event_match","key":"content.body","pattern_type":"user_id"}"# + ) +} + +#[test] +fn test_cannot_deserialize_event_match_condition_with_pattern_type() { + let json = r#"{"kind":"event_match","key":"content.body","pattern_type":"user_id"}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!(condition, Condition::Unknown(_))); } #[test] @@ -531,6 +663,37 @@ fn test_deserialize_unstable_msc3664_condition() { } #[test] +fn test_serialize_unstable_msc3664_condition_with_pattern_type() { + let condition = Condition::Known(KnownCondition::RelatedEventMatchType( + RelatedEventMatchTypeCondition { + key: "content.body".into(), + pattern_type: Cow::Owned(EventMatchPatternType::UserId), + rel_type: "m.in_reply_to".into(), + include_fallbacks: Some(true), + }, + )); + + let json = serde_json::to_string(&condition).unwrap(); + assert_eq!( + json, + r#"{"kind":"im.nheko.msc3664.related_event_match","key":"content.body","pattern_type":"user_id","rel_type":"m.in_reply_to","include_fallbacks":true}"# + ) +} + +#[test] +fn test_cannot_deserialize_unstable_msc3664_condition_with_pattern_type() { + let json = r#"{"kind":"im.nheko.msc3664.related_event_match","key":"content.body","pattern_type":"user_id","rel_type":"m.in_reply_to"}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + // Since pattern is optional on RelatedEventMatch it deserializes it to that + // instead of RelatedEventMatchType. + assert!(matches!( + condition, + Condition::Known(KnownCondition::RelatedEventMatch(_)) + )); +} + +#[test] fn test_deserialize_unstable_msc3931_condition() { let json = r#"{"kind":"org.matrix.msc3931.room_version_supports","feature":"org.example.feature"}"#; @@ -543,24 +706,41 @@ fn test_deserialize_unstable_msc3931_condition() { } #[test] -fn test_deserialize_unstable_msc3952_user_condition() { - let json = r#"{"kind":"org.matrix.msc3952.is_user_mention"}"#; +fn test_deserialize_event_property_is_condition() { + // A string condition should work. + let json = r#"{"kind":"event_property_is","key":"content.value","value":"foo"}"#; let condition: Condition = serde_json::from_str(json).unwrap(); assert!(matches!( condition, - Condition::Known(KnownCondition::IsUserMention) + Condition::Known(KnownCondition::EventPropertyIs(_)) )); -} -#[test] -fn test_deserialize_unstable_msc3952_room_condition() { - let json = r#"{"kind":"org.matrix.msc3952.is_room_mention"}"#; + // A boolean condition should work. + let json = r#"{"kind":"event_property_is","key":"content.value","value":true}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!( + condition, + Condition::Known(KnownCondition::EventPropertyIs(_)) + )); + + // An integer condition should work. + let json = r#"{"kind":"event_property_is","key":"content.value","value":1}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!( + condition, + Condition::Known(KnownCondition::EventPropertyIs(_)) + )); + + // A null condition should work + let json = r#"{"kind":"event_property_is","key":"content.value","value":null}"#; let condition: Condition = serde_json::from_str(json).unwrap(); assert!(matches!( condition, - Condition::Known(KnownCondition::IsRoomMention) + Condition::Known(KnownCondition::EventPropertyIs(_)) )); } |