summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2021-03-16 12:42:54 +0000
committerRichard van der Hoff <richard@matrix.org>2021-03-16 12:42:54 +0000
commitd8953b34f2b593043b518304fa70099432955f81 (patch)
treefdc122f5c1e368ea4716c6c9fe66b3d76c55c2cd /synapse/replication
parentMerge remote-tracking branch 'origin/develop' into matrix-org-hotfixes (diff)
parentPass SSO IdP information to spam checker's registration function (#9626) (diff)
downloadsynapse-d8953b34f2b593043b518304fa70099432955f81.tar.xz
Merge branch 'develop' into matrix-org-hotfixes
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/http/login.py4
-rw-r--r--synapse/replication/tcp/handler.py48
-rw-r--r--synapse/replication/tcp/protocol.py33
-rw-r--r--synapse/replication/tcp/redis.py45
4 files changed, 87 insertions, 43 deletions
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py

index 36071feb36..4ec1bfa6ea 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py
@@ -61,7 +61,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): is_guest = content["is_guest"] is_appservice_ghost = content["is_appservice_ghost"] - device_id, access_token = await self.registration_handler.register_device( + res = await self.registration_handler.register_device_inner( user_id, device_id, initial_display_name, @@ -69,7 +69,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): is_appservice_ghost=is_appservice_ghost, ) - return 200, {"device_id": device_id, "access_token": access_token} + return 200, res def register_servlets(hs, http_server): diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index a7245da152..a8894beadf 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py
@@ -48,7 +48,7 @@ from synapse.replication.tcp.commands import ( UserIpCommand, UserSyncCommand, ) -from synapse.replication.tcp.protocol import AbstractConnection +from synapse.replication.tcp.protocol import IReplicationConnection from synapse.replication.tcp.streams import ( STREAMS_MAP, AccountDataStream, @@ -82,7 +82,7 @@ user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache" # the type of the entries in _command_queues_by_stream _StreamCommandQueue = Deque[ - Tuple[Union[RdataCommand, PositionCommand], AbstractConnection] + Tuple[Union[RdataCommand, PositionCommand], IReplicationConnection] ] @@ -174,7 +174,7 @@ class ReplicationCommandHandler: # The currently connected connections. (The list of places we need to send # outgoing replication commands to.) - self._connections = [] # type: List[AbstractConnection] + self._connections = [] # type: List[IReplicationConnection] LaterGauge( "synapse_replication_tcp_resource_total_connections", @@ -197,7 +197,7 @@ class ReplicationCommandHandler: # For each connection, the incoming stream names that have received a POSITION # from that connection. - self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]] + self._streams_by_connection = {} # type: Dict[IReplicationConnection, Set[str]] LaterGauge( "synapse_replication_tcp_command_queue", @@ -220,7 +220,7 @@ class ReplicationCommandHandler: self._server_notices_sender = hs.get_server_notices_sender() def _add_command_to_stream_queue( - self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand] + self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand] ) -> None: """Queue the given received command for processing @@ -267,7 +267,7 @@ class ReplicationCommandHandler: async def _process_command( self, cmd: Union[PositionCommand, RdataCommand], - conn: AbstractConnection, + conn: IReplicationConnection, stream_name: str, ) -> None: if isinstance(cmd, PositionCommand): @@ -302,7 +302,7 @@ class ReplicationCommandHandler: hs, outbound_redis_connection ) hs.get_reactor().connectTCP( - hs.config.redis.redis_host, + hs.config.redis.redis_host.encode(), hs.config.redis.redis_port, self._factory, ) @@ -311,7 +311,7 @@ class ReplicationCommandHandler: self._factory = DirectTcpReplicationClientFactory(hs, client_name, self) host = hs.config.worker_replication_host port = hs.config.worker_replication_port - hs.get_reactor().connectTCP(host, port, self._factory) + hs.get_reactor().connectTCP(host.encode(), port, self._factory) def get_streams(self) -> Dict[str, Stream]: """Get a map from stream name to all streams.""" @@ -321,10 +321,10 @@ class ReplicationCommandHandler: """Get a list of streams that this instances replicates.""" return self._streams_to_replicate - def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand): + def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand): self.send_positions_to_connection(conn) - def send_positions_to_connection(self, conn: AbstractConnection): + def send_positions_to_connection(self, conn: IReplicationConnection): """Send current position of all streams this process is source of to the connection. """ @@ -347,7 +347,7 @@ class ReplicationCommandHandler: ) def on_USER_SYNC( - self, conn: AbstractConnection, cmd: UserSyncCommand + self, conn: IReplicationConnection, cmd: UserSyncCommand ) -> Optional[Awaitable[None]]: user_sync_counter.inc() @@ -359,21 +359,23 @@ class ReplicationCommandHandler: return None def on_CLEAR_USER_SYNC( - self, conn: AbstractConnection, cmd: ClearUserSyncsCommand + self, conn: IReplicationConnection, cmd: ClearUserSyncsCommand ) -> Optional[Awaitable[None]]: if self._is_master: return self._presence_handler.update_external_syncs_clear(cmd.instance_id) else: return None - def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand): + def on_FEDERATION_ACK( + self, conn: IReplicationConnection, cmd: FederationAckCommand + ): federation_ack_counter.inc() if self._federation_sender: self._federation_sender.federation_ack(cmd.instance_name, cmd.token) def on_USER_IP( - self, conn: AbstractConnection, cmd: UserIpCommand + self, conn: IReplicationConnection, cmd: UserIpCommand ) -> Optional[Awaitable[None]]: user_ip_cache_counter.inc() @@ -395,7 +397,7 @@ class ReplicationCommandHandler: assert self._server_notices_sender is not None await self._server_notices_sender.on_user_ip(cmd.user_id) - def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): + def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand): if cmd.instance_name == self._instance_name: # Ignore RDATA that are just our own echoes return @@ -412,7 +414,7 @@ class ReplicationCommandHandler: self._add_command_to_stream_queue(conn, cmd) async def _process_rdata( - self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand + self, stream_name: str, conn: IReplicationConnection, cmd: RdataCommand ) -> None: """Process an RDATA command @@ -486,7 +488,7 @@ class ReplicationCommandHandler: stream_name, instance_name, token, rows ) - def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): + def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand): if cmd.instance_name == self._instance_name: # Ignore POSITION that are just our own echoes return @@ -496,7 +498,7 @@ class ReplicationCommandHandler: self._add_command_to_stream_queue(conn, cmd) async def _process_position( - self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand + self, stream_name: str, conn: IReplicationConnection, cmd: PositionCommand ) -> None: """Process a POSITION command @@ -553,7 +555,9 @@ class ReplicationCommandHandler: self._streams_by_connection.setdefault(conn, set()).add(stream_name) - def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand): + def on_REMOTE_SERVER_UP( + self, conn: IReplicationConnection, cmd: RemoteServerUpCommand + ): """"Called when get a new REMOTE_SERVER_UP command.""" self._replication_data_handler.on_remote_server_up(cmd.data) @@ -576,7 +580,7 @@ class ReplicationCommandHandler: # between two instances, but that is not currently supported). self.send_command(cmd, ignore_conn=conn) - def new_connection(self, connection: AbstractConnection): + def new_connection(self, connection: IReplicationConnection): """Called when we have a new connection.""" self._connections.append(connection) @@ -603,7 +607,7 @@ class ReplicationCommandHandler: UserSyncCommand(self._instance_id, user_id, True, now) ) - def lost_connection(self, connection: AbstractConnection): + def lost_connection(self, connection: IReplicationConnection): """Called when a connection is closed/lost.""" # we no longer need _streams_by_connection for this connection. streams = self._streams_by_connection.pop(connection, None) @@ -624,7 +628,7 @@ class ReplicationCommandHandler: return bool(self._connections) def send_command( - self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None + self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None ): """Send a command to all connected connections. diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index e0b4ad314d..825900f64c 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py
@@ -46,7 +46,6 @@ indicate which side is sending, these are *not* included on the wire:: > ERROR server stopping * connection closed by server * """ -import abc import fcntl import logging import struct @@ -54,8 +53,10 @@ from inspect import isawaitable from typing import TYPE_CHECKING, List, Optional from prometheus_client import Counter +from zope.interface import Interface, implementer from twisted.internet import task +from twisted.internet.tcp import Connection from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure @@ -121,6 +122,14 @@ class ConnectionStates: CLOSED = "closed" +class IReplicationConnection(Interface): + """An interface for replication connections.""" + + def send_command(cmd: Command): + """Send the command down the connection""" + + +@implementer(IReplicationConnection) class BaseReplicationStreamProtocol(LineOnlyReceiver): """Base replication protocol shared between client and server. @@ -137,6 +146,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): (if they send a `PING` command) """ + # The transport is going to be an ITCPTransport, but that doesn't have the + # (un)registerProducer methods, those are only on the implementation. + transport = None # type: Connection + delimiter = b"\n" # Valid commands we expect to receive @@ -181,6 +194,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): connected_connections.append(self) # Register connection for metrics + assert self.transport is not None self.transport.registerProducer(self, True) # For the *Producing callbacks self._send_pending_commands() @@ -205,6 +219,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): logger.info( "[%s] Failed to close connection gracefully, aborting", self.id() ) + assert self.transport is not None self.transport.abortConnection() else: if now - self.last_sent_command >= PING_TIME: @@ -294,6 +309,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): def close(self): logger.warning("[%s] Closing connection", self.id()) self.time_we_closed = self.clock.time_msec() + assert self.transport is not None self.transport.loseConnection() self.on_connection_closed() @@ -391,6 +407,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): def connectionLost(self, reason): logger.info("[%s] Replication connection closed: %r", self.id(), reason) if isinstance(reason, Failure): + assert reason.type is not None connection_close_counter.labels(reason.type.__name__).inc() else: connection_close_counter.labels(reason.__class__.__name__).inc() @@ -495,20 +512,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.send_command(ReplicateCommand()) -class AbstractConnection(abc.ABC): - """An interface for replication connections.""" - - @abc.abstractmethod - def send_command(self, cmd: Command): - """Send the command down the connection""" - pass - - -# This tells python that `BaseReplicationStreamProtocol` implements the -# interface. -AbstractConnection.register(BaseReplicationStreamProtocol) - - # The following simply registers metrics for the replication connections pending_commands = LaterGauge( diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 7560706b4b..2f4d407f94 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py
@@ -19,6 +19,11 @@ from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast import attr import txredisapi +from zope.interface import implementer + +from twisted.internet.address import IPv4Address, IPv6Address +from twisted.internet.interfaces import IAddress, IConnector +from twisted.python.failure import Failure from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import ( @@ -32,7 +37,7 @@ from synapse.replication.tcp.commands import ( parse_command_from_line, ) from synapse.replication.tcp.protocol import ( - AbstractConnection, + IReplicationConnection, tcp_inbound_commands_counter, tcp_outbound_commands_counter, ) @@ -62,7 +67,8 @@ class ConstantProperty(Generic[T, V]): pass -class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): +@implementer(IReplicationConnection) +class RedisSubscriber(txredisapi.SubscriberProtocol): """Connection to redis subscribed to replication stream. This class fulfils two functions: @@ -71,7 +77,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): connection, parsing *incoming* messages into replication commands, and passing them to `ReplicationCommandHandler` - (b) it implements the AbstractConnection API, where it sends *outgoing* commands + (b) it implements the IReplicationConnection API, where it sends *outgoing* commands onto outbound_redis_connection. Due to the vagaries of `txredisapi` we don't want to have a custom @@ -253,6 +259,37 @@ class SynapseRedisFactory(txredisapi.RedisFactory): except Exception: logger.warning("Failed to send ping to a redis connection") + # ReconnectingClientFactory has some logging (if you enable `self.noisy`), but + # it's rubbish. We add our own here. + + def startedConnecting(self, connector: IConnector): + logger.info( + "Connecting to redis server %s", format_address(connector.getDestination()) + ) + super().startedConnecting(connector) + + def clientConnectionFailed(self, connector: IConnector, reason: Failure): + logger.info( + "Connection to redis server %s failed: %s", + format_address(connector.getDestination()), + reason.value, + ) + super().clientConnectionFailed(connector, reason) + + def clientConnectionLost(self, connector: IConnector, reason: Failure): + logger.info( + "Connection to redis server %s lost: %s", + format_address(connector.getDestination()), + reason.value, + ) + super().clientConnectionLost(connector, reason) + + +def format_address(address: IAddress) -> str: + if isinstance(address, (IPv4Address, IPv6Address)): + return "%s:%i" % (address.host, address.port) + return str(address) + class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory): """This is a reconnecting factory that connects to redis and immediately @@ -328,6 +365,6 @@ def lazyConnection( factory.continueTrying = reconnect reactor = hs.get_reactor() - reactor.connectTCP(host, port, factory, timeout=30, bindAddress=None) + reactor.connectTCP(host.encode(), port, factory, timeout=30, bindAddress=None) return factory.handler