summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-10-21 09:07:07 -0400
committerGitHub <noreply@github.com>2021-10-21 13:07:07 +0000
commit0f9adc99ada1f66f4897c8164dcf509a955e5584 (patch)
treeefc2c73d0f303d86ccd2aa97e3817891887e2e79
parentfix relative link in docker readme (#11144) (diff)
downloadsynapse-0f9adc99ada1f66f4897c8164dcf509a955e5584.tar.xz
Add missing type hints to synapse.crypto. (#11146)
And require type hints for this module.
-rw-r--r--changelog.d/11146.misc1
-rw-r--r--mypy.ini3
-rw-r--r--synapse/crypto/context_factory.py40
-rw-r--r--synapse/crypto/event_signing.py2
-rw-r--r--synapse/crypto/keyring.py8
5 files changed, 36 insertions, 18 deletions
diff --git a/changelog.d/11146.misc b/changelog.d/11146.misc
new file mode 100644
index 0000000000..6ce1c9f9f5
--- /dev/null
+++ b/changelog.d/11146.misc
@@ -0,0 +1 @@
+Add missing type hints to `synapse.crypto`.
diff --git a/mypy.ini b/mypy.ini
index 14d8bb8eaf..c5f44aea39 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -103,6 +103,9 @@ files =
 [mypy-synapse.api.*]
 disallow_untyped_defs = True
 
+[mypy-synapse.crypto.*]
+disallow_untyped_defs = True
+
 [mypy-synapse.events.*]
 disallow_untyped_defs = True
 
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index 2a6110eb10..7855f3498b 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -29,9 +29,12 @@ from twisted.internet.ssl import (
     TLSVersion,
     platformTrust,
 )
+from twisted.protocols.tls import TLSMemoryBIOProtocol
 from twisted.python.failure import Failure
 from twisted.web.iweb import IPolicyForHTTPS
 
+from synapse.config.homeserver import HomeServerConfig
+
 logger = logging.getLogger(__name__)
 
 
@@ -51,7 +54,7 @@ class ServerContextFactory(ContextFactory):
     per https://github.com/matrix-org/synapse/issues/1691
     """
 
-    def __init__(self, config):
+    def __init__(self, config: HomeServerConfig):
         # TODO: once pyOpenSSL exposes TLS_METHOD and SSL_CTX_set_min_proto_version,
         # switch to those (see https://github.com/pyca/cryptography/issues/5379).
         #
@@ -64,7 +67,7 @@ class ServerContextFactory(ContextFactory):
         self.configure_context(self._context, config)
 
     @staticmethod
-    def configure_context(context, config):
+    def configure_context(context: SSL.Context, config: HomeServerConfig) -> None:
         try:
             _ecCurve = crypto.get_elliptic_curve(_defaultCurveName)
             context.set_tmp_ecdh(_ecCurve)
@@ -75,14 +78,15 @@ class ServerContextFactory(ContextFactory):
             SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL.OP_NO_TLSv1 | SSL.OP_NO_TLSv1_1
         )
         context.use_certificate_chain_file(config.tls.tls_certificate_file)
+        assert config.tls.tls_private_key is not None
         context.use_privatekey(config.tls.tls_private_key)
 
         # https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/
         context.set_cipher_list(
-            "ECDH+AESGCM:ECDH+CHACHA20:ECDH+AES256:ECDH+AES128:!aNULL:!SHA1:!AESCCM"
+            b"ECDH+AESGCM:ECDH+CHACHA20:ECDH+AES256:ECDH+AES128:!aNULL:!SHA1:!AESCCM"
         )
 
-    def getContext(self):
+    def getContext(self) -> SSL.Context:
         return self._context
 
 
@@ -98,7 +102,7 @@ class FederationPolicyForHTTPS:
     constructs an SSLClientConnectionCreator factory accordingly.
     """
 
-    def __init__(self, config):
+    def __init__(self, config: HomeServerConfig):
         self._config = config
 
         # Check if we're using a custom list of a CA certificates
@@ -131,7 +135,7 @@ class FederationPolicyForHTTPS:
             self._config.tls.federation_certificate_verification_whitelist
         )
 
-    def get_options(self, host: bytes):
+    def get_options(self, host: bytes) -> IOpenSSLClientConnectionCreator:
         # IPolicyForHTTPS.get_options takes bytes, but we want to compare
         # against the str whitelist. The hostnames in the whitelist are already
         # IDNA-encoded like the hosts will be here.
@@ -153,7 +157,9 @@ class FederationPolicyForHTTPS:
 
         return SSLClientConnectionCreator(host, ssl_context, should_verify)
 
-    def creatorForNetloc(self, hostname, port):
+    def creatorForNetloc(
+        self, hostname: bytes, port: int
+    ) -> IOpenSSLClientConnectionCreator:
         """Implements the IPolicyForHTTPS interface so that this can be passed
         directly to agents.
         """
@@ -169,16 +175,18 @@ class RegularPolicyForHTTPS:
     trust root.
     """
 
-    def __init__(self):
+    def __init__(self) -> None:
         trust_root = platformTrust()
         self._ssl_context = CertificateOptions(trustRoot=trust_root).getContext()
         self._ssl_context.set_info_callback(_context_info_cb)
 
-    def creatorForNetloc(self, hostname, port):
+    def creatorForNetloc(
+        self, hostname: bytes, port: int
+    ) -> IOpenSSLClientConnectionCreator:
         return SSLClientConnectionCreator(hostname, self._ssl_context, True)
 
 
-def _context_info_cb(ssl_connection, where, ret):
+def _context_info_cb(ssl_connection: SSL.Connection, where: int, ret: int) -> None:
     """The 'information callback' for our openssl context objects.
 
     Note: Once this is set as the info callback on a Context object, the Context should
@@ -204,11 +212,13 @@ class SSLClientConnectionCreator:
     Replaces twisted.internet.ssl.ClientTLSOptions
     """
 
-    def __init__(self, hostname: bytes, ctx, verify_certs: bool):
+    def __init__(self, hostname: bytes, ctx: SSL.Context, verify_certs: bool):
         self._ctx = ctx
         self._verifier = ConnectionVerifier(hostname, verify_certs)
 
-    def clientConnectionForTLS(self, tls_protocol):
+    def clientConnectionForTLS(
+        self, tls_protocol: TLSMemoryBIOProtocol
+    ) -> SSL.Connection:
         context = self._ctx
         connection = SSL.Connection(context, None)
 
@@ -219,7 +229,7 @@ class SSLClientConnectionCreator:
         # ... and we also gut-wrench a '_synapse_tls_verifier' attribute into the
         # tls_protocol so that the SSL context's info callback has something to
         # call to do the cert verification.
-        tls_protocol._synapse_tls_verifier = self._verifier
+        tls_protocol._synapse_tls_verifier = self._verifier  # type: ignore[attr-defined]
         return connection
 
 
@@ -244,7 +254,9 @@ class ConnectionVerifier:
         self._hostnameBytes = hostname
         self._hostnameASCII = self._hostnameBytes.decode("ascii")
 
-    def verify_context_info_cb(self, ssl_connection, where):
+    def verify_context_info_cb(
+        self, ssl_connection: SSL.Connection, where: int
+    ) -> None:
         if where & SSL.SSL_CB_HANDSHAKE_START and not self._is_ip_address:
             ssl_connection.set_tlsext_host_name(self._hostnameBytes)
 
diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py
index 0f2b632e47..7520647d1e 100644
--- a/synapse/crypto/event_signing.py
+++ b/synapse/crypto/event_signing.py
@@ -100,7 +100,7 @@ def compute_content_hash(
 
 
 def compute_event_reference_hash(
-    event, hash_algorithm: Hasher = hashlib.sha256
+    event: EventBase, hash_algorithm: Hasher = hashlib.sha256
 ) -> Tuple[str, bytes]:
     """Computes the event reference hash. This is the hash of the redacted
     event.
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index e1e13a2412..8628e951c4 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -87,7 +87,7 @@ class VerifyJsonRequest:
         server_name: str,
         json_object: JsonDict,
         minimum_valid_until_ms: int,
-    ):
+    ) -> "VerifyJsonRequest":
         """Create a VerifyJsonRequest to verify all signatures on a signed JSON
         object for the given server.
         """
@@ -104,7 +104,7 @@ class VerifyJsonRequest:
         server_name: str,
         event: EventBase,
         minimum_valid_until_ms: int,
-    ):
+    ) -> "VerifyJsonRequest":
         """Create a VerifyJsonRequest to verify all signatures on an event
         object for the given server.
         """
@@ -449,7 +449,9 @@ class StoreKeyFetcher(KeyFetcher):
 
         self.store = hs.get_datastore()
 
-    async def _fetch_keys(self, keys_to_fetch: List[_FetchKeyRequest]):
+    async def _fetch_keys(
+        self, keys_to_fetch: List[_FetchKeyRequest]
+    ) -> Dict[str, Dict[str, FetchKeyResult]]:
         key_ids_to_fetch = (
             (queue_value.server_name, key_id)
             for queue_value in keys_to_fetch