summary refs log tree commit diff
path: root/synapse/crypto/keyring.py
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2019-06-10 20:28:08 +0100
committerRichard van der Hoff <richard@matrix.org>2019-06-10 20:28:08 +0100
commit69a43d9974824893605252bf32d08484fdaa28fb (patch)
tree180a6d2e659269ebf6e61b424051497cbab22848 /synapse/crypto/keyring.py
parentSAML2 Improvements and redirect stuff (diff)
parentMerge branch 'release-v1.0.0' of github.com:matrix-org/synapse into develop (diff)
downloadsynapse-69a43d9974824893605252bf32d08484fdaa28fb.tar.xz
Merge remote-tracking branch 'origin/develop' into rav/saml2_client
Diffstat (limited to 'synapse/crypto/keyring.py')
-rw-r--r--synapse/crypto/keyring.py410
1 files changed, 252 insertions, 158 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index f4918d1bc6..6f603f1961 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 import logging
+from collections import defaultdict
 
 import six
 from six import raise_from
@@ -45,6 +46,7 @@ from synapse.api.errors import (
 )
 from synapse.storage.keys import FetchKeyResult
 from synapse.util import logcontext, unwrapFirstError
+from synapse.util.async_helpers import yieldable_gather_results
 from synapse.util.logcontext import (
     LoggingContext,
     PreserveLoggingContext,
@@ -58,9 +60,9 @@ logger = logging.getLogger(__name__)
 
 
 @attr.s(slots=True, cmp=False)
-class VerifyKeyRequest(object):
+class VerifyJsonRequest(object):
     """
-    A request for a verify key to verify a JSON object.
+    A request to verify a JSON object.
 
     Attributes:
         server_name(str): The name of the server to verify against.
@@ -70,7 +72,10 @@ class VerifyKeyRequest(object):
 
         json_object(dict): The JSON object to verify.
 
-        deferred(Deferred[str, str, nacl.signing.VerifyKey]):
+        minimum_valid_until_ts (int): time at which we require the signing key to
+            be valid. (0 implies we don't care)
+
+        key_ready (Deferred[str, str, nacl.signing.VerifyKey]):
             A deferred (server_name, key_id, verify_key) tuple that resolves when
             a verify key has been fetched. The deferreds' callbacks are run with no
             logcontext.
@@ -80,9 +85,14 @@ class VerifyKeyRequest(object):
     """
 
     server_name = attr.ib()
-    key_ids = attr.ib()
     json_object = attr.ib()
-    deferred = attr.ib()
+    minimum_valid_until_ts = attr.ib()
+    request_name = attr.ib()
+    key_ids = attr.ib(init=False)
+    key_ready = attr.ib(default=attr.Factory(defer.Deferred))
+
+    def __attrs_post_init__(self):
+        self.key_ids = signature_ids(self.json_object, self.server_name)
 
 
 class KeyLookupError(ValueError):
@@ -90,14 +100,16 @@ class KeyLookupError(ValueError):
 
 
 class Keyring(object):
-    def __init__(self, hs):
+    def __init__(self, hs, key_fetchers=None):
         self.clock = hs.get_clock()
 
-        self._key_fetchers = (
-            StoreKeyFetcher(hs),
-            PerspectivesKeyFetcher(hs),
-            ServerKeyFetcher(hs),
-        )
+        if key_fetchers is None:
+            key_fetchers = (
+                StoreKeyFetcher(hs),
+                PerspectivesKeyFetcher(hs),
+                ServerKeyFetcher(hs),
+            )
+        self._key_fetchers = key_fetchers
 
         # map from server name to Deferred. Has an entry for each server with
         # an ongoing key download; the Deferred completes once the download
@@ -106,51 +118,99 @@ class Keyring(object):
         # These are regular, logcontext-agnostic Deferreds.
         self.key_downloads = {}
 
-    def verify_json_for_server(self, server_name, json_object):
-        return logcontext.make_deferred_yieldable(
-            self.verify_json_objects_for_server([(server_name, json_object)])[0]
-        )
+    def verify_json_for_server(
+        self, server_name, json_object, validity_time, request_name
+    ):
+        """Verify that a JSON object has been signed by a given server
+
+        Args:
+            server_name (str): name of the server which must have signed this object
+
+            json_object (dict): object to be checked
+
+            validity_time (int): timestamp at which we require the signing key to
+                be valid. (0 implies we don't care)
+
+            request_name (str): an identifier for this json object (eg, an event id)
+                for logging.
+
+        Returns:
+            Deferred[None]: completes if the the object was correctly signed, otherwise
+                errbacks with an error
+        """
+        req = VerifyJsonRequest(server_name, json_object, validity_time, request_name)
+        requests = (req,)
+        return logcontext.make_deferred_yieldable(self._verify_objects(requests)[0])
 
     def verify_json_objects_for_server(self, server_and_json):
         """Bulk verifies signatures of json objects, bulk fetching keys as
         necessary.
 
         Args:
-            server_and_json (list): List of pairs of (server_name, json_object)
+            server_and_json (iterable[Tuple[str, dict, int, str]):
+                Iterable of (server_name, json_object, validity_time, request_name)
+                tuples.
+
+                validity_time is a timestamp at which the signing key must be
+                valid.
+
+                request_name is an identifier for this json object (eg, an event id)
+                for logging.
 
         Returns:
-            List<Deferred>: for each input pair, a deferred indicating success
+            List<Deferred[None]>: for each input triplet, a deferred indicating success
                 or failure to verify each json object's signature for the given
                 server_name. The deferreds run their callbacks in the sentinel
                 logcontext.
         """
-        # a list of VerifyKeyRequests
-        verify_requests = []
+        return self._verify_objects(
+            VerifyJsonRequest(server_name, json_object, validity_time, request_name)
+            for server_name, json_object, validity_time, request_name in server_and_json
+        )
+
+    def _verify_objects(self, verify_requests):
+        """Does the work of verify_json_[objects_]for_server
+
+
+        Args:
+            verify_requests (iterable[VerifyJsonRequest]):
+                Iterable of verification requests.
+
+        Returns:
+            List<Deferred[None]>: for each input item, a deferred indicating success
+                or failure to verify each json object's signature for the given
+                server_name. The deferreds run their callbacks in the sentinel
+                logcontext.
+        """
+        # a list of VerifyJsonRequests which are awaiting a key lookup
+        key_lookups = []
         handle = preserve_fn(_handle_key_deferred)
 
-        def process(server_name, json_object):
+        def process(verify_request):
             """Process an entry in the request list
 
-            Given a (server_name, json_object) pair from the request list,
-            adds a key request to verify_requests, and returns a deferred which will
-            complete or fail (in the sentinel context) when verification completes.
+            Adds a key request to key_lookups, and returns a deferred which
+            will complete or fail (in the sentinel context) when verification completes.
             """
-            key_ids = signature_ids(json_object, server_name)
-
-            if not key_ids:
+            if not verify_request.key_ids:
                 return defer.fail(
                     SynapseError(
-                        400, "Not signed by %s" % (server_name,), Codes.UNAUTHORIZED
+                        400,
+                        "Not signed by %s" % (verify_request.server_name,),
+                        Codes.UNAUTHORIZED,
                     )
                 )
 
-            logger.debug("Verifying for %s with key_ids %s", server_name, key_ids)
+            logger.debug(
+                "Verifying %s for %s with key_ids %s, min_validity %i",
+                verify_request.request_name,
+                verify_request.server_name,
+                verify_request.key_ids,
+                verify_request.minimum_valid_until_ts,
+            )
 
             # add the key request to the queue, but don't start it off yet.
-            verify_request = VerifyKeyRequest(
-                server_name, key_ids, json_object, defer.Deferred()
-            )
-            verify_requests.append(verify_request)
+            key_lookups.append(verify_request)
 
             # now run _handle_key_deferred, which will wait for the key request
             # to complete and then do the verification.
@@ -159,13 +219,10 @@ class Keyring(object):
             # wrap it with preserve_fn (aka run_in_background)
             return handle(verify_request)
 
-        results = [
-            process(server_name, json_object)
-            for server_name, json_object in server_and_json
-        ]
+        results = [process(r) for r in verify_requests]
 
-        if verify_requests:
-            run_in_background(self._start_key_lookups, verify_requests)
+        if key_lookups:
+            run_in_background(self._start_key_lookups, key_lookups)
 
         return results
 
@@ -173,10 +230,10 @@ class Keyring(object):
     def _start_key_lookups(self, verify_requests):
         """Sets off the key fetches for each verify request
 
-        Once each fetch completes, verify_request.deferred will be resolved.
+        Once each fetch completes, verify_request.key_ready will be resolved.
 
         Args:
-            verify_requests (List[VerifyKeyRequest]):
+            verify_requests (List[VerifyJsonRequest]):
         """
 
         try:
@@ -219,7 +276,7 @@ class Keyring(object):
                 return res
 
             for verify_request in verify_requests:
-                verify_request.deferred.addBoth(remove_deferreds, verify_request)
+                verify_request.key_ready.addBoth(remove_deferreds, verify_request)
         except Exception:
             logger.exception("Error starting key lookups")
 
@@ -272,16 +329,16 @@ class Keyring(object):
     def _get_server_verify_keys(self, verify_requests):
         """Tries to find at least one key for each verify request
 
-        For each verify_request, verify_request.deferred is called back with
+        For each verify_request, verify_request.key_ready is called back with
         params (server_name, key_id, VerifyKey) if a key is found, or errbacked
         with a SynapseError if none of the keys are found.
 
         Args:
-            verify_requests (list[VerifyKeyRequest]): list of verify requests
+            verify_requests (list[VerifyJsonRequest]): list of verify requests
         """
 
         remaining_requests = set(
-            (rq for rq in verify_requests if not rq.deferred.called)
+            (rq for rq in verify_requests if not rq.key_ready.called)
         )
 
         @defer.inlineCallbacks
@@ -295,11 +352,15 @@ class Keyring(object):
                 # look for any requests which weren't satisfied
                 with PreserveLoggingContext():
                     for verify_request in remaining_requests:
-                        verify_request.deferred.errback(
+                        verify_request.key_ready.errback(
                             SynapseError(
                                 401,
-                                "No key for %s with id %s"
-                                % (verify_request.server_name, verify_request.key_ids),
+                                "No key for %s with ids in %s (min_validity %i)"
+                                % (
+                                    verify_request.server_name,
+                                    verify_request.key_ids,
+                                    verify_request.minimum_valid_until_ts,
+                                ),
                                 Codes.UNAUTHORIZED,
                             )
                         )
@@ -311,8 +372,8 @@ class Keyring(object):
             logger.error("Unexpected error in _get_server_verify_keys: %s", err)
             with PreserveLoggingContext():
                 for verify_request in remaining_requests:
-                    if not verify_request.deferred.called:
-                        verify_request.deferred.errback(err)
+                    if not verify_request.key_ready.called:
+                        verify_request.key_ready.errback(err)
 
         run_in_background(do_iterations).addErrback(on_err)
 
@@ -322,47 +383,66 @@ class Keyring(object):
 
         Args:
             fetcher (KeyFetcher): fetcher to use to fetch the keys
-            remaining_requests (set[VerifyKeyRequest]): outstanding key requests.
-                Any successfully-completed requests will be reomved from the list.
+            remaining_requests (set[VerifyJsonRequest]): outstanding key requests.
+                Any successfully-completed requests will be removed from the list.
         """
-        # dict[str, set(str)]: keys to fetch for each server
-        missing_keys = {}
+        # dict[str, dict[str, int]]: keys to fetch.
+        # server_name -> key_id -> min_valid_ts
+        missing_keys = defaultdict(dict)
+
         for verify_request in remaining_requests:
             # any completed requests should already have been removed
-            assert not verify_request.deferred.called
-            missing_keys.setdefault(verify_request.server_name, set()).update(
-                verify_request.key_ids
-            )
+            assert not verify_request.key_ready.called
+            keys_for_server = missing_keys[verify_request.server_name]
+
+            for key_id in verify_request.key_ids:
+                # If we have several requests for the same key, then we only need to
+                # request that key once, but we should do so with the greatest
+                # min_valid_until_ts of the requests, so that we can satisfy all of
+                # the requests.
+                keys_for_server[key_id] = max(
+                    keys_for_server.get(key_id, -1),
+                    verify_request.minimum_valid_until_ts,
+                )
 
-        results = yield fetcher.get_keys(missing_keys.items())
+        results = yield fetcher.get_keys(missing_keys)
 
         completed = list()
         for verify_request in remaining_requests:
             server_name = verify_request.server_name
 
             # see if any of the keys we got this time are sufficient to
-            # complete this VerifyKeyRequest.
+            # complete this VerifyJsonRequest.
             result_keys = results.get(server_name, {})
             for key_id in verify_request.key_ids:
-                key = result_keys.get(key_id)
-                if key:
-                    with PreserveLoggingContext():
-                        verify_request.deferred.callback(
-                            (server_name, key_id, key.verify_key)
-                        )
-                    completed.append(verify_request)
-                    break
+                fetch_key_result = result_keys.get(key_id)
+                if not fetch_key_result:
+                    # we didn't get a result for this key
+                    continue
+
+                if (
+                    fetch_key_result.valid_until_ts
+                    < verify_request.minimum_valid_until_ts
+                ):
+                    # key was not valid at this point
+                    continue
+
+                with PreserveLoggingContext():
+                    verify_request.key_ready.callback(
+                        (server_name, key_id, fetch_key_result.verify_key)
+                    )
+                completed.append(verify_request)
+                break
 
         remaining_requests.difference_update(completed)
 
 
 class KeyFetcher(object):
-    def get_keys(self, server_name_and_key_ids):
+    def get_keys(self, keys_to_fetch):
         """
         Args:
-            server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]):
-                list of (server_name, iterable[key_id]) tuples to fetch keys for
-                Note that the iterables may be iterated more than once.
+            keys_to_fetch (dict[str, dict[str, int]]):
+                the keys to be fetched. server_name -> key_id -> min_valid_ts
 
         Returns:
             Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
