summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/federation/sender/__init__.py145
-rw-r--r--synapse/federation/sender/per_destination_queue.py15
-rw-r--r--synapse/storage/databases/main/transactions.py28
4 files changed, 75 insertions, 115 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index fbd49a93e1..5bbaa62de2 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -47,7 +47,7 @@ try:
 except ImportError:
     pass
 
-__version__ = "1.32.2"
+__version__ = "1.33.0rc1"
 
 if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
     # We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 022bbf7dad..deb40f4610 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -14,26 +14,19 @@
 
 import abc
 import logging
-from typing import (
-    TYPE_CHECKING,
-    Collection,
-    Dict,
-    Hashable,
-    Iterable,
-    List,
-    Optional,
-    Set,
-    Tuple,
-)
+from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Set, Tuple
 
 from prometheus_client import Counter
 
+from twisted.internet import defer
+
 import synapse.metrics
 from synapse.api.presence import UserPresenceState
 from synapse.events import EventBase
 from synapse.federation.sender.per_destination_queue import PerDestinationQueue
 from synapse.federation.sender.transaction_manager import TransactionManager
 from synapse.federation.units import Edu
+from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics import (
     LaterGauge,
     event_processing_loop_counter,
@@ -262,27 +255,15 @@ class FederationSender(AbstractFederationSender):
                 if not events and next_token >= self._last_poked_id:
                     break
 
-                async def get_destinations_for_event(
-                    event: EventBase,
-                ) -> Collection[str]:
-                    """Computes the destinations to which this event must be sent.
-
-                    This returns an empty tuple when there are no destinations to send to,
-                    or if this event is not from this homeserver and it is not sending
-                    it on behalf of another server.
-
-                    Will also filter out destinations which this sender is not responsible for,
-                    if multiple federation senders exist.
-                    """
-
+                async def handle_event(event: EventBase) -> None:
                     # Only send events for this server.
                     send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of()
                     is_mine = self.is_mine_id(event.sender)
                     if not is_mine and send_on_behalf_of is None:
-                        return ()
+                        return
 
                     if not event.internal_metadata.should_proactively_send():
-                        return ()
+                        return
 
                     destinations = None  # type: Optional[Set[str]]
                     if not event.prev_event_ids():
@@ -317,7 +298,7 @@ class FederationSender(AbstractFederationSender):
                                 "Failed to calculate hosts in room for event: %s",
                                 event.event_id,
                             )
-                            return ()
+                            return
 
                     destinations = {
                         d
@@ -327,15 +308,17 @@ class FederationSender(AbstractFederationSender):
                         )
                     }
 
-                    destinations.discard(self.server_name)
-
                     if send_on_behalf_of is not None:
                         # If we are sending the event on behalf of another server
                         # then it already has the event and there is no reason to
                         # send the event to it.
                         destinations.discard(send_on_behalf_of)
 
+                    logger.debug("Sending %s to %r", event, destinations)
+
                     if destinations:
+                        await self._send_pdu(event, destinations)
+
                         now = self.clock.time_msec()
                         ts = await self.store.get_received_ts(event.event_id)
 
@@ -343,29 +326,24 @@ class FederationSender(AbstractFederationSender):
                             "federation_sender"
                         ).observe((now - ts) / 1000)
 
-                        return destinations
-                    return ()
-
-                async def get_federatable_events_and_destinations(
-                    events: Iterable[EventBase],
-                ) -> List[Tuple[EventBase, Collection[str]]]:
-                    with Measure(self.clock, "get_destinations_for_events"):
-                        # Fetch federation destinations per event,
-                        # skip if get_destinations_for_event returns an empty collection,
-                        # return list of event->destinations pairs.
-                        return [
-                            (event, dests)
-                            for (event, dests) in [
-                                (event, await get_destinations_for_event(event))
-                                for event in events
-                            ]
-                            if dests
-                        ]
-
-                events_and_dests = await get_federatable_events_and_destinations(events)
-
-                # Send corresponding events to each destination queue
-                await self._distribute_events(events_and_dests)
+                async def handle_room_events(events: Iterable[EventBase]) -> None:
+                    with Measure(self.clock, "handle_room_events"):
+                        for event in events:
+                            await handle_event(event)
+
+                events_by_room = {}  # type: Dict[str, List[EventBase]]
+                for event in events:
+                    events_by_room.setdefault(event.room_id, []).append(event)
+
+                await make_deferred_yieldable(
+                    defer.gatherResults(
+                        [
+                            run_in_background(handle_room_events, evs)
+                            for evs in events_by_room.values()
+                        ],
+                        consumeErrors=True,
+                    )
+                )
 
                 await self.store.update_federation_out_pos("events", next_token)
 
@@ -383,7 +361,7 @@ class FederationSender(AbstractFederationSender):
                     events_processed_counter.inc(len(events))
 
                     event_processing_loop_room_count.labels("federation_sender").inc(
-                        len({event.room_id for event in events})
+                        len(events_by_room)
                     )
 
                 event_processing_loop_counter.labels("federation_sender").inc()
@@ -395,53 +373,34 @@ class FederationSender(AbstractFederationSender):
         finally:
             self._is_processing = False
 
