summary refs log tree commit diff
path: root/synapse/federation/federation_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/federation_base.py')
-rw-r--r--synapse/federation/federation_base.py166
1 files changed, 49 insertions, 117 deletions
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 5a1e23a145..c0012c6872 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,27 +15,28 @@
 # limitations under the License.
 import logging
 from collections import namedtuple
+from typing import Iterable, List
 
 import six
 
 from twisted.internet import defer
-from twisted.internet.defer import DeferredList
+from twisted.internet.defer import Deferred, DeferredList
+from twisted.python.failure import Failure
 
 from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
 from synapse.api.errors import Codes, SynapseError
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, EventFormatVersions
+from synapse.api.room_versions import EventFormatVersions, RoomVersion
 from synapse.crypto.event_signing import check_event_content_hash
-from synapse.events import event_type_from_format_version
-from synapse.events.utils import prune_event
+from synapse.crypto.keyring import Keyring
+from synapse.events import EventBase, make_event_from_dict
+from synapse.events.utils import prune_event, validate_canonicaljson
 from synapse.http.servlet import assert_params_in_dict
 from synapse.logging.context import (
-    LoggingContext,
     PreserveLoggingContext,
+    current_context,
     make_deferred_yieldable,
-    preserve_fn,
 )
-from synapse.types import get_domain_from_id
-from synapse.util import unwrapFirstError
+from synapse.types import JsonDict, get_domain_from_id
 
 logger = logging.getLogger(__name__)
 
@@ -49,92 +51,25 @@ class FederationBase(object):
         self.store = hs.get_datastore()
         self._clock = hs.get_clock()
 
-    @defer.inlineCallbacks
-    def _check_sigs_and_hash_and_fetch(
-        self, origin, pdus, room_version, outlier=False, include_none=False
-    ):
-        """Takes a list of PDUs and checks the signatures and hashs of each
-        one. If a PDU fails its signature check then we check if we have it in
-        the database and if not then request if from the originating server of
-        that PDU.
-
-        If a PDU fails its content hash check then it is redacted.
-
-        The given list of PDUs are not modified, instead the function returns
-        a new list.
-
-        Args:
-            origin (str)
-            pdu (list)
-            room_version (str)
-            outlier (bool): Whether the events are outliers or not
-            include_none (str): Whether to include None in the returned list
-                for events that have failed their checks
-
-        Returns:
-            Deferred : A list of PDUs that have valid signatures and hashes.
-        """
-        deferreds = self._check_sigs_and_hashes(room_version, pdus)
-
-        @defer.inlineCallbacks
-        def handle_check_result(pdu, deferred):
-            try:
-                res = yield make_deferred_yieldable(deferred)
-            except SynapseError:
-                res = None
-
-            if not res:
-                # Check local db.
-                res = yield self.store.get_event(
-                    pdu.event_id, allow_rejected=True, allow_none=True
-                )
-
-            if not res and pdu.origin != origin:
-                try:
-                    res = yield self.get_pdu(
-                        destinations=[pdu.origin],
-                        event_id=pdu.event_id,
-                        room_version=room_version,
-                        outlier=outlier,
-                        timeout=10000,
-                    )
-                except SynapseError:
-                    pass
-
-            if not res:
-                logger.warn(
-                    "Failed to find copy of %s with valid signature", pdu.event_id
-                )
-
-            return res
-
-        handle = preserve_fn(handle_check_result)
-        deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)]
-
-        valid_pdus = yield make_deferred_yieldable(
-            defer.gatherResults(deferreds2, consumeErrors=True)
-        ).addErrback(unwrapFirstError)
-
-        if include_none:
-            return valid_pdus
-        else:
-            return [p for p in valid_pdus if p]
-
-    def _check_sigs_and_hash(self, room_version, pdu):
+    def _check_sigs_and_hash(
+        self, room_version: RoomVersion, pdu: EventBase
+    ) -> Deferred:
         return make_deferred_yieldable(
             self._check_sigs_and_hashes(room_version, [pdu])[0]
         )
 
-    def _check_sigs_and_hashes(self, room_version, pdus):
+    def _check_sigs_and_hashes(
+        self, room_version: RoomVersion, pdus: List[EventBase]
+    ) -> List[Deferred]:
         """Checks that each of the received events is correctly signed by the
         sending server.
 
         Args:
-            room_version (str): The room version of the PDUs
-            pdus (list[FrozenEvent]): the events to be checked
+            room_version: The room version of the PDUs
+            pdus: the events to be checked
 
         Returns:
-            list[Deferred]: for each input event, a deferred which:
+            For each input event, a deferred which:
               * returns the original event if the checks pass
               * returns a redacted version of the event (if the signature
                 matched but the hash did not)
@@ -143,9 +78,9 @@ class FederationBase(object):
         """
         deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus)
 
-        ctx = LoggingContext.current_context()
+        ctx = current_context()
 
-        def callback(_, pdu):
+        def callback(_, pdu: EventBase):
             with PreserveLoggingContext(ctx):
                 if not check_event_content_hash(pdu):
                     # let's try to distinguish between failures because the event was
@@ -173,7 +108,7 @@ class FederationBase(object):
                     return redacted_event
 
                 if self.spam_checker.check_event_for_spam(pdu):
