diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 36071feb36..4ec1bfa6ea 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -61,7 +61,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
is_guest = content["is_guest"]
is_appservice_ghost = content["is_appservice_ghost"]
- device_id, access_token = await self.registration_handler.register_device(
+ res = await self.registration_handler.register_device_inner(
user_id,
device_id,
initial_display_name,
@@ -69,7 +69,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
is_appservice_ghost=is_appservice_ghost,
)
- return 200, {"device_id": device_id, "access_token": access_token}
+ return 200, res
def register_servlets(hs, http_server):
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index a7245da152..a8894beadf 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -48,7 +48,7 @@ from synapse.replication.tcp.commands import (
UserIpCommand,
UserSyncCommand,
)
-from synapse.replication.tcp.protocol import AbstractConnection
+from synapse.replication.tcp.protocol import IReplicationConnection
from synapse.replication.tcp.streams import (
STREAMS_MAP,
AccountDataStream,
@@ -82,7 +82,7 @@ user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache"
# the type of the entries in _command_queues_by_stream
_StreamCommandQueue = Deque[
- Tuple[Union[RdataCommand, PositionCommand], AbstractConnection]
+ Tuple[Union[RdataCommand, PositionCommand], IReplicationConnection]
]
@@ -174,7 +174,7 @@ class ReplicationCommandHandler:
# The currently connected connections. (The list of places we need to send
# outgoing replication commands to.)
- self._connections = [] # type: List[AbstractConnection]
+ self._connections = [] # type: List[IReplicationConnection]
LaterGauge(
"synapse_replication_tcp_resource_total_connections",
@@ -197,7 +197,7 @@ class ReplicationCommandHandler:
# For each connection, the incoming stream names that have received a POSITION
# from that connection.
- self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
+ self._streams_by_connection = {} # type: Dict[IReplicationConnection, Set[str]]
LaterGauge(
"synapse_replication_tcp_command_queue",
@@ -220,7 +220,7 @@ class ReplicationCommandHandler:
self._server_notices_sender = hs.get_server_notices_sender()
def _add_command_to_stream_queue(
- self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
+ self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
) -> None:
"""Queue the given received command for processing
@@ -267,7 +267,7 @@ class ReplicationCommandHandler:
async def _process_command(
self,
cmd: Union[PositionCommand, RdataCommand],
- conn: AbstractConnection,
+ conn: IReplicationConnection,
stream_name: str,
) -> None:
if isinstance(cmd, PositionCommand):
@@ -302,7 +302,7 @@ class ReplicationCommandHandler:
hs, outbound_redis_connection
)
hs.get_reactor().connectTCP(
- hs.config.redis.redis_host,
+ hs.config.redis.redis_host.encode(),
hs.config.redis.redis_port,
self._factory,
)
@@ -311,7 +311,7 @@ class ReplicationCommandHandler:
self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
host = hs.config.worker_replication_host
port = hs.config.worker_replication_port
- hs.get_reactor().connectTCP(host, port, self._factory)
+ hs.get_reactor().connectTCP(host.encode(), port, self._factory)
def get_streams(self) -> Dict[str, Stream]:
"""Get a map from stream name to all streams."""
@@ -321,10 +321,10 @@ class ReplicationCommandHandler:
"""Get a list of streams that this instances replicates."""
return self._streams_to_replicate
- def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
+ def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand):
self.send_positions_to_connection(conn)
- def send_positions_to_connection(self, conn: AbstractConnection):
+ def send_positions_to_connection(self, conn: IReplicationConnection):
"""Send current position of all streams this process is source of to
the connection.
"""
@@ -347,7 +347,7 @@ class ReplicationCommandHandler:
)
def on_USER_SYNC(
- self, conn: AbstractConnection, cmd: UserSyncCommand
+ self, conn: IReplicationConnection, cmd: UserSyncCommand
) -> Optional[Awaitable[None]]:
user_sync_counter.inc()
@@ -359,21 +359,23 @@ class ReplicationCommandHandler:
return None
def on_CLEAR_USER_SYNC(
- self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
+ self, conn: IReplicationConnection, cmd: ClearUserSyncsCommand
) -> Optional[Awaitable[None]]:
if self._is_master:
return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
else:
return None
- def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand):
+ def on_FEDERATION_ACK(
+ self, conn: IReplicationConnection, cmd: FederationAckCommand
+ ):
federation_ack_counter.inc()
if self._federation_sender:
self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
def on_USER_IP(
- self, conn: AbstractConnection, cmd: UserIpCommand
+ self, conn: IReplicationConnection, cmd: UserIpCommand
) -> Optional[Awaitable[None]]:
user_ip_cache_counter.inc()
@@ -395,7 +397,7 @@ class ReplicationCommandHandler:
assert self._server_notices_sender is not None
await self._server_notices_sender.on_user_ip(cmd.user_id)
- def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
+ def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand):
if cmd.instance_name == self._instance_name:
# Ignore RDATA that are just our own echoes
return
@@ -412,7 +414,7 @@ class ReplicationCommandHandler:
self._add_command_to_stream_queue(conn, cmd)
async def _process_rdata(
- self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
+ self, stream_name: str, conn: IReplicationConnection, cmd: RdataCommand
) -> None:
"""Process an RDATA command
@@ -486,7 +488,7 @@ class ReplicationCommandHandler:
stream_name, instance_name, token, rows
)
- def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
+ def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand):
if cmd.instance_name == self._instance_name:
# Ignore POSITION that are just our own echoes
return
@@ -496,7 +498,7 @@ class ReplicationCommandHandler:
self._add_command_to_stream_queue(conn, cmd)
async def _process_position(
- self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
+ self, stream_name: str, conn: IReplicationConnection, cmd: PositionCommand
) -> None:
"""Process a POSITION command
@@ -553,7 +555,9 @@ class ReplicationCommandHandler:
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
- def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand):
+ def on_REMOTE_SERVER_UP(
+ self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
+ ):
""""Called when get a new REMOTE_SERVER_UP command."""
self._replication_data_handler.on_remote_server_up(cmd.data)
@@ -576,7 +580,7 @@ class ReplicationCommandHandler:
# between two instances, but that is not currently supported).
self.send_command(cmd, ignore_conn=conn)
- def new_connection(self, connection: AbstractConnection):
+ def new_connection(self, connection: IReplicationConnection):
"""Called when we have a new connection."""
self._connections.append(connection)
@@ -603,7 +607,7 @@ class ReplicationCommandHandler:
UserSyncCommand(self._instance_id, user_id, True, now)
)
- def lost_connection(self, connection: AbstractConnection):
+ def lost_connection(self, connection: IReplicationConnection):
"""Called when a connection is closed/lost."""
# we no longer need _streams_by_connection for this connection.
streams = self._streams_by_connection.pop(connection, None)
@@ -624,7 +628,7 @@ class ReplicationCommandHandler:
return bool(self._connections)
def send_command(
- self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None
+ self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None
):
"""Send a command to all connected connections.
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index e0b4ad314d..825900f64c 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -46,7 +46,6 @@ indicate which side is sending, these are *not* included on the wire::
> ERROR server stopping
* connection closed by server *
"""
-import abc
import fcntl
import logging
import struct
@@ -54,8 +53,10 @@ from inspect import isawaitable
from typing import TYPE_CHECKING, List, Optional
from prometheus_client import Counter
+from zope.interface import Interface, implementer
from twisted.internet import task
+from twisted.internet.tcp import Connection
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
@@ -121,6 +122,14 @@ class ConnectionStates:
CLOSED = "closed"
+class IReplicationConnection(Interface):
+ """An interface for replication connections."""
+
+ def send_command(cmd: Command):
+ """Send the command down the connection"""
+
+
+@implementer(IReplicationConnection)
class BaseReplicationStreamProtocol(LineOnlyReceiver):
"""Base replication protocol shared between client and server.
@@ -137,6 +146,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
(if they send a `PING` command)
"""
+ # The transport is going to be an ITCPTransport, but that doesn't have the
+ # (un)registerProducer methods, those are only on the implementation.
+ transport = None # type: Connection
+
delimiter = b"\n"
# Valid commands we expect to receive
@@ -181,6 +194,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
connected_connections.append(self) # Register connection for metrics
+ assert self.transport is not None
self.transport.registerProducer(self, True) # For the *Producing callbacks
self._send_pending_commands()
@@ -205,6 +219,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
logger.info(
"[%s] Failed to close connection gracefully, aborting", self.id()
)
+ assert self.transport is not None
self.transport.abortConnection()
else:
if now - self.last_sent_command >= PING_TIME:
@@ -294,6 +309,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def close(self):
logger.warning("[%s] Closing connection", self.id())
self.time_we_closed = self.clock.time_msec()
+ assert self.transport is not None
self.transport.loseConnection()
self.on_connection_closed()
@@ -391,6 +407,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def connectionLost(self, reason):
logger.info("[%s] Replication connection closed: %r", self.id(), reason)
if isinstance(reason, Failure):
+ assert reason.type is not None
connection_close_counter.labels(reason.type.__name__).inc()
else:
connection_close_counter.labels(reason.__class__.__name__).inc()
@@ -495,20 +512,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.send_command(ReplicateCommand())
-class AbstractConnection(abc.ABC):
- """An interface for replication connections."""
-
- @abc.abstractmethod
- def send_command(self, cmd: Command):
- """Send the command down the connection"""
- pass
-
-
-# This tells python that `BaseReplicationStreamProtocol` implements the
-# interface.
-AbstractConnection.register(BaseReplicationStreamProtocol)
-
-
# The following simply registers metrics for the replication connections
pending_commands = LaterGauge(
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 7560706b4b..2f4d407f94 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -19,6 +19,11 @@ from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast
import attr
import txredisapi
+from zope.interface import implementer
+
+from twisted.internet.address import IPv4Address, IPv6Address
+from twisted.internet.interfaces import IAddress, IConnector
+from twisted.python.failure import Failure
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import (
@@ -32,7 +37,7 @@ from synapse.replication.tcp.commands import (
parse_command_from_line,
)
from synapse.replication.tcp.protocol import (
- AbstractConnection,
+ IReplicationConnection,
tcp_inbound_commands_counter,
tcp_outbound_commands_counter,
)
@@ -62,7 +67,8 @@ class ConstantProperty(Generic[T, V]):
pass
-class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
+@implementer(IReplicationConnection)
+class RedisSubscriber(txredisapi.SubscriberProtocol):
"""Connection to redis subscribed to replication stream.
This class fulfils two functions:
@@ -71,7 +77,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
connection, parsing *incoming* messages into replication commands, and passing them
to `ReplicationCommandHandler`
- (b) it implements the AbstractConnection API, where it sends *outgoing* commands
+ (b) it implements the IReplicationConnection API, where it sends *outgoing* commands
onto outbound_redis_connection.
Due to the vagaries of `txredisapi` we don't want to have a custom
@@ -253,6 +259,37 @@ class SynapseRedisFactory(txredisapi.RedisFactory):
except Exception:
logger.warning("Failed to send ping to a redis connection")
+ # ReconnectingClientFactory has some logging (if you enable `self.noisy`), but
+ # it's rubbish. We add our own here.
+
+ def startedConnecting(self, connector: IConnector):
+ logger.info(
+ "Connecting to redis server %s", format_address(connector.getDestination())
+ )
+ super().startedConnecting(connector)
+
+ def clientConnectionFailed(self, connector: IConnector, reason: Failure):
+ logger.info(
+ "Connection to redis server %s failed: %s",
+ format_address(connector.getDestination()),
+ reason.value,
+ )
+ super().clientConnectionFailed(connector, reason)
+
+ def clientConnectionLost(self, connector: IConnector, reason: Failure):
+ logger.info(
+ "Connection to redis server %s lost: %s",
+ format_address(connector.getDestination()),
+ reason.value,
+ )
+ super().clientConnectionLost(connector, reason)
+
+
+def format_address(address: IAddress) -> str:
+ if isinstance(address, (IPv4Address, IPv6Address)):
+ return "%s:%i" % (address.host, address.port)
+ return str(address)
+
class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
"""This is a reconnecting factory that connects to redis and immediately
@@ -328,6 +365,6 @@ def lazyConnection(
factory.continueTrying = reconnect
reactor = hs.get_reactor()
- reactor.connectTCP(host, port, factory, timeout=30, bindAddress=None)
+ reactor.connectTCP(host.encode(), port, factory, timeout=30, bindAddress=None)
return factory.handler
|