diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 022bbf7dad..deb40f4610 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -14,26 +14,19 @@
import abc
import logging
-from typing import (
- TYPE_CHECKING,
- Collection,
- Dict,
- Hashable,
- Iterable,
- List,
- Optional,
- Set,
- Tuple,
-)
+from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Set, Tuple
from prometheus_client import Counter
+from twisted.internet import defer
+
import synapse.metrics
from synapse.api.presence import UserPresenceState
from synapse.events import EventBase
from synapse.federation.sender.per_destination_queue import PerDestinationQueue
from synapse.federation.sender.transaction_manager import TransactionManager
from synapse.federation.units import Edu
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import (
LaterGauge,
event_processing_loop_counter,
@@ -262,27 +255,15 @@ class FederationSender(AbstractFederationSender):
if not events and next_token >= self._last_poked_id:
break
- async def get_destinations_for_event(
- event: EventBase,
- ) -> Collection[str]:
- """Computes the destinations to which this event must be sent.
-
- This returns an empty tuple when there are no destinations to send to,
- or if this event is not from this homeserver and it is not sending
- it on behalf of another server.
-
- Will also filter out destinations which this sender is not responsible for,
- if multiple federation senders exist.
- """
-
+ async def handle_event(event: EventBase) -> None:
# Only send events for this server.
send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of()
is_mine = self.is_mine_id(event.sender)
if not is_mine and send_on_behalf_of is None:
- return ()
+ return
if not event.internal_metadata.should_proactively_send():
- return ()
+ return
destinations = None # type: Optional[Set[str]]
if not event.prev_event_ids():
@@ -317,7 +298,7 @@ class FederationSender(AbstractFederationSender):
"Failed to calculate hosts in room for event: %s",
event.event_id,
)
- return ()
+ return
destinations = {
d
@@ -327,15 +308,17 @@ class FederationSender(AbstractFederationSender):
)
}
- destinations.discard(self.server_name)
-
if send_on_behalf_of is not None:
# If we are sending the event on behalf of another server
# then it already has the event and there is no reason to
# send the event to it.
destinations.discard(send_on_behalf_of)
+ logger.debug("Sending %s to %r", event, destinations)
+
if destinations:
+ await self._send_pdu(event, destinations)
+
now = self.clock.time_msec()
ts = await self.store.get_received_ts(event.event_id)
@@ -343,29 +326,24 @@ class FederationSender(AbstractFederationSender):
"federation_sender"
).observe((now - ts) / 1000)
- return destinations
- return ()
-
- async def get_federatable_events_and_destinations(
- events: Iterable[EventBase],
- ) -> List[Tuple[EventBase, Collection[str]]]:
- with Measure(self.clock, "get_destinations_for_events"):
- # Fetch federation destinations per event,
- # skip if get_destinations_for_event returns an empty collection,
- # return list of event->destinations pairs.
- return [
- (event, dests)
- for (event, dests) in [
- (event, await get_destinations_for_event(event))
- for event in events
- ]
- if dests
- ]
-
- events_and_dests = await get_federatable_events_and_destinations(events)
-
- # Send corresponding events to each destination queue
- await self._distribute_events(events_and_dests)
+ async def handle_room_events(events: Iterable[EventBase]) -> None:
+ with Measure(self.clock, "handle_room_events"):
+ for event in events:
+ await handle_event(event)
+
+ events_by_room = {} # type: Dict[str, List[EventBase]]
+ for event in events:
+ events_by_room.setdefault(event.room_id, []).append(event)
+
+ await make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(handle_room_events, evs)
+ for evs in events_by_room.values()
+ ],
+ consumeErrors=True,
+ )
+ )
await self.store.update_federation_out_pos("events", next_token)
@@ -383,7 +361,7 @@ class FederationSender(AbstractFederationSender):
events_processed_counter.inc(len(events))
event_processing_loop_room_count.labels("federation_sender").inc(
- len({event.room_id for event in events})
+ len(events_by_room)
)
event_processing_loop_counter.labels("federation_sender").inc()
@@ -395,53 +373,34 @@ class FederationSender(AbstractFederationSender):
finally:
self._is_processing = False
- async def _distribute_events(
- self,
- events_and_dests: Iterable[Tuple[EventBase, Collection[str]]],
- ) -> None:
- """Distribute events to the respective per_destination queues.
-
- Also persists last-seen per-room stream_ordering to 'destination_rooms'.
-
- Args:
- events_and_dests: A list of tuples, which are (event: EventBase, destinations: Collection[str]).
- Every event is paired with its intended destinations (in federation).
- """
- # Tuples of room_id + destination to their max-seen stream_ordering
- room_with_dest_stream_ordering = {} # type: Dict[Tuple[str, str], int]
-
- # List of events to send to each destination
- events_by_dest = {} # type: Dict[str, List[EventBase]]
+ async def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None:
+ # We loop through all destinations to see whether we already have
+ # a transaction in progress. If we do, stick it in the pending_pdus
+ # table and we'll get back to it later.
- # For each event-destinations pair...
- for event, destinations in events_and_dests:
+ destinations = set(destinations)
+ destinations.discard(self.server_name)
+ logger.debug("Sending to: %s", str(destinations))
- # (we got this from the database, it's filled)
- assert event.internal_metadata.stream_ordering
-
- sent_pdus_destination_dist_total.inc(len(destinations))
- sent_pdus_destination_dist_count.inc()
+ if not destinations:
+ return
- # ...iterate over those destinations..
- for destination in destinations:
- # ...update their stream-ordering...
- room_with_dest_stream_ordering[(event.room_id, destination)] = max(
- event.internal_metadata.stream_ordering,
- room_with_dest_stream_ordering.get((event.room_id, destination), 0),
- )
+ sent_pdus_destination_dist_total.inc(len(destinations))
+ sent_pdus_destination_dist_count.inc()
- # ...and add the event to each destination queue.
- events_by_dest.setdefault(destination, []).append(event)
+ assert pdu.internal_metadata.stream_ordering
- # Bulk-store destination_rooms stream_ids
- await self.store.bulk_store_destination_rooms_entries(
- room_with_dest_stream_ordering
+ # track the fact that we have a PDU for these destinations,
+ # to allow us to perform catch-up later on if the remote is unreachable
+ # for a while.
+ await self.store.store_destination_rooms_entries(
+ destinations,
+ pdu.room_id,
+ pdu.internal_metadata.stream_ordering,
)
- for destination, pdus in events_by_dest.items():
- logger.debug("Sending %d pdus to %s", len(pdus), destination)
-
- self._get_per_destination_queue(destination).send_pdus(pdus)
+ for destination in destinations:
+ self._get_per_destination_queue(destination).send_pdu(pdu)
async def send_read_receipt(self, receipt: ReadReceipt) -> None:
"""Send a RR to any other servers in the room
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 3bb66bce32..3b053ebcfb 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -154,22 +154,19 @@ class PerDestinationQueue:
+ len(self._pending_edus_keyed)
)
- def send_pdus(self, pdus: Iterable[EventBase]) -> None:
- """Add PDUs to the queue, and start the transmission loop if necessary
+ def send_pdu(self, pdu: EventBase) -> None:
+ """Add a PDU to the queue, and start the transmission loop if necessary
Args:
- pdus: pdus to send
+ pdu: pdu to send
"""
if not self._catching_up or self._last_successful_stream_ordering is None:
# only enqueue the PDU if we are not catching up (False) or do not
# yet know if we have anything to catch up (None)
- self._pending_pdus.extend(pdus)
+ self._pending_pdus.append(pdu)
else:
- self._catchup_last_skipped = max(
- pdu.internal_metadata.stream_ordering
- for pdu in pdus
- if pdu.internal_metadata.stream_ordering is not None
- )
+ assert pdu.internal_metadata.stream_ordering
+ self._catchup_last_skipped = pdu.internal_metadata.stream_ordering
self.attempt_new_transaction()
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index b28ca61f80..82335e7a9d 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -14,7 +14,7 @@
import logging
from collections import namedtuple
-from typing import Dict, List, Optional, Tuple
+from typing import Iterable, List, Optional, Tuple
from canonicaljson import encode_canonical_json
@@ -295,33 +295,37 @@ class TransactionStore(TransactionWorkerStore):
},
)
- async def bulk_store_destination_rooms_entries(
- self, room_and_destination_to_ordering: Dict[Tuple[str, str], int]
- ):
+ async def store_destination_rooms_entries(
+ self,
+ destinations: Iterable[str],
+ room_id: str,
+ stream_ordering: int,
+ ) -> None:
"""
- Updates or creates `destination_rooms` entries for a number of events.
+ Updates or creates `destination_rooms` entries in batch for a single event.
Args:
- room_and_destination_to_ordering: A mapping of (room, destination) -> stream_id
+ destinations: list of destinations
+ room_id: the room_id of the event
+ stream_ordering: the stream_ordering of the event
"""
await self.db_pool.simple_upsert_many(
table="destinations",
key_names=("destination",),
- key_values={(d,) for _, d in room_and_destination_to_ordering.keys()},
+ key_values=[(d,) for d in destinations],
value_names=[],
value_values=[],
desc="store_destination_rooms_entries_dests",
)
+ rows = [(destination, room_id) for destination in destinations]
await self.db_pool.simple_upsert_many(
table="destination_rooms",
- key_names=("room_id", "destination"),
- key_values=list(room_and_destination_to_ordering.keys()),
+ key_names=("destination", "room_id"),
+ key_values=rows,
value_names=["stream_ordering"],
- value_values=[
- (stream_id,) for stream_id in room_and_destination_to_ordering.values()
- ],
+ value_values=[(stream_ordering,)] * len(rows),
desc="store_destination_rooms_entries_rooms",
)
|