summary refs log tree commit diff
path: root/synapse/crypto
diff options
context:
space:
mode:
authorAndrew Morgan <andrew@amorgan.xyz>2019-02-26 14:23:40 +0000
committerAndrew Morgan <andrew@amorgan.xyz>2019-02-26 14:23:40 +0000
commit802884d4ee06ca8e42f46f64e6da7c188d43dc69 (patch)
tree6767e6e142d75e5500092a829d488583fcedef51 /synapse/crypto
parentAdd changelog (diff)
parentMerge pull request #4745 from matrix-org/revert-4736-anoa/public_rooms_federate (diff)
downloadsynapse-802884d4ee06ca8e42f46f64e6da7c188d43dc69.tar.xz
Merge branch 'develop' of github.com:matrix-org/synapse into anoa/public_rooms_federate_develop
Diffstat (limited to 'synapse/crypto')
-rw-r--r--synapse/crypto/context_factory.py39
-rw-r--r--synapse/crypto/event_signing.py109
-rw-r--r--synapse/crypto/keyclient.py147
-rw-r--r--synapse/crypto/keyring.py200
4 files changed, 154 insertions, 341 deletions
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index 02b76dfcfb..49cbc7098f 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -1,4 +1,5 @@
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -11,12 +12,14 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 import logging
 
 from zope.interface import implementer
 
 from OpenSSL import SSL, crypto
 from twisted.internet._sslverify import _defaultCurveName
+from twisted.internet.abstract import isIPAddress, isIPv6Address
 from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
 from twisted.internet.ssl import CertificateOptions, ContextFactory
 from twisted.python.failure import Failure
@@ -42,12 +45,12 @@ class ServerContextFactory(ContextFactory):
             logger.exception("Failed to enable elliptic curve for TLS")
         context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)
         context.use_certificate_chain_file(config.tls_certificate_file)
+        context.use_privatekey(config.tls_private_key)
 
-        if not config.no_tls:
-            context.use_privatekey(config.tls_private_key)
-
-        context.load_tmp_dh(config.tls_dh_params_path)
-        context.set_cipher_list("!ADH:HIGH+kEDH:!AECDH:HIGH+kEECDH")
+        # https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/
+        context.set_cipher_list(
+            "ECDH+AESGCM:ECDH+CHACHA20:ECDH+AES256:ECDH+AES128:!aNULL:!SHA1"
+        )
 
     def getContext(self):
         return self._context
@@ -96,11 +99,15 @@ class ClientTLSOptions(object):
 
     def __init__(self, hostname, ctx):
         self._ctx = ctx
-        self._hostname = hostname
-        self._hostnameBytes = _idnaBytes(hostname)
-        ctx.set_info_callback(
-            _tolerateErrors(self._identityVerifyingInfoCallback)
-        )
+
+        if isIPAddress(hostname) or isIPv6Address(hostname):
+            self._hostnameBytes = hostname.encode('ascii')
+            self._sendSNI = False
+        else:
+            self._hostnameBytes = _idnaBytes(hostname)
+            self._sendSNI = True
+
+        ctx.set_info_callback(_tolerateErrors(self._identityVerifyingInfoCallback))
 
     def clientConnectionForTLS(self, tlsProtocol):
         context = self._ctx
@@ -109,7 +116,9 @@ class ClientTLSOptions(object):
         return connection
 
     def _identityVerifyingInfoCallback(self, connection, where, ret):
-        if where & SSL.SSL_CB_HANDSHAKE_START:
+        # Literal IPv4 and IPv6 addresses are not permitted
+        # as host names according to the RFCs
+        if where & SSL.SSL_CB_HANDSHAKE_START and self._sendSNI:
             connection.set_tlsext_host_name(self._hostnameBytes)
 
 
@@ -119,10 +128,8 @@ class ClientTLSOptionsFactory(object):
 
     def __init__(self, config):
         # We don't use config options yet
