diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index e0e9f5d0be..1076ebc036 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -21,6 +21,7 @@ from typing import (
Any,
Awaitable,
Callable,
+ Collection,
Dict,
Iterable,
List,
@@ -35,9 +36,6 @@ from typing import (
import attr
from prometheus_client import Counter
-from twisted.internet import defer
-from twisted.internet.defer import Deferred
-
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import (
CodeMessageException,
@@ -56,10 +54,9 @@ from synapse.api.room_versions import (
from synapse.events import EventBase, builder
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.transport.client import SendJoinResponse
-from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.logging.utils import log_function
from synapse.types import JsonDict, get_domain_from_id
-from synapse.util import unwrapFirstError
+from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
@@ -360,10 +357,9 @@ class FederationClient(FederationBase):
async def _check_sigs_and_hash_and_fetch(
self,
origin: str,
- pdus: List[EventBase],
+ pdus: Collection[EventBase],
room_version: RoomVersion,
outlier: bool = False,
- include_none: bool = False,
) -> List[EventBase]:
"""Takes a list of PDUs and checks the signatures and hashes of each
one. If a PDU fails its signature check then we check if we have it in
@@ -380,57 +376,87 @@ class FederationClient(FederationBase):
pdu
room_version
outlier: Whether the events are outliers or not
- include_none: Whether to include None in the returned list
- for events that have failed their checks
Returns:
A list of PDUs that have valid signatures and hashes.
"""
- deferreds = self._check_sigs_and_hashes(room_version, pdus)
- async def handle_check_result(pdu: EventBase, deferred: Deferred):
- try:
- res = await make_deferred_yieldable(deferred)
- except SynapseError:
- res = None
+ # We limit how many PDUs we check at once, as if we try to do hundreds
+ # of thousands of PDUs at once we see large memory spikes.
- if not res:
- # Check local db.
- res = await self.store.get_event(
- pdu.event_id, allow_rejected=True, allow_none=True
- )
+ valid_pdus = []
- pdu_origin = get_domain_from_id(pdu.sender)
- if not res and pdu_origin != origin:
- try:
- res = await self.get_pdu(
- destinations=[pdu_origin],
- event_id=pdu.event_id,
- room_version=room_version,
- outlier=outlier,
- timeout=10000,
- )
- except SynapseError:
- pass
+ async def _execute(pdu: EventBase) -> None:
+ valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
+ pdu=pdu,
+ origin=origin,
+ outlier=outlier,
+ room_version=room_version,
+ )
- if not res:
- logger.warning(
- "Failed to find copy of %s with valid signature", pdu.event_id
- )
+ if valid_pdu:
+ valid_pdus.append(valid_pdu)
- return res
+ await concurrently_execute(_execute, pdus, 10000)
- handle = preserve_fn(handle_check_result)
- deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)]
+ return valid_pdus
- valid_pdus = await make_deferred_yieldable(
- defer.gatherResults(deferreds2, consumeErrors=True)
- ).addErrback(unwrapFirstError)
+ async def _check_sigs_and_hash_and_fetch_one(
+ self,
+ pdu: EventBase,
+ origin: str,
+ room_version: RoomVersion,
+ outlier: bool = False,
+ ) -> Optional[EventBase]:
+ """Takes a PDU and checks its signatures and hashes. If the PDU fails
+ its signature check then we check if we have it in the database and if
+ not then request if from the originating server of that PDU.
- if include_none:
- return valid_pdus
- else:
- return [p for p in valid_pdus if p]
+ If then PDU fails its content hash check then it is redacted.
+
+ Args:
+ origin
+ pdu
+ room_version
+ outlier: Whether the events are outliers or not
+ include_none: Whether to include None in the returned list
+ for events that have failed their checks
+
+ Returns:
+ The PDU (possibly redacted) if it has valid signatures and hashes.
+ """
+
+ res = None
+ try:
+ res = await self._check_sigs_and_hash(room_version, pdu)
+ except SynapseError:
+ pass
+
+ if not res:
+ # Check local db.
+ res = await self.store.get_event(
+ pdu.event_id, allow_rejected=True, allow_none=True
+ )
+
+ pdu_origin = get_domain_from_id(pdu.sender)
+ if not res and pdu_origin != origin:
+ try:
+ res = await self.get_pdu(
+ destinations=[pdu_origin],
+ event_id=pdu.event_id,
+ room_version=room_version,
+ outlier=outlier,
+ timeout=10000,
+ )
+ except SynapseError:
+ pass
+
+ if not res:
+ logger.warning(
+ "Failed to find copy of %s with valid signature", pdu.event_id
+ )
+
+ return res
async def get_event_auth(
self, destination: str, room_id: str, event_id: str
@@ -671,8 +697,6 @@ class FederationClient(FederationBase):
state = response.state
auth_chain = response.auth_events
- pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)}
-
create_event = None
for e in state:
if (e.type, e.state_key) == (EventTypes.Create, ""):
@@ -696,14 +720,29 @@ class FederationClient(FederationBase):
% (create_room_version,)
)
- valid_pdus = await self._check_sigs_and_hash_and_fetch(
- destination,
- list(pdus.values()),
- outlier=True,
- room_version=room_version,
+ logger.info(
+ "Processing from send_join %d events", len(state) + len(auth_chain)
)
- valid_pdus_map = {p.event_id: p for p in valid_pdus}
+ # We now go and check the signatures and hashes for the event. Note
+ # that we limit how many events we process at a time to keep the
+ # memory overhead from exploding.
+ valid_pdus_map: Dict[str, EventBase] = {}
+
+ async def _execute(pdu: EventBase) -> None:
+ valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
+ pdu=pdu,
+ origin=destination,
+ outlier=True,
+ room_version=room_version,
+ )
+
+ if valid_pdu:
+ valid_pdus_map[valid_pdu.event_id] = valid_pdu
+
+ await concurrently_execute(
+ _execute, itertools.chain(state, auth_chain), 10000
+ )
# NB: We *need* to copy to ensure that we don't have multiple
# references being passed on, as that causes... issues.
|