diff options
author | Erik Johnston <erik@matrix.org> | 2022-12-15 14:05:59 +0000 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2022-12-15 14:05:59 +0000 |
commit | f5817281f8bc707ed60562706091dadccf55efe5 (patch) | |
tree | a653fae7ca87a579fc261a746ba68910b52d2e9b /rust/src | |
parent | Fixup (diff) | |
download | synapse-erikj/rust_http.tar.xz |
Diffstat (limited to '')
-rw-r--r-- | rust/src/http/mod.rs | 34 | ||||
-rw-r--r-- | rust/src/http/resolver.rs | 6 |
2 files changed, 28 insertions, 12 deletions
diff --git a/rust/src/http/mod.rs b/rust/src/http/mod.rs index 508f7cb048..c764f7c76a 100644 --- a/rust/src/http/mod.rs +++ b/rust/src/http/mod.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use anyhow::Error; -use http::Request; +use http::{Request, Uri}; use hyper::Body; use log::info; use pyo3::{ @@ -12,7 +12,7 @@ use pyo3::{ use self::resolver::{MatrixConnector, MatrixResolver}; -mod resolver; +pub mod resolver; /// Called when registering modules with python. pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> { @@ -31,8 +31,8 @@ pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> { Ok(()) } -#[derive(Clone)] -struct Bytes(Vec<u8>); +#[derive(Clone, Debug)] +pub struct Bytes(pub Vec<u8>); impl ToPyObject for Bytes { fn to_object(&self, py: Python<'_>) -> pyo3::PyObject { @@ -46,31 +46,34 @@ impl IntoPy<PyObject> for Bytes { } } +#[derive(Debug)] #[pyclass] pub struct MatrixResponse { #[pyo3(get)] - code: u16, + pub code: u16, #[pyo3(get)] - phrase: &'static str, + pub phrase: &'static str, #[pyo3(get)] - content: Bytes, + pub content: Bytes, #[pyo3(get)] - headers: HashMap<String, Bytes>, + pub headers: HashMap<String, Bytes>, } #[pyclass] #[derive(Clone)] pub struct HttpClient { client: hyper::Client<MatrixConnector>, + resolver: MatrixResolver, } impl HttpClient { pub fn new() -> Result<Self, Error> { let resolver = MatrixResolver::new()?; - let client = hyper::Client::builder().build(MatrixConnector::with_resolver(resolver)); + let client = + hyper::Client::builder().build(MatrixConnector::with_resolver(resolver.clone())); - Ok(HttpClient { client }) + Ok(HttpClient { client, resolver }) } pub async fn async_request( @@ -80,7 +83,9 @@ impl HttpClient { headers: HashMap<Vec<u8>, Vec<Vec<u8>>>, body: Option<Vec<u8>>, ) -> Result<MatrixResponse, Error> { - let mut builder = Request::builder().method(&*method).uri(url); + let uri: Uri = url.try_into()?; + + let mut builder = Request::builder().method(&*method).uri(uri.clone()); for (key, values) in headers { for value in values { @@ -88,6 +93,13 @@ impl HttpClient { } } + if uri.scheme_str() == Some("matrix") { + let endpoints = self.resolver.resolve_server_name_from_uri(&uri).await?; + if let Some(endpoint) = endpoints.first() { + builder = builder.header("Host", &endpoint.host_header); + } + } + let request = if let Some(body) = body { builder.body(Body::from(body))? } else { diff --git a/rust/src/http/resolver.rs b/rust/src/http/resolver.rs index 77e9bf5c20..0a2641ebdf 100644 --- a/rust/src/http/resolver.rs +++ b/rust/src/http/resolver.rs @@ -20,7 +20,7 @@ use hyper::service::Service; use hyper::Client; use hyper_tls::HttpsConnector; use hyper_tls::MaybeHttpsStream; -use log::info; +use log::{debug, info}; use native_tls::TlsConnector; use serde::Deserialize; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; @@ -28,6 +28,7 @@ use tokio::net::TcpStream; use tokio_native_tls::TlsConnector as AsyncTlsConnector; use trust_dns_resolver::error::ResolveErrorKind; +#[derive(Debug, Clone)] pub struct Endpoint { pub host: String, pub port: u16, @@ -213,6 +214,7 @@ impl Service<Uri> for MatrixConnector { let resolver = self.resolver.clone(); if dst.scheme_str() != Some("matrix") { + debug!("Got non-matrix scheme"); return HttpsConnector::new() .call(dst) .map_err(|e| Error::msg(e)) @@ -227,6 +229,8 @@ impl Service<Uri> for MatrixConnector { ) .await?; + debug!("Got endpoints: {:?}", endpoints); + for endpoint in endpoints { match try_connecting(&dst, &endpoint).await { Ok(r) => return Ok(r), |