diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index b51590cf8f..d185cc0c8f 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -48,11 +48,12 @@ indicate which side is sending, these are *not* included on the wire::
> ERROR server stopping
* connection closed by server *
"""
-
+import abc
import fcntl
import logging
import struct
from collections import defaultdict
+from typing import Any, DefaultDict, Dict, List, Set, Tuple
from six import iteritems, iterkeys
@@ -62,29 +63,33 @@ from twisted.internet import defer
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
-from synapse.util.stringutils import random_string
-
-from .commands import (
+from synapse.replication.tcp.commands import (
COMMAND_MAP,
VALID_CLIENT_COMMANDS,
VALID_SERVER_COMMANDS,
+ Command,
ErrorCommand,
NameCommand,
PingCommand,
PositionCommand,
RdataCommand,
+ RemoteServerUpCommand,
ReplicateCommand,
ServerCommand,
SyncCommand,
UserSyncCommand,
)
-from .streams import STREAMS_MAP
+from synapse.replication.tcp.streams import STREAMS_MAP
+from synapse.types import Collection
+from synapse.util import Clock
+from synapse.util.stringutils import random_string
connection_close_counter = Counter(
- "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"])
+ "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
+)
# A list of all connected protocols. This allows us to send metrics about the
# connections.
@@ -119,10 +124,14 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
It also sends `PING` periodically, and correctly times out remote connections
(if they send a `PING` command)
"""
- delimiter = b'\n'
- VALID_INBOUND_COMMANDS = [] # Valid commands we expect to receive
- VALID_OUTBOUND_COMMANDS = [] # Valid commans we can send
+ delimiter = b"\n"
+
+ # Valid commands we expect to receive
+ VALID_INBOUND_COMMANDS = [] # type: Collection[str]
+
+ # Valid commands we can send
+ VALID_OUTBOUND_COMMANDS = [] # type: Collection[str]
max_line_buffer = 10000
@@ -141,13 +150,13 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.conn_id = random_string(5) # To dedupe in case of name clashes.
# List of pending commands to send once we've established the connection
- self.pending_commands = []
+ self.pending_commands = [] # type: List[Command]
# The LoopingCall for sending pings.
self._send_ping_loop = None
- self.inbound_commands_counter = defaultdict(int)
- self.outbound_commands_counter = defaultdict(int)
+ self.inbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int]
+ self.outbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int]
def connectionMade(self):
logger.info("[%s] Connection established", self.id())
@@ -183,10 +192,14 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
if now - self.last_sent_command >= PING_TIME:
self.send_command(PingCommand(now))
- if self.received_ping and now - self.last_received_command > PING_TIMEOUT_MS:
+ if (
+ self.received_ping
+ and now - self.last_received_command > PING_TIMEOUT_MS
+ ):
logger.info(
"[%s] Connection hasn't received command in %r ms. Closing.",
- self.id(), now - self.last_received_command
+ self.id(),
+ now - self.last_received_command,
)
self.send_error("ping timeout")
@@ -208,7 +221,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.last_received_command = self.clock.time_msec()
self.inbound_commands_counter[cmd_name] = (
- self.inbound_commands_counter[cmd_name] + 1)
+ self.inbound_commands_counter[cmd_name] + 1
+ )
cmd_cls = COMMAND_MAP[cmd_name]
try:
@@ -224,27 +238,22 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# Now lets try and call on_<CMD_NAME> function
run_as_background_process(
- "replication-" + cmd.get_logcontext_id(),
- self.handle_command,
- cmd,
+ "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
)
- def handle_command(self, cmd):
+ async def handle_command(self, cmd: Command):
"""Handle a command we have received over the replication stream.
- By default delegates to on_<COMMAND>
+ By default delegates to on_<COMMAND>, which should return an awaitable.
Args:
- cmd (synapse.replication.tcp.commands.Command): received command
-
- Returns:
- Deferred
+ cmd: received command
"""
handler = getattr(self, "on_%s" % (cmd.NAME,))
- return handler(cmd)
+ await handler(cmd)
def close(self):
- logger.warn("[%s] Closing connection", self.id())
+ logger.warning("[%s] Closing connection", self.id())
self.time_we_closed = self.clock.time_msec()
self.transport.loseConnection()
self.on_connection_closed()
@@ -274,8 +283,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
return
self.outbound_commands_counter[cmd.NAME] = (
- self.outbound_commands_counter[cmd.NAME] + 1)
- string = "%s %s" % (cmd.NAME, cmd.to_line(),)
+ self.outbound_commands_counter[cmd.NAME] + 1
+ )
+ string = "%s %s" % (cmd.NAME, cmd.to_line())
if "\n" in string:
raise Exception("Unexpected newline in command: %r", string)
@@ -283,10 +293,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
if len(encoded_string) > self.MAX_LENGTH:
raise Exception(
- "Failed to send command %s as too long (%d > %d)" % (
- cmd.NAME,
- len(encoded_string), self.MAX_LENGTH,
- )
+ "Failed to send command %s as too long (%d > %d)"
+ % (cmd.NAME, len(encoded_string), self.MAX_LENGTH)
)
self.sendLine(encoded_string)
@@ -315,10 +323,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
for cmd in pending:
self.send_command(cmd)
- def on_PING(self, line):
+ async def on_PING(self, line):
self.received_ping = True
- def on_ERROR(self, cmd):
+ async def on_ERROR(self, cmd):
logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
def pauseProducing(self):
@@ -379,7 +387,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
if self.transport:
addr = str(self.transport.getPeer())
return "ReplicationConnection<name=%s,conn_id=%s,addr=%s>" % (
- self.name, self.conn_id, addr,
+ self.name,
+ self.conn_id,
+ addr,
)
def id(self):
@@ -402,68 +412,69 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.streamer = streamer
# The streams the client has subscribed to and is up to date with
- self.replication_streams = set()
+ self.replication_streams = set() # type: Set[str]
# The streams the client is currently subscribing to.
- self.connecting_streams = set()
+ self.connecting_streams = set() # type: Set[str]
# Map from stream name to list of updates to send once we've finished
# subscribing the client to the stream.
- self.pending_rdata = {}
+ self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]]
def connectionMade(self):
self.send_command(ServerCommand(self.server_name))
BaseReplicationStreamProtocol.connectionMade(self)
self.streamer.new_connection(self)
- def on_NAME(self, cmd):
+ async def on_NAME(self, cmd):
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
self.name = cmd.data
- def on_USER_SYNC(self, cmd):
- return self.streamer.on_user_sync(
- self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms,
+ async def on_USER_SYNC(self, cmd):
+ await self.streamer.on_user_sync(
+ self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
- def on_REPLICATE(self, cmd):
+ async def on_REPLICATE(self, cmd):
stream_name = cmd.stream_name
token = cmd.token
if stream_name == "ALL":
# Subscribe to all streams we're publishing to.
deferreds = [
- run_in_background(
- self.subscribe_to_stream,
- stream, token,
- )
+ run_in_background(self.subscribe_to_stream, stream, token)
for stream in iterkeys(self.streamer.streams_by_name)
]
- return make_deferred_yieldable(
+ await make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
else:
- return self.subscribe_to_stream(stream_name, token)
+ await self.subscribe_to_stream(stream_name, token)
- def on_FEDERATION_ACK(self, cmd):
- return self.streamer.federation_ack(cmd.token)
+ async def on_FEDERATION_ACK(self, cmd):
+ self.streamer.federation_ack(cmd.token)
- def on_REMOVE_PUSHER(self, cmd):
- return self.streamer.on_remove_pusher(
- cmd.app_id, cmd.push_key, cmd.user_id,
- )
+ async def on_REMOVE_PUSHER(self, cmd):
+ await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
- def on_INVALIDATE_CACHE(self, cmd):
- return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
+ async def on_INVALIDATE_CACHE(self, cmd):
+ await self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
- def on_USER_IP(self, cmd):
- return self.streamer.on_user_ip(
- cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id,
+ async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
+ self.streamer.on_remote_server_up(cmd.data)
+
+ async def on_USER_IP(self, cmd):
+ await self.streamer.on_user_ip(
+ cmd.user_id,
+ cmd.access_token,
+ cmd.ip,
+ cmd.user_agent,
+ cmd.device_id,
cmd.last_seen,
)
- @defer.inlineCallbacks
- def subscribe_to_stream(self, stream_name, token):
+ async def subscribe_to_stream(self, stream_name, token):
"""Subscribe the remote to a stream.
This invloves checking if they've missed anything and sending those
@@ -475,8 +486,8 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
try:
# Get missing updates
- updates, current_token = yield self.streamer.get_stream_updates(
- stream_name, token,
+ updates, current_token = await self.streamer.get_stream_updates(
+ stream_name, token
)
# Send all the missing updates
@@ -548,16 +559,90 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
def send_sync(self, data):
self.send_command(SyncCommand(data))
+ def send_remote_server_up(self, server: str):
+ self.send_command(RemoteServerUpCommand(server))
+
def on_connection_closed(self):
BaseReplicationStreamProtocol.on_connection_closed(self)
self.streamer.lost_connection(self)
+class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
+ """
+ The interface for the handler that should be passed to
+ ClientReplicationStreamProtocol
+ """
+
+ @abc.abstractmethod
+ async def on_rdata(self, stream_name, token, rows):
+ """Called to handle a batch of replication data with a given stream token.
+
+ Args:
+ stream_name (str): name of the replication stream for this batch of rows
+ token (int): stream token for this batch of rows
+ rows (list): a list of Stream.ROW_TYPE objects as returned by
+ Stream.parse_row.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ async def on_position(self, stream_name, token):
+ """Called when we get new position data."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def on_sync(self, data):
+ """Called when get a new SYNC command."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ async def on_remote_server_up(self, server: str):
+ """Called when get a new REMOTE_SERVER_UP command."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_streams_to_replicate(self):
+ """Called when a new connection has been established and we need to
+ subscribe to streams.
+
+ Returns:
+ map from stream name to the most recent update we have for
+ that stream (ie, the point we want to start replicating from)
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_currently_syncing_users(self):
+ """Get the list of currently syncing users (if any). This is called
+ when a connection has been established and we need to send the
+ currently syncing users."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def update_connection(self, connection):
+ """Called when a connection has been established (or lost with None).
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def finished_connecting(self):
+ """Called when we have successfully subscribed and caught up to all
+ streams we're interested in.
+ """
+ raise NotImplementedError()
+
+
class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
- def __init__(self, client_name, server_name, clock, handler):
+ def __init__(
+ self,
+ client_name: str,
+ server_name: str,
+ clock: Clock,
+ handler: AbstractReplicationClientHandler,
+ ):
BaseReplicationStreamProtocol.__init__(self, clock)
self.client_name = client_name
@@ -567,11 +652,11 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Set of stream names that have been subscribe to, but haven't yet
# caught up with. This is used to track when the client has been fully
# connected to the remote.
- self.streams_connecting = set()
+ self.streams_connecting = set() # type: Set[str]
# Map of stream to batched updates. See RdataCommand for info on how
# batching works.
- self.pending_batches = {}
+ self.pending_batches = {} # type: Dict[str, Any]
def connectionMade(self):
self.send_command(NameCommand(self.client_name))
@@ -595,12 +680,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
if not self.streams_connecting:
self.handler.finished_connecting()
- def on_SERVER(self, cmd):
+ async def on_SERVER(self, cmd):
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
self.send_error("Wrong remote")
- def on_RDATA(self, cmd):
+ async def on_RDATA(self, cmd):
stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc()
@@ -608,8 +693,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
row = STREAMS_MAP[stream_name].parse_row(cmd.row)
except Exception:
logger.exception(
- "[%s] Failed to parse RDATA: %r %r",
- self.id(), stream_name, cmd.row
+ "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row
)
raise
@@ -621,19 +705,22 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Check if this is the last of a batch of updates
rows = self.pending_batches.pop(stream_name, [])
rows.append(row)
- return self.handler.on_rdata(stream_name, cmd.token, rows)
+ await self.handler.on_rdata(stream_name, cmd.token, rows)
- def on_POSITION(self, cmd):
+ async def on_POSITION(self, cmd):
# When we get a `POSITION` command it means we've finished getting
# missing updates for the given stream, and are now up to date.
self.streams_connecting.discard(cmd.stream_name)
if not self.streams_connecting:
self.handler.finished_connecting()
- return self.handler.on_position(cmd.stream_name, cmd.token)
+ await self.handler.on_position(cmd.stream_name, cmd.token)
- def on_SYNC(self, cmd):
- return self.handler.on_sync(cmd.data)
+ async def on_SYNC(self, cmd):
+ self.handler.on_sync(cmd.data)
+
+ async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
+ self.handler.on_remote_server_up(cmd.data)
def replicate(self, stream_name, token):
"""Send the subscription request to the server
@@ -643,7 +730,9 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
logger.info(
"[%s] Subscribing to replication stream: %r from %r",
- self.id(), stream_name, token
+ self.id(),
+ stream_name,
+ token,
)
self.streams_connecting.add(stream_name)
@@ -661,9 +750,7 @@ pending_commands = LaterGauge(
"synapse_replication_tcp_protocol_pending_commands",
"",
["name"],
- lambda: {
- (p.name,): len(p.pending_commands) for p in connected_connections
- },
+ lambda: {(p.name,): len(p.pending_commands) for p in connected_connections},
)
@@ -678,9 +765,7 @@ transport_send_buffer = LaterGauge(
"synapse_replication_tcp_protocol_transport_send_buffer",
"",
["name"],
- lambda: {
- (p.name,): transport_buffer_size(p) for p in connected_connections
- },
+ lambda: {(p.name,): transport_buffer_size(p) for p in connected_connections},
)
@@ -694,7 +779,7 @@ def transport_kernel_read_buffer_size(protocol, read=True):
op = SIOCINQ
else:
op = SIOCOUTQ
- size = struct.unpack("I", fcntl.ioctl(fileno, op, '\0\0\0\0'))[0]
+ size = struct.unpack("I", fcntl.ioctl(fileno, op, b"\0\0\0\0"))[0]
return size
return 0
@@ -726,7 +811,7 @@ tcp_inbound_commands = LaterGauge(
"",
["command", "name"],
lambda: {
- (k, p.name,): count
+ (k, p.name): count
for p in connected_connections
for k, count in iteritems(p.inbound_commands_counter)
},
@@ -737,7 +822,7 @@ tcp_outbound_commands = LaterGauge(
"",
["command", "name"],
lambda: {
- (k, p.name,): count
+ (k, p.name): count
for p in connected_connections
for k, count in iteritems(p.outbound_commands_counter)
},
|