diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index d08ee0aa91..7cd11cfae7 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -44,7 +44,25 @@ import logging
logger = logging.getLogger(__name__)
-KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
+VerifyKeyRequest = namedtuple("VerifyRequest", (
+ "server_name", "key_ids", "json_object", "deferred"
+))
+"""
+A request for a verify key to verify a JSON object.
+
+Attributes:
+ server_name(str): The name of the server to verify against.
+ key_ids(set(str)): The set of key_ids to that could be used to verify the
+ JSON object
+ json_object(dict): The JSON object to verify.
+ deferred(twisted.internet.defer.Deferred):
+ A deferred (server_name, key_id, verify_key) tuple that resolves when
+ a verify key has been fetched
+"""
+
+
+class KeyLookupError(ValueError):
+ pass
class Keyring(object):
@@ -74,39 +92,32 @@ class Keyring(object):
list of deferreds indicating success or failure to verify each
json object's signature for the given server_name.
"""
- group_id_to_json = {}
- group_id_to_group = {}
- group_ids = []
-
- next_group_id = 0
- deferreds = {}
+ verify_requests = []
for server_name, json_object in server_and_json:
logger.debug("Verifying for %s", server_name)
- group_id = next_group_id
- next_group_id += 1
- group_ids.append(group_id)
key_ids = signature_ids(json_object, server_name)
if not key_ids:
- deferreds[group_id] = defer.fail(SynapseError(
+ deferred = defer.fail(SynapseError(
400,
"Not signed with a supported algorithm",
Codes.UNAUTHORIZED,
))
else:
- deferreds[group_id] = defer.Deferred()
+ deferred = defer.Deferred()
- group = KeyGroup(server_name, group_id, key_ids)
+ verify_request = VerifyKeyRequest(
+ server_name, key_ids, json_object, deferred
+ )
- group_id_to_group[group_id] = group
- group_id_to_json[group_id] = json_object
+ verify_requests.append(verify_request)
@defer.inlineCallbacks
- def handle_key_deferred(group, deferred):
- server_name = group.server_name
+ def handle_key_deferred(verify_request):
+ server_name = verify_request.server_name
try:
- _, _, key_id, verify_key = yield deferred
+ _, key_id, verify_key = yield verify_request.deferred
except IOError as e:
logger.warn(
"Got IOError when downloading keys for %s: %s %s",
@@ -128,7 +139,7 @@ class Keyring(object):
Codes.UNAUTHORIZED,
)
- json_object = group_id_to_json[group.group_id]
+ json_object = verify_request.json_object
try:
verify_signed_json(json_object, server_name, verify_key)
@@ -157,36 +168,34 @@ class Keyring(object):
# Actually start fetching keys.
wait_on_deferred.addBoth(
- lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
+ lambda _: self.get_server_verify_keys(verify_requests)
)
# When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed.
- server_to_gids = {}
+ server_to_request_ids = {}
- def remove_deferreds(res, server_name, group_id):
- server_to_gids[server_name].discard(group_id)
- if not server_to_gids[server_name]:
+ def remove_deferreds(res, server_name, verify_request):
+ request_id = id(verify_request)
+ server_to_request_ids[server_name].discard(request_id)
+ if not server_to_request_ids[server_name]:
d = server_to_deferred.pop(server_name, None)
if d:
d.callback(None)
return res
- for g_id, deferred in deferreds.items():
- server_name = group_id_to_group[g_id].server_name
- server_to_gids.setdefault(server_name, set()).add(g_id)
- deferred.addBoth(remove_deferreds, server_name, g_id)
+ for verify_request in verify_requests:
+ server_name = verify_request.server_name
+ request_id = id(verify_request)
+ server_to_request_ids.setdefault(server_name, set()).add(request_id)
+ deferred.addBoth(remove_deferreds, server_name, verify_request)
# Pass those keys to handle_key_deferred so that the json object
# signatures can be verified
return [
- preserve_context_over_fn(
- handle_key_deferred,
- group_id_to_group[g_id],
- deferreds[g_id],
- )
- for g_id in group_ids
+ preserve_context_over_fn(handle_key_deferred, verify_request)
+ for verify_request in verify_requests
]
@defer.inlineCallbacks
@@ -220,7 +229,7 @@ class Keyring(object):
d.addBoth(rm, server_name)
- def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
+ def get_server_verify_keys(self, verify_requests):
"""Takes a dict of KeyGroups and tries to find at least one key for
each group.
"""
@@ -237,62 +246,64 @@ class Keyring(object):
merged_results = {}
missing_keys = {}
- for group in group_id_to_group.values():
- missing_keys.setdefault(group.server_name, set()).update(
- group.key_ids
+ for verify_request in verify_requests:
+ missing_keys.setdefault(verify_request.server_name, set()).update(
+ verify_request.key_ids
)
for fn in key_fetch_fns:
results = yield fn(missing_keys.items())
merged_results.update(results)
- # We now need to figure out which groups we have keys for
- # and which we don't
- missing_groups = {}
- for group in group_id_to_group.values():
- for key_id in group.key_ids:
- if key_id in merged_results[group.server_name]:
+ # We now need to figure out which verify requests we have keys
+ # for and which we don't
+ missing_keys = {}
+ requests_missing_keys = []
+ for verify_request in verify_requests:
+ server_name = verify_request.server_name
+ result_keys = merged_results[server_name]
+
+ if verify_request.deferred.called:
+ # We've already called this deferred, which probably
+ # means that we've already found a key for it.
+ continue
+
+ for key_id in verify_request.key_ids:
+ if key_id in result_keys:
with PreserveLoggingContext():
- group_id_to_deferred[group.group_id].callback((
- group.group_id,
- group.server_name,
+ verify_request.deferred.callback((
+ server_name,
key_id,
- merged_results[group.server_name][key_id],
+ result_keys[key_id],
))
break
else:
- missing_groups.setdefault(
- group.server_name, []
- ).append(group)
-
- if not missing_groups:
+ # The else block is only reached if the loop above
+ # doesn't break.
+ missing_keys.setdefault(server_name, set()).update(
+ verify_request.key_ids
+ )
+ requests_missing_keys.append(verify_request)
+
+ if not missing_keys:
break
- missing_keys = {
- server_name: set(
- key_id for group in groups for key_id in group.key_ids
- )
- for server_name, groups in missing_groups.items()
- }
-
- for group in missing_groups.values():
- group_id_to_deferred[group.group_id].errback(SynapseError(
+ for verify_request in requests_missing_keys.values():
+ verify_request.deferred.errback(SynapseError(
401,
"No key for %s with id %s" % (
- group.server_name, group.key_ids,
+ verify_request.server_name, verify_request.key_ids,
),
Codes.UNAUTHORIZED,
))
def on_err(err):
- for deferred in group_id_to_deferred.values():
- if not deferred.called:
- deferred.errback(err)
+ for verify_request in verify_requests:
+ if not verify_request.deferred.called:
+ verify_request.deferred.errback(err)
do_iterations().addErrback(on_err)
- return group_id_to_deferred
-
@defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids):
res = yield defer.gatherResults(
@@ -356,7 +367,7 @@ class Keyring(object):
)
except Exception as e:
logger.info(
- "Unable to getting key %r for %r directly: %s %s",
+ "Unable to get key %r for %r directly: %s %s",
key_ids, server_name,
type(e).__name__, str(e.message),
)
@@ -418,7 +429,7 @@ class Keyring(object):
for response in responses:
if (u"signatures" not in response
or perspective_name not in response[u"signatures"]):
- raise ValueError(
+ raise KeyLookupError(
"Key response not signed by perspective server"
" %r" % (perspective_name,)
)
@@ -441,13 +452,13 @@ class Keyring(object):
list(response[u"signatures"][perspective_name]),
list(perspective_keys)
)
- raise ValueError(
+ raise KeyLookupError(
"Response not signed with a known key for perspective"
" server %r" % (perspective_name,)
)
processed_response = yield self.process_v2_response(
- perspective_name, response
+ perspective_name, response, only_from_server=False
)
for server_name, response_keys in processed_response.items():
@@ -484,10 +495,10 @@ class Keyring(object):
if (u"signatures" not in response
or server_name not in response[u"signatures"]):
- raise ValueError("Key response not signed by remote server")
+ raise KeyLookupError("Key response not signed by remote server")
if "tls_fingerprints" not in response:
- raise ValueError("Key response missing TLS fingerprints")
+ raise KeyLookupError("Key response missing TLS fingerprints")
certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1, tls_certificate
@@ -501,7 +512,7 @@ class Keyring(object):
response_sha256_fingerprints.add(fingerprint[u"sha256"])
if sha256_fingerprint_b64 not in response_sha256_fingerprints:
- raise ValueError("TLS certificate not allowed by fingerprints")
+ raise KeyLookupError("TLS certificate not allowed by fingerprints")
response_keys = yield self.process_v2_response(
from_server=server_name,
@@ -527,7 +538,7 @@ class Keyring(object):
@defer.inlineCallbacks
def process_v2_response(self, from_server, response_json,
- requested_ids=[]):
+ requested_ids=[], only_from_server=True):
time_now_ms = self.clock.time_msec()
response_keys = {}
verify_keys = {}
@@ -551,9 +562,16 @@ class Keyring(object):
results = {}
server_name = response_json["server_name"]
+ if only_from_server:
+ if server_name != from_server:
+ raise KeyLookupError(
+ "Expected a response for server %r not %r" % (
+ from_server, server_name
+ )
+ )
for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]:
- raise ValueError(
+ raise KeyLookupError(
"Key response must include verification keys for all"
" signatures"
)
@@ -621,15 +639,15 @@ class Keyring(object):
if ("signatures" not in response
or server_name not in response["signatures"]):
- raise ValueError("Key response not signed by remote server")
+ raise KeyLookupError("Key response not signed by remote server")
if "tls_certificate" not in response:
- raise ValueError("Key response missing TLS certificate")
+ raise KeyLookupError("Key response missing TLS certificate")
tls_certificate_b64 = response["tls_certificate"]
if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
- raise ValueError("TLS certificate doesn't match")
+ raise KeyLookupError("TLS certificate doesn't match")
# Cache the result in the datastore.
@@ -645,7 +663,7 @@ class Keyring(object):
for key_id in response["signatures"][server_name]:
if key_id not in response["verify_keys"]:
- raise ValueError(
+ raise KeyLookupError(
"Key response must include verification keys for all"
" signatures"
)
|