diff options
author | Quentin Gliech <quenting@element.io> | 2024-04-25 14:50:12 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-25 12:50:12 +0000 |
commit | 2e92b718d5ea063af4b2dc9412dcd2ce625b4987 (patch) | |
tree | 9d916997bcd61eaf055c622b2aa27919faeb9cc9 /rust | |
parent | Add type annotation to `visited_chains` (#17125) (diff) | |
download | synapse-2e92b718d5ea063af4b2dc9412dcd2ce625b4987.tar.xz |
MSC4108 implementation (#17056)
Co-authored-by: Hugh Nimmo-Smith <hughns@element.io> Co-authored-by: Hugh Nimmo-Smith <hughns@users.noreply.github.com> Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com>
Diffstat (limited to 'rust')
-rw-r--r-- | rust/Cargo.toml | 4 | ||||
-rw-r--r-- | rust/src/lib.rs | 2 | ||||
-rw-r--r-- | rust/src/rendezvous/mod.rs | 315 | ||||
-rw-r--r-- | rust/src/rendezvous/session.rs | 91 |
4 files changed, 412 insertions, 0 deletions
diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 9ac766182b..d41a216d1c 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -23,11 +23,13 @@ name = "synapse.synapse_rust" [dependencies] anyhow = "1.0.63" +base64 = "0.21.7" bytes = "1.6.0" headers = "0.4.0" http = "1.1.0" lazy_static = "1.4.0" log = "0.4.17" +mime = "0.3.17" pyo3 = { version = "0.20.0", features = [ "macros", "anyhow", @@ -37,8 +39,10 @@ pyo3 = { version = "0.20.0", features = [ pyo3-log = "0.9.0" pythonize = "0.20.0" regex = "1.6.0" +sha2 = "0.10.8" serde = { version = "1.0.144", features = ["derive"] } serde_json = "1.0.85" +ulid = "1.1.2" [features] extension-module = ["pyo3/extension-module"] diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 36a3d64528..9bd1f17ad9 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -7,6 +7,7 @@ pub mod errors; pub mod events; pub mod http; pub mod push; +pub mod rendezvous; lazy_static! { static ref LOGGING_HANDLE: ResetHandle = pyo3_log::init(); @@ -45,6 +46,7 @@ fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> { acl::register_module(py, m)?; push::register_module(py, m)?; events::register_module(py, m)?; + rendezvous::register_module(py, m)?; Ok(()) } diff --git a/rust/src/rendezvous/mod.rs b/rust/src/rendezvous/mod.rs new file mode 100644 index 0000000000..c0f5d8b600 --- /dev/null +++ b/rust/src/rendezvous/mod.rs @@ -0,0 +1,315 @@ +/* + * This file is licensed under the Affero General Public License (AGPL) version 3. + * + * Copyright (C) 2024 New Vector, Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * See the GNU Affero General Public License for more details: + * <https://www.gnu.org/licenses/agpl-3.0.html>. + * + */ + +use std::{ + collections::{BTreeMap, HashMap}, + time::{Duration, SystemTime}, +}; + +use bytes::Bytes; +use headers::{ + AccessControlAllowOrigin, AccessControlExposeHeaders, CacheControl, ContentLength, ContentType, + HeaderMapExt, IfMatch, IfNoneMatch, Pragma, +}; +use http::{header::ETAG, HeaderMap, Response, StatusCode, Uri}; +use mime::Mime; +use pyo3::{ + exceptions::PyValueError, pyclass, pymethods, types::PyModule, Py, PyAny, PyObject, PyResult, + Python, ToPyObject, +}; +use ulid::Ulid; + +use self::session::Session; +use crate::{ + errors::{NotFoundError, SynapseError}, + http::{http_request_from_twisted, http_response_to_twisted, HeaderMapPyExt}, +}; + +mod session; + +// n.b. Because OPTIONS requests are handled by the Python code, we don't need to set Access-Control-Allow-Headers. +fn prepare_headers(headers: &mut HeaderMap, session: &Session) { + headers.typed_insert(AccessControlAllowOrigin::ANY); + headers.typed_insert(AccessControlExposeHeaders::from_iter([ETAG])); + headers.typed_insert(Pragma::no_cache()); + headers.typed_insert(CacheControl::new().with_no_store()); + headers.typed_insert(session.etag()); + headers.typed_insert(session.expires()); + headers.typed_insert(session.last_modified()); +} + +#[pyclass] +struct RendezvousHandler { + base: Uri, + clock: PyObject, + sessions: BTreeMap<Ulid, Session>, + capacity: usize, + max_content_length: u64, + ttl: Duration, +} + +impl RendezvousHandler { + /// Check the input headers of a request which sets data for a session, and return the content type. + fn check_input_headers(&self, headers: &HeaderMap) -> PyResult<Mime> { + let ContentLength(content_length) = headers.typed_get_required()?; + + if content_length > self.max_content_length { + return Err(SynapseError::new( + StatusCode::PAYLOAD_TOO_LARGE, + "Payload too large".to_owned(), + "M_TOO_LARGE", + None, + None, + )); + } + + let content_type: ContentType = headers.typed_get_required()?; + + // Content-Type must be text/plain + if content_type != ContentType::text() { + return Err(SynapseError::new( + StatusCode::BAD_REQUEST, + "Content-Type must be text/plain".to_owned(), + "M_INVALID_PARAM", + None, + None, + )); + } + + Ok(content_type.into()) + } + + /// Evict expired sessions and remove the oldest sessions until we're under the capacity. + fn evict(&mut self, now: SystemTime) { + // First remove all the entries which expired + self.sessions.retain(|_, session| !session.expired(now)); + + // Then we remove the oldest entires until we're under the limit + while self.sessions.len() > self.capacity { + self.sessions.pop_first(); + } + } +} + +#[pymethods] +impl RendezvousHandler { + #[new] + #[pyo3(signature = (homeserver, /, capacity=100, max_content_length=4*1024, eviction_interval=60*1000, ttl=60*1000))] + fn new( + py: Python<'_>, + homeserver: &PyAny, + capacity: usize, + max_content_length: u64, + eviction_interval: u64, + ttl: u64, + ) -> PyResult<Py<Self>> { + let base: String = homeserver + .getattr("config")? + .getattr("server")? + .getattr("public_baseurl")? + .extract()?; + let base = Uri::try_from(format!("{base}_synapse/client/rendezvous")) + .map_err(|_| PyValueError::new_err("Invalid base URI"))?; + + let clock = homeserver.call_method0("get_clock")?.to_object(py); + + // Construct a Python object so that we can get a reference to the + // evict method and schedule it to run. + let self_ = Py::new( + py, + Self { + base, + clock, + sessions: BTreeMap::new(), + capacity, + max_content_length, + ttl: Duration::from_millis(ttl), + }, + )?; + + let evict = self_.getattr(py, "_evict")?; + homeserver.call_method0("get_clock")?.call_method( + "looping_call", + (evict, eviction_interval), + None, + )?; + + Ok(self_) + } + + fn _evict(&mut self, py: Python<'_>) -> PyResult<()> { + let clock = self.clock.as_ref(py); + let now: u64 = clock.call_method0("time_msec")?.extract()?; + let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now); + self.evict(now); + + Ok(()) + } + + fn handle_post(&mut self, py: Python<'_>, twisted_request: &PyAny) -> PyResult<()> { + let request = http_request_from_twisted(twisted_request)?; + + let content_type = self.check_input_headers(request.headers())?; + + let clock = self.clock.as_ref(py); + let now: u64 = clock.call_method0("time_msec")?.extract()?; + let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now); + + // We trigger an immediate eviction if we're at 2x the capacity + if self.sessions.len() >= self.capacity * 2 { + self.evict(now); + } + + // Generate a new ULID for the session from the current time. + let id = Ulid::from_datetime(now); + + let uri = format!("{base}/{id}", base = self.base); + + let body = request.into_body(); + + let session = Session::new(body, content_type, now, self.ttl); + + let response = serde_json::json!({ + "url": uri, + }) + .to_string(); + + let mut response = Response::new(response.as_bytes()); + *response.status_mut() = StatusCode::CREATED; + response.headers_mut().typed_insert(ContentType::json()); + prepare_headers(response.headers_mut(), &session); + http_response_to_twisted(twisted_request, response)?; + + self.sessions.insert(id, session); + + Ok(()) + } + + fn handle_get(&mut self, py: Python<'_>, twisted_request: &PyAny, id: &str) -> PyResult<()> { + let request = http_request_from_twisted(twisted_request)?; + + let if_none_match: Option<IfNoneMatch> = request.headers().typed_get_optional()?; + + let now: u64 = self.clock.call_method0(py, "time_msec")?.extract(py)?; + let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now); + + let id: Ulid = id.parse().map_err(|_| NotFoundError::new())?; + let session = self + .sessions + .get(&id) + .filter(|s| !s.expired(now)) + .ok_or_else(NotFoundError::new)?; + + if let Some(if_none_match) = if_none_match { + if !if_none_match.precondition_passes(&session.etag()) { + let mut response = Response::new(Bytes::new()); + *response.status_mut() = StatusCode::NOT_MODIFIED; + prepare_headers(response.headers_mut(), session); + http_response_to_twisted(twisted_request, response)?; + return Ok(()); + } + } + + let mut response = Response::new(session.data()); + *response.status_mut() = StatusCode::OK; + let headers = response.headers_mut(); + prepare_headers(headers, session); + headers.typed_insert(session.content_type()); + headers.typed_insert(session.content_length()); + http_response_to_twisted(twisted_request, response)?; + + Ok(()) + } + + fn handle_put(&mut self, py: Python<'_>, twisted_request: &PyAny, id: &str) -> PyResult<()> { + let request = http_request_from_twisted(twisted_request)?; + + let content_type = self.check_input_headers(request.headers())?; + + let if_match: IfMatch = request.headers().typed_get_required()?; + + let data = request.into_body(); + + let now: u64 = self.clock.call_method0(py, "time_msec")?.extract(py)?; + let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now); + + let id: Ulid = id.parse().map_err(|_| NotFoundError::new())?; + let session = self + .sessions + .get_mut(&id) + .filter(|s| !s.expired(now)) + .ok_or_else(NotFoundError::new)?; + + if !if_match.precondition_passes(&session.etag()) { + let mut headers = HeaderMap::new(); + prepare_headers(&mut headers, session); + + let mut additional_fields = HashMap::with_capacity(1); + additional_fields.insert( + String::from("org.matrix.msc4108.errcode"), + String::from("M_CONCURRENT_WRITE"), + ); + + return Err(SynapseError::new( + StatusCode::PRECONDITION_FAILED, + "ETag does not match".to_owned(), + "M_UNKNOWN", // Would be M_CONCURRENT_WRITE + Some(additional_fields), + Some(headers), + )); + } + + session.update(data, content_type, now); + + let mut response = Response::new(Bytes::new()); + *response.status_mut() = StatusCode::ACCEPTED; + prepare_headers(response.headers_mut(), session); + http_response_to_twisted(twisted_request, response)?; + + Ok(()) + } + + fn handle_delete(&mut self, twisted_request: &PyAny, id: &str) -> PyResult<()> { + let _request = http_request_from_twisted(twisted_request)?; + + let id: Ulid = id.parse().map_err(|_| NotFoundError::new())?; + let _session = self.sessions.remove(&id).ok_or_else(NotFoundError::new)?; + + let mut response = Response::new(Bytes::new()); + *response.status_mut() = StatusCode::NO_CONTENT; + response + .headers_mut() + .typed_insert(AccessControlAllowOrigin::ANY); + http_response_to_twisted(twisted_request, response)?; + + Ok(()) + } +} + +pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> { + let child_module = PyModule::new(py, "rendezvous")?; + + child_module.add_class::<RendezvousHandler>()?; + + m.add_submodule(child_module)?; + + // We need to manually add the module to sys.modules to make `from + // synapse.synapse_rust import rendezvous` work. + py.import("sys")? + .getattr("modules")? + .set_item("synapse.synapse_rust.rendezvous", child_module)?; + + Ok(()) +} diff --git a/rust/src/rendezvous/session.rs b/rust/src/rendezvous/session.rs new file mode 100644 index 0000000000..179304edfe --- /dev/null +++ b/rust/src/rendezvous/session.rs @@ -0,0 +1,91 @@ +/* + * This file is licensed under the Affero General Public License (AGPL) version 3. + * + * Copyright (C) 2024 New Vector, Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * See the GNU Affero General Public License for more details: + * <https://www.gnu.org/licenses/agpl-3.0.html>. + */ + +use std::time::{Duration, SystemTime}; + +use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _}; +use bytes::Bytes; +use headers::{ContentLength, ContentType, ETag, Expires, LastModified}; +use mime::Mime; +use sha2::{Digest, Sha256}; + +/// A single session, containing data, metadata, and expiry information. +pub struct Session { + hash: [u8; 32], + data: Bytes, + content_type: Mime, + last_modified: SystemTime, + expires: SystemTime, +} + +impl Session { + /// Create a new session with the given data, content type, and time-to-live. + pub fn new(data: Bytes, content_type: Mime, now: SystemTime, ttl: Duration) -> Self { + let hash = Sha256::digest(&data).into(); + Self { + hash, + data, + content_type, + expires: now + ttl, + last_modified: now, + } + } + + /// Returns true if the session has expired at the given time. + pub fn expired(&self, now: SystemTime) -> bool { + self.expires <= now + } + + /// Update the session with new data, content type, and last modified time. + pub fn update(&mut self, data: Bytes, content_type: Mime, now: SystemTime) { + self.hash = Sha256::digest(&data).into(); + self.data = data; + self.content_type = content_type; + self.last_modified = now; + } + + /// Returns the Content-Type header of the session. + pub fn content_type(&self) -> ContentType { + self.content_type.clone().into() + } + + /// Returns the Content-Length header of the session. + pub fn content_length(&self) -> ContentLength { + ContentLength(self.data.len() as _) + } + + /// Returns the ETag header of the session. + pub fn etag(&self) -> ETag { + let encoded = URL_SAFE_NO_PAD.encode(self.hash); + // SAFETY: Base64 encoding is URL-safe, so ETag-safe + format!("\"{encoded}\"") + .parse() + .expect("base64-encoded hash should be URL-safe") + } + + /// Returns the Last-Modified header of the session. + pub fn last_modified(&self) -> LastModified { + self.last_modified.into() + } + + /// Returns the Expires header of the session. + pub fn expires(&self) -> Expires { + self.expires.into() + } + + /// Returns the current data stored in the session. + pub fn data(&self) -> Bytes { + self.data.clone() + } +} |