summary refs log tree commit diff
path: root/synapse/replication/tcp/protocol.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/protocol.py')
-rw-r--r--synapse/replication/tcp/protocol.py286
1 files changed, 36 insertions, 250 deletions
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index ff720beb56..d4456f42f3 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -46,7 +46,6 @@ indicate which side is sending, these are *not* included on the wire::
     > ERROR server stopping
     * connection closed by server *
 """
-import abc
 import fcntl
 import logging
 import struct
@@ -64,26 +63,22 @@ from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.commands import (
     COMMAND_MAP,
-    VALID_CLIENT_COMMANDS,
-    VALID_SERVER_COMMANDS,
     Command,
     ErrorCommand,
     NameCommand,
     PingCommand,
-    PositionCommand,
-    RdataCommand,
     RemoteServerUpCommand,
     ReplicateCommand,
     ServerCommand,
-    SyncCommand,
-    UserSyncCommand,
 )
 from synapse.replication.tcp.streams import STREAMS_MAP, Stream
-from synapse.server import HomeServer
-from synapse.types import Collection
 from synapse.util import Clock
 from synapse.util.stringutils import random_string
 
+MYPY = False
+if MYPY:
+    import synapse.server
+
 connection_close_counter = Counter(
     "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
 )
@@ -124,16 +119,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
     delimiter = b"\n"
 
-    # Valid commands we expect to receive
-    VALID_INBOUND_COMMANDS = []  # type: Collection[str]
-
-    # Valid commands we can send
-    VALID_OUTBOUND_COMMANDS = []  # type: Collection[str]
-
     max_line_buffer = 10000
 
-    def __init__(self, clock):
+    def __init__(self, clock, handler):
         self.clock = clock
+        self.handler = handler
 
         self.last_received_command = self.clock.time_msec()
         self.last_sent_command = 0
@@ -173,6 +163,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         # can time us out.
         self.send_command(PingCommand(self.clock.time_msec()))
 
+        self.handler.new_connection(self)
+
     def send_ping(self):
         """Periodically sends a ping and checks if we should close the connection
         due to the other side timing out.
@@ -210,11 +202,6 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         line = line.decode("utf-8")
         cmd_name, rest_of_line = line.split(" ", 1)
 
