diff --git a/rust/src/events/filter.rs b/rust/src/events/filter.rs
new file mode 100644
index 0000000000..7e39972c62
--- /dev/null
+++ b/rust/src/events/filter.rs
@@ -0,0 +1,107 @@
+/*
+ * This file is licensed under the Affero General Public License (AGPL) version 3.
+ *
+ * Copyright (C) 2024 New Vector, Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as
+ * published by the Free Software Foundation, either version 3 of the
+ * License, or (at your option) any later version.
+ *
+ * See the GNU Affero General Public License for more details:
+ * <https://www.gnu.org/licenses/agpl-3.0.html>.
+ */
+
+use std::collections::HashMap;
+
+use pyo3::{exceptions::PyValueError, pyfunction, PyResult};
+
+use crate::{
+ identifier::UserID,
+ matrix_const::{
+ HISTORY_VISIBILITY_INVITED, HISTORY_VISIBILITY_JOINED, MEMBERSHIP_INVITE, MEMBERSHIP_JOIN,
+ },
+};
+
+#[pyfunction(name = "event_visible_to_server")]
+pub fn event_visible_to_server_py(
+ sender: String,
+ target_server_name: String,
+ history_visibility: String,
+ erased_senders: HashMap<String, bool>,
+ partial_state_invisible: bool,
+ memberships: Vec<(String, String)>, // (state_key, membership)
+) -> PyResult<bool> {
+ event_visible_to_server(
+ sender,
+ target_server_name,
+ history_visibility,
+ erased_senders,
+ partial_state_invisible,
+ memberships,
+ )
+ .map_err(|e| PyValueError::new_err(format!("{e}")))
+}
+
+/// Return whether the target server is allowed to see the event.
+///
+/// For a fully stated room, the target server is allowed to see an event E if:
+/// - the state at E has world readable or shared history vis, OR
+/// - the state at E says that the target server is in the room.
+///
+/// For a partially stated room, the target server is allowed to see E if:
+/// - E was created by this homeserver, AND:
+/// - the partial state at E has world readable or shared history vis, OR
+/// - the partial state at E says that the target server is in the room.
+pub fn event_visible_to_server(
+ sender: String,
+ target_server_name: String,
+ history_visibility: String,
+ erased_senders: HashMap<String, bool>,
+ partial_state_invisible: bool,
+ memberships: Vec<(String, String)>, // (state_key, membership)
+) -> anyhow::Result<bool> {
+ if let Some(&erased) = erased_senders.get(&sender) {
+ if erased {
+ return Ok(false);
+ }
+ }
+
+ if partial_state_invisible {
+ return Ok(false);
+ }
+
+ if history_visibility != HISTORY_VISIBILITY_INVITED
+ && history_visibility != HISTORY_VISIBILITY_JOINED
+ {
+ return Ok(true);
+ }
+
+ let mut visible = false;
+ for (state_key, membership) in memberships {
+ let state_key = UserID::try_from(state_key.as_ref())
+ .map_err(|e| anyhow::anyhow!(format!("invalid user_id ({state_key}): {e}")))?;
+ if state_key.server_name() != target_server_name {
+ return Err(anyhow::anyhow!(
+ "state_key.server_name ({}) does not match target_server_name ({target_server_name})",
+ state_key.server_name()
+ ));
+ }
+
+ match membership.as_str() {
+ MEMBERSHIP_INVITE => {
+ if history_visibility == HISTORY_VISIBILITY_INVITED {
+ visible = true;
+ break;
+ }
+ }
+ MEMBERSHIP_JOIN => {
+ visible = true;
+ break;
+ }
+ _ => continue,
+ }
+ }
+
+ Ok(visible)
+}
diff --git a/rust/src/events/internal_metadata.rs b/rust/src/events/internal_metadata.rs
index ad87825f16..eeb6074c10 100644
--- a/rust/src/events/internal_metadata.rs
+++ b/rust/src/events/internal_metadata.rs
@@ -41,9 +41,11 @@ use pyo3::{
pybacked::PyBackedStr,
pyclass, pymethods,
types::{PyAnyMethods, PyDict, PyDictMethods, PyString},
- Bound, IntoPy, PyAny, PyObject, PyResult, Python,
+ Bound, IntoPyObject, PyAny, PyObject, PyResult, Python,
};
+use crate::UnwrapInfallible;
+
/// Definitions of the various fields of the internal metadata.
#[derive(Clone)]
enum EventInternalMetadataData {
@@ -60,31 +62,59 @@ enum EventInternalMetadataData {
impl EventInternalMetadataData {
/// Convert the field to its name and python object.
- fn to_python_pair<'a>(&self, py: Python<'a>) -> (&'a Bound<'a, PyString>, PyObject) {
+ fn to_python_pair<'a>(&self, py: Python<'a>) -> (&'a Bound<'a, PyString>, Bound<'a, PyAny>) {
match self {
- EventInternalMetadataData::OutOfBandMembership(o) => {
- (pyo3::intern!(py, "out_of_band_membership"), o.into_py(py))
- }
- EventInternalMetadataData::SendOnBehalfOf(o) => {
- (pyo3::intern!(py, "send_on_behalf_of"), o.into_py(py))
- }
- EventInternalMetadataData::RecheckRedaction(o) => {
- (pyo3::intern!(py, "recheck_redaction"), o.into_py(py))
- }
- EventInternalMetadataData::SoftFailed(o) => {
- (pyo3::intern!(py, "soft_failed"), o.into_py(py))
- }
- EventInternalMetadataData::ProactivelySend(o) => {
- (pyo3::intern!(py, "proactively_send"), o.into_py(py))
- }
- EventInternalMetadataData::Redacted(o) => {
- (pyo3::intern!(py, "redacted"), o.into_py(py))
- }
- EventInternalMetadataData::TxnId(o) => (pyo3::intern!(py, "txn_id"), o.into_py(py)),
- EventInternalMetadataData::TokenId(o) => (pyo3::intern!(py, "token_id"), o.into_py(py)),
- EventInternalMetadataData::DeviceId(o) => {
- (pyo3::intern!(py, "device_id"), o.into_py(py))
- }
+ EventInternalMetadataData::OutOfBandMembership(o) => (
+ pyo3::intern!(py, "out_of_band_membership"),
+ o.into_pyobject(py)
+ .unwrap_infallible()
+ .to_owned()
+ .into_any(),
+ ),
+ EventInternalMetadataData::SendOnBehalfOf(o) => (
+ pyo3::intern!(py, "send_on_behalf_of"),
+ o.into_pyobject(py).unwrap_infallible().into_any(),
+ ),
+ EventInternalMetadataData::RecheckRedaction(o) => (
+ pyo3::intern!(py, "recheck_redaction"),
+ o.into_pyobject(py)
+ .unwrap_infallible()
+ .to_owned()
+ .into_any(),
+ ),
+ EventInternalMetadataData::SoftFailed(o) => (
+ pyo3::intern!(py, "soft_failed"),
+ o.into_pyobject(py)
+ .unwrap_infallible()
+ .to_owned()
+ .into_any(),
+ ),
+ EventInternalMetadataData::ProactivelySend(o) => (
+ pyo3::intern!(py, "proactively_send"),
+ o.into_pyobject(py)
+ .unwrap_infallible()
+ .to_owned()
+ .into_any(),
+ ),
+ EventInternalMetadataData::Redacted(o) => (
+ pyo3::intern!(py, "redacted"),
+ o.into_pyobject(py)
+ .unwrap_infallible()
+ .to_owned()
+ .into_any(),
+ ),
+ EventInternalMetadataData::TxnId(o) => (
+ pyo3::intern!(py, "txn_id"),
+ o.into_pyobject(py).unwrap_infallible().into_any(),
+ ),
+ EventInternalMetadataData::TokenId(o) => (
+ pyo3::intern!(py, "token_id"),
+ o.into_pyobject(py).unwrap_infallible().into_any(),
+ ),
+ EventInternalMetadataData::DeviceId(o) => (
+ pyo3::intern!(py, "device_id"),
+ o.into_pyobject(py).unwrap_infallible().into_any(),
+ ),
}
}
@@ -247,7 +277,7 @@ impl EventInternalMetadata {
///
/// Note that `outlier` and `stream_ordering` are stored in separate columns so are not returned here.
fn get_dict(&self, py: Python<'_>) -> PyResult<PyObject> {
- let dict = PyDict::new_bound(py);
+ let dict = PyDict::new(py);
for entry in &self.data {
let (key, value) = entry.to_python_pair(py);
diff --git a/rust/src/events/mod.rs b/rust/src/events/mod.rs
index a4ade1a178..209efb917b 100644
--- a/rust/src/events/mod.rs
+++ b/rust/src/events/mod.rs
@@ -22,21 +22,23 @@
use pyo3::{
types::{PyAnyMethods, PyModule, PyModuleMethods},
- Bound, PyResult, Python,
+ wrap_pyfunction, Bound, PyResult, Python,
};
+pub mod filter;
mod internal_metadata;
/// Called when registering modules with python.
pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
- let child_module = PyModule::new_bound(py, "events")?;
+ let child_module = PyModule::new(py, "events")?;
child_module.add_class::<internal_metadata::EventInternalMetadata>()?;
+ child_module.add_function(wrap_pyfunction!(filter::event_visible_to_server_py, m)?)?;
m.add_submodule(&child_module)?;
// We need to manually add the module to sys.modules to make `from
// synapse.synapse_rust import events` work.
- py.import_bound("sys")?
+ py.import("sys")?
.getattr("modules")?
.set_item("synapse.synapse_rust.events", child_module)?;
|