summary refs log tree commit diff
path: root/synapse/crypto/keyring.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/crypto/keyring.py')
-rw-r--r--synapse/crypto/keyring.py899
1 files changed, 517 insertions, 382 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index d8ba870cca..0fd15287e7 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -15,12 +15,13 @@
 # limitations under the License.
 
 import logging
-from collections import namedtuple
+from collections import defaultdict
 
+import six
 from six import raise_from
 from six.moves import urllib
 
-import nacl.signing
+import attr
 from signedjson.key import (
     decode_verify_key_bytes,
     encode_verify_key_base64,
@@ -43,7 +44,9 @@ from synapse.api.errors import (
     RequestSendFailed,
     SynapseError,
 )
+from synapse.storage.keys import FetchKeyResult
 from synapse.util import logcontext, unwrapFirstError
+from synapse.util.async_helpers import yieldable_gather_results
 from synapse.util.logcontext import (
     LoggingContext,
     PreserveLoggingContext,
@@ -56,22 +59,36 @@ from synapse.util.retryutils import NotRetryingDestination
 logger = logging.getLogger(__name__)
 
 
-VerifyKeyRequest = namedtuple("VerifyRequest", (
-    "server_name", "key_ids", "json_object", "deferred"
-))
-"""
-A request for a verify key to verify a JSON object.
+@attr.s(slots=True, cmp=False)
+class VerifyKeyRequest(object):
+    """
+    A request for a verify key to verify a JSON object.
+
+    Attributes:
+        server_name(str): The name of the server to verify against.
+
+        key_ids(set[str]): The set of key_ids to that could be used to verify the
+            JSON object
 
-Attributes:
-    server_name(str): The name of the server to verify against.
-    key_ids(set(str)): The set of key_ids to that could be used to verify the
-        JSON object
-    json_object(dict): The JSON object to verify.
-    deferred(Deferred[str, str, nacl.signing.VerifyKey]):
-        A deferred (server_name, key_id, verify_key) tuple that resolves when
-        a verify key has been fetched. The deferreds' callbacks are run with no
-        logcontext.
-"""
+        json_object(dict): The JSON object to verify.
+
+        minimum_valid_until_ts (int): time at which we require the signing key to
+            be valid. (0 implies we don't care)
+
+        deferred(Deferred[str, str, nacl.signing.VerifyKey]):
+            A deferred (server_name, key_id, verify_key) tuple that resolves when
+            a verify key has been fetched. The deferreds' callbacks are run with no
+            logcontext.
+
+            If we are unable to find a key which satisfies the request, the deferred
+            errbacks with an M_UNAUTHORIZED SynapseError.
+    """
+
+    server_name = attr.ib()
+    key_ids = attr.ib()
+    json_object = attr.ib()
+    minimum_valid_until_ts = attr.ib()
+    deferred = attr.ib(default=attr.Factory(defer.Deferred))
 
 
 class KeyLookupError(ValueError):
@@ -79,13 +96,16 @@ class KeyLookupError(ValueError):
 
 
 class Keyring(object):
-    def __init__(self, hs):
-        self.store = hs.get_datastore()
+    def __init__(self, hs, key_fetchers=None):
         self.clock = hs.get_clock()
-        self.client = hs.get_http_client()
-        self.config = hs.get_config()
-        self.perspective_servers = self.config.perspectives
-        self.hs = hs
+
+        if key_fetchers is None:
+            key_fetchers = (
+                StoreKeyFetcher(hs),
+                PerspectivesKeyFetcher(hs),
+                ServerKeyFetcher(hs),
+            )
+        self._key_fetchers = key_fetchers
 
         # map from server name to Deferred. Has an entry for each server with
         # an ongoing key download; the Deferred completes once the download
@@ -94,11 +114,25 @@ class Keyring(object):
         # These are regular, logcontext-agnostic Deferreds.
         self.key_downloads = {}
 
-    def verify_json_for_server(self, server_name, json_object):
+    def verify_json_for_server(self, server_name, json_object, validity_time):
+        """Verify that a JSON object has been signed by a given server
+
+        Args:
+            server_name (str): name of the server which must have signed this object
+
+            json_object (dict): object to be checked
+
+            validity_time (int): timestamp at which we require the signing key to
+                be valid. (0 implies we don't care)
+
+        Returns:
+            Deferred[None]: completes if the the object was correctly signed, otherwise
+                errbacks with an error
+        """
+        req = server_name, json_object, validity_time
+
         return logcontext.make_deferred_yieldable(
-            self.verify_json_objects_for_server(
-                [(server_name, json_object)]
-            )[0]
+            self.verify_json_objects_for_server((req,))[0]
         )
 
     def verify_json_objects_for_server(self, server_and_json):
@@ -106,10 +140,12 @@ class Keyring(object):
         necessary.
 
         Args:
-            server_and_json (list): List of pairs of (server_name, json_object)
+            server_and_json (iterable[Tuple[str, dict, int]):
+                Iterable of triplets of (server_name, json_object, validity_time)
+                validity_time is a timestamp at which the signing key must be valid.
 
         Returns:
-            List<Deferred>: for each input pair, a deferred indicating success
+            List<Deferred[None]>: for each input triplet, a deferred indicating success
                 or failure to verify each json object's signature for the given
                 server_name. The deferreds run their callbacks in the sentinel
                 logcontext.
@@ -118,30 +154,32 @@ class Keyring(object):
         verify_requests = []
         handle = preserve_fn(_handle_key_deferred)
 
-        def process(server_name, json_object):
+        def process(server_name, json_object, validity_time):
             """Process an entry in the request list
 
-            Given a (server_name, json_object) pair from the request list,
-            adds a key request to verify_requests, and returns a deferred which will
-            complete or fail (in the sentinel context) when verification completes.
+            Given a (server_name, json_object, validity_time) triplet from the request
+            list, adds a key request to verify_requests, and returns a deferred which
+            will complete or fail (in the sentinel context) when verification completes.
             """
             key_ids = signature_ids(json_object, server_name)
 
             if not key_ids:
                 return defer.fail(
                     SynapseError(
-                        400,
-                        "Not signed by %s" % (server_name,),
-                        Codes.UNAUTHORIZED,
+                        400, "Not signed by %s" % (server_name,), Codes.UNAUTHORIZED
                     )
                 )
 
-            logger.debug("Verifying for %s with key_ids %s",
-                         server_name, key_ids)
+            logger.debug(
+                "Verifying for %s with key_ids %s, min_validity %i",
+                server_name,
+                key_ids,
+                validity_time,
+            )
 
             # add the key request to the queue, but don't start it off yet.
             verify_request = VerifyKeyRequest(
-                server_name, key_ids, json_object, defer.Deferred(),
+                server_name, key_ids, json_object, validity_time
             )
             verify_requests.append(verify_request)
 
@@ -153,8 +191,8 @@ class Keyring(object):
             return handle(verify_request)
 
         results = [
-            process(server_name, json_object)
-            for server_name, json_object in server_and_json
+            process(server_name, json_object, validity_time)
+            for server_name, json_object, validity_time in server_and_json
         ]
 
         if verify_requests:
@@ -179,16 +217,12 @@ class Keyring(object):
             # any other lookups until we have finished.
             # The deferreds are called with no logcontext.
             server_to_deferred = {
-                rq.server_name: defer.Deferred()
-                for rq in verify_requests
+                rq.server_name: defer.Deferred() for rq in verify_requests
             }
 
             # We want to wait for any previous lookups to complete before
             # proceeding.
-            yield self.wait_for_previous_lookups(
-                [rq.server_name for rq in verify_requests],
-                server_to_deferred,
-            )
+            yield self.wait_for_previous_lookups(server_to_deferred)
 
             # Actually start fetching keys.
             self._get_server_verify_keys(verify_requests)
@@ -216,19 +250,16 @@ class Keyring(object):
                 return res
 
             for verify_request in verify_requests:
-                verify_request.deferred.addBoth(
-                    remove_deferreds, verify_request,
-                )
+                verify_request.deferred.addBoth(remove_deferreds, verify_request)
         except Exception:
             logger.exception("Error starting key lookups")
 
     @defer.inlineCallbacks
-    def wait_for_previous_lookups(self, server_names, server_to_deferred):
+    def wait_for_previous_lookups(self, server_to_deferred):
         """Waits for any previous key lookups for the given servers to finish.
 
         Args:
-            server_names (list): list of server_names we want to lookup
-            server_to_deferred (dict): server_name to deferred which gets
+            server_to_deferred (dict[str, Deferred]): server_name to deferred which gets
                 resolved once we've finished looking up keys for that server.
                 The Deferreds should be regular twisted ones which call their
                 callbacks with no logcontext.
@@ -241,14 +272,15 @@ class Keyring(object):
         while True:
             wait_on = [
                 (server_name, self.key_downloads[server_name])
-                for server_name in server_names
+                for server_name in server_to_deferred.keys()
                 if server_name in self.key_downloads
             ]
             if not wait_on:
                 break
             logger.info(
                 "Waiting for existing lookups for %s to complete [loop %i]",
-                [w[0] for w in wait_on], loop_count,
+                [w[0] for w in wait_on],
+                loop_count,
             )
             with PreserveLoggingContext():
                 yield defer.DeferredList((w[1] for w in wait_on))
@@ -279,129 +311,290 @@ class Keyring(object):
             verify_requests (list[VerifyKeyRequest]): list of verify requests
         """
 
-        # These are functions that produce keys given a list of key ids
-        key_fetch_fns = (
-            self.get_keys_from_store,  # First try the local store
-            self.get_keys_from_perspectives,  # Then try via perspectives
-            self.get_keys_from_server,  # Then try directly
+        remaining_requests = set(
+            (rq for rq in verify_requests if not rq.deferred.called)
         )
 
         @defer.inlineCallbacks
         def do_iterations():
             with Measure(self.clock, "get_server_verify_keys"):
-                # dict[str, set(str)]: keys to fetch for each server
-                missing_keys = {}
-                for verify_request in verify_requests:
-                    missing_keys.setdefault(verify_request.server_name, set()).update(
-                        verify_request.key_ids
-                    )
-
-                for fn in key_fetch_fns:
-                    results = yield fn(missing_keys.items())
-
-                    # We now need to figure out which verify requests we have keys
-                    # for and which we don't
-                    missing_keys = {}
-                    requests_missing_keys = []
-                    for verify_request in verify_requests:
-                        if verify_request.deferred.called:
-                            # We've already called this deferred, which probably
-                            # means that we've already found a key for it.
-                            continue
-
-                        server_name = verify_request.server_name
-
-                        # see if any of the keys we got this time are sufficient to
-                        # complete this VerifyKeyRequest.
-                        result_keys = results.get(server_name, {})
-                        for key_id in verify_request.key_ids:
-                            key = result_keys.get(key_id)
-                            if key:
-                                with PreserveLoggingContext():
-                                    verify_request.deferred.callback(
-                                        (server_name, key_id, key)
-                                    )
-                                break
-                        else:
-                            # The else block is only reached if the loop above
-                            # doesn't break.
-                            missing_keys.setdefault(server_name, set()).update(
-                                verify_request.key_ids
-                            )
-                            requests_missing_keys.append(verify_request)
-
-                    if not missing_keys:
-                        break
+                for f in self._key_fetchers:
+                    if not remaining_requests:
+                        return
+                    yield self._attempt_key_fetches_with_fetcher(f, remaining_requests)
 
+                # look for any requests which weren't satisfied
                 with PreserveLoggingContext():
-                    for verify_request in requests_missing_keys:
-                        verify_request.deferred.errback(SynapseError(
-                            401,
-                            "No key for %s with id %s" % (
-                                verify_request.server_name, verify_request.key_ids,
-                            ),
-                            Codes.UNAUTHORIZED,
-                        ))
+                    for verify_request in remaining_requests:
+                        verify_request.deferred.errback(
+                            SynapseError(
+                                401,
+                                "No key for %s with ids in %s (min_validity %i)"
+                                % (
+                                    verify_request.server_name,
+                                    verify_request.key_ids,
+                                    verify_request.minimum_valid_until_ts,
+                                ),
+                                Codes.UNAUTHORIZED,
+                            )
+                        )
 
         def on_err(err):
+            # we don't really expect to get here, because any errors should already
+            # have been caught and logged. But if we do, let's log the error and make
+            # sure that all of the deferreds are resolved.
+            logger.error("Unexpected error in _get_server_verify_keys: %s", err)
             with PreserveLoggingContext():
-                for verify_request in verify_requests:
+                for verify_request in remaining_requests:
                     if not verify_request.deferred.called:
                         verify_request.deferred.errback(err)
 
         run_in_background(do_iterations).addErrback(on_err)
 
     @defer.inlineCallbacks
-    def get_keys_from_store(self, server_name_and_key_ids):
+    def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
+        """Use a key fetcher to attempt to satisfy some key requests
+
+        Args:
+            fetcher (KeyFetcher): fetcher to use to fetch the keys
+            remaining_requests (set[VerifyKeyRequest]): outstanding key requests.
+                Any successfully-completed requests will be removed from the list.
+        """
+        # dict[str, dict[str, int]]: keys to fetch.
+        # server_name -> key_id -> min_valid_ts
+        missing_keys = defaultdict(dict)
+
+        for verify_request in remaining_requests:
+            # any completed requests should already have been removed
+            assert not verify_request.deferred.called
+            keys_for_server = missing_keys[verify_request.server_name]
+
+            for key_id in verify_request.key_ids:
+                # If we have several requests for the same key, then we only need to
+                # request that key once, but we should do so with the greatest
+                # min_valid_until_ts of the requests, so that we can satisfy all of
+                # the requests.
+                keys_for_server[key_id] = max(
+                    keys_for_server.get(key_id, -1),
+                    verify_request.minimum_valid_until_ts
+                )
+
+        results = yield fetcher.get_keys(missing_keys)
+
+        completed = list()
+        for verify_request in remaining_requests:
+            server_name = verify_request.server_name
+
+            # see if any of the keys we got this time are sufficient to
+            # complete this VerifyKeyRequest.
+            result_keys = results.get(server_name, {})
+            for key_id in verify_request.key_ids:
+                fetch_key_result = result_keys.get(key_id)
+                if not fetch_key_result:
+                    # we didn't get a result for this key
+                    continue
+
+                if (
+                    fetch_key_result.valid_until_ts
+                    < verify_request.minimum_valid_until_ts
+                ):
+                    # key was not valid at this point
+                    continue
+
+                with PreserveLoggingContext():
+                    verify_request.deferred.callback(
+                        (server_name, key_id, fetch_key_result.verify_key)
+                    )
+                completed.append(verify_request)
+                break
+
+        remaining_requests.difference_update(completed)
+
+
+class KeyFetcher(object):
+    def get_keys(self, keys_to_fetch):
         """
         Args:
-            server_name_and_key_ids (iterable(Tuple[str, iterable[str]]):
-                list of (server_name, iterable[key_id]) tuples to fetch keys for
+            keys_to_fetch (dict[str, dict[str, int]]):
+                the keys to be fetched. server_name -> key_id -> min_valid_ts
 
         Returns:
-            Deferred: resolves to dict[str, dict[str, VerifyKey|None]]: map from
-                server_name -> key_id -> VerifyKey
+            Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
+                map from server_name -> key_id -> FetchKeyResult
         """
+        raise NotImplementedError
+
+
+class StoreKeyFetcher(KeyFetcher):
+    """KeyFetcher impl which fetches keys from our data store"""
+
+    def __init__(self, hs):
+        self.store = hs.get_datastore()
+
+    @defer.inlineCallbacks
+    def get_keys(self, keys_to_fetch):
+        """see KeyFetcher.get_keys"""
+
         keys_to_fetch = (
             (server_name, key_id)
-            for server_name, key_ids in server_name_and_key_ids
-            for key_id in key_ids
+            for server_name, keys_for_server in keys_to_fetch.items()
+            for key_id in keys_for_server.keys()
         )
+
         res = yield self.store.get_server_verify_keys(keys_to_fetch)
         keys = {}
         for (server_name, key_id), key in res.items():
             keys.setdefault(server_name, {})[key_id] = key
         defer.returnValue(keys)
 
+
+class BaseV2KeyFetcher(object):
+    def __init__(self, hs):
+        self.store = hs.get_datastore()
+        self.config = hs.get_config()
+
+    @defer.inlineCallbacks
+    def process_v2_response(
+        self, from_server, response_json, time_added_ms
+    ):
+        """Parse a 'Server Keys' structure from the result of a /key request
+
+        This is used to parse either the entirety of the response from
+        GET /_matrix/key/v2/server, or a single entry from the list returned by
+        POST /_matrix/key/v2/query.
+
+        Checks that each signature in the response that claims to come from the origin
+        server is valid, and that there is at least one such signature.
+
+        Stores the json in server_keys_json so that it can be used for future responses
+        to /_matrix/key/v2/query.
+
+        Args:
+            from_server (str): the name of the server producing this result: either
+                the origin server for a /_matrix/key/v2/server request, or the notary
+                for a /_matrix/key/v2/query.
+
+            response_json (dict): the json-decoded Server Keys response object
+
+            time_added_ms (int): the timestamp to record in server_keys_json
+
+        Returns:
+            Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
+        """
+        ts_valid_until_ms = response_json[u"valid_until_ts"]
+
+        # start by extracting the keys from the response, since they may be required
+        # to validate the signature on the response.
+        verify_keys = {}
+        for key_id, key_data in response_json["verify_keys"].items():
+            if is_signing_algorithm_supported(key_id):
+                key_base64 = key_data["key"]
+                key_bytes = decode_base64(key_base64)
+                verify_key = decode_verify_key_bytes(key_id, key_bytes)
+                verify_keys[key_id] = FetchKeyResult(
+                    verify_key=verify_key, valid_until_ts=ts_valid_until_ms
+                )
+
+        server_name = response_json["server_name"]
+        verified = False
+        for key_id in response_json["signatures"].get(server_name, {}):
+            # each of the keys used for the signature must be present in the response
+            # json.
+            key = verify_keys.get(key_id)
+            if not key:
+                raise KeyLookupError(
+                    "Key response is signed by key id %s:%s but that key is not "
+                    "present in the response" % (server_name, key_id)
+                )
+
+            verify_signed_json(response_json, server_name, key.verify_key)
+            verified = True
+
+        if not verified:
+            raise KeyLookupError(
+                "Key response for %s is not signed by the origin server"
+                % (server_name,)
+            )
+
+        for key_id, key_data in response_json["old_verify_keys"].items():
+            if is_signing_algorithm_supported(key_id):
+                key_base64 = key_data["key"]
+                key_bytes = decode_base64(key_base64)
+                verify_key = decode_verify_key_bytes(key_id, key_bytes)
+                verify_keys[key_id] = FetchKeyResult(
+                    verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
+                )
+
+        # re-sign the json with our own key, so that it is ready if we are asked to
+        # give it out as a notary server
+        signed_key_json = sign_json(
+            response_json, self.config.server_name, self.config.signing_key[0]
+        )
+
+        signed_key_json_bytes = encode_canonical_json(signed_key_json)
+
+        yield logcontext.make_deferred_yieldable(
+            defer.gatherResults(
+                [
+                    run_in_background(
+                        self.store.store_server_keys_json,
+                        server_name=server_name,
+                        key_id=key_id,
+                        from_server=from_server,
+                        ts_now_ms=time_added_ms,
+                        ts_expires_ms=ts_valid_until_ms,
+                        key_json_bytes=signed_key_json_bytes,
+                    )
+                    for key_id in verify_keys
+                ],
+                consumeErrors=True,
+            ).addErrback(unwrapFirstError)
+        )
+
+        defer.returnValue(verify_keys)
+
+
+class PerspectivesKeyFetcher(BaseV2KeyFetcher):
+    """KeyFetcher impl which fetches keys from the "perspectives" servers"""
+
+    def __init__(self, hs):
+        super(PerspectivesKeyFetcher, self).__init__(hs)
+        self.clock = hs.get_clock()
+        self.client = hs.get_http_client()
+        self.perspective_servers = self.config.perspectives
+
     @defer.inlineCallbacks
-    def get_keys_from_perspectives(self, server_name_and_key_ids):
+    def get_keys(self, keys_to_fetch):
+        """see KeyFetcher.get_keys"""
+
         @defer.inlineCallbacks
         def get_key(perspective_name, perspective_keys):
             try:
                 result = yield self.get_server_verify_key_v2_indirect(
-                    server_name_and_key_ids, perspective_name, perspective_keys
+                    keys_to_fetch, perspective_name, perspective_keys
                 )
                 defer.returnValue(result)
             except KeyLookupError as e:
-                logger.warning(
-                    "Key lookup failed from %r: %s", perspective_name, e,
-                )
+                logger.warning("Key lookup failed from %r: %s", perspective_name, e)
             except Exception as e:
                 logger.exception(
                     "Unable to get key from %r: %s %s",
                     perspective_name,
-                    type(e).__name__, str(e),
+                    type(e).__name__,
+                    str(e),
                 )
 
             defer.returnValue({})
 
-        results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
-            [
-                run_in_background(get_key, p_name, p_keys)
-                for p_name, p_keys in self.perspective_servers.items()
-            ],
-            consumeErrors=True,
-        ).addErrback(unwrapFirstError))
+        results = yield logcontext.make_deferred_yieldable(
+            defer.gatherResults(
+                [
+                    run_in_background(get_key, p_name, p_keys)
+                    for p_name, p_keys in self.perspective_servers.items()
+                ],
+                consumeErrors=True,
+            ).addErrback(unwrapFirstError)
+        )
 
         union_of_keys = {}
         for result in results:
@@ -411,36 +604,33 @@ class Keyring(object):
         defer.returnValue(union_of_keys)
 
     @defer.inlineCallbacks
-    def get_keys_from_server(self, server_name_and_key_ids):
-        results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
-            [
-                run_in_background(
-                    self.get_server_verify_key_v2_direct,
-                    server_name,
-                    key_ids,
-                )
-                for server_name, key_ids in server_name_and_key_ids
-            ],
-            consumeErrors=True,
-        ).addErrback(unwrapFirstError))
+    def get_server_verify_key_v2_indirect(
+        self, keys_to_fetch, perspective_name, perspective_keys
+    ):
+        """
+        Args:
+            keys_to_fetch (dict[str, dict[str, int]]):
+                the keys to be fetched. server_name -> key_id -> min_valid_ts
 
-        merged = {}
-        for result in results:
-            merged.update(result)
+            perspective_name (str): name of the notary server to query for the keys
 
-        defer.returnValue({
-            server_name: keys
-            for server_name, keys in merged.items()
-            if keys
-        })
+            perspective_keys (dict[str, VerifyKey]): map of key_id->key for the
+                notary server
+
+        Returns:
+            Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map
+                from server_name -> key_id -> FetchKeyResult
+
+        Raises:
+            KeyLookupError if there was an error processing the entire response from
+                the server
+        """
+        logger.info(
+            "Requesting keys %s from notary server %s",
+            keys_to_fetch.items(),
+            perspective_name,
+        )
 
