summary refs log tree commit diff
path: root/synapse/replication/tcp/protocol.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/protocol.py')
-rw-r--r--synapse/replication/tcp/protocol.py68
1 files changed, 35 insertions, 33 deletions
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 7bae36db16..7763ffb2d0 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -49,7 +49,7 @@ import fcntl
 import logging
 import struct
 from inspect import isawaitable
-from typing import TYPE_CHECKING, Collection, List, Optional
+from typing import TYPE_CHECKING, Any, Collection, List, Optional
 
 from prometheus_client import Counter
 from zope.interface import Interface, implementer
@@ -123,7 +123,7 @@ class ConnectionStates:
 class IReplicationConnection(Interface):
     """An interface for replication connections."""
 
-    def send_command(cmd: Command):
+    def send_command(cmd: Command) -> None:
         """Send the command down the connection"""
 
 
@@ -190,7 +190,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
                 "replication-conn", self.conn_id
             )
 
-    def connectionMade(self):
+    def connectionMade(self) -> None:
         logger.info("[%s] Connection established", self.id())
 
         self.state = ConnectionStates.ESTABLISHED
@@ -207,11 +207,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
         # Always send the initial PING so that the other side knows that they
         # can time us out.
-        self.send_command(PingCommand(self.clock.time_msec()))
+        self.send_command(PingCommand(str(self.clock.time_msec())))
 
         self.command_handler.new_connection(self)
 
-    def send_ping(self):
+    def send_ping(self) -> None:
         """Periodically sends a ping and checks if we should close the connection
         due to the other side timing out.
         """
@@ -226,7 +226,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
                 self.transport.abortConnection()
         else:
             if now - self.last_sent_command >= PING_TIME:
-                self.send_command(PingCommand(now))
+                self.send_command(PingCommand(str(now)))
 
             if (
                 self.received_ping
@@ -239,12 +239,12 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
                 )
                 self.send_error("ping timeout")
 
-    def lineReceived(self, line: bytes):
+    def lineReceived(self, line: bytes) -> None:
         """Called when we've received a line"""
         with PreserveLoggingContext(self._logging_context):
             self._parse_and_dispatch_line(line)
 
-    def _parse_and_dispatch_line(self, line: bytes):
+    def _parse_and_dispatch_line(self, line: bytes) -> None:
         if line.strip() == "":
             # Ignore blank lines
             return
@@ -309,24 +309,24 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         if not handled:
             logger.warning("Unhandled command: %r", cmd)
 
-    def close(self):
+    def close(self) -> None:
         logger.warning("[%s] Closing connection", self.id())
         self.time_we_closed = self.clock.time_msec()
         assert self.transport is not None
         self.transport.loseConnection()
         self.on_connection_closed()
 
-    def send_error(self, error_string, *args):
+    def send_error(self, error_string: str, *args: Any) -> None:
         """Send an error to remote and close the connection."""
         self.send_command(ErrorCommand(error_string % args))
         self.close()
 
-    def send_command(self, cmd, do_buffer=True):
+    def send_command(self, cmd: Command, do_buffer: bool = True) -> None:
         """Send a command if connection has been established.
 
         Args:
-            cmd (Command)
-            do_buffer (bool): Whether to buffer the message or always attempt
+            cmd
+            do_buffer: Whether to buffer the message or always attempt
                 to send the command. This is mostly used to send an error
                 message if we're about to close the connection due our buffers
                 becoming full.
@@ -357,7 +357,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
         self.last_sent_command = self.clock.time_msec()
 
-    def _queue_command(self, cmd):
+    def _queue_command(self, cmd: Command) -> None:
         """Queue the command until the connection is ready to write to again."""
         logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd)
         self.pending_commands.append(cmd)
@@ -370,20 +370,20 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
             self.send_command(ErrorCommand("Failed to keep up"), do_buffer=False)
             self.close()
 
-    def _send_pending_commands(self):
+    def _send_pending_commands(self) -> None:
         """Send any queued commandes"""
         pending = self.pending_commands
         self.pending_commands = []
         for cmd in pending:
             self.send_command(cmd)
 
-    def on_PING(self, line):
+    def on_PING(self, cmd: PingCommand) -> None:
         self.received_ping = True
 
-    def on_ERROR(self, cmd):
+    def on_ERROR(self, cmd: ErrorCommand) -> None:
         logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
 
-    def pauseProducing(self):
+    def pauseProducing(self) -> None:
         """This is called when both the kernel send buffer and the twisted
         tcp connection send buffers have become full.
 