@@ -378,13 +458,15 @@ class StoreKeyFetcher(KeyFetcher):
         self.store = hs.get_datastore()
 
     @defer.inlineCallbacks
-    def get_keys(self, server_name_and_key_ids):
+    def get_keys(self, keys_to_fetch):
         """see KeyFetcher.get_keys"""
+
         keys_to_fetch = (
             (server_name, key_id)
-            for server_name, key_ids in server_name_and_key_ids
-            for key_id in key_ids
+            for server_name, keys_for_server in keys_to_fetch.items()
+            for key_id in keys_for_server.keys()
         )
+
         res = yield self.store.get_server_verify_keys(keys_to_fetch)
         keys = {}
         for (server_name, key_id), key in res.items():
@@ -398,9 +480,7 @@ class BaseV2KeyFetcher(object):
         self.config = hs.get_config()
 
     @defer.inlineCallbacks
-    def process_v2_response(
-        self, from_server, response_json, time_added_ms, requested_ids=[]
-    ):
+    def process_v2_response(self, from_server, response_json, time_added_ms):
         """Parse a 'Server Keys' structure from the result of a /key request
 
         This is used to parse either the entirety of the response from
@@ -422,10 +502,6 @@ class BaseV2KeyFetcher(object):
 
             time_added_ms (int): the timestamp to record in server_keys_json
 
-            requested_ids (iterable[str]): a list of the key IDs that were requested.
-                We will store the json for these key ids as well as any that are
-                actually in the response
-
         Returns:
             Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
         """