-    @defer.inlineCallbacks
-    def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
-                                          perspective_name,
-                                          perspective_keys):
-        # TODO(mark): Set the minimum_valid_until_ts to that needed by
-        # the events being validated or the current time if validating
-        # an incoming request.
         try:
             query_response = yield self.client.post_json(
                 destination=perspective_name,
@@ -448,249 +638,213 @@ class Keyring(object):
                 data={
                     u"server_keys": {
                         server_name: {
-                            key_id: {
-                                u"minimum_valid_until_ts": 0
-                            } for key_id in key_ids
+                            key_id: {u"minimum_valid_until_ts": min_valid_ts}
+                            for key_id, min_valid_ts in server_keys.items()
                         }
-                        for server_name, key_ids in server_names_and_key_ids
+                        for server_name, server_keys in keys_to_fetch.items()
                     }
                 },
-                long_retries=True,
             )
         except (NotRetryingDestination, RequestSendFailed) as e:
-            raise_from(
-                KeyLookupError("Failed to connect to remote server"), e,
-            )
+            raise_from(KeyLookupError("Failed to connect to remote server"), e)
         except HttpResponseException as e:
-            raise_from(
-                KeyLookupError("Remote server returned an error"), e,
-            )
+            raise_from(KeyLookupError("Remote server returned an error"), e)
 
         keys = {}