-                    logger.warn(
+                    logger.warning(
                         "Event contains spam, redacting %s: %s",
                         pdu.event_id,
                         pdu.get_pdu_json(),
@@ -182,10 +117,10 @@ class FederationBase(object):
 
                 return pdu
 
-        def errback(failure, pdu):
+        def errback(failure: Failure, pdu: EventBase):
             failure.trap(SynapseError)
             with PreserveLoggingContext(ctx):
-                logger.warn(
+                logger.warning(
                     "Signature check failed for %s: %s",
                     pdu.event_id,
                     failure.getErrorMessage(),
@@ -208,16 +143,18 @@ class PduToCheckSig(
     pass
 
 
-def _check_sigs_on_pdus(keyring, room_version, pdus):
+def _check_sigs_on_pdus(
+    keyring: Keyring, room_version: RoomVersion, pdus: Iterable[EventBase]
+) -> List[Deferred]:
     """Check that the given events are correctly signed
 
     Args:
-        keyring (synapse.crypto.Keyring): keyring object to do the checks
-        room_version (str): the room version of the PDUs
-        pdus (Collection[EventBase]): the events to be checked
+        keyring: keyring object to do the checks
+        room_version: the room version of the PDUs
+        pdus: the events to be checked
 
     Returns:
-        List[Deferred]: a Deferred for each event in pdus, which will either succeed if
+        A Deferred for each event in pdus, which will either succeed if
            the signatures are valid, or fail (with a SynapseError) if not.
     """
 
@@ -252,10 +189,6 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
         for p in pdus
     ]
 
-    v = KNOWN_ROOM_VERSIONS.get(room_version)
-    if not v:
-        raise RuntimeError("Unrecognized room version %s" % (room_version,))
-
     # First we check that the sender event is signed by the sender's domain
     # (except if its a 3pid invite, in which case it may be sent by any server)
     pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
@@ -265,7 +198,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
             (
                 p.sender_domain,
                 p.redacted_pdu_json,
-                p.pdu.origin_server_ts if v.enforce_key_validity else 0,
+                p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
                 p.pdu.event_id,
             )
             for p in pdus_to_check_sender
@@ -278,9 +211,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
             pdu_to_check.sender_domain,
             e.getErrorMessage(),
         )
-        # XX not really sure if these are the right codes, but they are what
-        # we've done for ages
-        raise SynapseError(400, errmsg, Codes.UNAUTHORIZED)
+        raise SynapseError(403, errmsg, Codes.FORBIDDEN)
 
     for p, d in zip(pdus_to_check_sender, more_deferreds):
         d.addErrback(sender_err, p)
@@ -290,7 +221,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
     # event id's domain (normally only the case for joins/leaves), and add additional
     # checks. Only do this if the room version has a concept of event ID domain
     # (ie, the room version uses old-style non-hash event IDs).
-    if v.event_format == EventFormatVersions.V1:
+    if room_version.event_format == EventFormatVersions.V1:
         pdus_to_check_event_id = [
             p
             for p in pdus_to_check
@@ -302,7 +233,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
                 (
                     get_domain_from_id(p.pdu.event_id),
                     p.redacted_pdu_json,
-                    p.pdu.origin_server_ts if v.enforce_key_validity else 0,
+                    p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
                     p.pdu.event_id,
                 )
                 for p in pdus_to_check_event_id
@@ -314,8 +245,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
                 "event id %s: unable to verify signature for event id domain: %s"
                 % (pdu_to_check.pdu.event_id, e.getErrorMessage())
             )
-            # XX as above: not really sure if these are the right codes
-            raise SynapseError(400, errmsg, Codes.UNAUTHORIZED)
+            raise SynapseError(403, errmsg, Codes.FORBIDDEN)
 
         for p, d in zip(pdus_to_check_event_id, more_deferreds):
             d.addErrback(event_err, p)
@@ -325,7 +255,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
     return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check]
 
 
-def _flatten_deferred_list(deferreds):
+def _flatten_deferred_list(deferreds: List[Deferred]) -> Deferred:
     """Given a list of deferreds, either return the single deferred,
     combine into a DeferredList, or return an already resolved deferred.
     """
@@ -337,7 +267,7 @@ def _flatten_deferred_list(deferreds):
         return defer.succeed(None)
 
 
-def _is_invite_via_3pid(event):
+def _is_invite_via_3pid(event: EventBase) -> bool:
     return (
         event.type == EventTypes.Member
         and event.membership == Membership.INVITE
@@ -345,16 +275,15 @@ def _is_invite_via_3pid(event):
     )
 
 
-def event_from_pdu_json(pdu_json, event_format_version, outlier=False):
-    """Construct a FrozenEvent from an event json received over federation
+def event_from_pdu_json(
+    pdu_json: JsonDict, room_version: RoomVersion, outlier: bool = False
+) -> EventBase:
+    """Construct an EventBase from an event json received over federation
 
     Args:
-        pdu_json (object): pdu as received over federation
-        event_format_version (int): The event format version
-        outlier (bool): True to mark this event as an outlier
-
-    Returns:
-        FrozenEvent
+        pdu_json: pdu as received over federation
+        room_version: The version of the room this event belongs to
+        outlier: True to mark this event as an outlier
 
     Raises:
         SynapseError: if the pdu is missing required fields or is otherwise
@@ -373,8 +302,11 @@ def event_from_pdu_json(pdu_json, event_format_version, outlier=False):
     elif depth > MAX_DEPTH:
         raise SynapseError(400, "Depth too large", Codes.BAD_JSON)
 
-    event = event_type_from_format_version(event_format_version)(pdu_json)
+    # Validate that the JSON conforms to the specification.
+    if room_version.strict_canonicaljson:
+        validate_canonicaljson(pdu_json)
 
+    event = make_event_from_dict(pdu_json, room_version)
     event.internal_metadata.outlier = outlier
 
     return event