summary refs log tree commit diff
path: root/rust
diff options
context:
space:
mode:
authorQuentin Gliech <quenting@element.io>2024-04-25 14:50:12 +0200
committerGitHub <noreply@github.com>2024-04-25 12:50:12 +0000
commit2e92b718d5ea063af4b2dc9412dcd2ce625b4987 (patch)
tree9d916997bcd61eaf055c622b2aa27919faeb9cc9 /rust
parentAdd type annotation to `visited_chains` (#17125) (diff)
downloadsynapse-2e92b718d5ea063af4b2dc9412dcd2ce625b4987.tar.xz
MSC4108 implementation (#17056)
Co-authored-by: Hugh Nimmo-Smith <hughns@element.io>
Co-authored-by: Hugh Nimmo-Smith <hughns@users.noreply.github.com>
Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com>
Diffstat (limited to 'rust')
-rw-r--r--rust/Cargo.toml4
-rw-r--r--rust/src/lib.rs2
-rw-r--r--rust/src/rendezvous/mod.rs315
-rw-r--r--rust/src/rendezvous/session.rs91
4 files changed, 412 insertions, 0 deletions
diff --git a/rust/Cargo.toml b/rust/Cargo.toml
index 9ac766182b..d41a216d1c 100644
--- a/rust/Cargo.toml
+++ b/rust/Cargo.toml
@@ -23,11 +23,13 @@ name = "synapse.synapse_rust"
 
 [dependencies]
 anyhow = "1.0.63"
+base64 = "0.21.7"
 bytes = "1.6.0"
 headers = "0.4.0"
 http = "1.1.0"
 lazy_static = "1.4.0"
 log = "0.4.17"
+mime = "0.3.17"
 pyo3 = { version = "0.20.0", features = [
     "macros",
     "anyhow",
@@ -37,8 +39,10 @@ pyo3 = { version = "0.20.0", features = [
 pyo3-log = "0.9.0"
 pythonize = "0.20.0"
 regex = "1.6.0"
+sha2 = "0.10.8"
 serde = { version = "1.0.144", features = ["derive"] }
 serde_json = "1.0.85"
+ulid = "1.1.2"
 
 [features]
 extension-module = ["pyo3/extension-module"]
diff --git a/rust/src/lib.rs b/rust/src/lib.rs
index 36a3d64528..9bd1f17ad9 100644
--- a/rust/src/lib.rs
+++ b/rust/src/lib.rs
@@ -7,6 +7,7 @@ pub mod errors;
 pub mod events;
 pub mod http;
 pub mod push;
+pub mod rendezvous;
 
 lazy_static! {
     static ref LOGGING_HANDLE: ResetHandle = pyo3_log::init();
@@ -45,6 +46,7 @@ fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> {
     acl::register_module(py, m)?;
     push::register_module(py, m)?;
     events::register_module(py, m)?;
+    rendezvous::register_module(py, m)?;
 
     Ok(())
 }
diff --git a/rust/src/rendezvous/mod.rs b/rust/src/rendezvous/mod.rs
new file mode 100644
index 0000000000..c0f5d8b600
--- /dev/null
+++ b/rust/src/rendezvous/mod.rs
@@ -0,0 +1,315 @@
+/*
+ * This file is licensed under the Affero General Public License (AGPL) version 3.
+ *
+ * Copyright (C) 2024 New Vector, Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as
+ * published by the Free Software Foundation, either version 3 of the
+ * License, or (at your option) any later version.
+ *
+ * See the GNU Affero General Public License for more details:
+ * <https://www.gnu.org/licenses/agpl-3.0.html>.
+ *
+ */
+
+use std::{
+    collections::{BTreeMap, HashMap},
+    time::{Duration, SystemTime},
+};
+
+use bytes::Bytes;
+use headers::{
+    AccessControlAllowOrigin, AccessControlExposeHeaders, CacheControl, ContentLength, ContentType,
+    HeaderMapExt, IfMatch, IfNoneMatch, Pragma,
+};
+use http::{header::ETAG, HeaderMap, Response, StatusCode, Uri};
+use mime::Mime;
+use pyo3::{
+    exceptions::PyValueError, pyclass, pymethods, types::PyModule, Py, PyAny, PyObject, PyResult,
+    Python, ToPyObject,
+};
+use ulid::Ulid;
+
+use self::session::Session;
+use crate::{
+    errors::{NotFoundError, SynapseError},
+    http::{http_request_from_twisted, http_response_to_twisted, HeaderMapPyExt},
+};
+
+mod session;
+
+// n.b. Because OPTIONS requests are handled by the Python code, we don't need to set Access-Control-Allow-Headers.
+fn prepare_headers(headers: &mut HeaderMap, session: &Session) {
+    headers.typed_insert(AccessControlAllowOrigin::ANY);
+    headers.typed_insert(AccessControlExposeHeaders::from_iter([ETAG]));
+    headers.typed_insert(Pragma::no_cache());
+    headers.typed_insert(CacheControl::new().with_no_store());
+    headers.typed_insert(session.etag());
+    headers.typed_insert(session.expires());
+    headers.typed_insert(session.last_modified());
+}
+
+#[pyclass]
+struct RendezvousHandler {
+    base: Uri,
+    clock: PyObject,
+    sessions: BTreeMap<Ulid, Session>,
+    capacity: usize,
+    max_content_length: u64,
+    ttl: Duration,
+}
+
+impl RendezvousHandler {
+    /// Check the input headers of a request which sets data for a session, and return the content type.
+    fn check_input_headers(&self, headers: &HeaderMap) -> PyResult<Mime> {
+        let ContentLength(content_length) = headers.typed_get_required()?;
+
+        if content_length > self.max_content_length {
+            return Err(SynapseError::new(
+                StatusCode::PAYLOAD_TOO_LARGE,
+                "Payload too large".to_owned(),
+                "M_TOO_LARGE",
+                None,
+                None,
+            ));
+        }
+
+        let content_type: ContentType = headers.typed_get_required()?;
+
+        // Content-Type must be text/plain
+        if content_type != ContentType::text() {
+            return Err(SynapseError::new(
+                StatusCode::BAD_REQUEST,
+                "Content-Type must be text/plain".to_owned(),
+                "M_INVALID_PARAM",
+                None,
+                None,
+            ));
+        }
+
+        Ok(content_type.into())
+    }
+
+    /// Evict expired sessions and remove the oldest sessions until we're under the capacity.
+    fn evict(&mut self, now: SystemTime) {
+        // First remove all the entries which expired
+        self.sessions.retain(|_, session| !session.expired(now));
+
+        // Then we remove the oldest entires until we're under the limit
+        while self.sessions.len() > self.capacity {
+            self.sessions.pop_first();
+        }
+    }
+}
+
+#[pymethods]
+impl RendezvousHandler {
+    #[new]
+    #[pyo3(signature = (homeserver, /, capacity=100, max_content_length=4*1024, eviction_interval=60*1000, ttl=60*1000))]
+    fn new(
+        py: Python<'_>,
+        homeserver: &PyAny,
+        capacity: usize,
+        max_content_length: u64,
+        eviction_interval: u64,
+        ttl: u64,
+    ) -> PyResult<Py<Self>> {
+        let base: String = homeserver
+            .getattr("config")?
+            .getattr("server")?
+            .getattr("public_baseurl")?
+            .extract()?;
+        let base = Uri::try_from(format!("{base}_synapse/client/rendezvous"))
+            .map_err(|_| PyValueError::new_err("Invalid base URI"))?;
+
+        let clock = homeserver.call_method0("get_clock")?.to_object(py);
+
+        // Construct a Python object so that we can get a reference to the
+        // evict method and schedule it to run.
+        let self_ = Py::new(
+            py,
+            Self {
+                base,
+                clock,
+                sessions: BTreeMap::new(),
+                capacity,
+                max_content_length,
+                ttl: Duration::from_millis(ttl),
+            },
+        )?;
+
+        let evict = self_.getattr(py, "_evict")?;
+        homeserver.call_method0("get_clock")?.call_method(
+            "looping_call",
+            (evict, eviction_interval),
+            None,
+        )?;
+
+        Ok(self_)
+    }
+
+    fn _evict(&mut self, py: Python<'_>) -> PyResult<()> {
+        let clock = self.clock.as_ref(py);
+        let now: u64 = clock.call_method0("time_msec")?.extract()?;
+        let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now);
+        self.evict(now);
+
+        Ok(())
+    }
+
+    fn handle_post(&mut self, py: Python<'_>, twisted_request: &PyAny) -> PyResult<()> {
+        let request = http_request_from_twisted(twisted_request)?;
+
+        let content_type = self.check_input_headers(request.headers())?;
+
+        let clock = self.clock.as_ref(py);
+        let now: u64 = clock.call_method0("time_msec")?.extract()?;
+        let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now);
+
+        // We trigger an immediate eviction if we're at 2x the capacity
+        if self.sessions.len() >= self.capacity * 2 {
+            self.evict(now);
+        }
+
+        // Generate a new ULID for the session from the current time.
+        let id = Ulid::from_datetime(now);
+
+        let uri = format!("{base}/{id}", base = self.base);
+
+        let body = request.into_body();
+
+        let session = Session::new(body, content_type, now, self.ttl);
+
+        let response = serde_json::json!({
+            "url": uri,
+        })
+        .to_string();
+
+        let mut response = Response::new(response.as_bytes());
+        *response.status_mut() = StatusCode::CREATED;
+        response.headers_mut().typed_insert(ContentType::json());
+        prepare_headers(response.headers_mut(), &session);
+        http_response_to_twisted(twisted_request, response)?;
+
+        self.sessions.insert(id, session);
+
+        Ok(())
+    }
+
+    fn handle_get(&mut self, py: Python<'_>, twisted_request: &PyAny, id: &str) -> PyResult<()> {
+        let request = http_request_from_twisted(twisted_request)?;
+
+        let if_none_match: Option<IfNoneMatch> = request.headers().typed_get_optional()?;
+
+        let now: u64 = self.clock.call_method0(py, "time_msec")?.extract(py)?;
+        let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now);
+
+        let id: Ulid = id.parse().map_err(|_| NotFoundError::new())?;
+        let session = self
+            .sessions
+            .get(&id)
+            .filter(|s| !s.expired(now))
+            .ok_or_else(NotFoundError::new)?;
+
+        if let Some(if_none_match) = if_none_match {
+            if !if_none_match.precondition_passes(&session.etag()) {
+                let mut response = Response::new(Bytes::new());
+                *response.status_mut() = StatusCode::NOT_MODIFIED;
+                prepare_headers(response.headers_mut(), session);
+                http_response_to_twisted(twisted_request, response)?;
+                return Ok(());
+            }
+        }
+
+        let mut response = Response::new(session.data());
+        *response.status_mut() = StatusCode::OK;
+        let headers = response.headers_mut();
+        prepare_headers(headers, session);
+        headers.typed_insert(session.content_type());
+        headers.typed_insert(session.content_length());
+        http_response_to_twisted(twisted_request, response)?;
+
+        Ok(())
+    }
+
+    fn handle_put(&mut self, py: Python<'_>, twisted_request: &PyAny, id: &str) -> PyResult<()> {
+        let request = http_request_from_twisted(twisted_request)?;
+
+        let content_type = self.check_input_headers(request.headers())?;
+
+        let if_match: IfMatch = request.headers().typed_get_required()?;
+
+        let data = request.into_body();
+
+        let now: u64 = self.clock.call_method0(py, "time_msec")?.extract(py)?;
+        let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now);
+
+        let id: Ulid = id.parse().map_err(|_| NotFoundError::new())?;
+        let session = self
+            .sessions
+            .get_mut(&id)
+            .filter(|s| !s.expired(now))
+            .ok_or_else(NotFoundError::new)?;
+
+        if !if_match.precondition_passes(&session.etag()) {
+            let mut headers = HeaderMap::new();
+            prepare_headers(&mut headers, session);
+
+            let mut additional_fields = HashMap::with_capacity(1);
+            additional_fields.insert(
+                String::from("org.matrix.msc4108.errcode"),
+                String::from("M_CONCURRENT_WRITE"),
+            );
+
+            return Err(SynapseError::new(
+                StatusCode::PRECONDITION_FAILED,
+                "ETag does not match".to_owned(),
+                "M_UNKNOWN", // Would be M_CONCURRENT_WRITE
+                Some(additional_fields),
+                Some(headers),
+            ));
+        }
+
+        session.update(data, content_type, now);
+
+        let mut response = Response::new(Bytes::new());
+        *response.status_mut() = StatusCode::ACCEPTED;
+        prepare_headers(response.headers_mut(), session);
+        http_response_to_twisted(twisted_request, response)?;
+
+        Ok(())
+    }
+
+    fn handle_delete(&mut self, twisted_request: &PyAny, id: &str) -> PyResult<()> {
+        let _request = http_request_from_twisted(twisted_request)?;
+
+        let id: Ulid = id.parse().map_err(|_| NotFoundError::new())?;
+        let _session = self.sessions.remove(&id).ok_or_else(NotFoundError::new)?;
+
+        let mut response = Response::new(Bytes::new());
+        *response.status_mut() = StatusCode::NO_CONTENT;
+        response
+            .headers_mut()
+            .typed_insert(AccessControlAllowOrigin::ANY);
+        http_response_to_twisted(twisted_request, response)?;
+
+        Ok(())
+    }
+}
+
+pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
+    let child_module = PyModule::new(py, "rendezvous")?;
+
+    child_module.add_class::<RendezvousHandler>()?;
+
+    m.add_submodule(child_module)?;
+
+    // We need to manually add the module to sys.modules to make `from
+    // synapse.synapse_rust import rendezvous` work.
+    py.import("sys")?
+        .getattr("modules")?
+        .set_item("synapse.synapse_rust.rendezvous", child_module)?;
+
+    Ok(())
+}
diff --git a/rust/src/rendezvous/session.rs b/rust/src/rendezvous/session.rs
new file mode 100644
index 0000000000..179304edfe
--- /dev/null
+++ b/rust/src/rendezvous/session.rs
@@ -0,0 +1,91 @@
+/*
+ * This file is licensed under the Affero General Public License (AGPL) version 3.
+ *
+ * Copyright (C) 2024 New Vector, Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as
+ * published by the Free Software Foundation, either version 3 of the
+ * License, or (at your option) any later version.
+ *
+ * See the GNU Affero General Public License for more details:
+ * <https://www.gnu.org/licenses/agpl-3.0.html>.
+ */
+
+use std::time::{Duration, SystemTime};
+
+use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
+use bytes::Bytes;
+use headers::{ContentLength, ContentType, ETag, Expires, LastModified};
+use mime::Mime;
+use sha2::{Digest, Sha256};
+
+/// A single session, containing data, metadata, and expiry information.
+pub struct Session {
+    hash: [u8; 32],
+    data: Bytes,
+    content_type: Mime,
+    last_modified: SystemTime,
+    expires: SystemTime,
+}
+
+impl Session {
+    /// Create a new session with the given data, content type, and time-to-live.
+    pub fn new(data: Bytes, content_type: Mime, now: SystemTime, ttl: Duration) -> Self {
+        let hash = Sha256::digest(&data).into();
+        Self {
+            hash,
+            data,
+            content_type,
+            expires: now + ttl,
+            last_modified: now,
+        }
+    }
+
+    /// Returns true if the session has expired at the given time.
+    pub fn expired(&self, now: SystemTime) -> bool {
+        self.expires <= now
+    }
+
+    /// Update the session with new data, content type, and last modified time.
+    pub fn update(&mut self, data: Bytes, content_type: Mime, now: SystemTime) {
+        self.hash = Sha256::digest(&data).into();
+        self.data = data;
+        self.content_type = content_type;
+        self.last_modified = now;
+    }
+
+    /// Returns the Content-Type header of the session.
+    pub fn content_type(&self) -> ContentType {
+        self.content_type.clone().into()
+    }
+
+    /// Returns the Content-Length header of the session.
+    pub fn content_length(&self) -> ContentLength {
+        ContentLength(self.data.len() as _)
+    }
+
+    /// Returns the ETag header of the session.
+    pub fn etag(&self) -> ETag {
+        let encoded = URL_SAFE_NO_PAD.encode(self.hash);
+        // SAFETY: Base64 encoding is URL-safe, so ETag-safe
+        format!("\"{encoded}\"")
+            .parse()
+            .expect("base64-encoded hash should be URL-safe")
+    }
+
+    /// Returns the Last-Modified header of the session.
+    pub fn last_modified(&self) -> LastModified {
+        self.last_modified.into()
+    }
+
+    /// Returns the Expires header of the session.
+    pub fn expires(&self) -> Expires {
+        self.expires.into()
+    }
+
+    /// Returns the current data stored in the session.
+    pub fn data(&self) -> Bytes {
+        self.data.clone()
+    }
+}