summary refs log tree commit diff
path: root/synapse/replication/tcp/protocol.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/protocol.py')
-rw-r--r--synapse/replication/tcp/protocol.py165
1 files changed, 51 insertions, 114 deletions
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index f2a37f568e..9aabb9c586 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -46,6 +46,7 @@ indicate which side is sending, these are *not* included on the wire::
     > ERROR server stopping
     * connection closed by server *
 """
+import abc
 import fcntl
 import logging
 import struct
@@ -69,13 +70,8 @@ from synapse.replication.tcp.commands import (
     ErrorCommand,
     NameCommand,
     PingCommand,
-    PositionCommand,
-    RdataCommand,
-    RemoteServerUpCommand,
     ReplicateCommand,
     ServerCommand,
-    SyncCommand,
-    UserSyncCommand,
 )
 from synapse.types import Collection
 from synapse.util import Clock
@@ -118,7 +114,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
     are only sent by the server.
 
     On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed
-    command.
+    command before delegating to `ReplicationCommandHandler.on_<COMMAND_NAME>`.
 
     It also sends `PING` periodically, and correctly times out remote connections
     (if they send a `PING` command)
@@ -134,8 +130,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
     max_line_buffer = 10000
 
-    def __init__(self, clock):
+    def __init__(self, clock: Clock, handler: "ReplicationCommandHandler"):
         self.clock = clock
+        self.command_handler = handler
 
         self.last_received_command = self.clock.time_msec()
         self.last_sent_command = 0
@@ -175,6 +172,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         # can time us out.
         self.send_command(PingCommand(self.clock.time_msec()))
 
+        self.command_handler.new_connection(self)
+
     def send_ping(self):
         """Periodically sends a ping and checks if we should close the connection
         due to the other side timing out.
@@ -243,13 +242,31 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
     async def handle_command(self, cmd: Command):
         """Handle a command we have received over the replication stream.
 
-        By default delegates to on_<COMMAND>, which should return an awaitable.
+        First calls `self.on_<COMMAND>` if it exists, then calls
+        `self.command_handler.on_<COMMAND>` if it exists. This allows for
+        protocol level handling of commands (e.g. PINGs), before delegating to
+        the handler.
 
         Args:
             cmd: received command
         """
-        handler = getattr(self, "on_%s" % (cmd.NAME,))
-        await handler(cmd)
+        handled = False
+
+        # First call any command handlers on this instance. These are for TCP
+        # specific handling.
+        cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
+        if cmd_func:
+            await cmd_func(cmd)
+            handled = True
+
+        # Then call out to the handler.
+        cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
+        if cmd_func:
+            await cmd_func(cmd)
+            handled = True
+
+        if not handled:
+            logger.warning("Unhandled command: %r", cmd)
 
     def close(self):
         logger.warning("[%s] Closing connection", self.id())
@@ -378,6 +395,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         self.state = ConnectionStates.CLOSED
         self.pending_commands = []
 
+        self.command_handler.lost_connection(self)
+
         if self.transport:
             self.transport.unregisterProducer()
 
@@ -404,74 +423,21 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
     VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS
     VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS
 
-    def __init__(self, server_name, clock, streamer):
-        BaseReplicationStreamProtocol.__init__(self, clock)  # Old style class
+    def __init__(
+        self, server_name: str, clock: Clock, handler: "ReplicationCommandHandler"
+    ):
+        super().__init__(clock, handler)
 
         self.server_name = server_name
-        self.streamer = streamer
 
     def connectionMade(self):
         self.send_command(ServerCommand(self.server_name))
-        BaseReplicationStreamProtocol.connectionMade(self)
-        self.streamer.new_connection(self)
+        super().connectionMade()
 
     async def on_NAME(self, cmd):
         logger.info("[%s] Renamed to %r", self.id(), cmd.data)
         self.name = cmd.data
 
