summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/crypto/keyring.py101
1 files changed, 53 insertions, 48 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index c63f106cf3..194867db03 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -270,59 +270,21 @@ class Keyring(object):
             verify_requests (list[VerifyKeyRequest]): list of verify requests
         """
 
+        remaining_requests = set(
+            (rq for rq in verify_requests if not rq.deferred.called)
+        )
+
         @defer.inlineCallbacks
         def do_iterations():
             with Measure(self.clock, "get_server_verify_keys"):
-                # dict[str, set(str)]: keys to fetch for each server
-                missing_keys = {}
-                for verify_request in verify_requests:
-                    missing_keys.setdefault(verify_request.server_name, set()).update(
-                        verify_request.key_ids
-                    )
-
                 for f in self._key_fetchers:
-                    results = yield f.get_keys(missing_keys.items())
-
-                    # We now need to figure out which verify requests we have keys
-                    # for and which we don't
-                    missing_keys = {}
-                    requests_missing_keys = []
-                    for verify_request in verify_requests:
-                        if verify_request.deferred.called:
-                            # We've already called this deferred, which probably
-                            # means that we've already found a key for it.
-                            continue
-
-                        server_name = verify_request.server_name
-
-                        # see if any of the keys we got this time are sufficient to
-                        # complete this VerifyKeyRequest.
-                        result_keys = results.get(server_name, {})
-                        for key_id in verify_request.key_ids:
-                            fetch_key_result = result_keys.get(key_id)
-                            if fetch_key_result:
-                                with PreserveLoggingContext():
-                                    verify_request.deferred.callback(
-                                        (
-                                            server_name,
-                                            key_id,
-                                            fetch_key_result.verify_key,
-                                        )
-                                    )
-                                break
-                        else:
-                            # The else block is only reached if the loop above
-                            # doesn't break.
-                            missing_keys.setdefault(server_name, set()).update(
-                                verify_request.key_ids
-                            )
-                            requests_missing_keys.append(verify_request)
-
-                    if not missing_keys:
-                        break
+                    if not remaining_requests:
+                        return
+                    yield self._attempt_key_fetches_with_fetcher(f, remaining_requests)
 
+                # look for any requests which weren't satisfied
                 with PreserveLoggingContext():
-                    for verify_request in requests_missing_keys:
+                    for verify_request in remaining_requests:
                         verify_request.deferred.errback(
                             SynapseError(
                                 401,
@@ -333,13 +295,56 @@ class Keyring(object):
                         )
 
         def on_err(err):
+            # we don't really expect to get here, because any errors should already
+            # have been caught and logged. But if we do, let's log the error and make
+            # sure that all of the deferreds are resolved.
+            logger.error("Unexpected error in _get_server_verify_keys: %s", err)
             with PreserveLoggingContext():
-                for verify_request in verify_requests:
+                for verify_request in remaining_requests:
                     if not verify_request.deferred.called:
                         verify_request.deferred.errback(err)
 
         run_in_background(do_iterations).addErrback(on_err)
 
+    @defer.inlineCallbacks
+    def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
+        """Use a key fetcher to attempt to satisfy some key requests
+
+        Args:
+            fetcher (KeyFetcher): fetcher to use to fetch the keys
+            remaining_requests (set[VerifyKeyRequest]): outstanding key requests.
+                Any successfully-completed requests will be reomved from the list.
+        """
+        # dict[str, set(str)]: keys to fetch for each server
+        missing_keys = {}
+        for verify_request in remaining_requests:
+            # any completed requests should already have been removed
+            assert not verify_request.deferred.called
+            missing_keys.setdefault(verify_request.server_name, set()).update(
+                verify_request.key_ids
+            )
+
+        results = yield fetcher.get_keys(missing_keys.items())
+
+        completed = list()
+        for verify_request in remaining_requests:
+            server_name = verify_request.server_name
+
+            # see if any of the keys we got this time are sufficient to
+            # complete this VerifyKeyRequest.
+            result_keys = results.get(server_name, {})
+            for key_id in verify_request.key_ids:
+                key = result_keys.get(key_id)
+                if key:
+                    with PreserveLoggingContext():
+                        verify_request.deferred.callback(
+                            (server_name, key_id, key.verify_key)
+                        )
+                    completed.append(verify_request)
+                    break
+
+        remaining_requests.difference_update(completed)
+
 
 class KeyFetcher(object):
     def get_keys(self, server_name_and_key_ids):