| diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index a7245da152..ee909f3fc5 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -48,7 +48,7 @@ from synapse.replication.tcp.commands import (
     UserIpCommand,
     UserSyncCommand,
 )
-from synapse.replication.tcp.protocol import AbstractConnection
+from synapse.replication.tcp.protocol import IReplicationConnection
 from synapse.replication.tcp.streams import (
     STREAMS_MAP,
     AccountDataStream,
@@ -82,7 +82,7 @@ user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache"
 
 # the type of the entries in _command_queues_by_stream
 _StreamCommandQueue = Deque[
-    Tuple[Union[RdataCommand, PositionCommand], AbstractConnection]
+    Tuple[Union[RdataCommand, PositionCommand], IReplicationConnection]
 ]
 
 
@@ -174,7 +174,7 @@ class ReplicationCommandHandler:
 
         # The currently connected connections. (The list of places we need to send
         # outgoing replication commands to.)
-        self._connections = []  # type: List[AbstractConnection]
+        self._connections = []  # type: List[IReplicationConnection]
 
         LaterGauge(
             "synapse_replication_tcp_resource_total_connections",
@@ -197,7 +197,7 @@ class ReplicationCommandHandler:
 
         # For each connection, the incoming stream names that have received a POSITION
         # from that connection.
-        self._streams_by_connection = {}  # type: Dict[AbstractConnection, Set[str]]
+        self._streams_by_connection = {}  # type: Dict[IReplicationConnection, Set[str]]
 
         LaterGauge(
             "synapse_replication_tcp_command_queue",
@@ -220,7 +220,7 @@ class ReplicationCommandHandler:
             self._server_notices_sender = hs.get_server_notices_sender()
 
     def _add_command_to_stream_queue(
-        self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
+        self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
     ) -> None:
         """Queue the given received command for processing
 
@@ -267,7 +267,7 @@ class ReplicationCommandHandler:
     async def _process_command(
         self,
         cmd: Union[PositionCommand, RdataCommand],
-        conn: AbstractConnection,
+        conn: IReplicationConnection,
         stream_name: str,
     ) -> None:
         if isinstance(cmd, PositionCommand):
@@ -321,10 +321,10 @@ class ReplicationCommandHandler:
         """Get a list of streams that this instances replicates."""
         return self._streams_to_replicate
 
-    def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
+    def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand):
         self.send_positions_to_connection(conn)
 
-    def send_positions_to_connection(self, conn: AbstractConnection):
+    def send_positions_to_connection(self, conn: IReplicationConnection):
         """Send current position of all streams this process is source of to
         the connection.
         """
@@ -347,7 +347,7 @@ class ReplicationCommandHandler:
             )
 
     def on_USER_SYNC(
-        self, conn: AbstractConnection, cmd: UserSyncCommand
+        self, conn: IReplicationConnection, cmd: UserSyncCommand
     ) -> Optional[Awaitable[None]]:
         user_sync_counter.inc()
 
@@ -359,21 +359,23 @@ class ReplicationCommandHandler:
             return None
 
     def on_CLEAR_USER_SYNC(
-        self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
+        self, conn: IReplicationConnection, cmd: ClearUserSyncsCommand
     ) -> Optional[Awaitable[None]]:
         if self._is_master:
             return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
         else:
             return None
 
-    def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand):
+    def on_FEDERATION_ACK(
+        self, conn: IReplicationConnection, cmd: FederationAckCommand
+    ):
         federation_ack_counter.inc()
 
         if self._federation_sender:
             self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
 
     def on_USER_IP(
-        self, conn: AbstractConnection, cmd: UserIpCommand
+        self, conn: IReplicationConnection, cmd: UserIpCommand
     ) -> Optional[Awaitable[None]]:
         user_ip_cache_counter.inc()
 
@@ -395,7 +397,7 @@ class ReplicationCommandHandler:
         assert self._server_notices_sender is not None
         await self._server_notices_sender.on_user_ip(cmd.user_id)
 
-    def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
+    def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand):
         if cmd.instance_name == self._instance_name:
             # Ignore RDATA that are just our own echoes
             return
@@ -412,7 +414,7 @@ class ReplicationCommandHandler:
         self._add_command_to_stream_queue(conn, cmd)
 
     async def _process_rdata(
-        self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
+        self, stream_name: str, conn: IReplicationConnection, cmd: RdataCommand
     ) -> None:
         """Process an RDATA command
 
@@ -486,7 +488,7 @@ class ReplicationCommandHandler:
             stream_name, instance_name, token, rows
         )
 
-    def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
+    def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand):
         if cmd.instance_name == self._instance_name:
             # Ignore POSITION that are just our own echoes
             return
@@ -496,7 +498,7 @@ class ReplicationCommandHandler:
         self._add_command_to_stream_queue(conn, cmd)
 
     async def _process_position(
-        self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
+        self, stream_name: str, conn: IReplicationConnection, cmd: PositionCommand
     ) -> None:
         """Process a POSITION command
 
@@ -553,7 +555,9 @@ class ReplicationCommandHandler:
 
         self._streams_by_connection.setdefault(conn, set()).add(stream_name)
 
-    def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand):
+    def on_REMOTE_SERVER_UP(
+        self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
+    ):
         """"Called when get a new REMOTE_SERVER_UP command."""
         self._replication_data_handler.on_remote_server_up(cmd.data)
 
@@ -576,7 +580,7 @@ class ReplicationCommandHandler:
         # between two instances, but that is not currently supported).
         self.send_command(cmd, ignore_conn=conn)
 
-    def new_connection(self, connection: AbstractConnection):
+    def new_connection(self, connection: IReplicationConnection):
         """Called when we have a new connection."""
         self._connections.append(connection)
 
@@ -603,7 +607,7 @@ class ReplicationCommandHandler:
                 UserSyncCommand(self._instance_id, user_id, True, now)
             )
 
-    def lost_connection(self, connection: AbstractConnection):
+    def lost_connection(self, connection: IReplicationConnection):
         """Called when a connection is closed/lost."""
         # we no longer need _streams_by_connection for this connection.
         streams = self._streams_by_connection.pop(connection, None)
@@ -624,7 +628,7 @@ class ReplicationCommandHandler:
         return bool(self._connections)
 
     def send_command(
-        self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None
+        self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None
     ):
         """Send a command to all connected connections.
 |