-        if cmd_name not in self.VALID_INBOUND_COMMANDS:
-            logger.error("[%s] invalid command %s", self.id(), cmd_name)
-            self.send_error("invalid command: %s", cmd_name)
-            return
-
         self.last_received_command = self.clock.time_msec()
 
         self.inbound_commands_counter[cmd_name] = (
@@ -246,8 +233,23 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         Args:
             cmd: received command
         """
-        handler = getattr(self, "on_%s" % (cmd.NAME,))
-        await handler(cmd)
+        handled = False
+
+        # First call any command handlers on this instance. These are for TCP
+        # specific handling.
+        cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
+        if cmd_func:
+            await cmd_func(cmd)
+            handled = True
+
+        # Then call out to the handler.
+        cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
+        if cmd_func:
+            await cmd_func(cmd)
+            handled = True
+
+        if not handled:
+            logger.warning("Unhandled command: %r", cmd)
 
     def close(self):
         logger.warning("[%s] Closing connection", self.id())
@@ -255,6 +257,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         self.transport.loseConnection()
         self.on_connection_closed()
 
+    def send_remote_server_up(self, server: str):
+        self.send_command(RemoteServerUpCommand(server))
+
     def send_error(self, error_string, *args):
         """Send an error to remote and close the connection.
         """
@@ -376,6 +381,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         self.state = ConnectionStates.CLOSED
         self.pending_commands = []
 
+        self.handler.lost_connection(self)
+
         if self.transport:
             self.transport.unregisterProducer()
 
@@ -399,162 +406,35 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
 
 class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
-    VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS
-    VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS
-
-    def __init__(self, server_name, clock, streamer):
-        BaseReplicationStreamProtocol.__init__(self, clock)  # Old style class
+    def __init__(self, hs, server_name, clock, handler):
+        BaseReplicationStreamProtocol.__init__(self, clock, handler)  # Old style class
 
         self.server_name = server_name
-        self.streamer = streamer
 
     def connectionMade(self):
         self.send_command(ServerCommand(self.server_name))
         BaseReplicationStreamProtocol.connectionMade(self)
-        self.streamer.new_connection(self)
 
     async def on_NAME(self, cmd):
         logger.info("[%s] Renamed to %r", self.id(), cmd.data)
         self.name = cmd.data
 
-    async def on_USER_SYNC(self, cmd):
-        await self.streamer.on_user_sync(
-            cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
-        )
-
-    async def on_CLEAR_USER_SYNC(self, cmd):
-        await self.streamer.on_clear_user_syncs(cmd.instance_id)
-
-    async def on_REPLICATE(self, cmd):
-        # Subscribe to all streams we're publishing to.
-        for stream_name in self.streamer.streams_by_name:
-            current_token = self.streamer.get_stream_token(stream_name)
-            self.send_command(PositionCommand(stream_name, current_token))
-
-    async def on_FEDERATION_ACK(self, cmd):
-        self.streamer.federation_ack(cmd.token)
-
-    async def on_REMOVE_PUSHER(self, cmd):
-        await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
-
-    async def on_INVALIDATE_CACHE(self, cmd):
-        await self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
-
-    async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
-        self.streamer.on_remote_server_up(cmd.data)
-
-    async def on_USER_IP(self, cmd):
-        self.streamer.on_user_ip(
-            cmd.user_id,
-            cmd.access_token,
-            cmd.ip,
-            cmd.user_agent,
-            cmd.device_id,
-            cmd.last_seen,
-        )
-
-    def stream_update(self, stream_name, token, data):
-        """Called when a new update is available to stream to clients.
-
-        We need to check if the client is interested in the stream or not
-        """
-        self.send_command(RdataCommand(stream_name, token, data))
-
-    def send_sync(self, data):
-        self.send_command(SyncCommand(data))
-
-    def send_remote_server_up(self, server: str):
-        self.send_command(RemoteServerUpCommand(server))
-
-    def on_connection_closed(self):
-        BaseReplicationStreamProtocol.on_connection_closed(self)
-        self.streamer.lost_connection(self)
-
-
-class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
-    """
-    The interface for the handler that should be passed to
-    ClientReplicationStreamProtocol
-    """
-
-    @abc.abstractmethod
-    async def on_rdata(self, stream_name, token, rows):
-        """Called to handle a batch of replication data with a given stream token.
-
-        Args:
-            stream_name (str): name of the replication stream for this batch of rows
-            token (int): stream token for this batch of rows
-            rows (list): a list of Stream.ROW_TYPE objects as returned by
-                Stream.parse_row.
-        """
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    async def on_position(self, stream_name, token):
-        """Called when we get new position data."""
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    def on_sync(self, data):
-        """Called when get a new SYNC command."""
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    async def on_remote_server_up(self, server: str):
-        """Called when get a new REMOTE_SERVER_UP command."""
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    def get_streams_to_replicate(self):
-        """Called when a new connection has been established and we need to
-        subscribe to streams.
-
-        Returns:
-            map from stream name to the most recent update we have for
-            that stream (ie, the point we want to start replicating from)
-        """
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    def get_currently_syncing_users(self):
-        """Get the list of currently syncing users (if any). This is called
-        when a connection has been established and we need to send the
-        currently syncing users."""
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    def update_connection(self, connection):
-        """Called when a connection has been established (or lost with None).
-        """
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    def finished_connecting(self):
-        """Called when we have successfully subscribed and caught up to all
-        streams we're interested in.
-        """
-        raise NotImplementedError()
-
 
 class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
-    VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
-    VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
-
     def __init__(
         self,
-        hs: HomeServer,
+        hs: "synapse.server.HomeServer",
         client_name: str,
         server_name: str,
         clock: Clock,
-        handler: AbstractReplicationClientHandler,
+        handler,
     ):
-        BaseReplicationStreamProtocol.__init__(self, clock)
+        BaseReplicationStreamProtocol.__init__(self, clock, handler)
 
         self.instance_id = hs.get_instance_id()
 
         self.client_name = client_name
         self.server_name = server_name
-        self.handler = handler
 
         self.streams = {
             stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
@@ -570,106 +450,16 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
         self.pending_batches = {}  # type: Dict[str, List[Any]]
 
     def connectionMade(self):
-        self.send_command(NameCommand(self.client_name))
         BaseReplicationStreamProtocol.connectionMade(self)
 
-        # Once we've connected subscribe to the necessary streams
+        self.send_command(NameCommand(self.client_name))
         self.replicate()
 
-        # Tell the server if we have any users currently syncing (should only
-        # happen on synchrotrons)
-        currently_syncing = self.handler.get_currently_syncing_users()
-        now = self.clock.time_msec()
-        for user_id in currently_syncing:
-            self.send_command(UserSyncCommand(self.instance_id, user_id, True, now))
-
-        # We've now finished connecting to so inform the client handler
-        self.handler.update_connection(self)
-
     async def on_SERVER(self, cmd):
         if cmd.data != self.server_name:
             logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
             self.send_error("Wrong remote")
 
-    async def on_RDATA(self, cmd):
-        stream_name = cmd.stream_name
-        inbound_rdata_count.labels(stream_name).inc()
-
-        try:
-            row = STREAMS_MAP[stream_name].parse_row(cmd.row)
-        except Exception:
-            logger.exception(
-                "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row
-            )
-            raise
-
-        if cmd.token is None or stream_name in self.streams_connecting:
-            # I.e. this is part of a batch of updates for this stream. Batch
-            # until we get an update for the stream with a non None token
-            self.pending_batches.setdefault(stream_name, []).append(row)
-        else:
-            # Check if this is the last of a batch of updates
-            rows = self.pending_batches.pop(stream_name, [])
-            rows.append(row)
-            await self.handler.on_rdata(stream_name, cmd.token, rows)
-
-    async def on_POSITION(self, cmd: PositionCommand):
-        stream = self.streams.get(cmd.stream_name)
-        if not stream:
-            logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
-            return
-
-        # Find where we previously streamed up to.
-        current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name)
-        if current_token is None:
-            logger.warning(
-                "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name
-            )
-            return
-
-        # Fetch all updates between then and now.
-        limited = True
-        while limited:
-            updates, current_token, limited = await stream.get_updates_since(
-                current_token, cmd.token
-            )
-
-            # Check if the connection was closed underneath us, if so we bail
-            # rather than risk having concurrent catch ups going on.
-            if self.state == ConnectionStates.CLOSED:
-                return
-
-            if updates:
-                await self.handler.on_rdata(
-                    cmd.stream_name,
-                    current_token,
-                    [stream.parse_row(update[1]) for update in updates],
-                )
-
-        # We've now caught up to position sent to us, notify handler.
-        await self.handler.on_position(cmd.stream_name, cmd.token)
-
-        # We're now up to date wit the stream
-        self.streams_connecting.discard(cmd.stream_name)
-        if not self.streams_connecting:
-            self.handler.finished_connecting()
-
-        # Check if the connection was closed underneath us, if so we bail
-        # rather than risk having concurrent catch ups going on.
-        if self.state == ConnectionStates.CLOSED:
-            return
-
-        # Handle any RDATA that came in while we were catching up.
-        rows = self.pending_batches.pop(cmd.stream_name, [])
-        if rows:
-            await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows)
-
-    async def on_SYNC(self, cmd):
-        self.handler.on_sync(cmd.data)
-
-    async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
-        self.handler.on_remote_server_up(cmd.data)
-
     def replicate(self):
         """Send the subscription request to the server
         """
@@ -677,10 +467,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
         self.send_command(ReplicateCommand())
 
-    def on_connection_closed(self):
-        BaseReplicationStreamProtocol.on_connection_closed(self)
-        self.handler.update_connection(None)
-
 
 # The following simply registers metrics for the replication connections