diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index e0b4ad314d..825900f64c 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -46,7 +46,6 @@ indicate which side is sending, these are *not* included on the wire::
> ERROR server stopping
* connection closed by server *
"""
-import abc
import fcntl
import logging
import struct
@@ -54,8 +53,10 @@ from inspect import isawaitable
from typing import TYPE_CHECKING, List, Optional
from prometheus_client import Counter
+from zope.interface import Interface, implementer
from twisted.internet import task
+from twisted.internet.tcp import Connection
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
@@ -121,6 +122,14 @@ class ConnectionStates:
CLOSED = "closed"
+class IReplicationConnection(Interface):
+ """An interface for replication connections."""
+
+ def send_command(cmd: Command):
+ """Send the command down the connection"""
+
+
+@implementer(IReplicationConnection)
class BaseReplicationStreamProtocol(LineOnlyReceiver):
"""Base replication protocol shared between client and server.
@@ -137,6 +146,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
(if they send a `PING` command)
"""
+ # The transport is going to be an ITCPTransport, but that doesn't have the
+ # (un)registerProducer methods, those are only on the implementation.
+ transport = None # type: Connection
+
delimiter = b"\n"
# Valid commands we expect to receive
@@ -181,6 +194,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
connected_connections.append(self) # Register connection for metrics
+ assert self.transport is not None
self.transport.registerProducer(self, True) # For the *Producing callbacks
self._send_pending_commands()
@@ -205,6 +219,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
logger.info(
"[%s] Failed to close connection gracefully, aborting", self.id()
)
+ assert self.transport is not None
self.transport.abortConnection()
else:
if now - self.last_sent_command >= PING_TIME:
@@ -294,6 +309,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def close(self):
logger.warning("[%s] Closing connection", self.id())
self.time_we_closed = self.clock.time_msec()
+ assert self.transport is not None
self.transport.loseConnection()
self.on_connection_closed()
@@ -391,6 +407,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def connectionLost(self, reason):
logger.info("[%s] Replication connection closed: %r", self.id(), reason)
if isinstance(reason, Failure):
+ assert reason.type is not None
connection_close_counter.labels(reason.type.__name__).inc()
else:
connection_close_counter.labels(reason.__class__.__name__).inc()
@@ -495,20 +512,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.send_command(ReplicateCommand())
-class AbstractConnection(abc.ABC):
- """An interface for replication connections."""
-
- @abc.abstractmethod
- def send_command(self, cmd: Command):
- """Send the command down the connection"""
- pass
-
-
-# This tells python that `BaseReplicationStreamProtocol` implements the
-# interface.
-AbstractConnection.register(BaseReplicationStreamProtocol)
-
-
# The following simply registers metrics for the replication connections
pending_commands = LaterGauge(
|