summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2015-06-26 09:52:24 +0100
committerErik Johnston <erik@matrix.org>2015-06-26 10:39:34 +0100
commitb5f55a1d85386834b18828b196e1859a4410d98a (patch)
tree6c8b39f1a6df244b7cabcc503bc300f0dd686754 /synapse
parentBatch SELECTs in _get_auth_chain_ids_txn (diff)
downloadsynapse-b5f55a1d85386834b18828b196e1859a4410d98a.tar.xz
Implement bulk verify_signed_json API
Diffstat (limited to '')
-rw-r--r--synapse/crypto/keyring.py426
-rw-r--r--synapse/federation/federation_base.py125
-rw-r--r--synapse/federation/federation_client.py57
-rw-r--r--synapse/storage/keys.py50
4 files changed, 441 insertions, 217 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index aff69c5f83..873c9b40fa 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -25,11 +25,11 @@ from syutil.base64util import decode_base64, encode_base64
 from synapse.api.errors import SynapseError, Codes
 
 from synapse.util.retryutils import get_retry_limiter
-
-from synapse.util.async import ObservableDeferred
+from synapse.util import unwrapFirstError
 
 from OpenSSL import crypto
 
+from collections import namedtuple
 import urllib
 import hashlib
 import logging
@@ -38,6 +38,9 @@ import logging
 logger = logging.getLogger(__name__)
 
 
+KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
+
+
 class Keyring(object):
     def __init__(self, hs):
         self.store = hs.get_datastore()
@@ -49,141 +52,274 @@ class Keyring(object):
 
         self.key_downloads = {}
 
-    @defer.inlineCallbacks
     def verify_json_for_server(self, server_name, json_object):
-        logger.debug("Verifying for %s", server_name)
-        key_ids = signature_ids(json_object, server_name)
-        if not key_ids:
-            raise SynapseError(
-                400,
-                "Not signed with a supported algorithm",
-                Codes.UNAUTHORIZED,
-            )
-        try:
-            verify_key = yield self.get_server_verify_key(server_name, key_ids)
-        except IOError as e:
-            logger.warn(
-                "Got IOError when downloading keys for %s: %s %s",
-                server_name, type(e).__name__, str(e.message),
-            )
-            raise SynapseError(
-                502,
-                "Error downloading keys for %s" % (server_name,),
-                Codes.UNAUTHORIZED,
-            )
-        except Exception as e:
-            logger.warn(
-                "Got Exception when downloading keys for %s: %s %s",
-                server_name, type(e).__name__, str(e.message),
-            )
-            raise SynapseError(
-                401,
-                "No key for %s with id %s" % (server_name, key_ids),
-                Codes.UNAUTHORIZED,
-            )
+        return self.verify_json_objects_for_server(
+            [(server_name, json_object)]
+        )[0]
 
-        try:
-            verify_signed_json(json_object, server_name, verify_key)
-        except:
-            raise SynapseError(
-                401,
-                "Invalid signature for server %s with key %s:%s" % (
-                    server_name, verify_key.alg, verify_key.version
-                ),
-                Codes.UNAUTHORIZED,
-            )
+    def verify_json_objects_for_server(self, server_and_json):
+        """Bulk verfies signatures of json objects, bulk fetching keys as
+        necessary.
 
-    @defer.inlineCallbacks
-    def get_server_verify_key(self, server_name, key_ids):
-        """Finds a verification key for the server with one of the key ids.
-        Trys to fetch the key from a trusted perspective server first.
         Args:
