summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/9702.misc1
-rw-r--r--contrib/experiments/test_messaging.py42
-rw-r--r--synapse/federation/sender/__init__.py140
-rw-r--r--synapse/federation/sender/per_destination_queue.py15
-rw-r--r--synapse/storage/databases/main/transactions.py28
5 files changed, 129 insertions, 97 deletions
diff --git a/changelog.d/9702.misc b/changelog.d/9702.misc
new file mode 100644
index 0000000000..c6e63450a9
--- /dev/null
+++ b/changelog.d/9702.misc
@@ -0,0 +1 @@
+Speed up federation transmission by using fewer database calls. Contributed by @ShadowJonathan.
diff --git a/contrib/experiments/test_messaging.py b/contrib/experiments/test_messaging.py
index 31b8a68225..5dd172052b 100644
--- a/contrib/experiments/test_messaging.py
+++ b/contrib/experiments/test_messaging.py
@@ -224,14 +224,16 @@ class HomeServer(ReplicationHandler):
         destinations = yield self.get_servers_for_context(room_name)
 
         try:
-            yield self.replication_layer.send_pdu(
-                Pdu.create_new(
-                    context=room_name,
-                    pdu_type="sy.room.message",
-                    content={"sender": sender, "body": body},
-                    origin=self.server_name,
-                    destinations=destinations,
-                )
+            yield self.replication_layer.send_pdus(
+                [
+                    Pdu.create_new(
+                        context=room_name,
+                        pdu_type="sy.room.message",
+                        content={"sender": sender, "body": body},
+                        origin=self.server_name,
+                        destinations=destinations,
+                    )
+                ]
             )
         except Exception as e:
             logger.exception(e)
@@ -253,7 +255,7 @@ class HomeServer(ReplicationHandler):
                 origin=self.server_name,
                 destinations=destinations,
             )
-            yield self.replication_layer.send_pdu(pdu)
+            yield self.replication_layer.send_pdus([pdu])
         except Exception as e:
             logger.exception(e)
 
@@ -265,16 +267,18 @@ class HomeServer(ReplicationHandler):
         destinations = yield self.get_servers_for_context(room_name)
 
         try:
-            yield self.replication_layer.send_pdu(
-                Pdu.create_new(
-                    context=room_name,
-                    is_state=True,
-                    pdu_type="sy.room.member",
-                    state_key=invitee,
-                    content={"membership": "invite"},
-                    origin=self.server_name,
-                    destinations=destinations,
-                )
+            yield self.replication_layer.send_pdus(
+                [
+                    Pdu.create_new(
+                        context=room_name,
+                        is_state=True,
+                        pdu_type="sy.room.member",
+                        state_key=invitee,
+                        content={"membership": "invite"},
+                        origin=self.server_name,
+                        destinations=destinations,
+                    )
+                ]
             )
         except Exception as e:
             logger.exception(e)
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 155161685d..952ad39f8c 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -18,8 +18,6 @@ from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Set,
 
 from prometheus_client import Counter
 
-from twisted.internet import defer
-
 import synapse.metrics
 from synapse.api.presence import UserPresenceState
 from synapse.events import EventBase
@@ -27,11 +25,7 @@ from synapse.federation.sender.per_destination_queue import PerDestinationQueue
 from synapse.federation.sender.transaction_manager import TransactionManager
 from synapse.federation.units import Edu
 from synapse.handlers.presence import get_interested_remotes
-from synapse.logging.context import (
-    make_deferred_yieldable,
-    preserve_fn,
-    run_in_background,
-)
+from synapse.logging.context import preserve_fn
 from synapse.metrics import (
     LaterGauge,
     event_processing_loop_counter,
@@ -39,7 +33,7 @@ from synapse.metrics import (
     events_processed_counter,
 )
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import JsonDict, ReadReceipt, RoomStreamToken
+from synapse.types import Collection, JsonDict, ReadReceipt, RoomStreamToken
 from synapse.util.metrics import Measure, measure_func
 
 if TYPE_CHECKING:
@@ -276,15 +270,27 @@ class FederationSender(AbstractFederationSender):
                 if not events and next_token >= self._last_poked_id:
                     break
 
-                async def handle_event(event: EventBase) -> None:
+                async def get_destinations_for_event(
+                    event: EventBase,
+                ) -> Collection[str]:
+                    """Computes the destinations to which this event must be sent.
+
+                    This returns an empty tuple when there are no destinations to send to,
+                    or if this event is not from this homeserver and it is not sending
+                    it on behalf of another server.
+
+                    Will also filter out destinations which this sender is not responsible for,
+                    if multiple federation senders exist.
+                    """
+
                     # Only send events for this server.
                     send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of()
                     is_mine = self.is_mine_id(event.sender)
                     if not is_mine and send_on_behalf_of is None:
-                        return
+                        return ()
 
                     if not event.internal_metadata.should_proactively_send():
-                        return
+                        return ()
 
                     destinations = None  # type: Optional[Set[str]]
                     if not event.prev_event_ids():
@@ -319,7 +325,7 @@ class FederationSender(AbstractFederationSender):
                                 "Failed to calculate hosts in room for event: %s",
                                 event.event_id,
                             )
-                            return
+                            return ()
 
                     destinations = {
                         d
@@ -329,17 +335,15 @@ class FederationSender(AbstractFederationSender):
                         )
                     }
 
