diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 317796d5e0..58d46a5951 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
from typing import (
+ TYPE_CHECKING,
Any,
Awaitable,
Dict,
@@ -63,6 +64,9 @@ from synapse.replication.tcp.streams import (
TypingStream,
)
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -88,7 +92,7 @@ class ReplicationCommandHandler:
back out to connections.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self._replication_data_handler = hs.get_replication_data_handler()
self._presence_handler = hs.get_presence_handler()
self._store = hs.get_datastore()
@@ -300,7 +304,7 @@ class ReplicationCommandHandler:
# First create the connection for sending commands.
outbound_redis_connection = lazyConnection(
- reactor=hs.get_reactor(),
+ hs=hs,
host=hs.config.redis_host,
port=hs.config.redis_port,
password=hs.config.redis.redis_password,
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index bc6ba709a7..fdd087683b 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -15,7 +15,7 @@
import logging
from inspect import isawaitable
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Optional, Type, cast
import txredisapi
@@ -23,6 +23,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import (
BackgroundProcessLoggingContext,
run_as_background_process,
+ wrap_as_background_process,
)
from synapse.replication.tcp.commands import (
Command,
@@ -59,16 +60,16 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
immediately after initialisation.
Attributes:
- handler: The command handler to handle incoming commands.
- stream_name: The *redis* stream name to subscribe to and publish from
- (not anything to do with Synapse replication streams).
- outbound_redis_connection: The connection to redis to use to send
+ synapse_handler: The command handler to handle incoming commands.
+ synapse_stream_name: The *redis* stream name to subscribe to and publish
+ from (not anything to do with Synapse replication streams).
+ synapse_outbound_redis_connection: The connection to redis to use to send
commands.
"""
- handler = None # type: ReplicationCommandHandler
- stream_name = None # type: str
- outbound_redis_connection = None # type: txredisapi.RedisProtocol
+ synapse_handler = None # type: ReplicationCommandHandler
+ synapse_stream_name = None # type: str
+ synapse_outbound_redis_connection = None # type: txredisapi.RedisProtocol
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -88,19 +89,19 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
# it's important to make sure that we only send the REPLICATE command once we
# have successfully subscribed to the stream - otherwise we might miss the
# POSITION response sent back by the other end.
- logger.info("Sending redis SUBSCRIBE for %s", self.stream_name)
- await make_deferred_yieldable(self.subscribe(self.stream_name))
+ logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name)
+ await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
logger.info(
"Successfully subscribed to redis stream, sending REPLICATE command"
)
- self.handler.new_connection(self)
+ self.synapse_handler.new_connection(self)
await self._async_send_command(ReplicateCommand())
logger.info("REPLICATE successfully sent")
# We send out our positions when there is a new connection in case the
# other side missed updates. We do this for Redis connections as the
# otherside won't know we've connected and so won't issue a REPLICATE.
- self.handler.send_positions_to_connection(self)
+ self.synapse_handler.send_positions_to_connection(self)
def messageReceived(self, pattern: str, channel: str, message: str):
"""Received a message from redis.
@@ -137,7 +138,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
cmd: received command
"""
- cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
+ cmd_func = getattr(self.synapse_handler, "on_%s" % (cmd.NAME,), None)
if not cmd_func:
logger.warning("Unhandled command: %r", cmd)
return
@@ -155,7 +156,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
def connectionLost(self, reason):
logger.info("Lost connection to redis")
super().connectionLost(reason)
- self.handler.lost_connection(self)
+ self.synapse_handler.lost_connection(self)
# mark the logging context as finished
self._logging_context.__exit__(None, None, None)
@@ -183,11 +184,54 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
await make_deferred_yieldable(
- self.outbound_redis_connection.publish(self.stream_name, encoded_string)
+ self.synapse_outbound_redis_connection.publish(
+ self.synapse_stream_name, encoded_string
+ )
+ )
+
+
+class SynapseRedisFactory(txredisapi.RedisFactory):
+ """A subclass of RedisFactory that periodically sends pings to ensure that
+ we detect dead connections.
+ """
+
+ def __init__(
+ self,
+ hs: "HomeServer",
+ uuid: str,
+ dbid: Optional[int],
+ poolsize: int,
+ isLazy: bool = False,
+ handler: Type = txredisapi.ConnectionHandler,
+ charset: str = "utf-8",
+ password: Optional[str] = None,
+ replyTimeout: int = 30,
+ convertNumbers: Optional[int] = True,
+ ):
+ super().__init__(
+ uuid=uuid,
+ dbid=dbid,
+ poolsize=poolsize,
+ isLazy=isLazy,
+ handler=handler,
+ charset=charset,
+ password=password,
+ replyTimeout=replyTimeout,
+ convertNumbers=convertNumbers,
)
+ hs.get_clock().looping_call(self._send_ping, 30 * 1000)
+
+ @wrap_as_background_process("redis_ping")
+ async def _send_ping(self):
+ for connection in self.pool:
+ try:
+ await make_deferred_yieldable(connection.ping())
+ except Exception:
+ logger.warning("Failed to send ping to a redis connection")
-class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
+
+class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
"""This is a reconnecting factory that connects to redis and immediately
subscribes to a stream.
@@ -206,65 +250,62 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol
):
- super().__init__()
-
- # This sets the password on the RedisFactory base class (as
- # SubscriberFactory constructor doesn't pass it through).
- self.password = hs.config.redis.redis_password
+ super().__init__(
+ hs,
+ uuid="subscriber",
+ dbid=None,
+ poolsize=1,
+ replyTimeout=30,
+ password=hs.config.redis.redis_password,
+ )
- self.handler = hs.get_tcp_replication()
- self.stream_name = hs.hostname
+ self.synapse_handler = hs.get_tcp_replication()
+ self.synapse_stream_name = hs.hostname
- self.outbound_redis_connection = outbound_redis_connection
+ self.synapse_outbound_redis_connection = outbound_redis_connection
def buildProtocol(self, addr):
- p = super().buildProtocol(addr) # type: RedisSubscriber
+ p = super().buildProtocol(addr)
+ p = cast(RedisSubscriber, p)
# We do this here rather than add to the constructor of `RedisSubcriber`
# as to do so would involve overriding `buildProtocol` entirely, however
# the base method does some other things than just instantiating the
# protocol.
- p.handler = self.handler
- p.outbound_redis_connection = self.outbound_redis_connection
- p.stream_name = self.stream_name
- p.password = self.password
+ p.synapse_handler = self.synapse_handler
+ p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection
+ p.synapse_stream_name = self.synapse_stream_name
return p
def lazyConnection(
- reactor,
+ hs: "HomeServer",
host: str = "localhost",
port: int = 6379,
dbid: Optional[int] = None,
reconnect: bool = True,
- charset: str = "utf-8",
password: Optional[str] = None,
- connectTimeout: Optional[int] = None,
- replyTimeout: Optional[int] = None,
- convertNumbers: bool = True,
+ replyTimeout: int = 30,
) -> txredisapi.RedisProtocol:
- """Equivalent to `txredisapi.lazyConnection`, except allows specifying a
- reactor.
+ """Creates a connection to Redis that is lazily set up and reconnects if the
+ connections is lost.
"""
- isLazy = True
- poolsize = 1
-
uuid = "%s:%d" % (host, port)
- factory = txredisapi.RedisFactory(
- uuid,
- dbid,
- poolsize,
- isLazy,
- txredisapi.ConnectionHandler,
- charset,
- password,
- replyTimeout,
- convertNumbers,
+ factory = SynapseRedisFactory(
+ hs,
+ uuid=uuid,
+ dbid=dbid,
+ poolsize=1,
+ isLazy=True,
+ handler=txredisapi.ConnectionHandler,
+ password=password,
+ replyTimeout=replyTimeout,
)
factory.continueTrying = reconnect
- for x in range(poolsize):
- reactor.connectTCP(host, port, factory, connectTimeout)
+
+ reactor = hs.get_reactor()
+ reactor.connectTCP(host, port, factory, 30)
return factory.handler
|