summary refs log tree commit diff
path: root/scripts-dev/federation_client.py
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2022-04-27 14:10:31 +0100
committerGitHub <noreply@github.com>2022-04-27 13:10:31 +0000
commit30c8e7e408322967e5beb2a64ef5f796cb8df226 (patch)
tree1d65d45fd7ddf5735c80282c62ae1cc5aae2b708 /scripts-dev/federation_client.py
parentRemove unused `# type: ignore`s (#12531) (diff)
downloadsynapse-30c8e7e408322967e5beb2a64ef5f796cb8df226.tar.xz
Make `scripts-dev` pass `mypy --disallow-untyped-defs` (#12356)
Not enforced in config yet. One day.
Diffstat (limited to 'scripts-dev/federation_client.py')
-rwxr-xr-xscripts-dev/federation_client.py27
1 files changed, 17 insertions, 10 deletions
diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py
index 079d2f5ed0..763dd02c47 100755
--- a/scripts-dev/federation_client.py
+++ b/scripts-dev/federation_client.py
@@ -38,7 +38,7 @@ import argparse
 import base64
 import json
 import sys
-from typing import Any, Optional
+from typing import Any, Dict, Optional, Tuple
 from urllib import parse as urlparse
 
 import requests
@@ -47,13 +47,14 @@ import signedjson.types
 import srvlookup
 import yaml
 from requests.adapters import HTTPAdapter
+from urllib3 import HTTPConnectionPool
 
 # uncomment the following to enable debug logging of http requests
 # from httplib import HTTPConnection
 # HTTPConnection.debuglevel = 1
 
 
-def encode_base64(input_bytes):
+def encode_base64(input_bytes: bytes) -> str:
     """Encode bytes as a base64 string without any padding."""
 
     input_len = len(input_bytes)
@@ -63,7 +64,7 @@ def encode_base64(input_bytes):
     return output_string
 
 
-def encode_canonical_json(value):
+def encode_canonical_json(value: object) -> bytes:
     return json.dumps(
         value,
         # Encode code-points outside of ASCII as UTF-8 rather than \u escapes
@@ -130,7 +131,7 @@ def request(
             sig,
             destination,
         )
-        authorization_headers.append(header.encode("ascii"))
+        authorization_headers.append(header)
         print("Authorization: %s" % header, file=sys.stderr)
 
     dest = "matrix://%s%s" % (destination, path)
@@ -139,7 +140,10 @@ def request(
     s = requests.Session()
     s.mount("matrix://", MatrixConnectionAdapter())
 
-    headers = {"Host": destination, "Authorization": authorization_headers[0]}
+    headers: Dict[str, str] = {
+        "Host": destination,
+        "Authorization": authorization_headers[0],
+    }
 
     if method == "POST":
         headers["Content-Type"] = "application/json"
@@ -154,7 +158,7 @@ def request(
     )
 
 
-def main():
+def main() -> None:
     parser = argparse.ArgumentParser(
         description="Signs and sends a federation request to a matrix homeserver"
     )
@@ -212,6 +216,7 @@ def main():
     if not args.server_name or not args.signing_key:
         read_args_from_config(args)
 
+    assert isinstance(args.signing_key, str)
     algorithm, version, key_base64 = args.signing_key.split()
     key = signedjson.key.decode_signing_key_base64(algorithm, version, key_base64)
 
@@ -233,7 +238,7 @@ def main():
     print("")
 
 
-def read_args_from_config(args):
+def read_args_from_config(args: argparse.Namespace) -> None:
     with open(args.config, "r") as fh:
         config = yaml.safe_load(fh)
 
@@ -250,7 +255,7 @@ def read_args_from_config(args):
 
 class MatrixConnectionAdapter(HTTPAdapter):
     @staticmethod
-    def lookup(s, skip_well_known=False):
+    def lookup(s: str, skip_well_known: bool = False) -> Tuple[str, int]:
         if s[-1] == "]":
             # ipv6 literal (with no port)
             return s, 8448
@@ -276,7 +281,7 @@ class MatrixConnectionAdapter(HTTPAdapter):
             return s, 8448
 
     @staticmethod
-    def get_well_known(server_name):
+    def get_well_known(server_name: str) -> Optional[str]:
         uri = "https://%s/.well-known/matrix/server" % (server_name,)
         print("fetching %s" % (uri,), file=sys.stderr)
 
@@ -299,7 +304,9 @@ class MatrixConnectionAdapter(HTTPAdapter):
             print("Invalid response from %s: %s" % (uri, e), file=sys.stderr)
         return None
 
-    def get_connection(self, url, proxies=None):
+    def get_connection(
+        self, url: str, proxies: Optional[Dict[str, str]] = None
+    ) -> HTTPConnectionPool:
         parsed = urlparse.urlparse(url)
 
         (host, port) = self.lookup(parsed.netloc)