summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/federation/sender/__init__.py14
-rw-r--r--synapse/storage/databases/main/transactions.py99
-rw-r--r--synapse/util/__init__.py44
3 files changed, 126 insertions, 31 deletions
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 9ed6fc98b5..1888480881 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -192,10 +192,9 @@ sent_pdus_destination_dist_total = Counter(
 )
 
 # Time (in s) to wait before trying to wake up destinations that have
-# catch-up outstanding. This will also be the delay applied at startup
-# before trying the same.
+# catch-up outstanding.
 # Please note that rate limiting still applies, so while the loop is
-# executed every X seconds the destinations may not be wake up because
+# executed every X seconds the destinations may not be woken up because
 # they are being rate limited following previous attempt failures.
 WAKEUP_RETRY_PERIOD_SEC = 60
 
@@ -428,18 +427,17 @@ class FederationSender(AbstractFederationSender):
             / hs.config.ratelimiting.federation_rr_transactions_per_room_per_second
         )
 
+        self._external_cache = hs.get_external_cache()
+        self._destination_wakeup_queue = _DestinationWakeupQueue(self, self.clock)
+
         # Regularly wake up destinations that have outstanding PDUs to be caught up
-        self.clock.looping_call(
+        self.clock.looping_call_now(
             run_as_background_process,
             WAKEUP_RETRY_PERIOD_SEC * 1000.0,
             "wake_destinations_needing_catchup",
             self._wake_destinations_needing_catchup,
         )
 
-        self._external_cache = hs.get_external_cache()
-
-        self._destination_wakeup_queue = _DestinationWakeupQueue(self, self.clock)
-
     def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
         """Get or create a PerDestinationQueue for the given destination
 
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index c91c44818f..08e0241f68 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -423,8 +423,11 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
         self, after_destination: Optional[str]
     ) -> List[str]:
         """
-        Gets at most 25 destinations which have outstanding PDUs to be caught up,
-        and are not being backed off from
+        Get a list of destinations we should retry transaction sending to.
+
+        Returns up to 25 destinations which have outstanding PDUs or to-device messages,
+        and are not subject to a backoff.
+
         Args:
             after_destination:
                 If provided, all destinations must be lexicographically greater
@@ -448,30 +451,86 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
     def _get_catch_up_outstanding_destinations_txn(
         txn: LoggingTransaction, now_time_ms: int, after_destination: Optional[str]
     ) -> List[str]:
+        # We're looking for destinations which satisfy either of the following
+        # conditions:
+        #
+        #   * There is at least one room where we have an event that we have not yet
+        #     sent to them, indicated by a row in `destination_rooms` with a
+        #     `stream_ordering` older than the `last_successful_stream_ordering`
+        #     (if any) in `destinations`, or:
+        #
+        #   * There is at least one to-device message outstanding for the destination,
+        #     indicated by a row in `device_federation_outbox`.
+        #
+        # Of course, that may produce destinations where we are already busy sending
+        # the relevant PDU or to-device message, but in that case, waking up the
+        # sender will just be a no-op.
+        #
+        # From those two lists, we need to *exclude* destinations which are subject
+        # to a backoff (ie, where `destinations.retry_last_ts + destinations.retry_interval`
+        # is in the future). There is also an edge-case where, if the server was
+        # previously shut down in the middle of the first send attempt to a given
+        # destination, there may be no row in `destinations` at all; we need to include
+        # such rows in the output, which means we need to left-join rather than
+        # inner-join against `destinations`.
+        #
+        # The two sources of destinations (`destination_rooms` and
+        # `device_federation_outbox`) are queried separately and UNIONed; but the list
+        # may be very long, and we don't want to return all the rows at once. We
+        # therefore sort the output and just return the first 25 rows. Obviously that
+        # means there is no point in either of the inner queries returning more than
+        # 25 results, since any further results are certain to be dropped by the outer
+        # LIMIT. In order to help the query-optimiser understand that, we *also* sort
+        # and limit the *inner* queries, hence we express them as CTEs rather than
+        # sub-queries.
+        #
+        # (NB: we make sure to do the top-level sort and limit on the database, rather
+        # than making two queries and combining the result in Python. We could otherwise
+        # suffer from slight differences in sort order between Python and the database,
+        # which would make the `after_destination` condition unreliable.)
+
         q = """
-            SELECT DISTINCT destination FROM destinations
-            INNER JOIN destination_rooms USING (destination)
-                WHERE
-                    stream_ordering > last_successful_stream_ordering
-                    AND destination > ?
-                    AND (
-                        retry_last_ts IS NULL OR
-                        retry_last_ts + retry_interval < ?
-                    )
-                    ORDER BY destination
-                    LIMIT 25
+        WITH pdu_destinations AS (
+            SELECT DISTINCT destination FROM destination_rooms
+            LEFT JOIN destinations USING (destination)
+            WHERE
+                destination > ?
+                AND destination_rooms.stream_ordering > COALESCE(destinations.last_successful_stream_ordering, 0)
+                AND (
+                    destinations.retry_last_ts IS NULL OR
+                    destinations.retry_last_ts + destinations.retry_interval < ?
+                )
+            ORDER BY destination
+            LIMIT 25
+        ), to_device_destinations AS (
+            SELECT DISTINCT destination FROM device_federation_outbox
+            LEFT JOIN destinations USING (destination)
+            WHERE
+               destination > ?
+               AND (
+                    destinations.retry_last_ts IS NULL OR
+                    destinations.retry_last_ts + destinations.retry_interval < ?
+               )
+            ORDER BY destination
+            LIMIT 25
+        )
+
+        SELECT destination FROM pdu_destinations
+        UNION SELECT destination FROM to_device_destinations
+            ORDER BY destination
+            LIMIT 25
         """
+
+        # everything is lexicographically greater than "" so this gives
+        # us the first batch of up to 25.
+        after_destination = after_destination or ""
+
         txn.execute(
             q,
-            (
-                # everything is lexicographically greater than "" so this gives
-                # us the first batch of up to 25.
-                after_destination or "",
-                now_time_ms,
-            ),
+            (after_destination, now_time_ms, after_destination, now_time_ms),
         )
-
         destinations = [row[0] for row in txn]
+
         return destinations
 
     async def get_destinations_paginate(
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 9e374354ec..e0d876e84b 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -117,7 +117,11 @@ class Clock:
         return int(self.time() * 1000)
 
     def looping_call(
-        self, f: Callable[P, object], msec: float, *args: P.args, **kwargs: P.kwargs
+        self,
+        f: Callable[P, object],
+        msec: float,
+        *args: P.args,
+        **kwargs: P.kwargs,
     ) -> LoopingCall:
         """Call a function repeatedly.
 
@@ -134,12 +138,46 @@ class Clock:
         Args:
             f: The function to call repeatedly.
             msec: How long to wait between calls in milliseconds.
-            *args: Postional arguments to pass to function.
+            *args: Positional arguments to pass to function.
             **kwargs: Key arguments to pass to function.
         """
+        return self._looping_call_common(f, msec, False, *args, **kwargs)
+
+    def looping_call_now(
+        self,
+        f: Callable[P, object],
+        msec: float,
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> LoopingCall:
+        """Call a function immediately, and then repeatedly thereafter.
+
+        As with `looping_call`: subsequent calls are not scheduled until after the
+        the Awaitable returned by a previous call has finished.
+
+        Also as with `looping_call`: the function is called with no logcontext and
+        you probably want to wrap it in `run_as_background_process`.
+
+        Args:
+            f: The function to call repeatedly.
+            msec: How long to wait between calls in milliseconds.
+            *args: Positional arguments to pass to function.
+            **kwargs: Key arguments to pass to function.
+        """
+        return self._looping_call_common(f, msec, True, *args, **kwargs)
+
+    def _looping_call_common(
+        self,
+        f: Callable[P, object],
+        msec: float,
+        now: bool,
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> LoopingCall:
+        """Common functionality for `looping_call` and `looping_call_now`"""
         call = task.LoopingCall(f, *args, **kwargs)
         call.clock = self._reactor
-        d = call.start(msec / 1000.0, now=False)
+        d = call.start(msec / 1000.0, now=now)
         d.addErrback(log_failure, "Looping call died", consumeErrors=False)
         return call