summary refs log tree commit diff
path: root/rust
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2022-12-14 11:02:16 +0000
committerErik Johnston <erik@matrix.org>2022-12-14 11:02:16 +0000
commitc93ef61fa38337691359b65a90f0a6d5bfc299a7 (patch)
tree5ce79ebd56eac6f417fe9816a050b1a1d84e7f9b /rust
parentAllow selecting "prejoin" events by state keys (#14642) (diff)
downloadsynapse-c93ef61fa38337691359b65a90f0a6d5bfc299a7.tar.xz
WIP Rust HTTP for federation
Diffstat (limited to 'rust')
-rw-r--r--rust/Cargo.toml10
-rw-r--r--rust/src/http/mod.rs149
-rw-r--r--rust/src/http/resolver.rs428
-rw-r--r--rust/src/lib.rs2
4 files changed, 589 insertions, 0 deletions
diff --git a/rust/Cargo.toml b/rust/Cargo.toml
index cffaa5b51b..f96f8c4041 100644
--- a/rust/Cargo.toml
+++ b/rust/Cargo.toml
@@ -21,14 +21,24 @@ name = "synapse.synapse_rust"
 
 [dependencies]
 anyhow = "1.0.63"
+futures = "0.3.25"
+futures-util = "0.3.25"
+http = "0.2.8"
+hyper = { version = "0.14.23", features = ["client", "http1", "http2", "runtime", "server", "full"] }
+hyper-tls = "0.5.0"
 lazy_static = "1.4.0"
 log = "0.4.17"
+native-tls = "0.2.11"
 pyo3 = { version = "0.17.1", features = ["extension-module", "macros", "anyhow", "abi3", "abi3-py37"] }
+pyo3-asyncio = { version = "0.17.0", features = ["tokio", "tokio-runtime"] }
 pyo3-log = "0.7.0"
 pythonize = "0.17.0"
 regex = "1.6.0"
 serde = { version = "1.0.144", features = ["derive"] }
 serde_json = "1.0.85"
+tokio = "1.23.0"
+tokio-native-tls = "0.3.0"
+trust-dns-resolver = "0.22.0"
 
 [build-dependencies]
 blake2 = "0.10.4"
diff --git a/rust/src/http/mod.rs b/rust/src/http/mod.rs
new file mode 100644
index 0000000000..b533a3d36d
--- /dev/null
+++ b/rust/src/http/mod.rs
@@ -0,0 +1,149 @@
+use std::collections::HashMap;
+
+use anyhow::Error;
+use http::Request;
+use hyper::Body;
+use log::info;
+use pyo3::{
+    pyclass, pymethods,
+    types::{PyBytes, PyModule},
+    IntoPy, PyAny, PyObject, PyResult, Python, ToPyObject,
+};
+
+use self::resolver::{MatrixConnector, MatrixResolver};
+
+mod resolver;
+
+/// Called when registering modules with python.
+pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
+    let child_module = PyModule::new(py, "http")?;
+    child_module.add_class::<HttpClient>()?;
+    child_module.add_class::<MatrixResponse>()?;
+
+    m.add_submodule(child_module)?;
+
+    // We need to manually add the module to sys.modules to make `from
+    // synapse.synapse_rust import push` work.
+    py.import("sys")?
+        .getattr("modules")?
+        .set_item("synapse.synapse_rust.http", child_module)?;
+
+    Ok(())
+}
+
+#[derive(Clone)]
+struct Bytes(Vec<u8>);
+
+impl ToPyObject for Bytes {
+    fn to_object(&self, py: Python<'_>) -> pyo3::PyObject {
+        PyBytes::new(py, &self.0).into_py(py)
+    }
+}
+
+impl IntoPy<PyObject> for Bytes {
+    fn into_py(self, py: Python<'_>) -> PyObject {
+        self.to_object(py)
+    }
+}
+
+#[pyclass]
+pub struct MatrixResponse {
+    #[pyo3(get)]
+    code: u16,
+    #[pyo3(get)]
+    phrase: &'static str,
+    #[pyo3(get)]
+    content: Bytes,
+    #[pyo3(get)]
+    headers: HashMap<String, Bytes>,
+}
+
+#[pyclass]
+#[derive(Clone)]
+pub struct HttpClient {
+    client: hyper::Client<MatrixConnector>,
+}
+
+impl HttpClient {
+    pub fn new() -> Result<Self, Error> {
+        let resolver = MatrixResolver::new()?;
+
+        let client = hyper::Client::builder().build(MatrixConnector::with_resolver(resolver));
+
+        Ok(HttpClient { client })
+    }
+
+    pub async fn async_request(
+        &self,
+        url: String,
+        method: String,
+        headers: HashMap<Vec<u8>, Vec<Vec<u8>>>,
+        body: Option<Vec<u8>>,
+    ) -> Result<MatrixResponse, Error> {
+        let mut builder = Request::builder().method(&*method).uri(url);
+
+        for (key, values) in headers {
+            for value in values {
+                builder = builder.header(key.clone(), value);
+            }
+        }
+
+        let request = if let Some(body) = body {
+            builder.body(Body::from(body))?
+        } else {
+            builder.body(Body::empty())?
+        };
+
+        let response = self.client.request(request).await?;
+
+        let code = response.status().as_u16();
+        let phrase = response.status().canonical_reason().unwrap_or_default();
+
+        let headers = response
+            .headers()
+            .iter()
+            .map(|(k, v)| (k.to_string(), Bytes(v.as_bytes().to_owned())))
+            .collect();
+
+        let body = response.into_body();
+
+        let bytes = hyper::body::to_bytes(body).await?;
+        let content = Bytes(bytes.to_vec());
+
+        info!("DONE");
+
+        Ok(MatrixResponse {
+            code,
+            phrase,
+            content,
+            headers,
+        })
+    }
+}
+
+#[pymethods]
+impl HttpClient {
+    #[new]
+    fn py_new() -> Result<Self, Error> {
+        Self::new()
+    }
+
+    fn request<'a>(
+        &'a self,
+        py: Python<'a>,
+        url: String,
+        method: String,
+        headers: HashMap<Vec<u8>, Vec<Vec<u8>>>,
+        body: Option<Vec<u8>>,
+    ) -> PyResult<&'a PyAny> {
+        pyo3::prepare_freethreaded_python();
+        info!("REQUEST");
+
+        let client = self.clone();
+
+        pyo3_asyncio::tokio::future_into_py(py, async move {
+            let resp = client.async_request(url, method, headers, body).await?;
+            Ok(resp)
+        })
+    }
+}
diff --git a/rust/src/http/resolver.rs b/rust/src/http/resolver.rs
new file mode 100644
index 0000000000..77e9bf5c20
--- /dev/null
+++ b/rust/src/http/resolver.rs
@@ -0,0 +1,428 @@
+use std::collections::BTreeMap;
+use std::future::Future;
+use std::net::IpAddr;
+use std::pin::Pin;
+use std::str::FromStr;
+use std::{
+    io::Cursor,
+    sync::{Arc, Mutex},
+    task::{self, Poll},
+};
+
+use anyhow::{bail, Error};
+use futures::{FutureExt, TryFutureExt};
+use futures_util::stream::StreamExt;
+use http::Uri;
+use hyper::client::connect::Connection;
+use hyper::client::connect::{Connected, HttpConnector};
+use hyper::server::conn::Http;
+use hyper::service::Service;
+use hyper::Client;
+use hyper_tls::HttpsConnector;
+use hyper_tls::MaybeHttpsStream;
+use log::info;
+use native_tls::TlsConnector;
+use serde::Deserialize;
+use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
+use tokio::net::TcpStream;
+use tokio_native_tls::TlsConnector as AsyncTlsConnector;
+use trust_dns_resolver::error::ResolveErrorKind;
+
+pub struct Endpoint {
+    pub host: String,
+    pub port: u16,
+
+    pub host_header: String,
+    pub tls_name: String,
+}
+
+#[derive(Clone)]
+pub struct MatrixResolver {
+    resolver: trust_dns_resolver::TokioAsyncResolver,
+    http_client: Client<HttpsConnector<HttpConnector>>,
+}
+
+impl MatrixResolver {
+    pub fn new() -> Result<MatrixResolver, Error> {
+        let http_client = hyper::Client::builder().build(HttpsConnector::new());
+
+        MatrixResolver::with_client(http_client)
+    }
+
+    pub fn with_client(
+        http_client: Client<HttpsConnector<HttpConnector>>,
+    ) -> Result<MatrixResolver, Error> {
+        let resolver = trust_dns_resolver::TokioAsyncResolver::tokio_from_system_conf()?;
+
+        Ok(MatrixResolver {
+            resolver,
+            http_client,
+        })
+    }
+
+    /// Does SRV lookup
+    pub async fn resolve_server_name_from_uri(&self, uri: &Uri) -> Result<Vec<Endpoint>, Error> {
+        let host = uri.host().expect("URI has no host").to_string();
+        let port = uri.port_u16();
+
+        self.resolve_server_name_from_host_port(host, port).await
+    }
+
+    pub async fn resolve_server_name_from_host_port(
+        &self,
+        mut host: String,
+        mut port: Option<u16>,
+    ) -> Result<Vec<Endpoint>, Error> {
+        let mut authority = if let Some(p) = port {
+            format!("{}:{}", host, p)
+        } else {
+            host.to_string()
+        };
+
+        // If a literal IP or includes port then we shortcircuit.
+        if host.parse::<IpAddr>().is_ok() || port.is_some() {
+            return Ok(vec![Endpoint {
+                host: host.to_string(),
+                port: port.unwrap_or(8448),
+
+                host_header: authority.to_string(),
+                tls_name: host.to_string(),
+            }]);
+        }
+
+        // Do well-known delegation lookup.
+        if let Some(server) = get_well_known(&self.http_client, &host).await {
+            let a = http::uri::Authority::from_str(&server.server)?;
+            host = a.host().to_string();
+            port = a.port_u16();
+            authority = a.to_string();
+        }
+
+        // If a literal IP or includes port then we shortcircuit.
+        if host.parse::<IpAddr>().is_ok() || port.is_some() {
+            return Ok(vec![Endpoint {
+                host: host.clone(),
+                port: port.unwrap_or(8448),
+
+                host_header: authority.to_string(),
+                tls_name: host.clone(),
+            }]);
+        }
+
+        let result = self
+            .resolver
+            .srv_lookup(format!("_matrix._tcp.{}", host))
+            .await;
+
+        let records = match result {
+            Ok(records) => records,
+            Err(err) => match err.kind() {
+                ResolveErrorKind::NoRecordsFound { .. } => {
+                    return Ok(vec![Endpoint {
+                        host: host.clone(),
+                        port: 8448,
+                        host_header: authority.to_string(),
+                        tls_name: host.clone(),
+                    }])
+                }
+                _ => return Err(err.into()),
+            },
+        };
+
+        let mut priority_map: BTreeMap<u16, Vec<_>> = BTreeMap::new();
+
+        let mut count = 0;
+        for record in records {
+            count += 1;
+            let priority = record.priority();
+            priority_map.entry(priority).or_default().push(record);
+        }
+
+        let mut results = Vec::with_capacity(count);
+
+        for (_priority, records) in priority_map {
+            // TODO: Correctly shuffle records
+            results.extend(records.into_iter().map(|record| Endpoint {
+                host: record.target().to_utf8(),
+                port: record.port(),
+
+                host_header: host.to_string(),
+                tls_name: host.to_string(),
+            }))
+        }
+
+        Ok(results)
+    }
+}
+
+async fn get_well_known<C>(http_client: &Client<C>, host: &str) -> Option<WellKnownServer>
+where
+    C: Service<Uri> + Clone + Sync + Send + 'static,
+    C::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
+    C::Future: Unpin + Send,
+    C::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static,
+{
+    // TODO: Add timeout.
+
+    let uri = hyper::Uri::builder()
+        .scheme("https")
+        .authority(host)
+        .path_and_query("/.well-known/matrix/server")
+        .build()
+        .ok()?;
+
+    let mut body = http_client.get(uri).await.ok()?.into_body();
+
+    let mut vec = Vec::new();
+    while let Some(next) = body.next().await {
+        let chunk = next.ok()?;
+        vec.extend(chunk);
+    }
+
+    serde_json::from_slice(&vec).ok()?
+}
+
+#[derive(Deserialize)]
+struct WellKnownServer {
+    #[serde(rename = "m.server")]
+    server: String,
+}
+
+#[derive(Clone)]
+pub struct MatrixConnector {
+    resolver: MatrixResolver,
+}
+
+impl MatrixConnector {
+    pub fn with_resolver(resolver: MatrixResolver) -> MatrixConnector {
+        MatrixConnector { resolver }
+    }
+}
+
+impl Service<Uri> for MatrixConnector {
+    type Response = MaybeHttpsStream<TcpStream>;
+    type Error = Error;
+    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
+
+    fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
+        // This connector is always ready, but others might not be.
+        Poll::Ready(Ok(()))
+    }
+
+    fn call(&mut self, dst: Uri) -> Self::Future {
+        let resolver = self.resolver.clone();
+
+        if dst.scheme_str() != Some("matrix") {
+            return HttpsConnector::new()
+                .call(dst)
+                .map_err(|e| Error::msg(e))
+                .boxed();
+        }
+
+        async move {
+            let endpoints = resolver
+                .resolve_server_name_from_host_port(
+                    dst.host().expect("hostname").to_string(),
+                    dst.port_u16(),
+                )
+                .await?;
+
+            for endpoint in endpoints {
+                match try_connecting(&dst, &endpoint).await {
+                    Ok(r) => return Ok(r),
+                    // Errors here are not unexpected, and we just move on
+                    // with our lives.
+                    Err(e) => info!(
+                        "Failed to connect to {} via {}:{} because {}",
+                        dst.host().expect("hostname"),
+                        endpoint.host,
+                        endpoint.port,
+                        e,
+                    ),
+                }
+            }
+
+            bail!(
+                "failed to resolve host: {:?} port {:?}",
+                dst.host(),
+                dst.port()
+            )
+        }
+        .boxed()
+    }
+}
+
+/// Attempts to connect to a particular endpoint.
+async fn try_connecting(
+    dst: &Uri,
+    endpoint: &Endpoint,
+) -> Result<MaybeHttpsStream<TcpStream>, Error> {
+    let tcp = TcpStream::connect((&endpoint.host as &str, endpoint.port)).await?;
+
+    let connector: AsyncTlsConnector = if dst.host().expect("hostname").contains("localhost") {
+        TlsConnector::builder()
+            .danger_accept_invalid_certs(true)
+            .build()?
+            .into()
+    } else {
+        TlsConnector::new().unwrap().into()
+    };
+
+    let tls = connector.connect(&endpoint.tls_name, tcp).await?;
+
+    Ok(tls.into())
+}
+
+/// A connector that reutrns a connection which returns 200 OK to all connections.
+#[derive(Clone)]
+pub struct TestConnector;
+
+impl Service<Uri> for TestConnector {
+    type Response = TestConnection;
+    type Error = Error;
+    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
+
+    fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
+        // This connector is always ready, but others might not be.
+        Poll::Ready(Ok(()))
+    }
+
+    fn call(&mut self, _dst: Uri) -> Self::Future {
+        let (client, server) = TestConnection::double_ended();
+
+        {
+            let service = hyper::service::service_fn(|_| async move {
+                Ok(hyper::Response::new(hyper::Body::from("Hello World")))
+                    as Result<_, hyper::http::Error>
+            });
+            let fut = Http::new().serve_connection(server, service);
+            tokio::spawn(fut);
+        }
+
+        futures::future::ok(client).boxed()
+    }
+}
+
+#[derive(Default)]
+struct TestConnectionInner {
+    outbound_buffer: Cursor<Vec<u8>>,
+    inbound_buffer: Cursor<Vec<u8>>,
+    wakers: Vec<futures::task::Waker>,
+}
+
+/// A in memory connection for use with tests.
+#[derive(Clone, Default)]
+pub struct TestConnection {
+    inner: Arc<Mutex<TestConnectionInner>>,
+    direction: bool,
+}
+
+impl TestConnection {
+    pub fn double_ended() -> (TestConnection, TestConnection) {
+        let inner: Arc<Mutex<TestConnectionInner>> = Arc::default();
+
+        let a = TestConnection {
+            inner: inner.clone(),
+            direction: false,
+        };
+
+        let b = TestConnection {
+            inner,
+            direction: true,
+        };
+
+        (a, b)
+    }
+}
+
+impl AsyncRead for TestConnection {
+    fn poll_read(
+        self: Pin<&mut Self>,
+        cx: &mut task::Context<'_>,
+        buf: &mut ReadBuf<'_>,
+    ) -> Poll<Result<(), std::io::Error>> {
+        let mut conn = self.inner.lock().expect("mutex");
+
+        let buffer = if self.direction {
+            &mut conn.inbound_buffer
+        } else {
+            &mut conn.outbound_buffer
+        };
+
+        let bytes_read = std::io::Read::read(buffer, buf.initialize_unfilled())?;
+        buf.advance(bytes_read);
+        if bytes_read > 0 {
+            Poll::Ready(Ok(()))
+        } else {
+            conn.wakers.push(cx.waker().clone());
+            Poll::Pending
+        }
+    }
+}
+
+impl AsyncWrite for TestConnection {
+    fn poll_write(
+        self: Pin<&mut Self>,
+        _cx: &mut task::Context<'_>,
+        buf: &[u8],
+    ) -> Poll<Result<usize, std::io::Error>> {
+        let mut conn = self.inner.lock().expect("mutex");
+
+        if self.direction {
+            conn.outbound_buffer.get_mut().extend_from_slice(buf);
+        } else {
+            conn.inbound_buffer.get_mut().extend_from_slice(buf);
+        }
+
+        for waker in conn.wakers.drain(..) {
+            waker.wake()
+        }
+
+        Poll::Ready(Ok(buf.len()))
+    }
+    fn poll_flush(
+        self: Pin<&mut Self>,
+        cx: &mut task::Context<'_>,
+    ) -> Poll<Result<(), std::io::Error>> {
+        let mut conn = self.inner.lock().expect("mutex");
+
+        if self.direction {
+            Pin::new(&mut conn.outbound_buffer).poll_flush(cx)
+        } else {
+            Pin::new(&mut conn.inbound_buffer).poll_flush(cx)
+        }
+    }
+    fn poll_shutdown(
+        self: Pin<&mut Self>,
+        cx: &mut task::Context<'_>,
+    ) -> Poll<Result<(), std::io::Error>> {
+        let mut conn = self.inner.lock().expect("mutex");
+
+        if self.direction {
+            Pin::new(&mut conn.outbound_buffer).poll_shutdown(cx)
+        } else {
+            Pin::new(&mut conn.inbound_buffer).poll_shutdown(cx)
+        }
+    }
+}
+
+impl Connection for TestConnection {
+    fn connected(&self) -> Connected {
+        Connected::new()
+    }
+}
+
+#[tokio::test]
+async fn test_memory_connection() {
+    let client: hyper::Client<_, hyper::Body> = hyper::Client::builder().build(TestConnector);
+
+    let response = client
+        .get("http://localhost".parse().unwrap())
+        .await
+        .unwrap();
+
+    assert!(response.status().is_success());
+
+    let bytes = hyper::body::to_bytes(response.into_body()).await.unwrap();
+    assert_eq!(&bytes[..], b"Hello World");
+}
diff --git a/rust/src/lib.rs b/rust/src/lib.rs
index c7b60e58a7..1a19705b4f 100644
--- a/rust/src/lib.rs
+++ b/rust/src/lib.rs
@@ -1,5 +1,6 @@
 use pyo3::prelude::*;
 
+pub mod http;
 pub mod push;
 
 /// Returns the hash of all the rust source files at the time it was compiled.
@@ -26,6 +27,7 @@ fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> {
     m.add_function(wrap_pyfunction!(get_rust_file_digest, m)?)?;
 
     push::register_module(py, m)?;
+    http::register_module(py, m)?;
 
     Ok(())
 }