@@ -394,26 +394,26 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         logger.info("[%s] Pause producing", self.id())
         self.state = ConnectionStates.PAUSED
 
-    def resumeProducing(self):
+    def resumeProducing(self) -> None:
         """The remote has caught up after we started buffering!"""
         logger.info("[%s] Resume producing", self.id())
         self.state = ConnectionStates.ESTABLISHED
         self._send_pending_commands()
 
-    def stopProducing(self):
+    def stopProducing(self) -> None:
         """We're never going to send any more data (normally because either
         we or the remote has closed the connection)
         """
         logger.info("[%s] Stop producing", self.id())
         self.on_connection_closed()
 
-    def connectionLost(self, reason):
+    def connectionLost(self, reason: Failure) -> None:  # type: ignore[override]
         logger.info("[%s] Replication connection closed: %r", self.id(), reason)
         if isinstance(reason, Failure):
             assert reason.type is not None
             connection_close_counter.labels(reason.type.__name__).inc()
         else:
-            connection_close_counter.labels(reason.__class__.__name__).inc()
+            connection_close_counter.labels(reason.__class__.__name__).inc()  # type: ignore[unreachable]
 
         try:
             # Remove us from list of connections to be monitored
@@ -427,7 +427,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
         self.on_connection_closed()
 
-    def on_connection_closed(self):
+    def on_connection_closed(self) -> None:
         logger.info("[%s] Connection was closed", self.id())
 
         self.state = ConnectionStates.CLOSED
@@ -445,7 +445,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
             # the sentinel context is now active, which may not be correct.
             # PreserveLoggingContext() will restore the correct logging context.
 
-    def __str__(self):
+    def __str__(self) -> str:
         addr = None
         if self.transport:
             addr = str(self.transport.getPeer())
@@ -455,10 +455,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
             addr,
         )
 
-    def id(self):
+    def id(self) -> str:
         return "%s-%s" % (self.name, self.conn_id)
 
-    def lineLengthExceeded(self, line):
+    def lineLengthExceeded(self, line: str) -> None:
         """Called when we receive a line that is above the maximum line length"""
         self.send_error("Line length exceeded")
 
@@ -474,11 +474,11 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
         self.server_name = server_name
 
-    def connectionMade(self):
+    def connectionMade(self) -> None:
         self.send_command(ServerCommand(self.server_name))
         super().connectionMade()
 
-    def on_NAME(self, cmd):
+    def on_NAME(self, cmd: NameCommand) -> None:
         logger.info("[%s] Renamed to %r", self.id(), cmd.data)
         self.name = cmd.data
 
@@ -500,19 +500,19 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
         self.client_name = client_name
         self.server_name = server_name
 
-    def connectionMade(self):
+    def connectionMade(self) -> None:
         self.send_command(NameCommand(self.client_name))
         super().connectionMade()
 
         # Once we've connected subscribe to the necessary streams
         self.replicate()
 
-    def on_SERVER(self, cmd):
+    def on_SERVER(self, cmd: ServerCommand) -> None:
         if cmd.data != self.server_name:
             logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
             self.send_error("Wrong remote")
 
-    def replicate(self):
+    def replicate(self) -> None:
         """Send the subscription request to the server"""
         logger.info("[%s] Subscribing to replication streams", self.id())
 
@@ -529,7 +529,7 @@ pending_commands = LaterGauge(
 )
 
 
-def transport_buffer_size(protocol):
+def transport_buffer_size(protocol: BaseReplicationStreamProtocol) -> int:
     if protocol.transport:
         size = len(protocol.transport.dataBuffer) + protocol.transport._tempDataLen
         return size
@@ -544,7 +544,9 @@ transport_send_buffer = LaterGauge(
 )
 
 
-def transport_kernel_read_buffer_size(protocol, read=True):
+def transport_kernel_read_buffer_size(
+    protocol: BaseReplicationStreamProtocol, read: bool = True
+) -> int:
     SIOCINQ = 0x541B
     SIOCOUTQ = 0x5411