summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/http/streams.py4
-rw-r--r--synapse/replication/tcp/handler.py55
-rw-r--r--synapse/replication/tcp/resource.py39
3 files changed, 50 insertions, 48 deletions
diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py
index 0459f582bf..b705a8e16c 100644
--- a/synapse/replication/http/streams.py
+++ b/synapse/replication/http/streams.py
@@ -52,9 +52,9 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
 
         self._instance_name = hs.get_instance_name()
 
-        # We pull the streams from the replication steamer (if we try and make
+        # We pull the streams from the replication handler (if we try and make
         # them ourselves we end up in an import loop).
-        self.streams = hs.get_replication_streamer().get_streams()
+        self.streams = hs.get_tcp_replication().get_streams()
 
     @staticmethod
     def _serialize_payload(stream_name, from_token, upto_token):
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index e6a50aa74e..acfa66a7a8 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -36,7 +36,12 @@ from synapse.replication.tcp.commands import (
     UserSyncCommand,
 )
 from synapse.replication.tcp.protocol import AbstractConnection
-from synapse.replication.tcp.streams import STREAMS_MAP, Stream
+from synapse.replication.tcp.streams import (
+    STREAMS_MAP,
+    CachesStream,
+    FederationStream,
+    Stream,
+)
 from synapse.util.async_helpers import Linearizer
 
 logger = logging.getLogger(__name__)
@@ -73,6 +78,26 @@ class ReplicationCommandHandler:
             stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
         }  # type: Dict[str, Stream]
 
+        # List of streams that this instance is the source of
+        self._streams_to_replicate = []  # type: List[Stream]
+
+        for stream in self._streams.values():
+            if stream.NAME == CachesStream.NAME:
+                # All workers can write to the cache invalidation stream.
+                self._streams_to_replicate.append(stream)
+                continue
+
+            # Only add any other streams if we're on master.
+            if hs.config.worker_app is not None:
+                continue
+
+            if stream.NAME == FederationStream.NAME and hs.config.send_federation:
+                # We only support federation stream if federation sending
+                # has been disabled on the master.
+                continue
+
+            self._streams_to_replicate.append(stream)
+
         self._position_linearizer = Linearizer(
             "replication_position", clock=self._clock
         )
@@ -150,6 +175,16 @@ class ReplicationCommandHandler:
             port = hs.config.worker_replication_port
             hs.get_reactor().connectTCP(host, port, self._factory)
 
+    def get_streams(self) -> Dict[str, Stream]:
+        """Get a map from stream name to all streams.
+        """
+        return self._streams
+
+    def get_streams_to_replicate(self) -> List[Stream]:
+        """Get a list of streams that this instances replicates.
+        """
+        return self._streams_to_replicate
+
     async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
         self.send_positions_to_connection(conn)
 
@@ -158,15 +193,15 @@ class ReplicationCommandHandler:
         the connection.
         """
 
-        # We only want to announce positions by the writer of the streams.
-        # Currently this is just the master process.
-        if not self._is_master:
-            return
-
-        for stream_name, stream in self._streams.items():
-            current_token = stream.current_token(self._instance_name)
-            conn.send_command(
-                PositionCommand(stream_name, self._instance_name, current_token)
+        # We respond with current position of all streams this instance
+        # replicates.
+        for stream in self.get_streams_to_replicate():
+            self.send_command(
+                PositionCommand(
+                    stream.NAME,
+                    self._instance_name,
+                    stream.current_token(self._instance_name),
+                )
             )
 
     async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 002171ce7c..41569305df 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -17,7 +17,6 @@
 
 import logging
 import random
-from typing import Dict, List
 
 from prometheus_client import Counter
 
@@ -25,12 +24,6 @@ from twisted.internet.protocol import Factory
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
-from synapse.replication.tcp.streams import (
-    STREAMS_MAP,
-    CachesStream,
-    FederationStream,
-    Stream,
-)
 from synapse.util.metrics import Measure
 
 stream_updates_counter = Counter(
@@ -80,31 +73,7 @@ class ReplicationStreamer(object):
 
         self._replication_torture_level = hs.config.replication_torture_level
 
-        # Work out list of streams that this instance is the source of.
-        self.streams = []  # type: List[Stream]
-
-        # All workers can write to the cache invalidation stream.
-        self.streams.append(CachesStream(hs))
-
-        if hs.config.worker_app is None:
-            for stream in STREAMS_MAP.values():
-                if stream == FederationStream and hs.config.send_federation:
-                    # We only support federation stream if federation sending
-                    # has been disabled on the master.
-                    continue
-
-                if stream == CachesStream:
-                    # We've already added it above.
-                    continue
-
-                self.streams.append(stream(hs))
-
-        self.streams_by_name = {stream.NAME: stream for stream in self.streams}
-
-        # Only bother registering the notifier callback if we have streams to
-        # publish.
-        if self.streams:
-            self.notifier.add_replication_callback(self.on_notifier_poke)
+        self.notifier.add_replication_callback(self.on_notifier_poke)
 
         # Keeps track of whether we are currently checking for updates
         self.is_looping = False
@@ -112,10 +81,8 @@ class ReplicationStreamer(object):
 
         self.command_handler = hs.get_tcp_replication()
 
-    def get_streams(self) -> Dict[str, Stream]:
-        """Get a mapp from stream name to stream instance.
-        """
-        return self.streams_by_name
+        # Set of streams to replicate.
+        self.streams = self.command_handler.get_streams_to_replicate()
 
     def on_notifier_poke(self):
         """Checks if there is actually any new data and sends it to the