summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--Cargo.lock92
-rw-r--r--changelog.d/17081.misc1
-rw-r--r--rust/Cargo.toml3
-rw-r--r--rust/src/errors.rs60
-rw-r--r--rust/src/http.rs165
-rw-r--r--rust/src/lib.rs2
6 files changed, 321 insertions, 2 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 630d38c2f4..65f4807c65 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -30,6 +30,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
 
 [[package]]
+name = "base64"
+version = "0.21.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
+
+[[package]]
 name = "bitflags"
 version = "1.3.2"
 source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -54,12 +60,27 @@ dependencies = [
 ]
 
 [[package]]
+name = "bytes"
+version = "1.6.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9"
+
+[[package]]
 name = "cfg-if"
 version = "1.0.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
 
 [[package]]
+name = "cpufeatures"
+version = "0.2.12"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504"
+dependencies = [
+ "libc",
+]
+
+[[package]]
 name = "crypto-common"
 version = "0.1.6"
 source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -81,6 +102,12 @@ dependencies = [
 ]
 
 [[package]]
+name = "fnv"
+version = "1.0.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
+
+[[package]]
 name = "generic-array"
 version = "0.14.6"
 source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -91,6 +118,30 @@ dependencies = [
 ]
 
 [[package]]
+name = "headers"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "322106e6bd0cba2d5ead589ddb8150a13d7c4217cf80d7c4f682ca994ccc6aa9"
+dependencies = [
+ "base64",
+ "bytes",
+ "headers-core",
+ "http",
+ "httpdate",
+ "mime",
+ "sha1",
+]
+
+[[package]]
+name = "headers-core"
+version = "0.3.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4"
+dependencies = [
+ "http",
+]
+
+[[package]]
 name = "heck"
 version = "0.4.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -103,6 +154,23 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
 
 [[package]]
+name = "http"
+version = "1.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258"
+dependencies = [
+ "bytes",
+ "fnv",
+ "itoa",
+]
+
+[[package]]
+name = "httpdate"
+version = "1.0.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
+
+[[package]]
 name = "indoc"
 version = "2.0.4"
 source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -122,9 +190,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
 
 [[package]]
 name = "libc"
-version = "0.2.135"
+version = "0.2.153"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "68783febc7782c6c5cb401fbda4de5a9898be1762314da0bb2c10ced61f18b0c"
+checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd"
 
 [[package]]
 name = "lock_api"
@@ -158,6 +226,12 @@ dependencies = [
 ]
 
 [[package]]
+name = "mime"
+version = "0.3.17"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
+
+[[package]]
 name = "once_cell"
 version = "1.15.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -377,6 +451,17 @@ dependencies = [
 ]
 
 [[package]]
