summary refs log tree commit diff
path: root/synapse/crypto/keyring.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/crypto/keyring.py')
-rw-r--r--synapse/crypto/keyring.py200
1 files changed, 59 insertions, 141 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index d89f94c219..7474fd515f 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -1,6 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2017 New Vector Ltd.
+# Copyright 2017, 2018 New Vector Ltd.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,10 +14,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import hashlib
 import logging
 from collections import namedtuple
 
+from six import raise_from
 from six.moves import urllib
 
 from signedjson.key import (
@@ -32,13 +32,16 @@ from signedjson.sign import (
     signature_ids,
     verify_signed_json,
 )
-from unpaddedbase64 import decode_base64, encode_base64
+from unpaddedbase64 import decode_base64
 
-from OpenSSL import crypto
 from twisted.internet import defer
 
-from synapse.api.errors import Codes, SynapseError
-from synapse.crypto.keyclient import fetch_server_key
+from synapse.api.errors import (
+    Codes,
+    HttpResponseException,
+    RequestSendFailed,
+    SynapseError,
+)
 from synapse.util import logcontext, unwrapFirstError
 from synapse.util.logcontext import (
     LoggingContext,
@@ -47,6 +50,7 @@ from synapse.util.logcontext import (
     run_in_background,
 )
 from synapse.util.metrics import Measure
+from synapse.util.retryutils import NotRetryingDestination
 
 logger = logging.getLogger(__name__)
 
@@ -370,13 +374,18 @@ class Keyring(object):
                     server_name_and_key_ids, perspective_name, perspective_keys
                 )
                 defer.returnValue(result)
+            except KeyLookupError as e:
+                logger.warning(
+                    "Key lookup failed from %r: %s", perspective_name, e,
+                )
             except Exception as e:
                 logger.exception(
                     "Unable to get key from %r: %s %s",
                     perspective_name,
                     type(e).__name__, str(e),
                 )
-                defer.returnValue({})
+
+            defer.returnValue({})
 
         results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
             [
@@ -395,32 +404,13 @@ class Keyring(object):
 
     @defer.inlineCallbacks
     def get_keys_from_server(self, server_name_and_key_ids):
-        @defer.inlineCallbacks
-        def get_key(server_name, key_ids):
-            keys = None
-            try:
-                keys = yield self.get_server_verify_key_v2_direct(
-                    server_name, key_ids
-                )
-            except Exception as e:
-                logger.info(
-                    "Unable to get key %r for %r directly: %s %s",
-                    key_ids, server_name,
-                    type(e).__name__, str(e),
-                )
-
-            if not keys:
-                keys = yield self.get_server_verify_key_v1_direct(
-                    server_name, key_ids
-                )
-
-                keys = {server_name: keys}
-
-            defer.returnValue(keys)
-
         results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
             [
-                run_in_background(get_key, server_name, key_ids)
+                run_in_background(
+                    self.get_server_verify_key_v2_direct,
+                    server_name,
+                    key_ids,
+                )
                 for server_name, key_ids in server_name_and_key_ids
             ],
             consumeErrors=True,
@@ -443,21 +433,30 @@ class Keyring(object):
         # TODO(mark): Set the minimum_valid_until_ts to that needed by
         # the events being validated or the current time if validating
         # an incoming request.
-        query_response = yield self.client.post_json(
-            destination=perspective_name,
-            path="/_matrix/key/v2/query",
-            data={
-                u"server_keys": {
-                    server_name: {
-                        key_id: {
-                            u"minimum_valid_until_ts": 0
-                        } for key_id in key_ids
+        try:
+            query_response = yield self.client.post_json(
+                destination=perspective_name,
+                path="/_matrix/key/v2/query",
+                data={
+                    u"server_keys": {
+                        server_name: {
+                            key_id: {
+                                u"minimum_valid_until_ts": 0
+                            } for key_id in key_ids
+                        }
+                        for server_name, key_ids in server_names_and_key_ids
                     }
-                    for server_name, key_ids in server_names_and_key_ids
-                }
-            },
-            long_retries=True,
-        )
+                },
+                long_retries=True,
+            )
+        except (NotRetryingDestination, RequestSendFailed) as e:
+            raise_from(
+                KeyLookupError("Failed to connect to remote server"), e,
+            )
+        except HttpResponseException as e:
+            raise_from(
+                KeyLookupError("Remote server returned an error"), e,
+            )
 
         keys = {}
 
@@ -524,34 +523,25 @@ class Keyring(object):
             if requested_key_id in keys:
                 continue
 
-            (response, tls_certificate) = yield fetch_server_key(
-                server_name, self.hs.tls_client_options_factory,
-                path=("/_matrix/key/v2/server/%s" % (
-                    urllib.parse.quote(requested_key_id),
-                )).encode("ascii"),
-            )
+            try:
+                response = yield self.client.get_json(
+                    destination=server_name,
+                    path="/_matrix/key/v2/server/" + urllib.parse.quote(requested_key_id),
+                    ignore_backoff=True,
+                )
+            except (NotRetryingDestination, RequestSendFailed) as e:
+                raise_from(
+                    KeyLookupError("Failed to connect to remote server"), e,
+                )
+            except HttpResponseException as e:
+                raise_from(
+                    KeyLookupError("Remote server returned an error"), e,
+                )
 
             if (u"signatures" not in response
                     or server_name not in response[u"signatures"]):
                 raise KeyLookupError("Key response not signed by remote server")
 
-            if "tls_fingerprints" not in response:
-                raise KeyLookupError("Key response missing TLS fingerprints")
-
-            certificate_bytes = crypto.dump_certificate(
-                crypto.FILETYPE_ASN1, tls_certificate
-            )
-            sha256_fingerprint = hashlib.sha256(certificate_bytes).digest()
-            sha256_fingerprint_b64 = encode_base64(sha256_fingerprint)
-
-            response_sha256_fingerprints = set()
-            for fingerprint in response[u"tls_fingerprints"]:
-                if u"sha256" in fingerprint:
-                    response_sha256_fingerprints.add(fingerprint[u"sha256"])
-
-            if sha256_fingerprint_b64 not in response_sha256_fingerprints:
-                raise KeyLookupError("TLS certificate not allowed by fingerprints")
-
             response_keys = yield self.process_v2_response(
                 from_server=server_name,
                 requested_ids=[requested_key_id],
@@ -657,78 +647,6 @@ class Keyring(object):
 
         defer.returnValue(results)
 
-    @defer.inlineCallbacks
-    def get_server_verify_key_v1_direct(self, server_name, key_ids):
-        """Finds a verification key for the server with one of the key ids.
-        Args:
-            server_name (str): The name of the server to fetch a key for.
-            keys_ids (list of str): The key_ids to check for.
-        """
-
-        # Try to fetch the key from the remote server.
-
-        (response, tls_certificate) = yield fetch_server_key(
-            server_name, self.hs.tls_client_options_factory
-        )
-
-        # Check the response.
-
-        x509_certificate_bytes = crypto.dump_certificate(
-            crypto.FILETYPE_ASN1, tls_certificate
-        )
-
-        if ("signatures" not in response
-                or server_name not in response["signatures"]):
-            raise KeyLookupError("Key response not signed by remote server")
-
-        if "tls_certificate" not in response:
-            raise KeyLookupError("Key response missing TLS certificate")
-
-        tls_certificate_b64 = response["tls_certificate"]
-
-        if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
-            raise KeyLookupError("TLS certificate doesn't match")
-
-        # Cache the result in the datastore.
-
-        time_now_ms = self.clock.time_msec()
-
-        verify_keys = {}
-        for key_id, key_base64 in response["verify_keys"].items():
-            if is_signing_algorithm_supported(key_id):
-                key_bytes = decode_base64(key_base64)
-                verify_key = decode_verify_key_bytes(key_id, key_bytes)
-                verify_key.time_added = time_now_ms
-                verify_keys[key_id] = verify_key
-
-        for key_id in response["signatures"][server_name]:
-            if key_id not in response["verify_keys"]:
-                raise KeyLookupError(
-                    "Key response must include verification keys for all"
-                    " signatures"
-                )
-            if key_id in verify_keys:
-                verify_signed_json(
-                    response,
-                    server_name,
-                    verify_keys[key_id]
-                )
-
-        yield self.store.store_server_certificate(
-            server_name,
-            server_name,
-            time_now_ms,
-            tls_certificate,
-        )
-
-        yield self.store_keys(
-            server_name=server_name,
-            from_server=server_name,
-            verify_keys=verify_keys,
-        )
-
-        defer.returnValue(verify_keys)
-
     def store_keys(self, server_name, from_server, verify_keys):
         """Store a collection of verify keys for a given server
         Args:
@@ -768,7 +686,7 @@ def _handle_key_deferred(verify_request):
     try:
         with PreserveLoggingContext():
             _, key_id, verify_key = yield verify_request.deferred
-    except IOError as e:
+    except (IOError, RequestSendFailed) as e:
         logger.warn(
             "Got IOError when downloading keys for %s: %s %s",
             server_name, type(e).__name__, str(e),