summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/sender/per_destination_queue.py287
-rw-r--r--synapse/federation/sender/transaction_manager.py48
2 files changed, 182 insertions, 153 deletions
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index deb519f3ef..cc0d765e5f 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -17,6 +17,7 @@ import datetime
 import logging
 from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, cast
 
+import attr
 from prometheus_client import Counter
 
 from synapse.api.errors import (
@@ -93,6 +94,10 @@ class PerDestinationQueue:
         self._destination = destination
         self.transmission_loop_running = False
 
+        # Flag to signal to any running transmission loop that there is new data
+        # queued up to be sent.
+        self._new_data_to_send = False
+
         # True whilst we are sending events that the remote homeserver missed
         # because it was unreachable. We start in this state so we can perform
         # catch-up at startup.
@@ -108,7 +113,7 @@ class PerDestinationQueue:
         # destination (we are the only updater so this is safe)
         self._last_successful_stream_ordering = None  # type: Optional[int]
 
-        # a list of pending PDUs
+        # a queue of pending PDUs
         self._pending_pdus = []  # type: List[EventBase]
 
         # XXX this is never actually used: see
@@ -208,6 +213,10 @@ class PerDestinationQueue:
         transaction in the background.
         """
 
+        # Mark that we (may) have new things to send, so that any running
+        # transmission loop will recheck whether there is stuff to send.
+        self._new_data_to_send = True
+
         if self.transmission_loop_running:
             # XXX: this can get stuck on by a never-ending
             # request at which point pending_pdus just keeps growing.
@@ -250,125 +259,41 @@ class PerDestinationQueue:
 
             pending_pdus = []
             while True:
-                # We have to keep 2 free slots for presence and rr_edus
-                limit = MAX_EDUS_PER_TRANSACTION - 2
-
-                device_update_edus, dev_list_id = await self._get_device_update_edus(
-                    limit
-                )
-
-                limit -= len(device_update_edus)
-
-                (
-                    to_device_edus,
-                    device_stream_id,
-                ) = await self._get_to_device_message_edus(limit)
-
-                pending_edus = device_update_edus + to_device_edus
-
-                # BEGIN CRITICAL SECTION
-                #
-                # In order to avoid a race condition, we need to make sure that
-                # the following code (from popping the queues up to the point
-                # where we decide if we actually have any pending messages) is
-                # atomic - otherwise new PDUs or EDUs might arrive in the
-                # meantime, but not get sent because we hold the
-                # transmission_loop_running flag.
-
-                pending_pdus = self._pending_pdus
+                self._new_data_to_send = False
 
-                # We can only include at most 50 PDUs per transactions
-                pending_pdus, self._pending_pdus = pending_pdus[:50], pending_pdus[50:]
+                async with _TransactionQueueManager(self) as (
+                    pending_pdus,
+                    pending_edus,
+                ):
+                    if not pending_pdus and not pending_edus:
+                        logger.debug("TX [%s] Nothing to send", self._destination)
+
+                        # If we've gotten told about new things to send during
+                        # checking for things to send, we try looking again.
+                        # Otherwise new PDUs or EDUs might arrive in the meantime,
+                        # but not get sent because we hold the
+                        # `transmission_loop_running` flag.
+                        if self._new_data_to_send:
+                            continue
+                        else:
+                            return
 
-                pending_edus.extend(self._get_rr_edus(force_flush=False))
-                pending_presence = self._pending_presence
-                self._pending_presence = {}
-                if pending_presence:
-                    pending_edus.append(
-                        Edu(
-                            origin=self._server_name,
-                            destination=self._destination,
-                            edu_type="m.presence",
-                            content={
-                                "push": [
-                                    format_user_presence_state(
-                                        presence, self._clock.time_msec()
-                                    )
-                                    for presence in pending_presence.values()
-                                ]
-                            },
+                    if pending_pdus:
+                        logger.debug(
+                            "TX [%s] len(pending_pdus_by_dest[dest]) = %d",
+                            self._destination,
+                            len(pending_pdus),
                         )
-                    )
 
-                pending_edus.extend(
-                    self._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus))
-                )
-                while (
-                    len(pending_edus) < MAX_EDUS_PER_TRANSACTION
-                    and self._pending_edus_keyed
-                ):
-                    _, val = self._pending_edus_keyed.popitem()
-                    pending_edus.append(val)
-
-                if pending_pdus:
-                    logger.debug(
-                        "TX [%s] len(pending_pdus_by_dest[dest]) = %d",
-                        self._destination,
-                        len(pending_pdus),
+                    await self._transaction_manager.send_new_transaction(
+                        self._destination, pending_pdus, pending_edus
                     )
 
-                if not pending_pdus and not pending_edus:
-                    logger.debug("TX [%s] Nothing to send", self._destination)
-                    self._last_device_stream_id = device_stream_id
-                    return
-
-                # if we've decided to send a transaction anyway, and we have room, we
-                # may as well send any pending RRs
-                if len(pending_edus) < MAX_EDUS_PER_TRANSACTION:
-                    pending_edus.extend(self._get_rr_edus(force_flush=True))
-
-                # END CRITICAL SECTION
-
-                success = await self._transaction_manager.send_new_transaction(
-                    self._destination, pending_pdus, pending_edus
-                )
-                if success:
                     sent_transactions_counter.inc()
                     sent_edus_counter.inc(len(pending_edus))
                     for edu in pending_edus:
                         sent_edus_by_type.labels(edu.edu_type).inc()
-                    # Remove the acknowledged device messages from the database
-                    # Only bother if we actually sent some device messages
-                    if to_device_edus:
-                        await self._store.delete_device_msgs_for_remote(
-                            self._destination, device_stream_id
-                        )
 
-                    # also mark the device updates as sent
-                    if device_update_edus:
-                        logger.info(
-                            "Marking as sent %r %r", self._destination, dev_list_id
-                        )
-                        await self._store.mark_as_sent_devices_by_remote(
-                            self._destination, dev_list_id
-                        )
-
-                    self._last_device_stream_id = device_stream_id
-                    self._last_device_list_stream_id = dev_list_id
-
-                    if pending_pdus:
-                        # we sent some PDUs and it was successful, so update our
-                        # last_successful_stream_ordering in the destinations table.
-                        final_pdu = pending_pdus[-1]
-                        last_successful_stream_ordering = (
-                            final_pdu.internal_metadata.stream_ordering
-                        )
-                        assert last_successful_stream_ordering
-                        await self._store.set_destination_last_successful_stream_ordering(
-                            self._destination, last_successful_stream_ordering
-                        )
-                else:
-                    break
         except NotRetryingDestination as e:
             logger.debug(
                 "TX [%s] not ready for retry yet (next retry at %s) - "
@@ -401,7 +326,7 @@ class PerDestinationQueue:
                 self._pending_presence = {}
                 self._pending_rrs = {}
 
-            self._start_catching_up()
+                self._start_catching_up()
         except FederationDeniedError as e:
             logger.info(e)
         except HttpResponseException as e:
@@ -412,7 +337,6 @@ class PerDestinationQueue:
                 e,
             )
 
-            self._start_catching_up()
         except RequestSendFailed as e:
             logger.warning(
                 "TX [%s] Failed to send transaction: %s", self._destination, e
@@ -422,16 +346,12 @@ class PerDestinationQueue:
                 logger.info(
                     "Failed to send event %s to %s", p.event_id, self._destination
                 )
-
-            self._start_catching_up()
         except Exception:
             logger.exception("TX [%s] Failed to send transaction", self._destination)
             for p in pending_pdus:
                 logger.info(
                     "Failed to send event %s to %s", p.event_id, self._destination
                 )
-
-            self._start_catching_up()
         finally:
             # We want to be *very* sure we clear this after we stop processing
             self.transmission_loop_running = False
@@ -499,13 +419,10 @@ class PerDestinationQueue:
                 rooms = [p.room_id for p in catchup_pdus]
                 logger.info("Catching up rooms to %s: %r", self._destination, rooms)
 
-            success = await self._transaction_manager.send_new_transaction(
+            await self._transaction_manager.send_new_transaction(
                 self._destination, catchup_pdus, []
             )
 
-            if not success:
-                return
-
             sent_transactions_counter.inc()
             final_pdu = catchup_pdus[-1]
             self._last_successful_stream_ordering = cast(
@@ -584,3 +501,135 @@ class PerDestinationQueue:
         """
         self._catching_up = True
         self._pending_pdus = []
+
+
+@attr.s(slots=True)
+class _TransactionQueueManager:
+    """A helper async context manager for pulling stuff off the queues and
+    tracking what was last successfully sent, etc.
+    """
+
+    queue = attr.ib(type=PerDestinationQueue)
+
+    _device_stream_id = attr.ib(type=Optional[int], default=None)
+    _device_list_id = attr.ib(type=Optional[int], default=None)
+    _last_stream_ordering = attr.ib(type=Optional[int], default=None)
+    _pdus = attr.ib(type=List[EventBase], factory=list)
+
+    async def __aenter__(self) -> Tuple[List[EventBase], List[Edu]]:
+        # First we calculate the EDUs we want to send, if any.
+
+        # We start by fetching device related EDUs, i.e device updates and to
+        # device messages. We have to keep 2 free slots for presence and rr_edus.
+        limit = MAX_EDUS_PER_TRANSACTION - 2
+
+        device_update_edus, dev_list_id = await self.queue._get_device_update_edus(
+            limit
+        )
+
+        if device_update_edus:
+            self._device_list_id = dev_list_id
+        else:
+            self.queue._last_device_list_stream_id = dev_list_id
+
+        limit -= len(device_update_edus)
+
+        (
+            to_device_edus,
+            device_stream_id,
+        ) = await self.queue._get_to_device_message_edus(limit)
+
+        if to_device_edus:
+            self._device_stream_id = device_stream_id
+        else:
+            self.queue._last_device_stream_id = device_stream_id
+
+        pending_edus = device_update_edus + to_device_edus
+
+        # Now add the read receipt EDU.
+        pending_edus.extend(self.queue._get_rr_edus(force_flush=False))
+
+        # And presence EDU.
+        if self.queue._pending_presence:
+            pending_edus.append(
+                Edu(
+                    origin=self.queue._server_name,
+                    destination=self.queue._destination,
+                    edu_type="m.presence",
+                    content={
+                        "push": [
+                            format_user_presence_state(
+                                presence, self.queue._clock.time_msec()
+                            )
+                            for presence in self.queue._pending_presence.values()
+                        ]
+                    },
+                )
+            )
+            self.queue._pending_presence = {}
+
+        # Finally add any other types of EDUs if there is room.
+        pending_edus.extend(
+            self.queue._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus))
+        )
+        while (
+            len(pending_edus) < MAX_EDUS_PER_TRANSACTION
+            and self.queue._pending_edus_keyed
+        ):
+            _, val = self.queue._pending_edus_keyed.popitem()
+            pending_edus.append(val)
+
+        # Now we look for any PDUs to send, by getting up to 50 PDUs from the
+        # queue
+        self._pdus = self.queue._pending_pdus[:50]
+
+        if not self._pdus and not pending_edus:
+            return [], []
+
+        # if we've decided to send a transaction anyway, and we have room, we
+        # may as well send any pending RRs
+        if len(pending_edus) < MAX_EDUS_PER_TRANSACTION:
+            pending_edus.extend(self.queue._get_rr_edus(force_flush=True))
+
+        if self._pdus:
+            self._last_stream_ordering = self._pdus[
+                -1
+            ].internal_metadata.stream_ordering
+            assert self._last_stream_ordering
+
+        return self._pdus, pending_edus
+
+    async def __aexit__(self, exc_type, exc, tb):
+        if exc_type is not None:
+            # Failed to send transaction, so we bail out.
+            return
+
+        # Successfully sent transactions, so we remove pending PDUs from the queue
+        if self._pdus:
+            self.queue._pending_pdus = self.queue._pending_pdus[len(self._pdus) :]
+
+        # Succeeded to send the transaction so we record where we have sent up
+        # to in the various streams
+
+        if self._device_stream_id:
+            await self.queue._store.delete_device_msgs_for_remote(
+                self.queue._destination, self._device_stream_id
+            )
+            self.queue._last_device_stream_id = self._device_stream_id
+
+        # also mark the device updates as sent
+        if self._device_list_id:
+            logger.info(
+                "Marking as sent %r %r", self.queue._destination, self._device_list_id
+            )
+            await self.queue._store.mark_as_sent_devices_by_remote(
+                self.queue._destination, self._device_list_id
+            )
+            self.queue._last_device_list_stream_id = self._device_list_id
+
+        if self._last_stream_ordering:
+            # we sent some PDUs and it was successful, so update our
+            # last_successful_stream_ordering in the destinations table.
+            await self.queue._store.set_destination_last_successful_stream_ordering(
+                self.queue._destination, self._last_stream_ordering
+            )
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 2a9cd063c4..07b740c2f2 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -69,15 +69,12 @@ class TransactionManager:
         destination: str,
         pdus: List[EventBase],
         edus: List[Edu],