@@ -481,11 +557,6 @@ class BaseV2KeyFetcher(object):
 
         signed_key_json_bytes = encode_canonical_json(signed_key_json)
 
-        # for reasons I don't quite understand, we store this json for the key ids we
-        # requested, as well as those we got.
-        updated_key_ids = set(requested_ids)
-        updated_key_ids.update(verify_keys)
-
         yield logcontext.make_deferred_yieldable(
             defer.gatherResults(
                 [
@@ -498,7 +569,7 @@ class BaseV2KeyFetcher(object):
                         ts_expires_ms=ts_valid_until_ms,
                         key_json_bytes=signed_key_json_bytes,
                     )
-                    for key_id in updated_key_ids
+                    for key_id in verify_keys
                 ],
                 consumeErrors=True,
             ).addErrback(unwrapFirstError)
@@ -514,25 +585,27 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
         super(PerspectivesKeyFetcher, self).__init__(hs)
         self.clock = hs.get_clock()
         self.client = hs.get_http_client()
-        self.perspective_servers = self.config.perspectives
+        self.key_servers = self.config.key_servers
 
     @defer.inlineCallbacks
-    def get_keys(self, server_name_and_key_ids):
+    def get_keys(self, keys_to_fetch):
         """see KeyFetcher.get_keys"""
 
         @defer.inlineCallbacks
