summary refs log tree commit diff
path: root/scripts
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2019-06-26 22:34:41 +0100
committerRichard van der Hoff <richard@matrix.org>2019-06-26 22:34:41 +0100
commita4daa899ec4cd195fc10936f68df5c78314b366c (patch)
tree35e88ff388b0f7652773a79930b732aa04f16bde /scripts
parentchangelog (diff)
parentImprove docs on choosing server_name (#5558) (diff)
downloadsynapse-a4daa899ec4cd195fc10936f68df5c78314b366c.tar.xz
Merge branch 'develop' into rav/saml2_client
Diffstat (limited to '')
-rw-r--r--scripts-dev/check_auth.py4
-rw-r--r--scripts-dev/check_event_hash.py54
-rw-r--r--scripts-dev/check_signature.py5
-rw-r--r--scripts-dev/convert_server_keys.py2
-rwxr-xr-xscripts-dev/definitions.py62
-rwxr-xr-xscripts-dev/federation_client.py59
-rw-r--r--scripts-dev/hash_history.py2
-rwxr-xr-xscripts-dev/list_url_patterns.py4
-rw-r--r--scripts-dev/tail-synapse.py2
-rwxr-xr-xscripts/generate_signing_key.py10
-rwxr-xr-xscripts/move_remote_media_to_new_store.py4
11 files changed, 90 insertions, 118 deletions
diff --git a/scripts-dev/check_auth.py b/scripts-dev/check_auth.py
index b3d11f49ec..2a1c5f39d4 100644
--- a/scripts-dev/check_auth.py
+++ b/scripts-dev/check_auth.py
@@ -39,11 +39,11 @@ def check_auth(auth, auth_chain, events):
         print("Success:", e.event_id, e.type, e.state_key)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     parser = argparse.ArgumentParser()
 
     parser.add_argument(
-        'json', nargs='?', type=argparse.FileType('r'), default=sys.stdin
+        "json", nargs="?", type=argparse.FileType("r"), default=sys.stdin
     )
 
     args = parser.parse_args()
diff --git a/scripts-dev/check_event_hash.py b/scripts-dev/check_event_hash.py
deleted file mode 100644
index 8535f99697..0000000000
--- a/scripts-dev/check_event_hash.py
+++ /dev/null
@@ -1,54 +0,0 @@
-import argparse
-import hashlib
-import json
-import logging
-import sys
-
-from unpaddedbase64 import encode_base64
-
-from synapse.crypto.event_signing import (
-    check_event_content_hash,
-    compute_event_reference_hash,
-)
-
-
-class dictobj(dict):
-    def __init__(self, *args, **kargs):
-        dict.__init__(self, *args, **kargs)
-        self.__dict__ = self
-
-    def get_dict(self):
-        return dict(self)
-
-    def get_full_dict(self):
-        return dict(self)
-
-    def get_pdu_json(self):
-        return dict(self)
-
-
-def main():
-    parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "input_json", nargs="?", type=argparse.FileType('r'), default=sys.stdin
-    )
-    args = parser.parse_args()
-    logging.basicConfig()
-
-    event_json = dictobj(json.load(args.input_json))
-
-    algorithms = {"sha256": hashlib.sha256}
-
-    for alg_name in event_json.hashes:
-        if check_event_content_hash(event_json, algorithms[alg_name]):
-            print("PASS content hash %s" % (alg_name,))
-        else:
-            print("FAIL content hash %s" % (alg_name,))
-
-    for algorithm in algorithms.values():
-        name, h_bytes = compute_event_reference_hash(event_json, algorithm)
-        print("Reference hash %s: %s" % (name, encode_base64(h_bytes)))
-
-
-if __name__ == "__main__":
-    main()
diff --git a/scripts-dev/check_signature.py b/scripts-dev/check_signature.py
index 612f17ca7f..ecda103cf7 100644
--- a/scripts-dev/check_signature.py
+++ b/scripts-dev/check_signature.py
@@ -1,4 +1,3 @@
-
 import argparse
 import json
 import logging
@@ -40,7 +39,7 @@ def main():
     parser = argparse.ArgumentParser()
     parser.add_argument("signature_name")
     parser.add_argument(
-        "input_json", nargs="?", type=argparse.FileType('r'), default=sys.stdin
+        "input_json", nargs="?", type=argparse.FileType("r"), default=sys.stdin
     )
 
     args = parser.parse_args()
