diff options
Diffstat (limited to 'synapse/replication/tcp/protocol.py')
-rw-r--r-- | synapse/replication/tcp/protocol.py | 286 |
1 files changed, 36 insertions, 250 deletions
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index ff720beb56..d4456f42f3 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -46,7 +46,6 @@ indicate which side is sending, these are *not* included on the wire:: > ERROR server stopping * connection closed by server * """ -import abc import fcntl import logging import struct @@ -64,26 +63,22 @@ from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.commands import ( COMMAND_MAP, - VALID_CLIENT_COMMANDS, - VALID_SERVER_COMMANDS, Command, ErrorCommand, NameCommand, PingCommand, - PositionCommand, - RdataCommand, RemoteServerUpCommand, ReplicateCommand, ServerCommand, - SyncCommand, - UserSyncCommand, ) from synapse.replication.tcp.streams import STREAMS_MAP, Stream -from synapse.server import HomeServer -from synapse.types import Collection from synapse.util import Clock from synapse.util.stringutils import random_string +MYPY = False +if MYPY: + import synapse.server + connection_close_counter = Counter( "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"] ) @@ -124,16 +119,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): delimiter = b"\n" - # Valid commands we expect to receive - VALID_INBOUND_COMMANDS = [] # type: Collection[str] - - # Valid commands we can send - VALID_OUTBOUND_COMMANDS = [] # type: Collection[str] - max_line_buffer = 10000 - def __init__(self, clock): + def __init__(self, clock, handler): self.clock = clock + self.handler = handler self.last_received_command = self.clock.time_msec() self.last_sent_command = 0 @@ -173,6 +163,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # can time us out. self.send_command(PingCommand(self.clock.time_msec())) + self.handler.new_connection(self) + def send_ping(self): """Periodically sends a ping and checks if we should close the connection due to the other side timing out. @@ -210,11 +202,6 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): line = line.decode("utf-8") cmd_name, rest_of_line = line.split(" ", 1) - if cmd_name not in self.VALID_INBOUND_COMMANDS: - logger.error("[%s] invalid command %s", self.id(), cmd_name) - self.send_error("invalid command: %s", cmd_name) - return - self.last_received_command = self.clock.time_msec() self.inbound_commands_counter[cmd_name] = ( @@ -246,8 +233,23 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): Args: cmd: received command """ - handler = getattr(self, "on_%s" % (cmd.NAME,)) - await handler(cmd) + handled = False + + # First call any command handlers on this instance. These are for TCP + # specific handling. + cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True + + # Then call out to the handler. + cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True + + if not handled: + logger.warning("Unhandled command: %r", cmd) def close(self): logger.warning("[%s] Closing connection", self.id()) @@ -255,6 +257,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.transport.loseConnection() self.on_connection_closed() + def send_remote_server_up(self, server: str): + self.send_command(RemoteServerUpCommand(server)) + def send_error(self, error_string, *args): """Send an error to remote and close the connection. """ @@ -376,6 +381,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.state = ConnectionStates.CLOSED self.pending_commands = [] + self.handler.lost_connection(self) + if self.transport: self.transport.unregisterProducer() @@ -399,162 +406,35 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): - VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS - VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS - - def __init__(self, server_name, clock, streamer): - BaseReplicationStreamProtocol.__init__(self, clock) # Old style class + def __init__(self, hs, server_name, clock, handler): + BaseReplicationStreamProtocol.__init__(self, clock, handler) # Old style class self.server_name = server_name - self.streamer = streamer def connectionMade(self): self.send_command(ServerCommand(self.server_name)) BaseReplicationStreamProtocol.connectionMade(self) - self.streamer.new_connection(self) async def on_NAME(self, cmd): logger.info("[%s] Renamed to %r", self.id(), cmd.data) self.name = cmd.data - async def on_USER_SYNC(self, cmd): - await self.streamer.on_user_sync( - cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms - ) - - async def on_CLEAR_USER_SYNC(self, cmd): - await self.streamer.on_clear_user_syncs(cmd.instance_id) - - async def on_REPLICATE(self, cmd): - # Subscribe to all streams we're publishing to. - for stream_name in self.streamer.streams_by_name: - current_token = self.streamer.get_stream_token(stream_name) - self.send_command(PositionCommand(stream_name, current_token)) - - async def on_FEDERATION_ACK(self, cmd): - self.streamer.federation_ack(cmd.token) - - async def on_REMOVE_PUSHER(self, cmd): - await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) - - async def on_INVALIDATE_CACHE(self, cmd): - await self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) - - async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): - self.streamer.on_remote_server_up(cmd.data) - - async def on_USER_IP(self, cmd): - self.streamer.on_user_ip( - cmd.user_id, - cmd.access_token, - cmd.ip, - cmd.user_agent, - cmd.device_id, - cmd.last_seen, - ) - - def stream_update(self, stream_name, token, data): - """Called when a new update is available to stream to clients. - - We need to check if the client is interested in the stream or not - """ - self.send_command(RdataCommand(stream_name, token, data)) - - def send_sync(self, data): - self.send_command(SyncCommand(data)) - - def send_remote_server_up(self, server: str): - self.send_command(RemoteServerUpCommand(server)) - - def on_connection_closed(self): - BaseReplicationStreamProtocol.on_connection_closed(self) - self.streamer.lost_connection(self) - - -class AbstractReplicationClientHandler(metaclass=abc.ABCMeta): - """ - The interface for the handler that should be passed to - ClientReplicationStreamProtocol - """ - - @abc.abstractmethod - async def on_rdata(self, stream_name, token, rows): - """Called to handle a batch of replication data with a given stream token. - - Args: - stream_name (str): name of the replication stream for this batch of rows - token (int): stream token for this batch of rows - rows (list): a list of Stream.ROW_TYPE objects as returned by - Stream.parse_row. - """ - raise NotImplementedError() - - @abc.abstractmethod - async def on_position(self, stream_name, token): - """Called when we get new position data.""" - raise NotImplementedError() - - @abc.abstractmethod - def on_sync(self, data): - """Called when get a new SYNC command.""" - raise NotImplementedError() - - @abc.abstractmethod - async def on_remote_server_up(self, server: str): - """Called when get a new REMOTE_SERVER_UP command.""" - raise NotImplementedError() - - @abc.abstractmethod - def get_streams_to_replicate(self): - """Called when a new connection has been established and we need to - subscribe to streams. - - Returns: - map from stream name to the most recent update we have for - that stream (ie, the point we want to start replicating from) - """ - raise NotImplementedError() - - @abc.abstractmethod - def get_currently_syncing_users(self): - """Get the list of currently syncing users (if any). This is called - when a connection has been established and we need to send the - currently syncing users.""" - raise NotImplementedError() - - @abc.abstractmethod - def update_connection(self, connection): - """Called when a connection has been established (or lost with None). - """ - raise NotImplementedError() - - @abc.abstractmethod - def finished_connecting(self): - """Called when we have successfully subscribed and caught up to all - streams we're interested in. - """ - raise NotImplementedError() - class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): - VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS - VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS - def __init__( self, - hs: HomeServer, + hs: "synapse.server.HomeServer", client_name: str, server_name: str, clock: Clock, - handler: AbstractReplicationClientHandler, + handler, ): - BaseReplicationStreamProtocol.__init__(self, clock) + BaseReplicationStreamProtocol.__init__(self, clock, handler) self.instance_id = hs.get_instance_id() self.client_name = client_name self.server_name = server_name - self.handler = handler self.streams = { stream.NAME: stream(hs) for stream in STREAMS_MAP.values() @@ -570,106 +450,16 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.pending_batches = {} # type: Dict[str, List[Any]] def connectionMade(self): - self.send_command(NameCommand(self.client_name)) BaseReplicationStreamProtocol.connectionMade(self) - # Once we've connected subscribe to the necessary streams + self.send_command(NameCommand(self.client_name)) self.replicate() - # Tell the server if we have any users currently syncing (should only - # happen on synchrotrons) - currently_syncing = self.handler.get_currently_syncing_users() - now = self.clock.time_msec() - for user_id in currently_syncing: - self.send_command(UserSyncCommand(self.instance_id, user_id, True, now)) - - # We've now finished connecting to so inform the client handler - self.handler.update_connection(self) - async def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) self.send_error("Wrong remote") - async def on_RDATA(self, cmd): - stream_name = cmd.stream_name - inbound_rdata_count.labels(stream_name).inc() - - try: - row = STREAMS_MAP[stream_name].parse_row(cmd.row) - except Exception: - logger.exception( - "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row - ) - raise - - if cmd.token is None or stream_name in self.streams_connecting: - # I.e. this is part of a batch of updates for this stream. Batch - # until we get an update for the stream with a non None token - self.pending_batches.setdefault(stream_name, []).append(row) - else: - # Check if this is the last of a batch of updates - rows = self.pending_batches.pop(stream_name, []) - rows.append(row) - await self.handler.on_rdata(stream_name, cmd.token, rows) - - async def on_POSITION(self, cmd: PositionCommand): - stream = self.streams.get(cmd.stream_name) - if not stream: - logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) - return - - # Find where we previously streamed up to. - current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name) - if current_token is None: - logger.warning( - "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name - ) - return - - # Fetch all updates between then and now. - limited = True - while limited: - updates, current_token, limited = await stream.get_updates_since( - current_token, cmd.token - ) - - # Check if the connection was closed underneath us, if so we bail - # rather than risk having concurrent catch ups going on. - if self.state == ConnectionStates.CLOSED: - return - - if updates: - await self.handler.on_rdata( - cmd.stream_name, - current_token, - [stream.parse_row(update[1]) for update in updates], - ) - - # We've now caught up to position sent to us, notify handler. - await self.handler.on_position(cmd.stream_name, cmd.token) - - # We're now up to date wit the stream - self.streams_connecting.discard(cmd.stream_name) - if not self.streams_connecting: - self.handler.finished_connecting() - - # Check if the connection was closed underneath us, if so we bail - # rather than risk having concurrent catch ups going on. - if self.state == ConnectionStates.CLOSED: - return - - # Handle any RDATA that came in while we were catching up. - rows = self.pending_batches.pop(cmd.stream_name, []) - if rows: - await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows) - - async def on_SYNC(self, cmd): - self.handler.on_sync(cmd.data) - - async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): - self.handler.on_remote_server_up(cmd.data) - def replicate(self): """Send the subscription request to the server """ @@ -677,10 +467,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.send_command(ReplicateCommand()) - def on_connection_closed(self): - BaseReplicationStreamProtocol.on_connection_closed(self) - self.handler.update_connection(None) - # The following simply registers metrics for the replication connections |