diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 9ed6fc98b5..1888480881 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -192,10 +192,9 @@ sent_pdus_destination_dist_total = Counter(
)
# Time (in s) to wait before trying to wake up destinations that have
-# catch-up outstanding. This will also be the delay applied at startup
-# before trying the same.
+# catch-up outstanding.
# Please note that rate limiting still applies, so while the loop is
-# executed every X seconds the destinations may not be wake up because
+# executed every X seconds the destinations may not be woken up because
# they are being rate limited following previous attempt failures.
WAKEUP_RETRY_PERIOD_SEC = 60
@@ -428,18 +427,17 @@ class FederationSender(AbstractFederationSender):
/ hs.config.ratelimiting.federation_rr_transactions_per_room_per_second
)
+ self._external_cache = hs.get_external_cache()
+ self._destination_wakeup_queue = _DestinationWakeupQueue(self, self.clock)
+
# Regularly wake up destinations that have outstanding PDUs to be caught up
- self.clock.looping_call(
+ self.clock.looping_call_now(
run_as_background_process,
WAKEUP_RETRY_PERIOD_SEC * 1000.0,
"wake_destinations_needing_catchup",
self._wake_destinations_needing_catchup,
)
- self._external_cache = hs.get_external_cache()
-
- self._destination_wakeup_queue = _DestinationWakeupQueue(self, self.clock)
-
def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
"""Get or create a PerDestinationQueue for the given destination
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index c91c44818f..08e0241f68 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -423,8 +423,11 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
self, after_destination: Optional[str]
) -> List[str]:
"""
- Gets at most 25 destinations which have outstanding PDUs to be caught up,
- and are not being backed off from
+ Get a list of destinations we should retry transaction sending to.
+
+ Returns up to 25 destinations which have outstanding PDUs or to-device messages,
+ and are not subject to a backoff.
+
Args:
after_destination:
If provided, all destinations must be lexicographically greater
@@ -448,30 +451,86 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
def _get_catch_up_outstanding_destinations_txn(
txn: LoggingTransaction, now_time_ms: int, after_destination: Optional[str]
) -> List[str]:
+ # We're looking for destinations which satisfy either of the following
+ # conditions:
+ #
+ # * There is at least one room where we have an event that we have not yet
+ # sent to them, indicated by a row in `destination_rooms` with a
+ # `stream_ordering` older than the `last_successful_stream_ordering`
+ # (if any) in `destinations`, or:
+ #
+ # * There is at least one to-device message outstanding for the destination,
+ # indicated by a row in `device_federation_outbox`.
+ #
+ # Of course, that may produce destinations where we are already busy sending
+ # the relevant PDU or to-device message, but in that case, waking up the
+ # sender will just be a no-op.
+ #
+ # From those two lists, we need to *exclude* destinations which are subject
+ # to a backoff (ie, where `destinations.retry_last_ts + destinations.retry_interval`
+ # is in the future). There is also an edge-case where, if the server was
+ # previously shut down in the middle of the first send attempt to a given
+ # destination, there may be no row in `destinations` at all; we need to include
+ # such rows in the output, which means we need to left-join rather than
+ # inner-join against `destinations`.
+ #
+ # The two sources of destinations (`destination_rooms` and
+ # `device_federation_outbox`) are queried separately and UNIONed; but the list
+ # may be very long, and we don't want to return all the rows at once. We
+ # therefore sort the output and just return the first 25 rows. Obviously that
+ # means there is no point in either of the inner queries returning more than
+ # 25 results, since any further results are certain to be dropped by the outer
+ # LIMIT. In order to help the query-optimiser understand that, we *also* sort
+ # and limit the *inner* queries, hence we express them as CTEs rather than
+ # sub-queries.
+ #
+ # (NB: we make sure to do the top-level sort and limit on the database, rather
+ # than making two queries and combining the result in Python. We could otherwise
+ # suffer from slight differences in sort order between Python and the database,
+ # which would make the `after_destination` condition unreliable.)
+
q = """
- SELECT DISTINCT destination FROM destinations
- INNER JOIN destination_rooms USING (destination)
- WHERE
- stream_ordering > last_successful_stream_ordering
- AND destination > ?
- AND (
- retry_last_ts IS NULL OR
- retry_last_ts + retry_interval < ?
- )
- ORDER BY destination
- LIMIT 25
+ WITH pdu_destinations AS (
+ SELECT DISTINCT destination FROM destination_rooms
+ LEFT JOIN destinations USING (destination)
+ WHERE
+ destination > ?
+ AND destination_rooms.stream_ordering > COALESCE(destinations.last_successful_stream_ordering, 0)
+ AND (
+ destinations.retry_last_ts IS NULL OR
+ destinations.retry_last_ts + destinations.retry_interval < ?
+ )
+ ORDER BY destination
+ LIMIT 25
+ ), to_device_destinations AS (
+ SELECT DISTINCT destination FROM device_federation_outbox
+ LEFT JOIN destinations USING (destination)
+ WHERE
+ destination > ?
+ AND (
+ destinations.retry_last_ts IS NULL OR
+ destinations.retry_last_ts + destinations.retry_interval < ?
+ )
+ ORDER BY destination
+ LIMIT 25
+ )
+
+ SELECT destination FROM pdu_destinations
+ UNION SELECT destination FROM to_device_destinations
+ ORDER BY destination
+ LIMIT 25
"""
+
+ # everything is lexicographically greater than "" so this gives
+ # us the first batch of up to 25.
+ after_destination = after_destination or ""
+
txn.execute(
q,
- (
- # everything is lexicographically greater than "" so this gives
- # us the first batch of up to 25.
- after_destination or "",
- now_time_ms,
- ),
+ (after_destination, now_time_ms, after_destination, now_time_ms),
)
-
destinations = [row[0] for row in txn]
+
return destinations
async def get_destinations_paginate(
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 9e374354ec..e0d876e84b 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -117,7 +117,11 @@ class Clock:
return int(self.time() * 1000)
def looping_call(
- self, f: Callable[P, object], msec: float, *args: P.args, **kwargs: P.kwargs
+ self,
+ f: Callable[P, object],
+ msec: float,
+ *args: P.args,
+ **kwargs: P.kwargs,
) -> LoopingCall:
"""Call a function repeatedly.
@@ -134,12 +138,46 @@ class Clock:
Args:
f: The function to call repeatedly.
msec: How long to wait between calls in milliseconds.
- *args: Postional arguments to pass to function.
+ *args: Positional arguments to pass to function.
**kwargs: Key arguments to pass to function.
"""
+ return self._looping_call_common(f, msec, False, *args, **kwargs)
+
+ def looping_call_now(
+ self,
+ f: Callable[P, object],
+ msec: float,
+ *args: P.args,
+ **kwargs: P.kwargs,
+ ) -> LoopingCall:
+ """Call a function immediately, and then repeatedly thereafter.
+
+ As with `looping_call`: subsequent calls are not scheduled until after the
+ the Awaitable returned by a previous call has finished.
+
+ Also as with `looping_call`: the function is called with no logcontext and
+ you probably want to wrap it in `run_as_background_process`.
+
+ Args:
+ f: The function to call repeatedly.
+ msec: How long to wait between calls in milliseconds.
+ *args: Positional arguments to pass to function.
+ **kwargs: Key arguments to pass to function.
+ """
+ return self._looping_call_common(f, msec, True, *args, **kwargs)
+
+ def _looping_call_common(
+ self,
+ f: Callable[P, object],
+ msec: float,
+ now: bool,
+ *args: P.args,
+ **kwargs: P.kwargs,
+ ) -> LoopingCall:
+ """Common functionality for `looping_call` and `looping_call_now`"""
call = task.LoopingCall(f, *args, **kwargs)
call.clock = self._reactor
- d = call.start(msec / 1000.0, now=False)
+ d = call.start(msec / 1000.0, now=now)
d.addErrback(log_failure, "Looping call died", consumeErrors=False)
return call
|