summary refs log tree commit diff
path: root/synapse/federation/federation_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/federation_base.py')
-rw-r--r--synapse/federation/federation_base.py243
1 files changed, 80 insertions, 163 deletions
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 3fe496dcd3..c066617b92 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -14,11 +14,6 @@
 # limitations under the License.
 import logging
 from collections import namedtuple
-from typing import Iterable, List
-
-from twisted.internet import defer
-from twisted.internet.defer import Deferred, DeferredList
-from twisted.python.failure import Failure
 
 from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
 from synapse.api.errors import Codes, SynapseError
@@ -28,11 +23,6 @@ from synapse.crypto.keyring import Keyring
 from synapse.events import EventBase, make_event_from_dict
 from synapse.events.utils import prune_event, validate_canonicaljson
 from synapse.http.servlet import assert_params_in_dict
-from synapse.logging.context import (
-    PreserveLoggingContext,
-    current_context,
-    make_deferred_yieldable,
-)
 from synapse.types import JsonDict, get_domain_from_id
 
 logger = logging.getLogger(__name__)
@@ -48,112 +38,82 @@ class FederationBase:
         self.store = hs.get_datastore()
         self._clock = hs.get_clock()
 
-    def _check_sigs_and_hash(
+    async def _check_sigs_and_hash(
         self, room_version: RoomVersion, pdu: EventBase
-    ) -> Deferred:
-        return make_deferred_yieldable(
-            self._check_sigs_and_hashes(room_version, [pdu])[0]
-        )
-
-    def _check_sigs_and_hashes(
-        self, room_version: RoomVersion, pdus: List[EventBase]
-    ) -> List[Deferred]:
-        """Checks that each of the received events is correctly signed by the
-        sending server.
+    ) -> EventBase:
+        """Checks that event is correctly signed by the sending server.
 
         Args:
-            room_version: The room version of the PDUs
-            pdus: the events to be checked
+            room_version: The room version of the PDU
+            pdu: the event to be checked
 
         Returns:
-            For each input event, a deferred which:
-              * returns the original event if the checks pass
-              * returns a redacted version of the event (if the signature
+              * the original event if the checks pass
+              * a redacted version of the event (if the signature
                 matched but the hash did not)
-              * throws a SynapseError if the signature check failed.
-            The deferreds run their callbacks in the sentinel
-        """
-        deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus)
-
-        ctx = current_context()
-
-        @defer.inlineCallbacks
-        def callback(_, pdu: EventBase):
-            with PreserveLoggingContext(ctx):
-                if not check_event_content_hash(pdu):
-                    # let's try to distinguish between failures because the event was
-                    # redacted (which are somewhat expected) vs actual ball-tampering
-                    # incidents.
-                    #
-                    # This is just a heuristic, so we just assume that if the keys are
-                    # about the same between the redacted and received events, then the
-                    # received event was probably a redacted copy (but we then use our
-                    # *actual* redacted copy to be on the safe side.)
-                    redacted_event = prune_event(pdu)
-                    if set(redacted_event.keys()) == set(pdu.keys()) and set(
-                        redacted_event.content.keys()
-                    ) == set(pdu.content.keys()):
-                        logger.info(
-                            "Event %s seems to have been redacted; using our redacted "
-                            "copy",
-                            pdu.event_id,
-                        )
-                    else:
-                        logger.warning(
-                            "Event %s content has been tampered, redacting",
-                            pdu.event_id,
-                        )
-                    return redacted_event
-
-                result = yield defer.ensureDeferred(
-                    self.spam_checker.check_event_for_spam(pdu)
+              * throws a SynapseError if the signature check failed."""
+        try:
+            await _check_sigs_on_pdu(self.keyring, room_version, pdu)
+        except SynapseError as e:
+            logger.warning(
+                "Signature check failed for %s: %s",
+                pdu.event_id,
+                e,
+            )
+            raise
+
+        if not check_event_content_hash(pdu):
+            # let's try to distinguish between failures because the event was
+            # redacted (which are somewhat expected) vs actual ball-tampering
+            # incidents.
+            #
+            # This is just a heuristic, so we just assume that if the keys are
+            # about the same between the redacted and received events, then the
+            # received event was probably a redacted copy (but we then use our
+            # *actual* redacted copy to be on the safe side.)
+            redacted_event = prune_event(pdu)
+            if set(redacted_event.keys()) == set(pdu.keys()) and set(
+                redacted_event.content.keys()
+            ) == set(pdu.content.keys()):
+                logger.info(
+                    "Event %s seems to have been redacted; using our redacted copy",
+                    pdu.event_id,
                 )
-
-                if result:
-                    logger.warning(
-                        "Event contains spam, redacting %s: %s",
-                        pdu.event_id,
-                        pdu.get_pdu_json(),
-                    )
-                    return prune_event(pdu)
-
-                return pdu
-
-        def errback(failure: Failure, pdu: EventBase):
-            failure.trap(SynapseError)
-            with PreserveLoggingContext(ctx):
+            else:
                 logger.warning(
-                    "Signature check failed for %s: %s",
+                    "Event %s content has been tampered, redacting",
                     pdu.event_id,
-                    failure.getErrorMessage(),
                 )
-            return failure
+            return redacted_event
 
-        for deferred, pdu in zip(deferreds, pdus):
-            deferred.addCallbacks(
-                callback, errback, callbackArgs=[pdu], errbackArgs=[pdu]
+        result = await self.spam_checker.check_event_for_spam(pdu)
+
+        if result:
+            logger.warning(
+                "Event contains spam, redacting %s: %s",
+                pdu.event_id,
+                pdu.get_pdu_json(),
             )
+            return prune_event(pdu)
 
-        return deferreds
+        return pdu
 
 
 class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])):
     pass
 
 
