summary refs log tree commit diff
path: root/synapse/replication/tcp
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r--synapse/replication/tcp/client.py2
-rw-r--r--synapse/replication/tcp/commands.py18
-rw-r--r--synapse/replication/tcp/handler.py50
-rw-r--r--synapse/replication/tcp/protocol.py38
-rw-r--r--synapse/replication/tcp/redis.py181
5 files changed, 255 insertions, 34 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 700ae79158..2d07b8b2d0 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-class ReplicationClientFactory(ReconnectingClientFactory):
+class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
     """Factory for building connections to the master. Will reconnect if the
     connection is lost.
 
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 5ec89d0fb8..5002efe6a0 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -454,3 +454,21 @@ VALID_CLIENT_COMMANDS = (
     ErrorCommand.NAME,
     RemoteServerUpCommand.NAME,
 )
+
+
+def parse_command_from_line(line: str) -> Command:
+    """Parses a command from a received line.
+
+    Line should already be stripped of whitespace and be checked if blank.
+    """
+
+    idx = line.index(" ")
+    if idx >= 0:
+        cmd_name = line[:idx]
+        rest_of_line = line[idx + 1 :]
+    else:
+        cmd_name = line
+        rest_of_line = ""
+
+    cmd_cls = COMMAND_MAP[cmd_name]
+    return cmd_cls.from_line(rest_of_line)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index e32e68e8c4..5b5ee2c13e 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -30,8 +30,10 @@ from typing import (
 
 from prometheus_client import Counter
 
+from twisted.internet.protocol import ReconnectingClientFactory
+
 from synapse.metrics import LaterGauge
-from synapse.replication.tcp.client import ReplicationClientFactory
+from synapse.replication.tcp.client import DirectTcpReplicationClientFactory
 from synapse.replication.tcp.commands import (
     ClearUserSyncsCommand,
     Command,
@@ -92,7 +94,7 @@ class ReplicationCommandHandler:
         self._pending_batches = {}  # type: Dict[str, List[Any]]
 
         # The factory used to create connections.
-        self._factory = None  # type: Optional[ReplicationClientFactory]
+        self._factory = None  # type: Optional[ReconnectingClientFactory]
 
         # The currently connected connections.
         self._connections = []  # type: List[AbstractConnection]
@@ -119,11 +121,45 @@ class ReplicationCommandHandler:
         """Helper method to start a replication connection to the remote server
         using TCP.
         """
-        client_name = hs.config.worker_name
-        self._factory = ReplicationClientFactory(hs, client_name, self)
-        host = hs.config.worker_replication_host
-        port = hs.config.worker_replication_port
-        hs.get_reactor().connectTCP(host, port, self._factory)
+        if hs.config.redis.redis_enabled:
+            from synapse.replication.tcp.redis import (
+                RedisDirectTcpReplicationClientFactory,
+            )
+            import txredisapi
+
+            logger.info(
+                "Connecting to redis (host=%r port=%r DBID=%r)",
+                hs.config.redis_host,
+                hs.config.redis_port,
+                hs.config.redis_dbid,
+            )
+
+            # We need two connections to redis, one for the subscription stream and
+            # one to send commands to (as you can't send further redis commands to a
+            # connection after SUBSCRIBE is called).
+
+            # First create the connection for sending commands.
+            outbound_redis_connection = txredisapi.lazyConnection(
+                host=hs.config.redis_host,
+                port=hs.config.redis_port,
+                dbid=hs.config.redis_dbid,
+                password=hs.config.redis.redis_password,
+                reconnect=True,
+            )
+
+            # Now create the factory/connection for the subscription stream.
+            self._factory = RedisDirectTcpReplicationClientFactory(
+                hs, outbound_redis_connection
+            )
+            hs.get_reactor().connectTCP(
+                hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory,
+            )
+        else:
+            client_name = hs.config.worker_name
+            self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
+            host = hs.config.worker_replication_host
+            port = hs.config.worker_replication_port
+            hs.get_reactor().connectTCP(host, port, self._factory)
 
     async def on_REPLICATE(self, cmd: ReplicateCommand):
         # We only want to announce positions by the writer of the streams.
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 9276ed2965..7240acb0a2 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -63,7 +63,6 @@ from twisted.python.failure import Failure
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.commands import (
-    COMMAND_MAP,
     VALID_CLIENT_COMMANDS,
     VALID_SERVER_COMMANDS,
     Command,
@@ -72,6 +71,7 @@ from synapse.replication.tcp.commands import (
     PingCommand,
     ReplicateCommand,
     ServerCommand,
+    parse_command_from_line,
 )
 from synapse.types import Collection
 from synapse.util import Clock
@@ -210,38 +210,24 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
         linestr = line.decode("utf-8")
 
-        # split at the first " ", handling one-word commands
-        idx = linestr.index(" ")
-        if idx >= 0:
-            cmd_name = linestr[:idx]
-            rest_of_line = linestr[idx + 1 :]
-        else:
-            cmd_name = linestr
-            rest_of_line = ""
+        try:
+            cmd = parse_command_from_line(linestr)
+        except Exception as e:
+            logger.exception("[%s] failed to parse line: %r", self.id(), linestr)
+            self.send_error("failed to parse line: %r (%r):" % (e, linestr))
+            return
 
-        if cmd_name not in self.VALID_INBOUND_COMMANDS:
-            logger.error("[%s] invalid command %s", self.id(), cmd_name)
-            self.send_error("invalid command: %s", cmd_name)
+        if cmd.NAME not in self.VALID_INBOUND_COMMANDS:
+            logger.error("[%s] invalid command %s", self.id(), cmd.NAME)
+            self.send_error("invalid command: %s", cmd.NAME)
             return
 
         self.last_received_command = self.clock.time_msec()
 
-        self.inbound_commands_counter[cmd_name] = (
-            self.inbound_commands_counter[cmd_name] + 1
+        self.inbound_commands_counter[cmd.NAME] = (
+            self.inbound_commands_counter[cmd.NAME] + 1
         )
 
-        cmd_cls = COMMAND_MAP[cmd_name]
-        try:
-            cmd = cmd_cls.from_line(rest_of_line)
-        except Exception as e:
-            logger.exception(
-                "[%s] failed to parse line %r: %r", self.id(), cmd_name, rest_of_line
-            )
-            self.send_error(
-                "failed to parse line for  %r: %r (%r):" % (cmd_name, e, rest_of_line)
-            )
-            return
-
         # Now lets try and call on_<CMD_NAME> function
         run_as_background_process(
             "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
new file mode 100644
index 0000000000..4c08425735
--- /dev/null
+++ b/synapse/replication/tcp/redis.py
@@ -0,0 +1,181 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING
+
+import txredisapi
+
+from synapse.logging.context import PreserveLoggingContext
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.tcp.commands import (
+    Command,
+    ReplicateCommand,
+    parse_command_from_line,
+)
+from synapse.replication.tcp.protocol import AbstractConnection
+
+if TYPE_CHECKING:
+    from synapse.replication.tcp.handler import ReplicationCommandHandler
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
+    """Connection to redis subscribed to replication stream.
+
+    Parses incoming messages from redis into replication commands, and passes
+    them to `ReplicationCommandHandler`
+
+    Due to the vagaries of `txredisapi` we don't want to have a custom
+    constructor, so instead we expect the defined attributes below to be set
+    immediately after initialisation.
+
+    Attributes:
+        handler: The command handler to handle incoming commands.
+        stream_name: The *redis* stream name to subscribe to (not anything to
+            do with Synapse replication streams).
+        outbound_redis_connection: The connection to redis to use to send
+            commands.
+    """
+
+    handler = None  # type: ReplicationCommandHandler
+    stream_name = None  # type: str
+    outbound_redis_connection = None  # type: txredisapi.RedisProtocol
+
+    def connectionMade(self):
+        logger.info("Connected to redis instance")
+        self.subscribe(self.stream_name)
+        self.send_command(ReplicateCommand())
+
+        self.handler.new_connection(self)
+
+    def messageReceived(self, pattern: str, channel: str, message: str):
+        """Received a message from redis.
+        """
+
+        if message.strip() == "":
+            # Ignore blank lines
+            return
+
+        try:
+            cmd = parse_command_from_line(message)
+        except Exception:
+            logger.exception(
+                "[%s] failed to parse line: %r", message,
+            )
+            return
+
+        # Now lets try and call on_<CMD_NAME> function
+        run_as_background_process(
+            "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
+        )
+
+    async def handle_command(self, cmd: Command):
+        """Handle a command we have received over the replication stream.
+
+        By default delegates to on_<COMMAND>, which should return an awaitable.
+
+        Args:
+            cmd: received command
+        """
+        handled = False
+
+        # First call any command handlers on this instance. These are for redis
+        # specific handling.
+        cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
+        if cmd_func:
+            await cmd_func(cmd)
+            handled = True
+
+        # Then call out to the handler.
+        cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
+        if cmd_func:
+            await cmd_func(cmd)
+            handled = True
+
+        if not handled:
+            logger.warning("Unhandled command: %r", cmd)
+
+    def connectionLost(self, reason):
+        logger.info("Lost connection to redis instance")
+        self.handler.lost_connection(self)
+
+    def send_command(self, cmd: Command):
+        """Send a command if connection has been established.
+
+        Args:
+            cmd (Command)
+        """
+        string = "%s %s" % (cmd.NAME, cmd.to_line())
+        if "\n" in string:
+            raise Exception("Unexpected newline in command: %r", string)
+
+        encoded_string = string.encode("utf-8")
+
+        async def _send():
+            with PreserveLoggingContext():
+                # Note that we use the other connection as we can't send
+                # commands using the subscription connection.
+                await self.outbound_redis_connection.publish(
+                    self.stream_name, encoded_string
+                )
+
+        run_as_background_process("send-cmd", _send)
+
+
+class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
+    """This is a reconnecting factory that connects to redis and immediately
+    subscribes to a stream.
+
+    Args:
+        hs
+        outbound_redis_connection: A connection to redis that will be used to
+            send outbound commands (this is seperate to the redis connection
+            used to subscribe).
+    """
+
+    maxDelay = 5
+    continueTrying = True
+    protocol = RedisSubscriber
+
+    def __init__(
+        self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol
+    ):
+
+        super().__init__()
+
+        # This sets the password on the RedisFactory base class (as
+        # SubscriberFactory constructor doesn't pass it through).
+        self.password = hs.config.redis.redis_password
+
+        self.handler = hs.get_tcp_replication()
+        self.stream_name = hs.hostname
+
+        self.outbound_redis_connection = outbound_redis_connection
+
+    def buildProtocol(self, addr):
+        p = super().buildProtocol(addr)  # type: RedisSubscriber
+
+        # We do this here rather than add to the constructor of `RedisSubcriber`
+        # as to do so would involve overriding `buildProtocol` entirely, however
+        # the base method does some other things than just instantiating the
+        # protocol.
+        p.handler = self.handler
+        p.outbound_redis_connection = self.outbound_redis_connection
+        p.stream_name = self.stream_name
+
+        return p