summary refs log tree commit diff
path: root/synapse/crypto/keyring.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/crypto/keyring.py')
-rw-r--r--synapse/crypto/keyring.py64
1 files changed, 38 insertions, 26 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 8709394b97..aff69c5f83 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -26,7 +26,7 @@ from synapse.api.errors import SynapseError, Codes
 
 from synapse.util.retryutils import get_retry_limiter
 
-from synapse.util.async import create_observer
+from synapse.util.async import ObservableDeferred
 
 from OpenSSL import crypto
 
@@ -111,6 +111,10 @@ class Keyring(object):
 
         if download is None:
             download = self._get_server_verify_key_impl(server_name, key_ids)
+            download = ObservableDeferred(
+                download,
+                consumeErrors=True
+            )
             self.key_downloads[server_name] = download
 
             @download.addBoth
@@ -118,30 +122,31 @@ class Keyring(object):
                 del self.key_downloads[server_name]
                 return ret
 
-        r = yield create_observer(download)
+        r = yield download.observe()
         defer.returnValue(r)
 
     @defer.inlineCallbacks
     def _get_server_verify_key_impl(self, server_name, key_ids):
         keys = None
 
-        perspective_results = []
-        for perspective_name, perspective_keys in self.perspective_servers.items():
-            @defer.inlineCallbacks
-            def get_key():
-                try:
-                    result = yield self.get_server_verify_key_v2_indirect(
-                        server_name, key_ids, perspective_name, perspective_keys
-                    )
-                    defer.returnValue(result)
-                except:
-                    logging.info(
-                        "Unable to getting key %r for %r from %r",
-                        key_ids, server_name, perspective_name,
-                    )
-            perspective_results.append(get_key())
+        @defer.inlineCallbacks
+        def get_key(perspective_name, perspective_keys):
+            try:
+                result = yield self.get_server_verify_key_v2_indirect(
+                    server_name, key_ids, perspective_name, perspective_keys
+                )
+                defer.returnValue(result)
+            except Exception as e:
+                logging.info(
+                    "Unable to getting key %r for %r from %r: %s %s",
+                    key_ids, server_name, perspective_name,
+                    type(e).__name__, str(e.message),
+                )
 
-        perspective_results = yield defer.gatherResults(perspective_results)
+        perspective_results = yield defer.gatherResults([
+            get_key(p_name, p_keys)
+            for p_name, p_keys in self.perspective_servers.items()
+        ])
 
         for results in perspective_results:
             if results is not None:
@@ -154,17 +159,22 @@ class Keyring(object):
         )
 
         with limiter:
-            if keys is None:
+            if not keys:
                 try:
                     keys = yield self.get_server_verify_key_v2_direct(
                         server_name, key_ids
                     )
-                except:
-                    pass
+                except Exception as e:
+                    logging.info(
+                        "Unable to getting key %r for %r directly: %s %s",
+                        key_ids, server_name,
+                        type(e).__name__, str(e.message),
+                    )
 
-            keys = yield self.get_server_verify_key_v1_direct(
-                server_name, key_ids
-            )
+            if not keys:
+                keys = yield self.get_server_verify_key_v1_direct(
+                    server_name, key_ids
+                )
 
         for key_id in key_ids:
             if key_id in keys:
@@ -184,7 +194,7 @@ class Keyring(object):
             # TODO(mark): Set the minimum_valid_until_ts to that needed by
             # the events being validated or the current time if validating
             # an incoming request.
-            responses = yield self.client.post_json(
+            query_response = yield self.client.post_json(
                 destination=perspective_name,
                 path=b"/_matrix/key/v2/query",
                 data={
@@ -200,6 +210,8 @@ class Keyring(object):
 
         keys = {}
 
+        responses = query_response["server_keys"]
+
         for response in responses:
             if (u"signatures" not in response
                     or perspective_name not in response[u"signatures"]):
@@ -323,7 +335,7 @@ class Keyring(object):
                 verify_key.time_added = time_now_ms
                 old_verify_keys[key_id] = verify_key
 
-        for key_id in response_json["signatures"][server_name]:
+        for key_id in response_json["signatures"].get(server_name, {}):
             if key_id not in response_json["verify_keys"]:
                 raise ValueError(
                     "Key response must include verification keys for all"