diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index e6a50aa74e..acfa66a7a8 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -36,7 +36,12 @@ from synapse.replication.tcp.commands import (
UserSyncCommand,
)
from synapse.replication.tcp.protocol import AbstractConnection
-from synapse.replication.tcp.streams import STREAMS_MAP, Stream
+from synapse.replication.tcp.streams import (
+ STREAMS_MAP,
+ CachesStream,
+ FederationStream,
+ Stream,
+)
from synapse.util.async_helpers import Linearizer
logger = logging.getLogger(__name__)
@@ -73,6 +78,26 @@ class ReplicationCommandHandler:
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream]
+ # List of streams that this instance is the source of
+ self._streams_to_replicate = [] # type: List[Stream]
+
+ for stream in self._streams.values():
+ if stream.NAME == CachesStream.NAME:
+ # All workers can write to the cache invalidation stream.
+ self._streams_to_replicate.append(stream)
+ continue
+
+ # Only add any other streams if we're on master.
+ if hs.config.worker_app is not None:
+ continue
+
+ if stream.NAME == FederationStream.NAME and hs.config.send_federation:
+ # We only support federation stream if federation sending
+ # has been disabled on the master.
+ continue
+
+ self._streams_to_replicate.append(stream)
+
self._position_linearizer = Linearizer(
"replication_position", clock=self._clock
)
@@ -150,6 +175,16 @@ class ReplicationCommandHandler:
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self._factory)
+ def get_streams(self) -> Dict[str, Stream]:
+ """Get a map from stream name to all streams.
+ """
+ return self._streams
+
+ def get_streams_to_replicate(self) -> List[Stream]:
+ """Get a list of streams that this instances replicates.
+ """
+ return self._streams_to_replicate
+
async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
self.send_positions_to_connection(conn)
@@ -158,15 +193,15 @@ class ReplicationCommandHandler:
the connection.
"""
- # We only want to announce positions by the writer of the streams.
- # Currently this is just the master process.
- if not self._is_master:
- return
-
- for stream_name, stream in self._streams.items():
- current_token = stream.current_token(self._instance_name)
- conn.send_command(
- PositionCommand(stream_name, self._instance_name, current_token)
+ # We respond with current position of all streams this instance
+ # replicates.
+ for stream in self.get_streams_to_replicate():
+ self.send_command(
+ PositionCommand(
+ stream.NAME,
+ self._instance_name,
+ stream.current_token(self._instance_name),
+ )
)
async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 002171ce7c..41569305df 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -17,7 +17,6 @@
import logging
import random
-from typing import Dict, List
from prometheus_client import Counter
@@ -25,12 +24,6 @@ from twisted.internet.protocol import Factory
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
-from synapse.replication.tcp.streams import (
- STREAMS_MAP,
- CachesStream,
- FederationStream,
- Stream,
-)
from synapse.util.metrics import Measure
stream_updates_counter = Counter(
@@ -80,31 +73,7 @@ class ReplicationStreamer(object):
self._replication_torture_level = hs.config.replication_torture_level
- # Work out list of streams that this instance is the source of.
- self.streams = [] # type: List[Stream]
-
- # All workers can write to the cache invalidation stream.
- self.streams.append(CachesStream(hs))
-
- if hs.config.worker_app is None:
- for stream in STREAMS_MAP.values():
- if stream == FederationStream and hs.config.send_federation:
- # We only support federation stream if federation sending
- # has been disabled on the master.
- continue
-
- if stream == CachesStream:
- # We've already added it above.
- continue
-
- self.streams.append(stream(hs))
-
- self.streams_by_name = {stream.NAME: stream for stream in self.streams}
-
- # Only bother registering the notifier callback if we have streams to
- # publish.
- if self.streams:
- self.notifier.add_replication_callback(self.on_notifier_poke)
+ self.notifier.add_replication_callback(self.on_notifier_poke)
# Keeps track of whether we are currently checking for updates
self.is_looping = False
@@ -112,10 +81,8 @@ class ReplicationStreamer(object):
self.command_handler = hs.get_tcp_replication()
- def get_streams(self) -> Dict[str, Stream]:
- """Get a mapp from stream name to stream instance.
- """
- return self.streams_by_name
+ # Set of streams to replicate.
+ self.streams = self.command_handler.get_streams_to_replicate()
def on_notifier_poke(self):
"""Checks if there is actually any new data and sends it to the
|