-        pass
+        self._options = CertificateOptions(verify=False)
 
     def get_options(self, host):
-        return ClientTLSOptions(
-            host,
-            CertificateOptions(verify=False).getContext()
-        )
+        # Use _makeContext so that we get a fresh OpenSSL CTX each time.
+        return ClientTLSOptions(host, self._options._makeContext())
diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py
index 8774b28967..1dfa727fcf 100644
--- a/synapse/crypto/event_signing.py
+++ b/synapse/crypto/event_signing.py
@@ -23,14 +23,14 @@ from signedjson.sign import sign_json
 from unpaddedbase64 import decode_base64, encode_base64
 
 from synapse.api.errors import Codes, SynapseError
-from synapse.events.utils import prune_event
+from synapse.events.utils import prune_event, prune_event_dict
 
 logger = logging.getLogger(__name__)
 
 
 def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
     """Check whether the hash for this PDU matches the contents"""
-    name, expected_hash = compute_content_hash(event, hash_algorithm)
+    name, expected_hash = compute_content_hash(event.get_pdu_json(), hash_algorithm)
     logger.debug("Expecting hash: %s", encode_base64(expected_hash))
 
     # some malformed events lack a 'hashes'. Protect against it being missing
@@ -59,35 +59,70 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
     return message_hash_bytes == expected_hash
 
 
-def compute_content_hash(event, hash_algorithm):
-    event_json = event.get_pdu_json()
-    event_json.pop("age_ts", None)
-    event_json.pop("unsigned", None)
-    event_json.pop("signatures", None)
-    event_json.pop("hashes", None)
-    event_json.pop("outlier", None)
-    event_json.pop("destinations", None)
+def compute_content_hash(event_dict, hash_algorithm):
+    """Compute the content hash of an event, which is the hash of the
+    unredacted event.
 
-    event_json_bytes = encode_canonical_json(event_json)
+    Args:
+        event_dict (dict): The unredacted event as a dict
+        hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
+            to hash the event
+
+    Returns:
+        tuple[str, bytes]: A tuple of the name of hash and the hash as raw
+        bytes.
+    """
+    event_dict = dict(event_dict)
+    event_dict.pop("age_ts", None)
+    event_dict.pop("unsigned", None)
+    event_dict.pop("signatures", None)
+    event_dict.pop("hashes", None)
+    event_dict.pop("outlier", None)
+    event_dict.pop("destinations", None)
+
+    event_json_bytes = encode_canonical_json(event_dict)
 
     hashed = hash_algorithm(event_json_bytes)
     return (hashed.name, hashed.digest())
 
 
 def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256):
+    """Computes the event reference hash. This is the hash of the redacted
+    event.
+
+    Args:
+        event (FrozenEvent)
+        hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
+            to hash the event
+
+    Returns:
+        tuple[str, bytes]: A tuple of the name of hash and the hash as raw
+        bytes.
+    """
     tmp_event = prune_event(event)
-    event_json = tmp_event.get_pdu_json()
-    event_json.pop("signatures", None)
-    event_json.pop("age_ts", None)
-    event_json.pop("unsigned", None)
-    event_json_bytes = encode_canonical_json(event_json)
+    event_dict = tmp_event.get_pdu_json()
+    event_dict.pop("signatures", None)
+    event_dict.pop("age_ts", None)
+    event_dict.pop("unsigned", None)
+    event_json_bytes = encode_canonical_json(event_dict)
     hashed = hash_algorithm(event_json_bytes)
     return (hashed.name, hashed.digest())
 
 
