diff options
Diffstat (limited to 'synapse/federation/federation_client.py')
-rw-r--r-- | synapse/federation/federation_client.py | 90 |
1 files changed, 84 insertions, 6 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 4870e39652..8c6b839478 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -33,6 +33,7 @@ from typing import ( from prometheus_client import Counter from twisted.internet import defer +from twisted.internet.defer import Deferred from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( @@ -51,7 +52,7 @@ from synapse.api.room_versions import ( ) from synapse.events import EventBase, builder from synapse.federation.federation_base import FederationBase, event_from_pdu_json -from synapse.logging.context import make_deferred_yieldable +from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.logging.utils import log_function from synapse.types import JsonDict from synapse.util import unwrapFirstError @@ -187,7 +188,7 @@ class FederationClient(FederationBase): async def backfill( self, dest: str, room_id: str, limit: int, extremities: Iterable[str] - ) -> List[EventBase]: + ) -> Optional[List[EventBase]]: """Requests some more historic PDUs for the given room from the given destination server. @@ -199,9 +200,9 @@ class FederationClient(FederationBase): """ logger.debug("backfill extrem=%s", extremities) - # If there are no extremeties then we've (probably) reached the start. + # If there are no extremities then we've (probably) reached the start. if not extremities: - return + return None transaction_data = await self.transport_layer.backfill( dest, room_id, extremities, limit @@ -284,7 +285,7 @@ class FederationClient(FederationBase): pdu_list = [ event_from_pdu_json(p, room_version, outlier=outlier) for p in transaction_data["pdus"] - ] + ] # type: List[EventBase] if pdu_list and pdu_list[0]: pdu = pdu_list[0] @@ -345,6 +346,83 @@ class FederationClient(FederationBase): return state_event_ids, auth_event_ids + async def _check_sigs_and_hash_and_fetch( + self, + origin: str, + pdus: List[EventBase], + room_version: str, + outlier: bool = False, + include_none: bool = False, + ) -> List[EventBase]: + """Takes a list of PDUs and checks the signatures and hashs of each + one. If a PDU fails its signature check then we check if we have it in + the database and if not then request if from the originating server of + that PDU. + + If a PDU fails its content hash check then it is redacted. + + The given list of PDUs are not modified, instead the function returns + a new list. + + Args: + origin + pdu + room_version + outlier: Whether the events are outliers or not + include_none: Whether to include None in the returned list + for events that have failed their checks + + Returns: + Deferred : A list of PDUs that have valid signatures and hashes. + """ + deferreds = self._check_sigs_and_hashes(room_version, pdus) + + @defer.inlineCallbacks + def handle_check_result(pdu: EventBase, deferred: Deferred): + try: + res = yield make_deferred_yieldable(deferred) + except SynapseError: + res = None + + if not res: + # Check local db. + res = yield self.store.get_event( + pdu.event_id, allow_rejected=True, allow_none=True + ) + + if not res and pdu.origin != origin: + try: + res = yield defer.ensureDeferred( + self.get_pdu( + destinations=[pdu.origin], + event_id=pdu.event_id, + room_version=room_version, # type: ignore + outlier=outlier, + timeout=10000, + ) + ) + except SynapseError: + pass + + if not res: + logger.warning( + "Failed to find copy of %s with valid signature", pdu.event_id + ) + + return res + + handle = preserve_fn(handle_check_result) + deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)] + + valid_pdus = await make_deferred_yieldable( + defer.gatherResults(deferreds2, consumeErrors=True) + ).addErrback(unwrapFirstError) + + if include_none: + return valid_pdus + else: + return [p for p in valid_pdus if p] + async def get_event_auth(self, destination, room_id, event_id): res = await self.transport_layer.get_event_auth(destination, room_id, event_id) @@ -615,7 +693,7 @@ class FederationClient(FederationBase): ] if auth_chain_create_events != [create_event.event_id]: raise InvalidResponseError( - "Unexpected create event(s) in auth chain" + "Unexpected create event(s) in auth chain: %s" % (auth_chain_create_events,) ) |