@@ -69,5 +68,5 @@ def main():
             print("FAIL %s" % (key_id,))
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()
diff --git a/scripts-dev/convert_server_keys.py b/scripts-dev/convert_server_keys.py
index ac152b5c42..179be61c30 100644
--- a/scripts-dev/convert_server_keys.py
+++ b/scripts-dev/convert_server_keys.py
@@ -116,5 +116,5 @@ def main():
     connection.commit()
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()
diff --git a/scripts-dev/definitions.py b/scripts-dev/definitions.py
index 1deb0fe2b7..9eddb6d515 100755
--- a/scripts-dev/definitions.py
+++ b/scripts-dev/definitions.py
@@ -19,10 +19,10 @@ class DefinitionVisitor(ast.NodeVisitor):
         self.names = {}
         self.attrs = set()
         self.definitions = {
-            'def': self.functions,
-            'class': self.classes,
-            'names': self.names,
-            'attrs': self.attrs,
+            "def": self.functions,
+            "class": self.classes,
+            "names": self.names,
+            "attrs": self.attrs,
         }
 
     def visit_Name(self, node):
@@ -47,23 +47,23 @@ class DefinitionVisitor(ast.NodeVisitor):
 
 
 def non_empty(defs):
-    functions = {name: non_empty(f) for name, f in defs['def'].items()}
-    classes = {name: non_empty(f) for name, f in defs['class'].items()}
+    functions = {name: non_empty(f) for name, f in defs["def"].items()}
+    classes = {name: non_empty(f) for name, f in defs["class"].items()}
     result = {}
     if functions:
-        result['def'] = functions
+        result["def"] = functions
     if classes:
-        result['class'] = classes
-    names = defs['names']
+        result["class"] = classes
+    names = defs["names"]
     uses = []
-    for name in names.get('Load', ()):
-        if name not in names.get('Param', ()) and name not in names.get('Store', ()):
+    for name in names.get("Load", ()):
+        if name not in names.get("Param", ()) and name not in names.get("Store", ()):
             uses.append(name)
-    uses.extend(defs['attrs'])
+    uses.extend(defs["attrs"])
     if uses:
-        result['uses'] = uses
-    result['names'] = names
-    result['attrs'] = defs['attrs']
+        result["uses"] = uses
+    result["names"] = names
+    result["attrs"] = defs["attrs"]
     return result
 
 
@@ -81,33 +81,33 @@ def definitions_in_file(filepath):
 
 
 def defined_names(prefix, defs, names):
-    for name, funcs in defs.get('def', {}).items():
-        names.setdefault(name, {'defined': []})['defined'].append(prefix + name)
+    for name, funcs in defs.get("def", {}).items():
+        names.setdefault(name, {"defined": []})["defined"].append(prefix + name)
         defined_names(prefix + name + ".", funcs, names)
 
-    for name, funcs in defs.get('class', {}).items():
-        names.setdefault(name, {'defined': []})['defined'].append(prefix + name)
+    for name, funcs in defs.get("class", {}).items():
+        names.setdefault(name, {"defined": []})["defined"].append(prefix + name)
         defined_names(prefix + name + ".", funcs, names)
 
 
 def used_names(prefix, item, defs, names):
-    for name, funcs in defs.get('def', {}).items():
+    for name, funcs in defs.get("def", {}).items():
         used_names(prefix + name + ".", name, funcs, names)
 
-    for name, funcs in defs.get('class', {}).items():
+    for name, funcs in defs.get("class", {}).items():
         used_names(prefix + name + ".", name, funcs, names)
 
-    path = prefix.rstrip('.')
-    for used in defs.get('uses', ()):
+    path = prefix.rstrip(".")
+    for used in defs.get("uses", ()):
         if used in names:
             if item:
-                names[item].setdefault('uses', []).append(used)
-            names[used].setdefault('used', {}).setdefault(item, []).append(path)
+                names[item].setdefault("uses", []).append(used)
+            names[used].setdefault("used", {}).setdefault(item, []).append(path)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
 
-    parser = argparse.ArgumentParser(description='Find definitions.')
+    parser = argparse.ArgumentParser(description="Find definitions.")
     parser.add_argument(
         "--unused", action="store_true", help="Only list unused definitions"
     )
@@ -119,7 +119,7 @@ if __name__ == '__main__':
     )
     parser.add_argument(
         "directories",
-        nargs='+',
+        nargs="+",
         metavar="DIR",
         help="Directories to search for definitions",
     )