-def compute_event_signature(event, signature_name, signing_key):
-    tmp_event = prune_event(event)
-    redact_json = tmp_event.get_pdu_json()
+def compute_event_signature(event_dict, signature_name, signing_key):
+    """Compute the signature of the event for the given name and key.
+
+    Args:
+        event_dict (dict): The event as a dict
+        signature_name (str): The name of the entity signing the event
+            (typically the server's hostname).
+        signing_key (syutil.crypto.SigningKey): The key to sign with
+
+    Returns:
+        dict[str, dict[str, str]]: Returns a dictionary in the same format of
+        an event's signatures field.
+    """
+    redact_json = prune_event_dict(event_dict)
     redact_json.pop("age_ts", None)
     redact_json.pop("unsigned", None)
     logger.debug("Signing event: %s", encode_canonical_json(redact_json))
@@ -96,25 +131,25 @@ def compute_event_signature(event, signature_name, signing_key):
     return redact_json["signatures"]
 
 
-def add_hashes_and_signatures(event, signature_name, signing_key,
+def add_hashes_and_signatures(event_dict, signature_name, signing_key,
                               hash_algorithm=hashlib.sha256):
-    # if hasattr(event, "old_state_events"):
-    #     state_json_bytes = encode_canonical_json(
-    #         [e.event_id for e in event.old_state_events.values()]
-    #     )
-    #     hashed = hash_algorithm(state_json_bytes)
-    #     event.state_hash = {
-    #         hashed.name: encode_base64(hashed.digest())
-    #     }
-
-    name, digest = compute_content_hash(event, hash_algorithm=hash_algorithm)
-
-    if not hasattr(event, "hashes"):
-        event.hashes = {}
-    event.hashes[name] = encode_base64(digest)
-
-    event.signatures = compute_event_signature(
-        event,
+    """Add content hash and sign the event
+
+    Args:
+        event_dict (dict): The event to add hashes to and sign
+        signature_name (str): The name of the entity signing the event
+            (typically the server's hostname).
+        signing_key (syutil.crypto.SigningKey): The key to sign with
+        hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
+            to hash the event
+    """
+
+    name, digest = compute_content_hash(event_dict, hash_algorithm=hash_algorithm)
+
+    event_dict.setdefault("hashes", {})[name] = encode_base64(digest)
+
+    event_dict["signatures"] = compute_event_signature(
+        event_dict,
         signature_name=signature_name,
         signing_key=signing_key,
     )
diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py
deleted file mode 100644
index 080c81f14b..0000000000
--- a/synapse/crypto/keyclient.py
+++ /dev/null
@@ -1,147 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-
-from canonicaljson import json
-
-from twisted.internet import defer, reactor
-from twisted.internet.error import ConnectError
-from twisted.internet.protocol import Factory
-from twisted.names.error import DomainError
-from twisted.web.http import HTTPClient
-
-from synapse.http.endpoint import matrix_federation_endpoint
-from synapse.util import logcontext
-
-logger = logging.getLogger(__name__)
-
-KEY_API_V1 = b"/_matrix/key/v1/"
-
-
-@defer.inlineCallbacks
-def fetch_server_key(server_name, tls_client_options_factory, path=KEY_API_V1):
-    """Fetch the keys for a remote server."""
-
-    factory = SynapseKeyClientFactory()
-    factory.path = path
-    factory.host = server_name
-    endpoint = matrix_federation_endpoint(
-        reactor, server_name, tls_client_options_factory, timeout=30
-    )
-
-    for i in range(5):
-        try:
-            with logcontext.PreserveLoggingContext():
-                protocol = yield endpoint.connect(factory)
-                server_response, server_certificate = yield protocol.remote_key
-                defer.returnValue((server_response, server_certificate))
-        except SynapseKeyClientError as e:
-            logger.warn("Error getting key for %r: %s", server_name, e)
-            if e.status.startswith(b"4"):
-                # Don't retry for 4xx responses.
-                raise IOError("Cannot get key for %r" % server_name)
-        except (ConnectError, DomainError) as e:
-            logger.warn("Error getting key for %r: %s", server_name, e)
-        except Exception:
-            logger.exception("Error getting key for %r", server_name)
-    raise IOError("Cannot get key for %r" % server_name)
-
-
-class SynapseKeyClientError(Exception):
-    """The key wasn't retrieved from the remote server."""
-    status = None
-    pass
-
-
-class SynapseKeyClientProtocol(HTTPClient):
-    """Low level HTTPS client which retrieves an application/json response from
-    the server and extracts the X.509 certificate for the remote peer from the
-    SSL connection."""
-
-    timeout = 30
-
-    def __init__(self):
-        self.remote_key = defer.Deferred()
-        self.host = None
-        self._peer = None
-
-    def connectionMade(self):
-        self._peer = self.transport.getPeer()
-        logger.debug("Connected to %s", self._peer)
-
-        if not isinstance(self.path, bytes):
-            self.path = self.path.encode('ascii')
-
-        if not isinstance(self.host, bytes):
-            self.host = self.host.encode('ascii')
-
-        self.sendCommand(b"GET", self.path)
-        if self.host:
-            self.sendHeader(b"Host", self.host)
-        self.endHeaders()
-        self.timer = reactor.callLater(
-            self.timeout,
-            self.on_timeout
-        )
-
-    def errback(self, error):
-        if not self.remote_key.called:
-            self.remote_key.errback(error)
-
-    def callback(self, result):
-        if not self.remote_key.called:
-            self.remote_key.callback(result)
-
-    def handleStatus(self, version, status, message):
-        if status != b"200":
-            # logger.info("Non-200 response from %s: %s %s",
-            #            self.transport.getHost(), status, message)
-            error = SynapseKeyClientError(
-                "Non-200 response %r from %r" % (status, self.host)
-            )
-            error.status = status
-            self.errback(error)
-            self.transport.abortConnection()
-
-    def handleResponse(self, response_body_bytes):
-        try:
-            json_response = json.loads(response_body_bytes)
-        except ValueError:
-            # logger.info("Invalid JSON response from %s",
-            #            self.transport.getHost())
-            self.transport.abortConnection()
-            return
-
-        certificate = self.transport.getPeerCertificate()
-        self.callback((json_response, certificate))
-        self.transport.abortConnection()
-        self.timer.cancel()
-
-    def on_timeout(self):
-        logger.debug(
-            "Timeout waiting for response from %s: %s",
-            self.host, self._peer,
-        )
-        self.errback(IOError("Timeout waiting for response"))
-        self.transport.abortConnection()
-
-
-class SynapseKeyClientFactory(Factory):
-    def protocol(self):
-        protocol = SynapseKeyClientProtocol()
-        protocol.path = self.path
-        protocol.host = self.host
-        return protocol
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index d89f94c219..7474fd515f 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -1,6 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2017 New Vector Ltd.
+# Copyright 2017, 2018 New Vector Ltd.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,10 +14,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import hashlib
 import logging
 from collections import namedtuple
 
+from six import raise_from
 from six.moves import urllib
 
 from signedjson.key import (
@@ -32,13 +32,16 @@ from signedjson.sign import (
     signature_ids,
     verify_signed_json,
 )
-from unpaddedbase64 import decode_base64, encode_base64
+from unpaddedbase64 import decode_base64
 
-from OpenSSL import crypto
 from twisted.internet import defer
 
-from synapse.api.errors import Codes, SynapseError
-from synapse.crypto.keyclient import fetch_server_key
+from synapse.api.errors import (
+    Codes,
+    HttpResponseException,
+    RequestSendFailed,
+    SynapseError,
+)
 from synapse.util import logcontext, unwrapFirstError
 from synapse.util.logcontext import (
     LoggingContext,
@@ -47,6 +50,7 @@ from synapse.util.logcontext import (
     run_in_background,
 )
 from synapse.util.metrics import Measure
+from synapse.util.retryutils import NotRetryingDestination
 
 logger = logging.getLogger(__name__)
 
@@ -370,13 +374,18 @@ class Keyring(object):
                     server_name_and_key_ids, perspective_name, perspective_keys
                 )
                 defer.returnValue(result)
+            except KeyLookupError as e:
+                logger.warning(
+                    "Key lookup failed from %r: %s", perspective_name, e,
+                )
             except Exception as e:
                 logger.exception(
                     "Unable to get key from %r: %s %s",
                     perspective_name,
                     type(e).__name__, str(e),
                 )
-                defer.returnValue({})
+
+            defer.returnValue({})
 
         results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
             [
@@ -395,32 +404,13 @@ class Keyring(object):
 
     @defer.inlineCallbacks
     def get_keys_from_server(self, server_name_and_key_ids):
-        @defer.inlineCallbacks
-        def get_key(server_name, key_ids):
-            keys = None
-            try:
-                keys = yield self.get_server_verify_key_v2_direct(
-                    server_name, key_ids
-                )
-            except Exception as e:
-                logger.info(
-                    "Unable to get key %r for %r directly: %s %s",
-                    key_ids, server_name,
-                    type(e).__name__, str(e),
-                )
-
-            if not keys:
-                keys = yield self.get_server_verify_key_v1_direct(
-                    server_name, key_ids
-                )
-
-                keys = {server_name: keys}
-
-            defer.returnValue(keys)
-
         results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
             [
-                run_in_background(get_key, server_name, key_ids)
+                run_in_background(
+                    self.get_server_verify_key_v2_direct,
+                    server_name,
+                    key_ids,
+                )
                 for server_name, key_ids in server_name_and_key_ids
             ],
             consumeErrors=True,
@@ -443,21 +433,30 @@ class Keyring(object):
         # TODO(mark): Set the minimum_valid_until_ts to that needed by
         # the events being validated or the current time if validating
         # an incoming request.
-        query_response = yield self.client.post_json(
-            destination=perspective_name,
-            path="/_matrix/key/v2/query",
-            data={
-                u"server_keys": {
-                    server_name: {
-                        key_id: {
-                            u"minimum_valid_until_ts": 0
-                        } for key_id in key_ids
+        try:
+            query_response = yield self.client.post_json(
+                destination=perspective_name,
+                path="/_matrix/key/v2/query",
+                data={
+                    u"server_keys": {
+                        server_name: {
+                            key_id: {
+                                u"minimum_valid_until_ts": 0
+                            } for key_id in key_ids
+                        }
+                        for server_name, key_ids in server_names_and_key_ids
                     }
-                    for server_name, key_ids in server_names_and_key_ids
-                }
-            },
-            long_retries=True,
-        )
+                },
+                long_retries=True,
+            )
+        except (NotRetryingDestination, RequestSendFailed) as e:
+            raise_from(
+                KeyLookupError("Failed to connect to remote server"), e,
+            )
+        except HttpResponseException as e:
+            raise_from(
+                KeyLookupError("Remote server returned an error"), e,
+            )
 
         keys = {}
 
@@ -524,34 +523,25 @@ class Keyring(object):
             if requested_key_id in keys:
                 continue
 
-            (response, tls_certificate) = yield fetch_server_key(
-                server_name, self.hs.tls_client_options_factory,
-                path=("/_matrix/key/v2/server/%s" % (
-                    urllib.parse.quote(requested_key_id),
-                )).encode("ascii"),
-            )
+            try:
+                response = yield self.client.get_json(
+                    destination=server_name,
+                    path="/_matrix/key/v2/server/" + urllib.parse.quote(requested_key_id),
+                    ignore_backoff=True,
+                )
+            except (NotRetryingDestination, RequestSendFailed) as e:
+                raise_from(
+                    KeyLookupError("Failed to connect to remote server"), e,
+                )
+            except HttpResponseException as e:
+                raise_from(
+                    KeyLookupError("Remote server returned an error"), e,
+                )
 
             if (u"signatures" not in response
                     or server_name not in response[u"signatures"]):
                 raise KeyLookupError("Key response not signed by remote server")
 
-            if "tls_fingerprints" not in response:
-                raise KeyLookupError("Key response missing TLS fingerprints")
-
-            certificate_bytes = crypto.dump_certificate(
-                crypto.FILETYPE_ASN1, tls_certificate
-            )
-            sha256_fingerprint = hashlib.sha256(certificate_bytes).digest()
-            sha256_fingerprint_b64 = encode_base64(sha256_fingerprint)
-
-            response_sha256_fingerprints = set()
-            for fingerprint in response[u"tls_fingerprints"]:
-                if u"sha256" in fingerprint:
-                    response_sha256_fingerprints.add(fingerprint[u"sha256"])
-
-            if sha256_fingerprint_b64 not in response_sha256_fingerprints:
-                raise KeyLookupError("TLS certificate not allowed by fingerprints")
-
             response_keys = yield self.process_v2_response(
                 from_server=server_name,
                 requested_ids=[requested_key_id],
@@ -657,78 +647,6 @@ class Keyring(object):
 
         defer.returnValue(results)
 
-    @defer.inlineCallbacks
-    def get_server_verify_key_v1_direct(self, server_name, key_ids):
-        """Finds a verification key for the server with one of the key ids.
-        Args:
-            server_name (str): The name of the server to fetch a key for.
-            keys_ids (list of str): The key_ids to check for.
-        """
-
-        # Try to fetch the key from the remote server.
-
-        (response, tls_certificate) = yield fetch_server_key(
-            server_name, self.hs.tls_client_options_factory
-        )
-
-        # Check the response.
-
-        x509_certificate_bytes = crypto.dump_certificate(
-            crypto.FILETYPE_ASN1, tls_certificate
-        )
-
-        if ("signatures" not in response
-                or server_name not in response["signatures"]):
-            raise KeyLookupError("Key response not signed by remote server")
-
-        if "tls_certificate" not in response:
-            raise KeyLookupError("Key response missing TLS certificate")
-
-        tls_certificate_b64 = response["tls_certificate"]
-
-        if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
-            raise KeyLookupError("TLS certificate doesn't match")
-
-        # Cache the result in the datastore.
-
-        time_now_ms = self.clock.time_msec()
-
-        verify_keys = {}
-        for key_id, key_base64 in response["verify_keys"].items():
-            if is_signing_algorithm_supported(key_id):
-                key_bytes = decode_base64(key_base64)
-                verify_key = decode_verify_key_bytes(key_id, key_bytes)
-                verify_key.time_added = time_now_ms
-                verify_keys[key_id] = verify_key
-
-        for key_id in response["signatures"][server_name]:
-            if key_id not in response["verify_keys"]:
-                raise KeyLookupError(
-                    "Key response must include verification keys for all"
-                    " signatures"
-                )
-            if key_id in verify_keys:
-                verify_signed_json(
-                    response,
-                    server_name,
-                    verify_keys[key_id]
-                )
-
-        yield self.store.store_server_certificate(
-            server_name,
-            server_name,
-            time_now_ms,
-            tls_certificate,
-        )
-
-        yield self.store_keys(
-            server_name=server_name,
-            from_server=server_name,
-            verify_keys=verify_keys,
-        )
-
-        defer.returnValue(verify_keys)
-
     def store_keys(self, server_name, from_server, verify_keys):
         """Store a collection of verify keys for a given server
         Args:
@@ -768,7 +686,7 @@ def _handle_key_deferred(verify_request):
     try:
         with PreserveLoggingContext():
             _, key_id, verify_key = yield verify_request.deferred
-    except IOError as e:
+    except (IOError, RequestSendFailed) as e:
         logger.warn(
             "Got IOError when downloading keys for %s: %s %s",
             server_name, type(e).__name__, str(e),