diff options
Diffstat (limited to 'synapse/crypto')
-rw-r--r-- | synapse/crypto/keyclient.py | 12 | ||||
-rw-r--r-- | synapse/crypto/keyring.py | 154 |
2 files changed, 93 insertions, 73 deletions
diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py index 54b83da9d8..c2bd64d6c2 100644 --- a/synapse/crypto/keyclient.py +++ b/synapse/crypto/keyclient.py @@ -77,10 +77,12 @@ class SynapseKeyClientProtocol(HTTPClient): def __init__(self): self.remote_key = defer.Deferred() self.host = None + self._peer = None def connectionMade(self): - self.host = self.transport.getHost() - logger.debug("Connected to %s", self.host) + self._peer = self.transport.getPeer() + logger.debug("Connected to %s", self._peer) + self.sendCommand(b"GET", self.path) if self.host: self.sendHeader(b"Host", self.host) @@ -124,7 +126,10 @@ class SynapseKeyClientProtocol(HTTPClient): self.timer.cancel() def on_timeout(self): - logger.debug("Timeout waiting for response from %s", self.host) + logger.debug( + "Timeout waiting for response from %s: %s", + self.host, self._peer, + ) self.errback(IOError("Timeout waiting for response")) self.transport.abortConnection() @@ -133,4 +138,5 @@ class SynapseKeyClientFactory(Factory): def protocol(self): protocol = SynapseKeyClientProtocol() protocol.path = self.path + protocol.host = self.host return protocol diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index d08ee0aa91..5012c10ee8 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -44,7 +44,21 @@ import logging logger = logging.getLogger(__name__) -KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids")) +VerifyKeyRequest = namedtuple("VerifyRequest", ( + "server_name", "key_ids", "json_object", "deferred" +)) +""" +A request for a verify key to verify a JSON object. + +Attributes: + server_name(str): The name of the server to verify against. + key_ids(set(str)): The set of key_ids to that could be used to verify the + JSON object + json_object(dict): The JSON object to verify. + deferred(twisted.internet.defer.Deferred): + A deferred (server_name, key_id, verify_key) tuple that resolves when + a verify key has been fetched +""" class Keyring(object): @@ -74,39 +88,32 @@ class Keyring(object): list of deferreds indicating success or failure to verify each json object's signature for the given server_name. """ - group_id_to_json = {} - group_id_to_group = {} - group_ids = [] - - next_group_id = 0 - deferreds = {} + verify_requests = [] for server_name, json_object in server_and_json: logger.debug("Verifying for %s", server_name) - group_id = next_group_id - next_group_id += 1 - group_ids.append(group_id) key_ids = signature_ids(json_object, server_name) if not key_ids: - deferreds[group_id] = defer.fail(SynapseError( + deferred = defer.fail(SynapseError( 400, "Not signed with a supported algorithm", Codes.UNAUTHORIZED, )) else: - deferreds[group_id] = defer.Deferred() + deferred = defer.Deferred() - group = KeyGroup(server_name, group_id, key_ids) + verify_request = VerifyKeyRequest( + server_name, key_ids, json_object, deferred + ) - group_id_to_group[group_id] = group - group_id_to_json[group_id] = json_object + verify_requests.append(verify_request) @defer.inlineCallbacks - def handle_key_deferred(group, deferred): - server_name = group.server_name + def handle_key_deferred(verify_request): + server_name = verify_request.server_name try: - _, _, key_id, verify_key = yield deferred + _, key_id, verify_key = yield verify_request.deferred except IOError as e: logger.warn( "Got IOError when downloading keys for %s: %s %s", @@ -128,7 +135,7 @@ class Keyring(object): Codes.UNAUTHORIZED, ) - json_object = group_id_to_json[group.group_id] + json_object = verify_request.json_object try: verify_signed_json(json_object, server_name, verify_key) @@ -157,36 +164,34 @@ class Keyring(object): # Actually start fetching keys. wait_on_deferred.addBoth( - lambda _: self.get_server_verify_keys(group_id_to_group, deferreds) + lambda _: self.get_server_verify_keys(verify_requests) ) # When we've finished fetching all the keys for a given server_name, # resolve the deferred passed to `wait_for_previous_lookups` so that # any lookups waiting will proceed. - server_to_gids = {} + server_to_request_ids = {} - def remove_deferreds(res, server_name, group_id): - server_to_gids[server_name].discard(group_id) - if not server_to_gids[server_name]: + def remove_deferreds(res, server_name, verify_request): + request_id = id(verify_request) + server_to_request_ids[server_name].discard(request_id) + if not server_to_request_ids[server_name]: d = server_to_deferred.pop(server_name, None) if d: d.callback(None) return res - for g_id, deferred in deferreds.items(): - server_name = group_id_to_group[g_id].server_name - server_to_gids.setdefault(server_name, set()).add(g_id) - deferred.addBoth(remove_deferreds, server_name, g_id) + for verify_request in verify_requests: + server_name = verify_request.server_name + request_id = id(verify_request) + server_to_request_ids.setdefault(server_name, set()).add(request_id) + deferred.addBoth(remove_deferreds, server_name, verify_request) # Pass those keys to handle_key_deferred so that the json object # signatures can be verified return [ - preserve_context_over_fn( - handle_key_deferred, - group_id_to_group[g_id], - deferreds[g_id], - ) - for g_id in group_ids + preserve_context_over_fn(handle_key_deferred, verify_request) + for verify_request in verify_requests ] @defer.inlineCallbacks @@ -220,7 +225,7 @@ class Keyring(object): d.addBoth(rm, server_name) - def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred): + def get_server_verify_keys(self, verify_requests): """Takes a dict of KeyGroups and tries to find at least one key for each group. """ @@ -237,62 +242,64 @@ class Keyring(object): merged_results = {} missing_keys = {} - for group in group_id_to_group.values(): - missing_keys.setdefault(group.server_name, set()).update( - group.key_ids + for verify_request in verify_requests: + missing_keys.setdefault(verify_request.server_name, set()).update( + verify_request.key_ids ) for fn in key_fetch_fns: results = yield fn(missing_keys.items()) merged_results.update(results) - # We now need to figure out which groups we have keys for - # and which we don't - missing_groups = {} - for group in group_id_to_group.values(): - for key_id in group.key_ids: - if key_id in merged_results[group.server_name]: + # We now need to figure out which verify requests we have keys + # for and which we don't + missing_keys = {} + requests_missing_keys = [] + for verify_request in verify_requests: + server_name = verify_request.server_name + result_keys = merged_results[server_name] + + if verify_request.deferred.called: + # We've already called this deferred, which probably + # means that we've already found a key for it. + continue + + for key_id in verify_request.key_ids: + if key_id in result_keys: with PreserveLoggingContext(): - group_id_to_deferred[group.group_id].callback(( - group.group_id, - group.server_name, + verify_request.deferred.callback(( + server_name, key_id, - merged_results[group.server_name][key_id], + result_keys[key_id], )) break else: - missing_groups.setdefault( - group.server_name, [] - ).append(group) - - if not missing_groups: + # The else block is only reached if the loop above + # doesn't break. + missing_keys.setdefault(server_name, set()).update( + verify_request.key_ids + ) + requests_missing_keys.append(verify_request) + + if not missing_keys: break - missing_keys = { - server_name: set( - key_id for group in groups for key_id in group.key_ids - ) - for server_name, groups in missing_groups.items() - } - - for group in missing_groups.values(): - group_id_to_deferred[group.group_id].errback(SynapseError( + for verify_request in requests_missing_keys.values(): + verify_request.deferred.errback(SynapseError( 401, "No key for %s with id %s" % ( - group.server_name, group.key_ids, + verify_request.server_name, verify_request.key_ids, ), Codes.UNAUTHORIZED, )) def on_err(err): - for deferred in group_id_to_deferred.values(): - if not deferred.called: - deferred.errback(err) + for verify_request in verify_requests: + if not verify_request.deferred.called: + verify_request.deferred.errback(err) do_iterations().addErrback(on_err) - return group_id_to_deferred - @defer.inlineCallbacks def get_keys_from_store(self, server_name_and_key_ids): res = yield defer.gatherResults( @@ -447,7 +454,7 @@ class Keyring(object): ) processed_response = yield self.process_v2_response( - perspective_name, response + perspective_name, response, only_from_server=False ) for server_name, response_keys in processed_response.items(): @@ -527,7 +534,7 @@ class Keyring(object): @defer.inlineCallbacks def process_v2_response(self, from_server, response_json, - requested_ids=[]): + requested_ids=[], only_from_server=True): time_now_ms = self.clock.time_msec() response_keys = {} verify_keys = {} @@ -551,6 +558,13 @@ class Keyring(object): results = {} server_name = response_json["server_name"] + if only_from_server: + if server_name != from_server: + raise ValueError( + "Expected a response for server %r not %r" % ( + from_server, server_name + ) + ) for key_id in response_json["signatures"].get(server_name, {}): if key_id not in response_json["verify_keys"]: raise ValueError( |