summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/crypto/keyring.py58
-rw-r--r--tests/crypto/test_keyring.py12
2 files changed, 41 insertions, 29 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index bef6498f4b..5660c96023 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -46,6 +46,7 @@ from synapse.api.errors import (
 )
 from synapse.storage.keys import FetchKeyResult
 from synapse.util import logcontext, unwrapFirstError
+from synapse.util.async_helpers import yieldable_gather_results
 from synapse.util.logcontext import (
     LoggingContext,
     PreserveLoggingContext,
@@ -169,7 +170,12 @@ class Keyring(object):
                     )
                 )
 
-            logger.debug("Verifying for %s with key_ids %s", server_name, key_ids)
+            logger.debug(
+                "Verifying for %s with key_ids %s, min_validity %i",
+                server_name,
+                key_ids,
+                validity_time,
+            )
 
             # add the key request to the queue, but don't start it off yet.
             verify_request = VerifyKeyRequest(
@@ -744,34 +750,42 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
         self.clock = hs.get_clock()
         self.client = hs.get_http_client()
 
-    @defer.inlineCallbacks
     def get_keys(self, keys_to_fetch):
         """see KeyFetcher.get_keys"""
-        # TODO make this more resilient
-        results = yield logcontext.make_deferred_yieldable(
-            defer.gatherResults(
-                [
-                    run_in_background(
-                        self.get_server_verify_key_v2_direct,
-                        server_name,
-                        server_keys.keys(),
-                    )
-                    for server_name, server_keys in keys_to_fetch.items()
-                ],
-                consumeErrors=True,
-            ).addErrback(unwrapFirstError)
-        )
 
-        merged = {}
-        for result in results:
-            merged.update(result)
+        results = {}
+
+        @defer.inlineCallbacks
+        def get_key(key_to_fetch_item):
+            server_name, key_ids = key_to_fetch_item
+            try:
+                keys = yield self.get_server_verify_key_v2_direct(server_name, key_ids)
+                results[server_name] = keys
+            except KeyLookupError as e:
+                logger.warning(
+                    "Error looking up keys %s from %s: %s", key_ids, server_name, e
+                )
+            except Exception:
+                logger.exception("Error getting keys %s from %s", key_ids, server_name)
 
-        defer.returnValue(
-            {server_name: keys for server_name, keys in merged.items() if keys}
+        return yieldable_gather_results(get_key, keys_to_fetch.items()).addCallback(
+            lambda _: results
         )
 
     @defer.inlineCallbacks
     def get_server_verify_key_v2_direct(self, server_name, key_ids):
+        """
+
+        Args:
+            server_name (str):
+            key_ids (iterable[str]):
+
+        Returns:
+            Deferred[dict[str, FetchKeyResult]]: map from key ID to lookup result
+
+        Raises:
+            KeyLookupError if there was a problem making the lookup
+        """
         keys = {}  # type: dict[str, FetchKeyResult]
 
         for requested_key_id in key_ids:
@@ -823,7 +837,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
             )
             keys.update(response_keys)
 
-        defer.returnValue({server_name: keys})
+        defer.returnValue(keys)
 
 
 @defer.inlineCallbacks
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 096401938d..4cff7e36c8 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -25,11 +25,7 @@ from twisted.internet import defer
 
 from synapse.api.errors import SynapseError
 from synapse.crypto import keyring
-from synapse.crypto.keyring import (
-    KeyLookupError,
-    PerspectivesKeyFetcher,
-    ServerKeyFetcher,
-)
+from synapse.crypto.keyring import PerspectivesKeyFetcher, ServerKeyFetcher
 from synapse.storage.keys import FetchKeyResult
 from synapse.util import logcontext
 from synapse.util.logcontext import LoggingContext
@@ -364,9 +360,11 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
             bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
         )
 
-        # change the server name: it should cause a rejection
+        # change the server name: the result should be ignored
         response["server_name"] = "OTHER_SERVER"
-        self.get_failure(fetcher.get_keys(keys_to_fetch), KeyLookupError)
+
+        keys = self.get_success(fetcher.get_keys(keys_to_fetch))
+        self.assertEqual(keys, {})
 
 
 class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):