summary refs log tree commit diff
path: root/synapse/replication/tcp
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r--synapse/replication/tcp/handler.py44
-rw-r--r--synapse/replication/tcp/protocol.py24
-rw-r--r--synapse/replication/tcp/redis.py8
3 files changed, 38 insertions, 38 deletions
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index a7245da152..ee909f3fc5 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -48,7 +48,7 @@ from synapse.replication.tcp.commands import (
     UserIpCommand,
     UserSyncCommand,
 )
-from synapse.replication.tcp.protocol import AbstractConnection
+from synapse.replication.tcp.protocol import IReplicationConnection
 from synapse.replication.tcp.streams import (
     STREAMS_MAP,
     AccountDataStream,
@@ -82,7 +82,7 @@ user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache"
 
 # the type of the entries in _command_queues_by_stream
 _StreamCommandQueue = Deque[
-    Tuple[Union[RdataCommand, PositionCommand], AbstractConnection]
+    Tuple[Union[RdataCommand, PositionCommand], IReplicationConnection]
 ]
 
 
@@ -174,7 +174,7 @@ class ReplicationCommandHandler:
 
         # The currently connected connections. (The list of places we need to send
         # outgoing replication commands to.)
-        self._connections = []  # type: List[AbstractConnection]
+        self._connections = []  # type: List[IReplicationConnection]
 
         LaterGauge(
             "synapse_replication_tcp_resource_total_connections",
@@ -197,7 +197,7 @@ class ReplicationCommandHandler:
 
         # For each connection, the incoming stream names that have received a POSITION
         # from that connection.
-        self._streams_by_connection = {}  # type: Dict[AbstractConnection, Set[str]]
+        self._streams_by_connection = {}  # type: Dict[IReplicationConnection, Set[str]]
 
         LaterGauge(
             "synapse_replication_tcp_command_queue",
@@ -220,7 +220,7 @@ class ReplicationCommandHandler:
             self._server_notices_sender = hs.get_server_notices_sender()
 
     def _add_command_to_stream_queue(
-        self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
+        self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
     ) -> None:
         """Queue the given received command for processing
 
@@ -267,7 +267,7 @@ class ReplicationCommandHandler:
     async def _process_command(
         self,
         cmd: Union[PositionCommand, RdataCommand],
-        conn: AbstractConnection,
+        conn: IReplicationConnection,
         stream_name: str,
     ) -> None:
         if isinstance(cmd, PositionCommand):
@@ -321,10 +321,10 @@ class ReplicationCommandHandler:
         """Get a list of streams that this instances replicates."""
         return self._streams_to_replicate
 
-    def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
+    def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand):
         self.send_positions_to_connection(conn)
 
-    def send_positions_to_connection(self, conn: AbstractConnection):
+    def send_positions_to_connection(self, conn: IReplicationConnection):
         """Send current position of all streams this process is source of to
         the connection.
         """
@@ -347,7 +347,7 @@ class ReplicationCommandHandler:
             )
 
     def on_USER_SYNC(
-        self, conn: AbstractConnection, cmd: UserSyncCommand
+        self, conn: IReplicationConnection, cmd: UserSyncCommand
     ) -> Optional[Awaitable[None]]:
         user_sync_counter.inc()
 
@@ -359,21 +359,23 @@ class ReplicationCommandHandler:
             return None
 
     def on_CLEAR_USER_SYNC(
-        self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
+        self, conn: IReplicationConnection, cmd: ClearUserSyncsCommand
     ) -> Optional[Awaitable[None]]:
         if self._is_master:
             return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
         else:
             return None
 
-    def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand):
+    def on_FEDERATION_ACK(
+        self, conn: IReplicationConnection, cmd: FederationAckCommand
+    ):
         federation_ack_counter.inc()
 
         if self._federation_sender:
             self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
 
     def on_USER_IP(
-        self, conn: AbstractConnection, cmd: UserIpCommand
+        self, conn: IReplicationConnection, cmd: UserIpCommand
     ) -> Optional[Awaitable[None]]:
         user_ip_cache_counter.inc()
 
@@ -395,7 +397,7 @@ class ReplicationCommandHandler:
         assert self._server_notices_sender is not None
         await self._server_notices_sender.on_user_ip(cmd.user_id)
 
-    def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
+    def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand):
         if cmd.instance_name == self._instance_name:
             # Ignore RDATA that are just our own echoes
             return
@@ -412,7 +414,7 @@ class ReplicationCommandHandler:
         self._add_command_to_stream_queue(conn, cmd)
 
     async def _process_rdata(
-        self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
+        self, stream_name: str, conn: IReplicationConnection, cmd: RdataCommand
     ) -> None:
         """Process an RDATA command
 
@@ -486,7 +488,7 @@ class ReplicationCommandHandler:
             stream_name, instance_name, token, rows
         )
 
-    def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
+    def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand):
         if cmd.instance_name == self._instance_name:
             # Ignore POSITION that are just our own echoes
             return
@@ -496,7 +498,7 @@ class ReplicationCommandHandler:
         self._add_command_to_stream_queue(conn, cmd)
 
     async def _process_position(
-        self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
+        self, stream_name: str, conn: IReplicationConnection, cmd: PositionCommand
     ) -> None:
         """Process a POSITION command
 
@@ -553,7 +555,9 @@ class ReplicationCommandHandler:
 
         self._streams_by_connection.setdefault(conn, set()).add(stream_name)
 
-    def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand):
+    def on_REMOTE_SERVER_UP(
+        self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
+    ):
         """"Called when get a new REMOTE_SERVER_UP command."""
         self._replication_data_handler.on_remote_server_up(cmd.data)
 
@@ -576,7 +580,7 @@ class ReplicationCommandHandler:
         # between two instances, but that is not currently supported).
         self.send_command(cmd, ignore_conn=conn)
 
-    def new_connection(self, connection: AbstractConnection):
+    def new_connection(self, connection: IReplicationConnection):
         """Called when we have a new connection."""
         self._connections.append(connection)
 
@@ -603,7 +607,7 @@ class ReplicationCommandHandler:
                 UserSyncCommand(self._instance_id, user_id, True, now)
             )
 
