diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index f687d41ccb..5012c10ee8 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -44,7 +44,21 @@ import logging
logger = logging.getLogger(__name__)
-KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
+VerifyKeyRequest = namedtuple("VerifyRequest", (
+ "server_name", "key_ids", "json_object", "deferred"
+))
+"""
+A request for a verify key to verify a JSON object.
+
+Attributes:
+ server_name(str): The name of the server to verify against.
+ key_ids(set(str)): The set of key_ids to that could be used to verify the
+ JSON object
+ json_object(dict): The JSON object to verify.
+ deferred(twisted.internet.defer.Deferred):
+ A deferred (server_name, key_id, verify_key) tuple that resolves when
+ a verify key has been fetched
+"""
class Keyring(object):
@@ -74,39 +88,32 @@ class Keyring(object):
list of deferreds indicating success or failure to verify each
json object's signature for the given server_name.
"""
- group_id_to_json = {}
- group_id_to_group = {}
- group_ids = []
-
- next_group_id = 0
- deferreds = {}
+ verify_requests = []
for server_name, json_object in server_and_json:
logger.debug("Verifying for %s", server_name)
- group_id = next_group_id
- next_group_id += 1
- group_ids.append(group_id)
key_ids = signature_ids(json_object, server_name)
if not key_ids:
- deferreds[group_id] = defer.fail(SynapseError(
+ deferred = defer.fail(SynapseError(
400,
"Not signed with a supported algorithm",
Codes.UNAUTHORIZED,
))
else:
- deferreds[group_id] = defer.Deferred()
+ deferred = defer.Deferred()
- group = KeyGroup(server_name, group_id, key_ids)
+ verify_request = VerifyKeyRequest(
+ server_name, key_ids, json_object, deferred
+ )
- group_id_to_group[group_id] = group
- group_id_to_json[group_id] = json_object
+ verify_requests.append(verify_request)
@defer.inlineCallbacks
- def handle_key_deferred(group, deferred):
- server_name = group.server_name
+ def handle_key_deferred(verify_request):
+ server_name = verify_request.server_name
try:
- _, _, key_id, verify_key = yield deferred
+ _, key_id, verify_key = yield verify_request.deferred
except IOError as e:
logger.warn(
"Got IOError when downloading keys for %s: %s %s",
@@ -128,7 +135,7 @@ class Keyring(object):
Codes.UNAUTHORIZED,
)
- json_object = group_id_to_json[group.group_id]
+ json_object = verify_request.json_object
try:
verify_signed_json(json_object, server_name, verify_key)
@@ -157,36 +164,34 @@ class Keyring(object):
# Actually start fetching keys.
wait_on_deferred.addBoth(
- lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
+ lambda _: self.get_server_verify_keys(verify_requests)
)
# When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed.
- server_to_gids = {}
+ server_to_request_ids = {}
- def remove_deferreds(res, server_name, group_id):
- server_to_gids[server_name].discard(group_id)
- if not server_to_gids[server_name]:
+ def remove_deferreds(res, server_name, verify_request):
+ request_id = id(verify_request)
+ server_to_request_ids[server_name].discard(request_id)
+ if not server_to_request_ids[server_name]:
d = server_to_deferred.pop(server_name, None)
if d:
d.callback(None)
return res
- for g_id, deferred in deferreds.items():
- server_name = group_id_to_group[g_id].server_name
- server_to_gids.setdefault(server_name, set()).add(g_id)
- deferred.addBoth(remove_deferreds, server_name, g_id)
+ for verify_request in verify_requests:
+ server_name = verify_request.server_name
+ request_id = id(verify_request)
+ server_to_request_ids.setdefault(server_name, set()).add(request_id)
+ deferred.addBoth(remove_deferreds, server_name, verify_request)
# Pass those keys to handle_key_deferred so that the json object
# signatures can be verified
return [
- preserve_context_over_fn(
- handle_key_deferred,
- group_id_to_group[g_id],
- deferreds[g_id],
- )
- for g_id in group_ids
+ preserve_context_over_fn(handle_key_deferred, verify_request)
+ for verify_request in verify_requests
]
@defer.inlineCallbacks
@@ -220,7 +225,7 @@ class Keyring(object):
d.addBoth(rm, server_name)
- def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
+ def get_server_verify_keys(self, verify_requests):
"""Takes a dict of KeyGroups and tries to find at least one key for
each group.
"""
@@ -237,63 +242,64 @@ class Keyring(object):
merged_results = {}
missing_keys = {}
- for group in group_id_to_group.values():
- missing_keys.setdefault(group.server_name, set()).update(
- group.key_ids
+ for verify_request in verify_requests:
+ missing_keys.setdefault(verify_request.server_name, set()).update(
+ verify_request.key_ids
)
for fn in key_fetch_fns:
results = yield fn(missing_keys.items())
merged_results.update(results)
- # We now need to figure out which groups we have keys for
- # and which we don't
- missing_groups = {}
- for group in group_id_to_group.values():
- for key_id in group.key_ids:
- if key_id in merged_results[group.server_name]:
+ # We now need to figure out which verify requests we have keys
+ # for and which we don't
+ missing_keys = {}
+ requests_missing_keys = []
+ for verify_request in verify_requests:
+ server_name = verify_request.server_name
+ result_keys = merged_results[server_name]
+
+ if verify_request.deferred.called:
+ # We've already called this deferred, which probably
+ # means that we've already found a key for it.
+ continue
+
+ for key_id in verify_request.key_ids:
+ if key_id in result_keys:
with PreserveLoggingContext():
- group_id_to_deferred[group.group_id].callback((
- group.group_id,
- group.server_name,
+ verify_request.deferred.callback((
+ server_name,
key_id,
- merged_results[group.server_name][key_id],
+ result_keys[key_id],
))
break
else:
- missing_groups.setdefault(
- group.server_name, []
- ).append(group)
-
- if not missing_groups:
+ # The else block is only reached if the loop above
+ # doesn't break.
+ missing_keys.setdefault(server_name, set()).update(
+ verify_request.key_ids
+ )
+ requests_missing_keys.append(verify_request)
+
+ if not missing_keys:
break
- missing_keys = {
- server_name: set(
- key_id for group in groups for key_id in group.key_ids
- )
- for server_name, groups in missing_groups.items()
- }
-
- for groups in missing_groups.values():
- for group in groups:
- group_id_to_deferred[group.group_id].errback(SynapseError(
- 401,
- "No key for %s with id %s" % (
- group.server_name, group.key_ids,
- ),
- Codes.UNAUTHORIZED,
- ))
+ for verify_request in requests_missing_keys.values():
+ verify_request.deferred.errback(SynapseError(
+ 401,
+ "No key for %s with id %s" % (
+ verify_request.server_name, verify_request.key_ids,
+ ),
+ Codes.UNAUTHORIZED,
+ ))
def on_err(err):
- for deferred in group_id_to_deferred.values():
- if not deferred.called:
- deferred.errback(err)
+ for verify_request in verify_requests:
+ if not verify_request.deferred.called:
+ verify_request.deferred.errback(err)
do_iterations().addErrback(on_err)
- return group_id_to_deferred
-
@defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids):
res = yield defer.gatherResults(
|