diff options
author | Erik Johnston <erik@matrix.org> | 2022-12-14 11:02:16 +0000 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2022-12-14 11:02:16 +0000 |
commit | c93ef61fa38337691359b65a90f0a6d5bfc299a7 (patch) | |
tree | 5ce79ebd56eac6f417fe9816a050b1a1d84e7f9b /rust | |
parent | Allow selecting "prejoin" events by state keys (#14642) (diff) | |
download | synapse-c93ef61fa38337691359b65a90f0a6d5bfc299a7.tar.xz |
WIP Rust HTTP for federation
Diffstat (limited to 'rust')
-rw-r--r-- | rust/Cargo.toml | 10 | ||||
-rw-r--r-- | rust/src/http/mod.rs | 149 | ||||
-rw-r--r-- | rust/src/http/resolver.rs | 428 | ||||
-rw-r--r-- | rust/src/lib.rs | 2 |
4 files changed, 589 insertions, 0 deletions
diff --git a/rust/Cargo.toml b/rust/Cargo.toml index cffaa5b51b..f96f8c4041 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -21,14 +21,24 @@ name = "synapse.synapse_rust" [dependencies] anyhow = "1.0.63" +futures = "0.3.25" +futures-util = "0.3.25" +http = "0.2.8" +hyper = { version = "0.14.23", features = ["client", "http1", "http2", "runtime", "server", "full"] } +hyper-tls = "0.5.0" lazy_static = "1.4.0" log = "0.4.17" +native-tls = "0.2.11" pyo3 = { version = "0.17.1", features = ["extension-module", "macros", "anyhow", "abi3", "abi3-py37"] } +pyo3-asyncio = { version = "0.17.0", features = ["tokio", "tokio-runtime"] } pyo3-log = "0.7.0" pythonize = "0.17.0" regex = "1.6.0" serde = { version = "1.0.144", features = ["derive"] } serde_json = "1.0.85" +tokio = "1.23.0" +tokio-native-tls = "0.3.0" +trust-dns-resolver = "0.22.0" [build-dependencies] blake2 = "0.10.4" diff --git a/rust/src/http/mod.rs b/rust/src/http/mod.rs new file mode 100644 index 0000000000..b533a3d36d --- /dev/null +++ b/rust/src/http/mod.rs @@ -0,0 +1,149 @@ +use std::collections::HashMap; + +use anyhow::Error; +use http::Request; +use hyper::Body; +use log::info; +use pyo3::{ + pyclass, pymethods, + types::{PyBytes, PyModule}, + IntoPy, PyAny, PyObject, PyResult, Python, ToPyObject, +}; + +use self::resolver::{MatrixConnector, MatrixResolver}; + +mod resolver; + +/// Called when registering modules with python. +pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> { + let child_module = PyModule::new(py, "http")?; + child_module.add_class::<HttpClient>()?; + child_module.add_class::<MatrixResponse>()?; + + m.add_submodule(child_module)?; + + // We need to manually add the module to sys.modules to make `from + // synapse.synapse_rust import push` work. + py.import("sys")? + .getattr("modules")? + .set_item("synapse.synapse_rust.http", child_module)?; + + Ok(()) +} + +#[derive(Clone)] +struct Bytes(Vec<u8>); + +impl ToPyObject for Bytes { + fn to_object(&self, py: Python<'_>) -> pyo3::PyObject { + PyBytes::new(py, &self.0).into_py(py) + } +} + +impl IntoPy<PyObject> for Bytes { + fn into_py(self, py: Python<'_>) -> PyObject { + self.to_object(py) + } +} + +#[pyclass] +pub struct MatrixResponse { + #[pyo3(get)] + code: u16, + #[pyo3(get)] + phrase: &'static str, + #[pyo3(get)] + content: Bytes, + #[pyo3(get)] + headers: HashMap<String, Bytes>, +} + +#[pyclass] +#[derive(Clone)] +pub struct HttpClient { + client: hyper::Client<MatrixConnector>, +} + +impl HttpClient { + pub fn new() -> Result<Self, Error> { + let resolver = MatrixResolver::new()?; + + let client = hyper::Client::builder().build(MatrixConnector::with_resolver(resolver)); + + Ok(HttpClient { client }) + } + + pub async fn async_request( + &self, + url: String, + method: String, + headers: HashMap<Vec<u8>, Vec<Vec<u8>>>, + body: Option<Vec<u8>>, + ) -> Result<MatrixResponse, Error> { + let mut builder = Request::builder().method(&*method).uri(url); + + for (key, values) in headers { + for value in values { + builder = builder.header(key.clone(), value); + } + } + + let request = if let Some(body) = body { + builder.body(Body::from(body))? + } else { + builder.body(Body::empty())? + }; + + let response = self.client.request(request).await?; + + let code = response.status().as_u16(); + let phrase = response.status().canonical_reason().unwrap_or_default(); + + let headers = response + .headers() + .iter() + .map(|(k, v)| (k.to_string(), Bytes(v.as_bytes().to_owned()))) + .collect(); + + let body = response.into_body(); + + let bytes = hyper::body::to_bytes(body).await?; + let content = Bytes(bytes.to_vec()); + + info!("DONE"); + + Ok(MatrixResponse { + code, + phrase, + content, + headers, + }) + } +} + +#[pymethods] +impl HttpClient { + #[new] + fn py_new() -> Result<Self, Error> { + Self::new() + } + + fn request<'a>( + &'a self, + py: Python<'a>, + url: String, + method: String, + headers: HashMap<Vec<u8>, Vec<Vec<u8>>>, + body: Option<Vec<u8>>, + ) -> PyResult<&'a PyAny> { + pyo3::prepare_freethreaded_python(); + info!("REQUEST"); + + let client = self.clone(); + + pyo3_asyncio::tokio::future_into_py(py, async move { + let resp = client.async_request(url, method, headers, body).await?; + Ok(resp) + }) + } +} diff --git a/rust/src/http/resolver.rs b/rust/src/http/resolver.rs new file mode 100644 index 0000000000..77e9bf5c20 --- /dev/null +++ b/rust/src/http/resolver.rs @@ -0,0 +1,428 @@ +use std::collections::BTreeMap; +use std::future::Future; +use std::net::IpAddr; +use std::pin::Pin; +use std::str::FromStr; +use std::{ + io::Cursor, + sync::{Arc, Mutex}, + task::{self, Poll}, +}; + +use anyhow::{bail, Error}; +use futures::{FutureExt, TryFutureExt}; +use futures_util::stream::StreamExt; +use http::Uri; +use hyper::client::connect::Connection; +use hyper::client::connect::{Connected, HttpConnector}; +use hyper::server::conn::Http; +use hyper::service::Service; +use hyper::Client; +use hyper_tls::HttpsConnector; +use hyper_tls::MaybeHttpsStream; +use log::info; +use native_tls::TlsConnector; +use serde::Deserialize; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::net::TcpStream; +use tokio_native_tls::TlsConnector as AsyncTlsConnector; +use trust_dns_resolver::error::ResolveErrorKind; + +pub struct Endpoint { + pub host: String, + pub port: u16, + + pub host_header: String, + pub tls_name: String, +} + +#[derive(Clone)] +pub struct MatrixResolver { + resolver: trust_dns_resolver::TokioAsyncResolver, + http_client: Client<HttpsConnector<HttpConnector>>, +} + +impl MatrixResolver { + pub fn new() -> Result<MatrixResolver, Error> { + let http_client = hyper::Client::builder().build(HttpsConnector::new()); + + MatrixResolver::with_client(http_client) + } + + pub fn with_client( + http_client: Client<HttpsConnector<HttpConnector>>, + ) -> Result<MatrixResolver, Error> { + let resolver = trust_dns_resolver::TokioAsyncResolver::tokio_from_system_conf()?; + + Ok(MatrixResolver { + resolver, + http_client, + }) + } + + /// Does SRV lookup + pub async fn resolve_server_name_from_uri(&self, uri: &Uri) -> Result<Vec<Endpoint>, Error> { + let host = uri.host().expect("URI has no host").to_string(); + let port = uri.port_u16(); + + self.resolve_server_name_from_host_port(host, port).await + } + + pub async fn resolve_server_name_from_host_port( + &self, + mut host: String, + mut port: Option<u16>, + ) -> Result<Vec<Endpoint>, Error> { + let mut authority = if let Some(p) = port { + format!("{}:{}", host, p) + } else { + host.to_string() + }; + + // If a literal IP or includes port then we shortcircuit. + if host.parse::<IpAddr>().is_ok() || port.is_some() { + return Ok(vec![Endpoint { + host: host.to_string(), + port: port.unwrap_or(8448), + + host_header: authority.to_string(), + tls_name: host.to_string(), + }]); + } + + // Do well-known delegation lookup. + if let Some(server) = get_well_known(&self.http_client, &host).await { + let a = http::uri::Authority::from_str(&server.server)?; + host = a.host().to_string(); + port = a.port_u16(); + authority = a.to_string(); + } + + // If a literal IP or includes port then we shortcircuit. + if host.parse::<IpAddr>().is_ok() || port.is_some() { + return Ok(vec![Endpoint { + host: host.clone(), + port: port.unwrap_or(8448), + + host_header: authority.to_string(), + tls_name: host.clone(), + }]); + } + + let result = self + .resolver + .srv_lookup(format!("_matrix._tcp.{}", host)) + .await; + + let records = match result { + Ok(records) => records, + Err(err) => match err.kind() { + ResolveErrorKind::NoRecordsFound { .. } => { + return Ok(vec![Endpoint { + host: host.clone(), + port: 8448, + host_header: authority.to_string(), + tls_name: host.clone(), + }]) + } + _ => return Err(err.into()), + }, + }; + + let mut priority_map: BTreeMap<u16, Vec<_>> = BTreeMap::new(); + + let mut count = 0; + for record in records { + count += 1; + let priority = record.priority(); + priority_map.entry(priority).or_default().push(record); + } + + let mut results = Vec::with_capacity(count); + + for (_priority, records) in priority_map { + // TODO: Correctly shuffle records + results.extend(records.into_iter().map(|record| Endpoint { + host: record.target().to_utf8(), + port: record.port(), + + host_header: host.to_string(), + tls_name: host.to_string(), + })) + } + + Ok(results) + } +} + +async fn get_well_known<C>(http_client: &Client<C>, host: &str) -> Option<WellKnownServer> +where + C: Service<Uri> + Clone + Sync + Send + 'static, + C::Error: Into<Box<dyn std::error::Error + Send + Sync>>, + C::Future: Unpin + Send, + C::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static, +{ + // TODO: Add timeout. + + let uri = hyper::Uri::builder() + .scheme("https") + .authority(host) + .path_and_query("/.well-known/matrix/server") + .build() + .ok()?; + + let mut body = http_client.get(uri).await.ok()?.into_body(); + + let mut vec = Vec::new(); + while let Some(next) = body.next().await { + let chunk = next.ok()?; + vec.extend(chunk); + } + + serde_json::from_slice(&vec).ok()? +} + +#[derive(Deserialize)] +struct WellKnownServer { + #[serde(rename = "m.server")] + server: String, +} + +#[derive(Clone)] +pub struct MatrixConnector { + resolver: MatrixResolver, +} + +impl MatrixConnector { + pub fn with_resolver(resolver: MatrixResolver) -> MatrixConnector { + MatrixConnector { resolver } + } +} + +impl Service<Uri> for MatrixConnector { + type Response = MaybeHttpsStream<TcpStream>; + type Error = Error; + type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>; + + fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { + // This connector is always ready, but others might not be. + Poll::Ready(Ok(())) + } + + fn call(&mut self, dst: Uri) -> Self::Future { + let resolver = self.resolver.clone(); + + if dst.scheme_str() != Some("matrix") { + return HttpsConnector::new() + .call(dst) + .map_err(|e| Error::msg(e)) + .boxed(); + } + + async move { + let endpoints = resolver + .resolve_server_name_from_host_port( + dst.host().expect("hostname").to_string(), + dst.port_u16(), + ) + .await?; + + for endpoint in endpoints { + match try_connecting(&dst, &endpoint).await { + Ok(r) => return Ok(r), + // Errors here are not unexpected, and we just move on + // with our lives. + Err(e) => info!( + "Failed to connect to {} via {}:{} because {}", + dst.host().expect("hostname"), + endpoint.host, + endpoint.port, + e, + ), + } + } + + bail!( + "failed to resolve host: {:?} port {:?}", + dst.host(), + dst.port() + ) + } + .boxed() + } +} + +/// Attempts to connect to a particular endpoint. +async fn try_connecting( + dst: &Uri, + endpoint: &Endpoint, +) -> Result<MaybeHttpsStream<TcpStream>, Error> { + let tcp = TcpStream::connect((&endpoint.host as &str, endpoint.port)).await?; + + let connector: AsyncTlsConnector = if dst.host().expect("hostname").contains("localhost") { + TlsConnector::builder() + .danger_accept_invalid_certs(true) + .build()? + .into() + } else { + TlsConnector::new().unwrap().into() + }; + + let tls = connector.connect(&endpoint.tls_name, tcp).await?; + + Ok(tls.into()) +} + +/// A connector that reutrns a connection which returns 200 OK to all connections. +#[derive(Clone)] +pub struct TestConnector; + +impl Service<Uri> for TestConnector { + type Response = TestConnection; + type Error = Error; + type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>; + + fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { + // This connector is always ready, but others might not be. + Poll::Ready(Ok(())) + } + + fn call(&mut self, _dst: Uri) -> Self::Future { + let (client, server) = TestConnection::double_ended(); + + { + let service = hyper::service::service_fn(|_| async move { + Ok(hyper::Response::new(hyper::Body::from("Hello World"))) + as Result<_, hyper::http::Error> + }); + let fut = Http::new().serve_connection(server, service); + tokio::spawn(fut); + } + + futures::future::ok(client).boxed() + } +} + +#[derive(Default)] +struct TestConnectionInner { + outbound_buffer: Cursor<Vec<u8>>, + inbound_buffer: Cursor<Vec<u8>>, + wakers: Vec<futures::task::Waker>, +} + +/// A in memory connection for use with tests. +#[derive(Clone, Default)] +pub struct TestConnection { + inner: Arc<Mutex<TestConnectionInner>>, + direction: bool, +} + +impl TestConnection { + pub fn double_ended() -> (TestConnection, TestConnection) { + let inner: Arc<Mutex<TestConnectionInner>> = Arc::default(); + + let a = TestConnection { + inner: inner.clone(), + direction: false, + }; + + let b = TestConnection { + inner, + direction: true, + }; + + (a, b) + } +} + +impl AsyncRead for TestConnection { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<Result<(), std::io::Error>> { + let mut conn = self.inner.lock().expect("mutex"); + + let buffer = if self.direction { + &mut conn.inbound_buffer + } else { + &mut conn.outbound_buffer + }; + + let bytes_read = std::io::Read::read(buffer, buf.initialize_unfilled())?; + buf.advance(bytes_read); + if bytes_read > 0 { + Poll::Ready(Ok(())) + } else { + conn.wakers.push(cx.waker().clone()); + Poll::Pending + } + } +} + +impl AsyncWrite for TestConnection { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, std::io::Error>> { + let mut conn = self.inner.lock().expect("mutex"); + + if self.direction { + conn.outbound_buffer.get_mut().extend_from_slice(buf); + } else { + conn.inbound_buffer.get_mut().extend_from_slice(buf); + } + + for waker in conn.wakers.drain(..) { + waker.wake() + } + + Poll::Ready(Ok(buf.len())) + } + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<Result<(), std::io::Error>> { + let mut conn = self.inner.lock().expect("mutex"); + + if self.direction { + Pin::new(&mut conn.outbound_buffer).poll_flush(cx) + } else { + Pin::new(&mut conn.inbound_buffer).poll_flush(cx) + } + } + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<Result<(), std::io::Error>> { + let mut conn = self.inner.lock().expect("mutex"); + + if self.direction { + Pin::new(&mut conn.outbound_buffer).poll_shutdown(cx) + } else { + Pin::new(&mut conn.inbound_buffer).poll_shutdown(cx) + } + } +} + +impl Connection for TestConnection { + fn connected(&self) -> Connected { + Connected::new() + } +} + +#[tokio::test] +async fn test_memory_connection() { + let client: hyper::Client<_, hyper::Body> = hyper::Client::builder().build(TestConnector); + + let response = client + .get("http://localhost".parse().unwrap()) + .await + .unwrap(); + + assert!(response.status().is_success()); + + let bytes = hyper::body::to_bytes(response.into_body()).await.unwrap(); + assert_eq!(&bytes[..], b"Hello World"); +} diff --git a/rust/src/lib.rs b/rust/src/lib.rs index c7b60e58a7..1a19705b4f 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -1,5 +1,6 @@ use pyo3::prelude::*; +pub mod http; pub mod push; /// Returns the hash of all the rust source files at the time it was compiled. @@ -26,6 +27,7 @@ fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(get_rust_file_digest, m)?)?; push::register_module(py, m)?; + http::register_module(py, m)?; Ok(()) } |