summary refs log tree commit diff
diff options
context:
space:
mode:
-rwxr-xr-xscripts-dev/federation_client.py65
1 files changed, 51 insertions, 14 deletions
diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py
index 3b28417376..d2acc7654d 100755
--- a/scripts-dev/federation_client.py
+++ b/scripts-dev/federation_client.py
@@ -18,14 +18,22 @@
 from __future__ import print_function
 
 import argparse
+from urlparse import urlparse, urlunparse
+
 import nacl.signing
 import json
 import base64
 import requests
 import sys
+
+from requests.adapters import HTTPAdapter
 import srvlookup
 import yaml
 
+# uncomment the following to enable debug logging of http requests
+#from httplib import HTTPConnection
+#HTTPConnection.debuglevel = 1
+
 def encode_base64(input_bytes):
     """Encode bytes as a base64 string without any padding."""
 
@@ -113,17 +121,6 @@ def read_signing_keys(stream):
     return keys
 
 
-def lookup(destination, path):
-    if ":" in destination:
-        return "https://%s%s" % (destination, path)
-    else:
-        try:
-            srv = srvlookup.lookup("matrix", "tcp", destination)[0]
-            return "https://%s:%d%s" % (srv.host, srv.port, path)
-        except:
-            return "https://%s:%d%s" % (destination, 8448, path)
-
-
 def request_json(method, origin_name, origin_key, destination, path, content):
     if method is None:
         if content is None:
@@ -152,13 +149,19 @@ def request_json(method, origin_name, origin_key, destination, path, content):
         authorization_headers.append(bytes(header))
         print ("Authorization: %s" % header, file=sys.stderr)
 
-    dest = lookup(destination, path)
+    dest = "matrix://%s%s" % (destination, path)
     print ("Requesting %s" % dest, file=sys.stderr)
 
-    result = requests.request(
+    s = requests.Session()
+    s.mount("matrix://", MatrixConnectionAdapter())
+
+    result = s.request(
         method=method,
         url=dest,
-        headers={"Authorization": authorization_headers[0]},
+        headers={
+            "Host": destination,
+            "Authorization": authorization_headers[0]
+        },
         verify=False,
         data=content,
     )
@@ -242,5 +245,39 @@ def read_args_from_config(args):
             args.signing_key_path = config['signing_key_path']
 
 
+class MatrixConnectionAdapter(HTTPAdapter):
+    @staticmethod
+    def lookup(s):
+        if s[-1] == ']':
+            # ipv6 literal (with no port)
+            return s, 8448
+
+        if ":" in s:
+            out = s.rsplit(":",1)
+            try:
+                port = int(out[1])
+            except ValueError:
+                raise ValueError("Invalid host:port '%s'" % s)
+            return out[0], port
+
+        try:
+            srv = srvlookup.lookup("matrix", "tcp", s)[0]
+            return srv.host, srv.port
+        except:
+            return s, 8448
+
+    def get_connection(self, url, proxies=None):
+        parsed = urlparse(url)
+
+        (host, port) = self.lookup(parsed.netloc)
+        netloc = "%s:%d" % (host, port)
+        print("Connecting to %s" % (netloc,), file=sys.stderr)
+        url = urlunparse((
+            "https", netloc, parsed.path, parsed.params, parsed.query,
+            parsed.fragment,
+        ))
+        return super(MatrixConnectionAdapter, self).get_connection(url, proxies)
+
+
 if __name__ == "__main__":
     main()