diff options
Diffstat (limited to 'rust/src/http/mod.rs')
-rw-r--r-- | rust/src/http/mod.rs | 34 |
1 files changed, 23 insertions, 11 deletions
diff --git a/rust/src/http/mod.rs b/rust/src/http/mod.rs index 508f7cb048..c764f7c76a 100644 --- a/rust/src/http/mod.rs +++ b/rust/src/http/mod.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use anyhow::Error; -use http::Request; +use http::{Request, Uri}; use hyper::Body; use log::info; use pyo3::{ @@ -12,7 +12,7 @@ use pyo3::{ use self::resolver::{MatrixConnector, MatrixResolver}; -mod resolver; +pub mod resolver; /// Called when registering modules with python. pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> { @@ -31,8 +31,8 @@ pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> { Ok(()) } -#[derive(Clone)] -struct Bytes(Vec<u8>); +#[derive(Clone, Debug)] +pub struct Bytes(pub Vec<u8>); impl ToPyObject for Bytes { fn to_object(&self, py: Python<'_>) -> pyo3::PyObject { @@ -46,31 +46,34 @@ impl IntoPy<PyObject> for Bytes { } } +#[derive(Debug)] #[pyclass] pub struct MatrixResponse { #[pyo3(get)] - code: u16, + pub code: u16, #[pyo3(get)] - phrase: &'static str, + pub phrase: &'static str, #[pyo3(get)] - content: Bytes, + pub content: Bytes, #[pyo3(get)] - headers: HashMap<String, Bytes>, + pub headers: HashMap<String, Bytes>, } #[pyclass] #[derive(Clone)] pub struct HttpClient { client: hyper::Client<MatrixConnector>, + resolver: MatrixResolver, } impl HttpClient { pub fn new() -> Result<Self, Error> { let resolver = MatrixResolver::new()?; - let client = hyper::Client::builder().build(MatrixConnector::with_resolver(resolver)); + let client = + hyper::Client::builder().build(MatrixConnector::with_resolver(resolver.clone())); - Ok(HttpClient { client }) + Ok(HttpClient { client, resolver }) } pub async fn async_request( @@ -80,7 +83,9 @@ impl HttpClient { headers: HashMap<Vec<u8>, Vec<Vec<u8>>>, body: Option<Vec<u8>>, ) -> Result<MatrixResponse, Error> { - let mut builder = Request::builder().method(&*method).uri(url); + let uri: Uri = url.try_into()?; + + let mut builder = Request::builder().method(&*method).uri(uri.clone()); for (key, values) in headers { for value in values { @@ -88,6 +93,13 @@ impl HttpClient { } } + if uri.scheme_str() == Some("matrix") { + let endpoints = self.resolver.resolve_server_name_from_uri(&uri).await?; + if let Some(endpoint) = endpoints.first() { + builder = builder.header("Host", &endpoint.host_header); + } + } + let request = if let Some(body) = body { builder.body(Body::from(body))? } else { |