summary refs log tree commit diff
path: root/rust/src/http/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'rust/src/http/mod.rs')
-rw-r--r--rust/src/http/mod.rs34
1 files changed, 23 insertions, 11 deletions
diff --git a/rust/src/http/mod.rs b/rust/src/http/mod.rs
index 508f7cb048..c764f7c76a 100644
--- a/rust/src/http/mod.rs
+++ b/rust/src/http/mod.rs
@@ -1,7 +1,7 @@
 use std::collections::HashMap;
 
 use anyhow::Error;
-use http::Request;
+use http::{Request, Uri};
 use hyper::Body;
 use log::info;
 use pyo3::{
@@ -12,7 +12,7 @@ use pyo3::{
 
 use self::resolver::{MatrixConnector, MatrixResolver};
 
-mod resolver;
+pub mod resolver;
 
 /// Called when registering modules with python.
 pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
@@ -31,8 +31,8 @@ pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
     Ok(())
 }
 
-#[derive(Clone)]
-struct Bytes(Vec<u8>);
+#[derive(Clone, Debug)]
+pub struct Bytes(pub Vec<u8>);
 
 impl ToPyObject for Bytes {
     fn to_object(&self, py: Python<'_>) -> pyo3::PyObject {
@@ -46,31 +46,34 @@ impl IntoPy<PyObject> for Bytes {
     }
 }
 
+#[derive(Debug)]
 #[pyclass]
 pub struct MatrixResponse {
     #[pyo3(get)]
-    code: u16,
+    pub code: u16,
     #[pyo3(get)]
-    phrase: &'static str,
+    pub phrase: &'static str,
     #[pyo3(get)]
-    content: Bytes,
+    pub content: Bytes,
     #[pyo3(get)]
-    headers: HashMap<String, Bytes>,
+    pub headers: HashMap<String, Bytes>,
 }
 
 #[pyclass]
 #[derive(Clone)]
 pub struct HttpClient {
     client: hyper::Client<MatrixConnector>,
+    resolver: MatrixResolver,
 }
 
 impl HttpClient {
     pub fn new() -> Result<Self, Error> {
         let resolver = MatrixResolver::new()?;
 
-        let client = hyper::Client::builder().build(MatrixConnector::with_resolver(resolver));
+        let client =
+            hyper::Client::builder().build(MatrixConnector::with_resolver(resolver.clone()));
 
-        Ok(HttpClient { client })
+        Ok(HttpClient { client, resolver })
     }
 
     pub async fn async_request(
@@ -80,7 +83,9 @@ impl HttpClient {
         headers: HashMap<Vec<u8>, Vec<Vec<u8>>>,
         body: Option<Vec<u8>>,
     ) -> Result<MatrixResponse, Error> {
-        let mut builder = Request::builder().method(&*method).uri(url);
+        let uri: Uri = url.try_into()?;
+
+        let mut builder = Request::builder().method(&*method).uri(uri.clone());
 
         for (key, values) in headers {
             for value in values {
@@ -88,6 +93,13 @@ impl HttpClient {
             }
         }
 
+        if uri.scheme_str() == Some("matrix") {
+            let endpoints = self.resolver.resolve_server_name_from_uri(&uri).await?;
+            if let Some(endpoint) = endpoints.first() {
+                builder = builder.header("Host", &endpoint.host_header);
+            }
+        }
+
         let request = if let Some(body) = body {
             builder.body(Body::from(body))?
         } else {