summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/10117.feature1
-rw-r--r--synapse/crypto/keyring.py46
-rw-r--r--synapse/federation/federation_base.py243
-rw-r--r--synapse/federation/federation_client.py147
-rw-r--r--synapse/util/async_helpers.py21
5 files changed, 202 insertions, 256 deletions
diff --git a/changelog.d/10117.feature b/changelog.d/10117.feature
new file mode 100644
index 0000000000..e137e142c6
--- /dev/null
+++ b/changelog.d/10117.feature
@@ -0,0 +1 @@
+Significantly reduce memory usage of joining large remote rooms.
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index c840ffca71..e5a4685ed4 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -233,41 +233,19 @@ class Keyring:
             for server_name, json_object, validity_time in server_and_json
         ]
 
-    def verify_events_for_server(
-        self, server_and_events: Iterable[Tuple[str, EventBase, int]]
-    ) -> List[defer.Deferred]:
-        """Bulk verification of signatures on events.
-
-        Args:
-            server_and_events:
-                Iterable of `(server_name, event, validity_time)` tuples.
-
-                `server_name` is which server we are verifying the signature for
-                on the event.
-
-                `event` is the event that we'll verify the signatures of for
-                the given `server_name`.
-
-                `validity_time` is a timestamp at which the signing key must be
-                valid.
-
-        Returns:
-            List<Deferred[None]>: for each input triplet, a deferred indicating success
-                or failure to verify each event's signature for the given
-                server_name. The deferreds run their callbacks in the sentinel
-                logcontext.
-        """
-        return [
-            run_in_background(
-                self.process_request,
-                VerifyJsonRequest.from_event(
-                    server_name,
-                    event,
-                    validity_time,
-                ),
+    async def verify_event_for_server(
+        self,
+        server_name: str,
+        event: EventBase,
+        validity_time: int,
+    ) -> None:
+        await self.process_request(
+            VerifyJsonRequest.from_event(
+                server_name,
+                event,
+                validity_time,
             )
-            for server_name, event, validity_time in server_and_events
-        ]
+        )
 
     async def process_request(self, verify_request: VerifyJsonRequest) -> None:
         """Processes the `VerifyJsonRequest`. Raises if the object is not signed
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 3fe496dcd3..c066617b92 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -14,11 +14,6 @@
 # limitations under the License.
 import logging
 from collections import namedtuple
-from typing import Iterable, List
-
-from twisted.internet import defer
-from twisted.internet.defer import Deferred, DeferredList
-from twisted.python.failure import Failure
 
 from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
 from synapse.api.errors import Codes, SynapseError
@@ -28,11 +23,6 @@ from synapse.crypto.keyring import Keyring
 from synapse.events import EventBase, make_event_from_dict
 from synapse.events.utils import prune_event, validate_canonicaljson
 from synapse.http.servlet import assert_params_in_dict
-from synapse.logging.context import (
-    PreserveLoggingContext,
-    current_context,
-    make_deferred_yieldable,
-)
 from synapse.types import JsonDict, get_domain_from_id
 
 logger = logging.getLogger(__name__)
@@ -48,112 +38,82 @@ class FederationBase:
         self.store = hs.get_datastore()
         self._clock = hs.get_clock()
 
-    def _check_sigs_and_hash(
+    async def _check_sigs_and_hash(
         self, room_version: RoomVersion, pdu: EventBase
-    ) -> Deferred:
-        return make_deferred_yieldable(
-            self._check_sigs_and_hashes(room_version, [pdu])[0]
-        )
-
-    def _check_sigs_and_hashes(
-        self, room_version: RoomVersion, pdus: List[EventBase]
-    ) -> List[Deferred]:
-        """Checks that each of the received events is correctly signed by the
-        sending server.
+    ) -> EventBase:
+        """Checks that event is correctly signed by the sending server.
 
         Args:
-            room_version: The room version of the PDUs
-            pdus: the events to be checked
+            room_version: The room version of the PDU
+            pdu: the event to be checked
 
         Returns:
-            For each input event, a deferred which:
-              * returns the original event if the checks pass
-              * returns a redacted version of the event (if the signature
+              * the original event if the checks pass
+              * a redacted version of the event (if the signature
                 matched but the hash did not)
-              * throws a SynapseError if the signature check failed.
-            The deferreds run their callbacks in the sentinel
-        """
-        deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus)
-
-        ctx = current_context()
-
-        @defer.inlineCallbacks
-        def callback(_, pdu: EventBase):
-            with PreserveLoggingContext(ctx):
-                if not check_event_content_hash(pdu):
-                    # let's try to distinguish between failures because the event was
-                    # redacted (which are somewhat expected) vs actual ball-tampering
-                    # incidents.
-                    #
-                    # This is just a heuristic, so we just assume that if the keys are
-                    # about the same between the redacted and received events, then the
-                    # received event was probably a redacted copy (but we then use our
-                    # *actual* redacted copy to be on the safe side.)
-                    redacted_event = prune_event(pdu)
-                    if set(redacted_event.keys()) == set(pdu.keys()) and set(
-                        redacted_event.content.keys()
-                    ) == set(pdu.content.keys()):
-                        logger.info(
-                            "Event %s seems to have been redacted; using our redacted "
-                            "copy",
-                            pdu.event_id,
-                        )
-                    else:
-                        logger.warning(
-                            "Event %s content has been tampered, redacting",
-                            pdu.event_id,
-                        )
-                    return redacted_event
-
-                result = yield defer.ensureDeferred(
-                    self.spam_checker.check_event_for_spam(pdu)
+              * throws a SynapseError if the signature check failed."""
+        try:
+            await _check_sigs_on_pdu(self.keyring, room_version, pdu)
+        except SynapseError as e:
+            logger.warning(
+                "Signature check failed for %s: %s",
+                pdu.event_id,
+                e,
+            )
+            raise
+
+        if not check_event_content_hash(pdu):
+            # let's try to distinguish between failures because the event was
+            # redacted (which are somewhat expected) vs actual ball-tampering
+            # incidents.
+            #
+            # This is just a heuristic, so we just assume that if the keys are
+            # about the same between the redacted and received events, then the
+            # received event was probably a redacted copy (but we then use our
+            # *actual* redacted copy to be on the safe side.)
+            redacted_event = prune_event(pdu)
+            if set(redacted_event.keys()) == set(pdu.keys()) and set(
+                redacted_event.content.keys()
+            ) == set(pdu.content.keys()):
+                logger.info(
+                    "Event %s seems to have been redacted; using our redacted copy",
+                    pdu.event_id,
                 )
-
-                if result:
-                    logger.warning(
-                        "Event contains spam, redacting %s: %s",
-                        pdu.event_id,
-                        pdu.get_pdu_json(),
-                    )
-                    return prune_event(pdu)
-
-                return pdu
-
-        def errback(failure: Failure, pdu: EventBase):
-            failure.trap(SynapseError)
-            with PreserveLoggingContext(ctx):
+            else:
                 logger.warning(
-                    "Signature check failed for %s: %s",
+                    "Event %s content has been tampered, redacting",
                     pdu.event_id,
-                    failure.getErrorMessage(),
                 )
-            return failure
+            return redacted_event
 
-        for deferred, pdu in zip(deferreds, pdus):
-            deferred.addCallbacks(
-                callback, errback, callbackArgs=[pdu], errbackArgs=[pdu]
+        result = await self.spam_checker.check_event_for_spam(pdu)
+
+        if result:
+            logger.warning(
+                "Event contains spam, redacting %s: %s",
+                pdu.event_id,
+                pdu.get_pdu_json(),
             )
+            return prune_event(pdu)
 
-        return deferreds
+        return pdu
 
 
 class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])):
     pass
 
 
-def _check_sigs_on_pdus(
-    keyring: Keyring, room_version: RoomVersion, pdus: Iterable[EventBase]
-) -> List[Deferred]:
+async def _check_sigs_on_pdu(
+    keyring: Keyring, room_version: RoomVersion, pdu: EventBase
+) -> None:
     """Check that the given events are correctly signed
 
+    Raise a SynapseError if the event wasn't correctly signed.
+
     Args:
         keyring: keyring object to do the checks
         room_version: the room version of the PDUs
         pdus: the events to be checked
-
-    Returns:
-        A Deferred for each event in pdus, which will either succeed if
-           the signatures are valid, or fail (with a SynapseError) if not.
     """
 
     # we want to check that the event is signed by:
@@ -177,90 +137,47 @@ def _check_sigs_on_pdus(
     # let's start by getting the domain for each pdu, and flattening the event back
     # to JSON.
 
-    pdus_to_check = [
-        PduToCheckSig(
-            pdu=p,
-            sender_domain=get_domain_from_id(p.sender),
-            deferreds=[],
-        )
-        for p in pdus
-    ]
-
     # First we check that the sender event is signed by the sender's domain
     # (except if its a 3pid invite, in which case it may be sent by any server)
-    pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
-
-    more_deferreds = keyring.verify_events_for_server(
-        [
-            (
-                p.sender_domain,
-                p.pdu,
-                p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
+    if not _is_invite_via_3pid(pdu):
+        try:
+            await keyring.verify_event_for_server(
+                get_domain_from_id(pdu.sender),
+                pdu,
+                pdu.origin_server_ts if room_version.enforce_key_validity else 0,
             )
-            for p in pdus_to_check_sender
-        ]
-    )
-
-    def sender_err(e, pdu_to_check):
-        errmsg = "event id %s: unable to verify signature for sender %s: %s" % (
-            pdu_to_check.pdu.event_id,
-            pdu_to_check.sender_domain,
-            e.getErrorMessage(),
-        )
-        raise SynapseError(403, errmsg, Codes.FORBIDDEN)
-
-    for p, d in zip(pdus_to_check_sender, more_deferreds):
-        d.addErrback(sender_err, p)
-        p.deferreds.append(d)
+        except Exception as e:
+            errmsg = "event id %s: unable to verify signature for sender %s: %s" % (
+                pdu.event_id,
+                get_domain_from_id(pdu.sender),
+                e,
+            )
+            raise SynapseError(403, errmsg, Codes.FORBIDDEN)
 
     # now let's look for events where the sender's domain is different to the
     # event id's domain (normally only the case for joins/leaves), and add additional
     # checks. Only do this if the room version has a concept of event ID domain
     # (ie, the room version uses old-style non-hash event IDs).
-    if room_version.event_format == EventFormatVersions.V1:
-        pdus_to_check_event_id = [
-            p
-            for p in pdus_to_check
-            if p.sender_domain != get_domain_from_id(p.pdu.event_id)
-        ]
-
-        more_deferreds = keyring.verify_events_for_server(
-            [
-                (
-                    get_domain_from_id(p.pdu.event_id),
-                    p.pdu,
-                    p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
-                )
-                for p in pdus_to_check_event_id
-            ]
-        )
-
-        def event_err(e, pdu_to_check):
+    if room_version.event_format == EventFormatVersions.V1 and get_domain_from_id(
+        pdu.event_id
+    ) != get_domain_from_id(pdu.sender):
+        try:
+            await keyring.verify_event_for_server(
+                get_domain_from_id(pdu.event_id),
+                pdu,
+                pdu.origin_server_ts if room_version.enforce_key_validity else 0,
+            )
+        except Exception as e:
             errmsg = (
-                "event id %s: unable to verify signature for event id domain: %s"
-                % (pdu_to_check.pdu.event_id, e.getErrorMessage())
+                "event id %s: unable to verify signature for event id domain %s: %s"
+                % (
+                    pdu.event_id,
+                    get_domain_from_id(pdu.event_id),
+                    e,
+                )
             )
             raise SynapseError(403, errmsg, Codes.FORBIDDEN)
 
-        for p, d in zip(pdus_to_check_event_id, more_deferreds):
-            d.addErrback(event_err, p)
-            p.deferreds.append(d)
-
-    # replace lists of deferreds with single Deferreds
-    return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check]
-
-
-def _flatten_deferred_list(deferreds: List[Deferred]) -> Deferred:
-    """Given a list of deferreds, either return the single deferred,
-    combine into a DeferredList, or return an already resolved deferred.
-    """
-    if len(deferreds) > 1:
-        return DeferredList(deferreds, fireOnOneErrback=True, consumeErrors=True)
-    elif len(deferreds) == 1:
-        return deferreds[0]
-    else:
-        return defer.succeed(None)
-
 
 def _is_invite_via_3pid(event: EventBase) -> bool:
     return (
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index e0e9f5d0be..1076ebc036 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -21,6 +21,7 @@ from typing import (
     Any,
     Awaitable,
     Callable,
+    Collection,
     Dict,
     Iterable,
     List,
@@ -35,9 +36,6 @@ from typing import (
 import attr
 from prometheus_client import Counter
 
-from twisted.internet import defer
-from twisted.internet.defer import Deferred
-
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import (
     CodeMessageException,
@@ -56,10 +54,9 @@ from synapse.api.room_versions import (
 from synapse.events import EventBase, builder
 from synapse.federation.federation_base import FederationBase, event_from_pdu_json
 from synapse.federation.transport.client import SendJoinResponse
-from synapse.logging.context import make_deferred_yieldable, preserve_fn
 from synapse.logging.utils import log_function
 from synapse.types import JsonDict, get_domain_from_id
-from synapse.util import unwrapFirstError
+from synapse.util.async_helpers import concurrently_execute
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.retryutils import NotRetryingDestination
 
@@ -360,10 +357,9 @@ class FederationClient(FederationBase):
     async def _check_sigs_and_hash_and_fetch(
         self,
         origin: str,
-        pdus: List[EventBase],
+        pdus: Collection[EventBase],
         room_version: RoomVersion,
         outlier: bool = False,
-        include_none: bool = False,
     ) -> List[EventBase]:
         """Takes a list of PDUs and checks the signatures and hashes of each
         one. If a PDU fails its signature check then we check if we have it in
@@ -380,57 +376,87 @@ class FederationClient(FederationBase):
             pdu
             room_version
             outlier: Whether the events are outliers or not
-            include_none: Whether to include None in the returned list
-                for events that have failed their checks
 
         Returns:
             A list of PDUs that have valid signatures and hashes.
         """
-        deferreds = self._check_sigs_and_hashes(room_version, pdus)
 
-        async def handle_check_result(pdu: EventBase, deferred: Deferred):
-            try:
-                res = await make_deferred_yieldable(deferred)
-            except SynapseError:
-                res = None
+        # We limit how many PDUs we check at once, as if we try to do hundreds
+        # of thousands of PDUs at once we see large memory spikes.
 
-            if not res:
-                # Check local db.
-                res = await self.store.get_event(
-                    pdu.event_id, allow_rejected=True, allow_none=True
-                )
+        valid_pdus = []
 
-            pdu_origin = get_domain_from_id(pdu.sender)
-            if not res and pdu_origin != origin:
-                try:
-                    res = await self.get_pdu(
-                        destinations=[pdu_origin],
-                        event_id=pdu.event_id,
-                        room_version=room_version,
-                        outlier=outlier,
-                        timeout=10000,
-                    )
-                except SynapseError:
-                    pass
+        async def _execute(pdu: EventBase) -> None:
+            valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
+                pdu=pdu,
+                origin=origin,
+                outlier=outlier,
+                room_version=room_version,
+            )
 
-            if not res:
-                logger.warning(
-                    "Failed to find copy of %s with valid signature", pdu.event_id
-                )
+            if valid_pdu:
+                valid_pdus.append(valid_pdu)
 
-            return res
+        await concurrently_execute(_execute, pdus, 10000)
 
-        handle = preserve_fn(handle_check_result)
-        deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)]
+        return valid_pdus
 
-        valid_pdus = await make_deferred_yieldable(
-            defer.gatherResults(deferreds2, consumeErrors=True)
-        ).addErrback(unwrapFirstError)
+    async def _check_sigs_and_hash_and_fetch_one(
+        self,
+        pdu: EventBase,
+        origin: str,
+        room_version: RoomVersion,
+        outlier: bool = False,
+    ) -> Optional[EventBase]:
+        """Takes a PDU and checks its signatures and hashes. If the PDU fails
+        its signature check then we check if we have it in the database and if
+        not then request if from the originating server of that PDU.
 
-        if include_none:
-            return valid_pdus
-        else:
-            return [p for p in valid_pdus if p]
+        If then PDU fails its content hash check then it is redacted.
+
+        Args:
+            origin
+            pdu
+            room_version
+            outlier: Whether the events are outliers or not
+            include_none: Whether to include None in the returned list
+                for events that have failed their checks
+
+        Returns:
+            The PDU (possibly redacted) if it has valid signatures and hashes.
+        """
+
+        res = None
+        try:
+            res = await self._check_sigs_and_hash(room_version, pdu)
+        except SynapseError:
+            pass
+
+        if not res:
+            # Check local db.
+            res = await self.store.get_event(
+                pdu.event_id, allow_rejected=True, allow_none=True
+            )
+
+        pdu_origin = get_domain_from_id(pdu.sender)
+        if not res and pdu_origin != origin:
+            try:
+                res = await self.get_pdu(
+                    destinations=[pdu_origin],
+                    event_id=pdu.event_id,
+                    room_version=room_version,
+                    outlier=outlier,
+                    timeout=10000,
+                )
+            except SynapseError:
+                pass
+
+        if not res:
+            logger.warning(
+                "Failed to find copy of %s with valid signature", pdu.event_id
+            )
+
+        return res
 
     async def get_event_auth(
         self, destination: str, room_id: str, event_id: str
@@ -671,8 +697,6 @@ class FederationClient(FederationBase):
             state = response.state
             auth_chain = response.auth_events
 
-            pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)}
-
             create_event = None
             for e in state:
                 if (e.type, e.state_key) == (EventTypes.Create, ""):
@@ -696,14 +720,29 @@ class FederationClient(FederationBase):
                     % (create_room_version,)
                 )
 
-            valid_pdus = await self._check_sigs_and_hash_and_fetch(
-                destination,
-                list(pdus.values()),
-                outlier=True,
-                room_version=room_version,
+            logger.info(
+                "Processing from send_join %d events", len(state) + len(auth_chain)
             )
 
-            valid_pdus_map = {p.event_id: p for p in valid_pdus}
+            # We now go and check the signatures and hashes for the event. Note
+            # that we limit how many events we process at a time to keep the
+            # memory overhead from exploding.
+            valid_pdus_map: Dict[str, EventBase] = {}
+
+            async def _execute(pdu: EventBase) -> None:
+                valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
+                    pdu=pdu,
+                    origin=destination,
+                    outlier=True,
+                    room_version=room_version,
+                )
+
+                if valid_pdu:
+                    valid_pdus_map[valid_pdu.event_id] = valid_pdu
+
+            await concurrently_execute(
+                _execute, itertools.chain(state, auth_chain), 10000
+            )
 
             # NB: We *need* to copy to ensure that we don't have multiple
             # references being passed on, as that causes... issues.
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 5c55bb0125..061102c3c8 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -15,6 +15,7 @@
 
 import collections
 import inspect
+import itertools
 import logging
 from contextlib import contextmanager
 from typing import (
@@ -160,8 +161,11 @@ class ObservableDeferred:
         )
 
 
+T = TypeVar("T")
+
+
 def concurrently_execute(
-    func: Callable, args: Iterable[Any], limit: int
+    func: Callable[[T], Any], args: Iterable[T], limit: int
 ) -> defer.Deferred:
     """Executes the function with each argument concurrently while limiting
     the number of concurrent executions.
@@ -173,20 +177,27 @@ def concurrently_execute(
         limit: Maximum number of conccurent executions.
 
     Returns:
-        Deferred[list]: Resolved when all function invocations have finished.
+        Deferred: Resolved when all function invocations have finished.
     """
     it = iter(args)
 
-    async def _concurrently_execute_inner():
+    async def _concurrently_execute_inner(value: T) -> None:
         try:
             while True:
-                await maybe_awaitable(func(next(it)))
+                await maybe_awaitable(func(value))
+                value = next(it)
         except StopIteration:
             pass
 
+    # We use `itertools.islice` to handle the case where the number of args is
+    # less than the limit, avoiding needlessly spawning unnecessary background
+    # tasks.
     return make_deferred_yieldable(
         defer.gatherResults(
-            [run_in_background(_concurrently_execute_inner) for _ in range(limit)],
+            [
+                run_in_background(_concurrently_execute_inner, value)
+                for value in itertools.islice(it, limit)
+            ],
             consumeErrors=True,
         )
     ).addErrback(unwrapFirstError)