-            server_name(str): The name of the server to fetch a key for.
-            keys_ids (list of str): The key_ids to check for.
+            server_and_json (list): List of pairs of (server_name, json_object)
+
+        Returns:
+            list of deferreds indicating success or failure to verify each
+            json object's signature for the given server_name.
         """
-        cached = yield self.store.get_server_verify_keys(server_name, key_ids)
+        group_id_to_json = {}
+        group_id_to_group = {}
+        group_ids = []
+
+        next_group_id = 0
+        deferreds = {}
+
+        for server_name, json_object in server_and_json:
+            logger.debug("Verifying for %s", server_name)
+            group_id = next_group_id
+            next_group_id += 1
+            group_ids.append(group_id)
+
+            key_ids = signature_ids(json_object, server_name)
+            if not key_ids:
+                deferreds[group_id] = defer.fail(SynapseError(
+                    400,
+                    "Not signed with a supported algorithm",
+                    Codes.UNAUTHORIZED,
+                ))
+
+            group = KeyGroup(server_name, group_id, key_ids)
 
-        if cached:
-            defer.returnValue(cached[0])
-            return
+            group_id_to_group[group_id] = group
+            group_id_to_json[group_id] = json_object
 
-        download = self.key_downloads.get(server_name)
+        @defer.inlineCallbacks
+        def handle_key_deferred(group, deferred):
+            server_name = group.server_name
+            try:
+                _, _, key_id, verify_key = yield deferred
+            except IOError as e:
+                logger.warn(
+                    "Got IOError when downloading keys for %s: %s %s",
+                    server_name, type(e).__name__, str(e.message),
+                )
+                raise SynapseError(
+                    502,
+                    "Error downloading keys for %s" % (server_name,),
+                    Codes.UNAUTHORIZED,
+                )
+            except Exception as e:
+                logger.exception(
+                    "Got Exception when downloading keys for %s: %s %s",
+                    server_name, type(e).__name__, str(e.message),
+                )
+                raise SynapseError(
+                    401,
+                    "No key for %s with id %s" % (server_name, key_ids),
+                    Codes.UNAUTHORIZED,
+                )
+
+            json_object = group_id_to_json[group.group_id]
 
-        if download is None:
-            download = self._get_server_verify_key_impl(server_name, key_ids)
-            download = ObservableDeferred(
-                download,
-                consumeErrors=True
+            try:
+                verify_signed_json(json_object, server_name, verify_key)
+            except:
+                raise SynapseError(
+                    401,
+                    "Invalid signature for server %s with key %s:%s" % (
+                        server_name, verify_key.alg, verify_key.version
+                    ),
+                    Codes.UNAUTHORIZED,
+                )
+
+        deferreds.update(self.get_server_verify_keys(
+            group_id_to_group
+        ))
+
+        return [
+            handle_key_deferred(
+                group_id_to_group[g_id],
+                deferreds[g_id],
             )
-            self.key_downloads[server_name] = download
+            for g_id in group_ids
+        ]
 
-            @download.addBoth
-            def callback(ret):
-                del self.key_downloads[server_name]
-                return ret
+    def get_server_verify_keys(self, group_id_to_group):
+        """Takes a dict of KeyGroups and tries to find at least one key for
+        each group.
+        """
+
+        # These are functions that produce keys given a list of key ids
+        key_fetch_fns = (
+            self.get_keys_from_store,  # First try the local store
+            self.get_keys_from_perspectives,  # Then try via perspectives
+            self.get_keys_from_server,  # Then try directly
+        )
+
+        group_deferreds = {
+            group_id: defer.Deferred()
+            for group_id in group_id_to_group
+        }
+
+        @defer.inlineCallbacks
+        def do_iterations():
+            merged_results = {}
+
+            missing_keys = {
+                group.server_name: key_id
+                for group in group_id_to_group.values()
+                for key_id in group.key_ids
+            }
+
+            for fn in key_fetch_fns:
+                results = yield fn(missing_keys.items())
+                merged_results.update(results)
+
+                # We now need to figure out which groups we have keys for
+                # and which we don't
+                missing_groups = {}
+                for group in group_id_to_group.values():
+                    for key_id in group.key_ids:
+                        if key_id in merged_results[group.server_name]:
+                            group_deferreds.pop(group.group_id).callback((
+                                group.group_id,
+                                group.server_name,
+                                key_id,
+                                merged_results[group.server_name][key_id],
+                            ))
+                            break
+                    else:
+                        missing_groups.setdefault(
+                            group.server_name, []
+                        ).append(group)
+
+                if not missing_groups:
+                    break
+
+                missing_keys = {
+                    server_name: set(
+                        key_id for group in groups for key_id in group.key_ids
+                    )
+                    for server_name, groups in missing_groups.items()
+                }
 
-        r = yield download.observe()
-        defer.returnValue(r)
+            for group in missing_groups.values():
+                group_deferreds.pop(group.group_id).errback(SynapseError(
+                    401,
+                    "No key for %s with id %s" % (
+                        group.server_name, group.key_ids,
+                    ),
+                    Codes.UNAUTHORIZED,
+                ))
+
+        def on_err(err):
+            for deferred in group_deferreds.values():
+                deferred.errback(err)
+            group_deferreds.clear()
+
+        do_iterations().addErrback(on_err)
+
+        return group_deferreds
 
     @defer.inlineCallbacks
-    def _get_server_verify_key_impl(self, server_name, key_ids):
-        keys = None
+    def get_keys_from_store(self, server_name_and_key_ids):
+        res = yield defer.gatherResults(
+            [
+                self.store.get_server_verify_keys(server_name, key_ids)
+                for server_name, key_ids in server_name_and_key_ids
+            ],
+            consumeErrors=True,
+        ).addErrback(unwrapFirstError)
+
+        defer.returnValue(dict(zip(
+            [server_name for server_name, _ in server_name_and_key_ids],
+            res
+        )))
 
+    @defer.inlineCallbacks
+    def get_keys_from_perspectives(self, server_name_and_key_ids):
         @defer.inlineCallbacks
         def get_key(perspective_name, perspective_keys):
             try:
                 result = yield self.get_server_verify_key_v2_indirect(
-                    server_name, key_ids, perspective_name, perspective_keys
+                    server_name_and_key_ids, perspective_name, perspective_keys
                 )
                 defer.returnValue(result)
             except Exception as e:
-                logging.info(
-                    "Unable to getting key %r for %r from %r: %s %s",
-                    key_ids, server_name, perspective_name,
+                logger.exception(
+                    "Unable to get key from %r: %s %s",
+                    perspective_name,
                     type(e).__name__, str(e.message),
                 )
+                defer.returnValue({})
 
-        perspective_results = yield defer.gatherResults([
-            get_key(p_name, p_keys)
-            for p_name, p_keys in self.perspective_servers.items()
-        ])
+        results = yield defer.gatherResults(
+            [
+                get_key(p_name, p_keys)
+                for p_name, p_keys in self.perspective_servers.items()
+            ],
+            consumeErrors=True,
+        ).addErrback(unwrapFirstError)
 
-        for results in perspective_results:
-            if results is not None:
-                keys = results
+        union_of_keys = {}
+        for result in results:
+            for server_name, keys in result.items():
+                union_of_keys.setdefault(server_name, {}).update(keys)
 
-        limiter = yield get_retry_limiter(
-            server_name,
-            self.clock,
-            self.store,
-        )
+        defer.returnValue(union_of_keys)
 
-        with limiter:
-            if not keys:
+    @defer.inlineCallbacks
+    def get_keys_from_server(self, server_name_and_key_ids):
+        @defer.inlineCallbacks
+        def get_key(server_name, key_ids):
+            limiter = yield get_retry_limiter(
+                server_name,
+                self.clock,
+                self.store,
+            )
+            with limiter:
+                keys = None
                 try:
                     keys = yield self.get_server_verify_key_v2_direct(
                         server_name, key_ids
                     )
                 except Exception as e:
-                    logging.info(
+                    logger.info(
                         "Unable to getting key %r for %r directly: %s %s",
                         key_ids, server_name,
                         type(e).__name__, str(e.message),
                     )
 
-            if not keys:
-                keys = yield self.get_server_verify_key_v1_direct(
-                    server_name, key_ids
-                )
+                if not keys:
+                    keys = yield self.get_server_verify_key_v1_direct(
+                        server_name, key_ids
+                    )
+
+                    keys = {server_name: keys}
+
+            defer.returnValue(keys)
+
+        results = yield defer.gatherResults(
+            [
+                get_key(server_name, key_ids)
+                for server_name, key_ids in server_name_and_key_ids
+            ],
+            consumeErrors=True,
+        ).addErrback(unwrapFirstError)
 
-        for key_id in key_ids:
-            if key_id in keys:
-                defer.returnValue(keys[key_id])
-                return
-        raise ValueError("No verification key found for given key ids")
+        merged = {}
+        for result in results:
+            merged.update(result)
+
+        defer.returnValue({
+            server_name: keys
+            for server_name, keys in merged.items()
+            if keys
+        })
 
     @defer.inlineCallbacks
-    def get_server_verify_key_v2_indirect(self, server_name, key_ids,
+    def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
                                           perspective_name,
                                           perspective_keys):
         limiter = yield get_retry_limiter(
@@ -204,6 +340,7 @@ class Keyring(object):
                                 u"minimum_valid_until_ts": 0
                             } for key_id in key_ids
                         }
+                        for server_name, key_ids in server_names_and_key_ids
                     }
                 },
             )
@@ -243,23 +380,29 @@ class Keyring(object):
                     " server %r" % (perspective_name,)
                 )
 
-            response_keys = yield self.process_v2_response(
-                server_name, perspective_name, response
+            processed_response = yield self.process_v2_response(
+                perspective_name, response
             )
 
-            keys.update(response_keys)
+            for server_name, response_keys in processed_response.items():
+                keys.setdefault(server_name, {}).update(response_keys)
 
-        yield self.store_keys(
-            server_name=server_name,
-            from_server=perspective_name,
-            verify_keys=keys,
-        )
+        yield defer.gatherResults(
+            [
+                self.store_keys(
+                    server_name=server_name,
+                    from_server=perspective_name,
+                    verify_keys=response_keys,
+                )
+                for server_name, response_keys in keys.items()
+            ],
+            consumeErrors=True
+        ).addErrback(unwrapFirstError)
 
         defer.returnValue(keys)
 
     @defer.inlineCallbacks
     def get_server_verify_key_v2_direct(self, server_name, key_ids):
-
         keys = {}
 
         for requested_key_id in key_ids:
@@ -295,25 +438,30 @@ class Keyring(object):
                 raise ValueError("TLS certificate not allowed by fingerprints")
 
             response_keys = yield self.process_v2_response(
-                server_name=server_name,
                 from_server=server_name,
-                requested_id=requested_key_id,
+                requested_ids=[requested_key_id],
                 response_json=response,
             )
 
             keys.update(response_keys)
 
-        yield self.store_keys(
-            server_name=server_name,
-            from_server=server_name,
-            verify_keys=keys,
-        )
+        yield defer.gatherResults(
+            [
+                self.store_keys(
+                    server_name=key_server_name,
+                    from_server=server_name,
+                    verify_keys=verify_keys,
+                )
+                for key_server_name, verify_keys in keys.items()
+            ],
+            consumeErrors=True
+        ).addErrback(unwrapFirstError)
 
         defer.returnValue(keys)
 
     @defer.inlineCallbacks
-    def process_v2_response(self, server_name, from_server, response_json,
-                            requested_id=None):
+    def process_v2_response(self, from_server, response_json,
+                            requested_ids=[]):
         time_now_ms = self.clock.time_msec()
         response_keys = {}
         verify_keys = {}
@@ -335,6 +483,8 @@ class Keyring(object):
                 verify_key.time_added = time_now_ms
                 old_verify_keys[key_id] = verify_key
 
+        results = {}
+        server_name = response_json["server_name"]
         for key_id in response_json["signatures"].get(server_name, {}):
             if key_id not in response_json["verify_keys"]:
                 raise ValueError(
@@ -357,28 +507,31 @@ class Keyring(object):
         signed_key_json_bytes = encode_canonical_json(signed_key_json)
         ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
 
-        updated_key_ids = set()
-        if requested_id is not None:
-            updated_key_ids.add(requested_id)
+        updated_key_ids = set(requested_ids)
         updated_key_ids.update(verify_keys)
         updated_key_ids.update(old_verify_keys)
 
         response_keys.update(verify_keys)
         response_keys.update(old_verify_keys)
 
-        for key_id in updated_key_ids:
-            yield self.store.store_server_keys_json(
-                server_name=server_name,
-                key_id=key_id,
-                from_server=server_name,
-                ts_now_ms=time_now_ms,
-                ts_expires_ms=ts_valid_until_ms,
-                key_json_bytes=signed_key_json_bytes,
-            )
+        yield defer.gatherResults(
+            [
+                self.store.store_server_keys_json(
+                    server_name=server_name,
+                    key_id=key_id,
+                    from_server=server_name,
+                    ts_now_ms=time_now_ms,
+                    ts_expires_ms=ts_valid_until_ms,
+                    key_json_bytes=signed_key_json_bytes,
+                )
+                for key_id in updated_key_ids
+            ],
+            consumeErrors=True,
+        ).addErrback(unwrapFirstError)
 
-        defer.returnValue(response_keys)
+        results[server_name] = response_keys
 
-        raise ValueError("No verification key found for given key ids")
+        defer.returnValue(results)
 
     @defer.inlineCallbacks
     def get_server_verify_key_v1_direct(self, server_name, key_ids):
@@ -462,8 +615,13 @@ class Keyring(object):
         Returns:
             A deferred that completes when the keys are stored.
         """
