summary refs log tree commit diff
path: root/synapse/replication/tcp/protocol.py
diff options
context:
space:
mode:
authorAndrew Morgan <andrew@amorgan.xyz>2020-06-10 16:58:10 +0100
committerAndrew Morgan <andrew@amorgan.xyz>2020-06-10 16:58:10 +0100
commite8a7a853f8e73b6ea27a1e02bd9786114bfcec3b (patch)
tree711dab9b030223706473bf0b8279b0eb86bca317 /synapse/replication/tcp/protocol.py
parentMerge pull request #39 from matrix-org/dinsic-release-v1.12.x (diff)
parentUpdate changelog based on feedback. (diff)
downloadsynapse-e8a7a853f8e73b6ea27a1e02bd9786114bfcec3b.tar.xz
Merge branch 'release-v1.13.0' of github.com:matrix-org/synapse into dinsic-release-v1.14.x
* 'release-v1.13.0' of github.com:matrix-org/synapse: (257 commits)
  Update changelog based on feedback.
  Move warnings in the changelog and re-iterate changes to branches.
  1.13.0
  update dh-virtualenv (#7526)
  1.13.0rc3
  Hash passwords earlier in the registration process (#7523)
  1.13.0rc2
  1.13.0rc2
  Stop `get_joined_users` corruption from custom statuses (#7376)
  Do not validate that the client dict is stable during UI Auth. (#7483)
  Fix new flake8 errors (#7489)
  Don't UPGRADE database rows
  RST indenting
  Put rollback instructions in upgrade notes
  Fix changelog typo
  Oh yeah, RST
  Absolute URL it is then
  Fix upgrade notes link
  Provide summary of upgrade issues in changelog. Fix )
  Move next version notes from changelog to upgrade notes
  ...
Diffstat (limited to 'synapse/replication/tcp/protocol.py')
-rw-r--r--synapse/replication/tcp/protocol.py459
1 files changed, 88 insertions, 371 deletions
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py

index d185cc0c8f..4198eece71 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py
@@ -35,9 +35,7 @@ indicate which side is sending, these are *not* included on the wire:: > PING 1490197665618 < NAME synapse.app.appservice < PING 1490197665618 - < REPLICATE events 1 - < REPLICATE backfill 1 - < REPLICATE caches 1 + < REPLICATE > POSITION events 1 > POSITION backfill 1 > POSITION caches 1 @@ -52,45 +50,51 @@ import abc import fcntl import logging import struct -from collections import defaultdict -from typing import Any, DefaultDict, Dict, List, Set, Tuple - -from six import iteritems, iterkeys +from typing import TYPE_CHECKING, List from prometheus_client import Counter -from twisted.internet import defer from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure -from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.commands import ( - COMMAND_MAP, VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS, Command, ErrorCommand, NameCommand, PingCommand, - PositionCommand, - RdataCommand, - RemoteServerUpCommand, ReplicateCommand, ServerCommand, - SyncCommand, - UserSyncCommand, + parse_command_from_line, ) -from synapse.replication.tcp.streams import STREAMS_MAP from synapse.types import Collection from synapse.util import Clock from synapse.util.stringutils import random_string +if TYPE_CHECKING: + from synapse.replication.tcp.handler import ReplicationCommandHandler + from synapse.server import HomeServer + + connection_close_counter = Counter( "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"] ) +tcp_inbound_commands_counter = Counter( + "synapse_replication_tcp_protocol_inbound_commands", + "Number of commands received from replication, by command and name of process connected to", + ["command", "name"], +) + +tcp_outbound_commands_counter = Counter( + "synapse_replication_tcp_protocol_outbound_commands", + "Number of commands sent to replication, by command and name of process connected to", + ["command", "name"], +) + # A list of all connected protocols. This allows us to send metrics about the # connections. connected_connections = [] @@ -119,7 +123,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): are only sent by the server. On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed - command. + command before delegating to `ReplicationCommandHandler.on_<COMMAND_NAME>`. It also sends `PING` periodically, and correctly times out remote connections (if they send a `PING` command) @@ -135,8 +139,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): max_line_buffer = 10000 - def __init__(self, clock): + def __init__(self, clock: Clock, handler: "ReplicationCommandHandler"): self.clock = clock + self.command_handler = handler self.last_received_command = self.clock.time_msec() self.last_sent_command = 0 @@ -155,9 +160,6 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # The LoopingCall for sending pings. self._send_ping_loop = None - self.inbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int] - self.outbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int] - def connectionMade(self): logger.info("[%s] Connection established", self.id()) @@ -176,6 +178,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # can time us out. self.send_command(PingCommand(self.clock.time_msec())) + self.command_handler.new_connection(self) + def send_ping(self): """Periodically sends a ping and checks if we should close the connection due to the other side timing out. @@ -203,38 +207,30 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): ) self.send_error("ping timeout") - def lineReceived(self, line): + def lineReceived(self, line: bytes): """Called when we've received a line """ if line.strip() == "": # Ignore blank lines return - line = line.decode("utf-8") - cmd_name, rest_of_line = line.split(" ", 1) + linestr = line.decode("utf-8") - if cmd_name not in self.VALID_INBOUND_COMMANDS: - logger.error("[%s] invalid command %s", self.id(), cmd_name) - self.send_error("invalid command: %s", cmd_name) + try: + cmd = parse_command_from_line(linestr) + except Exception as e: + logger.exception("[%s] failed to parse line: %r", self.id(), linestr) + self.send_error("failed to parse line: %r (%r):" % (e, linestr)) return - self.last_received_command = self.clock.time_msec() + if cmd.NAME not in self.VALID_INBOUND_COMMANDS: + logger.error("[%s] invalid command %s", self.id(), cmd.NAME) + self.send_error("invalid command: %s", cmd.NAME) + return - self.inbound_commands_counter[cmd_name] = ( - self.inbound_commands_counter[cmd_name] + 1 - ) + self.last_received_command = self.clock.time_msec() - cmd_cls = COMMAND_MAP[cmd_name] - try: - cmd = cmd_cls.from_line(rest_of_line) - except Exception as e: - logger.exception( - "[%s] failed to parse line %r: %r", self.id(), cmd_name, rest_of_line - ) - self.send_error( - "failed to parse line for %r: %r (%r):" % (cmd_name, e, rest_of_line) - ) - return + tcp_inbound_commands_counter.labels(cmd.NAME, self.name).inc() # Now lets try and call on_<CMD_NAME> function run_as_background_process( @@ -244,13 +240,31 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): async def handle_command(self, cmd: Command): """Handle a command we have received over the replication stream. - By default delegates to on_<COMMAND>, which should return an awaitable. + First calls `self.on_<COMMAND>` if it exists, then calls + `self.command_handler.on_<COMMAND>` if it exists. This allows for + protocol level handling of commands (e.g. PINGs), before delegating to + the handler. Args: cmd: received command """ - handler = getattr(self, "on_%s" % (cmd.NAME,)) - await handler(cmd) + handled = False + + # First call any command handlers on this instance. These are for TCP + # specific handling. + cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True + + # Then call out to the handler. + cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(self, cmd) + handled = True + + if not handled: + logger.warning("Unhandled command: %r", cmd) def close(self): logger.warning("[%s] Closing connection", self.id()) @@ -282,9 +296,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self._queue_command(cmd) return - self.outbound_commands_counter[cmd.NAME] = ( - self.outbound_commands_counter[cmd.NAME] + 1 - ) + tcp_outbound_commands_counter.labels(cmd.NAME, self.name).inc() + string = "%s %s" % (cmd.NAME, cmd.to_line()) if "\n" in string: raise Exception("Unexpected newline in command: %r", string) @@ -379,6 +392,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.state = ConnectionStates.CLOSED self.pending_commands = [] + self.command_handler.lost_connection(self) + if self.transport: self.transport.unregisterProducer() @@ -405,232 +420,21 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS - def __init__(self, server_name, clock, streamer): - BaseReplicationStreamProtocol.__init__(self, clock) # Old style class + def __init__( + self, server_name: str, clock: Clock, handler: "ReplicationCommandHandler" + ): + super().__init__(clock, handler) self.server_name = server_name - self.streamer = streamer - - # The streams the client has subscribed to and is up to date with - self.replication_streams = set() # type: Set[str] - - # The streams the client is currently subscribing to. - self.connecting_streams = set() # type: Set[str] - - # Map from stream name to list of updates to send once we've finished - # subscribing the client to the stream. - self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]] def connectionMade(self): self.send_command(ServerCommand(self.server_name)) - BaseReplicationStreamProtocol.connectionMade(self) - self.streamer.new_connection(self) + super().connectionMade() async def on_NAME(self, cmd): logger.info("[%s] Renamed to %r", self.id(), cmd.data) self.name = cmd.data - async def on_USER_SYNC(self, cmd): - await self.streamer.on_user_sync( - self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms - ) - - async def on_REPLICATE(self, cmd): - stream_name = cmd.stream_name - token = cmd.token - - if stream_name == "ALL": - # Subscribe to all streams we're publishing to. - deferreds = [ - run_in_background(self.subscribe_to_stream, stream, token) - for stream in iterkeys(self.streamer.streams_by_name) - ] - - await make_deferred_yieldable( - defer.gatherResults(deferreds, consumeErrors=True) - ) - else: - await self.subscribe_to_stream(stream_name, token) - - async def on_FEDERATION_ACK(self, cmd): - self.streamer.federation_ack(cmd.token) - - async def on_REMOVE_PUSHER(self, cmd): - await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) - - async def on_INVALIDATE_CACHE(self, cmd): - await self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) - - async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): - self.streamer.on_remote_server_up(cmd.data) - - async def on_USER_IP(self, cmd): - await self.streamer.on_user_ip( - cmd.user_id, - cmd.access_token, - cmd.ip, - cmd.user_agent, - cmd.device_id, - cmd.last_seen, - ) - - async def subscribe_to_stream(self, stream_name, token): - """Subscribe the remote to a stream. - - This invloves checking if they've missed anything and sending those - updates down if they have. During that time new updates for the stream - are queued and sent once we've sent down any missed updates. - """ - self.replication_streams.discard(stream_name) - self.connecting_streams.add(stream_name) - - try: - # Get missing updates - updates, current_token = await self.streamer.get_stream_updates( - stream_name, token - ) - - # Send all the missing updates - for update in updates: - token, row = update[0], update[1] - self.send_command(RdataCommand(stream_name, token, row)) - - # We send a POSITION command to ensure that they have an up to - # date token (especially useful if we didn't send any updates - # above) - self.send_command(PositionCommand(stream_name, current_token)) - - # Now we can send any updates that came in while we were subscribing - pending_rdata = self.pending_rdata.pop(stream_name, []) - updates = [] - for token, update in pending_rdata: - # If the token is null, it is part of a batch update. Batches - # are multiple updates that share a single token. To denote - # this, the token is set to None for all tokens in the batch - # except for the last. If we find a None token, we keep looking - # through tokens until we find one that is not None and then - # process all previous updates in the batch as if they had the - # final token. - if token is None: - # Store this update as part of a batch - updates.append(update) - continue - - if token <= current_token: - # This update or batch of updates is older than - # current_token, dismiss it - updates = [] - continue - - updates.append(update) - - # Send all updates that are part of this batch with the - # found token - for update in updates: - self.send_command(RdataCommand(stream_name, token, update)) - - # Clear stored updates - updates = [] - - # They're now fully subscribed - self.replication_streams.add(stream_name) - except Exception as e: - logger.exception("[%s] Failed to handle REPLICATE command", self.id()) - self.send_error("failed to handle replicate: %r", e) - finally: - self.connecting_streams.discard(stream_name) - - def stream_update(self, stream_name, token, data): - """Called when a new update is available to stream to clients. - - We need to check if the client is interested in the stream or not - """ - if stream_name in self.replication_streams: - # The client is subscribed to the stream - self.send_command(RdataCommand(stream_name, token, data)) - elif stream_name in self.connecting_streams: - # The client is being subscribed to the stream - logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token) - self.pending_rdata.setdefault(stream_name, []).append((token, data)) - else: - # The client isn't subscribed - logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token) - - def send_sync(self, data): - self.send_command(SyncCommand(data)) - - def send_remote_server_up(self, server: str): - self.send_command(RemoteServerUpCommand(server)) - - def on_connection_closed(self): - BaseReplicationStreamProtocol.on_connection_closed(self) - self.streamer.lost_connection(self) - - -class AbstractReplicationClientHandler(metaclass=abc.ABCMeta): - """ - The interface for the handler that should be passed to - ClientReplicationStreamProtocol - """ - - @abc.abstractmethod - async def on_rdata(self, stream_name, token, rows): - """Called to handle a batch of replication data with a given stream token. - - Args: - stream_name (str): name of the replication stream for this batch of rows - token (int): stream token for this batch of rows - rows (list): a list of Stream.ROW_TYPE objects as returned by - Stream.parse_row. - """ - raise NotImplementedError() - - @abc.abstractmethod - async def on_position(self, stream_name, token): - """Called when we get new position data.""" - raise NotImplementedError() - - @abc.abstractmethod - def on_sync(self, data): - """Called when get a new SYNC command.""" - raise NotImplementedError() - - @abc.abstractmethod - async def on_remote_server_up(self, server: str): - """Called when get a new REMOTE_SERVER_UP command.""" - raise NotImplementedError() - - @abc.abstractmethod - def get_streams_to_replicate(self): - """Called when a new connection has been established and we need to - subscribe to streams. - - Returns: - map from stream name to the most recent update we have for - that stream (ie, the point we want to start replicating from) - """ - raise NotImplementedError() - - @abc.abstractmethod - def get_currently_syncing_users(self): - """Get the list of currently syncing users (if any). This is called - when a connection has been established and we need to send the - currently syncing users.""" - raise NotImplementedError() - - @abc.abstractmethod - def update_connection(self, connection): - """Called when a connection has been established (or lost with None). - """ - raise NotImplementedError() - - @abc.abstractmethod - def finished_connecting(self): - """Called when we have successfully subscribed and caught up to all - streams we're interested in. - """ - raise NotImplementedError() - class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS @@ -638,110 +442,51 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): def __init__( self, + hs: "HomeServer", client_name: str, server_name: str, clock: Clock, - handler: AbstractReplicationClientHandler, + command_handler: "ReplicationCommandHandler", ): - BaseReplicationStreamProtocol.__init__(self, clock) + super().__init__(clock, command_handler) self.client_name = client_name self.server_name = server_name - self.handler = handler - - # Set of stream names that have been subscribe to, but haven't yet - # caught up with. This is used to track when the client has been fully - # connected to the remote. - self.streams_connecting = set() # type: Set[str] - - # Map of stream to batched updates. See RdataCommand for info on how - # batching works. - self.pending_batches = {} # type: Dict[str, Any] def connectionMade(self): self.send_command(NameCommand(self.client_name)) - BaseReplicationStreamProtocol.connectionMade(self) + super().connectionMade() # Once we've connected subscribe to the necessary streams - for stream_name, token in iteritems(self.handler.get_streams_to_replicate()): - self.replicate(stream_name, token) - - # Tell the server if we have any users currently syncing (should only - # happen on synchrotrons) - currently_syncing = self.handler.get_currently_syncing_users() - now = self.clock.time_msec() - for user_id in currently_syncing: - self.send_command(UserSyncCommand(user_id, True, now)) - - # We've now finished connecting to so inform the client handler - self.handler.update_connection(self) - - # This will happen if we don't actually subscribe to any streams - if not self.streams_connecting: - self.handler.finished_connecting() + self.replicate() async def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) self.send_error("Wrong remote") - async def on_RDATA(self, cmd): - stream_name = cmd.stream_name - inbound_rdata_count.labels(stream_name).inc() - - try: - row = STREAMS_MAP[stream_name].parse_row(cmd.row) - except Exception: - logger.exception( - "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row - ) - raise - - if cmd.token is None: - # I.e. this is part of a batch of updates for this stream. Batch - # until we get an update for the stream with a non None token - self.pending_batches.setdefault(stream_name, []).append(row) - else: - # Check if this is the last of a batch of updates - rows = self.pending_batches.pop(stream_name, []) - rows.append(row) - await self.handler.on_rdata(stream_name, cmd.token, rows) - - async def on_POSITION(self, cmd): - # When we get a `POSITION` command it means we've finished getting - # missing updates for the given stream, and are now up to date. - self.streams_connecting.discard(cmd.stream_name) - if not self.streams_connecting: - self.handler.finished_connecting() + def replicate(self): + """Send the subscription request to the server + """ + logger.info("[%s] Subscribing to replication streams", self.id()) - await self.handler.on_position(cmd.stream_name, cmd.token) + self.send_command(ReplicateCommand()) - async def on_SYNC(self, cmd): - self.handler.on_sync(cmd.data) - async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): - self.handler.on_remote_server_up(cmd.data) +class AbstractConnection(abc.ABC): + """An interface for replication connections. + """ - def replicate(self, stream_name, token): - """Send the subscription request to the server + @abc.abstractmethod + def send_command(self, cmd: Command): + """Send the command down the connection """ - if stream_name not in STREAMS_MAP: - raise Exception("Invalid stream name %r" % (stream_name,)) - - logger.info( - "[%s] Subscribing to replication stream: %r from %r", - self.id(), - stream_name, - token, - ) - - self.streams_connecting.add(stream_name) + pass - self.send_command(ReplicateCommand(stream_name, token)) - def on_connection_closed(self): - BaseReplicationStreamProtocol.on_connection_closed(self) - self.handler.update_connection(None) +# This tells python that `BaseReplicationStreamProtocol` implements the +# interface. +AbstractConnection.register(BaseReplicationStreamProtocol) # The following simply registers metrics for the replication connections @@ -804,31 +549,3 @@ tcp_transport_kernel_read_buffer = LaterGauge( for p in connected_connections }, ) - - -tcp_inbound_commands = LaterGauge( - "synapse_replication_tcp_protocol_inbound_commands", - "", - ["command", "name"], - lambda: { - (k, p.name): count - for p in connected_connections - for k, count in iteritems(p.inbound_commands_counter) - }, -) - -tcp_outbound_commands = LaterGauge( - "synapse_replication_tcp_protocol_outbound_commands", - "", - ["command", "name"], - lambda: { - (k, p.name): count - for p in connected_connections - for k, count in iteritems(p.outbound_commands_counter) - }, -) - -# number of updates received for each RDATA stream -inbound_rdata_count = Counter( - "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"] -)