summary refs log tree commit diff
path: root/scripts-dev/federation_client.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts-dev/federation_client.py')
-rwxr-xr-xscripts-dev/federation_client.py25
1 files changed, 18 insertions, 7 deletions
diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py
index 4c758e5424..fb879ef555 100755
--- a/scripts-dev/federation_client.py
+++ b/scripts-dev/federation_client.py
@@ -43,7 +43,7 @@ import argparse
 import base64
 import json
 import sys
-from typing import Any, Dict, Optional, Tuple
+from typing import Any, Dict, Mapping, Optional, Tuple, Union
 from urllib import parse as urlparse
 
 import requests
@@ -75,7 +75,7 @@ def encode_canonical_json(value: object) -> bytes:
         value,
         # Encode code-points outside of ASCII as UTF-8 rather than \u escapes
         ensure_ascii=False,
-        # Remove unecessary white space.
+        # Remove unnecessary white space.
         separators=(",", ":"),
         # Sort the keys of dictionaries.
         sort_keys=True,
@@ -298,12 +298,23 @@ class MatrixConnectionAdapter(HTTPAdapter):
 
         return super().send(request, *args, **kwargs)
 
-    def get_connection(
-        self, url: str, proxies: Optional[Dict[str, str]] = None
+    def get_connection_with_tls_context(
+        self,
+        request: PreparedRequest,
+        verify: Optional[Union[bool, str]],
+        proxies: Optional[Mapping[str, str]] = None,
+        cert: Optional[Union[Tuple[str, str], str]] = None,
     ) -> HTTPConnectionPool:
-        # overrides the get_connection() method in the base class
-        parsed = urlparse.urlsplit(url)
-        (host, port, ssl_server_name) = self._lookup(parsed.netloc)
+        # overrides the get_connection_with_tls_context() method in the base class
+        parsed = urlparse.urlsplit(request.url)
+
+        # Extract the server name from the request URL, and ensure it's a str.
+        hostname = parsed.netloc
+        if isinstance(hostname, bytes):
+            hostname = hostname.decode("utf-8")
+        assert isinstance(hostname, str)
+
+        (host, port, ssl_server_name) = self._lookup(hostname)
         print(
             f"Connecting to {host}:{port} with SNI {ssl_server_name}", file=sys.stderr
         )