diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index bc1482a9bb..7240acb0a2 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -35,9 +35,7 @@ indicate which side is sending, these are *not* included on the wire::
> PING 1490197665618
< NAME synapse.app.appservice
< PING 1490197665618
- < REPLICATE events 1
- < REPLICATE backfill 1
- < REPLICATE caches 1
+ < REPLICATE
> POSITION events 1
> POSITION backfill 1
> POSITION caches 1
@@ -53,40 +51,37 @@ import fcntl
import logging
import struct
from collections import defaultdict
-from typing import Any, DefaultDict, Dict, List, Set, Tuple
+from typing import TYPE_CHECKING, DefaultDict, List
-from six import iteritems, iterkeys
+from six import iteritems
from prometheus_client import Counter
-from twisted.internet import defer
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
-from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.commands import (
- COMMAND_MAP,
VALID_CLIENT_COMMANDS,
VALID_SERVER_COMMANDS,
Command,
ErrorCommand,
NameCommand,
PingCommand,
- PositionCommand,
- RdataCommand,
- RemoteServerUpCommand,
ReplicateCommand,
ServerCommand,
- SyncCommand,
- UserSyncCommand,
+ parse_command_from_line,
)
-from synapse.replication.tcp.streams import STREAMS_MAP
from synapse.types import Collection
from synapse.util import Clock
from synapse.util.stringutils import random_string
+if TYPE_CHECKING:
+ from synapse.replication.tcp.handler import ReplicationCommandHandler
+ from synapse.server import HomeServer
+
+
connection_close_counter = Counter(
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
)
@@ -119,7 +114,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
are only sent by the server.
On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed
- command.
+ command before delegating to `ReplicationCommandHandler.on_<COMMAND_NAME>`.
It also sends `PING` periodically, and correctly times out remote connections
(if they send a `PING` command)
@@ -135,8 +130,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
max_line_buffer = 10000
- def __init__(self, clock):
+ def __init__(self, clock: Clock, handler: "ReplicationCommandHandler"):
self.clock = clock
+ self.command_handler = handler
self.last_received_command = self.clock.time_msec()
self.last_sent_command = 0
@@ -176,6 +172,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# can time us out.
self.send_command(PingCommand(self.clock.time_msec()))
+ self.command_handler.new_connection(self)
+
def send_ping(self):
"""Periodically sends a ping and checks if we should close the connection
due to the other side timing out.
@@ -203,39 +201,33 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
)
self.send_error("ping timeout")
- def lineReceived(self, line):
+ def lineReceived(self, line: bytes):
"""Called when we've received a line
"""
if line.strip() == "":
# Ignore blank lines
return
- line = line.decode("utf-8")
- cmd_name, rest_of_line = line.split(" ", 1)
+ linestr = line.decode("utf-8")
+
+ try:
+ cmd = parse_command_from_line(linestr)
+ except Exception as e:
+ logger.exception("[%s] failed to parse line: %r", self.id(), linestr)
+ self.send_error("failed to parse line: %r (%r):" % (e, linestr))
+ return
- if cmd_name not in self.VALID_INBOUND_COMMANDS:
- logger.error("[%s] invalid command %s", self.id(), cmd_name)
- self.send_error("invalid command: %s", cmd_name)
+ if cmd.NAME not in self.VALID_INBOUND_COMMANDS:
+ logger.error("[%s] invalid command %s", self.id(), cmd.NAME)
+ self.send_error("invalid command: %s", cmd.NAME)
return
self.last_received_command = self.clock.time_msec()
- self.inbound_commands_counter[cmd_name] = (
- self.inbound_commands_counter[cmd_name] + 1
+ self.inbound_commands_counter[cmd.NAME] = (
+ self.inbound_commands_counter[cmd.NAME] + 1
)
- cmd_cls = COMMAND_MAP[cmd_name]
- try:
- cmd = cmd_cls.from_line(rest_of_line)
- except Exception as e:
- logger.exception(
- "[%s] failed to parse line %r: %r", self.id(), cmd_name, rest_of_line
- )
- self.send_error(
- "failed to parse line for %r: %r (%r):" % (cmd_name, e, rest_of_line)
- )
- return
-
# Now lets try and call on_<CMD_NAME> function
run_as_background_process(
"replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
@@ -244,13 +236,31 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
async def handle_command(self, cmd: Command):
"""Handle a command we have received over the replication stream.
- By default delegates to on_<COMMAND>, which should return an awaitable.
+ First calls `self.on_<COMMAND>` if it exists, then calls
+ `self.command_handler.on_<COMMAND>` if it exists. This allows for
+ protocol level handling of commands (e.g. PINGs), before delegating to
+ the handler.
Args:
cmd: received command
"""
- handler = getattr(self, "on_%s" % (cmd.NAME,))
- await handler(cmd)
+ handled = False
+
+ # First call any command handlers on this instance. These are for TCP
+ # specific handling.
+ cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
+ if cmd_func:
+ await cmd_func(cmd)
+ handled = True
+
+ # Then call out to the handler.
+ cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
+ if cmd_func:
+ await cmd_func(cmd)
+ handled = True
+
+ if not handled:
+ logger.warning("Unhandled command: %r", cmd)
def close(self):
logger.warning("[%s] Closing connection", self.id())
@@ -379,6 +389,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.state = ConnectionStates.CLOSED
self.pending_commands = []
+ self.command_handler.lost_connection(self)
+
if self.transport:
self.transport.unregisterProducer()
@@ -405,232 +417,21 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS
VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS
- def __init__(self, server_name, clock, streamer):
- BaseReplicationStreamProtocol.__init__(self, clock) # Old style class
+ def __init__(
+ self, server_name: str, clock: Clock, handler: "ReplicationCommandHandler"
+ ):
+ super().__init__(clock, handler)
self.server_name = server_name
- self.streamer = streamer
-
- # The streams the client has subscribed to and is up to date with
- self.replication_streams = set() # type: Set[str]
-
- # The streams the client is currently subscribing to.
- self.connecting_streams = set() # type: Set[str]
-
- # Map from stream name to list of updates to send once we've finished
- # subscribing the client to the stream.
- self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]]
def connectionMade(self):
self.send_command(ServerCommand(self.server_name))
- BaseReplicationStreamProtocol.connectionMade(self)
- self.streamer.new_connection(self)
+ super().connectionMade()
async def on_NAME(self, cmd):
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
self.name = cmd.data
- async def on_USER_SYNC(self, cmd):
- await self.streamer.on_user_sync(
- self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
- )
-
- async def on_REPLICATE(self, cmd):
- stream_name = cmd.stream_name
- token = cmd.token
-
- if stream_name == "ALL":
- # Subscribe to all streams we're publishing to.
- deferreds = [
- run_in_background(self.subscribe_to_stream, stream, token)
- for stream in iterkeys(self.streamer.streams_by_name)
- ]
-
- await make_deferred_yieldable(
- defer.gatherResults(deferreds, consumeErrors=True)
- )
- else:
- await self.subscribe_to_stream(stream_name, token)
-
- async def on_FEDERATION_ACK(self, cmd):
- self.streamer.federation_ack(cmd.token)
-
- async def on_REMOVE_PUSHER(self, cmd):
- await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
-
- async def on_INVALIDATE_CACHE(self, cmd):
- await self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
-
- async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
- self.streamer.on_remote_server_up(cmd.data)
-
- async def on_USER_IP(self, cmd):
- self.streamer.on_user_ip(
- cmd.user_id,
- cmd.access_token,
- cmd.ip,
- cmd.user_agent,
- cmd.device_id,
- cmd.last_seen,
- )
-
- async def subscribe_to_stream(self, stream_name, token):
- """Subscribe the remote to a stream.
-
- This invloves checking if they've missed anything and sending those
- updates down if they have. During that time new updates for the stream
- are queued and sent once we've sent down any missed updates.
- """
- self.replication_streams.discard(stream_name)
- self.connecting_streams.add(stream_name)
-
- try:
- # Get missing updates
- updates, current_token = await self.streamer.get_stream_updates(
- stream_name, token
- )
-
- # Send all the missing updates
- for update in updates:
- token, row = update[0], update[1]
- self.send_command(RdataCommand(stream_name, token, row))
-
- # We send a POSITION command to ensure that they have an up to
- # date token (especially useful if we didn't send any updates
- # above)
- self.send_command(PositionCommand(stream_name, current_token))
-
- # Now we can send any updates that came in while we were subscribing
- pending_rdata = self.pending_rdata.pop(stream_name, [])
- updates = []
- for token, update in pending_rdata:
- # If the token is null, it is part of a batch update. Batches
- # are multiple updates that share a single token. To denote
- # this, the token is set to None for all tokens in the batch
- # except for the last. If we find a None token, we keep looking
- # through tokens until we find one that is not None and then
- # process all previous updates in the batch as if they had the
- # final token.
- if token is None:
- # Store this update as part of a batch
- updates.append(update)
- continue
-
- if token <= current_token:
- # This update or batch of updates is older than
- # current_token, dismiss it
- updates = []
- continue
-
- updates.append(update)
-
- # Send all updates that are part of this batch with the
- # found token
- for update in updates:
- self.send_command(RdataCommand(stream_name, token, update))
-
- # Clear stored updates
- updates = []
-
- # They're now fully subscribed
- self.replication_streams.add(stream_name)
- except Exception as e:
- logger.exception("[%s] Failed to handle REPLICATE command", self.id())
- self.send_error("failed to handle replicate: %r", e)
- finally:
- self.connecting_streams.discard(stream_name)
-
- def stream_update(self, stream_name, token, data):
- """Called when a new update is available to stream to clients.
-
- We need to check if the client is interested in the stream or not
- """
- if stream_name in self.replication_streams:
- # The client is subscribed to the stream
- self.send_command(RdataCommand(stream_name, token, data))
- elif stream_name in self.connecting_streams:
- # The client is being subscribed to the stream
- logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token)
- self.pending_rdata.setdefault(stream_name, []).append((token, data))
- else:
- # The client isn't subscribed
- logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token)
-
- def send_sync(self, data):
- self.send_command(SyncCommand(data))
-
- def send_remote_server_up(self, server: str):
- self.send_command(RemoteServerUpCommand(server))
-
- def on_connection_closed(self):
- BaseReplicationStreamProtocol.on_connection_closed(self)
- self.streamer.lost_connection(self)
-
-
-class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
- """
- The interface for the handler that should be passed to
- ClientReplicationStreamProtocol
- """
-
- @abc.abstractmethod
- async def on_rdata(self, stream_name, token, rows):
- """Called to handle a batch of replication data with a given stream token.
-
- Args:
- stream_name (str): name of the replication stream for this batch of rows
- token (int): stream token for this batch of rows
- rows (list): a list of Stream.ROW_TYPE objects as returned by
- Stream.parse_row.
- """
- raise NotImplementedError()
-
- @abc.abstractmethod
- async def on_position(self, stream_name, token):
- """Called when we get new position data."""
- raise NotImplementedError()
-
- @abc.abstractmethod
- def on_sync(self, data):
- """Called when get a new SYNC command."""
- raise NotImplementedError()
-
- @abc.abstractmethod
- async def on_remote_server_up(self, server: str):
- """Called when get a new REMOTE_SERVER_UP command."""
- raise NotImplementedError()
-
- @abc.abstractmethod
- def get_streams_to_replicate(self):
- """Called when a new connection has been established and we need to
- subscribe to streams.
-
- Returns:
- map from stream name to the most recent update we have for
- that stream (ie, the point we want to start replicating from)
- """
- raise NotImplementedError()
-
- @abc.abstractmethod
- def get_currently_syncing_users(self):
- """Get the list of currently syncing users (if any). This is called
- when a connection has been established and we need to send the
- currently syncing users."""
- raise NotImplementedError()
-
- @abc.abstractmethod
- def update_connection(self, connection):
- """Called when a connection has been established (or lost with None).
- """
- raise NotImplementedError()
-
- @abc.abstractmethod
- def finished_connecting(self):
- """Called when we have successfully subscribed and caught up to all
- streams we're interested in.
- """
- raise NotImplementedError()
-
class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
@@ -638,110 +439,51 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
def __init__(
self,
+ hs: "HomeServer",
client_name: str,
server_name: str,
clock: Clock,
- handler: AbstractReplicationClientHandler,
+ command_handler: "ReplicationCommandHandler",
):
- BaseReplicationStreamProtocol.__init__(self, clock)
+ super().__init__(clock, command_handler)
self.client_name = client_name
self.server_name = server_name
- self.handler = handler
-
- # Set of stream names that have been subscribe to, but haven't yet
- # caught up with. This is used to track when the client has been fully
- # connected to the remote.
- self.streams_connecting = set() # type: Set[str]
-
- # Map of stream to batched updates. See RdataCommand for info on how
- # batching works.
- self.pending_batches = {} # type: Dict[str, Any]
def connectionMade(self):
self.send_command(NameCommand(self.client_name))
- BaseReplicationStreamProtocol.connectionMade(self)
+ super().connectionMade()
# Once we've connected subscribe to the necessary streams
- for stream_name, token in iteritems(self.handler.get_streams_to_replicate()):
- self.replicate(stream_name, token)
-
- # Tell the server if we have any users currently syncing (should only
- # happen on synchrotrons)
- currently_syncing = self.handler.get_currently_syncing_users()
- now = self.clock.time_msec()
- for user_id in currently_syncing:
- self.send_command(UserSyncCommand(user_id, True, now))
-
- # We've now finished connecting to so inform the client handler
- self.handler.update_connection(self)
-
- # This will happen if we don't actually subscribe to any streams
- if not self.streams_connecting:
- self.handler.finished_connecting()
+ self.replicate()
async def on_SERVER(self, cmd):
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
self.send_error("Wrong remote")
- async def on_RDATA(self, cmd):
- stream_name = cmd.stream_name
- inbound_rdata_count.labels(stream_name).inc()
-
- try:
- row = STREAMS_MAP[stream_name].parse_row(cmd.row)
- except Exception:
- logger.exception(
- "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row
- )
- raise
-
- if cmd.token is None:
- # I.e. this is part of a batch of updates for this stream. Batch
- # until we get an update for the stream with a non None token
- self.pending_batches.setdefault(stream_name, []).append(row)
- else:
- # Check if this is the last of a batch of updates
- rows = self.pending_batches.pop(stream_name, [])
- rows.append(row)
- await self.handler.on_rdata(stream_name, cmd.token, rows)
-
- async def on_POSITION(self, cmd):
- # When we get a `POSITION` command it means we've finished getting
- # missing updates for the given stream, and are now up to date.
- self.streams_connecting.discard(cmd.stream_name)
- if not self.streams_connecting:
- self.handler.finished_connecting()
+ def replicate(self):
+ """Send the subscription request to the server
+ """
+ logger.info("[%s] Subscribing to replication streams", self.id())
- await self.handler.on_position(cmd.stream_name, cmd.token)
+ self.send_command(ReplicateCommand())
- async def on_SYNC(self, cmd):
- self.handler.on_sync(cmd.data)
- async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
- self.handler.on_remote_server_up(cmd.data)
+class AbstractConnection(abc.ABC):
+ """An interface for replication connections.
+ """
- def replicate(self, stream_name, token):
- """Send the subscription request to the server
+ @abc.abstractmethod
+ def send_command(self, cmd: Command):
+ """Send the command down the connection
"""
- if stream_name not in STREAMS_MAP:
- raise Exception("Invalid stream name %r" % (stream_name,))
-
- logger.info(
- "[%s] Subscribing to replication stream: %r from %r",
- self.id(),
- stream_name,
- token,
- )
-
- self.streams_connecting.add(stream_name)
+ pass
- self.send_command(ReplicateCommand(stream_name, token))
- def on_connection_closed(self):
- BaseReplicationStreamProtocol.on_connection_closed(self)
- self.handler.update_connection(None)
+# This tells python that `BaseReplicationStreamProtocol` implements the
+# interface.
+AbstractConnection.register(BaseReplicationStreamProtocol)
# The following simply registers metrics for the replication connections
@@ -827,8 +569,3 @@ tcp_outbound_commands = LaterGauge(
for k, count in iteritems(p.outbound_commands_counter)
},
)
-
-# number of updates received for each RDATA stream
-inbound_rdata_count = Counter(
- "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"]
-)
|