diff options
author | Erik Johnston <erik@matrix.org> | 2015-06-26 09:52:24 +0100 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2015-06-26 10:39:34 +0100 |
commit | b5f55a1d85386834b18828b196e1859a4410d98a (patch) | |
tree | 6c8b39f1a6df244b7cabcc503bc300f0dd686754 /synapse/federation | |
parent | Batch SELECTs in _get_auth_chain_ids_txn (diff) | |
download | synapse-b5f55a1d85386834b18828b196e1859a4410d98a.tar.xz |
Implement bulk verify_signed_json API
Diffstat (limited to 'synapse/federation')
-rw-r--r-- | synapse/federation/federation_base.py | 125 | ||||
-rw-r--r-- | synapse/federation/federation_client.py | 57 |
2 files changed, 118 insertions, 64 deletions
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 299493af91..bdfa247604 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -32,7 +32,8 @@ logger = logging.getLogger(__name__) class FederationBase(object): @defer.inlineCallbacks - def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False): + def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False, + include_none=False): """Takes a list of PDUs and checks the signatures and hashs of each one. If a PDU fails its signature check then we check if we have it in the database and if not then request if from the originating server of @@ -50,84 +51,108 @@ class FederationBase(object): Returns: Deferred : A list of PDUs that have valid signatures and hashes. """ + deferreds = self._check_sigs_and_hashes(pdus) - signed_pdus = [] + def callback(pdu): + return pdu - @defer.inlineCallbacks - def do(pdu): - try: - new_pdu = yield self._check_sigs_and_hash(pdu) - signed_pdus.append(new_pdu) - except SynapseError: - # FIXME: We should handle signature failures more gracefully. + def errback(failure, pdu): + failure.trap(SynapseError) + return None + def try_local_db(res, pdu): + if not res: # Check local db. - new_pdu = yield self.store.get_event( + return self.store.get_event( pdu.event_id, allow_rejected=True, allow_none=True, ) - if new_pdu: - signed_pdus.append(new_pdu) - return - - # Check pdu.origin - if pdu.origin != origin: - try: - new_pdu = yield self.get_pdu( - destinations=[pdu.origin], - event_id=pdu.event_id, - outlier=outlier, - timeout=10000, - ) - - if new_pdu: - signed_pdus.append(new_pdu) - return - except: - pass - + return res + + def try_remote(res, pdu): + if not res and pdu.origin != origin: + return self.get_pdu( + destinations=[pdu.origin], + event_id=pdu.event_id, + outlier=outlier, + timeout=10000, + ).addErrback(lambda e: None) + return res + + def warn(res, pdu): + if not res: logger.warn( "Failed to find copy of %s with valid signature", pdu.event_id, ) + return res + + for pdu, deferred in zip(pdus, deferreds): + deferred.addCallbacks( + callback, errback, errbackArgs=[pdu] + ).addCallback( + try_local_db, pdu + ).addCallback( + try_remote, pdu + ).addCallback( + warn, pdu + ) - yield defer.gatherResults( - [do(pdu) for pdu in pdus], + valid_pdus = yield defer.gatherResults( + deferreds, consumeErrors=True ).addErrback(unwrapFirstError) - defer.returnValue(signed_pdus) + if include_none: + defer.returnValue(valid_pdus) + else: + defer.returnValue([p for p in valid_pdus if p]) - @defer.inlineCallbacks def _check_sigs_and_hash(self, pdu): - """Throws a SynapseError if the PDU does not have the correct + return self._check_sigs_and_hashes([pdu])[0] + + def _check_sigs_and_hashes(self, pdus): + """Throws a SynapseError if a PDU does not have the correct signatures. Returns: FrozenEvent: Either the given event or it redacted if it failed the content hash check. """ - # Check signatures are correct. - redacted_event = prune_event(pdu) - redacted_pdu_json = redacted_event.get_pdu_json() - try: - yield self.keyring.verify_json_for_server( - pdu.origin, redacted_pdu_json - ) - except SynapseError: + redacted_pdus = [ + prune_event(pdu) + for pdu in pdus + ] + + deferreds = self.keyring.verify_json_objects_for_server([ + (p.origin, p.get_pdu_json()) + for p in redacted_pdus + ]) + + def callback(_, pdu, redacted): + if not check_event_content_hash(pdu): + logger.warn( + "Event content has been tampered, redacting %s: %s", + pdu.event_id, pdu.get_pdu_json() + ) + return redacted + return pdu + + def errback(failure, pdu): + failure.trap(SynapseError) logger.warn( "Signature check failed for %s", pdu.event_id, ) - raise + return failure - if not check_event_content_hash(pdu): - logger.warn( - "Event content has been tampered, redacting.", - pdu.event_id, + for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus): + deferred.addCallbacks( + callback, errback, + callbackArgs=[pdu, redacted], + errbackArgs=[pdu], ) - defer.returnValue(redacted_event) - defer.returnValue(pdu) + return deferreds diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index d3b46b24c1..7736d14fb5 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -30,6 +30,7 @@ import synapse.metrics from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination +import copy import itertools import logging import random @@ -167,7 +168,7 @@ class FederationClient(FederationBase): # FIXME: We should handle signature failures more gracefully. pdus[:] = yield defer.gatherResults( - [self._check_sigs_and_hash(pdu) for pdu in pdus], + self._check_sigs_and_hashes(pdus), consumeErrors=True, ).addErrback(unwrapFirstError) @@ -230,7 +231,7 @@ class FederationClient(FederationBase): pdu = pdu_list[0] # Check signatures are correct. - pdu = yield self._check_sigs_and_hash(pdu) + pdu = yield self._check_sigs_and_hashes([pdu])[0] break @@ -327,6 +328,9 @@ class FederationClient(FederationBase): @defer.inlineCallbacks def make_join(self, destinations, room_id, user_id): for destination in destinations: + if destination == self.server_name: + continue + try: ret = yield self.transport_layer.make_join( destination, room_id, user_id @@ -353,6 +357,9 @@ class FederationClient(FederationBase): @defer.inlineCallbacks def send_join(self, destinations, pdu): for destination in destinations: + if destination == self.server_name: + continue + try: time_now = self._clock.time_msec() _, content = yield self.transport_layer.send_join( @@ -374,17 +381,39 @@ class FederationClient(FederationBase): for p in content.get("auth_chain", []) ] - signed_state, signed_auth = yield defer.gatherResults( - [ - self._check_sigs_and_hash_and_fetch( - destination, state, outlier=True - ), - self._check_sigs_and_hash_and_fetch( - destination, auth_chain, outlier=True - ) - ], - consumeErrors=True - ).addErrback(unwrapFirstError) + pdus = { + p.event_id: p + for p in itertools.chain(state, auth_chain) + } + + valid_pdus = yield self._check_sigs_and_hash_and_fetch( + destination, pdus.values(), + outlier=True, + ) + + valid_pdus_map = { + p.event_id: p + for p in valid_pdus + } + + # NB: We *need* to copy to ensure that we don't have multiple + # references being passed on, as that causes... issues. + signed_state = [ + copy.copy(valid_pdus_map[p.event_id]) + for p in state + if p.event_id in valid_pdus_map + ] + + signed_auth = [ + valid_pdus_map[p.event_id] + for p in auth_chain + if p.event_id in valid_pdus_map + ] + + # NB: We *need* to copy to ensure that we don't have multiple + # references being passed on, as that causes... issues. + for s in signed_state: + s.internal_metadata = copy.deepcopy(s.internal_metadata) auth_chain.sort(key=lambda e: e.depth) @@ -396,7 +425,7 @@ class FederationClient(FederationBase): except CodeMessageException: raise except Exception as e: - logger.warn( + logger.exception( "Failed to send_join via %s: %s", destination, e.message ) |