@@ -164,7 +164,7 @@ if __name__ == '__main__':
             continue
         if ignore and any(pattern.match(name) for pattern in ignore):
             continue
-        if args.unused and definition.get('used'):
+        if args.unused and definition.get("used"):
             continue
         result[name] = definition
 
@@ -196,9 +196,9 @@ if __name__ == '__main__':
                 continue
             result[name] = definition
 
-    if args.format == 'yaml':
+    if args.format == "yaml":
         yaml.dump(result, sys.stdout, default_flow_style=False)
-    elif args.format == 'dot':
+    elif args.format == "dot":
         print("digraph {")
         for name, entry in result.items():
             print(name)
diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py
index e0287c8c6c..7c19e405d4 100755
--- a/scripts-dev/federation_client.py
+++ b/scripts-dev/federation_client.py
@@ -21,7 +21,8 @@ import argparse
 import base64
 import json
 import sys
-from urlparse import urlparse, urlunparse
+
+from six.moves.urllib import parse as urlparse
 
 import nacl.signing
 import requests
@@ -62,7 +63,7 @@ def encode_canonical_json(value):
         # Encode code-points outside of ASCII as UTF-8 rather than \u escapes
         ensure_ascii=False,
         # Remove unecessary white space.
-        separators=(',', ':'),
+        separators=(",", ":"),
         # Sort the keys of dictionaries.
         sort_keys=True,
         # Encode the resulting unicode as UTF-8 bytes.
@@ -144,8 +145,8 @@ def request_json(method, origin_name, origin_key, destination, path, content):
     authorization_headers = []
 
     for key, sig in signed_json["signatures"][origin_name].items():
-        header = "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (origin_name, key, sig)
-        authorization_headers.append(bytes(header))
+        header = 'X-Matrix origin=%s,key="%s",sig="%s"' % (origin_name, key, sig)
+        authorization_headers.append(header.encode("ascii"))
         print("Authorization: %s" % header, file=sys.stderr)
 
     dest = "matrix://%s%s" % (destination, path)
@@ -160,11 +161,7 @@ def request_json(method, origin_name, origin_key, destination, path, content):
         headers["Content-Type"] = "application/json"
 
     result = s.request(
-        method=method,
-        url=dest,
-        headers=headers,
-        verify=False,
-        data=content,
+        method=method, url=dest, headers=headers, verify=False, data=content
     )
     sys.stderr.write("Status Code: %d\n" % (result.status_code,))
     return result.json()
@@ -240,18 +237,18 @@ def main():
 
 
 def read_args_from_config(args):
-    with open(args.config, 'r') as fh:
+    with open(args.config, "r") as fh:
         config = yaml.safe_load(fh)
         if not args.server_name:
-            args.server_name = config['server_name']
+            args.server_name = config["server_name"]
         if not args.signing_key_path:
-            args.signing_key_path = config['signing_key_path']
+            args.signing_key_path = config["signing_key_path"]
 
 
 class MatrixConnectionAdapter(HTTPAdapter):
     @staticmethod
-    def lookup(s):
-        if s[-1] == ']':
+    def lookup(s, skip_well_known=False):
+        if s[-1] == "]":
             # ipv6 literal (with no port)
             return s, 8448
 
@@ -263,19 +260,49 @@ class MatrixConnectionAdapter(HTTPAdapter):
                 raise ValueError("Invalid host:port '%s'" % s)
             return out[0], port
 
+        # try a .well-known lookup
+        if not skip_well_known:
+            well_known = MatrixConnectionAdapter.get_well_known(s)
+            if well_known:
+                return MatrixConnectionAdapter.lookup(well_known, skip_well_known=True)
+
         try:
             srv = srvlookup.lookup("matrix", "tcp", s)[0]
             return srv.host, srv.port
         except Exception:
             return s, 8448
 
+    @staticmethod
+    def get_well_known(server_name):
+        uri = "https://%s/.well-known/matrix/server" % (server_name,)
+        print("fetching %s" % (uri,), file=sys.stderr)
+
+        try:
+            resp = requests.get(uri)
+            if resp.status_code != 200:
+                print("%s gave %i" % (uri, resp.status_code), file=sys.stderr)
+                return None
+
+            parsed_well_known = resp.json()
+            if not isinstance(parsed_well_known, dict):
+                raise Exception("not a dict")
+            if "m.server" not in parsed_well_known:
+                raise Exception("Missing key 'm.server'")
+            new_name = parsed_well_known["m.server"]
+            print("well-known lookup gave %s" % (new_name,), file=sys.stderr)
+            return new_name
+
+        except Exception as e:
+            print("Invalid response from %s: %s" % (uri, e), file=sys.stderr)
+        return None
+
     def get_connection(self, url, proxies=None):
