summary refs log tree commit diff
path: root/scripts-dev/federation_client.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts-dev/federation_client.py')
-rwxr-xr-x[-rw-r--r--]scripts-dev/federation_client.py180
1 files changed, 157 insertions, 23 deletions
diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py
index d1ab42d3af..d2acc7654d 100644..100755
--- a/scripts-dev/federation_client.py
+++ b/scripts-dev/federation_client.py
@@ -1,10 +1,38 @@
+#!/usr/bin/env python
+#
+# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import argparse
+from urlparse import urlparse, urlunparse
+
 import nacl.signing
 import json
 import base64
 import requests
 import sys
+
+from requests.adapters import HTTPAdapter
 import srvlookup
+import yaml
 
+# uncomment the following to enable debug logging of http requests
+#from httplib import HTTPConnection
+#HTTPConnection.debuglevel = 1
 
 def encode_base64(input_bytes):
     """Encode bytes as a base64 string without any padding."""
@@ -93,25 +121,24 @@ def read_signing_keys(stream):
     return keys
 
 
-def lookup(destination, path):
-    if ":" in destination:
-        return "https://%s%s" % (destination, path)
-    else:
-        try:
-            srv = srvlookup.lookup("matrix", "tcp", destination)[0]
-            return "https://%s:%d%s" % (srv.host, srv.port, path)
-        except:
-            return "https://%s:%d%s" % (destination, 8448, path)
+def request_json(method, origin_name, origin_key, destination, path, content):
+    if method is None:
+        if content is None:
+            method = "GET"
+        else:
+            method = "POST"
 
-def get_json(origin_name, origin_key, destination, path):
-    request_json = {
-        "method": "GET",
+    json_to_sign = {
+        "method": method,
         "uri": path,
         "origin": origin_name,
         "destination": destination,
     }
 
-    signed_json = sign_json(request_json, origin_key, origin_name)
+    if content is not None:
+        json_to_sign["content"] = json.loads(content)
+
+    signed_json = sign_json(json_to_sign, origin_key, origin_name)
 
     authorization_headers = []
 
@@ -120,30 +147,137 @@ def get_json(origin_name, origin_key, destination, path):
             origin_name, key, sig,
         )
         authorization_headers.append(bytes(header))
-        sys.stderr.write(header)
-        sys.stderr.write("\n")
+        print ("Authorization: %s" % header, file=sys.stderr)
+
+    dest = "matrix://%s%s" % (destination, path)
+    print ("Requesting %s" % dest, file=sys.stderr)
 
-    result = requests.get(
-        lookup(destination, path),
-        headers={"Authorization": authorization_headers[0]},
+    s = requests.Session()
+    s.mount("matrix://", MatrixConnectionAdapter())
+
+    result = s.request(
+        method=method,
+        url=dest,
+        headers={
+            "Host": destination,
+            "Authorization": authorization_headers[0]
+        },
         verify=False,
+        data=content,
     )
     sys.stderr.write("Status Code: %d\n" % (result.status_code,))
     return result.json()
 
 
 def main():
-    origin_name, keyfile, destination, path = sys.argv[1:]
+    parser = argparse.ArgumentParser(
+        description=
+            "Signs and sends a federation request to a matrix homeserver",
+    )
+
+    parser.add_argument(
+        "-N", "--server-name",
+        help="Name to give as the local homeserver. If unspecified, will be "
+             "read from the config file.",
+    )
+
+    parser.add_argument(
+        "-k", "--signing-key-path",
+        help="Path to the file containing the private ed25519 key to sign the "
+             "request with.",
+    )
+
+    parser.add_argument(
+        "-c", "--config",
+        default="homeserver.yaml",
+        help="Path to server config file. Ignored if --server-name and "
+             "--signing-key-path are both given.",
+    )
+
+    parser.add_argument(
+        "-d", "--destination",
+        default="matrix.org",
+        help="name of the remote homeserver. We will do SRV lookups and "
+             "connect appropriately.",
+    )
+
+    parser.add_argument(
+        "-X", "--method",
+        help="HTTP method to use for the request. Defaults to GET if --data is"
+             "unspecified, POST if it is."
+    )
+
+    parser.add_argument(
+        "--body",
+        help="Data to send as the body of the HTTP request"
+    )
+
+    parser.add_argument(
+        "path",
+        help="request path. We will add '/_matrix/federation/v1/' to this."
+    )
 
-    with open(keyfile) as f:
+    args = parser.parse_args()
+
+    if not args.server_name or not args.signing_key_path:
+        read_args_from_config(args)
+
+    with open(args.signing_key_path) as f:
         key = read_signing_keys(f)[0]
 
-    result = get_json(
-        origin_name, key, destination, "/_matrix/federation/v1/" + path
+    result = request_json(
+        args.method,
+        args.server_name, key, args.destination,
+        "/_matrix/federation/v1/" + args.path,
+        content=args.body,
     )
 
     json.dump(result, sys.stdout)
-    print ""
+    print ("")
+
+
+def read_args_from_config(args):
+    with open(args.config, 'r') as fh:
+        config = yaml.safe_load(fh)
+        if not args.server_name:
+            args.server_name = config['server_name']
+        if not args.signing_key_path:
+            args.signing_key_path = config['signing_key_path']
+
+
+class MatrixConnectionAdapter(HTTPAdapter):
+    @staticmethod
+    def lookup(s):
+        if s[-1] == ']':
+            # ipv6 literal (with no port)
+            return s, 8448
+
+        if ":" in s:
+            out = s.rsplit(":",1)
+            try:
+                port = int(out[1])
+            except ValueError:
+                raise ValueError("Invalid host:port '%s'" % s)
+            return out[0], port
+
+        try:
+            srv = srvlookup.lookup("matrix", "tcp", s)[0]
+            return srv.host, srv.port
+        except:
+            return s, 8448
+
+    def get_connection(self, url, proxies=None):
+        parsed = urlparse(url)
+
+        (host, port) = self.lookup(parsed.netloc)
+        netloc = "%s:%d" % (host, port)
+        print("Connecting to %s" % (netloc,), file=sys.stderr)
+        url = urlunparse((
+            "https", netloc, parsed.path, parsed.params, parsed.query,
+            parsed.fragment,
+        ))
+        return super(MatrixConnectionAdapter, self).get_connection(url, proxies)
+
 
 if __name__ == "__main__":
     main()