summary refs log tree commit diff
path: root/synapse/replication/tcp/redis.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/redis.py')
-rw-r--r--synapse/replication/tcp/redis.py173
1 files changed, 118 insertions, 55 deletions
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py

index bc6ba709a7..0e6155cf53 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py
@@ -15,14 +15,16 @@ import logging from inspect import isawaitable -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast +import attr import txredisapi from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import ( BackgroundProcessLoggingContext, run_as_background_process, + wrap_as_background_process, ) from synapse.replication.tcp.commands import ( Command, @@ -41,6 +43,24 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +T = TypeVar("T") +V = TypeVar("V") + + +@attr.s +class ConstantProperty(Generic[T, V]): + """A descriptor that returns the given constant, ignoring attempts to set + it. + """ + + constant = attr.ib() # type: V + + def __get__(self, obj: Optional[T], objtype: Type[T] = None) -> V: + return self.constant + + def __set__(self, obj: Optional[T], value: V): + pass + class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): """Connection to redis subscribed to replication stream. @@ -59,16 +79,16 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): immediately after initialisation. Attributes: - handler: The command handler to handle incoming commands. - stream_name: The *redis* stream name to subscribe to and publish from - (not anything to do with Synapse replication streams). - outbound_redis_connection: The connection to redis to use to send + synapse_handler: The command handler to handle incoming commands. + synapse_stream_name: The *redis* stream name to subscribe to and publish + from (not anything to do with Synapse replication streams). + synapse_outbound_redis_connection: The connection to redis to use to send commands. """ - handler = None # type: ReplicationCommandHandler - stream_name = None # type: str - outbound_redis_connection = None # type: txredisapi.RedisProtocol + synapse_handler = None # type: ReplicationCommandHandler + synapse_stream_name = None # type: str + synapse_outbound_redis_connection = None # type: txredisapi.RedisProtocol def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -88,23 +108,22 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): # it's important to make sure that we only send the REPLICATE command once we # have successfully subscribed to the stream - otherwise we might miss the # POSITION response sent back by the other end. - logger.info("Sending redis SUBSCRIBE for %s", self.stream_name) - await make_deferred_yieldable(self.subscribe(self.stream_name)) + logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name) + await make_deferred_yieldable(self.subscribe(self.synapse_stream_name)) logger.info( "Successfully subscribed to redis stream, sending REPLICATE command" ) - self.handler.new_connection(self) + self.synapse_handler.new_connection(self) await self._async_send_command(ReplicateCommand()) logger.info("REPLICATE successfully sent") # We send out our positions when there is a new connection in case the # other side missed updates. We do this for Redis connections as the # otherside won't know we've connected and so won't issue a REPLICATE. - self.handler.send_positions_to_connection(self) + self.synapse_handler.send_positions_to_connection(self) def messageReceived(self, pattern: str, channel: str, message: str): - """Received a message from redis. - """ + """Received a message from redis.""" with PreserveLoggingContext(self._logging_context): self._parse_and_dispatch_message(message) @@ -117,7 +136,8 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): cmd = parse_command_from_line(message) except Exception: logger.exception( - "Failed to parse replication line: %r", message, + "Failed to parse replication line: %r", + message, ) return @@ -137,7 +157,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): cmd: received command """ - cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) + cmd_func = getattr(self.synapse_handler, "on_%s" % (cmd.NAME,), None) if not cmd_func: logger.warning("Unhandled command: %r", cmd) return @@ -155,7 +175,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): def connectionLost(self, reason): logger.info("Lost connection to redis") super().connectionLost(reason) - self.handler.lost_connection(self) + self.synapse_handler.lost_connection(self) # mark the logging context as finished self._logging_context.__exit__(None, None, None) @@ -183,11 +203,58 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc() await make_deferred_yieldable( - self.outbound_redis_connection.publish(self.stream_name, encoded_string) + self.synapse_outbound_redis_connection.publish( + self.synapse_stream_name, encoded_string + ) ) -class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): +class SynapseRedisFactory(txredisapi.RedisFactory): + """A subclass of RedisFactory that periodically sends pings to ensure that + we detect dead connections. + """ + + # We want to *always* retry connecting, txredisapi will stop if there is a + # failure during certain operations, e.g. during AUTH. + continueTrying = cast(bool, ConstantProperty(True)) + + def __init__( + self, + hs: "HomeServer", + uuid: str, + dbid: Optional[int], + poolsize: int, + isLazy: bool = False, + handler: Type = txredisapi.ConnectionHandler, + charset: str = "utf-8", + password: Optional[str] = None, + replyTimeout: int = 30, + convertNumbers: Optional[int] = True, + ): + super().__init__( + uuid=uuid, + dbid=dbid, + poolsize=poolsize, + isLazy=isLazy, + handler=handler, + charset=charset, + password=password, + replyTimeout=replyTimeout, + convertNumbers=convertNumbers, + ) + + hs.get_clock().looping_call(self._send_ping, 30 * 1000) + + @wrap_as_background_process("redis_ping") + async def _send_ping(self): + for connection in self.pool: + try: + await make_deferred_yieldable(connection.ping()) + except Exception: + logger.warning("Failed to send ping to a redis connection") + + +class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory): """This is a reconnecting factory that connects to redis and immediately subscribes to a stream. @@ -199,72 +266,68 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): """ maxDelay = 5 - continueTrying = True protocol = RedisSubscriber def __init__( self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol ): - super().__init__() - - # This sets the password on the RedisFactory base class (as - # SubscriberFactory constructor doesn't pass it through). - self.password = hs.config.redis.redis_password + super().__init__( + hs, + uuid="subscriber", + dbid=None, + poolsize=1, + replyTimeout=30, + password=hs.config.redis.redis_password, + ) - self.handler = hs.get_tcp_replication() - self.stream_name = hs.hostname + self.synapse_handler = hs.get_tcp_replication() + self.synapse_stream_name = hs.hostname - self.outbound_redis_connection = outbound_redis_connection + self.synapse_outbound_redis_connection = outbound_redis_connection def buildProtocol(self, addr): - p = super().buildProtocol(addr) # type: RedisSubscriber + p = super().buildProtocol(addr) + p = cast(RedisSubscriber, p) # We do this here rather than add to the constructor of `RedisSubcriber` # as to do so would involve overriding `buildProtocol` entirely, however # the base method does some other things than just instantiating the # protocol. - p.handler = self.handler - p.outbound_redis_connection = self.outbound_redis_connection - p.stream_name = self.stream_name - p.password = self.password + p.synapse_handler = self.synapse_handler + p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection + p.synapse_stream_name = self.synapse_stream_name return p def lazyConnection( - reactor, + hs: "HomeServer", host: str = "localhost", port: int = 6379, dbid: Optional[int] = None, reconnect: bool = True, - charset: str = "utf-8", password: Optional[str] = None, - connectTimeout: Optional[int] = None, - replyTimeout: Optional[int] = None, - convertNumbers: bool = True, + replyTimeout: int = 30, ) -> txredisapi.RedisProtocol: - """Equivalent to `txredisapi.lazyConnection`, except allows specifying a - reactor. + """Creates a connection to Redis that is lazily set up and reconnects if the + connections is lost. """ - isLazy = True - poolsize = 1 - uuid = "%s:%d" % (host, port) - factory = txredisapi.RedisFactory( - uuid, - dbid, - poolsize, - isLazy, - txredisapi.ConnectionHandler, - charset, - password, - replyTimeout, - convertNumbers, + factory = SynapseRedisFactory( + hs, + uuid=uuid, + dbid=dbid, + poolsize=1, + isLazy=True, + handler=txredisapi.ConnectionHandler, + password=password, + replyTimeout=replyTimeout, ) factory.continueTrying = reconnect - for x in range(poolsize): - reactor.connectTCP(host, port, factory, connectTimeout) + + reactor = hs.get_reactor() + reactor.connectTCP(host, port, factory, 30) return factory.handler