summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-04-29 14:10:59 +0100
committerGitHub <noreply@github.com>2020-04-29 14:10:59 +0100
commit3eab76ad43e1de4074550c468d11318d497ff632 (patch)
tree919ff2563327b2cef97a1c1625122fdc4fe7815f
parentFix limit logic for EventsStream (#7358) (diff)
downloadsynapse-3eab76ad43e1de4074550c468d11318d497ff632.tar.xz
Don't relay REMOTE_SERVER_UP cmds to same conn. (#7352)
For direct TCP connections we need the master to relay REMOTE_SERVER_UP
commands to the other connections so that all instances get notified
about it. The old implementation just relayed to all connections,
assuming that sending back to the original sender of the command was
safe. This is not true for redis, where commands sent get echoed back to
the sender, which was causing master to effectively infinite loop
sending and then re-receiving REMOTE_SERVER_UP commands that it sent.

The fix is to ensure that we only relay to *other* connections and not
to the connection we received the notification from.

Fixes #7334.
Diffstat (limited to '')
-rw-r--r--changelog.d/7352.feature1
-rw-r--r--synapse/notifier.py9
-rw-r--r--synapse/replication/tcp/handler.py63
-rw-r--r--synapse/replication/tcp/protocol.py2
-rw-r--r--synapse/replication/tcp/redis.py2
-rw-r--r--tests/replication/tcp/test_remote_server_up.py62
6 files changed, 114 insertions, 25 deletions
diff --git a/changelog.d/7352.feature b/changelog.d/7352.feature
new file mode 100644
index 0000000000..ce6140fdd1
--- /dev/null
+++ b/changelog.d/7352.feature
@@ -0,0 +1 @@
+Add support for running replication over Redis when using workers.
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 6132727cbd..88a5a97caf 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -220,12 +220,6 @@ class Notifier(object):
         """
         self.replication_callbacks.append(cb)
 
-    def add_remote_server_up_callback(self, cb: Callable[[str], None]):
-        """Add a callback that will be called when synapse detects a server
-        has been
-        """
-        self.remote_server_up_callbacks.append(cb)
-
     def on_new_room_event(
         self, event, room_stream_id, max_room_stream_id, extra_users=[]
     ):
@@ -544,6 +538,3 @@ class Notifier(object):
         # circular dependencies.
         if self.federation_sender:
             self.federation_sender.wake_destination(server)
-
-        for cb in self.remote_server_up_callbacks:
-            cb(server)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 3a8c7c7e2d..b8f49a8d0f 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -117,7 +117,6 @@ class ReplicationCommandHandler:
         self._server_notices_sender = None
         if self._is_master:
             self._server_notices_sender = hs.get_server_notices_sender()
-            self._notifier.add_remote_server_up_callback(self.send_remote_server_up)
 
     def start_replication(self, hs):
         """Helper method to start a replication connection to the remote server
@@ -163,7 +162,7 @@ class ReplicationCommandHandler:
             port = hs.config.worker_replication_port
             hs.get_reactor().connectTCP(host, port, self._factory)
 
-    async def on_REPLICATE(self, cmd: ReplicateCommand):
+    async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
         # We only want to announce positions by the writer of the streams.
         # Currently this is just the master process.
         if not self._is_master:
@@ -173,7 +172,7 @@ class ReplicationCommandHandler:
             current_token = stream.current_token()
             self.send_command(PositionCommand(stream_name, current_token))
 
-    async def on_USER_SYNC(self, cmd: UserSyncCommand):
+    async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
         user_sync_counter.inc()
 
         if self._is_master:
@@ -181,17 +180,23 @@ class ReplicationCommandHandler:
                 cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
             )
 
-    async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand):
+    async def on_CLEAR_USER_SYNC(
+        self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
+    ):
         if self._is_master:
             await self._presence_handler.update_external_syncs_clear(cmd.instance_id)
 
-    async def on_FEDERATION_ACK(self, cmd: FederationAckCommand):
+    async def on_FEDERATION_ACK(
+        self, conn: AbstractConnection, cmd: FederationAckCommand
+    ):
         federation_ack_counter.inc()
 
         if self._federation_sender:
             self._federation_sender.federation_ack(cmd.token)
 
-    async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand):
+    async def on_REMOVE_PUSHER(
+        self, conn: AbstractConnection, cmd: RemovePusherCommand
+    ):
         remove_pusher_counter.inc()
 
         if self._is_master:
@@ -201,7 +206,9 @@ class ReplicationCommandHandler:
 
             self._notifier.on_new_replication_data()
 
-    async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand):
+    async def on_INVALIDATE_CACHE(
+        self, conn: AbstractConnection, cmd: InvalidateCacheCommand
+    ):
         invalidate_cache_counter.inc()
 
         if self._is_master:
@@ -211,7 +218,7 @@ class ReplicationCommandHandler:
                 cmd.cache_func, tuple(cmd.keys)
             )
 
