diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index aa74d4d0cb..e251ab6af3 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -14,20 +14,20 @@
# limitations under the License.
from synapse.crypto.keyclient import fetch_server_key
+from synapse.api.errors import SynapseError, Codes
+from synapse.util.retryutils import get_retry_limiter
+from synapse.util import unwrapFirstError
+from synapse.util.async import ObservableDeferred
+
from twisted.internet import defer
-from syutil.crypto.jsonsign import (
+
+from signedjson.sign import (
verify_signed_json, signature_ids, sign_json, encode_canonical_json
)
-from syutil.crypto.signing_key import (
+from signedjson.key import (
is_signing_algorithm_supported, decode_verify_key_bytes
)
-from syutil.base64util import decode_base64, encode_base64
-from synapse.api.errors import SynapseError, Codes
-
-from synapse.util.retryutils import get_retry_limiter
-from synapse.util import unwrapFirstError
-
-from synapse.util.async import ObservableDeferred
+from unpaddedbase64 import decode_base64, encode_base64
from OpenSSL import crypto
@@ -162,7 +162,9 @@ class Keyring(object):
def remove_deferreds(res, server_name, group_id):
server_to_gids[server_name].discard(group_id)
if not server_to_gids[server_name]:
- server_to_deferred.pop(server_name).callback(None)
+ d = server_to_deferred.pop(server_name, None)
+ if d:
+ d.callback(None)
return res
for g_id, deferred in deferreds.items():
@@ -200,8 +202,15 @@ class Keyring(object):
else:
break
- for server_name, deferred in server_to_deferred:
- self.key_downloads[server_name] = ObservableDeferred(deferred)
+ for server_name, deferred in server_to_deferred.items():
+ d = ObservableDeferred(deferred)
+ self.key_downloads[server_name] = d
+
+ def rm(r, server_name):
+ self.key_downloads.pop(server_name, None)
+ return r
+
+ d.addBoth(rm, server_name)
def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
"""Takes a dict of KeyGroups and tries to find at least one key for
@@ -220,9 +229,8 @@ class Keyring(object):
merged_results = {}
missing_keys = {
- group.server_name: key_id
+ group.server_name: set(group.key_ids)
for group in group_id_to_group.values()
- for key_id in group.key_ids
}
for fn in key_fetch_fns:
@@ -279,16 +287,15 @@ class Keyring(object):
def get_keys_from_store(self, server_name_and_key_ids):
res = yield defer.gatherResults(
[
- self.store.get_server_verify_keys(server_name, key_ids)
+ self.store.get_server_verify_keys(
+ server_name, key_ids
+ ).addCallback(lambda ks, server: (server, ks), server_name)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
- defer.returnValue(dict(zip(
- [server_name for server_name, _ in server_name_and_key_ids],
- res
- )))
+ defer.returnValue(dict(res))
@defer.inlineCallbacks
def get_keys_from_perspectives(self, server_name_and_key_ids):
|