+                    destinations.discard(self.server_name)
+
                     if send_on_behalf_of is not None:
                         # If we are sending the event on behalf of another server
                         # then it already has the event and there is no reason to
                         # send the event to it.
                         destinations.discard(send_on_behalf_of)
 
-                    logger.debug("Sending %s to %r", event, destinations)
-
                     if destinations:
-                        await self._send_pdu(event, destinations)
-
                         now = self.clock.time_msec()
                         ts = await self.store.get_received_ts(event.event_id)
 
@@ -347,24 +351,29 @@ class FederationSender(AbstractFederationSender):
                             "federation_sender"
                         ).observe((now - ts) / 1000)
 
-                async def handle_room_events(events: Iterable[EventBase]) -> None:
-                    with Measure(self.clock, "handle_room_events"):
-                        for event in events:
-                            await handle_event(event)
-
-                events_by_room = {}  # type: Dict[str, List[EventBase]]
-                for event in events:
-                    events_by_room.setdefault(event.room_id, []).append(event)
-
-                await make_deferred_yieldable(
-                    defer.gatherResults(
-                        [
-                            run_in_background(handle_room_events, evs)
-                            for evs in events_by_room.values()
-                        ],
-                        consumeErrors=True,
-                    )
-                )
+                        return destinations
+                    return ()
+
+                async def get_federatable_events_and_destinations(
+                    events: Iterable[EventBase],
+                ) -> List[Tuple[EventBase, Collection[str]]]:
+                    with Measure(self.clock, "get_destinations_for_events"):
+                        # Fetch federation destinations per event,
+                        # skip if get_destinations_for_event returns an empty collection,
+                        # return list of event->destinations pairs.
+                        return [
+                            (event, dests)
+                            for (event, dests) in [
+                                (event, await get_destinations_for_event(event))
+                                for event in events
+                            ]
+                            if dests
+                        ]
+
+                events_and_dests = await get_federatable_events_and_destinations(events)
+
+                # Send corresponding events to each destination queue
+                await self._distribute_events(events_and_dests)
 
                 await self.store.update_federation_out_pos("events", next_token)
 
@@ -382,7 +391,7 @@ class FederationSender(AbstractFederationSender):
                     events_processed_counter.inc(len(events))
 
                     event_processing_loop_room_count.labels("federation_sender").inc(
-                        len(events_by_room)
+                        len({event.room_id for event in events})
                     )
 
                 event_processing_loop_counter.labels("federation_sender").inc()
@@ -394,34 +403,53 @@ class FederationSender(AbstractFederationSender):
         finally:
             self._is_processing = False
 
-    async def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None:
-        # We loop through all destinations to see whether we already have
-        # a transaction in progress. If we do, stick it in the pending_pdus
-        # table and we'll get back to it later.
+    async def _distribute_events(
+        self,
+        events_and_dests: Iterable[Tuple[EventBase, Collection[str]]],
+    ) -> None:
+        """Distribute events to the respective per_destination queues.
 
-        destinations = set(destinations)
-        destinations.discard(self.server_name)
-        logger.debug("Sending to: %s", str(destinations))
+        Also persists last-seen per-room stream_ordering to 'destination_rooms'.
 
-        if not destinations:
-            return
+        Args:
+            events_and_dests: A list of tuples, which are (event: EventBase, destinations: Collection[str]).
+                              Every event is paired with its intended destinations (in federation).
+        """
+        # Tuples of room_id + destination to their max-seen stream_ordering
+        room_with_dest_stream_ordering = {}  # type: Dict[Tuple[str, str], int]
 
-        sent_pdus_destination_dist_total.inc(len(destinations))
-        sent_pdus_destination_dist_count.inc()
+        # List of events to send to each destination
+        events_by_dest = {}  # type: Dict[str, List[EventBase]]
 
-        assert pdu.internal_metadata.stream_ordering
+        # For each event-destinations pair...
+        for event, destinations in events_and_dests:
 
