summary refs log tree commit diff
path: root/synapse/replication/tcp/protocol.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/protocol.py')
-rw-r--r--synapse/replication/tcp/protocol.py121
1 files changed, 67 insertions, 54 deletions
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index afaf002fe6..131e5acb09 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -53,6 +53,7 @@ import fcntl
 import logging
 import struct
 from collections import defaultdict
+from typing import Any, DefaultDict, Dict, List, Set, Tuple
 
 from six import iteritems, iterkeys
 
@@ -65,24 +66,26 @@ from twisted.python.failure import Failure
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.util import Clock
-from synapse.util.stringutils import random_string
-
-from .commands import (
+from synapse.replication.tcp.commands import (
     COMMAND_MAP,
     VALID_CLIENT_COMMANDS,
     VALID_SERVER_COMMANDS,
+    Command,
     ErrorCommand,
     NameCommand,
     PingCommand,
     PositionCommand,
     RdataCommand,
+    RemoteServerUpCommand,
     ReplicateCommand,
     ServerCommand,
     SyncCommand,
     UserSyncCommand,
 )
-from .streams import STREAMS_MAP
+from synapse.replication.tcp.streams import STREAMS_MAP
+from synapse.types import Collection
+from synapse.util import Clock
+from synapse.util.stringutils import random_string
 
 connection_close_counter = Counter(
     "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
@@ -124,8 +127,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
     delimiter = b"\n"
 
-    VALID_INBOUND_COMMANDS = []  # Valid commands we expect to receive
-    VALID_OUTBOUND_COMMANDS = []  # Valid commans we can send
+    # Valid commands we expect to receive
+    VALID_INBOUND_COMMANDS = []  # type: Collection[str]
+
+    # Valid commands we can send
+    VALID_OUTBOUND_COMMANDS = []  # type: Collection[str]
 
     max_line_buffer = 10000
 
@@ -144,13 +150,13 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         self.conn_id = random_string(5)  # To dedupe in case of name clashes.
 
         # List of pending commands to send once we've established the connection
-        self.pending_commands = []
+        self.pending_commands = []  # type: List[Command]
 
         # The LoopingCall for sending pings.
         self._send_ping_loop = None
 
-        self.inbound_commands_counter = defaultdict(int)
-        self.outbound_commands_counter = defaultdict(int)
+        self.inbound_commands_counter = defaultdict(int)  # type: DefaultDict[str, int]
+        self.outbound_commands_counter = defaultdict(int)  # type: DefaultDict[str, int]
 
     def connectionMade(self):
         logger.info("[%s] Connection established", self.id())
@@ -235,19 +241,16 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
             "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
         )
 
-    def handle_command(self, cmd):
+    async def handle_command(self, cmd: Command):
         """Handle a command we have received over the replication stream.
 
-        By default delegates to on_<COMMAND>
+        By default delegates to on_<COMMAND>, which should return an awaitable.
 
         Args:
-            cmd (synapse.replication.tcp.commands.Command): received command
-
-        Returns:
-            Deferred
+            cmd: received command
         """
         handler = getattr(self, "on_%s" % (cmd.NAME,))
-        return handler(cmd)
+        await handler(cmd)
 
     def close(self):
         logger.warning("[%s] Closing connection", self.id())
@@ -320,10 +323,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         for cmd in pending:
             self.send_command(cmd)
 
-    def on_PING(self, line):
+    async def on_PING(self, line):
         self.received_ping = True
 
-    def on_ERROR(self, cmd):
+    async def on_ERROR(self, cmd):
         logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
 
     def pauseProducing(self):
@@ -409,30 +412,30 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
         self.streamer = streamer
 
         # The streams the client has subscribed to and is up to date with
-        self.replication_streams = set()
+        self.replication_streams = set()  # type: Set[str]
 
         # The streams the client is currently subscribing to.
-        self.connecting_streams = set()
+        self.connecting_streams = set()  # type:  Set[str]
 
         # Map from stream name to list of updates to send once we've finished
         # subscribing the client to the stream.
-        self.pending_rdata = {}
+        self.pending_rdata = {}  # type: Dict[str, List[Tuple[int, Any]]]
 
     def connectionMade(self):
         self.send_command(ServerCommand(self.server_name))
         BaseReplicationStreamProtocol.connectionMade(self)
         self.streamer.new_connection(self)
 
-    def on_NAME(self, cmd):
+    async def on_NAME(self, cmd):
         logger.info("[%s] Renamed to %r", self.id(), cmd.data)
         self.name = cmd.data
 
-    def on_USER_SYNC(self, cmd):
-        return self.streamer.on_user_sync(
+    async def on_USER_SYNC(self, cmd):
+        await self.streamer.on_user_sync(
             self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
         )
 
-    def on_REPLICATE(self, cmd):
+    async def on_REPLICATE(self, cmd):
         stream_name = cmd.stream_name
         token = cmd.token
 
@@ -443,23 +446,26 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
                 for stream in iterkeys(self.streamer.streams_by_name)
             ]
 
-            return make_deferred_yieldable(
+            await make_deferred_yieldable(
                 defer.gatherResults(deferreds, consumeErrors=True)
             )
         else:
-            return self.subscribe_to_stream(stream_name, token)
+            await self.subscribe_to_stream(stream_name, token)
 
-    def on_FEDERATION_ACK(self, cmd):
-        return self.streamer.federation_ack(cmd.token)
+    async def on_FEDERATION_ACK(self, cmd):
+        self.streamer.federation_ack(cmd.token)
 
-    def on_REMOVE_PUSHER(self, cmd):
-        return self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
+    async def on_REMOVE_PUSHER(self, cmd):
+        await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
 
-    def on_INVALIDATE_CACHE(self, cmd):
-        return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
+    async def on_INVALIDATE_CACHE(self, cmd):
+        self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
 
-    def on_USER_IP(self, cmd):
-        return self.streamer.on_user_ip(
+    async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
+        self.streamer.on_remote_server_up(cmd.data)
+
+    async def on_USER_IP(self, cmd):
+        self.streamer.on_user_ip(
             cmd.user_id,
             cmd.access_token,
             cmd.ip,
@@ -468,8 +474,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
             cmd.last_seen,
         )
 
-    @defer.inlineCallbacks
-    def subscribe_to_stream(self, stream_name, token):
+    async def subscribe_to_stream(self, stream_name, token):
         """Subscribe the remote to a stream.
 
         This invloves checking if they've missed anything and sending those
@@ -481,7 +486,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
         try:
             # Get missing updates
-            updates, current_token = yield self.streamer.get_stream_updates(
+            updates, current_token = await self.streamer.get_stream_updates(
                 stream_name, token
             )
 
@@ -554,6 +559,9 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
     def send_sync(self, data):
         self.send_command(SyncCommand(data))
 
+    def send_remote_server_up(self, server: str):
+        self.send_command(RemoteServerUpCommand(server))
+
     def on_connection_closed(self):
         BaseReplicationStreamProtocol.on_connection_closed(self)
         self.streamer.lost_connection(self)
@@ -566,7 +574,7 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
     """
 
     @abc.abstractmethod
-    def on_rdata(self, stream_name, token, rows):
+    async def on_rdata(self, stream_name, token, rows):
         """Called to handle a batch of replication data with a given stream token.
 
         Args:
@@ -574,14 +582,11 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
             token (int): stream token for this batch of rows
             rows (list): a list of Stream.ROW_TYPE objects as returned by
                 Stream.parse_row.
-
-        Returns:
-            Deferred|None
         """
         raise NotImplementedError()
 
     @abc.abstractmethod
-    def on_position(self, stream_name, token):
+    async def on_position(self, stream_name, token):
         """Called when we get new position data."""
         raise NotImplementedError()
 
@@ -591,6 +596,11 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
         raise NotImplementedError()
 
     @abc.abstractmethod
+    async def on_remote_server_up(self, server: str):
+        """Called when get a new REMOTE_SERVER_UP command."""
+        raise NotImplementedError()
+
+    @abc.abstractmethod
     def get_streams_to_replicate(self):
         """Called when a new connection has been established and we need to
         subscribe to streams.
@@ -642,11 +652,11 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
         # Set of stream names that have been subscribe to, but haven't yet
         # caught up with. This is used to track when the client has been fully
         # connected to the remote.
-        self.streams_connecting = set()
+        self.streams_connecting = set()  # type: Set[str]
 
         # Map of stream to batched updates. See RdataCommand for info on how
         # batching works.
-        self.pending_batches = {}
+        self.pending_batches = {}  # type: Dict[str, Any]
 
     def connectionMade(self):
         self.send_command(NameCommand(self.client_name))
@@ -670,12 +680,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
         if not self.streams_connecting:
             self.handler.finished_connecting()
 
-    def on_SERVER(self, cmd):
+    async def on_SERVER(self, cmd):
         if cmd.data != self.server_name:
             logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
             self.send_error("Wrong remote")
 
-    def on_RDATA(self, cmd):
+    async def on_RDATA(self, cmd):
         stream_name = cmd.stream_name
         inbound_rdata_count.labels(stream_name).inc()
 
@@ -695,19 +705,22 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
             # Check if this is the last of a batch of updates
             rows = self.pending_batches.pop(stream_name, [])
             rows.append(row)
-            return self.handler.on_rdata(stream_name, cmd.token, rows)
+            await self.handler.on_rdata(stream_name, cmd.token, rows)
 
-    def on_POSITION(self, cmd):
+    async def on_POSITION(self, cmd):
         # When we get a `POSITION` command it means we've finished getting
         # missing updates for the given stream, and are now up to date.
         self.streams_connecting.discard(cmd.stream_name)
         if not self.streams_connecting:
             self.handler.finished_connecting()
 
-        return self.handler.on_position(cmd.stream_name, cmd.token)
+        await self.handler.on_position(cmd.stream_name, cmd.token)
+
+    async def on_SYNC(self, cmd):
+        self.handler.on_sync(cmd.data)
 
-    def on_SYNC(self, cmd):
-        return self.handler.on_sync(cmd.data)
+    async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
+        self.handler.on_remote_server_up(cmd.data)
 
     def replicate(self, stream_name, token):
         """Send the subscription request to the server
@@ -766,7 +779,7 @@ def transport_kernel_read_buffer_size(protocol, read=True):
             op = SIOCINQ
         else:
             op = SIOCOUTQ
-        size = struct.unpack("I", fcntl.ioctl(fileno, op, "\0\0\0\0"))[0]
+        size = struct.unpack("I", fcntl.ioctl(fileno, op, b"\0\0\0\0"))[0]
         return size
     return 0