diff --git a/rust/Cargo.toml b/rust/Cargo.toml
index 026487275c..840988e74e 100644
--- a/rust/Cargo.toml
+++ b/rust/Cargo.toml
@@ -30,14 +30,14 @@ http = "1.1.0"
lazy_static = "1.4.0"
log = "0.4.17"
mime = "0.3.17"
-pyo3 = { version = "0.21.0", features = [
+pyo3 = { version = "0.24.2", features = [
"macros",
"anyhow",
"abi3",
- "abi3-py38",
+ "abi3-py39",
] }
-pyo3-log = "0.10.0"
-pythonize = "0.21.0"
+pyo3-log = "0.12.0"
+pythonize = "0.24.0"
regex = "1.6.0"
sha2 = "0.10.8"
serde = { version = "1.0.144", features = ["derive"] }
diff --git a/rust/benches/evaluator.rs b/rust/benches/evaluator.rs
index 4fea035b96..28537e187e 100644
--- a/rust/benches/evaluator.rs
+++ b/rust/benches/evaluator.rs
@@ -60,6 +60,7 @@ fn bench_match_exact(b: &mut Bencher) {
true,
vec![],
false,
+ false,
)
.unwrap();
@@ -105,6 +106,7 @@ fn bench_match_word(b: &mut Bencher) {
true,
vec![],
false,
+ false,
)
.unwrap();
@@ -150,6 +152,7 @@ fn bench_match_word_miss(b: &mut Bencher) {
true,
vec![],
false,
+ false,
)
.unwrap();
@@ -195,6 +198,7 @@ fn bench_eval_message(b: &mut Bencher) {
true,
vec![],
false,
+ false,
)
.unwrap();
@@ -205,6 +209,7 @@ fn bench_eval_message(b: &mut Bencher) {
false,
false,
false,
+ false,
);
b.iter(|| eval.run(&rules, Some("bob"), Some("person")));
diff --git a/rust/src/acl/mod.rs b/rust/src/acl/mod.rs
index 982720ba90..57b45475fd 100644
--- a/rust/src/acl/mod.rs
+++ b/rust/src/acl/mod.rs
@@ -32,14 +32,14 @@ use crate::push::utils::{glob_to_regex, GlobMatchType};
/// Called when registering modules with python.
pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
- let child_module = PyModule::new_bound(py, "acl")?;
+ let child_module = PyModule::new(py, "acl")?;
child_module.add_class::<ServerAclEvaluator>()?;
m.add_submodule(&child_module)?;
// We need to manually add the module to sys.modules to make `from
// synapse.synapse_rust import acl` work.
- py.import_bound("sys")?
+ py.import("sys")?
.getattr("modules")?
.set_item("synapse.synapse_rust.acl", child_module)?;
diff --git a/rust/src/events/filter.rs b/rust/src/events/filter.rs
new file mode 100644
index 0000000000..7e39972c62
--- /dev/null
+++ b/rust/src/events/filter.rs
@@ -0,0 +1,107 @@
+/*
+ * This file is licensed under the Affero General Public License (AGPL) version 3.
+ *
+ * Copyright (C) 2024 New Vector, Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as
+ * published by the Free Software Foundation, either version 3 of the
+ * License, or (at your option) any later version.
+ *
+ * See the GNU Affero General Public License for more details:
+ * <https://www.gnu.org/licenses/agpl-3.0.html>.
+ */
+
+use std::collections::HashMap;
+
+use pyo3::{exceptions::PyValueError, pyfunction, PyResult};
+
+use crate::{
+ identifier::UserID,
+ matrix_const::{
+ HISTORY_VISIBILITY_INVITED, HISTORY_VISIBILITY_JOINED, MEMBERSHIP_INVITE, MEMBERSHIP_JOIN,
+ },
+};
+
+#[pyfunction(name = "event_visible_to_server")]
+pub fn event_visible_to_server_py(
+ sender: String,
+ target_server_name: String,
+ history_visibility: String,
+ erased_senders: HashMap<String, bool>,
+ partial_state_invisible: bool,
+ memberships: Vec<(String, String)>, // (state_key, membership)
+) -> PyResult<bool> {
+ event_visible_to_server(
+ sender,
+ target_server_name,
+ history_visibility,
+ erased_senders,
+ partial_state_invisible,
+ memberships,
+ )
+ .map_err(|e| PyValueError::new_err(format!("{e}")))
+}
+
+/// Return whether the target server is allowed to see the event.
+///
+/// For a fully stated room, the target server is allowed to see an event E if:
+/// - the state at E has world readable or shared history vis, OR
+/// - the state at E says that the target server is in the room.
+///
+/// For a partially stated room, the target server is allowed to see E if:
+/// - E was created by this homeserver, AND:
+/// - the partial state at E has world readable or shared history vis, OR
+/// - the partial state at E says that the target server is in the room.
+pub fn event_visible_to_server(
+ sender: String,
+ target_server_name: String,
+ history_visibility: String,
+ erased_senders: HashMap<String, bool>,
+ partial_state_invisible: bool,
+ memberships: Vec<(String, String)>, // (state_key, membership)
+) -> anyhow::Result<bool> {
+ if let Some(&erased) = erased_senders.get(&sender) {
+ if erased {
+ return Ok(false);
+ }
+ }
+
+ if partial_state_invisible {
+ return Ok(false);
+ }
+
+ if history_visibility != HISTORY_VISIBILITY_INVITED
+ && history_visibility != HISTORY_VISIBILITY_JOINED
+ {
+ return Ok(true);
+ }
+
+ let mut visible = false;
+ for (state_key, membership) in memberships {
+ let state_key = UserID::try_from(state_key.as_ref())
+ .map_err(|e| anyhow::anyhow!(format!("invalid user_id ({state_key}): {e}")))?;
+ if state_key.server_name() != target_server_name {
+ return Err(anyhow::anyhow!(
+ "state_key.server_name ({}) does not match target_server_name ({target_server_name})",
+ state_key.server_name()
+ ));
+ }
+
+ match membership.as_str() {
+ MEMBERSHIP_INVITE => {
+ if history_visibility == HISTORY_VISIBILITY_INVITED {
+ visible = true;
+ break;
+ }
+ }
+ MEMBERSHIP_JOIN => {
+ visible = true;
+ break;
+ }
+ _ => continue,
+ }
+ }
+
+ Ok(visible)
+}
diff --git a/rust/src/events/internal_metadata.rs b/rust/src/events/internal_metadata.rs
index ad87825f16..eeb6074c10 100644
--- a/rust/src/events/internal_metadata.rs
+++ b/rust/src/events/internal_metadata.rs
@@ -41,9 +41,11 @@ use pyo3::{
pybacked::PyBackedStr,
pyclass, pymethods,
types::{PyAnyMethods, PyDict, PyDictMethods, PyString},
- Bound, IntoPy, PyAny, PyObject, PyResult, Python,
+ Bound, IntoPyObject, PyAny, PyObject, PyResult, Python,
};
+use crate::UnwrapInfallible;
+
/// Definitions of the various fields of the internal metadata.
#[derive(Clone)]
enum EventInternalMetadataData {
@@ -60,31 +62,59 @@ enum EventInternalMetadataData {
impl EventInternalMetadataData {
/// Convert the field to its name and python object.
- fn to_python_pair<'a>(&self, py: Python<'a>) -> (&'a Bound<'a, PyString>, PyObject) {
+ fn to_python_pair<'a>(&self, py: Python<'a>) -> (&'a Bound<'a, PyString>, Bound<'a, PyAny>) {
match self {
- EventInternalMetadataData::OutOfBandMembership(o) => {
- (pyo3::intern!(py, "out_of_band_membership"), o.into_py(py))
- }
- EventInternalMetadataData::SendOnBehalfOf(o) => {
- (pyo3::intern!(py, "send_on_behalf_of"), o.into_py(py))
- }
- EventInternalMetadataData::RecheckRedaction(o) => {
- (pyo3::intern!(py, "recheck_redaction"), o.into_py(py))
- }
- EventInternalMetadataData::SoftFailed(o) => {
- (pyo3::intern!(py, "soft_failed"), o.into_py(py))
- }
- EventInternalMetadataData::ProactivelySend(o) => {
- (pyo3::intern!(py, "proactively_send"), o.into_py(py))
- }
- EventInternalMetadataData::Redacted(o) => {
- (pyo3::intern!(py, "redacted"), o.into_py(py))
- }
- EventInternalMetadataData::TxnId(o) => (pyo3::intern!(py, "txn_id"), o.into_py(py)),
- EventInternalMetadataData::TokenId(o) => (pyo3::intern!(py, "token_id"), o.into_py(py)),
- EventInternalMetadataData::DeviceId(o) => {
- (pyo3::intern!(py, "device_id"), o.into_py(py))
- }
+ EventInternalMetadataData::OutOfBandMembership(o) => (
+ pyo3::intern!(py, "out_of_band_membership"),
+ o.into_pyobject(py)
+ .unwrap_infallible()
+ .to_owned()
+ .into_any(),
+ ),
+ EventInternalMetadataData::SendOnBehalfOf(o) => (
+ pyo3::intern!(py, "send_on_behalf_of"),
+ o.into_pyobject(py).unwrap_infallible().into_any(),
+ ),
+ EventInternalMetadataData::RecheckRedaction(o) => (
+ pyo3::intern!(py, "recheck_redaction"),
+ o.into_pyobject(py)
+ .unwrap_infallible()
+ .to_owned()
+ .into_any(),
+ ),
+ EventInternalMetadataData::SoftFailed(o) => (
+ pyo3::intern!(py, "soft_failed"),
+ o.into_pyobject(py)
+ .unwrap_infallible()
+ .to_owned()
+ .into_any(),
+ ),
+ EventInternalMetadataData::ProactivelySend(o) => (
+ pyo3::intern!(py, "proactively_send"),
+ o.into_pyobject(py)
+ .unwrap_infallible()
+ .to_owned()
+ .into_any(),
+ ),
+ EventInternalMetadataData::Redacted(o) => (
+ pyo3::intern!(py, "redacted"),
+ o.into_pyobject(py)
+ .unwrap_infallible()
+ .to_owned()
+ .into_any(),
+ ),
+ EventInternalMetadataData::TxnId(o) => (
+ pyo3::intern!(py, "txn_id"),
+ o.into_pyobject(py).unwrap_infallible().into_any(),
+ ),
+ EventInternalMetadataData::TokenId(o) => (
+ pyo3::intern!(py, "token_id"),
+ o.into_pyobject(py).unwrap_infallible().into_any(),
+ ),
+ EventInternalMetadataData::DeviceId(o) => (
+ pyo3::intern!(py, "device_id"),
+ o.into_pyobject(py).unwrap_infallible().into_any(),
+ ),
}
}
@@ -247,7 +277,7 @@ impl EventInternalMetadata {
///
/// Note that `outlier` and `stream_ordering` are stored in separate columns so are not returned here.
fn get_dict(&self, py: Python<'_>) -> PyResult<PyObject> {
- let dict = PyDict::new_bound(py);
+ let dict = PyDict::new(py);
for entry in &self.data {
let (key, value) = entry.to_python_pair(py);
diff --git a/rust/src/events/mod.rs b/rust/src/events/mod.rs
index a4ade1a178..209efb917b 100644
--- a/rust/src/events/mod.rs
+++ b/rust/src/events/mod.rs
@@ -22,21 +22,23 @@
use pyo3::{
types::{PyAnyMethods, PyModule, PyModuleMethods},
- Bound, PyResult, Python,
+ wrap_pyfunction, Bound, PyResult, Python,
};
+pub mod filter;
mod internal_metadata;
/// Called when registering modules with python.
pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
- let child_module = PyModule::new_bound(py, "events")?;
+ let child_module = PyModule::new(py, "events")?;
child_module.add_class::<internal_metadata::EventInternalMetadata>()?;
+ child_module.add_function(wrap_pyfunction!(filter::event_visible_to_server_py, m)?)?;
m.add_submodule(&child_module)?;
// We need to manually add the module to sys.modules to make `from
// synapse.synapse_rust import events` work.
- py.import_bound("sys")?
+ py.import("sys")?
.getattr("modules")?
.set_item("synapse.synapse_rust.events", child_module)?;
diff --git a/rust/src/http.rs b/rust/src/http.rs
index af052ab721..63ed05be54 100644
--- a/rust/src/http.rs
+++ b/rust/src/http.rs
@@ -70,7 +70,7 @@ pub fn http_request_from_twisted(request: &Bound<'_, PyAny>) -> PyResult<Request
let headers_iter = request
.getattr("requestHeaders")?
.call_method0("getAllRawHeaders")?
- .iter()?;
+ .try_iter()?;
for header in headers_iter {
let header = header?;
diff --git a/rust/src/identifier.rs b/rust/src/identifier.rs
new file mode 100644
index 0000000000..b70f6a30c7
--- /dev/null
+++ b/rust/src/identifier.rs
@@ -0,0 +1,252 @@
+/*
+ * This file is licensed under the Affero General Public License (AGPL) version 3.
+ *
+ * Copyright (C) 2024 New Vector, Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as
+ * published by the Free Software Foundation, either version 3 of the
+ * License, or (at your option) any later version.
+ *
+ * See the GNU Affero General Public License for more details:
+ * <https://www.gnu.org/licenses/agpl-3.0.html>.
+ */
+
+//! # Matrix Identifiers
+//!
+//! This module contains definitions and utilities for working with matrix identifiers.
+
+use std::{fmt, ops::Deref};
+
+/// Errors that can occur when parsing a matrix identifier.
+#[derive(Clone, Debug, PartialEq)]
+pub enum IdentifierError {
+ IncorrectSigil,
+ MissingColon,
+}
+
+impl fmt::Display for IdentifierError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "{:?}", self)
+ }
+}
+
+/// A Matrix user_id.
+#[derive(Clone, Debug, PartialEq)]
+pub struct UserID(String);
+
+impl UserID {
+ /// Returns the `localpart` of the user_id.
+ pub fn localpart(&self) -> &str {
+ &self[1..self.colon_pos()]
+ }
+
+ /// Returns the `server_name` / `domain` of the user_id.
+ pub fn server_name(&self) -> &str {
+ &self[self.colon_pos() + 1..]
+ }
+
+ /// Returns the position of the ':' inside of the user_id.
+ /// Used when splitting the user_id into it's respective parts.
+ fn colon_pos(&self) -> usize {
+ self.find(':').unwrap()
+ }
+}
+
+impl TryFrom<&str> for UserID {
+ type Error = IdentifierError;
+
+ /// Will try creating a `UserID` from the provided `&str`.
+ /// Can fail if the user_id is incorrectly formatted.
+ fn try_from(s: &str) -> Result<Self, Self::Error> {
+ if !s.starts_with('@') {
+ return Err(IdentifierError::IncorrectSigil);
+ }
+
+ if s.find(':').is_none() {
+ return Err(IdentifierError::MissingColon);
+ }
+
+ Ok(UserID(s.to_string()))
+ }
+}
+
+impl TryFrom<String> for UserID {
+ type Error = IdentifierError;
+
+ /// Will try creating a `UserID` from the provided `&str`.
+ /// Can fail if the user_id is incorrectly formatted.
+ fn try_from(s: String) -> Result<Self, Self::Error> {
+ if !s.starts_with('@') {
+ return Err(IdentifierError::IncorrectSigil);
+ }
+
+ if s.find(':').is_none() {
+ return Err(IdentifierError::MissingColon);
+ }
+
+ Ok(UserID(s))
+ }
+}
+
+impl<'de> serde::Deserialize<'de> for UserID {
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ let s: String = serde::Deserialize::deserialize(deserializer)?;
+ UserID::try_from(s).map_err(serde::de::Error::custom)
+ }
+}
+
+impl Deref for UserID {
+ type Target = str;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl fmt::Display for UserID {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "{}", self.0)
+ }
+}
+
+/// A Matrix room_id.
+#[derive(Clone, Debug, PartialEq)]
+pub struct RoomID(String);
+
+impl RoomID {
+ /// Returns the `localpart` of the room_id.
+ pub fn localpart(&self) -> &str {
+ &self[1..self.colon_pos()]
+ }
+
+ /// Returns the `server_name` / `domain` of the room_id.
+ pub fn server_name(&self) -> &str {
+ &self[self.colon_pos() + 1..]
+ }
+
+ /// Returns the position of the ':' inside of the room_id.
+ /// Used when splitting the room_id into it's respective parts.
+ fn colon_pos(&self) -> usize {
+ self.find(':').unwrap()
+ }
+}
+
+impl TryFrom<&str> for RoomID {
+ type Error = IdentifierError;
+
+ /// Will try creating a `RoomID` from the provided `&str`.
+ /// Can fail if the room_id is incorrectly formatted.
+ fn try_from(s: &str) -> Result<Self, Self::Error> {
+ if !s.starts_with('!') {
+ return Err(IdentifierError::IncorrectSigil);
+ }
+
+ if s.find(':').is_none() {
+ return Err(IdentifierError::MissingColon);
+ }
+
+ Ok(RoomID(s.to_string()))
+ }
+}
+
+impl TryFrom<String> for RoomID {
+ type Error = IdentifierError;
+
+ /// Will try creating a `RoomID` from the provided `String`.
+ /// Can fail if the room_id is incorrectly formatted.
+ fn try_from(s: String) -> Result<Self, Self::Error> {
+ if !s.starts_with('!') {
+ return Err(IdentifierError::IncorrectSigil);
+ }
+
+ if s.find(':').is_none() {
+ return Err(IdentifierError::MissingColon);
+ }
+
+ Ok(RoomID(s))
+ }
+}
+
+impl<'de> serde::Deserialize<'de> for RoomID {
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ let s: String = serde::Deserialize::deserialize(deserializer)?;
+ RoomID::try_from(s).map_err(serde::de::Error::custom)
+ }
+}
+
+impl Deref for RoomID {
+ type Target = str;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl fmt::Display for RoomID {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "{}", self.0)
+ }
+}
+
+/// A Matrix event_id.
+#[derive(Clone, Debug, PartialEq)]
+pub struct EventID(String);
+
+impl TryFrom<&str> for EventID {
+ type Error = IdentifierError;
+
+ /// Will try creating a `EventID` from the provided `&str`.
+ /// Can fail if the event_id is incorrectly formatted.
+ fn try_from(s: &str) -> Result<Self, Self::Error> {
+ if !s.starts_with('$') {
+ return Err(IdentifierError::IncorrectSigil);
+ }
+
+ Ok(EventID(s.to_string()))
+ }
+}
+
+impl TryFrom<String> for EventID {
+ type Error = IdentifierError;
+
+ /// Will try creating a `EventID` from the provided `String`.
+ /// Can fail if the event_id is incorrectly formatted.
+ fn try_from(s: String) -> Result<Self, Self::Error> {
+ if !s.starts_with('$') {
+ return Err(IdentifierError::IncorrectSigil);
+ }
+
+ Ok(EventID(s))
+ }
+}
+
+impl<'de> serde::Deserialize<'de> for EventID {
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ let s: String = serde::Deserialize::deserialize(deserializer)?;
+ EventID::try_from(s).map_err(serde::de::Error::custom)
+ }
+}
+
+impl Deref for EventID {
+ type Target = str;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl fmt::Display for EventID {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "{}", self.0)
+ }
+}
diff --git a/rust/src/lib.rs b/rust/src/lib.rs
index 06477880b9..d751889874 100644
--- a/rust/src/lib.rs
+++ b/rust/src/lib.rs
@@ -1,3 +1,5 @@
+use std::convert::Infallible;
+
use lazy_static::lazy_static;
use pyo3::prelude::*;
use pyo3_log::ResetHandle;
@@ -6,6 +8,8 @@ pub mod acl;
pub mod errors;
pub mod events;
pub mod http;
+pub mod identifier;
+pub mod matrix_const;
pub mod push;
pub mod rendezvous;
@@ -50,3 +54,16 @@ fn synapse_rust(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
Ok(())
}
+
+pub trait UnwrapInfallible<T> {
+ fn unwrap_infallible(self) -> T;
+}
+
+impl<T> UnwrapInfallible<T> for Result<T, Infallible> {
+ fn unwrap_infallible(self) -> T {
+ match self {
+ Ok(val) => val,
+ Err(never) => match never {},
+ }
+ }
+}
diff --git a/rust/src/matrix_const.rs b/rust/src/matrix_const.rs
new file mode 100644
index 0000000000..f75f3bd7c3
--- /dev/null
+++ b/rust/src/matrix_const.rs
@@ -0,0 +1,28 @@
+/*
+ * This file is licensed under the Affero General Public License (AGPL) version 3.
+ *
+ * Copyright (C) 2024 New Vector, Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as
+ * published by the Free Software Foundation, either version 3 of the
+ * License, or (at your option) any later version.
+ *
+ * See the GNU Affero General Public License for more details:
+ * <https://www.gnu.org/licenses/agpl-3.0.html>.
+ */
+
+//! # Matrix Constants
+//!
+//! This module contains definitions for constant values described by the matrix specification.
+
+pub const HISTORY_VISIBILITY_WORLD_READABLE: &str = "world_readable";
+pub const HISTORY_VISIBILITY_SHARED: &str = "shared";
+pub const HISTORY_VISIBILITY_INVITED: &str = "invited";
+pub const HISTORY_VISIBILITY_JOINED: &str = "joined";
+
+pub const MEMBERSHIP_BAN: &str = "ban";
+pub const MEMBERSHIP_LEAVE: &str = "leave";
+pub const MEMBERSHIP_KNOCK: &str = "knock";
+pub const MEMBERSHIP_INVITE: &str = "invite";
+pub const MEMBERSHIP_JOIN: &str = "join";
diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs
index 74f02d6001..e0832ada1c 100644
--- a/rust/src/push/base_rules.rs
+++ b/rust/src/push/base_rules.rs
@@ -81,7 +81,7 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[
))]),
actions: Cow::Borrowed(&[Action::Notify]),
default: true,
- default_enabled: false,
+ default_enabled: true,
},
PushRule {
rule_id: Cow::Borrowed("global/override/.m.rule.suppress_notices"),
diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs
index 2f4b6d47bb..db406acb88 100644
--- a/rust/src/push/evaluator.rs
+++ b/rust/src/push/evaluator.rs
@@ -105,6 +105,9 @@ pub struct PushRuleEvaluator {
/// If MSC3931 (room version feature flags) is enabled. Usually controlled by the same
/// flag as MSC1767 (extensible events core).
msc3931_enabled: bool,
+
+ // If MSC4210 (remove legacy mentions) is enabled.
+ msc4210_enabled: bool,
}
#[pymethods]
@@ -122,6 +125,7 @@ impl PushRuleEvaluator {
related_event_match_enabled,
room_version_feature_flags,
msc3931_enabled,
+ msc4210_enabled,
))]
pub fn py_new(
flattened_keys: BTreeMap<String, JsonValue>,
@@ -133,6 +137,7 @@ impl PushRuleEvaluator {
related_event_match_enabled: bool,
room_version_feature_flags: Vec<String>,
msc3931_enabled: bool,
+ msc4210_enabled: bool,
) -> Result<Self, Error> {
let body = match flattened_keys.get("content.body") {
Some(JsonValue::Value(SimpleJsonValue::Str(s))) => s.clone().into_owned(),
@@ -150,6 +155,7 @@ impl PushRuleEvaluator {
related_event_match_enabled,
room_version_feature_flags,
msc3931_enabled,
+ msc4210_enabled,
})
}
@@ -161,6 +167,7 @@ impl PushRuleEvaluator {
///
/// Returns the set of actions, if any, that match (filtering out any
/// `dont_notify` and `coalesce` actions).
+ #[pyo3(signature = (push_rules, user_id=None, display_name=None))]
pub fn run(
&self,
push_rules: &FilteredPushRules,
@@ -176,7 +183,8 @@ impl PushRuleEvaluator {
// For backwards-compatibility the legacy mention rules are disabled
// if the event contains the 'm.mentions' property.
- if self.has_mentions
+ // Additionally, MSC4210 always disables the legacy rules.
+ if (self.has_mentions || self.msc4210_enabled)
&& (rule_id == "global/override/.m.rule.contains_display_name"
|| rule_id == "global/content/.m.rule.contains_user_name"
|| rule_id == "global/override/.m.rule.roomnotif")
@@ -229,6 +237,7 @@ impl PushRuleEvaluator {
}
/// Check if the given condition matches.
+ #[pyo3(signature = (condition, user_id=None, display_name=None))]
fn matches(
&self,
condition: Condition,
@@ -526,6 +535,7 @@ fn push_rule_evaluator() {
true,
vec![],
true,
+ false,
)
.unwrap();
@@ -555,6 +565,7 @@ fn test_requires_room_version_supports_condition() {
false,
flags,
true,
+ false,
)
.unwrap();
@@ -582,7 +593,7 @@ fn test_requires_room_version_supports_condition() {
};
let rules = PushRules::new(vec![custom_rule]);
result = evaluator.run(
- &FilteredPushRules::py_new(rules, BTreeMap::new(), true, false, true, false),
+ &FilteredPushRules::py_new(rules, BTreeMap::new(), true, false, true, false, false),
None,
None,
);
diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs
index 2a452b69a3..bd0e853ac3 100644
--- a/rust/src/push/mod.rs
+++ b/rust/src/push/mod.rs
@@ -65,8 +65,8 @@ use anyhow::{Context, Error};
use log::warn;
use pyo3::exceptions::PyTypeError;
use pyo3::prelude::*;
-use pyo3::types::{PyBool, PyList, PyLong, PyString};
-use pythonize::{depythonize_bound, pythonize};
+use pyo3::types::{PyBool, PyInt, PyList, PyString};
+use pythonize::{depythonize, pythonize, PythonizeError};
use serde::de::Error as _;
use serde::{Deserialize, Serialize};
use serde_json::Value;
@@ -79,7 +79,7 @@ pub mod utils;
/// Called when registering modules with python.
pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
- let child_module = PyModule::new_bound(py, "push")?;
+ let child_module = PyModule::new(py, "push")?;
child_module.add_class::<PushRule>()?;
child_module.add_class::<PushRules>()?;
child_module.add_class::<FilteredPushRules>()?;
@@ -90,7 +90,7 @@ pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()>
// We need to manually add the module to sys.modules to make `from
// synapse.synapse_rust import push` work.
- py.import_bound("sys")?
+ py.import("sys")?
.getattr("modules")?
.set_item("synapse.synapse_rust.push", child_module)?;
@@ -182,12 +182,16 @@ pub enum Action {
Unknown(Value),
}
-impl IntoPy<PyObject> for Action {
- fn into_py(self, py: Python<'_>) -> PyObject {
+impl<'py> IntoPyObject<'py> for Action {
+ type Target = PyAny;
+ type Output = Bound<'py, Self::Target>;
+ type Error = PythonizeError;
+
+ fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
// When we pass the `Action` struct to Python we want it to be converted
// to a dict. We use `pythonize`, which converts the struct using the
// `serde` serialization.
- pythonize(py, &self).expect("valid action")
+ pythonize(py, &self)
}
}
@@ -270,13 +274,13 @@ pub enum SimpleJsonValue {
}
impl<'source> FromPyObject<'source> for SimpleJsonValue {
- fn extract(ob: &'source PyAny) -> PyResult<Self> {
+ fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> {
if let Ok(s) = ob.downcast::<PyString>() {
Ok(SimpleJsonValue::Str(Cow::Owned(s.to_string())))
// A bool *is* an int, ensure we try bool first.
} else if let Ok(b) = ob.downcast::<PyBool>() {
Ok(SimpleJsonValue::Bool(b.extract()?))
- } else if let Ok(i) = ob.downcast::<PyLong>() {
+ } else if let Ok(i) = ob.downcast::<PyInt>() {
Ok(SimpleJsonValue::Int(i.extract()?))
} else if ob.is_none() {
Ok(SimpleJsonValue::Null)
@@ -298,15 +302,19 @@ pub enum JsonValue {
}
impl<'source> FromPyObject<'source> for JsonValue {
- fn extract(ob: &'source PyAny) -> PyResult<Self> {
+ fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> {
if let Ok(l) = ob.downcast::<PyList>() {
- match l.iter().map(SimpleJsonValue::extract).collect() {
+ match l
+ .iter()
+ .map(|it| SimpleJsonValue::extract_bound(&it))
+ .collect()
+ {
Ok(a) => Ok(JsonValue::Array(a)),
Err(e) => Err(PyTypeError::new_err(format!(
"Can't convert to JsonValue::Array: {e}"
))),
}
- } else if let Ok(v) = SimpleJsonValue::extract(ob) {
+ } else if let Ok(v) = SimpleJsonValue::extract_bound(ob) {
Ok(JsonValue::Value(v))
} else {
Err(PyTypeError::new_err(format!(
@@ -363,15 +371,19 @@ pub enum KnownCondition {
},
}
-impl IntoPy<PyObject> for Condition {
- fn into_py(self, py: Python<'_>) -> PyObject {
- pythonize(py, &self).expect("valid condition")
+impl<'source> IntoPyObject<'source> for Condition {
+ type Target = PyAny;
+ type Output = Bound<'source, Self::Target>;
+ type Error = PythonizeError;
+
+ fn into_pyobject(self, py: Python<'source>) -> Result<Self::Output, Self::Error> {
+ pythonize(py, &self)
}
}
impl<'source> FromPyObject<'source> for Condition {
fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> {
- Ok(depythonize_bound(ob.clone())?)
+ Ok(depythonize(ob)?)
}
}
@@ -534,6 +546,7 @@ pub struct FilteredPushRules {
msc3381_polls_enabled: bool,
msc3664_enabled: bool,
msc4028_push_encrypted_events: bool,
+ msc4210_enabled: bool,
}
#[pymethods]
@@ -546,6 +559,7 @@ impl FilteredPushRules {
msc3381_polls_enabled: bool,
msc3664_enabled: bool,
msc4028_push_encrypted_events: bool,
+ msc4210_enabled: bool,
) -> Self {
Self {
push_rules,
@@ -554,6 +568,7 @@ impl FilteredPushRules {
msc3381_polls_enabled,
msc3664_enabled,
msc4028_push_encrypted_events,
+ msc4210_enabled,
}
}
@@ -596,6 +611,14 @@ impl FilteredPushRules {
return false;
}
+ if self.msc4210_enabled
+ && (rule.rule_id == "global/override/.m.rule.contains_display_name"
+ || rule.rule_id == "global/content/.m.rule.contains_user_name"
+ || rule.rule_id == "global/override/.m.rule.roomnotif")
+ {
+ return false;
+ }
+
true
})
.map(|r| {
diff --git a/rust/src/push/utils.rs b/rust/src/push/utils.rs
index 28ebed62c8..59536c9954 100644
--- a/rust/src/push/utils.rs
+++ b/rust/src/push/utils.rs
@@ -23,7 +23,6 @@ use anyhow::bail;
use anyhow::Context;
use anyhow::Error;
use lazy_static::lazy_static;
-use regex;
use regex::Regex;
use regex::RegexBuilder;
diff --git a/rust/src/rendezvous/mod.rs b/rust/src/rendezvous/mod.rs
index f69f45490f..3148e0f67a 100644
--- a/rust/src/rendezvous/mod.rs
+++ b/rust/src/rendezvous/mod.rs
@@ -29,7 +29,7 @@ use pyo3::{
exceptions::PyValueError,
pyclass, pymethods,
types::{PyAnyMethods, PyModule, PyModuleMethods},
- Bound, Py, PyAny, PyObject, PyResult, Python, ToPyObject,
+ Bound, IntoPyObject, Py, PyAny, PyObject, PyResult, Python,
};
use ulid::Ulid;
@@ -37,6 +37,7 @@ use self::session::Session;
use crate::{
errors::{NotFoundError, SynapseError},
http::{http_request_from_twisted, http_response_to_twisted, HeaderMapPyExt},
+ UnwrapInfallible,
};
mod session;
@@ -46,7 +47,7 @@ fn prepare_headers(headers: &mut HeaderMap, session: &Session) {
headers.typed_insert(AccessControlAllowOrigin::ANY);
headers.typed_insert(AccessControlExposeHeaders::from_iter([ETAG]));
headers.typed_insert(Pragma::no_cache());
- headers.typed_insert(CacheControl::new().with_no_store());
+ headers.typed_insert(CacheControl::new().with_no_store().with_no_transform());
headers.typed_insert(session.etag());
headers.typed_insert(session.expires());
headers.typed_insert(session.last_modified());
@@ -125,7 +126,11 @@ impl RendezvousHandler {
let base = Uri::try_from(format!("{base}_synapse/client/rendezvous"))
.map_err(|_| PyValueError::new_err("Invalid base URI"))?;
- let clock = homeserver.call_method0("get_clock")?.to_object(py);
+ let clock = homeserver
+ .call_method0("get_clock")?
+ .into_pyobject(py)
+ .unwrap_infallible()
+ .unbind();
// Construct a Python object so that we can get a reference to the
// evict method and schedule it to run.
@@ -187,10 +192,12 @@ impl RendezvousHandler {
"url": uri,
})
.to_string();
+ let length = response.len() as _;
let mut response = Response::new(response.as_bytes());
*response.status_mut() = StatusCode::CREATED;
response.headers_mut().typed_insert(ContentType::json());
+ response.headers_mut().typed_insert(ContentLength(length));
prepare_headers(response.headers_mut(), &session);
http_response_to_twisted(twisted_request, response)?;
@@ -288,6 +295,14 @@ impl RendezvousHandler {
let mut response = Response::new(Bytes::new());
*response.status_mut() = StatusCode::ACCEPTED;
prepare_headers(response.headers_mut(), session);
+
+ // Even though this isn't mandated by the MSC, we set a Content-Type on the response. It
+ // doesn't do any harm as the body is empty, but this helps escape a bug in some reverse
+ // proxy/cache setup which strips the ETag header if there is no Content-Type set.
+ // Specifically, we noticed this behaviour when placing Synapse behind Cloudflare.
+ response.headers_mut().typed_insert(ContentType::text());
+ response.headers_mut().typed_insert(ContentLength(0));
+
http_response_to_twisted(twisted_request, response)?;
Ok(())
@@ -304,6 +319,7 @@ impl RendezvousHandler {
response
.headers_mut()
.typed_insert(AccessControlAllowOrigin::ANY);
+ response.headers_mut().typed_insert(ContentLength(0));
http_response_to_twisted(twisted_request, response)?;
Ok(())
@@ -311,7 +327,7 @@ impl RendezvousHandler {
}
pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
- let child_module = PyModule::new_bound(py, "rendezvous")?;
+ let child_module = PyModule::new(py, "rendezvous")?;
child_module.add_class::<RendezvousHandler>()?;
@@ -319,7 +335,7 @@ pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()>
// We need to manually add the module to sys.modules to make `from
// synapse.synapse_rust import rendezvous` work.
- py.import_bound("sys")?
+ py.import("sys")?
.getattr("modules")?
.set_item("synapse.synapse_rust.rendezvous", child_module)?;
|