+        added_keys = []
 
-        responses = query_response["server_keys"]
+        time_now_ms = self.clock.time_msec()
 
-        for response in responses:
-            if (u"signatures" not in response
-                    or perspective_name not in response[u"signatures"]):
+        for response in query_response["server_keys"]:
+            # do this first, so that we can give useful errors thereafter
+            server_name = response.get("server_name")
+            if not isinstance(server_name, six.string_types):
                 raise KeyLookupError(
-                    "Key response not signed by perspective server"
-                    " %r" % (perspective_name,)
+                    "Malformed response from key notary server %s: invalid server_name"
+                    % (perspective_name,)
                 )
 
-            verified = False
-            for key_id in response[u"signatures"][perspective_name]:
-                if key_id in perspective_keys:
-                    verify_signed_json(
-                        response,
-                        perspective_name,
-                        perspective_keys[key_id]
-                    )
-                    verified = True
-
-            if not verified:
-                logging.info(
-                    "Response from perspective server %r not signed with a"
-                    " known key, signed with: %r, known keys: %r",
+            try:
+                processed_response = yield self._process_perspectives_response(
                     perspective_name,
-                    list(response[u"signatures"][perspective_name]),
-                    list(perspective_keys)
+                    perspective_keys,
+                    response,
+                    time_added_ms=time_now_ms,
                 )
-                raise KeyLookupError(
-                    "Response not signed with a known key for perspective"
-                    " server %r" % (perspective_name,)
+            except KeyLookupError as e:
+                logger.warning(
+                    "Error processing response from key notary server %s for origin "
+                    "server %s: %s",
+                    perspective_name,
+                    server_name,
+                    e,
                 )
+                # we continue to process the rest of the response
+                continue
 
-            processed_response = yield self.process_v2_response(
-                perspective_name, response
+            added_keys.extend(
+                (server_name, key_id, key) for key_id, key in processed_response.items()
             )
-            server_name = response["server_name"]
-
             keys.setdefault(server_name, {}).update(processed_response)
 
-        yield logcontext.make_deferred_yieldable(defer.gatherResults(
-            [
-                run_in_background(
-                    self.store_keys,
-                    server_name=server_name,
-                    from_server=perspective_name,
-                    verify_keys=response_keys,
-                )
-                for server_name, response_keys in keys.items()
-            ],
-            consumeErrors=True
-        ).addErrback(unwrapFirstError))
+        yield self.store.store_server_verify_keys(
+            perspective_name, time_now_ms, added_keys
+        )
 
         defer.returnValue(keys)
 
-    @defer.inlineCallbacks
-    def get_server_verify_key_v2_direct(self, server_name, key_ids):
-        keys = {}  # type: dict[str, nacl.signing.VerifyKey]
+    def _process_perspectives_response(
+        self, perspective_name, perspective_keys, response, time_added_ms
+    ):
+        """Parse a 'Server Keys' structure from the result of a /key/query request
 
-        for requested_key_id in key_ids:
-            if requested_key_id in keys:
-                continue
+        Checks that the entry is correctly signed by the perspectives server, and then
+        passes over to process_v2_response
 
-            try:
-                response = yield self.client.get_json(
-                    destination=server_name,
-                    path="/_matrix/key/v2/server/" + urllib.parse.quote(requested_key_id),
-                    ignore_backoff=True,
-                )
-            except (NotRetryingDestination, RequestSendFailed) as e:
-                raise_from(
-                    KeyLookupError("Failed to connect to remote server"), e,
-                )
-            except HttpResponseException as e:
-                raise_from(
-                    KeyLookupError("Remote server returned an error"), e,
-                )
+        Args:
+            perspective_name (str): the name of the notary server that produced this
+                result
 
-            if (u"signatures" not in response
-                    or server_name not in response[u"signatures"]):
-                raise KeyLookupError("Key response not signed by remote server")
+            perspective_keys (dict[str, VerifyKey]): map of key_id->key for the
+                notary server
 
-            if response["server_name"] != server_name:
-                raise KeyLookupError("Expected a response for server %r not %r" % (
-                    server_name, response["server_name"]
-                ))
+            response (dict): the json-decoded Server Keys response object
 
-            response_keys = yield self.process_v2_response(
-                from_server=server_name,
-                requested_ids=[requested_key_id],
-                response_json=response,
-            )
+            time_added_ms (int): the timestamp to record in server_keys_json
 
-            keys.update(response_keys)
+        Returns:
+            Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
+        """
+        if (
+            u"signatures" not in response
+            or perspective_name not in response[u"signatures"]
+        ):
+            raise KeyLookupError("Response not signed by the notary server")
+
+        verified = False
+        for key_id in response[u"signatures"][perspective_name]:
+            if key_id in perspective_keys:
+                verify_signed_json(response, perspective_name, perspective_keys[key_id])
+                verified = True
+
+        if not verified:
+            raise KeyLookupError(
+                "Response not signed with a known key: signed with: %r, known keys: %r"
+                % (
+                    list(response[u"signatures"][perspective_name].keys()),
+                    list(perspective_keys.keys()),
+                )
+            )
 
-        yield self.store_keys(
-            server_name=server_name,
-            from_server=server_name,
-            verify_keys=keys,
+        return self.process_v2_response(
+            perspective_name, response, time_added_ms=time_added_ms
         )
-        defer.returnValue({server_name: keys})
 
-    @defer.inlineCallbacks
-    def process_v2_response(
-        self, from_server, response_json, requested_ids=[],
-    ):
-        """Parse a 'Server Keys' structure from the result of a /key request
 
-        This is used to parse either the entirety of the response from
-        GET /_matrix/key/v2/server, or a single entry from the list returned by
-        POST /_matrix/key/v2/query.
-
-        Checks that each signature in the response that claims to come from the origin
-        server is valid. (Does not check that there actually is such a signature, for
-        some reason.)
+class ServerKeyFetcher(BaseV2KeyFetcher):
+    """KeyFetcher impl which fetches keys from the origin servers"""
 
-        Stores the json in server_keys_json so that it can be used for future responses
-        to /_matrix/key/v2/query.
+    def __init__(self, hs):
+        super(ServerKeyFetcher, self).__init__(hs)
+        self.clock = hs.get_clock()
+        self.client = hs.get_http_client()
 
+    def get_keys(self, keys_to_fetch):
+        """
         Args:
-            from_server (str): the name of the server producing this result: either
-                the origin server for a /_matrix/key/v2/server request, or the notary
-                for a /_matrix/key/v2/query.
-
-            response_json (dict): the json-decoded Server Keys response object
-
-            requested_ids (iterable[str]): a list of the key IDs that were requested.
-                We will store the json for these key ids as well as any that are
-                actually in the response
+            keys_to_fetch (dict[str, iterable[str]]):
+                the keys to be fetched. server_name -> key_ids
 
         Returns:
-            Deferred[dict[str, nacl.signing.VerifyKey]]:
-                map from key_id to key object
+            Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
+                map from server_name -> key_id -> FetchKeyResult
         """
-        time_now_ms = self.clock.time_msec()
-        response_keys = {}
-        verify_keys = {}
-        for key_id, key_data in response_json["verify_keys"].items():
-            if is_signing_algorithm_supported(key_id):
-                key_base64 = key_data["key"]
-                key_bytes = decode_base64(key_base64)
-                verify_key = decode_verify_key_bytes(key_id, key_bytes)
-                verify_key.time_added = time_now_ms
-                verify_keys[key_id] = verify_key
 
-        old_verify_keys = {}
-        for key_id, key_data in response_json["old_verify_keys"].items():
-            if is_signing_algorithm_supported(key_id):
-                key_base64 = key_data["key"]
-                key_bytes = decode_base64(key_base64)
-                verify_key = decode_verify_key_bytes(key_id, key_bytes)
-                verify_key.expired = key_data["expired_ts"]
-                verify_key.time_added = time_now_ms
-                old_verify_keys[key_id] = verify_key
+        results = {}
 
-        server_name = response_json["server_name"]
-        for key_id in response_json["signatures"].get(server_name, {}):
-            if key_id not in response_json["verify_keys"]:
-                raise KeyLookupError(
-                    "Key response must include verification keys for all"
-                    " signatures"
-                )
-            if key_id in verify_keys:
-                verify_signed_json(
-                    response_json,
-                    server_name,
-                    verify_keys[key_id]
+        @defer.inlineCallbacks
+        def get_key(key_to_fetch_item):
+            server_name, key_ids = key_to_fetch_item
+            try:
+                keys = yield self.get_server_verify_key_v2_direct(server_name, key_ids)
+                results[server_name] = keys
+            except KeyLookupError as e:
+                logger.warning(
+                    "Error looking up keys %s from %s: %s", key_ids, server_name, e
                 )
+            except Exception:
+                logger.exception("Error getting keys %s from %s", key_ids, server_name)
 
-        signed_key_json = sign_json(
-            response_json,
-            self.config.server_name,
-            self.config.signing_key[0],
+        return yieldable_gather_results(get_key, keys_to_fetch.items()).addCallback(
+            lambda _: results
         )
 
-        signed_key_json_bytes = encode_canonical_json(signed_key_json)
-        ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
-
-        updated_key_ids = set(requested_ids)
-        updated_key_ids.update(verify_keys)
-        updated_key_ids.update(old_verify_keys)
-
-        response_keys.update(verify_keys)
-        response_keys.update(old_verify_keys)
-
-        yield logcontext.make_deferred_yieldable(defer.gatherResults(
-            [
-                run_in_background(
-                    self.store.store_server_keys_json,
-                    server_name=server_name,
-                    key_id=key_id,
-                    from_server=from_server,
-                    ts_now_ms=time_now_ms,
-                    ts_expires_ms=ts_valid_until_ms,
-                    key_json_bytes=signed_key_json_bytes,
-                )
-                for key_id in updated_key_ids
-            ],
-            consumeErrors=True,
-        ).addErrback(unwrapFirstError))
-
-        defer.returnValue(response_keys)
+    @defer.inlineCallbacks
+    def get_server_verify_key_v2_direct(self, server_name, key_ids):
+        """
 
-    def store_keys(self, server_name, from_server, verify_keys):
-        """Store a collection of verify keys for a given server
         Args:
-            server_name(str): The name of the server the keys are for.
-            from_server(str): The server the keys were downloaded from.
-            verify_keys(dict): A mapping of key_id to VerifyKey.
+            server_name (str):
+            key_ids (iterable[str]):
+
         Returns:
-            A deferred that completes when the keys are stored.
+            Deferred[dict[str, FetchKeyResult]]: map from key ID to lookup result
+
+        Raises:
+            KeyLookupError if there was a problem making the lookup
         """
-        # TODO(markjh): Store whether the keys have expired.
-        return logcontext.make_deferred_yieldable(defer.gatherResults(
-            [
-                run_in_background(
-                    self.store.store_server_verify_key,
-                    server_name, server_name, key.time_added, key
+        keys = {}  # type: dict[str, FetchKeyResult]
+
+        for requested_key_id in key_ids:
+            # we may have found this key as a side-effect of asking for another.
+            if requested_key_id in keys:
+                continue
+
+            time_now_ms = self.clock.time_msec()
+            try:
+                response = yield self.client.get_json(
+                    destination=server_name,
+                    path="/_matrix/key/v2/server/"
+                    + urllib.parse.quote(requested_key_id),
+                    ignore_backoff=True,
+
+                    # we only give the remote server 10s to respond. It should be an
+                    # easy request to handle, so if it doesn't reply within 10s, it's
+                    # probably not going to.
+                    #
+                    # Furthermore, when we are acting as a notary server, we cannot
+                    # wait all day for all of the origin servers, as the requesting
+                    # server will otherwise time out before we can respond.
+                    #
+                    # (Note that get_json may make 4 attempts, so this can still take
+                    # almost 45 seconds to fetch the headers, plus up to another 60s to
+                    # read the response).
+                    timeout=10000,
                 )
-                for key_id, key in verify_keys.items()
-            ],
-            consumeErrors=True,
-        ).addErrback(unwrapFirstError))
+            except (NotRetryingDestination, RequestSendFailed) as e:
+                raise_from(KeyLookupError("Failed to connect to remote server"), e)
+            except HttpResponseException as e:
+                raise_from(KeyLookupError("Remote server returned an error"), e)
+
+            if response["server_name"] != server_name:
+                raise KeyLookupError(
+                    "Expected a response for server %r not %r"
+                    % (server_name, response["server_name"])
+                )
+
+            response_keys = yield self.process_v2_response(
+                from_server=server_name,
+                response_json=response,
+                time_added_ms=time_now_ms,
+            )
+            yield self.store.store_server_verify_keys(
+                server_name,
+                time_now_ms,
+                ((server_name, key_id, key) for key_id, key in response_keys.items()),
+            )
+            keys.update(response_keys)
+
+        defer.returnValue(keys)
 
 
 @defer.inlineCallbacks
@@ -707,48 +861,29 @@ def _handle_key_deferred(verify_request):
         SynapseError if there was a problem performing the verification
     """
     server_name = verify_request.server_name
-    try:
-        with PreserveLoggingContext():
-            _, key_id, verify_key = yield verify_request.deferred
-    except KeyLookupError as e:
-        logger.warn(
-            "Failed to download keys for %s: %s %s",
-            server_name, type(e).__name__, str(e),
-        )
-        raise SynapseError(
-            502,
-            "Error downloading keys for %s" % (server_name,),
-            Codes.UNAUTHORIZED,
-        )
-    except Exception as e:
-        logger.exception(
-            "Got Exception when downloading keys for %s: %s %s",
-            server_name, type(e).__name__, str(e),
-        )
-        raise SynapseError(
-            401,
-            "No key for %s with id %s" % (server_name, verify_request.key_ids),
-            Codes.UNAUTHORIZED,
-        )
+    with PreserveLoggingContext():
+        _, key_id, verify_key = yield verify_request.deferred
 
     json_object = verify_request.json_object
 
-    logger.debug("Got key %s %s:%s for server %s, verifying" % (
-        key_id, verify_key.alg, verify_key.version, server_name,
-    ))
+    logger.debug(
+        "Got key %s %s:%s for server %s, verifying"
+        % (key_id, verify_key.alg, verify_key.version, server_name)
+    )
     try:
         verify_signed_json(json_object, server_name, verify_key)
     except SignatureVerifyException as e:
         logger.debug(
             "Error verifying signature for %s:%s:%s with key %s: %s",
-            server_name, verify_key.alg, verify_key.version,
+            server_name,
+            verify_key.alg,
+            verify_key.version,
             encode_verify_key_base64(verify_key),
             str(e),
         )
         raise SynapseError(
             401,
-            "Invalid signature for server %s with key %s:%s: %s" % (
-                server_name, verify_key.alg, verify_key.version, str(e),
-            ),
+            "Invalid signature for server %s with key %s:%s: %s"
+            % (server_name, verify_key.alg, verify_key.version, str(e)),
             Codes.UNAUTHORIZED,
         )