-    ) -> bool:
+    ) -> None:
         """
         Args:
             destination: The destination to send to (e.g. 'example.org')
             pdus: In-order list of PDUs to send
             edus: List of EDUs to send
-
-        Returns:
-            True iff the transaction was successful
         """
 
         # Make a transaction-sending opentracing span. This span follows on from
@@ -96,8 +93,6 @@ class TransactionManager:
                 edu.strip_context()
 
         with start_active_span_follows_from("send_transaction", span_contexts):
-            success = True
-
             logger.debug("TX [%s] _attempt_new_transaction", destination)
 
             txn_id = str(self._next_txn_id)
@@ -152,44 +147,29 @@ class TransactionManager:
                 response = await self._transport_layer.send_transaction(
                     transaction, json_data_cb
                 )
-                code = 200
             except HttpResponseException as e:
                 code = e.code
                 response = e.response
 
-                if e.code in (401, 404, 429) or 500 <= e.code:
-                    logger.info(
-                        "TX [%s] {%s} got %d response", destination, txn_id, code
-                    )
-                    raise e
-
-            logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
-
-            if code == 200:
-                for e_id, r in response.get("pdus", {}).items():
-                    if "error" in r:
-                        logger.warning(
-                            "TX [%s] {%s} Remote returned error for %s: %s",
-                            destination,
-                            txn_id,
-                            e_id,
-                            r,
-                        )
-            else:
-                for p in pdus:
+                set_tag(tags.ERROR, True)
+
+                logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
+                raise
+
+            logger.info("TX [%s] {%s} got 200 response", destination, txn_id)
+
+            for e_id, r in response.get("pdus", {}).items():
+                if "error" in r:
                     logger.warning(
-                        "TX [%s] {%s} Failed to send event %s",
+                        "TX [%s] {%s} Remote returned error for %s: %s",
                         destination,
                         txn_id,
-                        p.event_id,
+                        e_id,
+                        r,
                     )
-                success = False
 
-            if success and pdus and destination in self._federation_metrics_domains:
+            if pdus and destination in self._federation_metrics_domains:
                 last_pdu = pdus[-1]
                 last_pdu_ts_metric.labels(server_name=destination).set(
                     last_pdu.origin_server_ts / 1000
                 )
-
-            set_tag(tags.ERROR, not success)
-            return success