summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/http/matrixfederationclient.py15
-rw-r--r--synapse/notifier.py8
-rw-r--r--synapse/util/async_helpers.py57
-rw-r--r--synapse/util/retryutils.py24
4 files changed, 98 insertions, 6 deletions
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index e686445955..c2ec3caa0e 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -73,7 +73,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.logging.opentracing import set_tag, start_active_span, tags
 from synapse.types import JsonDict
 from synapse.util import json_decoder
-from synapse.util.async_helpers import timeout_deferred
+from synapse.util.async_helpers import AwakenableSleeper, timeout_deferred
 from synapse.util.metrics import Measure
 
 if TYPE_CHECKING:
@@ -353,6 +353,13 @@ class MatrixFederationHttpClient:
 
         self._cooperator = Cooperator(scheduler=schedule)
 
+        self._sleeper = AwakenableSleeper(self.reactor)
+
+    def wake_destination(self, destination: str) -> None:
+        """Called when the remote server may have come back online."""
+
+        self._sleeper.wake(destination)
+
     async def _send_request_with_optional_trailing_slash(
         self,
         request: MatrixFederationRequest,
@@ -474,6 +481,8 @@ class MatrixFederationHttpClient:
             self._store,
             backoff_on_404=backoff_on_404,
             ignore_backoff=ignore_backoff,
+            notifier=self.hs.get_notifier(),
+            replication_client=self.hs.get_replication_command_handler(),
         )
 
         method_bytes = request.method.encode("ascii")
@@ -664,7 +673,9 @@ class MatrixFederationHttpClient:
                             delay,
                         )
 
-                        await self.clock.sleep(delay)
+                        # Sleep for the calculated delay, or wake up immediately
+                        # if we get notified that the server is back up.
+                        await self._sleeper.sleep(request.destination, delay * 1000)
                         retries_left -= 1
                     else:
                         raise
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 16d15a1f33..01a50b9d62 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -228,9 +228,7 @@ class Notifier:
         # Called when there are new things to stream over replication
         self.replication_callbacks: List[Callable[[], None]] = []
 
-        # Called when remote servers have come back online after having been
-        # down.
-        self.remote_server_up_callbacks: List[Callable[[str], None]] = []
+        self._federation_client = hs.get_federation_http_client()
 
         self._third_party_rules = hs.get_third_party_event_rules()
 
@@ -731,3 +729,7 @@ class Notifier:
         # circular dependencies.
         if self.federation_sender:
             self.federation_sender.wake_destination(server)
+
+        # Tell the federation client about the fact the server is back up, so
+        # that any in flight requests can be immediately retried.
+        self._federation_client.wake_destination(server)
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index b91020117f..7f1d41eb3c 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -778,3 +778,60 @@ def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]:
     new_deferred: "defer.Deferred[T]" = defer.Deferred(handle_cancel)
     deferred.chainDeferred(new_deferred)
     return new_deferred
+
+
+class AwakenableSleeper:
+    """Allows explicitly waking up deferreds related to an entity that are
+    currently sleeping.
+    """
+
+    def __init__(self, reactor: IReactorTime) -> None:
+        self._streams: Dict[str, Set[defer.Deferred[None]]] = {}
+        self._reactor = reactor
+
+    def wake(self, name: str) -> None:
+        """Wake everything related to `name` that is currently sleeping."""
+        stream_set = self._streams.pop(name, set())
+        for deferred in stream_set:
+            try:
+                with PreserveLoggingContext():
+                    deferred.callback(None)
+            except Exception:
+                pass
+
+    async def sleep(self, name: str, delay_ms: int) -> None:
+        """Sleep for the given number of milliseconds, or return if the given
+        `name` is explicitly woken up.
+        """
+
+        # Create a deferred that gets called in N seconds
+        sleep_deferred: "defer.Deferred[None]" = defer.Deferred()
+        call = self._reactor.callLater(delay_ms / 1000, sleep_deferred.callback, None)
+
+        # Create a deferred that will get called if `wake` is called with
+        # the same `name`.
+        stream_set = self._streams.setdefault(name, set())
+        notify_deferred: "defer.Deferred[None]" = defer.Deferred()
+        stream_set.add(notify_deferred)
+
+        try:
+            # Wait for either the delay or for `wake` to be called.
+            await make_deferred_yieldable(
+                defer.DeferredList(
+                    [sleep_deferred, notify_deferred],
+                    fireOnOneCallback=True,
+                    fireOnOneErrback=True,
+                    consumeErrors=True,
+                )
+            )
+        finally:
+            # Clean up the state
+            curr_stream_set = self._streams.get(name)
+            if curr_stream_set is not None:
+                curr_stream_set.discard(notify_deferred)
+                if len(curr_stream_set) == 0:
+                    self._streams.pop(name)
+
+            # Cancel the sleep if we were woken up
+            if call.active():
+                call.cancel()
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index d81f2527d7..81bfed268e 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -14,13 +14,17 @@
 import logging
 import random
 from types import TracebackType
-from typing import Any, Optional, Type
+from typing import TYPE_CHECKING, Any, Optional, Type
 
 import synapse.logging.context
 from synapse.api.errors import CodeMessageException
 from synapse.storage import DataStore
 from synapse.util import Clock
 
+if TYPE_CHECKING:
+    from synapse.notifier import Notifier
+    from synapse.replication.tcp.handler import ReplicationCommandHandler
+
 logger = logging.getLogger(__name__)
 
 # the initial backoff, after the first transaction fails
@@ -131,6 +135,8 @@ class RetryDestinationLimiter:
         retry_interval: int,
         backoff_on_404: bool = False,
         backoff_on_failure: bool = True,
+        notifier: Optional["Notifier"] = None,
+        replication_client: Optional["ReplicationCommandHandler"] = None,
     ):
         """Marks the destination as "down" if an exception is thrown in the
         context, except for CodeMessageException with code < 500.
@@ -160,6 +166,9 @@ class RetryDestinationLimiter:
         self.backoff_on_404 = backoff_on_404
         self.backoff_on_failure = backoff_on_failure
 
+        self.notifier = notifier
+        self.replication_client = replication_client
+
     def __enter__(self) -> None:
         pass
 
@@ -239,6 +248,19 @@ class RetryDestinationLimiter:
                     retry_last_ts,
                     self.retry_interval,
                 )
+
+                if self.notifier:
+                    # Inform the relevant places that the remote server is back up.
+                    self.notifier.notify_remote_server_up(self.destination)
+
+                if self.replication_client:
+                    # If we're on a worker we try and inform master about this. The
+                    # replication client doesn't hook into the notifier to avoid
+                    # infinite loops where we send a `REMOTE_SERVER_UP` command to
+                    # master, which then echoes it back to us which in turn pokes
+                    # the notifier.
+                    self.replication_client.send_remote_server_up(self.destination)
+
             except Exception:
                 logger.exception("Failed to store destination_retry_timings")