diff --git a/changelog.d/12672.misc b/changelog.d/12672.misc
new file mode 100644
index 0000000000..265e0a801f
--- /dev/null
+++ b/changelog.d/12672.misc
@@ -0,0 +1 @@
+Lay some foundation work to allow workers to only subscribe to some kinds of messages, reducing replication traffic.
\ No newline at end of file
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 9aba1cd451..e1cbfa50eb 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -1,5 +1,5 @@
# Copyright 2017 Vector Creations Ltd
-# Copyright 2020 The Matrix.org Foundation C.I.C.
+# Copyright 2020, 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -101,6 +101,9 @@ class ReplicationCommandHandler:
self._instance_id = hs.get_instance_id()
self._instance_name = hs.get_instance_name()
+ # Additional Redis channel suffixes to subscribe to.
+ self._channels_to_subscribe_to: List[str] = []
+
self._is_presence_writer = (
hs.get_instance_name() in hs.config.worker.writers.presence
)
@@ -243,6 +246,31 @@ class ReplicationCommandHandler:
# If we're NOT using Redis, this must be handled by the master
self._should_insert_client_ips = hs.get_instance_name() == "master"
+ if self._is_master or self._should_insert_client_ips:
+ self.subscribe_to_channel("USER_IP")
+
+ def subscribe_to_channel(self, channel_name: str) -> None:
+ """
+ Indicates that we wish to subscribe to a Redis channel by name.
+
+ (The name will later be prefixed with the server name; i.e. subscribing
+ to the 'ABC' channel actually subscribes to 'example.com/ABC' Redis-side.)
+
+ Raises:
+ - If replication has already started, then it's too late to subscribe
+ to new channels.
+ """
+
+ if self._factory is not None:
+ # We don't allow subscribing after the fact to avoid the chance
+ # of missing an important message because we didn't subscribe in time.
+ raise RuntimeError(
+ "Cannot subscribe to more channels after replication started."
+ )
+
+ if channel_name not in self._channels_to_subscribe_to:
+ self._channels_to_subscribe_to.append(channel_name)
+
def _add_command_to_stream_queue(
self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
) -> None:
@@ -321,7 +349,9 @@ class ReplicationCommandHandler:
# Now create the factory/connection for the subscription stream.
self._factory = RedisDirectTcpReplicationClientFactory(
- hs, outbound_redis_connection
+ hs,
+ outbound_redis_connection,
+ channel_names=self._channels_to_subscribe_to,
)
hs.get_reactor().connectTCP(
hs.config.redis.redis_host,
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 989c5be032..73294654ef 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -14,7 +14,7 @@
import logging
from inspect import isawaitable
-from typing import TYPE_CHECKING, Any, Generic, Optional, Type, TypeVar, cast
+from typing import TYPE_CHECKING, Any, Generic, List, Optional, Type, TypeVar, cast
import attr
import txredisapi
@@ -85,14 +85,15 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
Attributes:
synapse_handler: The command handler to handle incoming commands.
- synapse_stream_name: The *redis* stream name to subscribe to and publish
+ synapse_stream_prefix: The *redis* stream name to subscribe to and publish
from (not anything to do with Synapse replication streams).
synapse_outbound_redis_connection: The connection to redis to use to send
commands.
"""
synapse_handler: "ReplicationCommandHandler"
- synapse_stream_name: str
+ synapse_stream_prefix: str
+ synapse_channel_names: List[str]
synapse_outbound_redis_connection: txredisapi.ConnectionHandler
def __init__(self, *args: Any, **kwargs: Any):
@@ -117,8 +118,13 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
# it's important to make sure that we only send the REPLICATE command once we
# have successfully subscribed to the stream - otherwise we might miss the
# POSITION response sent back by the other end.
- logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name)
- await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
+ fully_qualified_stream_names = [
+ f"{self.synapse_stream_prefix}/{stream_suffix}"
+ for stream_suffix in self.synapse_channel_names
+ ] + [self.synapse_stream_prefix]
+ logger.info("Sending redis SUBSCRIBE for %r", fully_qualified_stream_names)
+ await make_deferred_yieldable(self.subscribe(fully_qualified_stream_names))
+
logger.info(
"Successfully subscribed to redis stream, sending REPLICATE command"
)
@@ -217,7 +223,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
await make_deferred_yieldable(
self.synapse_outbound_redis_connection.publish(
- self.synapse_stream_name, encoded_string
+ self.synapse_stream_prefix, encoded_string
)
)
@@ -300,20 +306,27 @@ def format_address(address: IAddress) -> str:
class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
"""This is a reconnecting factory that connects to redis and immediately
- subscribes to a stream.
+ subscribes to some streams.
Args:
hs
outbound_redis_connection: A connection to redis that will be used to
send outbound commands (this is separate to the redis connection
used to subscribe).
+ channel_names: A list of channel names to append to the base channel name
+ to additionally subscribe to.
+ e.g. if ['ABC', 'DEF'] is specified then we'll listen to:
+ example.com; example.com/ABC; and example.com/DEF.
"""
maxDelay = 5
protocol = RedisSubscriber
def __init__(
- self, hs: "HomeServer", outbound_redis_connection: txredisapi.ConnectionHandler
+ self,
+ hs: "HomeServer",
+ outbound_redis_connection: txredisapi.ConnectionHandler,
+ channel_names: List[str],
):
super().__init__(
@@ -326,7 +339,8 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
)
self.synapse_handler = hs.get_replication_command_handler()
- self.synapse_stream_name = hs.hostname
+ self.synapse_stream_prefix = hs.hostname
+ self.synapse_channel_names = channel_names
self.synapse_outbound_redis_connection = outbound_redis_connection
@@ -340,7 +354,8 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
# protocol.
p.synapse_handler = self.synapse_handler
p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection
- p.synapse_stream_name = self.synapse_stream_name
+ p.synapse_stream_prefix = self.synapse_stream_prefix
+ p.synapse_channel_names = self.synapse_channel_names
return p
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index a7602b4c96..970d5e533b 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Dict, List, Optional, Tuple
+from collections import defaultdict
+from typing import Any, Dict, List, Optional, Set, Tuple
from twisted.internet.address import IPv4Address
from twisted.internet.protocol import Protocol
@@ -32,6 +33,7 @@ from synapse.server import HomeServer
from tests import unittest
from tests.server import FakeTransport
+from tests.utils import USE_POSTGRES_FOR_TESTS
try:
import hiredis
@@ -475,22 +477,25 @@ class FakeRedisPubSubServer:
"""A fake Redis server for pub/sub."""
def __init__(self):
- self._subscribers = set()
+ self._subscribers_by_channel: Dict[
+ bytes, Set["FakeRedisPubSubProtocol"]
+ ] = defaultdict(set)
- def add_subscriber(self, conn):
+ def add_subscriber(self, conn, channel: bytes):
"""A connection has called SUBSCRIBE"""
- self._subscribers.add(conn)
+ self._subscribers_by_channel[channel].add(conn)
def remove_subscriber(self, conn):
- """A connection has called UNSUBSCRIBE"""
- self._subscribers.discard(conn)
+ """A connection has lost connection"""
+ for subscribers in self._subscribers_by_channel.values():
+ subscribers.discard(conn)
- def publish(self, conn, channel, msg) -> int:
+ def publish(self, conn, channel: bytes, msg) -> int:
"""A connection want to publish a message to subscribers."""
- for sub in self._subscribers:
+ for sub in self._subscribers_by_channel[channel]:
sub.send(["message", channel, msg])
- return len(self._subscribers)
+ return len(self._subscribers_by_channel)
def buildProtocol(self, addr):
return FakeRedisPubSubProtocol(self)
@@ -531,9 +536,10 @@ class FakeRedisPubSubProtocol(Protocol):
num_subscribers = self._server.publish(self, channel, message)
self.send(num_subscribers)
elif command == b"SUBSCRIBE":
- (channel,) = args
- self._server.add_subscriber(self)
- self.send(["subscribe", channel, 1])
+ for idx, channel in enumerate(args):
+ num_channels = idx + 1
+ self._server.add_subscriber(self, channel)
+ self.send(["subscribe", channel, num_channels])
# Since we use SET/GET to cache things we can safely no-op them.
elif command == b"SET":
@@ -576,3 +582,27 @@ class FakeRedisPubSubProtocol(Protocol):
def connectionLost(self, reason):
self._server.remove_subscriber(self)
+
+
+class RedisMultiWorkerStreamTestCase(BaseMultiWorkerStreamTestCase):
+ """
+ A test case that enables Redis, providing a fake Redis server.
+ """
+
+ if not hiredis:
+ skip = "Requires hiredis"
+
+ if not USE_POSTGRES_FOR_TESTS:
+ # Redis replication only takes place on Postgres
+ skip = "Requires Postgres"
+
+ def default_config(self) -> Dict[str, Any]:
+ """
+ Overrides the default config to enable Redis.
+ Even if the test only uses make_worker_hs, the main process needs Redis
+ enabled otherwise it won't create a Fake Redis server to listen on the
+ Redis port and accept fake TCP connections.
+ """
+ base = super().default_config()
+ base["redis"] = {"enabled": True}
+ return base
diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py
new file mode 100644
index 0000000000..e6a19eafd5
--- /dev/null
+++ b/tests/replication/tcp/test_handler.py
@@ -0,0 +1,73 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from tests.replication._base import RedisMultiWorkerStreamTestCase
+
+
+class ChannelsTestCase(RedisMultiWorkerStreamTestCase):
+ def test_subscribed_to_enough_redis_channels(self) -> None:
+ # The default main process is subscribed to the USER_IP channel.
+ self.assertCountEqual(
+ self.hs.get_replication_command_handler()._channels_to_subscribe_to,
+ ["USER_IP"],
+ )
+
+ def test_background_worker_subscribed_to_user_ip(self) -> None:
+ # The default main process is subscribed to the USER_IP channel.
+ worker1 = self.make_worker_hs(
+ "synapse.app.generic_worker",
+ extra_config={
+ "worker_name": "worker1",
+ "run_background_tasks_on": "worker1",
+ "redis": {"enabled": True},
+ },
+ )
+ self.assertIn(
+ "USER_IP",
+ worker1.get_replication_command_handler()._channels_to_subscribe_to,
+ )
+
+ # Advance so the Redis subscription gets processed
+ self.pump(0.1)
+
+ # The counts are 2 because both the main process and the worker are subscribed.
+ self.assertEqual(len(self._redis_server._subscribers_by_channel[b"test"]), 2)
+ self.assertEqual(
+ len(self._redis_server._subscribers_by_channel[b"test/USER_IP"]), 2
+ )
+
+ def test_non_background_worker_not_subscribed_to_user_ip(self) -> None:
+ # The default main process is subscribed to the USER_IP channel.
+ worker2 = self.make_worker_hs(
+ "synapse.app.generic_worker",
+ extra_config={
+ "worker_name": "worker2",
+ "run_background_tasks_on": "worker1",
+ "redis": {"enabled": True},
+ },
+ )
+ self.assertNotIn(
+ "USER_IP",
+ worker2.get_replication_command_handler()._channels_to_subscribe_to,
+ )
+
+ # Advance so the Redis subscription gets processed
+ self.pump(0.1)
+
+ # The count is 2 because both the main process and the worker are subscribed.
+ self.assertEqual(len(self._redis_server._subscribers_by_channel[b"test"]), 2)
+ # For USER_IP, the count is 1 because only the main process is subscribed.
+ self.assertEqual(
+ len(self._redis_server._subscribers_by_channel[b"test/USER_IP"]), 1
+ )
|