diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py
index 3b28417376..d2acc7654d 100755
--- a/scripts-dev/federation_client.py
+++ b/scripts-dev/federation_client.py
@@ -18,14 +18,22 @@
from __future__ import print_function
import argparse
+from urlparse import urlparse, urlunparse
+
import nacl.signing
import json
import base64
import requests
import sys
+
+from requests.adapters import HTTPAdapter
import srvlookup
import yaml
+# uncomment the following to enable debug logging of http requests
+#from httplib import HTTPConnection
+#HTTPConnection.debuglevel = 1
+
def encode_base64(input_bytes):
"""Encode bytes as a base64 string without any padding."""
@@ -113,17 +121,6 @@ def read_signing_keys(stream):
return keys
-def lookup(destination, path):
- if ":" in destination:
- return "https://%s%s" % (destination, path)
- else:
- try:
- srv = srvlookup.lookup("matrix", "tcp", destination)[0]
- return "https://%s:%d%s" % (srv.host, srv.port, path)
- except:
- return "https://%s:%d%s" % (destination, 8448, path)
-
-
def request_json(method, origin_name, origin_key, destination, path, content):
if method is None:
if content is None:
@@ -152,13 +149,19 @@ def request_json(method, origin_name, origin_key, destination, path, content):
authorization_headers.append(bytes(header))
print ("Authorization: %s" % header, file=sys.stderr)
- dest = lookup(destination, path)
+ dest = "matrix://%s%s" % (destination, path)
print ("Requesting %s" % dest, file=sys.stderr)
- result = requests.request(
+ s = requests.Session()
+ s.mount("matrix://", MatrixConnectionAdapter())
+
+ result = s.request(
method=method,
url=dest,
- headers={"Authorization": authorization_headers[0]},
+ headers={
+ "Host": destination,
+ "Authorization": authorization_headers[0]
+ },
verify=False,
data=content,
)
@@ -242,5 +245,39 @@ def read_args_from_config(args):
args.signing_key_path = config['signing_key_path']
+class MatrixConnectionAdapter(HTTPAdapter):
+ @staticmethod
+ def lookup(s):
+ if s[-1] == ']':
+ # ipv6 literal (with no port)
+ return s, 8448
+
+ if ":" in s:
+ out = s.rsplit(":",1)
+ try:
+ port = int(out[1])
+ except ValueError:
+ raise ValueError("Invalid host:port '%s'" % s)
+ return out[0], port
+
+ try:
+ srv = srvlookup.lookup("matrix", "tcp", s)[0]
+ return srv.host, srv.port
+ except:
+ return s, 8448
+
+ def get_connection(self, url, proxies=None):
+ parsed = urlparse(url)
+
+ (host, port) = self.lookup(parsed.netloc)
+ netloc = "%s:%d" % (host, port)
+ print("Connecting to %s" % (netloc,), file=sys.stderr)
+ url = urlunparse((
+ "https", netloc, parsed.path, parsed.params, parsed.query,
+ parsed.fragment,
+ ))
+ return super(MatrixConnectionAdapter, self).get_connection(url, proxies)
+
+
if __name__ == "__main__":
main()
|