diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 4870e39652..8c6b839478 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -33,6 +33,7 @@ from typing import (
from prometheus_client import Counter
from twisted.internet import defer
+from twisted.internet.defer import Deferred
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import (
@@ -51,7 +52,7 @@ from synapse.api.room_versions import (
)
from synapse.events import EventBase, builder
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
-from synapse.logging.context import make_deferred_yieldable
+from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.logging.utils import log_function
from synapse.types import JsonDict
from synapse.util import unwrapFirstError
@@ -187,7 +188,7 @@ class FederationClient(FederationBase):
async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
- ) -> List[EventBase]:
+ ) -> Optional[List[EventBase]]:
"""Requests some more historic PDUs for the given room from the
given destination server.
@@ -199,9 +200,9 @@ class FederationClient(FederationBase):
"""
logger.debug("backfill extrem=%s", extremities)
- # If there are no extremeties then we've (probably) reached the start.
+ # If there are no extremities then we've (probably) reached the start.
if not extremities:
- return
+ return None
transaction_data = await self.transport_layer.backfill(
dest, room_id, extremities, limit
@@ -284,7 +285,7 @@ class FederationClient(FederationBase):
pdu_list = [
event_from_pdu_json(p, room_version, outlier=outlier)
for p in transaction_data["pdus"]
- ]
+ ] # type: List[EventBase]
if pdu_list and pdu_list[0]:
pdu = pdu_list[0]
@@ -345,6 +346,83 @@ class FederationClient(FederationBase):
return state_event_ids, auth_event_ids
+ async def _check_sigs_and_hash_and_fetch(
+ self,
+ origin: str,
+ pdus: List[EventBase],
+ room_version: str,
+ outlier: bool = False,
+ include_none: bool = False,
+ ) -> List[EventBase]:
+ """Takes a list of PDUs and checks the signatures and hashs of each
+ one. If a PDU fails its signature check then we check if we have it in
+ the database and if not then request if from the originating server of
+ that PDU.
+
+ If a PDU fails its content hash check then it is redacted.
+
+ The given list of PDUs are not modified, instead the function returns
+ a new list.
+
+ Args:
+ origin
+ pdu
+ room_version
+ outlier: Whether the events are outliers or not
+ include_none: Whether to include None in the returned list
+ for events that have failed their checks
+
+ Returns:
+ Deferred : A list of PDUs that have valid signatures and hashes.
+ """
+ deferreds = self._check_sigs_and_hashes(room_version, pdus)
+
+ @defer.inlineCallbacks
+ def handle_check_result(pdu: EventBase, deferred: Deferred):
+ try:
+ res = yield make_deferred_yieldable(deferred)
+ except SynapseError:
+ res = None
+
+ if not res:
+ # Check local db.
+ res = yield self.store.get_event(
+ pdu.event_id, allow_rejected=True, allow_none=True
+ )
+
+ if not res and pdu.origin != origin:
+ try:
+ res = yield defer.ensureDeferred(
+ self.get_pdu(
+ destinations=[pdu.origin],
+ event_id=pdu.event_id,
+ room_version=room_version, # type: ignore
+ outlier=outlier,
+ timeout=10000,
+ )
+ )
+ except SynapseError:
+ pass
+
+ if not res:
+ logger.warning(
+ "Failed to find copy of %s with valid signature", pdu.event_id
+ )
+
+ return res
+
+ handle = preserve_fn(handle_check_result)
+ deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)]
+
+ valid_pdus = await make_deferred_yieldable(
+ defer.gatherResults(deferreds2, consumeErrors=True)
+ ).addErrback(unwrapFirstError)
+
+ if include_none:
+ return valid_pdus
+ else:
+ return [p for p in valid_pdus if p]
+
async def get_event_auth(self, destination, room_id, event_id):
res = await self.transport_layer.get_event_auth(destination, room_id, event_id)
@@ -615,7 +693,7 @@ class FederationClient(FederationBase):
]
if auth_chain_create_events != [create_event.event_id]:
raise InvalidResponseError(
- "Unexpected create event(s) in auth chain"
+ "Unexpected create event(s) in auth chain: %s"
% (auth_chain_create_events,)
)
|