-    async def on_USER_IP(self, cmd: UserIpCommand):
+    async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
         user_ip_cache_counter.inc()
 
         if self._is_master:
@@ -227,7 +234,7 @@ class ReplicationCommandHandler:
         if self._server_notices_sender:
             await self._server_notices_sender.on_user_ip(cmd.user_id)
 
-    async def on_RDATA(self, cmd: RdataCommand):
+    async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
         stream_name = cmd.stream_name
         inbound_rdata_count.labels(stream_name).inc()
 
@@ -278,7 +285,7 @@ class ReplicationCommandHandler:
         logger.debug("Received rdata %s -> %s", stream_name, token)
         await self._replication_data_handler.on_rdata(stream_name, token, rows)
 
-    async def on_POSITION(self, cmd: PositionCommand):
+    async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
         stream = self._streams.get(cmd.stream_name)
         if not stream:
             logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
@@ -332,12 +339,30 @@ class ReplicationCommandHandler:
 
             self._streams_connected.add(cmd.stream_name)
 
-    async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
+    async def on_REMOTE_SERVER_UP(
+        self, conn: AbstractConnection, cmd: RemoteServerUpCommand
+    ):
         """"Called when get a new REMOTE_SERVER_UP command."""
         self._replication_data_handler.on_remote_server_up(cmd.data)
 
-        if self._is_master:
-            self._notifier.notify_remote_server_up(cmd.data)
+        self._notifier.notify_remote_server_up(cmd.data)
+
+        # We relay to all other connections to ensure every instance gets the
+        # notification.
+        #
+        # When configured to use redis we'll always only have one connection and
+        # so this is a no-op (all instances will have already received the same
+        # REMOTE_SERVER_UP command).
+        #
+        # For direct TCP connections this will relay to all other connections
+        # connected to us. When on master this will correctly fan out to all
+        # other direct TCP clients and on workers there'll only be the one
+        # connection to master.
+        #
+        # (The logic here should also be sound if we have a mix of Redis and
+        # direct TCP connections so long as there is only one traffic route
+        # between two instances, but that is not currently supported).
+        self.send_command(cmd, ignore_conn=conn)
 
     def new_connection(self, connection: AbstractConnection):
         """Called when we have a new connection.
@@ -382,11 +407,21 @@ class ReplicationCommandHandler:
         """
         return bool(self._connections)
 
-    def send_command(self, cmd: Command):
+    def send_command(
+        self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None
+    ):
         """Send a command to all connected connections.
+
+        Args:
+            cmd
+            ignore_conn: If set don't send command to the given connection.
+                Used when relaying commands from one connection to all others.
         """
         if self._connections:
             for connection in self._connections:
+                if connection == ignore_conn:
+                    continue
+
                 try:
                     connection.send_command(cmd)
                 except Exception:
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index e3f64eba8f..4198eece71 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -260,7 +260,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         # Then call out to the handler.
         cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
         if cmd_func:
-            await cmd_func(cmd)
+            await cmd_func(self, cmd)
             handled = True
 
         if not handled:
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 49b3ed0c5e..617e860f95 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -112,7 +112,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
         # Then call out to the handler.
         cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
         if cmd_func:
-            await cmd_func(cmd)
+            await cmd_func(self, cmd)
             handled = True
 
         if not handled:
diff --git a/tests/replication/tcp/test_remote_server_up.py b/tests/replication/tcp/test_remote_server_up.py
new file mode 100644
index 0000000000..d1c15caeb0
--- /dev/null
+++ b/tests/replication/tcp/test_remote_server_up.py
@@ -0,0 +1,62 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Tuple
+
+from twisted.internet.interfaces import IProtocol
+from twisted.test.proto_helpers import StringTransport
+
+from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+
+from tests.unittest import HomeserverTestCase
+
+
+class RemoteServerUpTestCase(HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
+        self.factory = ReplicationStreamProtocolFactory(hs)
+
+    def _make_client(self) -> Tuple[IProtocol, StringTransport]:
+        """Create a new direct TCP replication connection
+        """
+
+        proto = self.factory.buildProtocol(("127.0.0.1", 0))
+        transport = StringTransport()
+        proto.makeConnection(transport)
+
+        # We can safely ignore the commands received during connection.
+        self.pump()
+        transport.clear()
+
+        return proto, transport
+
+    def test_relay(self):
+        """Test that Synapse will relay REMOTE_SERVER_UP commands to all
+        other connections, but not the one that sent it.
+        """
+
+        proto1, transport1 = self._make_client()
+
+        # We shouldn't receive an echo.
+        proto1.dataReceived(b"REMOTE_SERVER_UP example.com\n")
+        self.pump()
+        self.assertEqual(transport1.value(), b"")
+
+        # But we should see an echo if we connect another client
+        proto2, transport2 = self._make_client()
+        proto1.dataReceived(b"REMOTE_SERVER_UP example.com\n")
+
+        self.pump()
+        self.assertEqual(transport1.value(), b"")
+        self.assertEqual(transport2.value(), b"REMOTE_SERVER_UP example.com\n")