summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/tcp/handler.py63
-rw-r--r--synapse/replication/tcp/protocol.py2
-rw-r--r--synapse/replication/tcp/redis.py2
3 files changed, 51 insertions, 16 deletions
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 3a8c7c7e2d..b8f49a8d0f 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -117,7 +117,6 @@ class ReplicationCommandHandler:
         self._server_notices_sender = None
         if self._is_master:
             self._server_notices_sender = hs.get_server_notices_sender()
-            self._notifier.add_remote_server_up_callback(self.send_remote_server_up)
 
     def start_replication(self, hs):
         """Helper method to start a replication connection to the remote server
@@ -163,7 +162,7 @@ class ReplicationCommandHandler:
             port = hs.config.worker_replication_port
             hs.get_reactor().connectTCP(host, port, self._factory)
 
-    async def on_REPLICATE(self, cmd: ReplicateCommand):
+    async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
         # We only want to announce positions by the writer of the streams.
         # Currently this is just the master process.
         if not self._is_master:
@@ -173,7 +172,7 @@ class ReplicationCommandHandler:
             current_token = stream.current_token()
             self.send_command(PositionCommand(stream_name, current_token))
 
-    async def on_USER_SYNC(self, cmd: UserSyncCommand):
+    async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
         user_sync_counter.inc()
 
         if self._is_master:
@@ -181,17 +180,23 @@ class ReplicationCommandHandler:
                 cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
             )
 
-    async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand):
+    async def on_CLEAR_USER_SYNC(
+        self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
+    ):
         if self._is_master:
             await self._presence_handler.update_external_syncs_clear(cmd.instance_id)
 
-    async def on_FEDERATION_ACK(self, cmd: FederationAckCommand):
+    async def on_FEDERATION_ACK(
+        self, conn: AbstractConnection, cmd: FederationAckCommand
+    ):
         federation_ack_counter.inc()
 
         if self._federation_sender:
             self._federation_sender.federation_ack(cmd.token)
 
-    async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand):
+    async def on_REMOVE_PUSHER(
+        self, conn: AbstractConnection, cmd: RemovePusherCommand
+    ):
         remove_pusher_counter.inc()
 
         if self._is_master:
@@ -201,7 +206,9 @@ class ReplicationCommandHandler:
 
             self._notifier.on_new_replication_data()
 
-    async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand):
+    async def on_INVALIDATE_CACHE(
+        self, conn: AbstractConnection, cmd: InvalidateCacheCommand
+    ):
         invalidate_cache_counter.inc()
 
         if self._is_master:
@@ -211,7 +218,7 @@ class ReplicationCommandHandler:
                 cmd.cache_func, tuple(cmd.keys)
             )
 
-    async def on_USER_IP(self, cmd: UserIpCommand):
+    async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
         user_ip_cache_counter.inc()
 
         if self._is_master:
@@ -227,7 +234,7 @@ class ReplicationCommandHandler:
         if self._server_notices_sender:
             await self._server_notices_sender.on_user_ip(cmd.user_id)
 
-    async def on_RDATA(self, cmd: RdataCommand):
+    async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
         stream_name = cmd.stream_name
         inbound_rdata_count.labels(stream_name).inc()
 
@@ -278,7 +285,7 @@ class ReplicationCommandHandler:
         logger.debug("Received rdata %s -> %s", stream_name, token)
         await self._replication_data_handler.on_rdata(stream_name, token, rows)
 
-    async def on_POSITION(self, cmd: PositionCommand):
+    async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
         stream = self._streams.get(cmd.stream_name)
         if not stream:
             logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
@@ -332,12 +339,30 @@ class ReplicationCommandHandler:
 
             self._streams_connected.add(cmd.stream_name)
 
-    async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
+    async def on_REMOTE_SERVER_UP(
+        self, conn: AbstractConnection, cmd: RemoteServerUpCommand
+    ):
         """"Called when get a new REMOTE_SERVER_UP command."""
         self._replication_data_handler.on_remote_server_up(cmd.data)
 
-        if self._is_master:
-            self._notifier.notify_remote_server_up(cmd.data)
+        self._notifier.notify_remote_server_up(cmd.data)
+
+        # We relay to all other connections to ensure every instance gets the
+        # notification.
+        #
+        # When configured to use redis we'll always only have one connection and
+        # so this is a no-op (all instances will have already received the same
+        # REMOTE_SERVER_UP command).
+        #
+        # For direct TCP connections this will relay to all other connections
+        # connected to us. When on master this will correctly fan out to all
+        # other direct TCP clients and on workers there'll only be the one
+        # connection to master.
+        #
+        # (The logic here should also be sound if we have a mix of Redis and
+        # direct TCP connections so long as there is only one traffic route
+        # between two instances, but that is not currently supported).
+        self.send_command(cmd, ignore_conn=conn)
 
     def new_connection(self, connection: AbstractConnection):
         """Called when we have a new connection.
@@ -382,11 +407,21 @@ class ReplicationCommandHandler:
         """
         return bool(self._connections)
 
-    def send_command(self, cmd: Command):
+    def send_command(
+        self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None
+    ):
         """Send a command to all connected connections.
+
+        Args:
+            cmd
+            ignore_conn: If set don't send command to the given connection.
+                Used when relaying commands from one connection to all others.
         """
         if self._connections:
             for connection in self._connections:
+                if connection == ignore_conn:
+                    continue
+
                 try:
                     connection.send_command(cmd)
                 except Exception:
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index e3f64eba8f..4198eece71 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -260,7 +260,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         # Then call out to the handler.
         cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
         if cmd_func:
-            await cmd_func(cmd)
+            await cmd_func(self, cmd)
             handled = True
 
         if not handled:
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 49b3ed0c5e..617e860f95 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -112,7 +112,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
         # Then call out to the handler.
         cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
         if cmd_func:
-            await cmd_func(cmd)
+            await cmd_func(self, cmd)
             handled = True
 
         if not handled: