diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index bef6498f4b..5660c96023 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -46,6 +46,7 @@ from synapse.api.errors import (
)
from synapse.storage.keys import FetchKeyResult
from synapse.util import logcontext, unwrapFirstError
+from synapse.util.async_helpers import yieldable_gather_results
from synapse.util.logcontext import (
LoggingContext,
PreserveLoggingContext,
@@ -169,7 +170,12 @@ class Keyring(object):
)
)
- logger.debug("Verifying for %s with key_ids %s", server_name, key_ids)
+ logger.debug(
+ "Verifying for %s with key_ids %s, min_validity %i",
+ server_name,
+ key_ids,
+ validity_time,
+ )
# add the key request to the queue, but don't start it off yet.
verify_request = VerifyKeyRequest(
@@ -744,34 +750,42 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
self.clock = hs.get_clock()
self.client = hs.get_http_client()
- @defer.inlineCallbacks
def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys"""
- # TODO make this more resilient
- results = yield logcontext.make_deferred_yieldable(
- defer.gatherResults(
- [
- run_in_background(
- self.get_server_verify_key_v2_direct,
- server_name,
- server_keys.keys(),
- )
- for server_name, server_keys in keys_to_fetch.items()
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError)
- )
- merged = {}
- for result in results:
- merged.update(result)
+ results = {}
+
+ @defer.inlineCallbacks
+ def get_key(key_to_fetch_item):
+ server_name, key_ids = key_to_fetch_item
+ try:
+ keys = yield self.get_server_verify_key_v2_direct(server_name, key_ids)
+ results[server_name] = keys
+ except KeyLookupError as e:
+ logger.warning(
+ "Error looking up keys %s from %s: %s", key_ids, server_name, e
+ )
+ except Exception:
+ logger.exception("Error getting keys %s from %s", key_ids, server_name)
- defer.returnValue(
- {server_name: keys for server_name, keys in merged.items() if keys}
+ return yieldable_gather_results(get_key, keys_to_fetch.items()).addCallback(
+ lambda _: results
)
@defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids):
+ """
+
+ Args:
+ server_name (str):
+ key_ids (iterable[str]):
+
+ Returns:
+ Deferred[dict[str, FetchKeyResult]]: map from key ID to lookup result
+
+ Raises:
+ KeyLookupError if there was a problem making the lookup
+ """
keys = {} # type: dict[str, FetchKeyResult]
for requested_key_id in key_ids:
@@ -823,7 +837,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
)
keys.update(response_keys)
- defer.returnValue({server_name: keys})
+ defer.returnValue(keys)
@defer.inlineCallbacks
|