summary refs log tree commit diff
path: root/rust
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--rust/Cargo.toml8
-rw-r--r--rust/benches/evaluator.rs5
-rw-r--r--rust/src/acl/mod.rs4
-rw-r--r--rust/src/events/filter.rs107
-rw-r--r--rust/src/events/internal_metadata.rs82
-rw-r--r--rust/src/events/mod.rs8
-rw-r--r--rust/src/http.rs2
-rw-r--r--rust/src/identifier.rs252
-rw-r--r--rust/src/lib.rs17
-rw-r--r--rust/src/matrix_const.rs28
-rw-r--r--rust/src/push/base_rules.rs2
-rw-r--r--rust/src/push/evaluator.rs15
-rw-r--r--rust/src/push/mod.rs55
-rw-r--r--rust/src/push/utils.rs1
-rw-r--r--rust/src/rendezvous/mod.rs26
15 files changed, 551 insertions, 61 deletions
diff --git a/rust/Cargo.toml b/rust/Cargo.toml

index 026487275c..840988e74e 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml
@@ -30,14 +30,14 @@ http = "1.1.0" lazy_static = "1.4.0" log = "0.4.17" mime = "0.3.17" -pyo3 = { version = "0.21.0", features = [ +pyo3 = { version = "0.24.2", features = [ "macros", "anyhow", "abi3", - "abi3-py38", + "abi3-py39", ] } -pyo3-log = "0.10.0" -pythonize = "0.21.0" +pyo3-log = "0.12.0" +pythonize = "0.24.0" regex = "1.6.0" sha2 = "0.10.8" serde = { version = "1.0.144", features = ["derive"] } diff --git a/rust/benches/evaluator.rs b/rust/benches/evaluator.rs
index 4fea035b96..28537e187e 100644 --- a/rust/benches/evaluator.rs +++ b/rust/benches/evaluator.rs
@@ -60,6 +60,7 @@ fn bench_match_exact(b: &mut Bencher) { true, vec![], false, + false, ) .unwrap(); @@ -105,6 +106,7 @@ fn bench_match_word(b: &mut Bencher) { true, vec![], false, + false, ) .unwrap(); @@ -150,6 +152,7 @@ fn bench_match_word_miss(b: &mut Bencher) { true, vec![], false, + false, ) .unwrap(); @@ -195,6 +198,7 @@ fn bench_eval_message(b: &mut Bencher) { true, vec![], false, + false, ) .unwrap(); @@ -205,6 +209,7 @@ fn bench_eval_message(b: &mut Bencher) { false, false, false, + false, ); b.iter(|| eval.run(&rules, Some("bob"), Some("person"))); diff --git a/rust/src/acl/mod.rs b/rust/src/acl/mod.rs
index 982720ba90..57b45475fd 100644 --- a/rust/src/acl/mod.rs +++ b/rust/src/acl/mod.rs
@@ -32,14 +32,14 @@ use crate::push::utils::{glob_to_regex, GlobMatchType}; /// Called when registering modules with python. pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { - let child_module = PyModule::new_bound(py, "acl")?; + let child_module = PyModule::new(py, "acl")?; child_module.add_class::<ServerAclEvaluator>()?; m.add_submodule(&child_module)?; // We need to manually add the module to sys.modules to make `from // synapse.synapse_rust import acl` work. - py.import_bound("sys")? + py.import("sys")? .getattr("modules")? .set_item("synapse.synapse_rust.acl", child_module)?; diff --git a/rust/src/events/filter.rs b/rust/src/events/filter.rs new file mode 100644
index 0000000000..7e39972c62 --- /dev/null +++ b/rust/src/events/filter.rs
@@ -0,0 +1,107 @@ +/* + * This file is licensed under the Affero General Public License (AGPL) version 3. + * + * Copyright (C) 2024 New Vector, Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * See the GNU Affero General Public License for more details: + * <https://www.gnu.org/licenses/agpl-3.0.html>. + */ + +use std::collections::HashMap; + +use pyo3::{exceptions::PyValueError, pyfunction, PyResult}; + +use crate::{ + identifier::UserID, + matrix_const::{ + HISTORY_VISIBILITY_INVITED, HISTORY_VISIBILITY_JOINED, MEMBERSHIP_INVITE, MEMBERSHIP_JOIN, + }, +}; + +#[pyfunction(name = "event_visible_to_server")] +pub fn event_visible_to_server_py( + sender: String, + target_server_name: String, + history_visibility: String, + erased_senders: HashMap<String, bool>, + partial_state_invisible: bool, + memberships: Vec<(String, String)>, // (state_key, membership) +) -> PyResult<bool> { + event_visible_to_server( + sender, + target_server_name, + history_visibility, + erased_senders, + partial_state_invisible, + memberships, + ) + .map_err(|e| PyValueError::new_err(format!("{e}"))) +} + +/// Return whether the target server is allowed to see the event. +/// +/// For a fully stated room, the target server is allowed to see an event E if: +/// - the state at E has world readable or shared history vis, OR +/// - the state at E says that the target server is in the room. +/// +/// For a partially stated room, the target server is allowed to see E if: +/// - E was created by this homeserver, AND: +/// - the partial state at E has world readable or shared history vis, OR +/// - the partial state at E says that the target server is in the room. +pub fn event_visible_to_server( + sender: String, + target_server_name: String, + history_visibility: String, + erased_senders: HashMap<String, bool>, + partial_state_invisible: bool, + memberships: Vec<(String, String)>, // (state_key, membership) +) -> anyhow::Result<bool> { + if let Some(&erased) = erased_senders.get(&sender) { + if erased { + return Ok(false); + } + } + + if partial_state_invisible { + return Ok(false); + } + + if history_visibility != HISTORY_VISIBILITY_INVITED + && history_visibility != HISTORY_VISIBILITY_JOINED + { + return Ok(true); + } + + let mut visible = false; + for (state_key, membership) in memberships { + let state_key = UserID::try_from(state_key.as_ref()) + .map_err(|e| anyhow::anyhow!(format!("invalid user_id ({state_key}): {e}")))?; + if state_key.server_name() != target_server_name { + return Err(anyhow::anyhow!( + "state_key.server_name ({}) does not match target_server_name ({target_server_name})", + state_key.server_name() + )); + } + + match membership.as_str() { + MEMBERSHIP_INVITE => { + if history_visibility == HISTORY_VISIBILITY_INVITED { + visible = true; + break; + } + } + MEMBERSHIP_JOIN => { + visible = true; + break; + } + _ => continue, + } + } + + Ok(visible) +} diff --git a/rust/src/events/internal_metadata.rs b/rust/src/events/internal_metadata.rs
index ad87825f16..eeb6074c10 100644 --- a/rust/src/events/internal_metadata.rs +++ b/rust/src/events/internal_metadata.rs
@@ -41,9 +41,11 @@ use pyo3::{ pybacked::PyBackedStr, pyclass, pymethods, types::{PyAnyMethods, PyDict, PyDictMethods, PyString}, - Bound, IntoPy, PyAny, PyObject, PyResult, Python, + Bound, IntoPyObject, PyAny, PyObject, PyResult, Python, }; +use crate::UnwrapInfallible; + /// Definitions of the various fields of the internal metadata. #[derive(Clone)] enum EventInternalMetadataData { @@ -60,31 +62,59 @@ enum EventInternalMetadataData { impl EventInternalMetadataData { /// Convert the field to its name and python object. - fn to_python_pair<'a>(&self, py: Python<'a>) -> (&'a Bound<'a, PyString>, PyObject) { + fn to_python_pair<'a>(&self, py: Python<'a>) -> (&'a Bound<'a, PyString>, Bound<'a, PyAny>) { match self { - EventInternalMetadataData::OutOfBandMembership(o) => { - (pyo3::intern!(py, "out_of_band_membership"), o.into_py(py)) - } - EventInternalMetadataData::SendOnBehalfOf(o) => { - (pyo3::intern!(py, "send_on_behalf_of"), o.into_py(py)) - } - EventInternalMetadataData::RecheckRedaction(o) => { - (pyo3::intern!(py, "recheck_redaction"), o.into_py(py)) - } - EventInternalMetadataData::SoftFailed(o) => { - (pyo3::intern!(py, "soft_failed"), o.into_py(py)) - } - EventInternalMetadataData::ProactivelySend(o) => { - (pyo3::intern!(py, "proactively_send"), o.into_py(py)) - } - EventInternalMetadataData::Redacted(o) => { - (pyo3::intern!(py, "redacted"), o.into_py(py)) - } - EventInternalMetadataData::TxnId(o) => (pyo3::intern!(py, "txn_id"), o.into_py(py)), - EventInternalMetadataData::TokenId(o) => (pyo3::intern!(py, "token_id"), o.into_py(py)), - EventInternalMetadataData::DeviceId(o) => { - (pyo3::intern!(py, "device_id"), o.into_py(py)) - } + EventInternalMetadataData::OutOfBandMembership(o) => ( + pyo3::intern!(py, "out_of_band_membership"), + o.into_pyobject(py) + .unwrap_infallible() + .to_owned() + .into_any(), + ), + EventInternalMetadataData::SendOnBehalfOf(o) => ( + pyo3::intern!(py, "send_on_behalf_of"), + o.into_pyobject(py).unwrap_infallible().into_any(), + ), + EventInternalMetadataData::RecheckRedaction(o) => ( + pyo3::intern!(py, "recheck_redaction"), + o.into_pyobject(py) + .unwrap_infallible() + .to_owned() + .into_any(), + ), + EventInternalMetadataData::SoftFailed(o) => ( + pyo3::intern!(py, "soft_failed"), + o.into_pyobject(py) + .unwrap_infallible() + .to_owned() + .into_any(), + ), + EventInternalMetadataData::ProactivelySend(o) => ( + pyo3::intern!(py, "proactively_send"), + o.into_pyobject(py) + .unwrap_infallible() + .to_owned() + .into_any(), + ), + EventInternalMetadataData::Redacted(o) => ( + pyo3::intern!(py, "redacted"), + o.into_pyobject(py) + .unwrap_infallible() + .to_owned() + .into_any(), + ), + EventInternalMetadataData::TxnId(o) => ( + pyo3::intern!(py, "txn_id"), + o.into_pyobject(py).unwrap_infallible().into_any(), + ), + EventInternalMetadataData::TokenId(o) => ( + pyo3::intern!(py, "token_id"), + o.into_pyobject(py).unwrap_infallible().into_any(), + ), + EventInternalMetadataData::DeviceId(o) => ( + pyo3::intern!(py, "device_id"), + o.into_pyobject(py).unwrap_infallible().into_any(), + ), } } @@ -247,7 +277,7 @@ impl EventInternalMetadata { /// /// Note that `outlier` and `stream_ordering` are stored in separate columns so are not returned here. fn get_dict(&self, py: Python<'_>) -> PyResult<PyObject> { - let dict = PyDict::new_bound(py); + let dict = PyDict::new(py); for entry in &self.data { let (key, value) = entry.to_python_pair(py); diff --git a/rust/src/events/mod.rs b/rust/src/events/mod.rs
index a4ade1a178..209efb917b 100644 --- a/rust/src/events/mod.rs +++ b/rust/src/events/mod.rs
@@ -22,21 +22,23 @@ use pyo3::{ types::{PyAnyMethods, PyModule, PyModuleMethods}, - Bound, PyResult, Python, + wrap_pyfunction, Bound, PyResult, Python, }; +pub mod filter; mod internal_metadata; /// Called when registering modules with python. pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { - let child_module = PyModule::new_bound(py, "events")?; + let child_module = PyModule::new(py, "events")?; child_module.add_class::<internal_metadata::EventInternalMetadata>()?; + child_module.add_function(wrap_pyfunction!(filter::event_visible_to_server_py, m)?)?; m.add_submodule(&child_module)?; // We need to manually add the module to sys.modules to make `from // synapse.synapse_rust import events` work. - py.import_bound("sys")? + py.import("sys")? .getattr("modules")? .set_item("synapse.synapse_rust.events", child_module)?; diff --git a/rust/src/http.rs b/rust/src/http.rs
index af052ab721..63ed05be54 100644 --- a/rust/src/http.rs +++ b/rust/src/http.rs
@@ -70,7 +70,7 @@ pub fn http_request_from_twisted(request: &Bound<'_, PyAny>) -> PyResult<Request let headers_iter = request .getattr("requestHeaders")? .call_method0("getAllRawHeaders")? - .iter()?; + .try_iter()?; for header in headers_iter { let header = header?; diff --git a/rust/src/identifier.rs b/rust/src/identifier.rs new file mode 100644
index 0000000000..b70f6a30c7 --- /dev/null +++ b/rust/src/identifier.rs
@@ -0,0 +1,252 @@ +/* + * This file is licensed under the Affero General Public License (AGPL) version 3. + * + * Copyright (C) 2024 New Vector, Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * See the GNU Affero General Public License for more details: + * <https://www.gnu.org/licenses/agpl-3.0.html>. + */ + +//! # Matrix Identifiers +//! +//! This module contains definitions and utilities for working with matrix identifiers. + +use std::{fmt, ops::Deref}; + +/// Errors that can occur when parsing a matrix identifier. +#[derive(Clone, Debug, PartialEq)] +pub enum IdentifierError { + IncorrectSigil, + MissingColon, +} + +impl fmt::Display for IdentifierError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self) + } +} + +/// A Matrix user_id. +#[derive(Clone, Debug, PartialEq)] +pub struct UserID(String); + +impl UserID { + /// Returns the `localpart` of the user_id. + pub fn localpart(&self) -> &str { + &self[1..self.colon_pos()] + } + + /// Returns the `server_name` / `domain` of the user_id. + pub fn server_name(&self) -> &str { + &self[self.colon_pos() + 1..] + } + + /// Returns the position of the ':' inside of the user_id. + /// Used when splitting the user_id into it's respective parts. + fn colon_pos(&self) -> usize { + self.find(':').unwrap() + } +} + +impl TryFrom<&str> for UserID { + type Error = IdentifierError; + + /// Will try creating a `UserID` from the provided `&str`. + /// Can fail if the user_id is incorrectly formatted. + fn try_from(s: &str) -> Result<Self, Self::Error> { + if !s.starts_with('@') { + return Err(IdentifierError::IncorrectSigil); + } + + if s.find(':').is_none() { + return Err(IdentifierError::MissingColon); + } + + Ok(UserID(s.to_string())) + } +} + +impl TryFrom<String> for UserID { + type Error = IdentifierError; + + /// Will try creating a `UserID` from the provided `&str`. + /// Can fail if the user_id is incorrectly formatted. + fn try_from(s: String) -> Result<Self, Self::Error> { + if !s.starts_with('@') { + return Err(IdentifierError::IncorrectSigil); + } + + if s.find(':').is_none() { + return Err(IdentifierError::MissingColon); + } + + Ok(UserID(s)) + } +} + +impl<'de> serde::Deserialize<'de> for UserID { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + let s: String = serde::Deserialize::deserialize(deserializer)?; + UserID::try_from(s).map_err(serde::de::Error::custom) + } +} + +impl Deref for UserID { + type Target = str; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl fmt::Display for UserID { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +/// A Matrix room_id. +#[derive(Clone, Debug, PartialEq)] +pub struct RoomID(String); + +impl RoomID { + /// Returns the `localpart` of the room_id. + pub fn localpart(&self) -> &str { + &self[1..self.colon_pos()] + } + + /// Returns the `server_name` / `domain` of the room_id. + pub fn server_name(&self) -> &str { + &self[self.colon_pos() + 1..] + } + + /// Returns the position of the ':' inside of the room_id. + /// Used when splitting the room_id into it's respective parts. + fn colon_pos(&self) -> usize { + self.find(':').unwrap() + } +} + +impl TryFrom<&str> for RoomID { + type Error = IdentifierError; + + /// Will try creating a `RoomID` from the provided `&str`. + /// Can fail if the room_id is incorrectly formatted. + fn try_from(s: &str) -> Result<Self, Self::Error> { + if !s.starts_with('!') { + return Err(IdentifierError::IncorrectSigil); + } + + if s.find(':').is_none() { + return Err(IdentifierError::MissingColon); + } + + Ok(RoomID(s.to_string())) + } +} + +impl TryFrom<String> for RoomID { + type Error = IdentifierError; + + /// Will try creating a `RoomID` from the provided `String`. + /// Can fail if the room_id is incorrectly formatted. + fn try_from(s: String) -> Result<Self, Self::Error> { + if !s.starts_with('!') { + return Err(IdentifierError::IncorrectSigil); + } + + if s.find(':').is_none() { + return Err(IdentifierError::MissingColon); + } + + Ok(RoomID(s)) + } +} + +impl<'de> serde::Deserialize<'de> for RoomID { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + let s: String = serde::Deserialize::deserialize(deserializer)?; + RoomID::try_from(s).map_err(serde::de::Error::custom) + } +} + +impl Deref for RoomID { + type Target = str; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl fmt::Display for RoomID { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +/// A Matrix event_id. +#[derive(Clone, Debug, PartialEq)] +pub struct EventID(String); + +impl TryFrom<&str> for EventID { + type Error = IdentifierError; + + /// Will try creating a `EventID` from the provided `&str`. + /// Can fail if the event_id is incorrectly formatted. + fn try_from(s: &str) -> Result<Self, Self::Error> { + if !s.starts_with('$') { + return Err(IdentifierError::IncorrectSigil); + } + + Ok(EventID(s.to_string())) + } +} + +impl TryFrom<String> for EventID { + type Error = IdentifierError; + + /// Will try creating a `EventID` from the provided `String`. + /// Can fail if the event_id is incorrectly formatted. + fn try_from(s: String) -> Result<Self, Self::Error> { + if !s.starts_with('$') { + return Err(IdentifierError::IncorrectSigil); + } + + Ok(EventID(s)) + } +} + +impl<'de> serde::Deserialize<'de> for EventID { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + let s: String = serde::Deserialize::deserialize(deserializer)?; + EventID::try_from(s).map_err(serde::de::Error::custom) + } +} + +impl Deref for EventID { + type Target = str; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl fmt::Display for EventID { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} diff --git a/rust/src/lib.rs b/rust/src/lib.rs
index 06477880b9..d751889874 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs
@@ -1,3 +1,5 @@ +use std::convert::Infallible; + use lazy_static::lazy_static; use pyo3::prelude::*; use pyo3_log::ResetHandle; @@ -6,6 +8,8 @@ pub mod acl; pub mod errors; pub mod events; pub mod http; +pub mod identifier; +pub mod matrix_const; pub mod push; pub mod rendezvous; @@ -50,3 +54,16 @@ fn synapse_rust(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { Ok(()) } + +pub trait UnwrapInfallible<T> { + fn unwrap_infallible(self) -> T; +} + +impl<T> UnwrapInfallible<T> for Result<T, Infallible> { + fn unwrap_infallible(self) -> T { + match self { + Ok(val) => val, + Err(never) => match never {}, + } + } +} diff --git a/rust/src/matrix_const.rs b/rust/src/matrix_const.rs new file mode 100644
index 0000000000..f75f3bd7c3 --- /dev/null +++ b/rust/src/matrix_const.rs
@@ -0,0 +1,28 @@ +/* + * This file is licensed under the Affero General Public License (AGPL) version 3. + * + * Copyright (C) 2024 New Vector, Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * See the GNU Affero General Public License for more details: + * <https://www.gnu.org/licenses/agpl-3.0.html>. + */ + +//! # Matrix Constants +//! +//! This module contains definitions for constant values described by the matrix specification. + +pub const HISTORY_VISIBILITY_WORLD_READABLE: &str = "world_readable"; +pub const HISTORY_VISIBILITY_SHARED: &str = "shared"; +pub const HISTORY_VISIBILITY_INVITED: &str = "invited"; +pub const HISTORY_VISIBILITY_JOINED: &str = "joined"; + +pub const MEMBERSHIP_BAN: &str = "ban"; +pub const MEMBERSHIP_LEAVE: &str = "leave"; +pub const MEMBERSHIP_KNOCK: &str = "knock"; +pub const MEMBERSHIP_INVITE: &str = "invite"; +pub const MEMBERSHIP_JOIN: &str = "join"; diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs
index 74f02d6001..e0832ada1c 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs
@@ -81,7 +81,7 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ ))]), actions: Cow::Borrowed(&[Action::Notify]), default: true, - default_enabled: false, + default_enabled: true, }, PushRule { rule_id: Cow::Borrowed("global/override/.m.rule.suppress_notices"), diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs
index 2f4b6d47bb..db406acb88 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs
@@ -105,6 +105,9 @@ pub struct PushRuleEvaluator { /// If MSC3931 (room version feature flags) is enabled. Usually controlled by the same /// flag as MSC1767 (extensible events core). msc3931_enabled: bool, + + // If MSC4210 (remove legacy mentions) is enabled. + msc4210_enabled: bool, } #[pymethods] @@ -122,6 +125,7 @@ impl PushRuleEvaluator { related_event_match_enabled, room_version_feature_flags, msc3931_enabled, + msc4210_enabled, ))] pub fn py_new( flattened_keys: BTreeMap<String, JsonValue>, @@ -133,6 +137,7 @@ impl PushRuleEvaluator { related_event_match_enabled: bool, room_version_feature_flags: Vec<String>, msc3931_enabled: bool, + msc4210_enabled: bool, ) -> Result<Self, Error> { let body = match flattened_keys.get("content.body") { Some(JsonValue::Value(SimpleJsonValue::Str(s))) => s.clone().into_owned(), @@ -150,6 +155,7 @@ impl PushRuleEvaluator { related_event_match_enabled, room_version_feature_flags, msc3931_enabled, + msc4210_enabled, }) } @@ -161,6 +167,7 @@ impl PushRuleEvaluator { /// /// Returns the set of actions, if any, that match (filtering out any /// `dont_notify` and `coalesce` actions). + #[pyo3(signature = (push_rules, user_id=None, display_name=None))] pub fn run( &self, push_rules: &FilteredPushRules, @@ -176,7 +183,8 @@ impl PushRuleEvaluator { // For backwards-compatibility the legacy mention rules are disabled // if the event contains the 'm.mentions' property. - if self.has_mentions + // Additionally, MSC4210 always disables the legacy rules. + if (self.has_mentions || self.msc4210_enabled) && (rule_id == "global/override/.m.rule.contains_display_name" || rule_id == "global/content/.m.rule.contains_user_name" || rule_id == "global/override/.m.rule.roomnotif") @@ -229,6 +237,7 @@ impl PushRuleEvaluator { } /// Check if the given condition matches. + #[pyo3(signature = (condition, user_id=None, display_name=None))] fn matches( &self, condition: Condition, @@ -526,6 +535,7 @@ fn push_rule_evaluator() { true, vec![], true, + false, ) .unwrap(); @@ -555,6 +565,7 @@ fn test_requires_room_version_supports_condition() { false, flags, true, + false, ) .unwrap(); @@ -582,7 +593,7 @@ fn test_requires_room_version_supports_condition() { }; let rules = PushRules::new(vec![custom_rule]); result = evaluator.run( - &FilteredPushRules::py_new(rules, BTreeMap::new(), true, false, true, false), + &FilteredPushRules::py_new(rules, BTreeMap::new(), true, false, true, false, false), None, None, ); diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs
index 2a452b69a3..bd0e853ac3 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs
@@ -65,8 +65,8 @@ use anyhow::{Context, Error}; use log::warn; use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; -use pyo3::types::{PyBool, PyList, PyLong, PyString}; -use pythonize::{depythonize_bound, pythonize}; +use pyo3::types::{PyBool, PyInt, PyList, PyString}; +use pythonize::{depythonize, pythonize, PythonizeError}; use serde::de::Error as _; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -79,7 +79,7 @@ pub mod utils; /// Called when registering modules with python. pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { - let child_module = PyModule::new_bound(py, "push")?; + let child_module = PyModule::new(py, "push")?; child_module.add_class::<PushRule>()?; child_module.add_class::<PushRules>()?; child_module.add_class::<FilteredPushRules>()?; @@ -90,7 +90,7 @@ pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> // We need to manually add the module to sys.modules to make `from // synapse.synapse_rust import push` work. - py.import_bound("sys")? + py.import("sys")? .getattr("modules")? .set_item("synapse.synapse_rust.push", child_module)?; @@ -182,12 +182,16 @@ pub enum Action { Unknown(Value), } -impl IntoPy<PyObject> for Action { - fn into_py(self, py: Python<'_>) -> PyObject { +impl<'py> IntoPyObject<'py> for Action { + type Target = PyAny; + type Output = Bound<'py, Self::Target>; + type Error = PythonizeError; + + fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> { // When we pass the `Action` struct to Python we want it to be converted // to a dict. We use `pythonize`, which converts the struct using the // `serde` serialization. - pythonize(py, &self).expect("valid action") + pythonize(py, &self) } } @@ -270,13 +274,13 @@ pub enum SimpleJsonValue { } impl<'source> FromPyObject<'source> for SimpleJsonValue { - fn extract(ob: &'source PyAny) -> PyResult<Self> { + fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> { if let Ok(s) = ob.downcast::<PyString>() { Ok(SimpleJsonValue::Str(Cow::Owned(s.to_string()))) // A bool *is* an int, ensure we try bool first. } else if let Ok(b) = ob.downcast::<PyBool>() { Ok(SimpleJsonValue::Bool(b.extract()?)) - } else if let Ok(i) = ob.downcast::<PyLong>() { + } else if let Ok(i) = ob.downcast::<PyInt>() { Ok(SimpleJsonValue::Int(i.extract()?)) } else if ob.is_none() { Ok(SimpleJsonValue::Null) @@ -298,15 +302,19 @@ pub enum JsonValue { } impl<'source> FromPyObject<'source> for JsonValue { - fn extract(ob: &'source PyAny) -> PyResult<Self> { + fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> { if let Ok(l) = ob.downcast::<PyList>() { - match l.iter().map(SimpleJsonValue::extract).collect() { + match l + .iter() + .map(|it| SimpleJsonValue::extract_bound(&it)) + .collect() + { Ok(a) => Ok(JsonValue::Array(a)), Err(e) => Err(PyTypeError::new_err(format!( "Can't convert to JsonValue::Array: {e}" ))), } - } else if let Ok(v) = SimpleJsonValue::extract(ob) { + } else if let Ok(v) = SimpleJsonValue::extract_bound(ob) { Ok(JsonValue::Value(v)) } else { Err(PyTypeError::new_err(format!( @@ -363,15 +371,19 @@ pub enum KnownCondition { }, } -impl IntoPy<PyObject> for Condition { - fn into_py(self, py: Python<'_>) -> PyObject { - pythonize(py, &self).expect("valid condition") +impl<'source> IntoPyObject<'source> for Condition { + type Target = PyAny; + type Output = Bound<'source, Self::Target>; + type Error = PythonizeError; + + fn into_pyobject(self, py: Python<'source>) -> Result<Self::Output, Self::Error> { + pythonize(py, &self) } } impl<'source> FromPyObject<'source> for Condition { fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> { - Ok(depythonize_bound(ob.clone())?) + Ok(depythonize(ob)?) } } @@ -534,6 +546,7 @@ pub struct FilteredPushRules { msc3381_polls_enabled: bool, msc3664_enabled: bool, msc4028_push_encrypted_events: bool, + msc4210_enabled: bool, } #[pymethods] @@ -546,6 +559,7 @@ impl FilteredPushRules { msc3381_polls_enabled: bool, msc3664_enabled: bool, msc4028_push_encrypted_events: bool, + msc4210_enabled: bool, ) -> Self { Self { push_rules, @@ -554,6 +568,7 @@ impl FilteredPushRules { msc3381_polls_enabled, msc3664_enabled, msc4028_push_encrypted_events, + msc4210_enabled, } } @@ -596,6 +611,14 @@ impl FilteredPushRules { return false; } + if self.msc4210_enabled + && (rule.rule_id == "global/override/.m.rule.contains_display_name" + || rule.rule_id == "global/content/.m.rule.contains_user_name" + || rule.rule_id == "global/override/.m.rule.roomnotif") + { + return false; + } + true }) .map(|r| { diff --git a/rust/src/push/utils.rs b/rust/src/push/utils.rs
index 28ebed62c8..59536c9954 100644 --- a/rust/src/push/utils.rs +++ b/rust/src/push/utils.rs
@@ -23,7 +23,6 @@ use anyhow::bail; use anyhow::Context; use anyhow::Error; use lazy_static::lazy_static; -use regex; use regex::Regex; use regex::RegexBuilder; diff --git a/rust/src/rendezvous/mod.rs b/rust/src/rendezvous/mod.rs
index f69f45490f..3148e0f67a 100644 --- a/rust/src/rendezvous/mod.rs +++ b/rust/src/rendezvous/mod.rs
@@ -29,7 +29,7 @@ use pyo3::{ exceptions::PyValueError, pyclass, pymethods, types::{PyAnyMethods, PyModule, PyModuleMethods}, - Bound, Py, PyAny, PyObject, PyResult, Python, ToPyObject, + Bound, IntoPyObject, Py, PyAny, PyObject, PyResult, Python, }; use ulid::Ulid; @@ -37,6 +37,7 @@ use self::session::Session; use crate::{ errors::{NotFoundError, SynapseError}, http::{http_request_from_twisted, http_response_to_twisted, HeaderMapPyExt}, + UnwrapInfallible, }; mod session; @@ -46,7 +47,7 @@ fn prepare_headers(headers: &mut HeaderMap, session: &Session) { headers.typed_insert(AccessControlAllowOrigin::ANY); headers.typed_insert(AccessControlExposeHeaders::from_iter([ETAG])); headers.typed_insert(Pragma::no_cache()); - headers.typed_insert(CacheControl::new().with_no_store()); + headers.typed_insert(CacheControl::new().with_no_store().with_no_transform()); headers.typed_insert(session.etag()); headers.typed_insert(session.expires()); headers.typed_insert(session.last_modified()); @@ -125,7 +126,11 @@ impl RendezvousHandler { let base = Uri::try_from(format!("{base}_synapse/client/rendezvous")) .map_err(|_| PyValueError::new_err("Invalid base URI"))?; - let clock = homeserver.call_method0("get_clock")?.to_object(py); + let clock = homeserver + .call_method0("get_clock")? + .into_pyobject(py) + .unwrap_infallible() + .unbind(); // Construct a Python object so that we can get a reference to the // evict method and schedule it to run. @@ -187,10 +192,12 @@ impl RendezvousHandler { "url": uri, }) .to_string(); + let length = response.len() as _; let mut response = Response::new(response.as_bytes()); *response.status_mut() = StatusCode::CREATED; response.headers_mut().typed_insert(ContentType::json()); + response.headers_mut().typed_insert(ContentLength(length)); prepare_headers(response.headers_mut(), &session); http_response_to_twisted(twisted_request, response)?; @@ -288,6 +295,14 @@ impl RendezvousHandler { let mut response = Response::new(Bytes::new()); *response.status_mut() = StatusCode::ACCEPTED; prepare_headers(response.headers_mut(), session); + + // Even though this isn't mandated by the MSC, we set a Content-Type on the response. It + // doesn't do any harm as the body is empty, but this helps escape a bug in some reverse + // proxy/cache setup which strips the ETag header if there is no Content-Type set. + // Specifically, we noticed this behaviour when placing Synapse behind Cloudflare. + response.headers_mut().typed_insert(ContentType::text()); + response.headers_mut().typed_insert(ContentLength(0)); + http_response_to_twisted(twisted_request, response)?; Ok(()) @@ -304,6 +319,7 @@ impl RendezvousHandler { response .headers_mut() .typed_insert(AccessControlAllowOrigin::ANY); + response.headers_mut().typed_insert(ContentLength(0)); http_response_to_twisted(twisted_request, response)?; Ok(()) @@ -311,7 +327,7 @@ impl RendezvousHandler { } pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { - let child_module = PyModule::new_bound(py, "rendezvous")?; + let child_module = PyModule::new(py, "rendezvous")?; child_module.add_class::<RendezvousHandler>()?; @@ -319,7 +335,7 @@ pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> // We need to manually add the module to sys.modules to make `from // synapse.synapse_rust import rendezvous` work. - py.import_bound("sys")? + py.import("sys")? .getattr("modules")? .set_item("synapse.synapse_rust.rendezvous", child_module)?;