+name = "sha1"
+version = "0.10.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3"
+dependencies = [
+ "cfg-if",
+ "cpufeatures",
+ "digest",
+]
+
+[[package]]
 name = "smallvec"
 version = "1.10.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -405,7 +490,10 @@ version = "0.1.0"
 dependencies = [
  "anyhow",
  "blake2",
+ "bytes",
+ "headers",
  "hex",
+ "http",
  "lazy_static",
  "log",
  "pyo3",
diff --git a/changelog.d/17081.misc b/changelog.d/17081.misc
new file mode 100644
index 0000000000..d1ab69126c
--- /dev/null
+++ b/changelog.d/17081.misc
@@ -0,0 +1 @@
+Add helpers to transform Twisted requests to Rust http Requests/Responses.
diff --git a/rust/Cargo.toml b/rust/Cargo.toml
index ba293f8d4f..9ac766182b 100644
--- a/rust/Cargo.toml
+++ b/rust/Cargo.toml
@@ -23,6 +23,9 @@ name = "synapse.synapse_rust"
 
 [dependencies]
 anyhow = "1.0.63"
+bytes = "1.6.0"
+headers = "0.4.0"
+http = "1.1.0"
 lazy_static = "1.4.0"
 log = "0.4.17"
 pyo3 = { version = "0.20.0", features = [
diff --git a/rust/src/errors.rs b/rust/src/errors.rs
new file mode 100644
index 0000000000..4e580e3e8c
--- /dev/null
+++ b/rust/src/errors.rs
@@ -0,0 +1,60 @@
+/*
+ * This file is licensed under the Affero General Public License (AGPL) version 3.
+ *
+ * Copyright (C) 2024 New Vector, Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as
+ * published by the Free Software Foundation, either version 3 of the
+ * License, or (at your option) any later version.
+ *
+ * See the GNU Affero General Public License for more details:
+ * <https://www.gnu.org/licenses/agpl-3.0.html>.
+ */
+
+#![allow(clippy::new_ret_no_self)]
+
+use std::collections::HashMap;
+
+use http::{HeaderMap, StatusCode};
+use pyo3::{exceptions::PyValueError, import_exception};
+
+import_exception!(synapse.api.errors, SynapseError);
+
+impl SynapseError {
+    pub fn new(
+        code: StatusCode,
+        message: String,
+        errcode: &'static str,
+        additional_fields: Option<HashMap<String, String>>,
+        headers: Option<HeaderMap>,
+    ) -> pyo3::PyErr {
+        // Transform the HeaderMap into a HashMap<String, String>
+        let headers = if let Some(headers) = headers {
+            let mut map = HashMap::with_capacity(headers.len());
+            for (key, value) in headers.iter() {
+                let Ok(value) = value.to_str() else {
+                    // This should never happen, but we don't want to panic in case it does
+                    return PyValueError::new_err(
+                        "Could not construct SynapseError: header value is not valid ASCII",
+                    );
+                };
+
+                map.insert(key.as_str().to_owned(), value.to_owned());
+            }
+            Some(map)
+        } else {
+            None
+        };
+
+        SynapseError::new_err((code.as_u16(), message, errcode, additional_fields, headers))
+    }
+}
+
+import_exception!(synapse.api.errors, NotFoundError);
+
+impl NotFoundError {
+    pub fn new() -> pyo3::PyErr {
+        NotFoundError::new_err(())
+    }
+}
diff --git a/rust/src/http.rs b/rust/src/http.rs
new file mode 100644
index 0000000000..74098f4c8b
--- /dev/null
+++ b/rust/src/http.rs
@@ -0,0 +1,165 @@
+/*
+ * This file is licensed under the Affero General Public License (AGPL) version 3.
+ *
+ * Copyright (C) 2024 New Vector, Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as
+ * published by the Free Software Foundation, either version 3 of the
+ * License, or (at your option) any later version.
+ *
+ * See the GNU Affero General Public License for more details:
+ * <https://www.gnu.org/licenses/agpl-3.0.html>.
+ */
+
+use bytes::{Buf, BufMut, Bytes, BytesMut};
+use headers::{Header, HeaderMapExt};
+use http::{HeaderName, HeaderValue, Method, Request, Response, StatusCode, Uri};
+use pyo3::{
+    exceptions::PyValueError,
+    types::{PyBytes, PySequence, PyTuple},
+    PyAny, PyResult,
+};
+
+use crate::errors::SynapseError;
+
+/// Read a file-like Python object by chunks
+///
+/// # Errors
+///
+/// Returns an error if calling the `read` on the Python object failed
+fn read_io_body(body: &PyAny, chunk_size: usize) -> PyResult<Bytes> {
+    let mut buf = BytesMut::new();
+    loop {
+        let bytes: &PyBytes = body.call_method1("read", (chunk_size,))?.downcast()?;
+        if bytes.as_bytes().is_empty() {
+            return Ok(buf.into());
+        }
+        buf.put(bytes.as_bytes());
+    }
+}
+
+/// Transform a Twisted `IRequest` to an [`http::Request`]
+///
+/// It uses the following members of `IRequest`:
+///   - `content`, which is expected to be a file-like object with a `read` method
+///   - `uri`, which is expected to be a valid URI as `bytes`
+///   - `method`, which is expected to be a valid HTTP method as `bytes`
+///   - `requestHeaders`, which is expected to have a `getAllRawHeaders` method
+///
+/// # Errors
+///
+/// Returns an error if the Python object doesn't properly implement `IRequest`
+pub fn http_request_from_twisted(request: &PyAny) -> PyResult<Request<Bytes>> {
+    let content = request.getattr("content")?;
+    let body = read_io_body(content, 4096)?;
+
+    let mut req = Request::new(body);
+
+    let uri: &PyBytes = request.getattr("uri")?.downcast()?;
+    *req.uri_mut() =
+        Uri::try_from(uri.as_bytes()).map_err(|_| PyValueError::new_err("invalid uri"))?;
+
+    let method: &PyBytes = request.getattr("method")?.downcast()?;
+    *req.method_mut() = Method::from_bytes(method.as_bytes())
+        .map_err(|_| PyValueError::new_err("invalid method"))?;
+
+    let headers_iter = request
+        .getattr("requestHeaders")?
+        .call_method0("getAllRawHeaders")?
+        .iter()?;
+
+    for header in headers_iter {
+        let header = header?;
+        let header: &PyTuple = header.downcast()?;
+        let name: &PyBytes = header.get_item(0)?.downcast()?;
+        let name = HeaderName::from_bytes(name.as_bytes())
+            .map_err(|_| PyValueError::new_err("invalid header name"))?;
+
+        let values: &PySequence = header.get_item(1)?.downcast()?;
+        for index in 0..values.len()? {
+            let value: &PyBytes = values.get_item(index)?.downcast()?;
+            let value = HeaderValue::from_bytes(value.as_bytes())
+                .map_err(|_| PyValueError::new_err("invalid header value"))?;
+            req.headers_mut().append(name.clone(), value);
+        }
+    }
+
+    Ok(req)
+}
+
+/// Send an [`http::Response`] through a Twisted `IRequest`
+///
+/// It uses the following members of `IRequest`:
+///
+///  - `responseHeaders`, which is expected to have a `addRawHeader(bytes, bytes)` method
+///  - `setResponseCode(int)` method
+///  - `write(bytes)` method
+///  - `finish()` method
+///
+///  # Errors
+///
+/// Returns an error if the Python object doesn't properly implement `IRequest`
+pub fn http_response_to_twisted<B>(request: &PyAny, response: Response<B>) -> PyResult<()>
+where
+    B: Buf,
+{
+    let (parts, mut body) = response.into_parts();
+
+    request.call_method1("setResponseCode", (parts.status.as_u16(),))?;
+
+    let response_headers = request.getattr("responseHeaders")?;
+    for (name, value) in parts.headers.iter() {
+        response_headers.call_method1("addRawHeader", (name.as_str(), value.as_bytes()))?;
+    }
+
+    while body.remaining() != 0 {
+        let chunk = body.chunk();
+        request.call_method1("write", (chunk,))?;
+        body.advance(chunk.len());
+    }
+
+    request.call_method0("finish")?;
+
+    Ok(())
+}
+
+/// An extension trait for [`HeaderMap`] that provides typed access to headers, and throws the
+/// right python exceptions when the header is missing or fails to parse.
+///
+/// [`HeaderMap`]: headers::HeaderMap
+pub trait HeaderMapPyExt: HeaderMapExt {
+    /// Get a header from the map, returning an error if it is missing or invalid.
+    fn typed_get_required<H>(&self) -> PyResult<H>
+    where
+        H: Header,
+    {
+        self.typed_get_optional::<H>()?.ok_or_else(|| {
+            SynapseError::new(
+                StatusCode::BAD_REQUEST,
+                format!("Missing required header: {}", H::name()),
+                "M_MISSING_PARAM",
+                None,
+                None,
+            )
+        })
+    }
+
+    /// Get a header from the map, returning `None` if it is missing and an error if it is invalid.
+    fn typed_get_optional<H>(&self) -> PyResult<Option<H>>
+    where
+        H: Header,
+    {
+        self.typed_try_get::<H>().map_err(|_| {
+            SynapseError::new(
+                StatusCode::BAD_REQUEST,
+                format!("Invalid header: {}", H::name()),
+                "M_INVALID_PARAM",
+                None,
+                None,
+            )
+        })
+    }
+}
+
+impl<T: HeaderMapExt> HeaderMapPyExt for T {}
diff --git a/rust/src/lib.rs b/rust/src/lib.rs
index 7b3b579e55..36a3d64528 100644
--- a/rust/src/lib.rs
+++ b/rust/src/lib.rs
@@ -3,7 +3,9 @@ use pyo3::prelude::*;
 use pyo3_log::ResetHandle;
 
 pub mod acl;
+pub mod errors;
 pub mod events;
+pub mod http;
 pub mod push;
 
 lazy_static! {