-        for key_id, key in verify_keys.items():
-            # TODO(markjh): Store whether the keys have expired.
-            yield self.store.store_server_verify_key(
-                server_name, server_name, key.time_added, key
-            )
+        # TODO(markjh): Store whether the keys have expired.
+        yield defer.gatherResults(
+            [
+                self.store.store_server_verify_key(
+                    server_name, server_name, key.time_added, key
+                )
+                for key_id, key in verify_keys.items()
+            ],
+            consumeErrors=True,
+        ).addErrback(unwrapFirstError)
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 299493af91..bdfa247604 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -32,7 +32,8 @@ logger = logging.getLogger(__name__)
 
 class FederationBase(object):
     @defer.inlineCallbacks
-    def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False):
+    def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
+                                       include_none=False):
         """Takes a list of PDUs and checks the signatures and hashs of each
         one. If a PDU fails its signature check then we check if we have it in
         the database and if not then request if from the originating server of
@@ -50,84 +51,108 @@ class FederationBase(object):
         Returns:
             Deferred : A list of PDUs that have valid signatures and hashes.
         """
+        deferreds = self._check_sigs_and_hashes(pdus)
 
-        signed_pdus = []
+        def callback(pdu):
+            return pdu
 
-        @defer.inlineCallbacks
-        def do(pdu):
-            try:
-                new_pdu = yield self._check_sigs_and_hash(pdu)
-                signed_pdus.append(new_pdu)
-            except SynapseError:
-                # FIXME: We should handle signature failures more gracefully.
+        def errback(failure, pdu):
+            failure.trap(SynapseError)
+            return None
 