-        parsed = urlparse(url)
+        parsed = urlparse.urlparse(url)
 
         (host, port) = self.lookup(parsed.netloc)
         netloc = "%s:%d" % (host, port)
         print("Connecting to %s" % (netloc,), file=sys.stderr)
-        url = urlunparse(
+        url = urlparse.urlunparse(
             ("https", netloc, parsed.path, parsed.params, parsed.query, parsed.fragment)
         )
         return super(MatrixConnectionAdapter, self).get_connection(url, proxies)
diff --git a/scripts-dev/hash_history.py b/scripts-dev/hash_history.py
index 514d80fa60..d20f6db176 100644
--- a/scripts-dev/hash_history.py
+++ b/scripts-dev/hash_history.py
@@ -79,5 +79,5 @@ def main():
     conn.commit()
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()
diff --git a/scripts-dev/list_url_patterns.py b/scripts-dev/list_url_patterns.py
index 62e5a07472..26ad7c67f4 100755
--- a/scripts-dev/list_url_patterns.py
+++ b/scripts-dev/list_url_patterns.py
@@ -35,11 +35,11 @@ def find_patterns_in_file(filepath):
         find_patterns_in_code(f.read())
 
 
-parser = argparse.ArgumentParser(description='Find url patterns.')
+parser = argparse.ArgumentParser(description="Find url patterns.")
 
 parser.add_argument(
     "directories",
-    nargs='+',
+    nargs="+",
     metavar="DIR",
     help="Directories to search for definitions",
 )
diff --git a/scripts-dev/tail-synapse.py b/scripts-dev/tail-synapse.py
index 7c9985d9f0..44e3a6dbf1 100644
--- a/scripts-dev/tail-synapse.py
+++ b/scripts-dev/tail-synapse.py
@@ -63,5 +63,5 @@ def main():
             streams[update.name] = update.position
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()
diff --git a/scripts/generate_signing_key.py b/scripts/generate_signing_key.py
index ba3ba97395..16d7c4f382 100755
--- a/scripts/generate_signing_key.py
+++ b/scripts/generate_signing_key.py
@@ -16,7 +16,7 @@
 import argparse
 import sys
 
-from signedjson.key import write_signing_keys, generate_signing_key
+from signedjson.key import generate_signing_key, write_signing_keys
 
 from synapse.util.stringutils import random_string
 
@@ -24,14 +24,14 @@ if __name__ == "__main__":
     parser = argparse.ArgumentParser()
 
     parser.add_argument(
-        "-o", "--output_file",
-
-        type=argparse.FileType('w'),
+        "-o",
+        "--output_file",
+        type=argparse.FileType("w"),
         default=sys.stdout,
         help="Where to write the output to",
     )
     args = parser.parse_args()
 
     key_id = "a_" + random_string(4)
-    key = generate_signing_key(key_id),
+    key = (generate_signing_key(key_id),)
     write_signing_keys(args.output_file, key)
diff --git a/scripts/move_remote_media_to_new_store.py b/scripts/move_remote_media_to_new_store.py
index e630936f78..12747c6024 100755
--- a/scripts/move_remote_media_to_new_store.py
+++ b/scripts/move_remote_media_to_new_store.py
@@ -50,7 +50,7 @@ def main(src_repo, dest_repo):
     dest_paths = MediaFilePaths(dest_repo)
     for line in sys.stdin:
         line = line.strip()
-        parts = line.split('|')
+        parts = line.split("|")
         if len(parts) != 2:
             print("Unable to parse input line %s" % line, file=sys.stderr)
             exit(1)
@@ -107,7 +107,7 @@ if __name__ == "__main__":
     parser = argparse.ArgumentParser(
         description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
     )
-    parser.add_argument("-v", action='store_true', help='enable debug logging')
+    parser.add_argument("-v", action="store_true", help="enable debug logging")
     parser.add_argument("src_repo", help="Path to source content repo")
     parser.add_argument("dest_repo", help="Path to source content repo")
     args = parser.parse_args()