-def _check_sigs_on_pdus(
-    keyring: Keyring, room_version: RoomVersion, pdus: Iterable[EventBase]
-) -> List[Deferred]:
+async def _check_sigs_on_pdu(
+    keyring: Keyring, room_version: RoomVersion, pdu: EventBase
+) -> None:
     """Check that the given events are correctly signed
 
+    Raise a SynapseError if the event wasn't correctly signed.
+
     Args:
         keyring: keyring object to do the checks
         room_version: the room version of the PDUs
         pdus: the events to be checked
-
-    Returns:
-        A Deferred for each event in pdus, which will either succeed if
-           the signatures are valid, or fail (with a SynapseError) if not.
     """
 
     # we want to check that the event is signed by:
@@ -177,90 +137,47 @@ def _check_sigs_on_pdus(
     # let's start by getting the domain for each pdu, and flattening the event back
     # to JSON.
 
-    pdus_to_check = [
-        PduToCheckSig(
-            pdu=p,
-            sender_domain=get_domain_from_id(p.sender),
-            deferreds=[],
-        )
-        for p in pdus
-    ]
-
     # First we check that the sender event is signed by the sender's domain
     # (except if its a 3pid invite, in which case it may be sent by any server)
-    pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
-
-    more_deferreds = keyring.verify_events_for_server(
-        [
-            (
-                p.sender_domain,
-                p.pdu,
-                p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
+    if not _is_invite_via_3pid(pdu):
+        try:
+            await keyring.verify_event_for_server(
+                get_domain_from_id(pdu.sender),
+                pdu,
+                pdu.origin_server_ts if room_version.enforce_key_validity else 0,
             )
-            for p in pdus_to_check_sender
-        ]
-    )
-
-    def sender_err(e, pdu_to_check):
-        errmsg = "event id %s: unable to verify signature for sender %s: %s" % (
-            pdu_to_check.pdu.event_id,
-            pdu_to_check.sender_domain,
-            e.getErrorMessage(),
-        )
-        raise SynapseError(403, errmsg, Codes.FORBIDDEN)
-
-    for p, d in zip(pdus_to_check_sender, more_deferreds):
-        d.addErrback(sender_err, p)
-        p.deferreds.append(d)
+        except Exception as e:
+            errmsg = "event id %s: unable to verify signature for sender %s: %s" % (
+                pdu.event_id,
+                get_domain_from_id(pdu.sender),
+                e,
+            )
+            raise SynapseError(403, errmsg, Codes.FORBIDDEN)
 
     # now let's look for events where the sender's domain is different to the
     # event id's domain (normally only the case for joins/leaves), and add additional
     # checks. Only do this if the room version has a concept of event ID domain
     # (ie, the room version uses old-style non-hash event IDs).
-    if room_version.event_format == EventFormatVersions.V1:
-        pdus_to_check_event_id = [
-            p
-            for p in pdus_to_check
-            if p.sender_domain != get_domain_from_id(p.pdu.event_id)
-        ]
-
-        more_deferreds = keyring.verify_events_for_server(
-            [
-                (
-                    get_domain_from_id(p.pdu.event_id),
-                    p.pdu,
-                    p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
-                )
-                for p in pdus_to_check_event_id
-            ]
-        )
-
-        def event_err(e, pdu_to_check):
+    if room_version.event_format == EventFormatVersions.V1 and get_domain_from_id(
+        pdu.event_id
+    ) != get_domain_from_id(pdu.sender):
+        try:
+            await keyring.verify_event_for_server(
+                get_domain_from_id(pdu.event_id),
+                pdu,
+                pdu.origin_server_ts if room_version.enforce_key_validity else 0,
+            )
+        except Exception as e:
             errmsg = (
-                "event id %s: unable to verify signature for event id domain: %s"
-                % (pdu_to_check.pdu.event_id, e.getErrorMessage())
+                "event id %s: unable to verify signature for event id domain %s: %s"
+                % (
+                    pdu.event_id,
+                    get_domain_from_id(pdu.event_id),
+                    e,
+                )
             )
             raise SynapseError(403, errmsg, Codes.FORBIDDEN)
 
-        for p, d in zip(pdus_to_check_event_id, more_deferreds):
-            d.addErrback(event_err, p)
-            p.deferreds.append(d)
-
-    # replace lists of deferreds with single Deferreds
-    return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check]
-
-
-def _flatten_deferred_list(deferreds: List[Deferred]) -> Deferred:
-    """Given a list of deferreds, either return the single deferred,
-    combine into a DeferredList, or return an already resolved deferred.
-    """
-    if len(deferreds) > 1:
-        return DeferredList(deferreds, fireOnOneErrback=True, consumeErrors=True)
-    elif len(deferreds) == 1:
-        return deferreds[0]
-    else:
-        return defer.succeed(None)
-
 
 def _is_invite_via_3pid(event: EventBase) -> bool:
     return (