diff options
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r-- | synapse/replication/tcp/handler.py | 63 | ||||
-rw-r--r-- | synapse/replication/tcp/protocol.py | 2 | ||||
-rw-r--r-- | synapse/replication/tcp/redis.py | 2 |
3 files changed, 51 insertions, 16 deletions
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 3a8c7c7e2d..b8f49a8d0f 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -117,7 +117,6 @@ class ReplicationCommandHandler: self._server_notices_sender = None if self._is_master: self._server_notices_sender = hs.get_server_notices_sender() - self._notifier.add_remote_server_up_callback(self.send_remote_server_up) def start_replication(self, hs): """Helper method to start a replication connection to the remote server @@ -163,7 +162,7 @@ class ReplicationCommandHandler: port = hs.config.worker_replication_port hs.get_reactor().connectTCP(host, port, self._factory) - async def on_REPLICATE(self, cmd: ReplicateCommand): + async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand): # We only want to announce positions by the writer of the streams. # Currently this is just the master process. if not self._is_master: @@ -173,7 +172,7 @@ class ReplicationCommandHandler: current_token = stream.current_token() self.send_command(PositionCommand(stream_name, current_token)) - async def on_USER_SYNC(self, cmd: UserSyncCommand): + async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand): user_sync_counter.inc() if self._is_master: @@ -181,17 +180,23 @@ class ReplicationCommandHandler: cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms ) - async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand): + async def on_CLEAR_USER_SYNC( + self, conn: AbstractConnection, cmd: ClearUserSyncsCommand + ): if self._is_master: await self._presence_handler.update_external_syncs_clear(cmd.instance_id) - async def on_FEDERATION_ACK(self, cmd: FederationAckCommand): + async def on_FEDERATION_ACK( + self, conn: AbstractConnection, cmd: FederationAckCommand + ): federation_ack_counter.inc() if self._federation_sender: self._federation_sender.federation_ack(cmd.token) - async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand): + async def on_REMOVE_PUSHER( + self, conn: AbstractConnection, cmd: RemovePusherCommand + ): remove_pusher_counter.inc() if self._is_master: @@ -201,7 +206,9 @@ class ReplicationCommandHandler: self._notifier.on_new_replication_data() - async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand): + async def on_INVALIDATE_CACHE( + self, conn: AbstractConnection, cmd: InvalidateCacheCommand + ): invalidate_cache_counter.inc() if self._is_master: @@ -211,7 +218,7 @@ class ReplicationCommandHandler: cmd.cache_func, tuple(cmd.keys) ) - async def on_USER_IP(self, cmd: UserIpCommand): + async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand): user_ip_cache_counter.inc() if self._is_master: @@ -227,7 +234,7 @@ class ReplicationCommandHandler: if self._server_notices_sender: await self._server_notices_sender.on_user_ip(cmd.user_id) - async def on_RDATA(self, cmd: RdataCommand): + async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): stream_name = cmd.stream_name inbound_rdata_count.labels(stream_name).inc() @@ -278,7 +285,7 @@ class ReplicationCommandHandler: logger.debug("Received rdata %s -> %s", stream_name, token) await self._replication_data_handler.on_rdata(stream_name, token, rows) - async def on_POSITION(self, cmd: PositionCommand): + async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): stream = self._streams.get(cmd.stream_name) if not stream: logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) @@ -332,12 +339,30 @@ class ReplicationCommandHandler: self._streams_connected.add(cmd.stream_name) - async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): + async def on_REMOTE_SERVER_UP( + self, conn: AbstractConnection, cmd: RemoteServerUpCommand + ): """"Called when get a new REMOTE_SERVER_UP command.""" self._replication_data_handler.on_remote_server_up(cmd.data) - if self._is_master: - self._notifier.notify_remote_server_up(cmd.data) + self._notifier.notify_remote_server_up(cmd.data) + + # We relay to all other connections to ensure every instance gets the + # notification. + # + # When configured to use redis we'll always only have one connection and + # so this is a no-op (all instances will have already received the same + # REMOTE_SERVER_UP command). + # + # For direct TCP connections this will relay to all other connections + # connected to us. When on master this will correctly fan out to all + # other direct TCP clients and on workers there'll only be the one + # connection to master. + # + # (The logic here should also be sound if we have a mix of Redis and + # direct TCP connections so long as there is only one traffic route + # between two instances, but that is not currently supported). + self.send_command(cmd, ignore_conn=conn) def new_connection(self, connection: AbstractConnection): """Called when we have a new connection. @@ -382,11 +407,21 @@ class ReplicationCommandHandler: """ return bool(self._connections) - def send_command(self, cmd: Command): + def send_command( + self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None + ): """Send a command to all connected connections. + + Args: + cmd + ignore_conn: If set don't send command to the given connection. + Used when relaying commands from one connection to all others. """ if self._connections: for connection in self._connections: + if connection == ignore_conn: + continue + try: connection.send_command(cmd) except Exception: diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index e3f64eba8f..4198eece71 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -260,7 +260,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # Then call out to the handler. cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None) if cmd_func: - await cmd_func(cmd) + await cmd_func(self, cmd) handled = True if not handled: diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 49b3ed0c5e..617e860f95 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -112,7 +112,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): # Then call out to the handler. cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) if cmd_func: - await cmd_func(cmd) + await cmd_func(self, cmd) handled = True if not handled: |