summary refs log tree commit diff
path: root/synapse/federation/federation_base.py
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2017-09-20 01:32:42 +0100
committerRichard van der Hoff <richard@matrix.org>2017-09-20 01:32:42 +0100
commit6de74ea6d7394b63c9475e9dfff943188a9ed73b (patch)
treee7ac3b9c1ab2d7f08cfe929d0bc491e34fba43bf /synapse/federation/federation_base.py
parentAdd some more tests for Keyring (diff)
downloadsynapse-6de74ea6d7394b63c9475e9dfff943188a9ed73b.tar.xz
Fix logcontexts in _check_sigs_and_hashes
Diffstat (limited to 'synapse/federation/federation_base.py')
-rw-r--r--synapse/federation/federation_base.py114
1 files changed, 58 insertions, 56 deletions
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index cabed33f74..babd9ea078 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -18,8 +18,7 @@ from synapse.api.errors import SynapseError
 from synapse.crypto.event_signing import check_event_content_hash
 from synapse.events import spamcheck
 from synapse.events.utils import prune_event
-from synapse.util import unwrapFirstError
-from synapse.util.logcontext import preserve_context_over_deferred
+from synapse.util import unwrapFirstError, logcontext
 from twisted.internet import defer
 
 logger = logging.getLogger(__name__)
@@ -51,56 +50,52 @@ class FederationBase(object):
         """
         deferreds = self._check_sigs_and_hashes(pdus)
 
-        def callback(pdu):
-            return pdu
+        @defer.inlineCallbacks
+        def handle_check_result(pdu, deferred):
+            try:
+                res = yield logcontext.make_deferred_yieldable(deferred)
+            except SynapseError:
+                res = None
 
-        def errback(failure, pdu):
-            failure.trap(SynapseError)
-            return None
-
-        def try_local_db(res, pdu):
             if not res:
                 # Check local db.
-                return self.store.get_event(
+                res = yield self.store.get_event(
                     pdu.event_id,
                     allow_rejected=True,
                     allow_none=True,
                 )
-            return res
 
-        def try_remote(res, pdu):
             if not res and pdu.origin != origin:
-                return self.get_pdu(
-                    destinations=[pdu.origin],
-                    event_id=pdu.event_id,
-                    outlier=outlier,
-                    timeout=10000,
-                ).addErrback(lambda e: None)
-            return res
-
-        def warn(res, pdu):
+                try:
+                    res = yield self.get_pdu(
+                        destinations=[pdu.origin],
+                        event_id=pdu.event_id,
+                        outlier=outlier,
+                        timeout=10000,
+                    )
+                except SynapseError:
+                    pass
+
             if not res:
                 logger.warn(
                     "Failed to find copy of %s with valid signature",
                     pdu.event_id,
                 )
-            return res
 
-        for pdu, deferred in zip(pdus, deferreds):
-            deferred.addCallbacks(
-                callback, errback, errbackArgs=[pdu]
-            ).addCallback(
-                try_local_db, pdu
-            ).addCallback(
-                try_remote, pdu
-            ).addCallback(
-                warn, pdu
-            )
+            defer.returnValue(res)
 
-        valid_pdus = yield preserve_context_over_deferred(defer.gatherResults(
-            deferreds,
-            consumeErrors=True
-        )).addErrback(unwrapFirstError)
+        handle = logcontext.preserve_fn(handle_check_result)
+        deferreds2 = [
+            handle(pdu, deferred)
+            for pdu, deferred in zip(pdus, deferreds)
+        ]
+
+        valid_pdus = yield logcontext.make_deferred_yieldable(
+            defer.gatherResults(
+                deferreds2,
+                consumeErrors=True,
+            )
+        ).addErrback(unwrapFirstError)
 
         if include_none:
             defer.returnValue(valid_pdus)
@@ -108,7 +103,9 @@ class FederationBase(object):
             defer.returnValue([p for p in valid_pdus if p])
 
     def _check_sigs_and_hash(self, pdu):
-        return self._check_sigs_and_hashes([pdu])[0]
+        return logcontext.make_deferred_yieldable(
+            self._check_sigs_and_hashes([pdu])[0],
+        )
 
     def _check_sigs_and_hashes(self, pdus):
         """Checks that each of the received events is correctly signed by the
@@ -123,6 +120,7 @@ class FederationBase(object):
               * returns a redacted version of the event (if the signature
                 matched but the hash did not)
               * throws a SynapseError if the signature check failed.
+            The deferreds run their callbacks in the sentinel logcontext.
         """
 
         redacted_pdus = [
@@ -135,29 +133,33 @@ class FederationBase(object):
             for p in redacted_pdus
         ])
 
-        def callback(_, pdu, redacted):
-            if not check_event_content_hash(pdu):
-                logger.warn(
-                    "Event content has been tampered, redacting %s: %s",
-                    pdu.event_id, pdu.get_pdu_json()
-                )
-                return redacted
-
-            if spamcheck.check_event_for_spam(pdu):
-                logger.warn(
-                    "Event contains spam, redacting %s: %s",
-                    pdu.event_id, pdu.get_pdu_json()
-                )
-                return redacted
+        ctx = logcontext.LoggingContext.current_context()
 
-            return pdu
+        def callback(_, pdu, redacted):
+            with logcontext.PreserveLoggingContext(ctx):
+                if not check_event_content_hash(pdu):
+                    logger.warn(
+                        "Event content has been tampered, redacting %s: %s",
+                        pdu.event_id, pdu.get_pdu_json()
+                    )
+                    return redacted
+
+                if spamcheck.check_event_for_spam(pdu):
+                    logger.warn(
+                        "Event contains spam, redacting %s: %s",
+                        pdu.event_id, pdu.get_pdu_json()
+                    )
+                    return redacted
+
+                return pdu
 
         def errback(failure, pdu):
             failure.trap(SynapseError)
-            logger.warn(
-                "Signature check failed for %s",
-                pdu.event_id,
-            )
+            with logcontext.PreserveLoggingContext(ctx):
+                logger.warn(
+                    "Signature check failed for %s",
+                    pdu.event_id,
+                )
             return failure
 
         for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):