-        # track the fact that we have a PDU for these destinations,
-        # to allow us to perform catch-up later on if the remote is unreachable
-        # for a while.
-        await self.store.store_destination_rooms_entries(
-            destinations,
-            pdu.room_id,
-            pdu.internal_metadata.stream_ordering,
+            # (we got this from the database, it's filled)
+            assert event.internal_metadata.stream_ordering
+
+            sent_pdus_destination_dist_total.inc(len(destinations))
+            sent_pdus_destination_dist_count.inc()
+
+            # ...iterate over those destinations..
+            for destination in destinations:
+                # ...update their stream-ordering...
+                room_with_dest_stream_ordering[(event.room_id, destination)] = max(
+                    event.internal_metadata.stream_ordering,
+                    room_with_dest_stream_ordering.get((event.room_id, destination), 0),
+                )
+
+                # ...and add the event to each destination queue.
+                events_by_dest.setdefault(destination, []).append(event)
+
+        # Bulk-store destination_rooms stream_ids
+        await self.store.bulk_store_destination_rooms_entries(
+            room_with_dest_stream_ordering
         )
 
-        for destination in destinations:
-            self._get_per_destination_queue(destination).send_pdu(pdu)
+        for destination, pdus in events_by_dest.items():
+            logger.debug("Sending %d pdus to %s", len(pdus), destination)
+
+            self._get_per_destination_queue(destination).send_pdus(pdus)
 
     async def send_read_receipt(self, receipt: ReadReceipt) -> None:
         """Send a RR to any other servers in the room
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 3b053ebcfb..3bb66bce32 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -154,19 +154,22 @@ class PerDestinationQueue:
             + len(self._pending_edus_keyed)
         )
 
-    def send_pdu(self, pdu: EventBase) -> None:
-        """Add a PDU to the queue, and start the transmission loop if necessary
+    def send_pdus(self, pdus: Iterable[EventBase]) -> None:
+        """Add PDUs to the queue, and start the transmission loop if necessary
 
         Args:
-            pdu: pdu to send
+            pdus: pdus to send
         """
         if not self._catching_up or self._last_successful_stream_ordering is None:
             # only enqueue the PDU if we are not catching up (False) or do not
             # yet know if we have anything to catch up (None)
-            self._pending_pdus.append(pdu)
+            self._pending_pdus.extend(pdus)
         else:
-            assert pdu.internal_metadata.stream_ordering
-            self._catchup_last_skipped = pdu.internal_metadata.stream_ordering
+            self._catchup_last_skipped = max(
+                pdu.internal_metadata.stream_ordering
+                for pdu in pdus
+                if pdu.internal_metadata.stream_ordering is not None
+            )
 
         self.attempt_new_transaction()
 
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 82335e7a9d..b28ca61f80 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -14,7 +14,7 @@
 
 import logging
 from collections import namedtuple
-from typing import Iterable, List, Optional, Tuple
+from typing import Dict, List, Optional, Tuple
 
 from canonicaljson import encode_canonical_json
 
@@ -295,37 +295,33 @@ class TransactionStore(TransactionWorkerStore):
                 },
             )
 
-    async def store_destination_rooms_entries(
-        self,
-        destinations: Iterable[str],
-        room_id: str,
-        stream_ordering: int,
-    ) -> None:
+    async def bulk_store_destination_rooms_entries(
+        self, room_and_destination_to_ordering: Dict[Tuple[str, str], int]
+    ):
         """
-        Updates or creates `destination_rooms` entries in batch for a single event.
+        Updates or creates `destination_rooms` entries for a number of events.
 
         Args:
-            destinations: list of destinations
-            room_id: the room_id of the event
-            stream_ordering: the stream_ordering of the event
+            room_and_destination_to_ordering: A mapping of (room, destination) -> stream_id
         """
 
         await self.db_pool.simple_upsert_many(
             table="destinations",
             key_names=("destination",),
-            key_values=[(d,) for d in destinations],
+            key_values={(d,) for _, d in room_and_destination_to_ordering.keys()},
             value_names=[],
             value_values=[],
             desc="store_destination_rooms_entries_dests",
         )
 
-        rows = [(destination, room_id) for destination in destinations]
         await self.db_pool.simple_upsert_many(
             table="destination_rooms",
-            key_names=("destination", "room_id"),
-            key_values=rows,
+            key_names=("room_id", "destination"),
+            key_values=list(room_and_destination_to_ordering.keys()),
             value_names=["stream_ordering"],
-            value_values=[(stream_ordering,)] * len(rows),
+            value_values=[
+                (stream_id,) for stream_id in room_and_destination_to_ordering.values()
+            ],
             desc="store_destination_rooms_entries_rooms",
         )