-        def get_key(perspective_name, perspective_keys):
+        def get_key(key_server):
             try:
                 result = yield self.get_server_verify_key_v2_indirect(
-                    server_name_and_key_ids, perspective_name, perspective_keys
+                    keys_to_fetch, key_server
                 )
                 defer.returnValue(result)
             except KeyLookupError as e:
-                logger.warning("Key lookup failed from %r: %s", perspective_name, e)
+                logger.warning(
+                    "Key lookup failed from %r: %s", key_server.server_name, e
+                )
             except Exception as e:
                 logger.exception(
                     "Unable to get key from %r: %s %s",
-                    perspective_name,
+                    key_server.server_name,
                     type(e).__name__,
                     str(e),
                 )
@@ -542,8 +615,8 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
         results = yield logcontext.make_deferred_yieldable(
             defer.gatherResults(
                 [
-                    run_in_background(get_key, p_name, p_keys)
-                    for p_name, p_keys in self.perspective_servers.items()
+                    run_in_background(get_key, server)
+                    for server in self.key_servers
                 ],
                 consumeErrors=True,
             ).addErrback(unwrapFirstError)
@@ -558,15 +631,15 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
 
     @defer.inlineCallbacks
     def get_server_verify_key_v2_indirect(
-        self, server_names_and_key_ids, perspective_name, perspective_keys
+        self, keys_to_fetch, key_server
     ):
         """
         Args:
-            server_names_and_key_ids (iterable[Tuple[str, iterable[str]]]):
-                list of (server_name, iterable[key_id]) tuples to fetch keys for
-            perspective_name (str): name of the notary server to query for the keys
-            perspective_keys (dict[str, VerifyKey]): map of key_id->key for the
-                notary server
+            keys_to_fetch (dict[str, dict[str, int]]):
+                the keys to be fetched. server_name -> key_id -> min_valid_ts
+
+            key_server (synapse.config.key.TrustedKeyServer): notary server to query for
+                the keys
 
         Returns:
             Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map
@@ -576,14 +649,13 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
             KeyLookupError if there was an error processing the entire response from
                 the server
         """
+        perspective_name = key_server.server_name
         logger.info(
             "Requesting keys %s from notary server %s",
-            server_names_and_key_ids,
+            keys_to_fetch.items(),
             perspective_name,
         )
-        # TODO(mark): Set the minimum_valid_until_ts to that needed by
-        # the events being validated or the current time if validating
-        # an incoming request.
+
         try:
             query_response = yield self.client.post_json(
                 destination=perspective_name,
@@ -591,12 +663,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
                 data={
                     u"server_keys": {
                         server_name: {
-                            key_id: {u"minimum_valid_until_ts": 0} for key_id in key_ids
+                            key_id: {u"minimum_valid_until_ts": min_valid_ts}
+                            for key_id, min_valid_ts in server_keys.items()
                         }
-                        for server_name, key_ids in server_names_and_key_ids
+                        for server_name, server_keys in keys_to_fetch.items()
                     }
                 },
-                long_retries=True,
             )
         except (NotRetryingDestination, RequestSendFailed) as e:
             raise_from(KeyLookupError("Failed to connect to remote server"), e)
@@ -618,11 +690,13 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
                 )
 
             try:
