summary refs log tree commit diff
path: root/synapse/federation/federation_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/federation_base.py')
-rw-r--r--synapse/federation/federation_base.py78
1 files changed, 32 insertions, 46 deletions
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index fc5cfb7d83..58b929363f 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -44,8 +44,9 @@ class FederationBase(object):
         self._clock = hs.get_clock()
 
     @defer.inlineCallbacks
-    def _check_sigs_and_hash_and_fetch(self, origin, pdus, room_version,
-                                       outlier=False, include_none=False):
+    def _check_sigs_and_hash_and_fetch(
+        self, origin, pdus, room_version, outlier=False, include_none=False
+    ):
         """Takes a list of PDUs and checks the signatures and hashs of each
         one. If a PDU fails its signature check then we check if we have it in
         the database and if not then request if from the originating server of
@@ -79,9 +80,7 @@ class FederationBase(object):
             if not res:
                 # Check local db.
                 res = yield self.store.get_event(
-                    pdu.event_id,
-                    allow_rejected=True,
-                    allow_none=True,
+                    pdu.event_id, allow_rejected=True, allow_none=True
                 )
 
             if not res and pdu.origin != origin:
@@ -98,23 +97,16 @@ class FederationBase(object):
 
             if not res:
                 logger.warn(
-                    "Failed to find copy of %s with valid signature",
-                    pdu.event_id,
+                    "Failed to find copy of %s with valid signature", pdu.event_id
                 )
 
             defer.returnValue(res)
 
         handle = logcontext.preserve_fn(handle_check_result)
-        deferreds2 = [
-            handle(pdu, deferred)
-            for pdu, deferred in zip(pdus, deferreds)
-        ]
+        deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)]
 
         valid_pdus = yield logcontext.make_deferred_yieldable(
-            defer.gatherResults(
-                deferreds2,
-                consumeErrors=True,
-            )
+            defer.gatherResults(deferreds2, consumeErrors=True)
         ).addErrback(unwrapFirstError)
 
         if include_none:
@@ -124,7 +116,7 @@ class FederationBase(object):
 
     def _check_sigs_and_hash(self, room_version, pdu):
         return logcontext.make_deferred_yieldable(
-            self._check_sigs_and_hashes(room_version, [pdu])[0],
+            self._check_sigs_and_hashes(room_version, [pdu])[0]
         )
 
     def _check_sigs_and_hashes(self, room_version, pdus):
@@ -159,11 +151,9 @@ class FederationBase(object):
                     # received event was probably a redacted copy (but we then use our
                     # *actual* redacted copy to be on the safe side.)
                     redacted_event = prune_event(pdu)
-                    if (
-                        set(redacted_event.keys()) == set(pdu.keys()) and
-                        set(six.iterkeys(redacted_event.content))
-                            == set(six.iterkeys(pdu.content))
-                    ):
+                    if set(redacted_event.keys()) == set(pdu.keys()) and set(
+                        six.iterkeys(redacted_event.content)
+                    ) == set(six.iterkeys(pdu.content)):
                         logger.info(
                             "Event %s seems to have been redacted; using our redacted "
                             "copy",
@@ -172,14 +162,16 @@ class FederationBase(object):
                     else:
                         logger.warning(
                             "Event %s content has been tampered, redacting",
-                            pdu.event_id, pdu.get_pdu_json(),
+                            pdu.event_id,
+                            pdu.get_pdu_json(),
                         )
                     return redacted_event
 
                 if self.spam_checker.check_event_for_spam(pdu):
                     logger.warn(
                         "Event contains spam, redacting %s: %s",
-                        pdu.event_id, pdu.get_pdu_json()
+                        pdu.event_id,
+                        pdu.get_pdu_json(),
                     )
                     return prune_event(pdu)
 
@@ -190,23 +182,24 @@ class FederationBase(object):
             with logcontext.PreserveLoggingContext(ctx):
                 logger.warn(
                     "Signature check failed for %s: %s",
-                    pdu.event_id, failure.getErrorMessage(),
+                    pdu.event_id,
+                    failure.getErrorMessage(),
                 )
             return failure
 
         for deferred, pdu in zip(deferreds, pdus):
             deferred.addCallbacks(
-                callback, errback,
-                callbackArgs=[pdu],
-                errbackArgs=[pdu],
+                callback, errback, callbackArgs=[pdu], errbackArgs=[pdu]
             )
 
         return deferreds
 
 
-class PduToCheckSig(namedtuple("PduToCheckSig", [
-    "pdu", "redacted_pdu_json", "sender_domain", "deferreds",
-])):
+class PduToCheckSig(
+    namedtuple(
+        "PduToCheckSig", ["pdu", "redacted_pdu_json", "sender_domain", "deferreds"]
+    )
+):
     pass
 
 
@@ -260,10 +253,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
 
     # First we check that the sender event is signed by the sender's domain
     # (except if its a 3pid invite, in which case it may be sent by any server)
-    pdus_to_check_sender = [
-        p for p in pdus_to_check
-        if not _is_invite_via_3pid(p.pdu)
-    ]
+    pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
 
     more_deferreds = keyring.verify_json_objects_for_server(
         [
@@ -297,7 +287,8 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
     # (ie, the room version uses old-style non-hash event IDs).
     if v.event_format == EventFormatVersions.V1:
         pdus_to_check_event_id = [
-            p for p in pdus_to_check
+            p
+            for p in pdus_to_check
             if p.sender_domain != get_domain_from_id(p.pdu.event_id)
         ]
 
@@ -315,10 +306,8 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
 
         def event_err(e, pdu_to_check):
             errmsg = (
-                "event id %s: unable to verify signature for event id domain: %s" % (
-                    pdu_to_check.pdu.event_id,
-                    e.getErrorMessage(),
-                )
+                "event id %s: unable to verify signature for event id domain: %s"
+                % (pdu_to_check.pdu.event_id, e.getErrorMessage())
             )
             # XX as above: not really sure if these are the right codes
             raise SynapseError(400, errmsg, Codes.UNAUTHORIZED)
@@ -368,21 +357,18 @@ def event_from_pdu_json(pdu_json, event_format_version, outlier=False):
     """
     # we could probably enforce a bunch of other fields here (room_id, sender,
     # origin, etc etc)
-    assert_params_in_dict(pdu_json, ('type', 'depth'))
+    assert_params_in_dict(pdu_json, ("type", "depth"))
 
-    depth = pdu_json['depth']
+    depth = pdu_json["depth"]
     if not isinstance(depth, six.integer_types):
-        raise SynapseError(400, "Depth %r not an intger" % (depth, ),
-                           Codes.BAD_JSON)
+        raise SynapseError(400, "Depth %r not an intger" % (depth,), Codes.BAD_JSON)
 
     if depth < 0:
         raise SynapseError(400, "Depth too small", Codes.BAD_JSON)
     elif depth > MAX_DEPTH:
         raise SynapseError(400, "Depth too large", Codes.BAD_JSON)
 
-    event = event_type_from_format_version(event_format_version)(
-        pdu_json,
-    )
+    event = event_type_from_format_version(event_format_version)(pdu_json)
 
     event.internal_metadata.outlier = outlier