-    def lost_connection(self, connection: AbstractConnection):
+    def lost_connection(self, connection: IReplicationConnection):
         """Called when a connection is closed/lost."""
         # we no longer need _streams_by_connection for this connection.
         streams = self._streams_by_connection.pop(connection, None)
@@ -624,7 +628,7 @@ class ReplicationCommandHandler:
         return bool(self._connections)
 
     def send_command(
-        self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None
+        self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None
     ):
         """Send a command to all connected connections.
 
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index e0b4ad314d..8e4734b59c 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -46,7 +46,6 @@ indicate which side is sending, these are *not* included on the wire::
     > ERROR server stopping
     * connection closed by server *
 """
-import abc
 import fcntl
 import logging
 import struct
@@ -54,6 +53,7 @@ from inspect import isawaitable
 from typing import TYPE_CHECKING, List, Optional
 
 from prometheus_client import Counter
+from zope.interface import Interface, implementer
 
 from twisted.internet import task
 from twisted.protocols.basic import LineOnlyReceiver
@@ -121,6 +121,14 @@ class ConnectionStates:
     CLOSED = "closed"
 
 
+class IReplicationConnection(Interface):
+    """An interface for replication connections."""
+
+    def send_command(cmd: Command):
+        """Send the command down the connection"""
+
+
+@implementer(IReplicationConnection)
 class BaseReplicationStreamProtocol(LineOnlyReceiver):
     """Base replication protocol shared between client and server.
 
@@ -495,20 +503,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
         self.send_command(ReplicateCommand())
 
 
-class AbstractConnection(abc.ABC):
-    """An interface for replication connections."""
-
-    @abc.abstractmethod
-    def send_command(self, cmd: Command):
-        """Send the command down the connection"""
-        pass
-
-
-# This tells python that `BaseReplicationStreamProtocol` implements the
-# interface.
-AbstractConnection.register(BaseReplicationStreamProtocol)
-
-
 # The following simply registers metrics for the replication connections
 
 pending_commands = LaterGauge(
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 574eaea1eb..7cccde097d 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast
 
 import attr
 import txredisapi
+from zope.interface import implementer
 
 from twisted.internet.address import IPv4Address, IPv6Address
 from twisted.internet.interfaces import IAddress, IConnector
@@ -36,7 +37,7 @@ from synapse.replication.tcp.commands import (
     parse_command_from_line,
 )
 from synapse.replication.tcp.protocol import (
-    AbstractConnection,
+    IReplicationConnection,
     tcp_inbound_commands_counter,
     tcp_outbound_commands_counter,
 )
@@ -66,7 +67,8 @@ class ConstantProperty(Generic[T, V]):
         pass
 
 
-class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
+@implementer(IReplicationConnection)
+class RedisSubscriber(txredisapi.SubscriberProtocol):
     """Connection to redis subscribed to replication stream.
 
     This class fulfils two functions:
@@ -75,7 +77,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
     connection, parsing *incoming* messages into replication commands, and passing them
     to `ReplicationCommandHandler`
 
-    (b) it implements the AbstractConnection API, where it sends *outgoing* commands
+    (b) it implements the IReplicationConnection API, where it sends *outgoing* commands
     onto outbound_redis_connection.
 
     Due to the vagaries of `txredisapi` we don't want to have a custom