-    async def on_USER_SYNC(self, cmd):
-        await self.streamer.on_user_sync(
-            cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
-        )
-
-    async def on_CLEAR_USER_SYNC(self, cmd):
-        await self.streamer.on_clear_user_syncs(cmd.instance_id)
-
-    async def on_REPLICATE(self, cmd):
-        # Subscribe to all streams we're publishing to.
-        for stream_name in self.streamer.streams_by_name:
-            current_token = self.streamer.get_stream_token(stream_name)
-            self.send_command(PositionCommand(stream_name, current_token))
-
-    async def on_FEDERATION_ACK(self, cmd):
-        self.streamer.federation_ack(cmd.token)
-
-    async def on_REMOVE_PUSHER(self, cmd):
-        await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
-
-    async def on_INVALIDATE_CACHE(self, cmd):
-        await self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
-
-    async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
-        self.streamer.on_remote_server_up(cmd.data)
-
-    async def on_USER_IP(self, cmd):
-        self.streamer.on_user_ip(
-            cmd.user_id,
-            cmd.access_token,
-            cmd.ip,
-            cmd.user_agent,
-            cmd.device_id,
-            cmd.last_seen,
-        )
-
-    def stream_update(self, stream_name, token, data):
-        """Called when a new update is available to stream to clients.
-
-        We need to check if the client is interested in the stream or not
-        """
-        self.send_command(RdataCommand(stream_name, token, data))
-
-    def send_sync(self, data):
-        self.send_command(SyncCommand(data))
-
-    def send_remote_server_up(self, server: str):
-        self.send_command(RemoteServerUpCommand(server))
-
-    def on_connection_closed(self):
-        BaseReplicationStreamProtocol.on_connection_closed(self)
-        self.streamer.lost_connection(self)
-
 
 class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
     VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
@@ -485,59 +451,18 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
         clock: Clock,
         command_handler: "ReplicationCommandHandler",
     ):
-        BaseReplicationStreamProtocol.__init__(self, clock)
-
-        self.instance_id = hs.get_instance_id()
+        super().__init__(clock, command_handler)
 
         self.client_name = client_name
         self.server_name = server_name
-        self.handler = command_handler
 
     def connectionMade(self):
         self.send_command(NameCommand(self.client_name))
-        BaseReplicationStreamProtocol.connectionMade(self)
+        super().connectionMade()
 
         # Once we've connected subscribe to the necessary streams
         self.replicate()
 
-        # Tell the server if we have any users currently syncing (should only
-        # happen on synchrotrons)
-        currently_syncing = self.handler.get_currently_syncing_users()
-        now = self.clock.time_msec()
-        for user_id in currently_syncing:
-            self.send_command(UserSyncCommand(self.instance_id, user_id, True, now))
-
-        # We've now finished connecting to so inform the client handler
-        self.handler.update_connection(self)
-        self.handler.finished_connecting()
-
-    async def handle_command(self, cmd: Command):
-        """Handle a command we have received over the replication stream.
-
-        Delegates to `command_handler.on_<COMMAND>`, which must return an
-        awaitable.
-
-        Args:
-            cmd: received command
-        """
-        handled = False
-
-        # First call any command handlers on this instance. These are for TCP
-        # specific handling.
-        cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
-        if cmd_func:
-            await cmd_func(cmd)
-            handled = True
-
-        # Then call out to the handler.
-        cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
-        if cmd_func:
-            await cmd_func(cmd)
-            handled = True
-
-        if not handled:
-            logger.warning("Unhandled command: %r", cmd)
-
     async def on_SERVER(self, cmd):
         if cmd.data != self.server_name:
             logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
@@ -550,9 +475,21 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
         self.send_command(ReplicateCommand())
 
-    def on_connection_closed(self):
-        BaseReplicationStreamProtocol.on_connection_closed(self)
-        self.handler.update_connection(None)
+
+class AbstractConnection(abc.ABC):
+    """An interface for replication connections.
+    """
+
+    @abc.abstractmethod
+    def send_command(self, cmd: Command):
+        """Send the command down the connection
+        """
+        pass
+
+
+# This tells python that `BaseReplicationStreamProtocol` implements the
+# interface.
+AbstractConnection.register(BaseReplicationStreamProtocol)
 
 
 # The following simply registers metrics for the replication connections