diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index f2a37f568e..9aabb9c586 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -46,6 +46,7 @@ indicate which side is sending, these are *not* included on the wire::
> ERROR server stopping
* connection closed by server *
"""
+import abc
import fcntl
import logging
import struct
@@ -69,13 +70,8 @@ from synapse.replication.tcp.commands import (
ErrorCommand,
NameCommand,
PingCommand,
- PositionCommand,
- RdataCommand,
- RemoteServerUpCommand,
ReplicateCommand,
ServerCommand,
- SyncCommand,
- UserSyncCommand,
)
from synapse.types import Collection
from synapse.util import Clock
@@ -118,7 +114,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
are only sent by the server.
On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed
- command.
+ command before delegating to `ReplicationCommandHandler.on_<COMMAND_NAME>`.
It also sends `PING` periodically, and correctly times out remote connections
(if they send a `PING` command)
@@ -134,8 +130,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
max_line_buffer = 10000
- def __init__(self, clock):
+ def __init__(self, clock: Clock, handler: "ReplicationCommandHandler"):
self.clock = clock
+ self.command_handler = handler
self.last_received_command = self.clock.time_msec()
self.last_sent_command = 0
@@ -175,6 +172,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# can time us out.
self.send_command(PingCommand(self.clock.time_msec()))
+ self.command_handler.new_connection(self)
+
def send_ping(self):
"""Periodically sends a ping and checks if we should close the connection
due to the other side timing out.
@@ -243,13 +242,31 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
async def handle_command(self, cmd: Command):
"""Handle a command we have received over the replication stream.
- By default delegates to on_<COMMAND>, which should return an awaitable.
+ First calls `self.on_<COMMAND>` if it exists, then calls
+ `self.command_handler.on_<COMMAND>` if it exists. This allows for
+ protocol level handling of commands (e.g. PINGs), before delegating to
+ the handler.
Args:
cmd: received command
"""
- handler = getattr(self, "on_%s" % (cmd.NAME,))
- await handler(cmd)
+ handled = False
+
+ # First call any command handlers on this instance. These are for TCP
+ # specific handling.
+ cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
+ if cmd_func:
+ await cmd_func(cmd)
+ handled = True
+
+ # Then call out to the handler.
+ cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
+ if cmd_func:
+ await cmd_func(cmd)
+ handled = True
+
+ if not handled:
+ logger.warning("Unhandled command: %r", cmd)
def close(self):
logger.warning("[%s] Closing connection", self.id())
@@ -378,6 +395,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.state = ConnectionStates.CLOSED
self.pending_commands = []
+ self.command_handler.lost_connection(self)
+
if self.transport:
self.transport.unregisterProducer()
@@ -404,74 +423,21 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS
VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS
- def __init__(self, server_name, clock, streamer):
- BaseReplicationStreamProtocol.__init__(self, clock) # Old style class
+ def __init__(
+ self, server_name: str, clock: Clock, handler: "ReplicationCommandHandler"
+ ):
+ super().__init__(clock, handler)
self.server_name = server_name
- self.streamer = streamer
def connectionMade(self):
self.send_command(ServerCommand(self.server_name))
- BaseReplicationStreamProtocol.connectionMade(self)
- self.streamer.new_connection(self)
+ super().connectionMade()
async def on_NAME(self, cmd):
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
self.name = cmd.data
- async def on_USER_SYNC(self, cmd):
- await self.streamer.on_user_sync(
- cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
- )
-
- async def on_CLEAR_USER_SYNC(self, cmd):
- await self.streamer.on_clear_user_syncs(cmd.instance_id)
-
- async def on_REPLICATE(self, cmd):
- # Subscribe to all streams we're publishing to.
- for stream_name in self.streamer.streams_by_name:
- current_token = self.streamer.get_stream_token(stream_name)
- self.send_command(PositionCommand(stream_name, current_token))
-
- async def on_FEDERATION_ACK(self, cmd):
- self.streamer.federation_ack(cmd.token)
-
- async def on_REMOVE_PUSHER(self, cmd):
- await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
-
- async def on_INVALIDATE_CACHE(self, cmd):
- await self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
-
- async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
- self.streamer.on_remote_server_up(cmd.data)
-
- async def on_USER_IP(self, cmd):
- self.streamer.on_user_ip(
- cmd.user_id,
- cmd.access_token,
- cmd.ip,
- cmd.user_agent,
- cmd.device_id,
- cmd.last_seen,
- )
-
- def stream_update(self, stream_name, token, data):
- """Called when a new update is available to stream to clients.
-
- We need to check if the client is interested in the stream or not
- """
- self.send_command(RdataCommand(stream_name, token, data))
-
- def send_sync(self, data):
- self.send_command(SyncCommand(data))
-
- def send_remote_server_up(self, server: str):
- self.send_command(RemoteServerUpCommand(server))
-
- def on_connection_closed(self):
- BaseReplicationStreamProtocol.on_connection_closed(self)
- self.streamer.lost_connection(self)
-
class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
@@ -485,59 +451,18 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
clock: Clock,
command_handler: "ReplicationCommandHandler",
):
- BaseReplicationStreamProtocol.__init__(self, clock)
-
- self.instance_id = hs.get_instance_id()
+ super().__init__(clock, command_handler)
self.client_name = client_name
self.server_name = server_name
- self.handler = command_handler
def connectionMade(self):
self.send_command(NameCommand(self.client_name))
- BaseReplicationStreamProtocol.connectionMade(self)
+ super().connectionMade()
# Once we've connected subscribe to the necessary streams
self.replicate()
- # Tell the server if we have any users currently syncing (should only
- # happen on synchrotrons)
- currently_syncing = self.handler.get_currently_syncing_users()
- now = self.clock.time_msec()
- for user_id in currently_syncing:
- self.send_command(UserSyncCommand(self.instance_id, user_id, True, now))
-
- # We've now finished connecting to so inform the client handler
- self.handler.update_connection(self)
- self.handler.finished_connecting()
-
- async def handle_command(self, cmd: Command):
- """Handle a command we have received over the replication stream.
-
- Delegates to `command_handler.on_<COMMAND>`, which must return an
- awaitable.
-
- Args:
- cmd: received command
- """
- handled = False
-
- # First call any command handlers on this instance. These are for TCP
- # specific handling.
- cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
- if cmd_func:
- await cmd_func(cmd)
- handled = True
-
- # Then call out to the handler.
- cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
- if cmd_func:
- await cmd_func(cmd)
- handled = True
-
- if not handled:
- logger.warning("Unhandled command: %r", cmd)
-
async def on_SERVER(self, cmd):
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
@@ -550,9 +475,21 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.send_command(ReplicateCommand())
- def on_connection_closed(self):
- BaseReplicationStreamProtocol.on_connection_closed(self)
- self.handler.update_connection(None)
+
+class AbstractConnection(abc.ABC):
+ """An interface for replication connections.
+ """
+
+ @abc.abstractmethod
+ def send_command(self, cmd: Command):
+ """Send the command down the connection
+ """
+ pass
+
+
+# This tells python that `BaseReplicationStreamProtocol` implements the
+# interface.
+AbstractConnection.register(BaseReplicationStreamProtocol)
# The following simply registers metrics for the replication connections
|