diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py
index 763dd02c47..b1d5e2e616 100755
--- a/scripts-dev/federation_client.py
+++ b/scripts-dev/federation_client.py
@@ -46,11 +46,12 @@ import signedjson.key
import signedjson.types
import srvlookup
import yaml
+from requests import PreparedRequest, Response
from requests.adapters import HTTPAdapter
from urllib3 import HTTPConnectionPool
# uncomment the following to enable debug logging of http requests
-# from httplib import HTTPConnection
+# from http.client import HTTPConnection
# HTTPConnection.debuglevel = 1
@@ -103,6 +104,7 @@ def request(
destination: str,
path: str,
content: Optional[str],
+ verify_tls: bool,
) -> requests.Response:
if method is None:
if content is None:
@@ -141,7 +143,6 @@ def request(
s.mount("matrix://", MatrixConnectionAdapter())
headers: Dict[str, str] = {
- "Host": destination,
"Authorization": authorization_headers[0],
}
@@ -152,7 +153,7 @@ def request(
method=method,
url=dest,
headers=headers,
- verify=False,
+ verify=verify_tls,
data=content,
stream=True,
)
@@ -203,6 +204,12 @@ def main() -> None:
parser.add_argument("--body", help="Data to send as the body of the HTTP request")
parser.add_argument(
+ "--insecure",
+ action="store_true",
+ help="Disable TLS certificate verification",
+ )
+
+ parser.add_argument(
"path", help="request path, including the '/_matrix/federation/...' prefix."
)
@@ -227,6 +234,7 @@ def main() -> None:
args.destination,
args.path,
content=args.body,
+ verify_tls=not args.insecure,
)
sys.stderr.write("Status Code: %d\n" % (result.status_code,))
@@ -254,36 +262,93 @@ def read_args_from_config(args: argparse.Namespace) -> None:
class MatrixConnectionAdapter(HTTPAdapter):
+ def send(
+ self,
+ request: PreparedRequest,
+ *args: Any,
+ **kwargs: Any,
+ ) -> Response:
+ # overrides the send() method in the base class.
+
+ # We need to look for .well-known redirects before passing the request up to
+ # HTTPAdapter.send().
+ assert isinstance(request.url, str)
+ parsed = urlparse.urlsplit(request.url)
+ server_name = parsed.netloc
+ well_known = self._get_well_known(parsed.netloc)
+
+ if well_known:
+ server_name = well_known
+
+ # replace the scheme in the uri with https, so that cert verification is done
+ # also replace the hostname if we got a .well-known result
+ request.url = urlparse.urlunsplit(
+ ("https", server_name, parsed.path, parsed.query, parsed.fragment)
+ )
+
+ # at this point we also add the host header (otherwise urllib will add one
+ # based on the `host` from the connection returned by `get_connection`,
+ # which will be wrong if there is an SRV record).
+ request.headers["Host"] = server_name
+
+ return super().send(request, *args, **kwargs)
+
+ def get_connection(
+ self, url: str, proxies: Optional[Dict[str, str]] = None
+ ) -> HTTPConnectionPool:
+ # overrides the get_connection() method in the base class
+ parsed = urlparse.urlsplit(url)
+ (host, port, ssl_server_name) = self._lookup(parsed.netloc)
+ print(
+ f"Connecting to {host}:{port} with SNI {ssl_server_name}", file=sys.stderr
+ )
+ return self.poolmanager.connection_from_host(
+ host,
+ port=port,
+ scheme="https",
+ pool_kwargs={"server_hostname": ssl_server_name},
+ )
+
@staticmethod
- def lookup(s: str, skip_well_known: bool = False) -> Tuple[str, int]:
- if s[-1] == "]":
+ def _lookup(server_name: str) -> Tuple[str, int, str]:
+ """
+ Do an SRV lookup on a server name and return the host:port to connect to
+ Given the server_name (after any .well-known lookup), return the host, port and
+ the ssl server name
+ """
+ if server_name[-1] == "]":
# ipv6 literal (with no port)
- return s, 8448
+ return server_name, 8448, server_name
- if ":" in s:
- out = s.rsplit(":", 1)
+ if ":" in server_name:
+ # explicit port
+ out = server_name.rsplit(":", 1)
try:
port = int(out[1])
except ValueError:
- raise ValueError("Invalid host:port '%s'" % s)
- return out[0], port
-
- # try a .well-known lookup
- if not skip_well_known:
- well_known = MatrixConnectionAdapter.get_well_known(s)
- if well_known:
- return MatrixConnectionAdapter.lookup(well_known, skip_well_known=True)
+ raise ValueError("Invalid host:port '%s'" % (server_name,))
+ return out[0], port, out[0]
try:
- srv = srvlookup.lookup("matrix", "tcp", s)[0]
- return srv.host, srv.port
+ srv = srvlookup.lookup("matrix", "tcp", server_name)[0]
+ print(
+ f"SRV lookup on _matrix._tcp.{server_name} gave {srv}",
+ file=sys.stderr,
+ )
+ return srv.host, srv.port, server_name
except Exception:
- return s, 8448
+ return server_name, 8448, server_name
@staticmethod
- def get_well_known(server_name: str) -> Optional[str]:
- uri = "https://%s/.well-known/matrix/server" % (server_name,)
- print("fetching %s" % (uri,), file=sys.stderr)
+ def _get_well_known(server_name: str) -> Optional[str]:
+ if ":" in server_name:
+ # explicit port, or ipv6 literal. Either way, no .well-known
+ return None
+
+ # TODO: check for ipv4 literals
+
+ uri = f"https://{server_name}/.well-known/matrix/server"
+ print(f"fetching {uri}", file=sys.stderr)
try:
resp = requests.get(uri)
@@ -304,19 +369,6 @@ class MatrixConnectionAdapter(HTTPAdapter):
print("Invalid response from %s: %s" % (uri, e), file=sys.stderr)
return None
- def get_connection(
- self, url: str, proxies: Optional[Dict[str, str]] = None
- ) -> HTTPConnectionPool:
- parsed = urlparse.urlparse(url)
-
- (host, port) = self.lookup(parsed.netloc)
- netloc = "%s:%d" % (host, port)
- print("Connecting to %s" % (netloc,), file=sys.stderr)
- url = urlparse.urlunparse(
- ("https", netloc, parsed.path, parsed.params, parsed.query, parsed.fragment)
- )
- return super().get_connection(url, proxies)
-
if __name__ == "__main__":
main()
|