diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 873c9b40fa..aa74d4d0cb 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -27,6 +27,8 @@ from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter
from synapse.util import unwrapFirstError
+from synapse.util.async import ObservableDeferred
+
from OpenSSL import crypto
from collections import namedtuple
@@ -88,6 +90,8 @@ class Keyring(object):
"Not signed with a supported algorithm",
Codes.UNAUTHORIZED,
))
+ else:
+ deferreds[group_id] = defer.Deferred()
group = KeyGroup(server_name, group_id, key_ids)
@@ -133,10 +137,41 @@ class Keyring(object):
Codes.UNAUTHORIZED,
)
- deferreds.update(self.get_server_verify_keys(
- group_id_to_group
- ))
+ server_to_deferred = {
+ server_name: defer.Deferred()
+ for server_name, _ in server_and_json
+ }
+
+ # We want to wait for any previous lookups to complete before
+ # proceeding.
+ wait_on_deferred = self.wait_for_previous_lookups(
+ [server_name for server_name, _ in server_and_json],
+ server_to_deferred,
+ )
+
+ # Actually start fetching keys.
+ wait_on_deferred.addBoth(
+ lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
+ )
+
+ # When we've finished fetching all the keys for a given server_name,
+ # resolve the deferred passed to `wait_for_previous_lookups` so that
+ # any lookups waiting will proceed.
+ server_to_gids = {}
+
+ def remove_deferreds(res, server_name, group_id):
+ server_to_gids[server_name].discard(group_id)
+ if not server_to_gids[server_name]:
+ server_to_deferred.pop(server_name).callback(None)
+ return res
+ for g_id, deferred in deferreds.items():
+ server_name = group_id_to_group[g_id].server_name
+ server_to_gids.setdefault(server_name, set()).add(g_id)
+ deferred.addBoth(remove_deferreds, server_name, g_id)
+
+ # Pass those keys to handle_key_deferred so that the json object
+ # signatures can be verified
return [
handle_key_deferred(
group_id_to_group[g_id],
@@ -145,7 +180,30 @@ class Keyring(object):
for g_id in group_ids
]
- def get_server_verify_keys(self, group_id_to_group):
+ @defer.inlineCallbacks
+ def wait_for_previous_lookups(self, server_names, server_to_deferred):
+ """Waits for any previous key lookups for the given servers to finish.
+
+ Args:
+ server_names (list): list of server_names we want to lookup
+ server_to_deferred (dict): server_name to deferred which gets
+ resolved once we've finished looking up keys for that server
+ """
+ while True:
+ wait_on = [
+ self.key_downloads[server_name]
+ for server_name in server_names
+ if server_name in self.key_downloads
+ ]
+ if wait_on:
+ yield defer.DeferredList(wait_on)
+ else:
+ break
+
+ for server_name, deferred in server_to_deferred:
+ self.key_downloads[server_name] = ObservableDeferred(deferred)
+
+ def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
"""Takes a dict of KeyGroups and tries to find at least one key for
each group.
"""
@@ -157,11 +215,6 @@ class Keyring(object):
self.get_keys_from_server, # Then try directly
)
- group_deferreds = {
- group_id: defer.Deferred()
- for group_id in group_id_to_group
- }
-
@defer.inlineCallbacks
def do_iterations():
merged_results = {}
@@ -182,7 +235,7 @@ class Keyring(object):
for group in group_id_to_group.values():
for key_id in group.key_ids:
if key_id in merged_results[group.server_name]:
- group_deferreds.pop(group.group_id).callback((
+ group_id_to_deferred[group.group_id].callback((
group.group_id,
group.server_name,
key_id,
@@ -205,7 +258,7 @@ class Keyring(object):
}
for group in missing_groups.values():
- group_deferreds.pop(group.group_id).errback(SynapseError(
+ group_id_to_deferred[group.group_id].errback(SynapseError(
401,
"No key for %s with id %s" % (
group.server_name, group.key_ids,
@@ -214,13 +267,13 @@ class Keyring(object):
))
def on_err(err):
- for deferred in group_deferreds.values():
- deferred.errback(err)
- group_deferreds.clear()
+ for deferred in group_id_to_deferred.values():
+ if not deferred.called:
+ deferred.errback(err)
do_iterations().addErrback(on_err)
- return group_deferreds
+ return group_id_to_deferred
@defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids):
|