summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2015-06-24 11:21:35 +0100
committerErik Johnston <erik@matrix.org>2015-06-24 11:21:35 +0100
commita29319fefa620b4878f9780645515e93dc83472a (patch)
treed4afe4ec6d2c01c7d237de854af4681cfd68c8a8
parentMerge branch 'develop' of github.com:matrix-org/synapse into erikj/persist_ev... (diff)
downloadsynapse-a29319fefa620b4878f9780645515e93dc83472a.tar.xz
Implement a batch API for verify_json_objects_for_server
-rw-r--r--synapse/crypto/keyring.py423
-rw-r--r--synapse/federation/federation_base.py49
-rw-r--r--synapse/federation/federation_client.py9
-rw-r--r--synapse/storage/keys.py6
4 files changed, 312 insertions, 175 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py

index aff69c5f83..eb94cd5b75 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py
@@ -25,11 +25,11 @@ from syutil.base64util import decode_base64, encode_base64 from synapse.api.errors import SynapseError, Codes from synapse.util.retryutils import get_retry_limiter - -from synapse.util.async import ObservableDeferred +from synapse.util import unwrapFirstError from OpenSSL import crypto +from collections import namedtuple import urllib import hashlib import logging @@ -38,6 +38,9 @@ import logging logger = logging.getLogger(__name__) +KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids")) + + class Keyring(object): def __init__(self, hs): self.store = hs.get_datastore() @@ -49,141 +52,257 @@ class Keyring(object): self.key_downloads = {} - @defer.inlineCallbacks def verify_json_for_server(self, server_name, json_object): - logger.debug("Verifying for %s", server_name) - key_ids = signature_ids(json_object, server_name) - if not key_ids: - raise SynapseError( - 400, - "Not signed with a supported algorithm", - Codes.UNAUTHORIZED, - ) - try: - verify_key = yield self.get_server_verify_key(server_name, key_ids) - except IOError as e: - logger.warn( - "Got IOError when downloading keys for %s: %s %s", - server_name, type(e).__name__, str(e.message), - ) - raise SynapseError( - 502, - "Error downloading keys for %s" % (server_name,), - Codes.UNAUTHORIZED, - ) - except Exception as e: - logger.warn( - "Got Exception when downloading keys for %s: %s %s", - server_name, type(e).__name__, str(e.message), - ) - raise SynapseError( - 401, - "No key for %s with id %s" % (server_name, key_ids), - Codes.UNAUTHORIZED, - ) + return self.verify_json_objects_for_server( + [(server_name, json_object)] + )[0] + + def verify_json_objects_for_server(self, server_and_json): + server_to_key_groupings = {} + group_id_to_json = {} + group_id_to_group = {} + group_ids = [] + + next_group_id = 0 + + for server_name, json_object in server_and_json: + logger.debug("Verifying for %s", server_name) + key_ids = signature_ids(json_object, server_name) + if not key_ids: + raise SynapseError( + 400, + "Not signed with a supported algorithm", + Codes.UNAUTHORIZED, + ) - try: - verify_signed_json(json_object, server_name, verify_key) - except: - raise SynapseError( - 401, - "Invalid signature for server %s with key %s:%s" % ( - server_name, verify_key.alg, verify_key.version - ), - Codes.UNAUTHORIZED, - ) + group_id = next_group_id + next_group_id += 1 + group_ids.append(group_id) - @defer.inlineCallbacks - def get_server_verify_key(self, server_name, key_ids): - """Finds a verification key for the server with one of the key ids. - Trys to fetch the key from a trusted perspective server first. - Args: - server_name(str): The name of the server to fetch a key for. - keys_ids (list of str): The key_ids to check for. - """ - cached = yield self.store.get_server_verify_keys(server_name, key_ids) + group = KeyGroup(server_name, group_id, key_ids) + + group_id_to_group[group_id] = group + group_id_to_json[group_id] = json_object + server_to_key_groupings.setdefault(server_name, []).append(group) + + @defer.inlineCallbacks + def handle_key_deferred(group, deferred): + server_name = group.server_name + try: + _, _, key_id, verify_key = yield deferred + except IOError as e: + logger.warn( + "Got IOError when downloading keys for %s: %s %s", + server_name, type(e).__name__, str(e.message), + ) + raise SynapseError( + 502, + "Error downloading keys for %s" % (server_name,), + Codes.UNAUTHORIZED, + ) + except Exception as e: + logger.exception( + "Got Exception when downloading keys for %s: %s %s", + server_name, type(e).__name__, str(e.message), + ) + raise SynapseError( + 401, + "No key for %s with id %s" % (server_name, key_ids), + Codes.UNAUTHORIZED, + ) + + json_object = group_id_to_json[group.group_id] + + try: + verify_signed_json(json_object, server_name, verify_key) + except: + raise SynapseError( + 401, + "Invalid signature for server %s with key %s:%s" % ( + server_name, verify_key.alg, verify_key.version + ), + Codes.UNAUTHORIZED, + ) - if cached: - defer.returnValue(cached[0]) - return + deferreds = self.get_server_verify_keys( + group_id_to_group + ) - download = self.key_downloads.get(server_name) + logger.info( + "Deferred count: %d vs. %d", + len(deferreds.items()), + len(server_and_json) + ) - if download is None: - download = self._get_server_verify_key_impl(server_name, key_ids) - download = ObservableDeferred( - download, - consumeErrors=True + return [ + handle_key_deferred( + group_id_to_group[g_id], + deferreds[g_id], ) - self.key_downloads[server_name] = download + for g_id in group_ids + ] + + def get_server_verify_keys(self, group_id_to_group): + merged_results = {} - @download.addBoth - def callback(ret): - del self.key_downloads[server_name] - return ret + fns = ( + self.get_keys_from_store, # First try the local store + self.get_keys_from_perspectives, # Then try via perspectives + self.get_keys_from_server, # Then try directly + ) - r = yield download.observe() - defer.returnValue(r) + group_deferreds = { + group_id: defer.Deferred() + for group_id in group_id_to_group + } + + @defer.inlineCallbacks + def do_iterations(): + missing_keys = { + group.server_name: key_id + for group in group_id_to_group.values() + for key_id in group.key_ids + } + + for fn in fns: + results = yield fn(missing_keys.items()) + merged_results.update(results) + + missing_groups = {} + for group in group_id_to_group.values(): + for key_id in group.key_ids: + if key_id in merged_results[group.server_name]: + group_deferreds.pop(group.group_id).callback(( + group.group_id, + group.server_name, + key_id, + merged_results[group.server_name][key_id], + )) + break + else: + missing_groups.setdefault( + group.server_name, [] + ).append(group) + + if not missing_groups: + break + + missing_keys = { + server_name: set( + key_id for group in groups for key_id in group.key_ids + ) + for server_name, groups in missing_groups.items() + } + + for group in missing_groups.values(): + group_deferreds.pop(group.group_id).errback(SynapseError( + 401, + "No key for %s with id %s" % ( + group.server_name, group.key_ids, + ), + Codes.UNAUTHORIZED, + )) + + def on_err(err): + for deferred in group_deferreds.values(): + deferred.errback(err) + group_deferreds.clear() + + do_iterations().addErrback(on_err) + + return group_deferreds @defer.inlineCallbacks - def _get_server_verify_key_impl(self, server_name, key_ids): - keys = None + def get_keys_from_store(self, server_name_and_key_ids): + res = yield defer.gatherResults( + [ + self.store.get_server_verify_keys(server_name, key_ids) + for server_name, key_ids in server_name_and_key_ids + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + + defer.returnValue(dict(zip( + [server_name for server_name, _ in server_name_and_key_ids], + res + ))) + @defer.inlineCallbacks + def get_keys_from_perspectives(self, server_name_and_key_ids): @defer.inlineCallbacks def get_key(perspective_name, perspective_keys): try: result = yield self.get_server_verify_key_v2_indirect( - server_name, key_ids, perspective_name, perspective_keys + server_name_and_key_ids, perspective_name, perspective_keys ) defer.returnValue(result) except Exception as e: - logging.info( - "Unable to getting key %r for %r from %r: %s %s", - key_ids, server_name, perspective_name, + logger.info( + "Unable to get key from %r: %s %s", + perspective_name, type(e).__name__, str(e.message), ) - perspective_results = yield defer.gatherResults([ + results = yield defer.gatherResults([ get_key(p_name, p_keys) for p_name, p_keys in self.perspective_servers.items() ]) - for results in perspective_results: - if results is not None: - keys = results + union_of_keys = {} + for result in results: + for server_name, keys in results.items(): + union_of_keys.setdefault(server_name, {}).update(keys) - limiter = yield get_retry_limiter( - server_name, - self.clock, - self.store, - ) + defer.returnValue(union_of_keys) - with limiter: - if not keys: + @defer.inlineCallbacks + def get_keys_from_server(self, server_name_and_key_ids): + @defer.inlineCallbacks + def get_key(server_name, key_ids): + limiter = yield get_retry_limiter( + server_name, + self.clock, + self.store, + ) + with limiter: + keys = None try: keys = yield self.get_server_verify_key_v2_direct( server_name, key_ids ) except Exception as e: - logging.info( + logger.info( "Unable to getting key %r for %r directly: %s %s", key_ids, server_name, type(e).__name__, str(e.message), ) - if not keys: - keys = yield self.get_server_verify_key_v1_direct( - server_name, key_ids - ) + if not keys: + keys = yield self.get_server_verify_key_v1_direct( + server_name, key_ids + ) + + keys = {server_name: keys} + + defer.returnValue(keys) + + results = yield defer.gatherResults([ + get_key(server_name, key_ids) + for server_name, key_ids in server_name_and_key_ids + ]) + + merged = {} + for result in results: + merged.update(result) - for key_id in key_ids: - if key_id in keys: - defer.returnValue(keys[key_id]) - return - raise ValueError("No verification key found for given key ids") + defer.returnValue({ + server_name: keys + for server_name, keys in merged.items() + if keys + }) @defer.inlineCallbacks - def get_server_verify_key_v2_indirect(self, server_name, key_ids, + def get_server_verify_key_v2_indirect(self, server_names_and_key_ids, perspective_name, perspective_keys): limiter = yield get_retry_limiter( @@ -204,6 +323,7 @@ class Keyring(object): u"minimum_valid_until_ts": 0 } for key_id in key_ids } + for server_name, key_ids in server_names_and_key_ids } }, ) @@ -243,23 +363,24 @@ class Keyring(object): " server %r" % (perspective_name,) ) - response_keys = yield self.process_v2_response( - server_name, perspective_name, response + processed_response = yield self.process_v2_response( + perspective_name, response ) - keys.update(response_keys) + for server_name, response_keys in processed_response: + keys.setdefault(server_name, {}).update(response_keys) - yield self.store_keys( - server_name=server_name, - from_server=perspective_name, - verify_keys=keys, - ) + for server_name, response_keys in keys.items(): + yield self.store_keys( + server_name=server_name, + from_server=perspective_name, + verify_keys=keys, + ) defer.returnValue(keys) @defer.inlineCallbacks def get_server_verify_key_v2_direct(self, server_name, key_ids): - keys = {} for requested_key_id in key_ids: @@ -295,25 +416,25 @@ class Keyring(object): raise ValueError("TLS certificate not allowed by fingerprints") response_keys = yield self.process_v2_response( - server_name=server_name, from_server=server_name, - requested_id=requested_key_id, + requested_ids=[requested_key_id], response_json=response, ) keys.update(response_keys) - yield self.store_keys( - server_name=server_name, - from_server=server_name, - verify_keys=keys, - ) + for server_name, verify_keys in keys.items(): + yield self.store_keys( + server_name=server_name, + from_server=server_name, + verify_keys=verify_keys, + ) defer.returnValue(keys) @defer.inlineCallbacks - def process_v2_response(self, server_name, from_server, response_json, - requested_id=None): + def process_v2_response(self, from_server, response_json, + requested_ids=[]): time_now_ms = self.clock.time_msec() response_keys = {} verify_keys = {} @@ -335,50 +456,50 @@ class Keyring(object): verify_key.time_added = time_now_ms old_verify_keys[key_id] = verify_key - for key_id in response_json["signatures"].get(server_name, {}): - if key_id not in response_json["verify_keys"]: - raise ValueError( - "Key response must include verification keys for all" - " signatures" - ) - if key_id in verify_keys: - verify_signed_json( - response_json, - server_name, - verify_keys[key_id] - ) + results = {} + for server_name, keys_dict in response_json["signatures"].items(): + for key_id in keys_dict: + if key_id not in response_json["verify_keys"]: + raise ValueError( + "Key response must include verification keys for all" + " signatures" + ) + if key_id in verify_keys: + verify_signed_json( + response_json, + server_name, + verify_keys[key_id] + ) - signed_key_json = sign_json( - response_json, - self.config.server_name, - self.config.signing_key[0], - ) + signed_key_json = sign_json( + response_json, + self.config.server_name, + self.config.signing_key[0], + ) - signed_key_json_bytes = encode_canonical_json(signed_key_json) - ts_valid_until_ms = signed_key_json[u"valid_until_ts"] + signed_key_json_bytes = encode_canonical_json(signed_key_json) + ts_valid_until_ms = signed_key_json[u"valid_until_ts"] - updated_key_ids = set() - if requested_id is not None: - updated_key_ids.add(requested_id) - updated_key_ids.update(verify_keys) - updated_key_ids.update(old_verify_keys) + updated_key_ids = set(requested_ids) + updated_key_ids.update(verify_keys) + updated_key_ids.update(old_verify_keys) - response_keys.update(verify_keys) - response_keys.update(old_verify_keys) + response_keys.update(verify_keys) + response_keys.update(old_verify_keys) - for key_id in updated_key_ids: - yield self.store.store_server_keys_json( - server_name=server_name, - key_id=key_id, - from_server=server_name, - ts_now_ms=time_now_ms, - ts_expires_ms=ts_valid_until_ms, - key_json_bytes=signed_key_json_bytes, - ) + for key_id in updated_key_ids: + yield self.store.store_server_keys_json( + server_name=server_name, + key_id=key_id, + from_server=server_name, + ts_now_ms=time_now_ms, + ts_expires_ms=ts_valid_until_ms, + key_json_bytes=signed_key_json_bytes, + ) - defer.returnValue(response_keys) + results[server_name] = response_keys - raise ValueError("No verification key found for given key ids") + defer.returnValue(results) @defer.inlineCallbacks def get_server_verify_key_v1_direct(self, server_name, key_ids): diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 299493af91..407e0f815c 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py
@@ -99,35 +99,50 @@ class FederationBase(object): defer.returnValue(signed_pdus) - @defer.inlineCallbacks def _check_sigs_and_hash(self, pdu): - """Throws a SynapseError if the PDU does not have the correct + return self._check_sigs_and_hashes([pdu])[0] + + def _check_sigs_and_hashes(self, pdus): + """Throws a SynapseError if a PDU does not have the correct signatures. Returns: FrozenEvent: Either the given event or it redacted if it failed the content hash check. """ - # Check signatures are correct. - redacted_event = prune_event(pdu) - redacted_pdu_json = redacted_event.get_pdu_json() - try: - yield self.keyring.verify_json_for_server( - pdu.origin, redacted_pdu_json - ) - except SynapseError: + redacted_pdus = [ + prune_event(pdu) + for pdu in pdus + ] + + deferreds = self.keyring.verify_json_objects_for_server([ + (p.origin, p.get_pdu_json()) + for p in redacted_pdus + ]) + + def callback(_, pdu, redacted): + if not check_event_content_hash(pdu): + logger.warn( + "Event content has been tampered, redacting %s: %s", + pdu.event_id, pdu.get_pdu_json() + ) + return redacted + return pdu + + def errback(failure, pdu): + failure.trap(SynapseError) logger.warn( "Signature check failed for %s", pdu.event_id, ) - raise + return failure - if not check_event_content_hash(pdu): - logger.warn( - "Event content has been tampered, redacting.", - pdu.event_id, + for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus): + deferred.addCallbacks( + callback, errback, + callbackArgs=[pdu, redacted], + errbackArgs=[pdu], ) - defer.returnValue(redacted_event) - defer.returnValue(pdu) + return deferreds diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 7ee3c66bf2..47d71542e4 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py
@@ -166,10 +166,7 @@ class FederationClient(FederationBase): ] # FIXME: We should handle signature failures more gracefully. - pdus[:] = yield defer.gatherResults( - [self._check_sigs_and_hash(pdu) for pdu in pdus], - consumeErrors=True, - ).addErrback(unwrapFirstError) + pdus[:] = yield self._check_sigs_and_hashes(pdus) defer.returnValue(pdus) @@ -230,7 +227,7 @@ class FederationClient(FederationBase): pdu = pdu_list[0] # Check signatures are correct. - pdu = yield self._check_sigs_and_hash(pdu) + pdu = yield self._check_sigs_and_hashes([pdu])[0] break @@ -402,7 +399,7 @@ class FederationClient(FederationBase): except CodeMessageException: raise except Exception as e: - logger.warn( + logger.exception( "Failed to send_join via %s: %s", destination, e.message ) diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 2902e35181..4f990b7792 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py
@@ -101,7 +101,11 @@ class KeyStore(SQLBaseStore): (list of VerifyKey): The verification keys. """ keys = yield self.get_all_server_verify_keys(server_name) - defer.returnValue([keys[k] for k in key_ids if k in keys]) + defer.returnValue({ + k: keys[k] + for k in key_ids + if k in keys and keys[k] + }) @defer.inlineCallbacks def store_server_verify_key(self, server_name, from_server, time_now_ms,