diff options
Diffstat (limited to 'synapse/federation/federation_client.py')
-rw-r--r-- | synapse/federation/federation_client.py | 147 |
1 files changed, 93 insertions, 54 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index e0e9f5d0be..1076ebc036 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -21,6 +21,7 @@ from typing import ( Any, Awaitable, Callable, + Collection, Dict, Iterable, List, @@ -35,9 +36,6 @@ from typing import ( import attr from prometheus_client import Counter -from twisted.internet import defer -from twisted.internet.defer import Deferred - from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( CodeMessageException, @@ -56,10 +54,9 @@ from synapse.api.room_versions import ( from synapse.events import EventBase, builder from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.transport.client import SendJoinResponse -from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.logging.utils import log_function from synapse.types import JsonDict, get_domain_from_id -from synapse.util import unwrapFirstError +from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination @@ -360,10 +357,9 @@ class FederationClient(FederationBase): async def _check_sigs_and_hash_and_fetch( self, origin: str, - pdus: List[EventBase], + pdus: Collection[EventBase], room_version: RoomVersion, outlier: bool = False, - include_none: bool = False, ) -> List[EventBase]: """Takes a list of PDUs and checks the signatures and hashes of each one. If a PDU fails its signature check then we check if we have it in @@ -380,57 +376,87 @@ class FederationClient(FederationBase): pdu room_version outlier: Whether the events are outliers or not - include_none: Whether to include None in the returned list - for events that have failed their checks Returns: A list of PDUs that have valid signatures and hashes. """ - deferreds = self._check_sigs_and_hashes(room_version, pdus) - async def handle_check_result(pdu: EventBase, deferred: Deferred): - try: - res = await make_deferred_yieldable(deferred) - except SynapseError: - res = None + # We limit how many PDUs we check at once, as if we try to do hundreds + # of thousands of PDUs at once we see large memory spikes. - if not res: - # Check local db. - res = await self.store.get_event( - pdu.event_id, allow_rejected=True, allow_none=True - ) + valid_pdus = [] - pdu_origin = get_domain_from_id(pdu.sender) - if not res and pdu_origin != origin: - try: - res = await self.get_pdu( - destinations=[pdu_origin], - event_id=pdu.event_id, - room_version=room_version, - outlier=outlier, - timeout=10000, - ) - except SynapseError: - pass + async def _execute(pdu: EventBase) -> None: + valid_pdu = await self._check_sigs_and_hash_and_fetch_one( + pdu=pdu, + origin=origin, + outlier=outlier, + room_version=room_version, + ) - if not res: - logger.warning( - "Failed to find copy of %s with valid signature", pdu.event_id - ) + if valid_pdu: + valid_pdus.append(valid_pdu) - return res + await concurrently_execute(_execute, pdus, 10000) - handle = preserve_fn(handle_check_result) - deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)] + return valid_pdus - valid_pdus = await make_deferred_yieldable( - defer.gatherResults(deferreds2, consumeErrors=True) - ).addErrback(unwrapFirstError) + async def _check_sigs_and_hash_and_fetch_one( + self, + pdu: EventBase, + origin: str, + room_version: RoomVersion, + outlier: bool = False, + ) -> Optional[EventBase]: + """Takes a PDU and checks its signatures and hashes. If the PDU fails + its signature check then we check if we have it in the database and if + not then request if from the originating server of that PDU. - if include_none: - return valid_pdus - else: - return [p for p in valid_pdus if p] + If then PDU fails its content hash check then it is redacted. + + Args: + origin + pdu + room_version + outlier: Whether the events are outliers or not + include_none: Whether to include None in the returned list + for events that have failed their checks + + Returns: + The PDU (possibly redacted) if it has valid signatures and hashes. + """ + + res = None + try: + res = await self._check_sigs_and_hash(room_version, pdu) + except SynapseError: + pass + + if not res: + # Check local db. + res = await self.store.get_event( + pdu.event_id, allow_rejected=True, allow_none=True + ) + + pdu_origin = get_domain_from_id(pdu.sender) + if not res and pdu_origin != origin: + try: + res = await self.get_pdu( + destinations=[pdu_origin], + event_id=pdu.event_id, + room_version=room_version, + outlier=outlier, + timeout=10000, + ) + except SynapseError: + pass + + if not res: + logger.warning( + "Failed to find copy of %s with valid signature", pdu.event_id + ) + + return res async def get_event_auth( self, destination: str, room_id: str, event_id: str @@ -671,8 +697,6 @@ class FederationClient(FederationBase): state = response.state auth_chain = response.auth_events - pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)} - create_event = None for e in state: if (e.type, e.state_key) == (EventTypes.Create, ""): @@ -696,14 +720,29 @@ class FederationClient(FederationBase): % (create_room_version,) ) - valid_pdus = await self._check_sigs_and_hash_and_fetch( - destination, - list(pdus.values()), - outlier=True, - room_version=room_version, + logger.info( + "Processing from send_join %d events", len(state) + len(auth_chain) ) - valid_pdus_map = {p.event_id: p for p in valid_pdus} + # We now go and check the signatures and hashes for the event. Note + # that we limit how many events we process at a time to keep the + # memory overhead from exploding. + valid_pdus_map: Dict[str, EventBase] = {} + + async def _execute(pdu: EventBase) -> None: + valid_pdu = await self._check_sigs_and_hash_and_fetch_one( + pdu=pdu, + origin=destination, + outlier=True, + room_version=room_version, + ) + + if valid_pdu: + valid_pdus_map[valid_pdu.event_id] = valid_pdu + + await concurrently_execute( + _execute, itertools.chain(state, auth_chain), 10000 + ) # NB: We *need* to copy to ensure that we don't have multiple # references being passed on, as that causes... issues. |