-                processed_response = yield self._process_perspectives_response(
-                    perspective_name,
-                    perspective_keys,
+                self._validate_perspectives_response(
+                    key_server,
                     response,
-                    time_added_ms=time_now_ms,
+                )
+
+                processed_response = yield self.process_v2_response(
+                    perspective_name, response, time_added_ms=time_now_ms
                 )
             except KeyLookupError as e:
                 logger.warning(
@@ -646,28 +720,24 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
 
         defer.returnValue(keys)
 
-    def _process_perspectives_response(
-        self, perspective_name, perspective_keys, response, time_added_ms
+    def _validate_perspectives_response(
+        self, key_server, response,
     ):
-        """Parse a 'Server Keys' structure from the result of a /key/query request
-
-        Checks that the entry is correctly signed by the perspectives server, and then
-        passes over to process_v2_response
+        """Optionally check the signature on the result of a /key/query request
 
         Args:
-            perspective_name (str): the name of the notary server that produced this
-                result
-
-            perspective_keys (dict[str, VerifyKey]): map of key_id->key for the
-                notary server
+            key_server (synapse.config.key.TrustedKeyServer): the notary server that
+                produced this result
 
             response (dict): the json-decoded Server Keys response object
+        """
+        perspective_name = key_server.server_name
+        perspective_keys = key_server.verify_keys
 
-            time_added_ms (int): the timestamp to record in server_keys_json
+        if perspective_keys is None:
+            # signature checking is disabled on this server
+            return
 
-        Returns:
-            Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
-        """
         if (
             u"signatures" not in response
             or perspective_name not in response[u"signatures"]
@@ -689,10 +759,6 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
                 )
             )
 
-        return self.process_v2_response(
-            perspective_name, response, time_added_ms=time_added_ms
-        )
-
 
 class ServerKeyFetcher(BaseV2KeyFetcher):
     """KeyFetcher impl which fetches keys from the origin servers"""
@@ -702,34 +768,54 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
         self.clock = hs.get_clock()
         self.client = hs.get_http_client()
 
-    @defer.inlineCallbacks
-    def get_keys(self, server_name_and_key_ids):
-        """see KeyFetcher.get_keys"""
-        results = yield logcontext.make_deferred_yieldable(
-            defer.gatherResults(
-                [
-                    run_in_background(
-                        self.get_server_verify_key_v2_direct, server_name, key_ids
-                    )
-                    for server_name, key_ids in server_name_and_key_ids
-                ],
-                consumeErrors=True,
-            ).addErrback(unwrapFirstError)
-        )
+    def get_keys(self, keys_to_fetch):
+        """
+        Args:
+            keys_to_fetch (dict[str, iterable[str]]):
+                the keys to be fetched. server_name -> key_ids
 
-        merged = {}
-        for result in results:
-            merged.update(result)
+        Returns:
+            Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
+                map from server_name -> key_id -> FetchKeyResult
+        """
+
+        results = {}
+
+        @defer.inlineCallbacks
+        def get_key(key_to_fetch_item):
+            server_name, key_ids = key_to_fetch_item
+            try:
+                keys = yield self.get_server_verify_key_v2_direct(server_name, key_ids)
+                results[server_name] = keys
+            except KeyLookupError as e:
+                logger.warning(
+                    "Error looking up keys %s from %s: %s", key_ids, server_name, e
+                )
+            except Exception:
+                logger.exception("Error getting keys %s from %s", key_ids, server_name)
 
-        defer.returnValue(
-            {server_name: keys for server_name, keys in merged.items() if keys}
+        return yieldable_gather_results(get_key, keys_to_fetch.items()).addCallback(
+            lambda _: results
         )
 
     @defer.inlineCallbacks
     def get_server_verify_key_v2_direct(self, server_name, key_ids):
+        """
+
+        Args:
+            server_name (str):
+            key_ids (iterable[str]):
+
+        Returns:
+            Deferred[dict[str, FetchKeyResult]]: map from key ID to lookup result
+
+        Raises:
+            KeyLookupError if there was a problem making the lookup
+        """
         keys = {}  # type: dict[str, FetchKeyResult]
 
         for requested_key_id in key_ids:
+            # we may have found this key as a side-effect of asking for another.
             if requested_key_id in keys:
                 continue
 
@@ -740,6 +826,19 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
                     path="/_matrix/key/v2/server/"
                     + urllib.parse.quote(requested_key_id),
                     ignore_backoff=True,
+
+                    # we only give the remote server 10s to respond. It should be an
+                    # easy request to handle, so if it doesn't reply within 10s, it's
+                    # probably not going to.
+                    #
+                    # Furthermore, when we are acting as a notary server, we cannot
+                    # wait all day for all of the origin servers, as the requesting
+                    # server will otherwise time out before we can respond.
+                    #
+                    # (Note that get_json may make 4 attempts, so this can still take
+                    # almost 45 seconds to fetch the headers, plus up to another 60s to
+                    # read the response).
+                    timeout=10000,
                 )
             except (NotRetryingDestination, RequestSendFailed) as e:
                 raise_from(KeyLookupError("Failed to connect to remote server"), e)
@@ -754,7 +853,6 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
 
             response_keys = yield self.process_v2_response(
                 from_server=server_name,
-                requested_ids=[requested_key_id],
                 response_json=response,
                 time_added_ms=time_now_ms,
             )
@@ -765,7 +863,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
             )
             keys.update(response_keys)
 
-        defer.returnValue({server_name: keys})
+        defer.returnValue(keys)
 
 
 @defer.inlineCallbacks
@@ -773,7 +871,7 @@ def _handle_key_deferred(verify_request):
     """Waits for the key to become available, and then performs a verification
 
     Args:
-        verify_request (VerifyKeyRequest):
+        verify_request (VerifyJsonRequest):
 
     Returns:
         Deferred[None]
@@ -783,14 +881,10 @@ def _handle_key_deferred(verify_request):
     """
     server_name = verify_request.server_name
     with PreserveLoggingContext():
-        _, key_id, verify_key = yield verify_request.deferred
+        _, key_id, verify_key = yield verify_request.key_ready
 
     json_object = verify_request.json_object
 
-    logger.debug(
-        "Got key %s %s:%s for server %s, verifying"
-        % (key_id, verify_key.alg, verify_key.version, server_name)
-    )
     try:
         verify_signed_json(json_object, server_name, verify_key)
     except SignatureVerifyException as e: