summary refs log tree commit diff
path: root/synapse/replication/tcp
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r--synapse/replication/tcp/handler.py8
-rw-r--r--synapse/replication/tcp/redis.py143
2 files changed, 98 insertions, 53 deletions
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 317796d5e0..58d46a5951 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 import logging
 from typing import (
+    TYPE_CHECKING,
     Any,
     Awaitable,
     Dict,
@@ -63,6 +64,9 @@ from synapse.replication.tcp.streams import (
     TypingStream,
 )
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -88,7 +92,7 @@ class ReplicationCommandHandler:
     back out to connections.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self._replication_data_handler = hs.get_replication_data_handler()
         self._presence_handler = hs.get_presence_handler()
         self._store = hs.get_datastore()
@@ -300,7 +304,7 @@ class ReplicationCommandHandler:
 
             # First create the connection for sending commands.
             outbound_redis_connection = lazyConnection(
-                reactor=hs.get_reactor(),
+                hs=hs,
                 host=hs.config.redis_host,
                 port=hs.config.redis_port,
                 password=hs.config.redis.redis_password,
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index bc6ba709a7..fdd087683b 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -15,7 +15,7 @@
 
 import logging
 from inspect import isawaitable
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Optional, Type, cast
 
 import txredisapi
 
@@ -23,6 +23,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
 from synapse.metrics.background_process_metrics import (
     BackgroundProcessLoggingContext,
     run_as_background_process,
+    wrap_as_background_process,
 )
 from synapse.replication.tcp.commands import (
     Command,
@@ -59,16 +60,16 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
     immediately after initialisation.
 
     Attributes:
-        handler: The command handler to handle incoming commands.
-        stream_name: The *redis* stream name to subscribe to and publish from
-            (not anything to do with Synapse replication streams).
-        outbound_redis_connection: The connection to redis to use to send
+        synapse_handler: The command handler to handle incoming commands.
+        synapse_stream_name: The *redis* stream name to subscribe to and publish
+            from (not anything to do with Synapse replication streams).
+        synapse_outbound_redis_connection: The connection to redis to use to send
             commands.
     """
 
-    handler = None  # type: ReplicationCommandHandler
-    stream_name = None  # type: str
-    outbound_redis_connection = None  # type: txredisapi.RedisProtocol
+    synapse_handler = None  # type: ReplicationCommandHandler
+    synapse_stream_name = None  # type: str
+    synapse_outbound_redis_connection = None  # type: txredisapi.RedisProtocol
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
@@ -88,19 +89,19 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
         # it's important to make sure that we only send the REPLICATE command once we
         # have successfully subscribed to the stream - otherwise we might miss the
         # POSITION response sent back by the other end.
-        logger.info("Sending redis SUBSCRIBE for %s", self.stream_name)
-        await make_deferred_yieldable(self.subscribe(self.stream_name))
+        logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name)
+        await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
         logger.info(
             "Successfully subscribed to redis stream, sending REPLICATE command"
         )
-        self.handler.new_connection(self)
+        self.synapse_handler.new_connection(self)
         await self._async_send_command(ReplicateCommand())
         logger.info("REPLICATE successfully sent")
 
         # We send out our positions when there is a new connection in case the
         # other side missed updates. We do this for Redis connections as the
         # otherside won't know we've connected and so won't issue a REPLICATE.
-        self.handler.send_positions_to_connection(self)
+        self.synapse_handler.send_positions_to_connection(self)
 
     def messageReceived(self, pattern: str, channel: str, message: str):
         """Received a message from redis.
@@ -137,7 +138,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
             cmd: received command
         """
 
-        cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
+        cmd_func = getattr(self.synapse_handler, "on_%s" % (cmd.NAME,), None)
         if not cmd_func:
             logger.warning("Unhandled command: %r", cmd)
             return
@@ -155,7 +156,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
     def connectionLost(self, reason):
         logger.info("Lost connection to redis")
         super().connectionLost(reason)
-        self.handler.lost_connection(self)
+        self.synapse_handler.lost_connection(self)
 
         # mark the logging context as finished
         self._logging_context.__exit__(None, None, None)
@@ -183,11 +184,54 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
         tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
 
         await make_deferred_yieldable(
-            self.outbound_redis_connection.publish(self.stream_name, encoded_string)
+            self.synapse_outbound_redis_connection.publish(
+                self.synapse_stream_name, encoded_string
+            )
+        )
+
+
+class SynapseRedisFactory(txredisapi.RedisFactory):
+    """A subclass of RedisFactory that periodically sends pings to ensure that
+    we detect dead connections.
+    """
+
+    def __init__(
+        self,
+        hs: "HomeServer",
+        uuid: str,
+        dbid: Optional[int],
+        poolsize: int,
+        isLazy: bool = False,
+        handler: Type = txredisapi.ConnectionHandler,
+        charset: str = "utf-8",
+        password: Optional[str] = None,
+        replyTimeout: int = 30,
+        convertNumbers: Optional[int] = True,
+    ):
+        super().__init__(
+            uuid=uuid,
+            dbid=dbid,
+            poolsize=poolsize,
+            isLazy=isLazy,
+            handler=handler,
+            charset=charset,
+            password=password,
+            replyTimeout=replyTimeout,
+            convertNumbers=convertNumbers,
         )
 
+        hs.get_clock().looping_call(self._send_ping, 30 * 1000)
+
+    @wrap_as_background_process("redis_ping")
+    async def _send_ping(self):
+        for connection in self.pool:
+            try:
+                await make_deferred_yieldable(connection.ping())
+            except Exception:
+                logger.warning("Failed to send ping to a redis connection")
 
-class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
+
+class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
     """This is a reconnecting factory that connects to redis and immediately
     subscribes to a stream.
 
@@ -206,65 +250,62 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
         self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol
     ):
 
-        super().__init__()
-
-        # This sets the password on the RedisFactory base class (as
-        # SubscriberFactory constructor doesn't pass it through).
-        self.password = hs.config.redis.redis_password
+        super().__init__(
+            hs,
+            uuid="subscriber",
+            dbid=None,
+            poolsize=1,
+            replyTimeout=30,
+            password=hs.config.redis.redis_password,
+        )
 
-        self.handler = hs.get_tcp_replication()
-        self.stream_name = hs.hostname
+        self.synapse_handler = hs.get_tcp_replication()
+        self.synapse_stream_name = hs.hostname
 
-        self.outbound_redis_connection = outbound_redis_connection
+        self.synapse_outbound_redis_connection = outbound_redis_connection
 
     def buildProtocol(self, addr):
-        p = super().buildProtocol(addr)  # type: RedisSubscriber
+        p = super().buildProtocol(addr)
+        p = cast(RedisSubscriber, p)
 
         # We do this here rather than add to the constructor of `RedisSubcriber`
         # as to do so would involve overriding `buildProtocol` entirely, however
         # the base method does some other things than just instantiating the
         # protocol.
-        p.handler = self.handler
-        p.outbound_redis_connection = self.outbound_redis_connection
-        p.stream_name = self.stream_name
-        p.password = self.password
+        p.synapse_handler = self.synapse_handler
+        p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection
+        p.synapse_stream_name = self.synapse_stream_name
 
         return p
 
 
 def lazyConnection(
-    reactor,
+    hs: "HomeServer",
     host: str = "localhost",
     port: int = 6379,
     dbid: Optional[int] = None,
     reconnect: bool = True,
-    charset: str = "utf-8",
     password: Optional[str] = None,
-    connectTimeout: Optional[int] = None,
-    replyTimeout: Optional[int] = None,
-    convertNumbers: bool = True,
+    replyTimeout: int = 30,
 ) -> txredisapi.RedisProtocol:
-    """Equivalent to `txredisapi.lazyConnection`, except allows specifying a
-    reactor.
+    """Creates a connection to Redis that is lazily set up and reconnects if the
+    connections is lost.
     """
 
-    isLazy = True
-    poolsize = 1
-
     uuid = "%s:%d" % (host, port)
-    factory = txredisapi.RedisFactory(
-        uuid,
-        dbid,
-        poolsize,
-        isLazy,
-        txredisapi.ConnectionHandler,
-        charset,
-        password,
-        replyTimeout,
-        convertNumbers,
+    factory = SynapseRedisFactory(
+        hs,
+        uuid=uuid,
+        dbid=dbid,
+        poolsize=1,
+        isLazy=True,
+        handler=txredisapi.ConnectionHandler,
+        password=password,
+        replyTimeout=replyTimeout,
     )
     factory.continueTrying = reconnect
-    for x in range(poolsize):
-        reactor.connectTCP(host, port, factory, connectTimeout)
+
+    reactor = hs.get_reactor()
+    reactor.connectTCP(host, port, factory, 30)
 
     return factory.handler