summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/app/federation_sender.py12
-rw-r--r--synapse/federation/sender/__init__.py18
-rw-r--r--synapse/federation/transport/server.py19
-rw-r--r--synapse/notifier.py31
-rw-r--r--synapse/replication/tcp/client.py3
-rw-r--r--synapse/replication/tcp/commands.py17
-rw-r--r--synapse/replication/tcp/protocol.py15
-rw-r--r--synapse/replication/tcp/resource.py9
-rw-r--r--synapse/server.pyi12
9 files changed, 129 insertions, 7 deletions
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index a57cf991ac..38d11fdd0f 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -158,6 +158,13 @@ class FederationSenderReplicationHandler(ReplicationClientHandler):
         args.update(self.send_handler.stream_positions())
         return args
 
+    def on_remote_server_up(self, server: str):
+        """Called when get a new REMOTE_SERVER_UP command."""
+
+        # Let's wake up the transaction queue for the server in case we have
+        # pending stuff to send to it.
+        self.send_handler.wake_destination(server)
+
 
 def start(config_options):
     try:
@@ -205,7 +212,7 @@ class FederationSenderHandler(object):
     to the federation sender.
     """
 
-    def __init__(self, hs, replication_client):
+    def __init__(self, hs: FederationSenderServer, replication_client):
         self.store = hs.get_datastore()
         self._is_mine_id = hs.is_mine_id
         self.federation_sender = hs.get_federation_sender()
@@ -226,6 +233,9 @@ class FederationSenderHandler(object):
             self.store.get_room_max_stream_ordering()
         )
 
+    def wake_destination(self, server: str):
+        self.federation_sender.wake_destination(server)
+
     def stream_positions(self):
         return {"federation": self.federation_position}
 
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 4ebb0e8bc0..36c83c3027 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -21,6 +21,7 @@ from prometheus_client import Counter
 
 from twisted.internet import defer
 
+import synapse
 import synapse.metrics
 from synapse.federation.sender.per_destination_queue import PerDestinationQueue
 from synapse.federation.sender.transaction_manager import TransactionManager
@@ -54,7 +55,7 @@ sent_pdus_destination_dist_total = Counter(
 
 
 class FederationSender(object):
-    def __init__(self, hs):
+    def __init__(self, hs: "synapse.server.HomeServer"):
         self.hs = hs
         self.server_name = hs.hostname
 
@@ -482,7 +483,20 @@ class FederationSender(object):
 
     def send_device_messages(self, destination):
         if destination == self.server_name:
-            logger.info("Not sending device update to ourselves")
+            logger.warning("Not sending device update to ourselves")
+            return
+
+        self._get_per_destination_queue(destination).attempt_new_transaction()
+
+    def wake_destination(self, destination: str):
+        """Called when we want to retry sending transactions to a remote.
+
+        This is mainly useful if the remote server has been down and we think it
+        might have come back.
+        """
+
+        if destination == self.server_name:
+            logger.warning("Not waking up ourselves")
             return
 
         self._get_per_destination_queue(destination).attempt_new_transaction()
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index b4cbf23394..d8cf9ed299 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -44,6 +44,7 @@ from synapse.logging.opentracing import (
     tags,
     whitelisted_homeserver,
 )
+from synapse.server import HomeServer
 from synapse.types import ThirdPartyInstanceID, get_domain_from_id
 from synapse.util.ratelimitutils import FederationRateLimiter
 from synapse.util.versionstring import get_version_string
@@ -101,12 +102,17 @@ class NoAuthenticationError(AuthenticationError):
 
 
 class Authenticator(object):
-    def __init__(self, hs):
+    def __init__(self, hs: HomeServer):
         self._clock = hs.get_clock()
         self.keyring = hs.get_keyring()
         self.server_name = hs.hostname
         self.store = hs.get_datastore()
         self.federation_domain_whitelist = hs.config.federation_domain_whitelist
+        self.notifer = hs.get_notifier()
+
+        self.replication_client = None
+        if hs.config.worker.worker_app:
+            self.replication_client = hs.get_tcp_replication()
 
     # A method just so we can pass 'self' as the authenticator to the Servlets
     async def authenticate_request(self, request, content):
@@ -166,6 +172,17 @@ class Authenticator(object):
         try:
             logger.info("Marking origin %r as up", origin)
             await self.store.set_destination_retry_timings(origin, None, 0, 0)
+
+            # Inform the relevant places that the remote server is back up.
+            self.notifer.notify_remote_server_up(origin)
+            if self.replication_client:
+                # If we're on a worker we try and inform master about this. The
+                # replication client doesn't hook into the notifier to avoid
+                # infinite loops where we send a `REMOTE_SERVER_UP` command to
+                # master, which then echoes it back to us which in turn pokes
+                # the notifier.
+                self.replication_client.send_remote_server_up(origin)
+
         except Exception:
             logger.exception("Error resetting retry timings on %s", origin)
 
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 5f5f765bea..6132727cbd 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -15,11 +15,13 @@
 
 import logging
 from collections import namedtuple
+from typing import Callable, List
 
 from prometheus_client import Counter
 
 from twisted.internet import defer
 
+import synapse.server
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import AuthError
 from synapse.handlers.presence import format_user_presence_state
@@ -154,7 +156,7 @@ class Notifier(object):
 
     UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000
 
-    def __init__(self, hs):
+    def __init__(self, hs: "synapse.server.HomeServer"):
         self.user_to_user_stream = {}
         self.room_to_user_streams = {}
 
@@ -164,7 +166,12 @@ class Notifier(object):
         self.store = hs.get_datastore()
         self.pending_new_room_events = []
 
-        self.replication_callbacks = []
+        # Called when there are new things to stream over replication
+        self.replication_callbacks = []  # type: List[Callable[[], None]]
+
+        # Called when remote servers have come back online after having been
+        # down.
+        self.remote_server_up_callbacks = []  # type: List[Callable[[str], None]]
 
         self.clock = hs.get_clock()
         self.appservice_handler = hs.get_application_service_handler()
@@ -205,7 +212,7 @@ class Notifier(object):
             "synapse_notifier_users", "", [], lambda: len(self.user_to_user_stream)
         )
 
-    def add_replication_callback(self, cb):
+    def add_replication_callback(self, cb: Callable[[], None]):
         """Add a callback that will be called when some new data is available.
         Callback is not given any arguments. It should *not* return a Deferred - if
         it needs to do any asynchronous work, a background thread should be started and
