diff --git a/rust/Cargo.toml b/rust/Cargo.toml
index 9ac766182b..d41a216d1c 100644
--- a/rust/Cargo.toml
+++ b/rust/Cargo.toml
@@ -23,11 +23,13 @@ name = "synapse.synapse_rust"
[dependencies]
anyhow = "1.0.63"
+base64 = "0.21.7"
bytes = "1.6.0"
headers = "0.4.0"
http = "1.1.0"
lazy_static = "1.4.0"
log = "0.4.17"
+mime = "0.3.17"
pyo3 = { version = "0.20.0", features = [
"macros",
"anyhow",
@@ -37,8 +39,10 @@ pyo3 = { version = "0.20.0", features = [
pyo3-log = "0.9.0"
pythonize = "0.20.0"
regex = "1.6.0"
+sha2 = "0.10.8"
serde = { version = "1.0.144", features = ["derive"] }
serde_json = "1.0.85"
+ulid = "1.1.2"
[features]
extension-module = ["pyo3/extension-module"]
diff --git a/rust/src/lib.rs b/rust/src/lib.rs
index 36a3d64528..9bd1f17ad9 100644
--- a/rust/src/lib.rs
+++ b/rust/src/lib.rs
@@ -7,6 +7,7 @@ pub mod errors;
pub mod events;
pub mod http;
pub mod push;
+pub mod rendezvous;
lazy_static! {
static ref LOGGING_HANDLE: ResetHandle = pyo3_log::init();
@@ -45,6 +46,7 @@ fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> {
acl::register_module(py, m)?;
push::register_module(py, m)?;
events::register_module(py, m)?;
+ rendezvous::register_module(py, m)?;
Ok(())
}
diff --git a/rust/src/rendezvous/mod.rs b/rust/src/rendezvous/mod.rs
new file mode 100644
index 0000000000..c0f5d8b600
--- /dev/null
+++ b/rust/src/rendezvous/mod.rs
@@ -0,0 +1,315 @@
+/*
+ * This file is licensed under the Affero General Public License (AGPL) version 3.
+ *
+ * Copyright (C) 2024 New Vector, Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as
+ * published by the Free Software Foundation, either version 3 of the
+ * License, or (at your option) any later version.
+ *
+ * See the GNU Affero General Public License for more details:
+ * <https://www.gnu.org/licenses/agpl-3.0.html>.
+ *
+ */
+
+use std::{
+ collections::{BTreeMap, HashMap},
+ time::{Duration, SystemTime},
+};
+
+use bytes::Bytes;
+use headers::{
+ AccessControlAllowOrigin, AccessControlExposeHeaders, CacheControl, ContentLength, ContentType,
+ HeaderMapExt, IfMatch, IfNoneMatch, Pragma,
+};
+use http::{header::ETAG, HeaderMap, Response, StatusCode, Uri};
+use mime::Mime;
+use pyo3::{
+ exceptions::PyValueError, pyclass, pymethods, types::PyModule, Py, PyAny, PyObject, PyResult,
+ Python, ToPyObject,
+};
+use ulid::Ulid;
+
+use self::session::Session;
+use crate::{
+ errors::{NotFoundError, SynapseError},
+ http::{http_request_from_twisted, http_response_to_twisted, HeaderMapPyExt},
+};
+
+mod session;
+
+// n.b. Because OPTIONS requests are handled by the Python code, we don't need to set Access-Control-Allow-Headers.
+fn prepare_headers(headers: &mut HeaderMap, session: &Session) {
+ headers.typed_insert(AccessControlAllowOrigin::ANY);
+ headers.typed_insert(AccessControlExposeHeaders::from_iter([ETAG]));
+ headers.typed_insert(Pragma::no_cache());
+ headers.typed_insert(CacheControl::new().with_no_store());
+ headers.typed_insert(session.etag());
+ headers.typed_insert(session.expires());
+ headers.typed_insert(session.last_modified());
+}
+
+#[pyclass]
+struct RendezvousHandler {
+ base: Uri,
+ clock: PyObject,
+ sessions: BTreeMap<Ulid, Session>,
+ capacity: usize,
+ max_content_length: u64,
+ ttl: Duration,
+}
+
+impl RendezvousHandler {
+ /// Check the input headers of a request which sets data for a session, and return the content type.
+ fn check_input_headers(&self, headers: &HeaderMap) -> PyResult<Mime> {
+ let ContentLength(content_length) = headers.typed_get_required()?;
+
+ if content_length > self.max_content_length {
+ return Err(SynapseError::new(
+ StatusCode::PAYLOAD_TOO_LARGE,
+ "Payload too large".to_owned(),
+ "M_TOO_LARGE",
+ None,
+ None,
+ ));
+ }
+
+ let content_type: ContentType = headers.typed_get_required()?;
+
+ // Content-Type must be text/plain
+ if content_type != ContentType::text() {
+ return Err(SynapseError::new(
+ StatusCode::BAD_REQUEST,
+ "Content-Type must be text/plain".to_owned(),
+ "M_INVALID_PARAM",
+ None,
+ None,
+ ));
+ }
+
+ Ok(content_type.into())
+ }
+
+ /// Evict expired sessions and remove the oldest sessions until we're under the capacity.
+ fn evict(&mut self, now: SystemTime) {
+ // First remove all the entries which expired
+ self.sessions.retain(|_, session| !session.expired(now));
+
+ // Then we remove the oldest entires until we're under the limit
+ while self.sessions.len() > self.capacity {
+ self.sessions.pop_first();
+ }
+ }
+}
+
+#[pymethods]
+impl RendezvousHandler {
+ #[new]
+ #[pyo3(signature = (homeserver, /, capacity=100, max_content_length=4*1024, eviction_interval=60*1000, ttl=60*1000))]
+ fn new(
+ py: Python<'_>,
+ homeserver: &PyAny,
+ capacity: usize,
+ max_content_length: u64,
+ eviction_interval: u64,
+ ttl: u64,
+ ) -> PyResult<Py<Self>> {
+ let base: String = homeserver
+ .getattr("config")?
+ .getattr("server")?
+ .getattr("public_baseurl")?
+ .extract()?;
+ let base = Uri::try_from(format!("{base}_synapse/client/rendezvous"))
+ .map_err(|_| PyValueError::new_err("Invalid base URI"))?;
+
+ let clock = homeserver.call_method0("get_clock")?.to_object(py);
+
+ // Construct a Python object so that we can get a reference to the
+ // evict method and schedule it to run.
+ let self_ = Py::new(
+ py,
+ Self {
+ base,
+ clock,
+ sessions: BTreeMap::new(),
+ capacity,
+ max_content_length,
+ ttl: Duration::from_millis(ttl),
+ },
+ )?;
+
+ let evict = self_.getattr(py, "_evict")?;
+ homeserver.call_method0("get_clock")?.call_method(
+ "looping_call",
+ (evict, eviction_interval),
+ None,
+ )?;
+
+ Ok(self_)
+ }
+
+ fn _evict(&mut self, py: Python<'_>) -> PyResult<()> {
+ let clock = self.clock.as_ref(py);
+ let now: u64 = clock.call_method0("time_msec")?.extract()?;
+ let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now);
+ self.evict(now);
+
+ Ok(())
+ }
+
+ fn handle_post(&mut self, py: Python<'_>, twisted_request: &PyAny) -> PyResult<()> {
+ let request = http_request_from_twisted(twisted_request)?;
+
+ let content_type = self.check_input_headers(request.headers())?;
+
+ let clock = self.clock.as_ref(py);
+ let now: u64 = clock.call_method0("time_msec")?.extract()?;
+ let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now);
+
+ // We trigger an immediate eviction if we're at 2x the capacity
+ if self.sessions.len() >= self.capacity * 2 {
+ self.evict(now);
+ }
+
+ // Generate a new ULID for the session from the current time.
+ let id = Ulid::from_datetime(now);
+
+ let uri = format!("{base}/{id}", base = self.base);
+
+ let body = request.into_body();
+
+ let session = Session::new(body, content_type, now, self.ttl);
+
+ let response = serde_json::json!({
+ "url": uri,
+ })
+ .to_string();
+
+ let mut response = Response::new(response.as_bytes());
+ *response.status_mut() = StatusCode::CREATED;
+ response.headers_mut().typed_insert(ContentType::json());
+ prepare_headers(response.headers_mut(), &session);
+ http_response_to_twisted(twisted_request, response)?;
+
+ self.sessions.insert(id, session);
+
+ Ok(())
+ }
+
+ fn handle_get(&mut self, py: Python<'_>, twisted_request: &PyAny, id: &str) -> PyResult<()> {
+ let request = http_request_from_twisted(twisted_request)?;
+
+ let if_none_match: Option<IfNoneMatch> = request.headers().typed_get_optional()?;
+
+ let now: u64 = self.clock.call_method0(py, "time_msec")?.extract(py)?;
+ let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now);
+
+ let id: Ulid = id.parse().map_err(|_| NotFoundError::new())?;
+ let session = self
+ .sessions
+ .get(&id)
+ .filter(|s| !s.expired(now))
+ .ok_or_else(NotFoundError::new)?;
+
+ if let Some(if_none_match) = if_none_match {
+ if !if_none_match.precondition_passes(&session.etag()) {
+ let mut response = Response::new(Bytes::new());
+ *response.status_mut() = StatusCode::NOT_MODIFIED;
+ prepare_headers(response.headers_mut(), session);
+ http_response_to_twisted(twisted_request, response)?;
+ return Ok(());
+ }
+ }
+
+ let mut response = Response::new(session.data());
+ *response.status_mut() = StatusCode::OK;
+ let headers = response.headers_mut();
+ prepare_headers(headers, session);
+ headers.typed_insert(session.content_type());
+ headers.typed_insert(session.content_length());
+ http_response_to_twisted(twisted_request, response)?;
+
+ Ok(())
+ }
+
+ fn handle_put(&mut self, py: Python<'_>, twisted_request: &PyAny, id: &str) -> PyResult<()> {
+ let request = http_request_from_twisted(twisted_request)?;
+
+ let content_type = self.check_input_headers(request.headers())?;
+
+ let if_match: IfMatch = request.headers().typed_get_required()?;
+
+ let data = request.into_body();
+
+ let now: u64 = self.clock.call_method0(py, "time_msec")?.extract(py)?;
+ let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now);
+
+ let id: Ulid = id.parse().map_err(|_| NotFoundError::new())?;
+ let session = self
+ .sessions
+ .get_mut(&id)
+ .filter(|s| !s.expired(now))
+ .ok_or_else(NotFoundError::new)?;
+
+ if !if_match.precondition_passes(&session.etag()) {
+ let mut headers = HeaderMap::new();
+ prepare_headers(&mut headers, session);
+
+ let mut additional_fields = HashMap::with_capacity(1);
+ additional_fields.insert(
+ String::from("org.matrix.msc4108.errcode"),
+ String::from("M_CONCURRENT_WRITE"),
+ );
+
+ return Err(SynapseError::new(
+ StatusCode::PRECONDITION_FAILED,
+ "ETag does not match".to_owned(),
+ "M_UNKNOWN", // Would be M_CONCURRENT_WRITE
+ Some(additional_fields),
+ Some(headers),
+ ));
+ }
+
+ session.update(data, content_type, now);
+
+ let mut response = Response::new(Bytes::new());
+ *response.status_mut() = StatusCode::ACCEPTED;
+ prepare_headers(response.headers_mut(), session);
+ http_response_to_twisted(twisted_request, response)?;
+
+ Ok(())
+ }
+
+ fn handle_delete(&mut self, twisted_request: &PyAny, id: &str) -> PyResult<()> {
+ let _request = http_request_from_twisted(twisted_request)?;
+
+ let id: Ulid = id.parse().map_err(|_| NotFoundError::new())?;
+ let _session = self.sessions.remove(&id).ok_or_else(NotFoundError::new)?;
+
+ let mut response = Response::new(Bytes::new());
+ *response.status_mut() = StatusCode::NO_CONTENT;
+ response
+ .headers_mut()
+ .typed_insert(AccessControlAllowOrigin::ANY);
+ http_response_to_twisted(twisted_request, response)?;
+
+ Ok(())
+ }
+}
+
+pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
+ let child_module = PyModule::new(py, "rendezvous")?;
+
+ child_module.add_class::<RendezvousHandler>()?;
+
+ m.add_submodule(child_module)?;
+
+ // We need to manually add the module to sys.modules to make `from
+ // synapse.synapse_rust import rendezvous` work.
+ py.import("sys")?
+ .getattr("modules")?
+ .set_item("synapse.synapse_rust.rendezvous", child_module)?;
+
+ Ok(())
+}
diff --git a/rust/src/rendezvous/session.rs b/rust/src/rendezvous/session.rs
new file mode 100644
index 0000000000..179304edfe
--- /dev/null
+++ b/rust/src/rendezvous/session.rs
@@ -0,0 +1,91 @@
+/*
+ * This file is licensed under the Affero General Public License (AGPL) version 3.
+ *
+ * Copyright (C) 2024 New Vector, Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as
+ * published by the Free Software Foundation, either version 3 of the
+ * License, or (at your option) any later version.
+ *
+ * See the GNU Affero General Public License for more details:
+ * <https://www.gnu.org/licenses/agpl-3.0.html>.
+ */
+
+use std::time::{Duration, SystemTime};
+
+use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
+use bytes::Bytes;
+use headers::{ContentLength, ContentType, ETag, Expires, LastModified};
+use mime::Mime;
+use sha2::{Digest, Sha256};
+
+/// A single session, containing data, metadata, and expiry information.
+pub struct Session {
+ hash: [u8; 32],
+ data: Bytes,
+ content_type: Mime,
+ last_modified: SystemTime,
+ expires: SystemTime,
+}
+
+impl Session {
+ /// Create a new session with the given data, content type, and time-to-live.
+ pub fn new(data: Bytes, content_type: Mime, now: SystemTime, ttl: Duration) -> Self {
+ let hash = Sha256::digest(&data).into();
+ Self {
+ hash,
+ data,
+ content_type,
+ expires: now + ttl,
+ last_modified: now,
+ }
+ }
+
+ /// Returns true if the session has expired at the given time.
+ pub fn expired(&self, now: SystemTime) -> bool {
+ self.expires <= now
+ }
+
+ /// Update the session with new data, content type, and last modified time.
+ pub fn update(&mut self, data: Bytes, content_type: Mime, now: SystemTime) {
+ self.hash = Sha256::digest(&data).into();
+ self.data = data;
+ self.content_type = content_type;
+ self.last_modified = now;
+ }
+
+ /// Returns the Content-Type header of the session.
+ pub fn content_type(&self) -> ContentType {
+ self.content_type.clone().into()
+ }
+
+ /// Returns the Content-Length header of the session.
+ pub fn content_length(&self) -> ContentLength {
+ ContentLength(self.data.len() as _)
+ }
+
+ /// Returns the ETag header of the session.
+ pub fn etag(&self) -> ETag {
+ let encoded = URL_SAFE_NO_PAD.encode(self.hash);
+ // SAFETY: Base64 encoding is URL-safe, so ETag-safe
+ format!("\"{encoded}\"")
+ .parse()
+ .expect("base64-encoded hash should be URL-safe")
+ }
+
+ /// Returns the Last-Modified header of the session.
+ pub fn last_modified(&self) -> LastModified {
+ self.last_modified.into()
+ }
+
+ /// Returns the Expires header of the session.
+ pub fn expires(&self) -> Expires {
+ self.expires.into()
+ }
+
+ /// Returns the current data stored in the session.
+ pub fn data(&self) -> Bytes {
+ self.data.clone()
+ }
+}
|