summary refs log tree commit diff
path: root/synapse/federation/federation_client.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/federation_client.py')
-rw-r--r--synapse/federation/federation_client.py147
1 files changed, 93 insertions, 54 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index e0e9f5d0be..1076ebc036 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -21,6 +21,7 @@ from typing import (
     Any,
     Awaitable,
     Callable,
+    Collection,
     Dict,
     Iterable,
     List,
@@ -35,9 +36,6 @@ from typing import (
 import attr
 from prometheus_client import Counter
 
-from twisted.internet import defer
-from twisted.internet.defer import Deferred
-
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import (
     CodeMessageException,
@@ -56,10 +54,9 @@ from synapse.api.room_versions import (
 from synapse.events import EventBase, builder
 from synapse.federation.federation_base import FederationBase, event_from_pdu_json
 from synapse.federation.transport.client import SendJoinResponse
-from synapse.logging.context import make_deferred_yieldable, preserve_fn
 from synapse.logging.utils import log_function
 from synapse.types import JsonDict, get_domain_from_id
-from synapse.util import unwrapFirstError
+from synapse.util.async_helpers import concurrently_execute
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.retryutils import NotRetryingDestination
 
@@ -360,10 +357,9 @@ class FederationClient(FederationBase):
     async def _check_sigs_and_hash_and_fetch(
         self,
         origin: str,
-        pdus: List[EventBase],
+        pdus: Collection[EventBase],
         room_version: RoomVersion,
         outlier: bool = False,
-        include_none: bool = False,
     ) -> List[EventBase]:
         """Takes a list of PDUs and checks the signatures and hashes of each
         one. If a PDU fails its signature check then we check if we have it in
@@ -380,57 +376,87 @@ class FederationClient(FederationBase):
             pdu
             room_version
             outlier: Whether the events are outliers or not
-            include_none: Whether to include None in the returned list
-                for events that have failed their checks
 
         Returns:
             A list of PDUs that have valid signatures and hashes.
         """
-        deferreds = self._check_sigs_and_hashes(room_version, pdus)
 
-        async def handle_check_result(pdu: EventBase, deferred: Deferred):
-            try:
-                res = await make_deferred_yieldable(deferred)
-            except SynapseError:
-                res = None
+        # We limit how many PDUs we check at once, as if we try to do hundreds
+        # of thousands of PDUs at once we see large memory spikes.
 
-            if not res:
-                # Check local db.
-                res = await self.store.get_event(
-                    pdu.event_id, allow_rejected=True, allow_none=True
-                )
+        valid_pdus = []
 
-            pdu_origin = get_domain_from_id(pdu.sender)
-            if not res and pdu_origin != origin:
-                try:
-                    res = await self.get_pdu(
-                        destinations=[pdu_origin],
-                        event_id=pdu.event_id,
-                        room_version=room_version,
-                        outlier=outlier,
-                        timeout=10000,
-                    )
-                except SynapseError:
-                    pass
+        async def _execute(pdu: EventBase) -> None:
+            valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
+                pdu=pdu,
+                origin=origin,
+                outlier=outlier,
+                room_version=room_version,
+            )
 
-            if not res:
-                logger.warning(
-                    "Failed to find copy of %s with valid signature", pdu.event_id
-                )
+            if valid_pdu:
+                valid_pdus.append(valid_pdu)
 
-            return res
+        await concurrently_execute(_execute, pdus, 10000)
 
-        handle = preserve_fn(handle_check_result)
-        deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)]
+        return valid_pdus
 
-        valid_pdus = await make_deferred_yieldable(
-            defer.gatherResults(deferreds2, consumeErrors=True)
-        ).addErrback(unwrapFirstError)
+    async def _check_sigs_and_hash_and_fetch_one(
+        self,
+        pdu: EventBase,
+        origin: str,
+        room_version: RoomVersion,
+        outlier: bool = False,
+    ) -> Optional[EventBase]:
+        """Takes a PDU and checks its signatures and hashes. If the PDU fails
+        its signature check then we check if we have it in the database and if
+        not then request if from the originating server of that PDU.
 
-        if include_none:
-            return valid_pdus
-        else:
-            return [p for p in valid_pdus if p]
+        If then PDU fails its content hash check then it is redacted.
+
+        Args:
+            origin
+            pdu
+            room_version
+            outlier: Whether the events are outliers or not
+            include_none: Whether to include None in the returned list
+                for events that have failed their checks
+
+        Returns:
+            The PDU (possibly redacted) if it has valid signatures and hashes.
+        """
+
+        res = None
+        try:
+            res = await self._check_sigs_and_hash(room_version, pdu)
+        except SynapseError:
+            pass
+
+        if not res:
+            # Check local db.
+            res = await self.store.get_event(
+                pdu.event_id, allow_rejected=True, allow_none=True
+            )
+
+        pdu_origin = get_domain_from_id(pdu.sender)
+        if not res and pdu_origin != origin:
+            try:
+                res = await self.get_pdu(
+                    destinations=[pdu_origin],
+                    event_id=pdu.event_id,
+                    room_version=room_version,
+                    outlier=outlier,
+                    timeout=10000,
+                )
+            except SynapseError:
+                pass
+
+        if not res:
+            logger.warning(
+                "Failed to find copy of %s with valid signature", pdu.event_id
+            )
+
+        return res
 
     async def get_event_auth(
         self, destination: str, room_id: str, event_id: str
@@ -671,8 +697,6 @@ class FederationClient(FederationBase):
             state = response.state
             auth_chain = response.auth_events
 
-            pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)}
-
             create_event = None
             for e in state:
                 if (e.type, e.state_key) == (EventTypes.Create, ""):
@@ -696,14 +720,29 @@ class FederationClient(FederationBase):
                     % (create_room_version,)
                 )
 
-            valid_pdus = await self._check_sigs_and_hash_and_fetch(
-                destination,
-                list(pdus.values()),
-                outlier=True,
-                room_version=room_version,
+            logger.info(
+                "Processing from send_join %d events", len(state) + len(auth_chain)
             )
 
-            valid_pdus_map = {p.event_id: p for p in valid_pdus}
+            # We now go and check the signatures and hashes for the event. Note
+            # that we limit how many events we process at a time to keep the
+            # memory overhead from exploding.
+            valid_pdus_map: Dict[str, EventBase] = {}
+
+            async def _execute(pdu: EventBase) -> None:
+                valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
+                    pdu=pdu,
+                    origin=destination,
+                    outlier=True,
+                    room_version=room_version,
+                )
+
+                if valid_pdu:
+                    valid_pdus_map[valid_pdu.event_id] = valid_pdu
+
+            await concurrently_execute(
+                _execute, itertools.chain(state, auth_chain), 10000
+            )
 
             # NB: We *need* to copy to ensure that we don't have multiple
             # references being passed on, as that causes... issues.