@@ -213,6 +220,12 @@ class Notifier(object):
         """
         self.replication_callbacks.append(cb)
 
+    def add_remote_server_up_callback(self, cb: Callable[[str], None]):
+        """Add a callback that will be called when synapse detects a server
+        has been
+        """
+        self.remote_server_up_callbacks.append(cb)
+
     def on_new_room_event(
         self, event, room_stream_id, max_room_stream_id, extra_users=[]
     ):
@@ -522,3 +535,15 @@ class Notifier(object):
         """Notify the any replication listeners that there's a new event"""
         for cb in self.replication_callbacks:
             cb()
+
+    def notify_remote_server_up(self, server: str):
+        """Notify any replication that a remote server has come back up
+        """
+        # We call federation_sender directly rather than registering as a
+        # callback as a) we already have a reference to it and b) it introduces
+        # circular dependencies.
+        if self.federation_sender:
+            self.federation_sender.wake_destination(server)
+
+        for cb in self.remote_server_up_callbacks:
+            cb(server)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 52a0aefe68..fc06a7b053 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -143,6 +143,9 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
         if d:
             d.callback(data)
 
+    def on_remote_server_up(self, server: str):
+        """Called when get a new REMOTE_SERVER_UP command."""
+
     def get_streams_to_replicate(self) -> Dict[str, int]:
         """Called when a new connection has been established and we need to
         subscribe to streams.
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index cbb36b9acf..451671412d 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -387,6 +387,20 @@ class UserIpCommand(Command):
         )
 
 
+class RemoteServerUpCommand(Command):
+    """Sent when a worker has detected that a remote server is no longer
+    "down" and retry timings should be reset.
+
+    If sent from a client the server will relay to all other workers.
+
+    Format::
+
+        REMOTE_SERVER_UP <server>
+    """
+
+    NAME = "REMOTE_SERVER_UP"
+
+
 _COMMANDS = (
     ServerCommand,
     RdataCommand,
@@ -401,6 +415,7 @@ _COMMANDS = (
     RemovePusherCommand,
     InvalidateCacheCommand,
     UserIpCommand,
+    RemoteServerUpCommand,
 )  # type: Tuple[Type[Command], ...]
 
 # Map of command name to command type.
@@ -414,6 +429,7 @@ VALID_SERVER_COMMANDS = (
     ErrorCommand.NAME,
     PingCommand.NAME,
     SyncCommand.NAME,
+    RemoteServerUpCommand.NAME,
 )
 
 # The commands the client is allowed to send
@@ -427,4 +443,5 @@ VALID_CLIENT_COMMANDS = (
     InvalidateCacheCommand.NAME,
     UserIpCommand.NAME,
     ErrorCommand.NAME,
+    RemoteServerUpCommand.NAME,
 )
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 5f4bdf84d2..131e5acb09 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -76,6 +76,7 @@ from synapse.replication.tcp.commands import (
     PingCommand,
     PositionCommand,
     RdataCommand,
+    RemoteServerUpCommand,
     ReplicateCommand,
     ServerCommand,
     SyncCommand,
@@ -460,6 +461,9 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
     async def on_INVALIDATE_CACHE(self, cmd):
         self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
 
+    async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
+        self.streamer.on_remote_server_up(cmd.data)
+
     async def on_USER_IP(self, cmd):
         self.streamer.on_user_ip(
             cmd.user_id,
@@ -555,6 +559,9 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
     def send_sync(self, data):
         self.send_command(SyncCommand(data))
 
+    def send_remote_server_up(self, server: str):
+        self.send_command(RemoteServerUpCommand(server))
+
     def on_connection_closed(self):
         BaseReplicationStreamProtocol.on_connection_closed(self)
         self.streamer.lost_connection(self)
@@ -589,6 +596,11 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
         raise NotImplementedError()
 
     @abc.abstractmethod
+    async def on_remote_server_up(self, server: str):
+        """Called when get a new REMOTE_SERVER_UP command."""
+        raise NotImplementedError()
+
+    @abc.abstractmethod
     def get_streams_to_replicate(self):
         """Called when a new connection has been established and we need to
         subscribe to streams.
