summary refs log tree commit diff
path: root/synapse/replication/tcp/handler.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/handler.py')
-rw-r--r--synapse/replication/tcp/handler.py51
1 files changed, 20 insertions, 31 deletions
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 317796d5e0..d1d00c3717 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 import logging
 from typing import (
+    TYPE_CHECKING,
     Any,
     Awaitable,
     Dict,
@@ -63,6 +64,9 @@ from synapse.replication.tcp.streams import (
     TypingStream,
 )
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -88,7 +92,7 @@ class ReplicationCommandHandler:
     back out to connections.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self._replication_data_handler = hs.get_replication_data_handler()
         self._presence_handler = hs.get_presence_handler()
         self._store = hs.get_datastore()
@@ -282,13 +286,6 @@ class ReplicationCommandHandler:
         if hs.config.redis.redis_enabled:
             from synapse.replication.tcp.redis import (
                 RedisDirectTcpReplicationClientFactory,
-                lazyConnection,
-            )
-
-            logger.info(
-                "Connecting to redis (host=%r port=%r)",
-                hs.config.redis_host,
-                hs.config.redis_port,
             )
 
             # First let's ensure that we have a ReplicationStreamer started.
@@ -299,20 +296,16 @@ class ReplicationCommandHandler:
             # connection after SUBSCRIBE is called).
 
             # First create the connection for sending commands.
-            outbound_redis_connection = lazyConnection(
-                reactor=hs.get_reactor(),
-                host=hs.config.redis_host,
-                port=hs.config.redis_port,
-                password=hs.config.redis.redis_password,
-                reconnect=True,
-            )
+            outbound_redis_connection = hs.get_outbound_redis_connection()
 
             # Now create the factory/connection for the subscription stream.
             self._factory = RedisDirectTcpReplicationClientFactory(
                 hs, outbound_redis_connection
             )
             hs.get_reactor().connectTCP(
-                hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory,
+                hs.config.redis.redis_host,
+                hs.config.redis.redis_port,
+                self._factory,
             )
         else:
             client_name = hs.get_instance_name()
@@ -322,13 +315,11 @@ class ReplicationCommandHandler:
             hs.get_reactor().connectTCP(host, port, self._factory)
 
     def get_streams(self) -> Dict[str, Stream]:
-        """Get a map from stream name to all streams.
-        """
+        """Get a map from stream name to all streams."""
         return self._streams
 
     def get_streams_to_replicate(self) -> List[Stream]:
-        """Get a list of streams that this instances replicates.
-        """
+        """Get a list of streams that this instances replicates."""
         return self._streams_to_replicate
 
     def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
@@ -349,7 +340,10 @@ class ReplicationCommandHandler:
             current_token = stream.current_token(self._instance_name)
             self.send_command(
                 PositionCommand(
-                    stream.NAME, self._instance_name, current_token, current_token,
+                    stream.NAME,
+                    self._instance_name,
+                    current_token,
+                    current_token,
                 )
             )
 
@@ -601,8 +595,7 @@ class ReplicationCommandHandler:
         self.send_command(cmd, ignore_conn=conn)
 
     def new_connection(self, connection: AbstractConnection):
-        """Called when we have a new connection.
-        """
+        """Called when we have a new connection."""
         self._connections.append(connection)
 
         # If we are connected to replication as a client (rather than a server)
@@ -629,8 +622,7 @@ class ReplicationCommandHandler:
             )
 
     def lost_connection(self, connection: AbstractConnection):
-        """Called when a connection is closed/lost.
-        """
+        """Called when a connection is closed/lost."""
         # we no longer need _streams_by_connection for this connection.
         streams = self._streams_by_connection.pop(connection, None)
         if streams:
@@ -687,15 +679,13 @@ class ReplicationCommandHandler:
     def send_user_sync(
         self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
     ):
-        """Poke the master that a user has started/stopped syncing.
-        """
+        """Poke the master that a user has started/stopped syncing."""
         self.send_command(
             UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
         )
 
     def send_remove_pusher(self, app_id: str, push_key: str, user_id: str):
-        """Poke the master to remove a pusher for a user
-        """
+        """Poke the master to remove a pusher for a user"""
         cmd = RemovePusherCommand(app_id, push_key, user_id)
         self.send_command(cmd)
 
@@ -708,8 +698,7 @@ class ReplicationCommandHandler:
         device_id: str,
         last_seen: int,
     ):
-        """Tell the master that the user made a request.
-        """
+        """Tell the master that the user made a request."""
         cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
         self.send_command(cmd)