diff --git a/rust/src/http/mod.rs b/rust/src/http/mod.rs
index 508f7cb048..c764f7c76a 100644
--- a/rust/src/http/mod.rs
+++ b/rust/src/http/mod.rs
@@ -1,7 +1,7 @@
use std::collections::HashMap;
use anyhow::Error;
-use http::Request;
+use http::{Request, Uri};
use hyper::Body;
use log::info;
use pyo3::{
@@ -12,7 +12,7 @@ use pyo3::{
use self::resolver::{MatrixConnector, MatrixResolver};
-mod resolver;
+pub mod resolver;
/// Called when registering modules with python.
pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
@@ -31,8 +31,8 @@ pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
Ok(())
}
-#[derive(Clone)]
-struct Bytes(Vec<u8>);
+#[derive(Clone, Debug)]
+pub struct Bytes(pub Vec<u8>);
impl ToPyObject for Bytes {
fn to_object(&self, py: Python<'_>) -> pyo3::PyObject {
@@ -46,31 +46,34 @@ impl IntoPy<PyObject> for Bytes {
}
}
+#[derive(Debug)]
#[pyclass]
pub struct MatrixResponse {
#[pyo3(get)]
- code: u16,
+ pub code: u16,
#[pyo3(get)]
- phrase: &'static str,
+ pub phrase: &'static str,
#[pyo3(get)]
- content: Bytes,
+ pub content: Bytes,
#[pyo3(get)]
- headers: HashMap<String, Bytes>,
+ pub headers: HashMap<String, Bytes>,
}
#[pyclass]
#[derive(Clone)]
pub struct HttpClient {
client: hyper::Client<MatrixConnector>,
+ resolver: MatrixResolver,
}
impl HttpClient {
pub fn new() -> Result<Self, Error> {
let resolver = MatrixResolver::new()?;
- let client = hyper::Client::builder().build(MatrixConnector::with_resolver(resolver));
+ let client =
+ hyper::Client::builder().build(MatrixConnector::with_resolver(resolver.clone()));
- Ok(HttpClient { client })
+ Ok(HttpClient { client, resolver })
}
pub async fn async_request(
@@ -80,7 +83,9 @@ impl HttpClient {
headers: HashMap<Vec<u8>, Vec<Vec<u8>>>,
body: Option<Vec<u8>>,
) -> Result<MatrixResponse, Error> {
- let mut builder = Request::builder().method(&*method).uri(url);
+ let uri: Uri = url.try_into()?;
+
+ let mut builder = Request::builder().method(&*method).uri(uri.clone());
for (key, values) in headers {
for value in values {
@@ -88,6 +93,13 @@ impl HttpClient {
}
}
+ if uri.scheme_str() == Some("matrix") {
+ let endpoints = self.resolver.resolve_server_name_from_uri(&uri).await?;
+ if let Some(endpoint) = endpoints.first() {
+ builder = builder.header("Host", &endpoint.host_header);
+ }
+ }
+
let request = if let Some(body) = body {
builder.body(Body::from(body))?
} else {
|