summary refs log tree commit diff
path: root/scripts-dev/federation_client.py
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2021-03-18 21:12:07 +0000
committerGitHub <noreply@github.com>2021-03-18 21:12:07 +0000
commit0e355847347d6a2334dd8879174b1b6806b2a38b (patch)
tree9707a7ab5ee741723344272506ddce2f7d3ab8a6 /scripts-dev/federation_client.py
parentfederation_client: stop adding URL prefix (#9645) (diff)
downloadsynapse-0e355847347d6a2334dd8879174b1b6806b2a38b.tar.xz
federation_client: handle inline signing_keys in hs.yaml (#9647)
Diffstat (limited to '')
-rwxr-xr-xscripts-dev/federation_client.py71
1 files changed, 17 insertions, 54 deletions
diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py
index 657919a0d7..6f76c08fcf 100755
--- a/scripts-dev/federation_client.py
+++ b/scripts-dev/federation_client.py
@@ -22,8 +22,8 @@ import sys
 from typing import Any, Optional
 from urllib import parse as urlparse
 
-import nacl.signing
 import requests
+import signedjson.key
 import signedjson.types
 import srvlookup
 import yaml
@@ -44,18 +44,6 @@ def encode_base64(input_bytes):
     return output_string
 
 
-def decode_base64(input_string):
-    """Decode a base64 string to bytes inferring padding from the length of the
-    string."""
-
-    input_bytes = input_string.encode("ascii")
-    input_len = len(input_bytes)
-    padding = b"=" * (3 - ((input_len + 3) % 4))
-    output_len = 3 * ((input_len + 2) // 4) + (input_len + 2) % 4 - 2
-    output_bytes = base64.b64decode(input_bytes + padding)
-    return output_bytes[:output_len]
-
-
 def encode_canonical_json(value):
     return json.dumps(
         value,
@@ -88,42 +76,6 @@ def sign_json(
     return json_object
 
 
-NACL_ED25519 = "ed25519"
-
-
-def decode_signing_key_base64(algorithm, version, key_base64):
-    """Decode a base64 encoded signing key
-    Args:
-        algorithm (str): The algorithm the key is for (currently "ed25519").
-        version (str): Identifies this key out of the keys for this entity.
-        key_base64 (str): Base64 encoded bytes of the key.
-    Returns:
-        A SigningKey object.
-    """
-    if algorithm == NACL_ED25519:
-        key_bytes = decode_base64(key_base64)
-        key = nacl.signing.SigningKey(key_bytes)
-        key.version = version
-        key.alg = NACL_ED25519
-        return key
-    else:
-        raise ValueError("Unsupported algorithm %s" % (algorithm,))
-
-
-def read_signing_keys(stream):
-    """Reads a list of keys from a stream
-    Args:
-        stream : A stream to iterate for keys.
-    Returns:
-        list of SigningKey objects.
-    """
-    keys = []
-    for line in stream:
-        algorithm, version, key_base64 = line.split()
-        keys.append(decode_signing_key_base64(algorithm, version, key_base64))
-    return keys
-
-
 def request(
     method: Optional[str],
     origin_name: str,
@@ -228,11 +180,16 @@ def main():
 
     args = parser.parse_args()
 
-    if not args.server_name or not args.signing_key_path:
+    args.signing_key = None
+    if args.signing_key_path:
+        with open(args.signing_key_path) as f:
+            args.signing_key = f.readline()
+
+    if not args.server_name or not args.signing_key:
         read_args_from_config(args)
 
-    with open(args.signing_key_path) as f:
-        key = read_signing_keys(f)[0]
+    algorithm, version, key_base64 = args.signing_key.split()
+    key = signedjson.key.decode_signing_key_base64(algorithm, version, key_base64)
 
     result = request(
         args.method,
@@ -255,10 +212,16 @@ def main():
 def read_args_from_config(args):
     with open(args.config, "r") as fh:
         config = yaml.safe_load(fh)
+
         if not args.server_name:
             args.server_name = config["server_name"]
-        if not args.signing_key_path:
-            args.signing_key_path = config["signing_key_path"]
+
+        if not args.signing_key:
+            if "signing_key" in config:
+                args.signing_key = config["signing_key"]
+            else:
+                with open(config["signing_key_path"]) as f:
+                    args.signing_key = f.readline()
 
 
 class MatrixConnectionAdapter(HTTPAdapter):