diff options
Diffstat (limited to 'synapse/crypto/keyring.py')
-rw-r--r-- | synapse/crypto/keyring.py | 45 |
1 files changed, 26 insertions, 19 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index aa74d4d0cb..e251ab6af3 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -14,20 +14,20 @@ # limitations under the License. from synapse.crypto.keyclient import fetch_server_key +from synapse.api.errors import SynapseError, Codes +from synapse.util.retryutils import get_retry_limiter +from synapse.util import unwrapFirstError +from synapse.util.async import ObservableDeferred + from twisted.internet import defer -from syutil.crypto.jsonsign import ( + +from signedjson.sign import ( verify_signed_json, signature_ids, sign_json, encode_canonical_json ) -from syutil.crypto.signing_key import ( +from signedjson.key import ( is_signing_algorithm_supported, decode_verify_key_bytes ) -from syutil.base64util import decode_base64, encode_base64 -from synapse.api.errors import SynapseError, Codes - -from synapse.util.retryutils import get_retry_limiter -from synapse.util import unwrapFirstError - -from synapse.util.async import ObservableDeferred +from unpaddedbase64 import decode_base64, encode_base64 from OpenSSL import crypto @@ -162,7 +162,9 @@ class Keyring(object): def remove_deferreds(res, server_name, group_id): server_to_gids[server_name].discard(group_id) if not server_to_gids[server_name]: - server_to_deferred.pop(server_name).callback(None) + d = server_to_deferred.pop(server_name, None) + if d: + d.callback(None) return res for g_id, deferred in deferreds.items(): @@ -200,8 +202,15 @@ class Keyring(object): else: break - for server_name, deferred in server_to_deferred: - self.key_downloads[server_name] = ObservableDeferred(deferred) + for server_name, deferred in server_to_deferred.items(): + d = ObservableDeferred(deferred) + self.key_downloads[server_name] = d + + def rm(r, server_name): + self.key_downloads.pop(server_name, None) + return r + + d.addBoth(rm, server_name) def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred): """Takes a dict of KeyGroups and tries to find at least one key for @@ -220,9 +229,8 @@ class Keyring(object): merged_results = {} missing_keys = { - group.server_name: key_id + group.server_name: set(group.key_ids) for group in group_id_to_group.values() - for key_id in group.key_ids } for fn in key_fetch_fns: @@ -279,16 +287,15 @@ class Keyring(object): def get_keys_from_store(self, server_name_and_key_ids): res = yield defer.gatherResults( [ - self.store.get_server_verify_keys(server_name, key_ids) + self.store.get_server_verify_keys( + server_name, key_ids + ).addCallback(lambda ks, server: (server, ks), server_name) for server_name, key_ids in server_name_and_key_ids ], consumeErrors=True, ).addErrback(unwrapFirstError) - defer.returnValue(dict(zip( - [server_name for server_name, _ in server_name_and_key_ids], - res - ))) + defer.returnValue(dict(res)) @defer.inlineCallbacks def get_keys_from_perspectives(self, server_name_and_key_ids): |