summary refs log tree commit diff
path: root/synapse/replication/tcp/protocol.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/protocol.py')
-rw-r--r--synapse/replication/tcp/protocol.py197
1 files changed, 31 insertions, 166 deletions
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index dae246825f..f2a37f568e 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -46,12 +46,11 @@ indicate which side is sending, these are *not* included on the wire::
     > ERROR server stopping
     * connection closed by server *
 """
-import abc
 import fcntl
 import logging
 import struct
 from collections import defaultdict
-from typing import Any, DefaultDict, Dict, List, Set
+from typing import TYPE_CHECKING, DefaultDict, List
 
 from six import iteritems
 
@@ -78,13 +77,12 @@ from synapse.replication.tcp.commands import (
     SyncCommand,
     UserSyncCommand,
 )
-from synapse.replication.tcp.streams import STREAMS_MAP, Stream
 from synapse.types import Collection
 from synapse.util import Clock
 from synapse.util.stringutils import random_string
 
-MYPY = False
-if MYPY:
+if TYPE_CHECKING:
+    from synapse.replication.tcp.handler import ReplicationCommandHandler
     from synapse.server import HomeServer
 
 
@@ -475,71 +473,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
         self.streamer.lost_connection(self)
 
 
-class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
-    """
-    The interface for the handler that should be passed to
-    ClientReplicationStreamProtocol
-    """
-
-    @abc.abstractmethod
-    async def on_rdata(self, stream_name, token, rows):
-        """Called to handle a batch of replication data with a given stream token.
-
-        Args:
-            stream_name (str): name of the replication stream for this batch of rows
-            token (int): stream token for this batch of rows
-            rows (list): a list of Stream.ROW_TYPE objects as returned by
-                Stream.parse_row.
-        """
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    async def on_position(self, stream_name, token):
-        """Called when we get new position data."""
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    def on_sync(self, data):
-        """Called when get a new SYNC command."""
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    async def on_remote_server_up(self, server: str):
-        """Called when get a new REMOTE_SERVER_UP command."""
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    def get_streams_to_replicate(self):
-        """Called when a new connection has been established and we need to
-        subscribe to streams.
-
-        Returns:
-            map from stream name to the most recent update we have for
-            that stream (ie, the point we want to start replicating from)
-        """
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    def get_currently_syncing_users(self):
-        """Get the list of currently syncing users (if any). This is called
-        when a connection has been established and we need to send the
-        currently syncing users."""
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    def update_connection(self, connection):
-        """Called when a connection has been established (or lost with None).
-        """
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    def finished_connecting(self):
-        """Called when we have successfully subscribed and caught up to all
-        streams we're interested in.
-        """
-        raise NotImplementedError()
-
-
 class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
     VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
     VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
@@ -550,7 +483,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
         client_name: str,
         server_name: str,
         clock: Clock,
-        handler: AbstractReplicationClientHandler,
+        command_handler: "ReplicationCommandHandler",
     ):
         BaseReplicationStreamProtocol.__init__(self, clock)
 
@@ -558,20 +491,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
         self.client_name = client_name
         self.server_name = server_name
-        self.handler = handler
-
-        self.streams = {
-            stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
-        }  # type: Dict[str, Stream]
-
-        # Set of stream names that have been subscribe to, but haven't yet
-        # caught up with. This is used to track when the client has been fully
-        # connected to the remote.
-        self.streams_connecting = set(STREAMS_MAP)  # type: Set[str]
-
-        # Map of stream to batched updates. See RdataCommand for info on how
-        # batching works.
-        self.pending_batches = {}  # type: Dict[str, List[Any]]
+        self.handler = command_handler
 
     def connectionMade(self):
         self.send_command(NameCommand(self.client_name))
@@ -589,89 +509,39 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
         # We've now finished connecting to so inform the client handler
         self.handler.update_connection(self)
+        self.handler.finished_connecting()
 
-    async def on_SERVER(self, cmd):
-        if cmd.data != self.server_name:
-            logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
-            self.send_error("Wrong remote")
-
-    async def on_RDATA(self, cmd):
-        stream_name = cmd.stream_name
-        inbound_rdata_count.labels(stream_name).inc()
-
-        try:
-            row = STREAMS_MAP[stream_name].parse_row(cmd.row)
-        except Exception:
-            logger.exception(
-                "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row
-            )
-            raise
-
-        if cmd.token is None or stream_name in self.streams_connecting:
-            # I.e. this is part of a batch of updates for this stream. Batch
-            # until we get an update for the stream with a non None token
-            self.pending_batches.setdefault(stream_name, []).append(row)
-        else:
-            # Check if this is the last of a batch of updates
-            rows = self.pending_batches.pop(stream_name, [])
-            rows.append(row)
-            await self.handler.on_rdata(stream_name, cmd.token, rows)
-
-    async def on_POSITION(self, cmd: PositionCommand):
-        stream = self.streams.get(cmd.stream_name)
-        if not stream:
-            logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
-            return
-
-        # Find where we previously streamed up to.
-        current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name)
-        if current_token is None:
-            logger.warning(
-                "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name
-            )
-            return
-
-        # Fetch all updates between then and now.
-        limited = True
-        while limited:
-            updates, current_token, limited = await stream.get_updates_since(
-                current_token, cmd.token
-            )
-
-            # Check if the connection was closed underneath us, if so we bail
-            # rather than risk having concurrent catch ups going on.
-            if self.state == ConnectionStates.CLOSED:
-                return
-
-            if updates:
-                await self.handler.on_rdata(
-                    cmd.stream_name,
-                    current_token,
-                    [stream.parse_row(update[1]) for update in updates],
-                )
+    async def handle_command(self, cmd: Command):
+        """Handle a command we have received over the replication stream.
 
-        # We've now caught up to position sent to us, notify handler.
-        await self.handler.on_position(cmd.stream_name, cmd.token)
+        Delegates to `command_handler.on_<COMMAND>`, which must return an
+        awaitable.
 
-        self.streams_connecting.discard(cmd.stream_name)
-        if not self.streams_connecting:
-            self.handler.finished_connecting()
+        Args:
+            cmd: received command
+        """
+        handled = False
 
-        # Check if the connection was closed underneath us, if so we bail
-        # rather than risk having concurrent catch ups going on.
-        if self.state == ConnectionStates.CLOSED:
-            return
+        # First call any command handlers on this instance. These are for TCP
+        # specific handling.
+        cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
+        if cmd_func:
+            await cmd_func(cmd)
+            handled = True
 
-        # Handle any RDATA that came in while we were catching up.
-        rows = self.pending_batches.pop(cmd.stream_name, [])
-        if rows:
-            await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows)
+        # Then call out to the handler.
+        cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
+        if cmd_func:
+            await cmd_func(cmd)
+            handled = True
 
-    async def on_SYNC(self, cmd):
-        self.handler.on_sync(cmd.data)
+        if not handled:
+            logger.warning("Unhandled command: %r", cmd)
 
-    async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
-        self.handler.on_remote_server_up(cmd.data)
+    async def on_SERVER(self, cmd):
+        if cmd.data != self.server_name:
+            logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
+            self.send_error("Wrong remote")
 
     def replicate(self):
         """Send the subscription request to the server
@@ -768,8 +638,3 @@ tcp_outbound_commands = LaterGauge(
         for k, count in iteritems(p.outbound_commands_counter)
     },
 )
-
-# number of updates received for each RDATA stream
-inbound_rdata_count = Counter(
-    "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"]
-)