+        def try_local_db(res, pdu):
+            if not res:
                 # Check local db.
-                new_pdu = yield self.store.get_event(
+                return self.store.get_event(
                     pdu.event_id,
                     allow_rejected=True,
                     allow_none=True,
                 )
-                if new_pdu:
-                    signed_pdus.append(new_pdu)
-                    return
-
-                # Check pdu.origin
-                if pdu.origin != origin:
-                    try:
-                        new_pdu = yield self.get_pdu(
-                            destinations=[pdu.origin],
-                            event_id=pdu.event_id,
-                            outlier=outlier,
-                            timeout=10000,
-                        )
-
-                        if new_pdu:
-                            signed_pdus.append(new_pdu)
-                            return
-                    except:
-                        pass
-
+            return res
+
+        def try_remote(res, pdu):
+            if not res and pdu.origin != origin:
+                return self.get_pdu(
+                    destinations=[pdu.origin],
+                    event_id=pdu.event_id,
+                    outlier=outlier,
+                    timeout=10000,
+                ).addErrback(lambda e: None)
+            return res
+
+        def warn(res, pdu):
+            if not res:
                 logger.warn(
                     "Failed to find copy of %s with valid signature",
                     pdu.event_id,
                 )
+            return res
+
+        for pdu, deferred in zip(pdus, deferreds):
+            deferred.addCallbacks(
+                callback, errback, errbackArgs=[pdu]
+            ).addCallback(
+                try_local_db, pdu
+            ).addCallback(
+                try_remote, pdu
+            ).addCallback(
+                warn, pdu
+            )
 
-        yield defer.gatherResults(
-            [do(pdu) for pdu in pdus],
+        valid_pdus = yield defer.gatherResults(
+            deferreds,
             consumeErrors=True
         ).addErrback(unwrapFirstError)
 
-        defer.returnValue(signed_pdus)
+        if include_none:
+            defer.returnValue(valid_pdus)
+        else:
+            defer.returnValue([p for p in valid_pdus if p])
 
-    @defer.inlineCallbacks
     def _check_sigs_and_hash(self, pdu):
-        """Throws a SynapseError if the PDU does not have the correct
+        return self._check_sigs_and_hashes([pdu])[0]
+
+    def _check_sigs_and_hashes(self, pdus):
+        """Throws a SynapseError if a PDU does not have the correct
         signatures.
 
         Returns:
             FrozenEvent: Either the given event or it redacted if it failed the
             content hash check.
         """
-        # Check signatures are correct.
-        redacted_event = prune_event(pdu)
-        redacted_pdu_json = redacted_event.get_pdu_json()
 
-        try:
-            yield self.keyring.verify_json_for_server(
-                pdu.origin, redacted_pdu_json
-            )
-        except SynapseError:
+        redacted_pdus = [
+            prune_event(pdu)
+            for pdu in pdus
+        ]
+
+        deferreds = self.keyring.verify_json_objects_for_server([
+            (p.origin, p.get_pdu_json())
+            for p in redacted_pdus
+        ])
+
+        def callback(_, pdu, redacted):
+            if not check_event_content_hash(pdu):
+                logger.warn(
+                    "Event content has been tampered, redacting %s: %s",
+                    pdu.event_id, pdu.get_pdu_json()
+                )
+                return redacted
+            return pdu
+
+        def errback(failure, pdu):
+            failure.trap(SynapseError)
             logger.warn(
                 "Signature check failed for %s",
                 pdu.event_id,
             )
-            raise
+            return failure
 
-        if not check_event_content_hash(pdu):
-            logger.warn(
-                "Event content has been tampered, redacting.",
-                pdu.event_id,
+        for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
+            deferred.addCallbacks(
+                callback, errback,
+                callbackArgs=[pdu, redacted],
+                errbackArgs=[pdu],
             )
-            defer.returnValue(redacted_event)
 
-        defer.returnValue(pdu)
+        return deferreds
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index d3b46b24c1..7736d14fb5 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -30,6 +30,7 @@ import synapse.metrics
 
 from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
 
+import copy
 import itertools
 import logging
 import random
@@ -167,7 +168,7 @@ class FederationClient(FederationBase):
 
         # FIXME: We should handle signature failures more gracefully.
         pdus[:] = yield defer.gatherResults(
-            [self._check_sigs_and_hash(pdu) for pdu in pdus],
+            self._check_sigs_and_hashes(pdus),
             consumeErrors=True,
         ).addErrback(unwrapFirstError)
 
@@ -230,7 +231,7 @@ class FederationClient(FederationBase):
                         pdu = pdu_list[0]
 
                         # Check signatures are correct.
-                        pdu = yield self._check_sigs_and_hash(pdu)
+                        pdu = yield self._check_sigs_and_hashes([pdu])[0]
 
                         break
 
@@ -327,6 +328,9 @@ class FederationClient(FederationBase):
     @defer.inlineCallbacks
     def make_join(self, destinations, room_id, user_id):
         for destination in destinations:
+            if destination == self.server_name:
+                continue
+
             try:
                 ret = yield self.transport_layer.make_join(
                     destination, room_id, user_id
@@ -353,6 +357,9 @@ class FederationClient(FederationBase):
     @defer.inlineCallbacks
     def send_join(self, destinations, pdu):
         for destination in destinations:
+            if destination == self.server_name:
+                continue
+
             try:
                 time_now = self._clock.time_msec()
                 _, content = yield self.transport_layer.send_join(
@@ -374,17 +381,39 @@ class FederationClient(FederationBase):
                     for p in content.get("auth_chain", [])
                 ]
 
-                signed_state, signed_auth = yield defer.gatherResults(
-                    [
-                        self._check_sigs_and_hash_and_fetch(
-                            destination, state, outlier=True
-                        ),
-                        self._check_sigs_and_hash_and_fetch(
-                            destination, auth_chain, outlier=True
-                        )
-                    ],
-                    consumeErrors=True
-                ).addErrback(unwrapFirstError)
+                pdus = {
+                    p.event_id: p
+                    for p in itertools.chain(state, auth_chain)
+                }
+
+                valid_pdus = yield self._check_sigs_and_hash_and_fetch(
+                    destination, pdus.values(),
+                    outlier=True,
+                )
+
+                valid_pdus_map = {
+                    p.event_id: p
+                    for p in valid_pdus
+                }
+
+                # NB: We *need* to copy to ensure that we don't have multiple
+                # references being passed on, as that causes... issues.
+                signed_state = [
+                    copy.copy(valid_pdus_map[p.event_id])
+                    for p in state
+                    if p.event_id in valid_pdus_map
+                ]
+
+                signed_auth = [
+                    valid_pdus_map[p.event_id]
+                    for p in auth_chain
+                    if p.event_id in valid_pdus_map
+                ]
+
+                # NB: We *need* to copy to ensure that we don't have multiple
+                # references being passed on, as that causes... issues.
+                for s in signed_state:
+                    s.internal_metadata = copy.deepcopy(s.internal_metadata)
 
                 auth_chain.sort(key=lambda e: e.depth)
 
@@ -396,7 +425,7 @@ class FederationClient(FederationBase):
             except CodeMessageException:
                 raise
             except Exception as e:
-                logger.warn(
+                logger.exception(
                     "Failed to send_join via %s: %s",
                     destination, e.message
                 )
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 5bdf497b93..940a5f7e08 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from _base import SQLBaseStore
+from _base import SQLBaseStore, cached
 
 from twisted.internet import defer
 
@@ -71,6 +71,25 @@ class KeyStore(SQLBaseStore):
             desc="store_server_certificate",
         )
 
+    @cached()
+    @defer.inlineCallbacks
+    def get_all_server_verify_keys(self, server_name):
+        rows = yield self._simple_select_list(
+            table="server_signature_keys",
+            keyvalues={
+                "server_name": server_name,
+            },
+            retcols=["key_id", "verify_key"],
+            desc="get_all_server_verify_keys",
+        )
+
+        defer.returnValue({
+            row["key_id"]: decode_verify_key_bytes(
+                row["key_id"], str(row["verify_key"])
+            )
+            for row in rows
+        })
+
     @defer.inlineCallbacks
     def get_server_verify_keys(self, server_name, key_ids):
         """Retrieve the NACL verification key for a given server for the given
@@ -81,24 +100,14 @@ class KeyStore(SQLBaseStore):
         Returns:
             (list of VerifyKey): The verification keys.
         """
-        sql = (
-            "SELECT key_id, verify_key FROM server_signature_keys"
-            " WHERE server_name = ?"
-            " AND key_id in (" + ",".join("?" for key_id in key_ids) + ")"
-        )
-
-        rows = yield self._execute_and_decode(
-            "get_server_verify_keys", sql, server_name, *key_ids
-        )
-
-        keys = []
-        for row in rows:
-            key_id = row["key_id"]
-            key_bytes = row["verify_key"]
-            key = decode_verify_key_bytes(key_id, str(key_bytes))
-            keys.append(key)
-        defer.returnValue(keys)
+        keys = yield self.get_all_server_verify_keys(server_name)
+        defer.returnValue({
+            k: keys[k]
+            for k in key_ids
+            if k in keys and keys[k]
+        })
 
+    @defer.inlineCallbacks
     def store_server_verify_key(self, server_name, from_server, time_now_ms,
                                 verify_key):
         """Stores a NACL verification key for the given server.
@@ -109,7 +118,7 @@ class KeyStore(SQLBaseStore):
             ts_now_ms (int): The time now in milliseconds
             verification_key (VerifyKey): The NACL verify key.
         """
-        return self._simple_upsert(
+        yield self._simple_upsert(
             table="server_signature_keys",
             keyvalues={
                 "server_name": server_name,
@@ -123,6 +132,8 @@ class KeyStore(SQLBaseStore):
             desc="store_server_verify_key",
         )
 
+        self.get_all_server_verify_keys.invalidate(server_name)
+
     def store_server_keys_json(self, server_name, key_id, from_server,
                                ts_now_ms, ts_expires_ms, key_json_bytes):
         """Stores the JSON bytes for a set of keys from a server
@@ -152,6 +163,7 @@ class KeyStore(SQLBaseStore):
                 "ts_valid_until_ms": ts_expires_ms,
                 "key_json": buffer(key_json_bytes),
             },
+            desc="store_server_keys_json",
         )
 
     def get_server_keys_json(self, server_keys):