summary refs log tree commit diff
path: root/scripts-dev/federation_client.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts-dev/federation_client.py')
-rwxr-xr-xscripts-dev/federation_client.py122
1 files changed, 87 insertions, 35 deletions
diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py
index 763dd02c47..b1d5e2e616 100755
--- a/scripts-dev/federation_client.py
+++ b/scripts-dev/federation_client.py
@@ -46,11 +46,12 @@ import signedjson.key
 import signedjson.types
 import srvlookup
 import yaml
+from requests import PreparedRequest, Response
 from requests.adapters import HTTPAdapter
 from urllib3 import HTTPConnectionPool
 
 # uncomment the following to enable debug logging of http requests
-# from httplib import HTTPConnection
+# from http.client import HTTPConnection
 # HTTPConnection.debuglevel = 1
 
 
@@ -103,6 +104,7 @@ def request(
     destination: str,
     path: str,
     content: Optional[str],
+    verify_tls: bool,
 ) -> requests.Response:
     if method is None:
         if content is None:
@@ -141,7 +143,6 @@ def request(
     s.mount("matrix://", MatrixConnectionAdapter())
 
     headers: Dict[str, str] = {
-        "Host": destination,
         "Authorization": authorization_headers[0],
     }
 
@@ -152,7 +153,7 @@ def request(
         method=method,
         url=dest,
         headers=headers,
-        verify=False,
+        verify=verify_tls,
         data=content,
         stream=True,
     )
@@ -203,6 +204,12 @@ def main() -> None:
     parser.add_argument("--body", help="Data to send as the body of the HTTP request")
 
     parser.add_argument(
+        "--insecure",
+        action="store_true",
+        help="Disable TLS certificate verification",
+    )
+
+    parser.add_argument(
         "path", help="request path, including the '/_matrix/federation/...' prefix."
     )
 
@@ -227,6 +234,7 @@ def main() -> None:
         args.destination,
         args.path,
         content=args.body,
+        verify_tls=not args.insecure,
     )
 
     sys.stderr.write("Status Code: %d\n" % (result.status_code,))
@@ -254,36 +262,93 @@ def read_args_from_config(args: argparse.Namespace) -> None:
 
 
 class MatrixConnectionAdapter(HTTPAdapter):
+    def send(
+        self,
+        request: PreparedRequest,
+        *args: Any,
+        **kwargs: Any,
+    ) -> Response:
+        # overrides the send() method in the base class.
+
+        # We need to look for .well-known redirects before passing the request up to
+        # HTTPAdapter.send().
+        assert isinstance(request.url, str)
+        parsed = urlparse.urlsplit(request.url)
+        server_name = parsed.netloc
+        well_known = self._get_well_known(parsed.netloc)
+
+        if well_known:
+            server_name = well_known
+
+        # replace the scheme in the uri with https, so that cert verification is done
+        # also replace the hostname if we got a .well-known result
+        request.url = urlparse.urlunsplit(
+            ("https", server_name, parsed.path, parsed.query, parsed.fragment)
+        )
+
+        # at this point we also add the host header (otherwise urllib will add one
+        # based on the `host` from the connection returned by `get_connection`,
+        # which will be wrong if there is an SRV record).
+        request.headers["Host"] = server_name
+
+        return super().send(request, *args, **kwargs)
+
+    def get_connection(
+        self, url: str, proxies: Optional[Dict[str, str]] = None
+    ) -> HTTPConnectionPool:
+        # overrides the get_connection() method in the base class
+        parsed = urlparse.urlsplit(url)
+        (host, port, ssl_server_name) = self._lookup(parsed.netloc)
+        print(
+            f"Connecting to {host}:{port} with SNI {ssl_server_name}", file=sys.stderr
+        )
+        return self.poolmanager.connection_from_host(
+            host,
+            port=port,
+            scheme="https",
+            pool_kwargs={"server_hostname": ssl_server_name},
+        )
+
     @staticmethod
-    def lookup(s: str, skip_well_known: bool = False) -> Tuple[str, int]:
-        if s[-1] == "]":
+    def _lookup(server_name: str) -> Tuple[str, int, str]:
+        """
+        Do an SRV lookup on a server name and return the host:port to connect to
+        Given the server_name (after any .well-known lookup), return the host, port and
+        the ssl server name
+        """
+        if server_name[-1] == "]":
             # ipv6 literal (with no port)
-            return s, 8448
+            return server_name, 8448, server_name
 
-        if ":" in s:
-            out = s.rsplit(":", 1)
+        if ":" in server_name:
+            # explicit port
+            out = server_name.rsplit(":", 1)
             try:
                 port = int(out[1])
             except ValueError:
-                raise ValueError("Invalid host:port '%s'" % s)
-            return out[0], port
-
-        # try a .well-known lookup
-        if not skip_well_known:
-            well_known = MatrixConnectionAdapter.get_well_known(s)
-            if well_known:
-                return MatrixConnectionAdapter.lookup(well_known, skip_well_known=True)
+                raise ValueError("Invalid host:port '%s'" % (server_name,))
+            return out[0], port, out[0]
 
         try:
-            srv = srvlookup.lookup("matrix", "tcp", s)[0]
-            return srv.host, srv.port
+            srv = srvlookup.lookup("matrix", "tcp", server_name)[0]
+            print(
+                f"SRV lookup on _matrix._tcp.{server_name} gave {srv}",
+                file=sys.stderr,
+            )
+            return srv.host, srv.port, server_name
         except Exception:
-            return s, 8448
+            return server_name, 8448, server_name
 
     @staticmethod
-    def get_well_known(server_name: str) -> Optional[str]:
-        uri = "https://%s/.well-known/matrix/server" % (server_name,)
-        print("fetching %s" % (uri,), file=sys.stderr)
+    def _get_well_known(server_name: str) -> Optional[str]:
+        if ":" in server_name:
+            # explicit port, or ipv6 literal. Either way, no .well-known
+            return None
+
+        # TODO: check for ipv4 literals
+
+        uri = f"https://{server_name}/.well-known/matrix/server"
+        print(f"fetching {uri}", file=sys.stderr)
 
         try:
             resp = requests.get(uri)
@@ -304,19 +369,6 @@ class MatrixConnectionAdapter(HTTPAdapter):
             print("Invalid response from %s: %s" % (uri, e), file=sys.stderr)
         return None
 
-    def get_connection(
-        self, url: str, proxies: Optional[Dict[str, str]] = None
-    ) -> HTTPConnectionPool:
-        parsed = urlparse.urlparse(url)
-
-        (host, port) = self.lookup(parsed.netloc)
-        netloc = "%s:%d" % (host, port)
-        print("Connecting to %s" % (netloc,), file=sys.stderr)
-        url = urlparse.urlunparse(
-            ("https", netloc, parsed.path, parsed.params, parsed.query, parsed.fragment)
-        )
-        return super().get_connection(url, proxies)
-
 
 if __name__ == "__main__":
     main()