-    async def _distribute_events(
-        self,
-        events_and_dests: Iterable[Tuple[EventBase, Collection[str]]],
-    ) -> None:
-        """Distribute events to the respective per_destination queues.
-
-        Also persists last-seen per-room stream_ordering to 'destination_rooms'.
-
-        Args:
-            events_and_dests: A list of tuples, which are (event: EventBase, destinations: Collection[str]).
-                              Every event is paired with its intended destinations (in federation).
-        """
-        # Tuples of room_id + destination to their max-seen stream_ordering
-        room_with_dest_stream_ordering = {}  # type: Dict[Tuple[str, str], int]
-
-        # List of events to send to each destination
-        events_by_dest = {}  # type: Dict[str, List[EventBase]]
+    async def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None:
+        # We loop through all destinations to see whether we already have
+        # a transaction in progress. If we do, stick it in the pending_pdus
+        # table and we'll get back to it later.
 
-        # For each event-destinations pair...
-        for event, destinations in events_and_dests:
+        destinations = set(destinations)
+        destinations.discard(self.server_name)
+        logger.debug("Sending to: %s", str(destinations))
 
-            # (we got this from the database, it's filled)
-            assert event.internal_metadata.stream_ordering
-
-            sent_pdus_destination_dist_total.inc(len(destinations))
-            sent_pdus_destination_dist_count.inc()
+        if not destinations:
+            return
 
-            # ...iterate over those destinations..
-            for destination in destinations:
-                # ...update their stream-ordering...
-                room_with_dest_stream_ordering[(event.room_id, destination)] = max(
-                    event.internal_metadata.stream_ordering,
-                    room_with_dest_stream_ordering.get((event.room_id, destination), 0),
-                )
+        sent_pdus_destination_dist_total.inc(len(destinations))
+        sent_pdus_destination_dist_count.inc()
 
-                # ...and add the event to each destination queue.
-                events_by_dest.setdefault(destination, []).append(event)
+        assert pdu.internal_metadata.stream_ordering
 
-        # Bulk-store destination_rooms stream_ids
-        await self.store.bulk_store_destination_rooms_entries(
-            room_with_dest_stream_ordering
+        # track the fact that we have a PDU for these destinations,
+        # to allow us to perform catch-up later on if the remote is unreachable
+        # for a while.
+        await self.store.store_destination_rooms_entries(
+            destinations,
+            pdu.room_id,
+            pdu.internal_metadata.stream_ordering,
         )
 
-        for destination, pdus in events_by_dest.items():
-            logger.debug("Sending %d pdus to %s", len(pdus), destination)
-
-            self._get_per_destination_queue(destination).send_pdus(pdus)
+        for destination in destinations:
+            self._get_per_destination_queue(destination).send_pdu(pdu)
 
     async def send_read_receipt(self, receipt: ReadReceipt) -> None:
         """Send a RR to any other servers in the room
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 3bb66bce32..3b053ebcfb 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -154,22 +154,19 @@ class PerDestinationQueue:
             + len(self._pending_edus_keyed)
         )
 
-    def send_pdus(self, pdus: Iterable[EventBase]) -> None:
-        """Add PDUs to the queue, and start the transmission loop if necessary
+    def send_pdu(self, pdu: EventBase) -> None:
+        """Add a PDU to the queue, and start the transmission loop if necessary
 
         Args:
-            pdus: pdus to send
+            pdu: pdu to send
         """
         if not self._catching_up or self._last_successful_stream_ordering is None:
             # only enqueue the PDU if we are not catching up (False) or do not
             # yet know if we have anything to catch up (None)
-            self._pending_pdus.extend(pdus)
+            self._pending_pdus.append(pdu)
         else:
-            self._catchup_last_skipped = max(
-                pdu.internal_metadata.stream_ordering
-                for pdu in pdus
-                if pdu.internal_metadata.stream_ordering is not None
-            )
+            assert pdu.internal_metadata.stream_ordering
+            self._catchup_last_skipped = pdu.internal_metadata.stream_ordering
 
         self.attempt_new_transaction()
 
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index b28ca61f80..82335e7a9d 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -14,7 +14,7 @@
 
 import logging
 from collections import namedtuple
-from typing import Dict, List, Optional, Tuple
+from typing import Iterable, List, Optional, Tuple
 
 from canonicaljson import encode_canonical_json
 
@@ -295,33 +295,37 @@ class TransactionStore(TransactionWorkerStore):
                 },
             )
 
-    async def bulk_store_destination_rooms_entries(
-        self, room_and_destination_to_ordering: Dict[Tuple[str, str], int]
-    ):
+    async def store_destination_rooms_entries(
+        self,
+        destinations: Iterable[str],
+        room_id: str,
+        stream_ordering: int,
+    ) -> None:
         """
-        Updates or creates `destination_rooms` entries for a number of events.
+        Updates or creates `destination_rooms` entries in batch for a single event.
 
         Args:
-            room_and_destination_to_ordering: A mapping of (room, destination) -> stream_id
+            destinations: list of destinations
+            room_id: the room_id of the event
+            stream_ordering: the stream_ordering of the event
         """
 
         await self.db_pool.simple_upsert_many(
             table="destinations",
             key_names=("destination",),
-            key_values={(d,) for _, d in room_and_destination_to_ordering.keys()},
+            key_values=[(d,) for d in destinations],
             value_names=[],
             value_values=[],
             desc="store_destination_rooms_entries_dests",
         )
 
+        rows = [(destination, room_id) for destination in destinations]
         await self.db_pool.simple_upsert_many(
             table="destination_rooms",
-            key_names=("room_id", "destination"),
-            key_values=list(room_and_destination_to_ordering.keys()),
+            key_names=("destination", "room_id"),
+            key_values=rows,
             value_names=["stream_ordering"],
-            value_values=[
-                (stream_id,) for stream_id in room_and_destination_to_ordering.values()
-            ],
+            value_values=[(stream_ordering,)] * len(rows),
             desc="store_destination_rooms_entries_rooms",
         )