diff --git a/changelog.d/12500.misc b/changelog.d/12500.misc
new file mode 100644
index 0000000000..dbe3f7f5d1
--- /dev/null
+++ b/changelog.d/12500.misc
@@ -0,0 +1 @@
+Immediately retry any requests that have backed off when a server comes back online.
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index e686445955..c2ec3caa0e 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -73,7 +73,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.types import JsonDict
from synapse.util import json_decoder
-from synapse.util.async_helpers import timeout_deferred
+from synapse.util.async_helpers import AwakenableSleeper, timeout_deferred
from synapse.util.metrics import Measure
if TYPE_CHECKING:
@@ -353,6 +353,13 @@ class MatrixFederationHttpClient:
self._cooperator = Cooperator(scheduler=schedule)
+ self._sleeper = AwakenableSleeper(self.reactor)
+
+ def wake_destination(self, destination: str) -> None:
+ """Called when the remote server may have come back online."""
+
+ self._sleeper.wake(destination)
+
async def _send_request_with_optional_trailing_slash(
self,
request: MatrixFederationRequest,
@@ -474,6 +481,8 @@ class MatrixFederationHttpClient:
self._store,
backoff_on_404=backoff_on_404,
ignore_backoff=ignore_backoff,
+ notifier=self.hs.get_notifier(),
+ replication_client=self.hs.get_replication_command_handler(),
)
method_bytes = request.method.encode("ascii")
@@ -664,7 +673,9 @@ class MatrixFederationHttpClient:
delay,
)
- await self.clock.sleep(delay)
+ # Sleep for the calculated delay, or wake up immediately
+ # if we get notified that the server is back up.
+ await self._sleeper.sleep(request.destination, delay * 1000)
retries_left -= 1
else:
raise
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 16d15a1f33..01a50b9d62 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -228,9 +228,7 @@ class Notifier:
# Called when there are new things to stream over replication
self.replication_callbacks: List[Callable[[], None]] = []
- # Called when remote servers have come back online after having been
- # down.
- self.remote_server_up_callbacks: List[Callable[[str], None]] = []
+ self._federation_client = hs.get_federation_http_client()
self._third_party_rules = hs.get_third_party_event_rules()
@@ -731,3 +729,7 @@ class Notifier:
# circular dependencies.
if self.federation_sender:
self.federation_sender.wake_destination(server)
+
+ # Tell the federation client about the fact the server is back up, so
+ # that any in flight requests can be immediately retried.
+ self._federation_client.wake_destination(server)
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index b91020117f..7f1d41eb3c 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -778,3 +778,60 @@ def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]:
new_deferred: "defer.Deferred[T]" = defer.Deferred(handle_cancel)
deferred.chainDeferred(new_deferred)
return new_deferred
+
+
+class AwakenableSleeper:
+ """Allows explicitly waking up deferreds related to an entity that are
+ currently sleeping.
+ """
+
+ def __init__(self, reactor: IReactorTime) -> None:
+ self._streams: Dict[str, Set[defer.Deferred[None]]] = {}
+ self._reactor = reactor
+
+ def wake(self, name: str) -> None:
+ """Wake everything related to `name` that is currently sleeping."""
+ stream_set = self._streams.pop(name, set())
+ for deferred in stream_set:
+ try:
+ with PreserveLoggingContext():
+ deferred.callback(None)
+ except Exception:
+ pass
+
+ async def sleep(self, name: str, delay_ms: int) -> None:
+ """Sleep for the given number of milliseconds, or return if the given
+ `name` is explicitly woken up.
+ """
+
+ # Create a deferred that gets called in N seconds
+ sleep_deferred: "defer.Deferred[None]" = defer.Deferred()
+ call = self._reactor.callLater(delay_ms / 1000, sleep_deferred.callback, None)
+
+ # Create a deferred that will get called if `wake` is called with
+ # the same `name`.
+ stream_set = self._streams.setdefault(name, set())
+ notify_deferred: "defer.Deferred[None]" = defer.Deferred()
+ stream_set.add(notify_deferred)
+
+ try:
+ # Wait for either the delay or for `wake` to be called.
+ await make_deferred_yieldable(
+ defer.DeferredList(
+ [sleep_deferred, notify_deferred],
+ fireOnOneCallback=True,
+ fireOnOneErrback=True,
+ consumeErrors=True,
+ )
+ )
+ finally:
+ # Clean up the state
+ curr_stream_set = self._streams.get(name)
+ if curr_stream_set is not None:
+ curr_stream_set.discard(notify_deferred)
+ if len(curr_stream_set) == 0:
+ self._streams.pop(name)
+
+ # Cancel the sleep if we were woken up
+ if call.active():
+ call.cancel()
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index d81f2527d7..81bfed268e 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -14,13 +14,17 @@
import logging
import random
from types import TracebackType
-from typing import Any, Optional, Type
+from typing import TYPE_CHECKING, Any, Optional, Type
import synapse.logging.context
from synapse.api.errors import CodeMessageException
from synapse.storage import DataStore
from synapse.util import Clock
+if TYPE_CHECKING:
+ from synapse.notifier import Notifier
+ from synapse.replication.tcp.handler import ReplicationCommandHandler
+
logger = logging.getLogger(__name__)
# the initial backoff, after the first transaction fails
@@ -131,6 +135,8 @@ class RetryDestinationLimiter:
retry_interval: int,
backoff_on_404: bool = False,
backoff_on_failure: bool = True,
+ notifier: Optional["Notifier"] = None,
+ replication_client: Optional["ReplicationCommandHandler"] = None,
):
"""Marks the destination as "down" if an exception is thrown in the
context, except for CodeMessageException with code < 500.
@@ -160,6 +166,9 @@ class RetryDestinationLimiter:
self.backoff_on_404 = backoff_on_404
self.backoff_on_failure = backoff_on_failure
+ self.notifier = notifier
+ self.replication_client = replication_client
+
def __enter__(self) -> None:
pass
@@ -239,6 +248,19 @@ class RetryDestinationLimiter:
retry_last_ts,
self.retry_interval,
)
+
+ if self.notifier:
+ # Inform the relevant places that the remote server is back up.
+ self.notifier.notify_remote_server_up(self.destination)
+
+ if self.replication_client:
+ # If we're on a worker we try and inform master about this. The
+ # replication client doesn't hook into the notifier to avoid
+ # infinite loops where we send a `REMOTE_SERVER_UP` command to
+ # master, which then echoes it back to us which in turn pokes
+ # the notifier.
+ self.replication_client.send_remote_server_up(self.destination)
+
except Exception:
logger.exception("Failed to store destination_retry_timings")
diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py
index daacc54c72..9d5010bf92 100644
--- a/tests/util/test_async_helpers.py
+++ b/tests/util/test_async_helpers.py
@@ -28,6 +28,7 @@ from synapse.logging.context import (
make_deferred_yieldable,
)
from synapse.util.async_helpers import (
+ AwakenableSleeper,
ObservableDeferred,
concurrently_execute,
delay_cancellation,
@@ -35,6 +36,7 @@ from synapse.util.async_helpers import (
timeout_deferred,
)
+from tests.server import get_clock
from tests.unittest import TestCase
@@ -496,3 +498,81 @@ class DelayCancellationTests(TestCase):
# logging context.
blocking_d.callback(None)
self.successResultOf(d)
+
+
+class AwakenableSleeperTests(TestCase):
+ "Tests AwakenableSleeper"
+
+ def test_sleep(self):
+ reactor, _ = get_clock()
+ sleeper = AwakenableSleeper(reactor)
+
+ d = defer.ensureDeferred(sleeper.sleep("name", 1000))
+
+ reactor.pump([0.0])
+ self.assertFalse(d.called)
+
+ reactor.advance(0.5)
+ self.assertFalse(d.called)
+
+ reactor.advance(0.6)
+ self.assertTrue(d.called)
+
+ def test_explicit_wake(self):
+ reactor, _ = get_clock()
+ sleeper = AwakenableSleeper(reactor)
+
+ d = defer.ensureDeferred(sleeper.sleep("name", 1000))
+
+ reactor.pump([0.0])
+ self.assertFalse(d.called)
+
+ reactor.advance(0.5)
+ self.assertFalse(d.called)
+
+ sleeper.wake("name")
+ self.assertTrue(d.called)
+
+ reactor.advance(0.6)
+
+ def test_multiple_sleepers_timeout(self):
+ reactor, _ = get_clock()
+ sleeper = AwakenableSleeper(reactor)
+
+ d1 = defer.ensureDeferred(sleeper.sleep("name", 1000))
+
+ reactor.advance(0.6)
+ self.assertFalse(d1.called)
+
+ # Add another sleeper
+ d2 = defer.ensureDeferred(sleeper.sleep("name", 1000))
+
+ # Only the first sleep should time out now.
+ reactor.advance(0.6)
+ self.assertTrue(d1.called)
+ self.assertFalse(d2.called)
+
+ reactor.advance(0.6)
+ self.assertTrue(d2.called)
+
+ def test_multiple_sleepers_wake(self):
+ reactor, _ = get_clock()
+ sleeper = AwakenableSleeper(reactor)
+
+ d1 = defer.ensureDeferred(sleeper.sleep("name", 1000))
+
+ reactor.advance(0.5)
+ self.assertFalse(d1.called)
+
+ # Add another sleeper
+ d2 = defer.ensureDeferred(sleeper.sleep("name", 1000))
+
+ # Neither should fire yet
+ reactor.advance(0.3)
+ self.assertFalse(d1.called)
+ self.assertFalse(d2.called)
+
+ # Explicitly waking both up works
+ sleeper.wake("name")
+ self.assertTrue(d1.called)
+ self.assertTrue(d2.called)
|