diff --git a/changelog.d/7352.feature b/changelog.d/7352.feature
new file mode 100644
index 0000000000..ce6140fdd1
--- /dev/null
+++ b/changelog.d/7352.feature
@@ -0,0 +1 @@
+Add support for running replication over Redis when using workers.
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 6132727cbd..88a5a97caf 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -220,12 +220,6 @@ class Notifier(object):
"""
self.replication_callbacks.append(cb)
- def add_remote_server_up_callback(self, cb: Callable[[str], None]):
- """Add a callback that will be called when synapse detects a server
- has been
- """
- self.remote_server_up_callbacks.append(cb)
-
def on_new_room_event(
self, event, room_stream_id, max_room_stream_id, extra_users=[]
):
@@ -544,6 +538,3 @@ class Notifier(object):
# circular dependencies.
if self.federation_sender:
self.federation_sender.wake_destination(server)
-
- for cb in self.remote_server_up_callbacks:
- cb(server)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 3a8c7c7e2d..b8f49a8d0f 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -117,7 +117,6 @@ class ReplicationCommandHandler:
self._server_notices_sender = None
if self._is_master:
self._server_notices_sender = hs.get_server_notices_sender()
- self._notifier.add_remote_server_up_callback(self.send_remote_server_up)
def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server
@@ -163,7 +162,7 @@ class ReplicationCommandHandler:
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self._factory)
- async def on_REPLICATE(self, cmd: ReplicateCommand):
+ async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
# We only want to announce positions by the writer of the streams.
# Currently this is just the master process.
if not self._is_master:
@@ -173,7 +172,7 @@ class ReplicationCommandHandler:
current_token = stream.current_token()
self.send_command(PositionCommand(stream_name, current_token))
- async def on_USER_SYNC(self, cmd: UserSyncCommand):
+ async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
user_sync_counter.inc()
if self._is_master:
@@ -181,17 +180,23 @@ class ReplicationCommandHandler:
cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
- async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand):
+ async def on_CLEAR_USER_SYNC(
+ self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
+ ):
if self._is_master:
await self._presence_handler.update_external_syncs_clear(cmd.instance_id)
- async def on_FEDERATION_ACK(self, cmd: FederationAckCommand):
+ async def on_FEDERATION_ACK(
+ self, conn: AbstractConnection, cmd: FederationAckCommand
+ ):
federation_ack_counter.inc()
if self._federation_sender:
self._federation_sender.federation_ack(cmd.token)
- async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand):
+ async def on_REMOVE_PUSHER(
+ self, conn: AbstractConnection, cmd: RemovePusherCommand
+ ):
remove_pusher_counter.inc()
if self._is_master:
@@ -201,7 +206,9 @@ class ReplicationCommandHandler:
self._notifier.on_new_replication_data()
- async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand):
+ async def on_INVALIDATE_CACHE(
+ self, conn: AbstractConnection, cmd: InvalidateCacheCommand
+ ):
invalidate_cache_counter.inc()
if self._is_master:
@@ -211,7 +218,7 @@ class ReplicationCommandHandler:
cmd.cache_func, tuple(cmd.keys)
)
- async def on_USER_IP(self, cmd: UserIpCommand):
+ async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
user_ip_cache_counter.inc()
if self._is_master:
@@ -227,7 +234,7 @@ class ReplicationCommandHandler:
if self._server_notices_sender:
await self._server_notices_sender.on_user_ip(cmd.user_id)
- async def on_RDATA(self, cmd: RdataCommand):
+ async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc()
@@ -278,7 +285,7 @@ class ReplicationCommandHandler:
logger.debug("Received rdata %s -> %s", stream_name, token)
await self._replication_data_handler.on_rdata(stream_name, token, rows)
- async def on_POSITION(self, cmd: PositionCommand):
+ async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
stream = self._streams.get(cmd.stream_name)
if not stream:
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
@@ -332,12 +339,30 @@ class ReplicationCommandHandler:
self._streams_connected.add(cmd.stream_name)
- async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
+ async def on_REMOTE_SERVER_UP(
+ self, conn: AbstractConnection, cmd: RemoteServerUpCommand
+ ):
""""Called when get a new REMOTE_SERVER_UP command."""
self._replication_data_handler.on_remote_server_up(cmd.data)
- if self._is_master:
- self._notifier.notify_remote_server_up(cmd.data)
+ self._notifier.notify_remote_server_up(cmd.data)
+
+ # We relay to all other connections to ensure every instance gets the
+ # notification.
+ #
+ # When configured to use redis we'll always only have one connection and
+ # so this is a no-op (all instances will have already received the same
+ # REMOTE_SERVER_UP command).
+ #
+ # For direct TCP connections this will relay to all other connections
+ # connected to us. When on master this will correctly fan out to all
+ # other direct TCP clients and on workers there'll only be the one
+ # connection to master.
+ #
+ # (The logic here should also be sound if we have a mix of Redis and
+ # direct TCP connections so long as there is only one traffic route
+ # between two instances, but that is not currently supported).
+ self.send_command(cmd, ignore_conn=conn)
def new_connection(self, connection: AbstractConnection):
"""Called when we have a new connection.
@@ -382,11 +407,21 @@ class ReplicationCommandHandler:
"""
return bool(self._connections)
- def send_command(self, cmd: Command):
+ def send_command(
+ self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None
+ ):
"""Send a command to all connected connections.
+
+ Args:
+ cmd
+ ignore_conn: If set don't send command to the given connection.
+ Used when relaying commands from one connection to all others.
"""
if self._connections:
for connection in self._connections:
+ if connection == ignore_conn:
+ continue
+
try:
connection.send_command(cmd)
except Exception:
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index e3f64eba8f..4198eece71 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -260,7 +260,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# Then call out to the handler.
cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
- await cmd_func(cmd)
+ await cmd_func(self, cmd)
handled = True
if not handled:
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 49b3ed0c5e..617e860f95 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -112,7 +112,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
# Then call out to the handler.
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
- await cmd_func(cmd)
+ await cmd_func(self, cmd)
handled = True
if not handled:
diff --git a/tests/replication/tcp/test_remote_server_up.py b/tests/replication/tcp/test_remote_server_up.py
new file mode 100644
index 0000000000..d1c15caeb0
--- /dev/null
+++ b/tests/replication/tcp/test_remote_server_up.py
@@ -0,0 +1,62 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Tuple
+
+from twisted.internet.interfaces import IProtocol
+from twisted.test.proto_helpers import StringTransport
+
+from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+
+from tests.unittest import HomeserverTestCase
+
+
+class RemoteServerUpTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
+ self.factory = ReplicationStreamProtocolFactory(hs)
+
+ def _make_client(self) -> Tuple[IProtocol, StringTransport]:
+ """Create a new direct TCP replication connection
+ """
+
+ proto = self.factory.buildProtocol(("127.0.0.1", 0))
+ transport = StringTransport()
+ proto.makeConnection(transport)
+
+ # We can safely ignore the commands received during connection.
+ self.pump()
+ transport.clear()
+
+ return proto, transport
+
+ def test_relay(self):
+ """Test that Synapse will relay REMOTE_SERVER_UP commands to all
+ other connections, but not the one that sent it.
+ """
+
+ proto1, transport1 = self._make_client()
+
+ # We shouldn't receive an echo.
+ proto1.dataReceived(b"REMOTE_SERVER_UP example.com\n")
+ self.pump()
+ self.assertEqual(transport1.value(), b"")
+
+ # But we should see an echo if we connect another client
+ proto2, transport2 = self._make_client()
+ proto1.dataReceived(b"REMOTE_SERVER_UP example.com\n")
+
+ self.pump()
+ self.assertEqual(transport1.value(), b"")
+ self.assertEqual(transport2.value(), b"REMOTE_SERVER_UP example.com\n")
|