summary refs log tree commit diff
path: root/synapse/replication/tcp/redis.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/redis.py')
-rw-r--r--synapse/replication/tcp/redis.py62
1 files changed, 54 insertions, 8 deletions
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index c8f4bf8b27..7e96145b3b 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -17,7 +17,12 @@ from inspect import isawaitable
 from typing import TYPE_CHECKING, Any, Generic, List, Optional, Type, TypeVar, cast
 
 import attr
-import txredisapi
+from txredisapi import (
+    ConnectionHandler,
+    RedisFactory,
+    SubscriberProtocol,
+    UnixConnectionHandler,
+)
 from zope.interface import implementer
 
 from twisted.internet.address import IPv4Address, IPv6Address
@@ -68,7 +73,7 @@ class ConstantProperty(Generic[T, V]):
 
 
 @implementer(IReplicationConnection)
-class RedisSubscriber(txredisapi.SubscriberProtocol):
+class RedisSubscriber(SubscriberProtocol):
     """Connection to redis subscribed to replication stream.
 
     This class fulfils two functions:
@@ -95,7 +100,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
     synapse_handler: "ReplicationCommandHandler"
     synapse_stream_prefix: str
     synapse_channel_names: List[str]
-    synapse_outbound_redis_connection: txredisapi.ConnectionHandler
+    synapse_outbound_redis_connection: ConnectionHandler
 
     def __init__(self, *args: Any, **kwargs: Any):
         super().__init__(*args, **kwargs)
@@ -229,7 +234,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
         )
 
 
-class SynapseRedisFactory(txredisapi.RedisFactory):
+class SynapseRedisFactory(RedisFactory):
     """A subclass of RedisFactory that periodically sends pings to ensure that
     we detect dead connections.
     """
@@ -245,7 +250,7 @@ class SynapseRedisFactory(txredisapi.RedisFactory):
         dbid: Optional[int],
         poolsize: int,
         isLazy: bool = False,
-        handler: Type = txredisapi.ConnectionHandler,
+        handler: Type = ConnectionHandler,
         charset: str = "utf-8",
         password: Optional[str] = None,
         replyTimeout: int = 30,
@@ -326,7 +331,7 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
     def __init__(
         self,
         hs: "HomeServer",
-        outbound_redis_connection: txredisapi.ConnectionHandler,
+        outbound_redis_connection: ConnectionHandler,
         channel_names: List[str],
     ):
         super().__init__(
@@ -368,7 +373,7 @@ def lazyConnection(
     reconnect: bool = True,
     password: Optional[str] = None,
     replyTimeout: int = 30,
-) -> txredisapi.ConnectionHandler:
+) -> ConnectionHandler:
     """Creates a connection to Redis that is lazily set up and reconnects if the
     connections is lost.
     """
@@ -380,7 +385,7 @@ def lazyConnection(
         dbid=dbid,
         poolsize=1,
         isLazy=True,
-        handler=txredisapi.ConnectionHandler,
+        handler=ConnectionHandler,
         password=password,
         replyTimeout=replyTimeout,
     )
@@ -408,3 +413,44 @@ def lazyConnection(
         )
 
     return factory.handler
+
+
+def lazyUnixConnection(
+    hs: "HomeServer",
+    path: str = "/tmp/redis.sock",
+    dbid: Optional[int] = None,
+    reconnect: bool = True,
+    password: Optional[str] = None,
+    replyTimeout: int = 30,
+) -> ConnectionHandler:
+    """Creates a connection to Redis that is lazily set up and reconnects if the
+    connection is lost.
+
+    Returns:
+        A subclass of ConnectionHandler, which is a UnixConnectionHandler in this case.
+    """
+
+    uuid = path
+
+    factory = SynapseRedisFactory(
+        hs,
+        uuid=uuid,
+        dbid=dbid,
+        poolsize=1,
+        isLazy=True,
+        handler=UnixConnectionHandler,
+        password=password,
+        replyTimeout=replyTimeout,
+    )
+    factory.continueTrying = reconnect
+
+    reactor = hs.get_reactor()
+
+    reactor.connectUNIX(
+        path,
+        factory,
+        timeout=30,
+        checkPID=False,
+    )
+
+    return factory.handler