diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index dae246825f..f2a37f568e 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -46,12 +46,11 @@ indicate which side is sending, these are *not* included on the wire::
> ERROR server stopping
* connection closed by server *
"""
-import abc
import fcntl
import logging
import struct
from collections import defaultdict
-from typing import Any, DefaultDict, Dict, List, Set
+from typing import TYPE_CHECKING, DefaultDict, List
from six import iteritems
@@ -78,13 +77,12 @@ from synapse.replication.tcp.commands import (
SyncCommand,
UserSyncCommand,
)
-from synapse.replication.tcp.streams import STREAMS_MAP, Stream
from synapse.types import Collection
from synapse.util import Clock
from synapse.util.stringutils import random_string
-MYPY = False
-if MYPY:
+if TYPE_CHECKING:
+ from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.server import HomeServer
@@ -475,71 +473,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.streamer.lost_connection(self)
-class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
- """
- The interface for the handler that should be passed to
- ClientReplicationStreamProtocol
- """
-
- @abc.abstractmethod
- async def on_rdata(self, stream_name, token, rows):
- """Called to handle a batch of replication data with a given stream token.
-
- Args:
- stream_name (str): name of the replication stream for this batch of rows
- token (int): stream token for this batch of rows
- rows (list): a list of Stream.ROW_TYPE objects as returned by
- Stream.parse_row.
- """
- raise NotImplementedError()
-
- @abc.abstractmethod
- async def on_position(self, stream_name, token):
- """Called when we get new position data."""
- raise NotImplementedError()
-
- @abc.abstractmethod
- def on_sync(self, data):
- """Called when get a new SYNC command."""
- raise NotImplementedError()
-
- @abc.abstractmethod
- async def on_remote_server_up(self, server: str):
- """Called when get a new REMOTE_SERVER_UP command."""
- raise NotImplementedError()
-
- @abc.abstractmethod
- def get_streams_to_replicate(self):
- """Called when a new connection has been established and we need to
- subscribe to streams.
-
- Returns:
- map from stream name to the most recent update we have for
- that stream (ie, the point we want to start replicating from)
- """
- raise NotImplementedError()
-
- @abc.abstractmethod
- def get_currently_syncing_users(self):
- """Get the list of currently syncing users (if any). This is called
- when a connection has been established and we need to send the
- currently syncing users."""
- raise NotImplementedError()
-
- @abc.abstractmethod
- def update_connection(self, connection):
- """Called when a connection has been established (or lost with None).
- """
- raise NotImplementedError()
-
- @abc.abstractmethod
- def finished_connecting(self):
- """Called when we have successfully subscribed and caught up to all
- streams we're interested in.
- """
- raise NotImplementedError()
-
-
class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
@@ -550,7 +483,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
client_name: str,
server_name: str,
clock: Clock,
- handler: AbstractReplicationClientHandler,
+ command_handler: "ReplicationCommandHandler",
):
BaseReplicationStreamProtocol.__init__(self, clock)
@@ -558,20 +491,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.client_name = client_name
self.server_name = server_name
- self.handler = handler
-
- self.streams = {
- stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
- } # type: Dict[str, Stream]
-
- # Set of stream names that have been subscribe to, but haven't yet
- # caught up with. This is used to track when the client has been fully
- # connected to the remote.
- self.streams_connecting = set(STREAMS_MAP) # type: Set[str]
-
- # Map of stream to batched updates. See RdataCommand for info on how
- # batching works.
- self.pending_batches = {} # type: Dict[str, List[Any]]
+ self.handler = command_handler
def connectionMade(self):
self.send_command(NameCommand(self.client_name))
@@ -589,89 +509,39 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# We've now finished connecting to so inform the client handler
self.handler.update_connection(self)
+ self.handler.finished_connecting()
- async def on_SERVER(self, cmd):
- if cmd.data != self.server_name:
- logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
- self.send_error("Wrong remote")
-
- async def on_RDATA(self, cmd):
- stream_name = cmd.stream_name
- inbound_rdata_count.labels(stream_name).inc()
-
- try:
- row = STREAMS_MAP[stream_name].parse_row(cmd.row)
- except Exception:
- logger.exception(
- "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row
- )
- raise
-
- if cmd.token is None or stream_name in self.streams_connecting:
- # I.e. this is part of a batch of updates for this stream. Batch
- # until we get an update for the stream with a non None token
- self.pending_batches.setdefault(stream_name, []).append(row)
- else:
- # Check if this is the last of a batch of updates
- rows = self.pending_batches.pop(stream_name, [])
- rows.append(row)
- await self.handler.on_rdata(stream_name, cmd.token, rows)
-
- async def on_POSITION(self, cmd: PositionCommand):
- stream = self.streams.get(cmd.stream_name)
- if not stream:
- logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
- return
-
- # Find where we previously streamed up to.
- current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name)
- if current_token is None:
- logger.warning(
- "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name
- )
- return
-
- # Fetch all updates between then and now.
- limited = True
- while limited:
- updates, current_token, limited = await stream.get_updates_since(
- current_token, cmd.token
- )
-
- # Check if the connection was closed underneath us, if so we bail
- # rather than risk having concurrent catch ups going on.
- if self.state == ConnectionStates.CLOSED:
- return
-
- if updates:
- await self.handler.on_rdata(
- cmd.stream_name,
- current_token,
- [stream.parse_row(update[1]) for update in updates],
- )
+ async def handle_command(self, cmd: Command):
+ """Handle a command we have received over the replication stream.
- # We've now caught up to position sent to us, notify handler.
- await self.handler.on_position(cmd.stream_name, cmd.token)
+ Delegates to `command_handler.on_<COMMAND>`, which must return an
+ awaitable.
- self.streams_connecting.discard(cmd.stream_name)
- if not self.streams_connecting:
- self.handler.finished_connecting()
+ Args:
+ cmd: received command
+ """
+ handled = False
- # Check if the connection was closed underneath us, if so we bail
- # rather than risk having concurrent catch ups going on.
- if self.state == ConnectionStates.CLOSED:
- return
+ # First call any command handlers on this instance. These are for TCP
+ # specific handling.
+ cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
+ if cmd_func:
+ await cmd_func(cmd)
+ handled = True
- # Handle any RDATA that came in while we were catching up.
- rows = self.pending_batches.pop(cmd.stream_name, [])
- if rows:
- await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows)
+ # Then call out to the handler.
+ cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
+ if cmd_func:
+ await cmd_func(cmd)
+ handled = True
- async def on_SYNC(self, cmd):
- self.handler.on_sync(cmd.data)
+ if not handled:
+ logger.warning("Unhandled command: %r", cmd)
- async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
- self.handler.on_remote_server_up(cmd.data)
+ async def on_SERVER(self, cmd):
+ if cmd.data != self.server_name:
+ logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
+ self.send_error("Wrong remote")
def replicate(self):
"""Send the subscription request to the server
@@ -768,8 +638,3 @@ tcp_outbound_commands = LaterGauge(
for k, count in iteritems(p.outbound_commands_counter)
},
)
-
-# number of updates received for each RDATA stream
-inbound_rdata_count = Counter(
- "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"]
-)
|