@@ -707,6 +719,9 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
     async def on_SYNC(self, cmd):
         self.handler.on_sync(cmd.data)
 
+    async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
+        self.handler.on_remote_server_up(cmd.data)
+
     def replicate(self, stream_name, token):
         """Send the subscription request to the server
         """
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index b1752e88cd..6ebf944f66 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -120,6 +120,7 @@ class ReplicationStreamer(object):
             self.federation_sender = hs.get_federation_sender()
 
         self.notifier.add_replication_callback(self.on_notifier_poke)
+        self.notifier.add_remote_server_up_callback(self.send_remote_server_up)
 
         # Keeps track of whether we are currently checking for updates
         self.is_looping = False
@@ -288,6 +289,14 @@ class ReplicationStreamer(object):
         )
         await self._server_notices_sender.on_user_ip(user_id)
 
+    @measure_func("repl.on_remote_server_up")
+    def on_remote_server_up(self, server: str):
+        self.notifier.notify_remote_server_up(server)
+
+    def send_remote_server_up(self, server: str):
+        for conn in self.connections:
+            conn.send_remote_server_up(server)
+
     def send_sync_to_all_connections(self, data):
         """Sends a SYNC command to all clients.
 
diff --git a/synapse/server.pyi b/synapse/server.pyi
index b5e0b57095..0731403047 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -1,3 +1,5 @@
+import twisted.internet
+
 import synapse.api.auth
 import synapse.config.homeserver
 import synapse.federation.sender
@@ -9,10 +11,12 @@ import synapse.handlers.deactivate_account
 import synapse.handlers.device
 import synapse.handlers.e2e_keys
 import synapse.handlers.message
+import synapse.handlers.presence
 import synapse.handlers.room
 import synapse.handlers.room_member
 import synapse.handlers.set_password
 import synapse.http.client
+import synapse.notifier
 import synapse.rest.media.v1.media_repository
 import synapse.server_notices.server_notices_manager
 import synapse.server_notices.server_notices_sender
@@ -85,3 +89,11 @@ class HomeServer(object):
         self,
     ) -> synapse.server_notices.server_notices_sender.ServerNoticesSender:
         pass
+    def get_notifier(self) -> synapse.notifier.Notifier:
+        pass
+    def get_presence_handler(self) -> synapse.handlers.presence.PresenceHandler:
+        pass
+    def get_clock(self) -> synapse.util.Clock:
+        pass
+    def get_reactor(self) -> twisted.internet.base.ReactorBase:
+        pass