diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 2339cc9034..a0f5d40eb3 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -12,28 +12,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-
-from twisted.internet import defer
-
-from synapse.events.utils import prune_event
-
-from synapse.crypto.event_signing import check_event_content_hash
-
-from synapse.api.errors import SynapseError
-
-from synapse.util import unwrapFirstError
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
-
import logging
+from synapse.api.errors import SynapseError
+from synapse.crypto.event_signing import check_event_content_hash
+from synapse.events.utils import prune_event
+from synapse.util import unwrapFirstError, logcontext
+from twisted.internet import defer
logger = logging.getLogger(__name__)
class FederationBase(object):
def __init__(self, hs):
- pass
+ self.spam_checker = hs.get_spam_checker()
@defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
@@ -57,56 +49,52 @@ class FederationBase(object):
"""
deferreds = self._check_sigs_and_hashes(pdus)
- def callback(pdu):
- return pdu
+ @defer.inlineCallbacks
+ def handle_check_result(pdu, deferred):
+ try:
+ res = yield logcontext.make_deferred_yieldable(deferred)
+ except SynapseError:
+ res = None
- def errback(failure, pdu):
- failure.trap(SynapseError)
- return None
-
- def try_local_db(res, pdu):
if not res:
# Check local db.
- return self.store.get_event(
+ res = yield self.store.get_event(
pdu.event_id,
allow_rejected=True,
allow_none=True,
)
- return res
- def try_remote(res, pdu):
if not res and pdu.origin != origin:
- return self.get_pdu(
- destinations=[pdu.origin],
- event_id=pdu.event_id,
- outlier=outlier,
- timeout=10000,
- ).addErrback(lambda e: None)
- return res
-
- def warn(res, pdu):
+ try:
+ res = yield self.get_pdu(
+ destinations=[pdu.origin],
+ event_id=pdu.event_id,
+ outlier=outlier,
+ timeout=10000,
+ )
+ except SynapseError:
+ pass
+
if not res:
logger.warn(
"Failed to find copy of %s with valid signature",
pdu.event_id,
)
- return res
- for pdu, deferred in zip(pdus, deferreds):
- deferred.addCallbacks(
- callback, errback, errbackArgs=[pdu]
- ).addCallback(
- try_local_db, pdu
- ).addCallback(
- try_remote, pdu
- ).addCallback(
- warn, pdu
- )
+ defer.returnValue(res)
- valid_pdus = yield preserve_context_over_deferred(defer.gatherResults(
- deferreds,
- consumeErrors=True
- )).addErrback(unwrapFirstError)
+ handle = logcontext.preserve_fn(handle_check_result)
+ deferreds2 = [
+ handle(pdu, deferred)
+ for pdu, deferred in zip(pdus, deferreds)
+ ]
+
+ valid_pdus = yield logcontext.make_deferred_yieldable(
+ defer.gatherResults(
+ deferreds2,
+ consumeErrors=True,
+ )
+ ).addErrback(unwrapFirstError)
if include_none:
defer.returnValue(valid_pdus)
@@ -114,15 +102,24 @@ class FederationBase(object):
defer.returnValue([p for p in valid_pdus if p])
def _check_sigs_and_hash(self, pdu):
- return self._check_sigs_and_hashes([pdu])[0]
+ return logcontext.make_deferred_yieldable(
+ self._check_sigs_and_hashes([pdu])[0],
+ )
def _check_sigs_and_hashes(self, pdus):
- """Throws a SynapseError if a PDU does not have the correct
- signatures.
+ """Checks that each of the received events is correctly signed by the
+ sending server.
+
+ Args:
+ pdus (list[FrozenEvent]): the events to be checked
Returns:
- FrozenEvent: Either the given event or it redacted if it failed the
- content hash check.
+ list[Deferred]: for each input event, a deferred which:
+ * returns the original event if the checks pass
+ * returns a redacted version of the event (if the signature
+ matched but the hash did not)
+ * throws a SynapseError if the signature check failed.
+ The deferreds run their callbacks in the sentinel logcontext.
"""
redacted_pdus = [
@@ -130,26 +127,38 @@ class FederationBase(object):
for pdu in pdus
]
- deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([
+ deferreds = self.keyring.verify_json_objects_for_server([
(p.origin, p.get_pdu_json())
for p in redacted_pdus
])
+ ctx = logcontext.LoggingContext.current_context()
+
def callback(_, pdu, redacted):
- if not check_event_content_hash(pdu):
- logger.warn(
- "Event content has been tampered, redacting %s: %s",
- pdu.event_id, pdu.get_pdu_json()
- )
- return redacted
- return pdu
+ with logcontext.PreserveLoggingContext(ctx):
+ if not check_event_content_hash(pdu):
+ logger.warn(
+ "Event content has been tampered, redacting %s: %s",
+ pdu.event_id, pdu.get_pdu_json()
+ )
+ return redacted
+
+ if self.spam_checker.check_event_for_spam(pdu):
+ logger.warn(
+ "Event contains spam, redacting %s: %s",
+ pdu.event_id, pdu.get_pdu_json()
+ )
+ return redacted
+
+ return pdu
def errback(failure, pdu):
failure.trap(SynapseError)
- logger.warn(
- "Signature check failed for %s",
- pdu.event_id,
- )
+ with logcontext.PreserveLoggingContext(ctx):
+ logger.warn(
+ "Signature check failed for %s",
+ pdu.event_id,
+ )
return failure
for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
|