diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index c63f106cf3..194867db03 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -270,59 +270,21 @@ class Keyring(object):
verify_requests (list[VerifyKeyRequest]): list of verify requests
"""
+ remaining_requests = set(
+ (rq for rq in verify_requests if not rq.deferred.called)
+ )
+
@defer.inlineCallbacks
def do_iterations():
with Measure(self.clock, "get_server_verify_keys"):
- # dict[str, set(str)]: keys to fetch for each server
- missing_keys = {}
- for verify_request in verify_requests:
- missing_keys.setdefault(verify_request.server_name, set()).update(
- verify_request.key_ids
- )
-
for f in self._key_fetchers:
- results = yield f.get_keys(missing_keys.items())
-
- # We now need to figure out which verify requests we have keys
- # for and which we don't
- missing_keys = {}
- requests_missing_keys = []
- for verify_request in verify_requests:
- if verify_request.deferred.called:
- # We've already called this deferred, which probably
- # means that we've already found a key for it.
- continue
-
- server_name = verify_request.server_name
-
- # see if any of the keys we got this time are sufficient to
- # complete this VerifyKeyRequest.
- result_keys = results.get(server_name, {})
- for key_id in verify_request.key_ids:
- fetch_key_result = result_keys.get(key_id)
- if fetch_key_result:
- with PreserveLoggingContext():
- verify_request.deferred.callback(
- (
- server_name,
- key_id,
- fetch_key_result.verify_key,
- )
- )
- break
- else:
- # The else block is only reached if the loop above
- # doesn't break.
- missing_keys.setdefault(server_name, set()).update(
- verify_request.key_ids
- )
- requests_missing_keys.append(verify_request)
-
- if not missing_keys:
- break
+ if not remaining_requests:
+ return
+ yield self._attempt_key_fetches_with_fetcher(f, remaining_requests)
+ # look for any requests which weren't satisfied
with PreserveLoggingContext():
- for verify_request in requests_missing_keys:
+ for verify_request in remaining_requests:
verify_request.deferred.errback(
SynapseError(
401,
@@ -333,13 +295,56 @@ class Keyring(object):
)
def on_err(err):
+ # we don't really expect to get here, because any errors should already
+ # have been caught and logged. But if we do, let's log the error and make
+ # sure that all of the deferreds are resolved.
+ logger.error("Unexpected error in _get_server_verify_keys: %s", err)
with PreserveLoggingContext():
- for verify_request in verify_requests:
+ for verify_request in remaining_requests:
if not verify_request.deferred.called:
verify_request.deferred.errback(err)
run_in_background(do_iterations).addErrback(on_err)
+ @defer.inlineCallbacks
+ def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
+ """Use a key fetcher to attempt to satisfy some key requests
+
+ Args:
+ fetcher (KeyFetcher): fetcher to use to fetch the keys
+ remaining_requests (set[VerifyKeyRequest]): outstanding key requests.
+ Any successfully-completed requests will be reomved from the list.
+ """
+ # dict[str, set(str)]: keys to fetch for each server
+ missing_keys = {}
+ for verify_request in remaining_requests:
+ # any completed requests should already have been removed
+ assert not verify_request.deferred.called
+ missing_keys.setdefault(verify_request.server_name, set()).update(
+ verify_request.key_ids
+ )
+
+ results = yield fetcher.get_keys(missing_keys.items())
+
+ completed = list()
+ for verify_request in remaining_requests:
+ server_name = verify_request.server_name
+
+ # see if any of the keys we got this time are sufficient to
+ # complete this VerifyKeyRequest.
+ result_keys = results.get(server_name, {})
+ for key_id in verify_request.key_ids:
+ key = result_keys.get(key_id)
+ if key:
+ with PreserveLoggingContext():
+ verify_request.deferred.callback(
+ (server_name, key_id, key.verify_key)
+ )
+ completed.append(verify_request)
+ break
+
+ remaining_requests.difference_update(completed)
+
class KeyFetcher(object):
def get_keys(self, server_name_and_key_ids):
|